deepdoctection 0.37.3__py3-none-any.whl → 0.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +5 -1
- deepdoctection/analyzer/_config.py +2 -1
- deepdoctection/analyzer/dd.py +6 -5
- deepdoctection/analyzer/factory.py +16 -6
- deepdoctection/configs/conf_dd_one.yaml +126 -85
- deepdoctection/datapoint/box.py +2 -4
- deepdoctection/datapoint/convert.py +14 -8
- deepdoctection/datapoint/image.py +12 -5
- deepdoctection/datapoint/view.py +151 -53
- deepdoctection/extern/hfdetr.py +4 -3
- deepdoctection/extern/model.py +6 -97
- deepdoctection/mapper/cats.py +21 -10
- deepdoctection/mapper/match.py +0 -22
- deepdoctection/mapper/misc.py +12 -2
- deepdoctection/mapper/pubstruct.py +1 -1
- deepdoctection/pipe/doctectionpipe.py +20 -3
- deepdoctection/pipe/lm.py +20 -5
- deepdoctection/pipe/refine.py +6 -13
- deepdoctection/pipe/segment.py +225 -46
- deepdoctection/pipe/sub_layout.py +40 -22
- deepdoctection/train/hf_layoutlm_train.py +3 -1
- deepdoctection/utils/pdf_utils.py +17 -9
- {deepdoctection-0.37.3.dist-info → deepdoctection-0.39.dist-info}/METADATA +15 -5
- {deepdoctection-0.37.3.dist-info → deepdoctection-0.39.dist-info}/RECORD +27 -27
- {deepdoctection-0.37.3.dist-info → deepdoctection-0.39.dist-info}/WHEEL +1 -1
- {deepdoctection-0.37.3.dist-info → deepdoctection-0.39.dist-info}/LICENSE +0 -0
- {deepdoctection-0.37.3.dist-info → deepdoctection-0.39.dist-info}/top_level.txt +0 -0
|
@@ -109,8 +109,13 @@ def _proto_process(
|
|
|
109
109
|
|
|
110
110
|
|
|
111
111
|
@curry
|
|
112
|
-
def _to_image(
|
|
113
|
-
|
|
112
|
+
def _to_image(
|
|
113
|
+
dp: Union[str, Mapping[str, Union[str, bytes]]],
|
|
114
|
+
dpi: Optional[int] = None,
|
|
115
|
+
width: Optional[int] = None,
|
|
116
|
+
height: Optional[int] = None,
|
|
117
|
+
) -> Optional[Image]:
|
|
118
|
+
return to_image(dp, dpi, width, height)
|
|
114
119
|
|
|
115
120
|
|
|
116
121
|
def _doc_to_dataflow(path: PathLikeOrStr, max_datapoints: Optional[int] = None) -> DataFlow:
|
|
@@ -188,7 +193,19 @@ class DoctectionPipe(Pipeline):
|
|
|
188
193
|
|
|
189
194
|
df = MapData(df, _proto_process(path, doc_path))
|
|
190
195
|
if dataset_dataflow is None:
|
|
191
|
-
|
|
196
|
+
if dpi := os.environ["DPI"]:
|
|
197
|
+
df = MapData(df, _to_image(dpi=int(dpi))) # pylint: disable=E1120
|
|
198
|
+
else:
|
|
199
|
+
width, height = kwargs.get("width", ""), kwargs.get("height", "")
|
|
200
|
+
if not width or not height:
|
|
201
|
+
width = os.environ["IMAGE_WIDTH"]
|
|
202
|
+
height = os.environ["IMAGE_HEIGHT"]
|
|
203
|
+
if not width or not height:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
"DPI, IMAGE_WIDTH and IMAGE_HEIGHT are all None, but "
|
|
206
|
+
"either DPI or IMAGE_WIDTH and IMAGE_HEIGHT must be set"
|
|
207
|
+
)
|
|
208
|
+
df = MapData(df, _to_image(width=int(width), height=int(height))) # pylint: disable=E1120
|
|
192
209
|
return df
|
|
193
210
|
|
|
194
211
|
@staticmethod
|
deepdoctection/pipe/lm.py
CHANGED
|
@@ -24,6 +24,7 @@ from copy import copy
|
|
|
24
24
|
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Sequence, Union
|
|
25
25
|
|
|
26
26
|
from ..datapoint.image import Image
|
|
27
|
+
from ..extern.base import SequenceClassResult
|
|
27
28
|
from ..mapper.laylmstruct import image_to_layoutlm_features, image_to_lm_features
|
|
28
29
|
from ..utils.settings import BioTag, LayoutType, ObjectTypes, PageType, TokenClasses, WordType
|
|
29
30
|
from .base import MetaAnnotation, PipelineComponent
|
|
@@ -264,6 +265,7 @@ class LMSequenceClassifierService(PipelineComponent):
|
|
|
264
265
|
padding: Literal["max_length", "do_not_pad", "longest"] = "max_length",
|
|
265
266
|
truncation: bool = True,
|
|
266
267
|
return_overflowing_tokens: bool = False,
|
|
268
|
+
use_other_as_default_category: bool = False
|
|
267
269
|
) -> None:
|
|
268
270
|
"""
|
|
269
271
|
:param tokenizer: Tokenizer, typing allows currently anything. This will be changed in the future
|
|
@@ -279,11 +281,16 @@ class LMSequenceClassifierService(PipelineComponent):
|
|
|
279
281
|
:param return_overflowing_tokens: If a sequence (due to a truncation strategy) overflows the overflowing tokens
|
|
280
282
|
can be returned as an additional batch element. Not that in this case, the number of input
|
|
281
283
|
batch samples will be smaller than the output batch samples.
|
|
284
|
+
:param use_other_as_default_category: When predicting document classes, it might be possible that some pages
|
|
285
|
+
do not get sent to the model because they are empty. If set to `True` it
|
|
286
|
+
will assign images with no features the category `TokenClasses.OTHER`.
|
|
287
|
+
|
|
282
288
|
"""
|
|
283
289
|
self.language_model = language_model
|
|
284
290
|
self.padding = padding
|
|
285
291
|
self.truncation = truncation
|
|
286
292
|
self.return_overflowing_tokens = return_overflowing_tokens
|
|
293
|
+
self.use_other_as_default_category = use_other_as_default_category
|
|
287
294
|
self.tokenizer = tokenizer
|
|
288
295
|
self.mapping_to_lm_input_func = self.image_to_features_func(self.language_model.image_to_features_mapping())
|
|
289
296
|
super().__init__(self._get_name(), self.language_model.model_id)
|
|
@@ -299,12 +306,20 @@ class LMSequenceClassifierService(PipelineComponent):
|
|
|
299
306
|
|
|
300
307
|
def serve(self, dp: Image) -> None:
|
|
301
308
|
lm_input = self.mapping_to_lm_input_func(**self.required_kwargs)(dp)
|
|
309
|
+
lm_output = None
|
|
302
310
|
if lm_input is None:
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
311
|
+
if self.use_other_as_default_category:
|
|
312
|
+
class_id = self.language_model.categories.get_categories(as_dict=True,
|
|
313
|
+
name_as_key=True).get(TokenClasses.OTHER, 1)
|
|
314
|
+
lm_output = SequenceClassResult(class_name=TokenClasses.OTHER,
|
|
315
|
+
class_id = class_id,
|
|
316
|
+
score=-1.)
|
|
317
|
+
else:
|
|
318
|
+
lm_output = self.language_model.predict(**lm_input)
|
|
319
|
+
if lm_output:
|
|
320
|
+
self.dp_manager.set_summary_annotation(
|
|
321
|
+
PageType.DOCUMENT_TYPE, lm_output.class_name, lm_output.class_id, None, lm_output.score
|
|
322
|
+
)
|
|
308
323
|
|
|
309
324
|
def clone(self) -> LMSequenceClassifierService:
|
|
310
325
|
return self.__class__(
|
deepdoctection/pipe/refine.py
CHANGED
|
@@ -295,28 +295,21 @@ def _html_table(
|
|
|
295
295
|
return html
|
|
296
296
|
|
|
297
297
|
|
|
298
|
-
def generate_html_string(table: ImageAnnotation) -> list[str]:
|
|
298
|
+
def generate_html_string(table: ImageAnnotation, cell_names: Sequence[ObjectTypes]) -> list[str]:
|
|
299
299
|
"""
|
|
300
300
|
Takes the table segmentation by using table cells row number, column numbers etc. and generates a html
|
|
301
301
|
representation.
|
|
302
302
|
|
|
303
303
|
:param table: An annotation that has a not None image and fully segmented cell annotation.
|
|
304
|
+
:param cell_names: List of cell names that are used for the table segmentation. Note: It must be ensured that
|
|
305
|
+
that all cells have a row number, column number, row span and column span and that the dissection
|
|
306
|
+
by rows and columns is completely covered by cells.
|
|
304
307
|
:return: HTML representation of the table
|
|
305
308
|
"""
|
|
306
309
|
if table.image is None:
|
|
307
310
|
raise ImageError("table.image cannot be None")
|
|
308
311
|
table_image = table.image
|
|
309
|
-
cells = table_image.get_annotation(
|
|
310
|
-
category_names=[
|
|
311
|
-
LayoutType.CELL,
|
|
312
|
-
CellType.HEADER,
|
|
313
|
-
CellType.BODY,
|
|
314
|
-
CellType.SPANNING,
|
|
315
|
-
CellType.ROW_HEADER,
|
|
316
|
-
CellType.COLUMN_HEADER,
|
|
317
|
-
CellType.PROJECTED_ROW_HEADER,
|
|
318
|
-
]
|
|
319
|
-
)
|
|
312
|
+
cells = table_image.get_annotation(category_names=cell_names)
|
|
320
313
|
number_of_rows = table_image.summary.get_sub_category(TableType.NUMBER_OF_ROWS).category_id
|
|
321
314
|
number_of_cols = table_image.summary.get_sub_category(TableType.NUMBER_OF_COLUMNS).category_id
|
|
322
315
|
table_list = []
|
|
@@ -485,7 +478,7 @@ class TableSegmentationRefinementService(PipelineComponent):
|
|
|
485
478
|
self.dp_manager.set_summary_annotation(
|
|
486
479
|
TableType.MAX_COL_SPAN, TableType.MAX_COL_SPAN, max_col_span, annotation_id=table.annotation_id
|
|
487
480
|
)
|
|
488
|
-
html = generate_html_string(table)
|
|
481
|
+
html = generate_html_string(table, self.cell_names)
|
|
489
482
|
self.dp_manager.set_container_annotation(TableType.HTML, -1, TableType.HTML, table.annotation_id, html)
|
|
490
483
|
|
|
491
484
|
def clone(self) -> TableSegmentationRefinementService:
|
deepdoctection/pipe/segment.py
CHANGED
|
@@ -28,13 +28,13 @@ from typing import Literal, Optional, Sequence, Union
|
|
|
28
28
|
import numpy as np
|
|
29
29
|
|
|
30
30
|
from ..datapoint.annotation import ImageAnnotation
|
|
31
|
-
from ..datapoint.box import BoundingBox, global_to_local_coords, intersection_boxes, iou
|
|
31
|
+
from ..datapoint.box import BoundingBox, global_to_local_coords, intersection_box, intersection_boxes, iou, merge_boxes
|
|
32
32
|
from ..datapoint.image import Image
|
|
33
33
|
from ..extern.base import DetectionResult
|
|
34
34
|
from ..mapper.maputils import MappingContextManager
|
|
35
35
|
from ..mapper.match import match_anns_by_intersection
|
|
36
36
|
from ..utils.error import ImageError
|
|
37
|
-
from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType
|
|
37
|
+
from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType, TypeOrStr, get_type
|
|
38
38
|
from .base import MetaAnnotation, PipelineComponent
|
|
39
39
|
from .refine import generate_html_string
|
|
40
40
|
from .registry import pipeline_component_registry
|
|
@@ -55,6 +55,15 @@ class SegmentationResult:
|
|
|
55
55
|
cs: int
|
|
56
56
|
|
|
57
57
|
|
|
58
|
+
@dataclass
|
|
59
|
+
class ItemHeaderResult:
|
|
60
|
+
"""
|
|
61
|
+
Simple mutable storage for item header results
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
annotation_id: str
|
|
65
|
+
|
|
66
|
+
|
|
58
67
|
def choose_items_by_iou(
|
|
59
68
|
dp: Image,
|
|
60
69
|
item_proposals: list[ImageAnnotation],
|
|
@@ -314,7 +323,7 @@ def _tile_by_stretching_rows_leftwise_column_downwise(
|
|
|
314
323
|
|
|
315
324
|
|
|
316
325
|
def tile_tables_with_items_per_table(
|
|
317
|
-
dp: Image, table: ImageAnnotation, item_name:
|
|
326
|
+
dp: Image, table: ImageAnnotation, item_name: ObjectTypes, stretch_rule: Literal["left", "equal"] = "left"
|
|
318
327
|
) -> Image:
|
|
319
328
|
"""
|
|
320
329
|
Tiling a table with items (i.e. rows or columns). To ensure that every position in a table can be assigned to a row
|
|
@@ -355,9 +364,9 @@ def tile_tables_with_items_per_table(
|
|
|
355
364
|
|
|
356
365
|
def stretch_items(
|
|
357
366
|
dp: Image,
|
|
358
|
-
table_name:
|
|
359
|
-
row_name:
|
|
360
|
-
col_name:
|
|
367
|
+
table_name: ObjectTypes,
|
|
368
|
+
row_name: ObjectTypes,
|
|
369
|
+
col_name: ObjectTypes,
|
|
361
370
|
remove_iou_threshold_rows: float,
|
|
362
371
|
remove_iou_threshold_cols: float,
|
|
363
372
|
) -> Image:
|
|
@@ -491,7 +500,7 @@ def create_intersection_cells(
|
|
|
491
500
|
cols: Sequence[ImageAnnotation],
|
|
492
501
|
table_annotation_id: str,
|
|
493
502
|
cell_class_id: int,
|
|
494
|
-
sub_item_names: Sequence[
|
|
503
|
+
sub_item_names: Sequence[ObjectTypes],
|
|
495
504
|
) -> tuple[Sequence[DetectionResult], Sequence[SegmentationResult]]:
|
|
496
505
|
"""
|
|
497
506
|
Given rows and columns with row- and column number sub categories, create a list of `DetectionResult` and
|
|
@@ -511,6 +520,7 @@ def create_intersection_cells(
|
|
|
511
520
|
detect_result_cells = []
|
|
512
521
|
segment_result_cells = []
|
|
513
522
|
idx = 0
|
|
523
|
+
break_outer_loop = False
|
|
514
524
|
for row in rows:
|
|
515
525
|
for col in cols:
|
|
516
526
|
detect_result_cells.append(
|
|
@@ -531,17 +541,59 @@ def create_intersection_cells(
|
|
|
531
541
|
)
|
|
532
542
|
)
|
|
533
543
|
idx += 1
|
|
534
|
-
# it is possible to have less intersection boxes, e.g. if one cell has height/width 0
|
|
544
|
+
# it is possible to have less intersection boxes, e.g. if one cell has height/width 0. We need to break both
|
|
545
|
+
# loops.
|
|
535
546
|
if idx >= len(boxes_cells):
|
|
547
|
+
break_outer_loop = True
|
|
536
548
|
break
|
|
549
|
+
if break_outer_loop:
|
|
550
|
+
break
|
|
537
551
|
return detect_result_cells, segment_result_cells
|
|
538
552
|
|
|
539
553
|
|
|
554
|
+
def header_cell_to_item_detect_result(
|
|
555
|
+
dp: Image,
|
|
556
|
+
table: ImageAnnotation,
|
|
557
|
+
item_name: ObjectTypes,
|
|
558
|
+
item_header_name: ObjectTypes,
|
|
559
|
+
segment_rule: Literal["iou", "ioa"],
|
|
560
|
+
threshold: float,
|
|
561
|
+
) -> list[ItemHeaderResult]:
|
|
562
|
+
"""
|
|
563
|
+
Match header cells to items (rows or columns) based on intersection-over-union (iou) or intersection-over-area (ioa)
|
|
564
|
+
and return a list of ItemHeaderResult.
|
|
565
|
+
|
|
566
|
+
:param dp: The image containing the table and items.
|
|
567
|
+
:param table: The table image annotation.
|
|
568
|
+
:param item_name: The type of items (e.g., rows or columns) to match with header cells.
|
|
569
|
+
:param item_header_name: The type of header cells to match with items.
|
|
570
|
+
:param segment_rule: The rule to use for matching, either 'iou' or 'ioa'.
|
|
571
|
+
:param threshold: The iou/ioa threshold for matching header cells with items.
|
|
572
|
+
:return: A list of ItemHeaderResult containing the matched header cells.
|
|
573
|
+
"""
|
|
574
|
+
child_ann_ids = table.get_relationship(Relationships.CHILD)
|
|
575
|
+
item_index, _, items, _ = match_anns_by_intersection(
|
|
576
|
+
dp,
|
|
577
|
+
item_header_name,
|
|
578
|
+
item_name,
|
|
579
|
+
segment_rule,
|
|
580
|
+
threshold,
|
|
581
|
+
True,
|
|
582
|
+
child_ann_ids,
|
|
583
|
+
child_ann_ids,
|
|
584
|
+
)
|
|
585
|
+
item_headers = []
|
|
586
|
+
for idx, item in enumerate(items):
|
|
587
|
+
if idx in item_index:
|
|
588
|
+
item_headers.append(ItemHeaderResult(annotation_id=item.annotation_id))
|
|
589
|
+
return item_headers
|
|
590
|
+
|
|
591
|
+
|
|
540
592
|
def segment_pubtables(
|
|
541
593
|
dp: Image,
|
|
542
594
|
table: ImageAnnotation,
|
|
543
|
-
item_names: Sequence[
|
|
544
|
-
spanning_cell_names: Sequence[
|
|
595
|
+
item_names: Sequence[ObjectTypes],
|
|
596
|
+
spanning_cell_names: Sequence[ObjectTypes],
|
|
545
597
|
segment_rule: Literal["iou", "ioa"],
|
|
546
598
|
threshold_rows: float,
|
|
547
599
|
threshold_cols: float,
|
|
@@ -553,7 +605,7 @@ def segment_pubtables(
|
|
|
553
605
|
|
|
554
606
|
Row and column positions as well as row and column lengths are determined for all types of spanning cells.
|
|
555
607
|
All simple cells that are covered by a spanning cell as well in the table position (double allocation) are then
|
|
556
|
-
|
|
608
|
+
replaced by the spanning cell and deactivated.
|
|
557
609
|
|
|
558
610
|
:param dp: Image
|
|
559
611
|
:param table: table ImageAnnotation
|
|
@@ -566,6 +618,7 @@ def segment_pubtables(
|
|
|
566
618
|
to the column.
|
|
567
619
|
:return: A list of len(number of cells) of SegmentationResult for spanning cells
|
|
568
620
|
"""
|
|
621
|
+
|
|
569
622
|
child_ann_ids = table.get_relationship(Relationships.CHILD)
|
|
570
623
|
cell_index_rows, row_index, _, _ = match_anns_by_intersection(
|
|
571
624
|
dp,
|
|
@@ -600,29 +653,77 @@ def segment_pubtables(
|
|
|
600
653
|
for idx, cell in enumerate(spanning_cells):
|
|
601
654
|
cell_positions_rows = cell_index_rows == idx
|
|
602
655
|
rows_of_cell = [rows[k] for k in row_index[cell_positions_rows]]
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
row_number =
|
|
656
|
+
if rows_of_cell:
|
|
657
|
+
min_row_cell = min(rows_of_cell, key=lambda row: row.get_sub_category(CellType.ROW_NUMBER).category_id)
|
|
658
|
+
max_row_cell = max(rows_of_cell, key=lambda row: row.get_sub_category(CellType.ROW_NUMBER).category_id)
|
|
659
|
+
max_row = max_row_cell.get_sub_category(CellType.ROW_NUMBER).category_id
|
|
660
|
+
min_row = min_row_cell.get_sub_category(CellType.ROW_NUMBER).category_id
|
|
661
|
+
rs = max_row - min_row + 1
|
|
662
|
+
row_number = min_row
|
|
610
663
|
else:
|
|
664
|
+
rs = 0
|
|
611
665
|
row_number = 0
|
|
612
666
|
|
|
613
667
|
cell_positions_cols = cell_index_cols == idx
|
|
614
668
|
cols_of_cell = [columns[k] for k in col_index[cell_positions_cols]]
|
|
615
|
-
cs = (
|
|
616
|
-
max(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
|
|
617
|
-
- min(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
|
|
618
|
-
+ 1
|
|
619
|
-
)
|
|
620
669
|
|
|
621
|
-
if
|
|
622
|
-
|
|
670
|
+
if cols_of_cell:
|
|
671
|
+
min_col_cell = min(
|
|
672
|
+
cols_of_cell, key=lambda col: col.get_sub_category(CellType.COLUMN_NUMBER).category_id
|
|
673
|
+
)
|
|
674
|
+
max_col_cell = max(
|
|
675
|
+
cols_of_cell, key=lambda col: col.get_sub_category(CellType.COLUMN_NUMBER).category_id
|
|
676
|
+
)
|
|
677
|
+
max_col = max_col_cell.get_sub_category(CellType.COLUMN_NUMBER).category_id
|
|
678
|
+
min_col = min_col_cell.get_sub_category(CellType.COLUMN_NUMBER).category_id
|
|
679
|
+
cs = max_col - min_col + 1
|
|
680
|
+
col_number = min_col
|
|
623
681
|
else:
|
|
682
|
+
cs = 0
|
|
624
683
|
col_number = 0
|
|
625
684
|
|
|
685
|
+
if rows_of_cell and cols_of_cell:
|
|
686
|
+
# We resize all bounding boxes of spanning cells so that they match with the grid structure, determined
|
|
687
|
+
# by the rows ans columns.
|
|
688
|
+
merge_box_image_row = merge_boxes(
|
|
689
|
+
*[min_row_cell.get_bounding_box(dp.image_id), max_row_cell.get_bounding_box(dp.image_id)]
|
|
690
|
+
)
|
|
691
|
+
merge_box_image_column = merge_boxes(
|
|
692
|
+
*[min_col_cell.get_bounding_box(dp.image_id), max_col_cell.get_bounding_box(dp.image_id)]
|
|
693
|
+
)
|
|
694
|
+
merge_box_image = intersection_box(merge_box_image_row, merge_box_image_column)
|
|
695
|
+
merge_box_table_row = merge_boxes(
|
|
696
|
+
*[
|
|
697
|
+
min_row_cell.get_bounding_box(table.annotation_id),
|
|
698
|
+
max_row_cell.get_bounding_box(table.annotation_id),
|
|
699
|
+
]
|
|
700
|
+
)
|
|
701
|
+
merge_box_table_column = merge_boxes(
|
|
702
|
+
*[
|
|
703
|
+
min_col_cell.get_bounding_box(table.annotation_id),
|
|
704
|
+
max_col_cell.get_bounding_box(table.annotation_id),
|
|
705
|
+
]
|
|
706
|
+
)
|
|
707
|
+
merge_box_table = intersection_box(merge_box_table_row, merge_box_table_column)
|
|
708
|
+
merge_box_spanning_cell_row = merge_boxes(
|
|
709
|
+
*[
|
|
710
|
+
min_row_cell.get_bounding_box(min_row_cell.annotation_id),
|
|
711
|
+
max_row_cell.get_bounding_box(max_row_cell.annotation_id),
|
|
712
|
+
]
|
|
713
|
+
)
|
|
714
|
+
merge_box_spanning_cell_column = merge_boxes(
|
|
715
|
+
*[
|
|
716
|
+
min_col_cell.get_bounding_box(min_col_cell.annotation_id),
|
|
717
|
+
max_col_cell.get_bounding_box(max_col_cell.annotation_id),
|
|
718
|
+
]
|
|
719
|
+
)
|
|
720
|
+
merge_box_spanning_cell = intersection_box(merge_box_spanning_cell_row, merge_box_spanning_cell_column)
|
|
721
|
+
if cell.image is None:
|
|
722
|
+
raise ImageError("cell.image cannot be None")
|
|
723
|
+
cell.image.set_embedding(dp.image_id, merge_box_image)
|
|
724
|
+
cell.image.set_embedding(table.annotation_id, merge_box_table)
|
|
725
|
+
cell.image.set_embedding(cell.annotation_id, merge_box_spanning_cell)
|
|
726
|
+
|
|
626
727
|
raw_table_segments.append(
|
|
627
728
|
SegmentationResult(
|
|
628
729
|
annotation_id=cell.annotation_id,
|
|
@@ -674,10 +775,10 @@ class TableSegmentationService(PipelineComponent):
|
|
|
674
775
|
tile_table_with_items: bool,
|
|
675
776
|
remove_iou_threshold_rows: float,
|
|
676
777
|
remove_iou_threshold_cols: float,
|
|
677
|
-
table_name:
|
|
678
|
-
cell_names: Sequence[
|
|
679
|
-
item_names: Sequence[
|
|
680
|
-
sub_item_names: Sequence[
|
|
778
|
+
table_name: TypeOrStr,
|
|
779
|
+
cell_names: Sequence[TypeOrStr],
|
|
780
|
+
item_names: Sequence[TypeOrStr],
|
|
781
|
+
sub_item_names: Sequence[TypeOrStr],
|
|
681
782
|
stretch_rule: Literal["left", "equal"] = "left",
|
|
682
783
|
):
|
|
683
784
|
"""
|
|
@@ -705,10 +806,10 @@ class TableSegmentationService(PipelineComponent):
|
|
|
705
806
|
self.tile_table = tile_table_with_items
|
|
706
807
|
self.remove_iou_threshold_rows = remove_iou_threshold_rows
|
|
707
808
|
self.remove_iou_threshold_cols = remove_iou_threshold_cols
|
|
708
|
-
self.table_name = table_name
|
|
709
|
-
self.cell_names = cell_names
|
|
710
|
-
self.item_names = item_names # row names must be before column name
|
|
711
|
-
self.sub_item_names = sub_item_names
|
|
809
|
+
self.table_name = get_type(table_name)
|
|
810
|
+
self.cell_names = [get_type(cell_name) for cell_name in cell_names]
|
|
811
|
+
self.item_names = [get_type(item_name) for item_name in item_names] # row names must be before column name
|
|
812
|
+
self.sub_item_names = [get_type(sub_item_name) for sub_item_name in sub_item_names]
|
|
712
813
|
self.stretch_rule = stretch_rule
|
|
713
814
|
self.item_iou_threshold = 0.0001
|
|
714
815
|
super().__init__("table_segment")
|
|
@@ -876,11 +977,13 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
876
977
|
remove_iou_threshold_rows: float,
|
|
877
978
|
remove_iou_threshold_cols: float,
|
|
878
979
|
cell_class_id: int,
|
|
879
|
-
table_name:
|
|
880
|
-
cell_names: Sequence[
|
|
881
|
-
spanning_cell_names: Sequence[
|
|
882
|
-
item_names: Sequence[
|
|
883
|
-
sub_item_names: Sequence[
|
|
980
|
+
table_name: TypeOrStr,
|
|
981
|
+
cell_names: Sequence[TypeOrStr],
|
|
982
|
+
spanning_cell_names: Sequence[TypeOrStr],
|
|
983
|
+
item_names: Sequence[TypeOrStr],
|
|
984
|
+
sub_item_names: Sequence[TypeOrStr],
|
|
985
|
+
item_header_cell_names: Sequence[TypeOrStr],
|
|
986
|
+
item_header_thresholds: Sequence[float],
|
|
884
987
|
cell_to_image: bool = True,
|
|
885
988
|
crop_cell_image: bool = False,
|
|
886
989
|
stretch_rule: Literal["left", "equal"] = "left",
|
|
@@ -900,6 +1003,11 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
900
1003
|
:param spanning_cell_names: layout type of spanning cells
|
|
901
1004
|
:param item_names: layout type of items (e.g. row and column)
|
|
902
1005
|
:param sub_item_names: layout type of sub items (e.g. row number and column number)
|
|
1006
|
+
:param item_header_cell_names: layout type of item header cells (e.g. CellType.COLUMN_HEADER,
|
|
1007
|
+
CellType.ROW_HEADER). Note that column header, resp. row header will be first assigned to rows, resp. columns
|
|
1008
|
+
and then transferred to cells.
|
|
1009
|
+
:param item_header_thresholds: iou/ioa threshold for matching header cells with items. The first threshold
|
|
1010
|
+
corresponds to matching the first entry of item_names.
|
|
903
1011
|
:param cell_to_image: If set to 'True' it will create an 'Image' for LayoutType.cell
|
|
904
1012
|
:param crop_cell_image: If set to 'True' it will crop a numpy array image for LayoutType.cell.
|
|
905
1013
|
Requires 'cell_to_image=True'
|
|
@@ -909,17 +1017,20 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
909
1017
|
self.threshold_rows = threshold_rows
|
|
910
1018
|
self.threshold_cols = threshold_cols
|
|
911
1019
|
self.tile_table = tile_table_with_items
|
|
912
|
-
self.table_name = table_name
|
|
913
|
-
self.cell_names = cell_names
|
|
914
|
-
self.spanning_cell_names = spanning_cell_names
|
|
1020
|
+
self.table_name = get_type(table_name)
|
|
1021
|
+
self.cell_names = [get_type(cell_name) for cell_name in cell_names]
|
|
1022
|
+
self.spanning_cell_names = [get_type(cell_name) for cell_name in spanning_cell_names]
|
|
915
1023
|
self.remove_iou_threshold_rows = remove_iou_threshold_rows
|
|
916
1024
|
self.remove_iou_threshold_cols = remove_iou_threshold_cols
|
|
917
1025
|
self.cell_class_id = cell_class_id
|
|
918
1026
|
self.cell_to_image = cell_to_image
|
|
919
1027
|
self.crop_cell_image = crop_cell_image
|
|
920
|
-
self.item_names = item_names # row names must be before column name
|
|
921
|
-
self.sub_item_names = sub_item_names
|
|
1028
|
+
self.item_names = [get_type(item_name) for item_name in item_names] # row names must be before column name
|
|
1029
|
+
self.sub_item_names = [get_type(item_name) for item_name in sub_item_names]
|
|
922
1030
|
self.stretch_rule = stretch_rule
|
|
1031
|
+
self.item_header_cell_names = [get_type(item_name) for item_name in item_header_cell_names]
|
|
1032
|
+
self.item_header_thresholds = item_header_thresholds
|
|
1033
|
+
|
|
923
1034
|
super().__init__("table_transformer_segment")
|
|
924
1035
|
|
|
925
1036
|
def serve(self, dp: Image) -> None:
|
|
@@ -932,10 +1043,18 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
932
1043
|
self.remove_iou_threshold_cols,
|
|
933
1044
|
)
|
|
934
1045
|
table_anns = dp.get_annotation(category_names=self.table_name)
|
|
1046
|
+
has_item_headers = {item: False for item in self.item_names}
|
|
935
1047
|
for table in table_anns:
|
|
936
1048
|
item_ann_ids = table.get_relationship(Relationships.CHILD)
|
|
937
|
-
for item_sub_item_name in zip(
|
|
938
|
-
|
|
1049
|
+
for item_sub_item_name in zip(
|
|
1050
|
+
self.item_names, self.sub_item_names, self.item_header_cell_names, self.item_header_thresholds
|
|
1051
|
+
): # one pass for rows and one for cols
|
|
1052
|
+
item_name, sub_item_name, item_header_cell_name, item_header_threshold = (
|
|
1053
|
+
item_sub_item_name[0],
|
|
1054
|
+
item_sub_item_name[1],
|
|
1055
|
+
item_sub_item_name[2],
|
|
1056
|
+
item_sub_item_name[3],
|
|
1057
|
+
)
|
|
939
1058
|
if self.tile_table:
|
|
940
1059
|
dp = tile_tables_with_items_per_table(dp, table, item_name, self.stretch_rule)
|
|
941
1060
|
items = dp.get_annotation(category_names=item_name, annotation_ids=item_ann_ids)
|
|
@@ -949,10 +1068,24 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
949
1068
|
)
|
|
950
1069
|
)
|
|
951
1070
|
|
|
1071
|
+
item_headers_detect_results = header_cell_to_item_detect_result(
|
|
1072
|
+
dp, table, item_name, item_header_cell_name, self.segment_rule, item_header_threshold
|
|
1073
|
+
)
|
|
1074
|
+
if item_headers_detect_results:
|
|
1075
|
+
has_item_headers[item_name] = True
|
|
1076
|
+
|
|
952
1077
|
for item_number, item in enumerate(items, 1):
|
|
953
1078
|
self.dp_manager.set_category_annotation(
|
|
954
1079
|
sub_item_name, item_number, sub_item_name, item.annotation_id
|
|
955
1080
|
)
|
|
1081
|
+
for item_header_detect_result in item_headers_detect_results:
|
|
1082
|
+
self.dp_manager.set_category_annotation(
|
|
1083
|
+
category_name=item_header_cell_name,
|
|
1084
|
+
category_id=None,
|
|
1085
|
+
sub_cat_key=item_header_cell_name,
|
|
1086
|
+
annotation_id=item_header_detect_result.annotation_id,
|
|
1087
|
+
)
|
|
1088
|
+
|
|
956
1089
|
rows = dp.get_annotation(category_names=self.item_names[0], annotation_ids=item_ann_ids)
|
|
957
1090
|
columns = dp.get_annotation(category_names=self.item_names[1], annotation_ids=item_ann_ids)
|
|
958
1091
|
detect_result_cells, segment_result_cells = create_intersection_cells(
|
|
@@ -979,6 +1112,7 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
979
1112
|
CellType.COLUMN_SPAN, segment_result.cs, CellType.COLUMN_SPAN, segment_result.annotation_id
|
|
980
1113
|
)
|
|
981
1114
|
cell_rn_cn_to_ann_id[(segment_result.row_num, segment_result.col_num)] = segment_result.annotation_id
|
|
1115
|
+
|
|
982
1116
|
spanning_cell_raw_segments = segment_pubtables(
|
|
983
1117
|
dp,
|
|
984
1118
|
table,
|
|
@@ -988,7 +1122,15 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
988
1122
|
self.threshold_rows,
|
|
989
1123
|
self.threshold_cols,
|
|
990
1124
|
)
|
|
1125
|
+
|
|
991
1126
|
for segment_result in spanning_cell_raw_segments:
|
|
1127
|
+
if (
|
|
1128
|
+
(segment_result.rs == 1 and segment_result.cs == 1)
|
|
1129
|
+
or segment_result.rs == 0
|
|
1130
|
+
or segment_result.cs == 0
|
|
1131
|
+
):
|
|
1132
|
+
self.dp_manager.deactivate_annotation(segment_result.annotation_id)
|
|
1133
|
+
continue
|
|
992
1134
|
self.dp_manager.set_category_annotation(
|
|
993
1135
|
CellType.ROW_NUMBER, segment_result.row_num, CellType.ROW_NUMBER, segment_result.annotation_id
|
|
994
1136
|
)
|
|
@@ -1009,6 +1151,19 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
1009
1151
|
cell_ann_id = cell_rn_cn_to_ann_id[cell_position]
|
|
1010
1152
|
self.dp_manager.deactivate_annotation(cell_ann_id)
|
|
1011
1153
|
|
|
1154
|
+
for segment_result in spanning_cell_raw_segments:
|
|
1155
|
+
if (
|
|
1156
|
+
(segment_result.rs == 1 and segment_result.cs == 1)
|
|
1157
|
+
or segment_result.rs == 0
|
|
1158
|
+
or segment_result.cs == 0
|
|
1159
|
+
):
|
|
1160
|
+
continue
|
|
1161
|
+
for rs in range(segment_result.rs):
|
|
1162
|
+
for cs in range(segment_result.cs):
|
|
1163
|
+
cell_rn_cn_to_ann_id[
|
|
1164
|
+
(segment_result.row_num + rs, segment_result.col_num + cs)
|
|
1165
|
+
] = segment_result.annotation_id
|
|
1166
|
+
|
|
1012
1167
|
cells = []
|
|
1013
1168
|
if table.image:
|
|
1014
1169
|
cells = table.image.get_annotation(category_names=self.cell_names)
|
|
@@ -1022,6 +1177,28 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
1022
1177
|
number_of_cols = 0
|
|
1023
1178
|
max_row_span = 0
|
|
1024
1179
|
max_col_span = 0
|
|
1180
|
+
|
|
1181
|
+
for idx, item_vals in enumerate(zip(self.item_names, self.item_header_cell_names, self.sub_item_names)):
|
|
1182
|
+
item_obj_type, item_header_cell_name, sub_item_name = item_vals[0], item_vals[1], item_vals[2]
|
|
1183
|
+
|
|
1184
|
+
if has_item_headers[item_obj_type]:
|
|
1185
|
+
items = dp.get_annotation(category_names=item_obj_type)
|
|
1186
|
+
|
|
1187
|
+
for item_ann in items:
|
|
1188
|
+
if item_header_cell_name in item_ann.sub_categories:
|
|
1189
|
+
item_number = item_ann.get_sub_category(sub_item_name).category_id
|
|
1190
|
+
for key, value in cell_rn_cn_to_ann_id.items():
|
|
1191
|
+
if key[idx] == item_number:
|
|
1192
|
+
cell_ann = dp.get_annotation(annotation_ids=value)[0]
|
|
1193
|
+
self.dp_manager.set_category_annotation(
|
|
1194
|
+
item_header_cell_name, None, item_header_cell_name, cell_ann.annotation_id
|
|
1195
|
+
)
|
|
1196
|
+
else:
|
|
1197
|
+
cell_ann = dp.get_annotation(annotation_ids=value)[0]
|
|
1198
|
+
self.dp_manager.set_category_annotation(
|
|
1199
|
+
item_header_cell_name, None, CellType.BODY, cell_ann.annotation_id
|
|
1200
|
+
)
|
|
1201
|
+
|
|
1025
1202
|
# TODO: the summaries should be sub categories of the underlying ann
|
|
1026
1203
|
self.dp_manager.set_summary_annotation(
|
|
1027
1204
|
TableType.NUMBER_OF_ROWS, TableType.NUMBER_OF_ROWS, number_of_rows, annotation_id=table.annotation_id
|
|
@@ -1038,7 +1215,7 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
1038
1215
|
self.dp_manager.set_summary_annotation(
|
|
1039
1216
|
TableType.MAX_COL_SPAN, TableType.MAX_COL_SPAN, max_col_span, annotation_id=table.annotation_id
|
|
1040
1217
|
)
|
|
1041
|
-
html = generate_html_string(table)
|
|
1218
|
+
html = generate_html_string(table, self.cell_names + self.spanning_cell_names)
|
|
1042
1219
|
self.dp_manager.set_container_annotation(TableType.HTML, -1, TableType.HTML, table.annotation_id, html)
|
|
1043
1220
|
|
|
1044
1221
|
def clone(self) -> PubtablesSegmentationService:
|
|
@@ -1055,6 +1232,8 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
1055
1232
|
self.spanning_cell_names,
|
|
1056
1233
|
self.item_names,
|
|
1057
1234
|
self.sub_item_names,
|
|
1235
|
+
self.item_header_cell_names,
|
|
1236
|
+
self.item_header_thresholds,
|
|
1058
1237
|
self.cell_to_image,
|
|
1059
1238
|
self.crop_cell_image,
|
|
1060
1239
|
self.stretch_rule,
|