deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (111) hide show
  1. deepdoctection/__init__.py +8 -25
  2. deepdoctection/analyzer/dd.py +84 -71
  3. deepdoctection/dataflow/common.py +9 -5
  4. deepdoctection/dataflow/custom.py +5 -5
  5. deepdoctection/dataflow/custom_serialize.py +75 -18
  6. deepdoctection/dataflow/parallel_map.py +3 -3
  7. deepdoctection/dataflow/serialize.py +4 -4
  8. deepdoctection/dataflow/stats.py +3 -3
  9. deepdoctection/datapoint/annotation.py +78 -56
  10. deepdoctection/datapoint/box.py +7 -7
  11. deepdoctection/datapoint/convert.py +6 -6
  12. deepdoctection/datapoint/image.py +157 -75
  13. deepdoctection/datapoint/view.py +175 -151
  14. deepdoctection/datasets/adapter.py +30 -24
  15. deepdoctection/datasets/base.py +10 -10
  16. deepdoctection/datasets/dataflow_builder.py +3 -3
  17. deepdoctection/datasets/info.py +23 -25
  18. deepdoctection/datasets/instances/doclaynet.py +48 -49
  19. deepdoctection/datasets/instances/fintabnet.py +44 -45
  20. deepdoctection/datasets/instances/funsd.py +23 -23
  21. deepdoctection/datasets/instances/iiitar13k.py +8 -8
  22. deepdoctection/datasets/instances/layouttest.py +2 -2
  23. deepdoctection/datasets/instances/publaynet.py +3 -3
  24. deepdoctection/datasets/instances/pubtables1m.py +18 -18
  25. deepdoctection/datasets/instances/pubtabnet.py +30 -29
  26. deepdoctection/datasets/instances/rvlcdip.py +28 -29
  27. deepdoctection/datasets/instances/xfund.py +51 -30
  28. deepdoctection/datasets/save.py +6 -6
  29. deepdoctection/eval/accmetric.py +32 -33
  30. deepdoctection/eval/base.py +8 -9
  31. deepdoctection/eval/cocometric.py +13 -12
  32. deepdoctection/eval/eval.py +32 -26
  33. deepdoctection/eval/tedsmetric.py +16 -12
  34. deepdoctection/eval/tp_eval_callback.py +7 -16
  35. deepdoctection/extern/base.py +339 -134
  36. deepdoctection/extern/d2detect.py +69 -89
  37. deepdoctection/extern/deskew.py +11 -10
  38. deepdoctection/extern/doctrocr.py +81 -64
  39. deepdoctection/extern/fastlang.py +23 -16
  40. deepdoctection/extern/hfdetr.py +53 -38
  41. deepdoctection/extern/hflayoutlm.py +216 -155
  42. deepdoctection/extern/hflm.py +35 -30
  43. deepdoctection/extern/model.py +433 -255
  44. deepdoctection/extern/pdftext.py +15 -15
  45. deepdoctection/extern/pt/ptutils.py +4 -2
  46. deepdoctection/extern/tessocr.py +39 -38
  47. deepdoctection/extern/texocr.py +14 -16
  48. deepdoctection/extern/tp/tfutils.py +16 -2
  49. deepdoctection/extern/tp/tpcompat.py +11 -7
  50. deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
  56. deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
  57. deepdoctection/extern/tpdetect.py +40 -45
  58. deepdoctection/mapper/cats.py +36 -40
  59. deepdoctection/mapper/cocostruct.py +16 -12
  60. deepdoctection/mapper/d2struct.py +22 -22
  61. deepdoctection/mapper/hfstruct.py +7 -7
  62. deepdoctection/mapper/laylmstruct.py +22 -24
  63. deepdoctection/mapper/maputils.py +9 -10
  64. deepdoctection/mapper/match.py +33 -2
  65. deepdoctection/mapper/misc.py +6 -7
  66. deepdoctection/mapper/pascalstruct.py +4 -4
  67. deepdoctection/mapper/prodigystruct.py +6 -6
  68. deepdoctection/mapper/pubstruct.py +84 -92
  69. deepdoctection/mapper/tpstruct.py +3 -3
  70. deepdoctection/mapper/xfundstruct.py +33 -33
  71. deepdoctection/pipe/anngen.py +39 -14
  72. deepdoctection/pipe/base.py +68 -99
  73. deepdoctection/pipe/common.py +181 -85
  74. deepdoctection/pipe/concurrency.py +14 -10
  75. deepdoctection/pipe/doctectionpipe.py +24 -21
  76. deepdoctection/pipe/language.py +20 -25
  77. deepdoctection/pipe/layout.py +18 -16
  78. deepdoctection/pipe/lm.py +49 -47
  79. deepdoctection/pipe/order.py +63 -65
  80. deepdoctection/pipe/refine.py +102 -109
  81. deepdoctection/pipe/segment.py +157 -162
  82. deepdoctection/pipe/sub_layout.py +50 -40
  83. deepdoctection/pipe/text.py +37 -36
  84. deepdoctection/pipe/transform.py +19 -16
  85. deepdoctection/train/d2_frcnn_train.py +27 -25
  86. deepdoctection/train/hf_detr_train.py +22 -18
  87. deepdoctection/train/hf_layoutlm_train.py +49 -48
  88. deepdoctection/train/tp_frcnn_train.py +10 -11
  89. deepdoctection/utils/concurrency.py +1 -1
  90. deepdoctection/utils/context.py +13 -6
  91. deepdoctection/utils/develop.py +4 -4
  92. deepdoctection/utils/env_info.py +52 -14
  93. deepdoctection/utils/file_utils.py +6 -11
  94. deepdoctection/utils/fs.py +41 -14
  95. deepdoctection/utils/identifier.py +2 -2
  96. deepdoctection/utils/logger.py +15 -15
  97. deepdoctection/utils/metacfg.py +7 -7
  98. deepdoctection/utils/pdf_utils.py +39 -14
  99. deepdoctection/utils/settings.py +188 -182
  100. deepdoctection/utils/tqdm.py +1 -1
  101. deepdoctection/utils/transform.py +14 -9
  102. deepdoctection/utils/types.py +104 -0
  103. deepdoctection/utils/utils.py +7 -7
  104. deepdoctection/utils/viz.py +70 -69
  105. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
  106. deepdoctection-0.34.dist-info/RECORD +146 -0
  107. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
  108. deepdoctection/utils/detection_types.py +0 -68
  109. deepdoctection-0.32.dist-info/RECORD +0 -146
  110. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
  111. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
@@ -19,11 +19,13 @@
19
19
  Module for refining methods of table segmentation. The refining methods lead ultimately to a table structure which
20
20
  enables html table representations
21
21
  """
22
+ from __future__ import annotations
23
+
22
24
  from collections import defaultdict
23
25
  from copy import copy
24
26
  from dataclasses import asdict
25
27
  from itertools import chain, product
26
- from typing import DefaultDict, List, Optional, Sequence, Set, Tuple, Union
28
+ from typing import DefaultDict, Optional, Sequence, Union
27
29
 
28
30
  import networkx as nx # type: ignore
29
31
 
@@ -32,16 +34,15 @@ from ..datapoint.box import merge_boxes
32
34
  from ..datapoint.image import Image
33
35
  from ..extern.base import DetectionResult
34
36
  from ..mapper.maputils import MappingContextManager
35
- from ..utils.detection_types import JsonDict
36
- from ..utils.error import AnnotationError, ImageError
37
+ from ..utils.error import ImageError
37
38
  from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType, get_type
38
- from .base import PipelineComponent
39
+ from .base import MetaAnnotation, PipelineComponent
39
40
  from .registry import pipeline_component_registry
40
41
 
41
42
  __all__ = ["TableSegmentationRefinementService", "generate_html_string"]
42
43
 
43
44
 
44
- def tiles_to_cells(dp: Image, table: ImageAnnotation) -> List[Tuple[Tuple[int, int], str]]:
45
+ def tiles_to_cells(dp: Image, table: ImageAnnotation) -> list[tuple[tuple[int, int], str]]:
45
46
  """
46
47
  Creation of a table parquet: A table is divided into a tile parquet with the (number of rows) x
47
48
  (the number of columns) tiles.
@@ -53,17 +54,17 @@ def tiles_to_cells(dp: Image, table: ImageAnnotation) -> List[Tuple[Tuple[int, i
53
54
  :return: Image
54
55
  """
55
56
 
56
- cell_ann_ids = table.get_relationship(Relationships.child)
57
+ cell_ann_ids = table.get_relationship(Relationships.CHILD)
57
58
  cells = dp.get_annotation(
58
- category_names=[LayoutType.cell, CellType.header, CellType.body], annotation_ids=cell_ann_ids
59
+ category_names=[LayoutType.CELL, CellType.HEADER, CellType.BODY], annotation_ids=cell_ann_ids
59
60
  )
60
61
  tile_to_cells = []
61
62
 
62
63
  for cell in cells:
63
- row_number = int(cell.get_sub_category(CellType.row_number).category_id)
64
- col_number = int(cell.get_sub_category(CellType.column_number).category_id)
65
- rs = int(cell.get_sub_category(CellType.row_span).category_id)
66
- cs = int(cell.get_sub_category(CellType.column_span).category_id)
64
+ row_number = cell.get_sub_category(CellType.ROW_NUMBER).category_id
65
+ col_number = cell.get_sub_category(CellType.COLUMN_NUMBER).category_id
66
+ rs = cell.get_sub_category(CellType.ROW_SPAN).category_id
67
+ cs = cell.get_sub_category(CellType.COLUMN_SPAN).category_id
67
68
  for k in range(rs):
68
69
  for l in range(cs):
69
70
  assert cell.annotation_id is not None, cell.annotation_id
@@ -73,15 +74,15 @@ def tiles_to_cells(dp: Image, table: ImageAnnotation) -> List[Tuple[Tuple[int, i
73
74
 
74
75
 
75
76
  def connected_component_tiles(
76
- tile_to_cell_list: List[Tuple[Tuple[int, int], str]]
77
- ) -> Tuple[List[Set[Tuple[int, int]]], DefaultDict[Tuple[int, int], List[str]]]:
77
+ tile_to_cell_list: list[tuple[tuple[int, int], str]]
78
+ ) -> tuple[list[set[tuple[int, int]]], DefaultDict[tuple[int, int], list[str]]]:
78
79
  """
79
80
  The assignment of bricks to their cell occupancy induces a graph, with bricks as corners and cell edges. Cells that
80
81
  lie on top of several bricks connect the underlying bricks. The graph generated according to this procedure is
81
82
  usually multiple connected. The related components and the tile/cell ids assignment are determined.
82
83
 
83
- :param tile_to_cell_list: List of tuples with tile position and cell ids
84
- :return: List of set with tiles that belong to the same connected component and a dict with tiles as keys and
84
+ :param tile_to_cell_list: list of tuples with tile position and cell ids
85
+ :return: list of set with tiles that belong to the same connected component and a dict with tiles as keys and
85
86
  assigned list of cell ids as values.
86
87
  """
87
88
  cell_to_tile_list = [(cell_position[1], cell_position[0]) for cell_position in tile_to_cell_list]
@@ -107,7 +108,7 @@ def connected_component_tiles(
107
108
  connected_components_tiles = []
108
109
 
109
110
  for component in connected_components_cell:
110
- tiles: Set[Tuple[int, int]] = set()
111
+ tiles: set[tuple[int, int]] = set()
111
112
  for cell in component:
112
113
  tiles = tiles.union(set(cell_to_tile_dict[cell])) # type: ignore
113
114
  connected_components_tiles.append(tiles)
@@ -115,7 +116,7 @@ def connected_component_tiles(
115
116
  return connected_components_tiles, tile_to_cell_dict
116
117
 
117
118
 
118
- def _missing_tile(inputs: Set[Tuple[int, int]]) -> Optional[Tuple[int, int]]:
119
+ def _missing_tile(inputs: set[tuple[int, int]]) -> Optional[tuple[int, int]]:
119
120
  min_x, min_y, max_x, max_y = (
120
121
  min(a[0] for a in inputs),
121
122
  min(a[1] for a in inputs),
@@ -131,15 +132,15 @@ def _missing_tile(inputs: Set[Tuple[int, int]]) -> Optional[Tuple[int, int]]:
131
132
 
132
133
 
133
134
  def _find_component(
134
- tile: Tuple[int, int], reduced_connected_tiles: List[Set[Tuple[int, int]]]
135
- ) -> Optional[Set[Tuple[int, int]]]:
135
+ tile: tuple[int, int], reduced_connected_tiles: list[set[tuple[int, int]]]
136
+ ) -> Optional[set[tuple[int, int]]]:
136
137
  for comp in reduced_connected_tiles:
137
138
  if tile in comp:
138
139
  return comp
139
140
  return None
140
141
 
141
142
 
142
- def _merge_components(reduced_connected_tiles: List[Set[Tuple[int, int]]]) -> List[Set[Tuple[int, int]]]:
143
+ def _merge_components(reduced_connected_tiles: list[set[tuple[int, int]]]) -> list[set[tuple[int, int]]]:
143
144
  new_reduced_connected_tiles = []
144
145
  for connected_tile in reduced_connected_tiles:
145
146
  out = _missing_tile(connected_tile)
@@ -161,17 +162,17 @@ def _merge_components(reduced_connected_tiles: List[Set[Tuple[int, int]]]) -> Li
161
162
  return new_reduced_connected_tiles
162
163
 
163
164
 
164
- def generate_rectangle_tiling(connected_components_tiles: List[Set[Tuple[int, int]]]) -> List[Set[Tuple[int, int]]]:
165
+ def generate_rectangle_tiling(connected_components_tiles: list[set[tuple[int, int]]]) -> list[set[tuple[int, int]]]:
165
166
  """
166
167
  The determined connected components imply that all cells have to be combined which are above a connected component.
167
168
  In addition, however, it must also be taken into account that cells must be rectangular. This means that related
168
169
  components have to be combined whose combined cells above do not create a rectangular tiling. All tiles are combined
169
170
  in such a way that all cells above them combine to form a rectangular scheme.
170
171
 
171
- :param connected_components_tiles: List of set with tiles that belong to the same connected component
172
- :return: List of sets with tiles, the cells on top of which together form a rectangular scheme
172
+ :param connected_components_tiles: list of set with tiles that belong to the same connected component
173
+ :return: list of sets with tiles, the cells on top of which together form a rectangular scheme
173
174
  """
174
- rectangle_tiling: List[Set[Tuple[int, int]]] = []
175
+ rectangle_tiling: list[set[tuple[int, int]]] = []
175
176
  inputs = connected_components_tiles
176
177
 
177
178
  while rectangle_tiling != inputs:
@@ -183,25 +184,25 @@ def generate_rectangle_tiling(connected_components_tiles: List[Set[Tuple[int, in
183
184
 
184
185
 
185
186
  def rectangle_cells(
186
- rectangle_tiling: List[Set[Tuple[int, int]]], tile_to_cell_dict: DefaultDict[Tuple[int, int], List[str]]
187
- ) -> List[Set[str]]:
187
+ rectangle_tiling: list[set[tuple[int, int]]], tile_to_cell_dict: DefaultDict[tuple[int, int], list[str]]
188
+ ) -> list[set[str]]:
188
189
  """
189
190
  All cells are determined that are located above combined connected components and form a rectangular scheme.
190
191
 
191
- :param rectangle_tiling: List of sets with tiles, the cells on top of which together form a rectangular scheme
192
+ :param rectangle_tiling: list of sets with tiles, the cells on top of which together form a rectangular scheme
192
193
  :param tile_to_cell_dict: Dict with tiles as keys and assigned list of cell ids as values.
193
- :return: List of set of cell ids that form a rectangular scheme
194
+ :return: list of set of cell ids that form a rectangular scheme
194
195
  """
195
- rectangle_tiling_cells: List[Set[str]] = []
196
+ rectangle_tiling_cells: list[set[str]] = []
196
197
  for rect_tiling_component in rectangle_tiling:
197
- rect_cell_component: Set[str] = set()
198
+ rect_cell_component: set[str] = set()
198
199
  for el in rect_tiling_component:
199
200
  rect_cell_component = rect_cell_component.union(set(tile_to_cell_dict[el]))
200
201
  rectangle_tiling_cells.append(rect_cell_component)
201
202
  return rectangle_tiling_cells
202
203
 
203
204
 
204
- def _tiling_to_cell_position(inputs: Set[Tuple[int, int]]) -> Tuple[int, int, int, int]:
205
+ def _tiling_to_cell_position(inputs: set[tuple[int, int]]) -> tuple[int, int, int, int]:
205
206
  row_number = min(a[0] for a in inputs)
206
207
  col_number = min(a[1] for a in inputs)
207
208
  row_span = max(abs(a[0] - b[0]) + 1 for a in inputs for b in inputs)
@@ -210,8 +211,8 @@ def _tiling_to_cell_position(inputs: Set[Tuple[int, int]]) -> Tuple[int, int, in
210
211
 
211
212
 
212
213
  def _html_cell(
213
- cell_position: Union[Tuple[int, int, int, int], Tuple[()]], position_filled_list: List[Tuple[int, int]]
214
- ) -> List[str]:
214
+ cell_position: Union[tuple[int, int, int, int], tuple[()]], position_filled_list: list[tuple[int, int]]
215
+ ) -> list[str]:
215
216
  """
216
217
  Html table cell string generation
217
218
  """
@@ -238,12 +239,12 @@ def _html_cell(
238
239
 
239
240
 
240
241
  def _html_row(
241
- row_list: List[Tuple[int, int, int, int]],
242
- position_filled_list: List[Tuple[int, int]],
242
+ row_list: list[tuple[int, int, int, int]],
243
+ position_filled_list: list[tuple[int, int]],
243
244
  this_row: int,
244
245
  number_of_cols: int,
245
- row_ann_id_list: List[str],
246
- ) -> List[str]:
246
+ row_ann_id_list: list[str],
247
+ ) -> list[str]:
247
248
  """
248
249
  Html table row string generation
249
250
  """
@@ -275,16 +276,16 @@ def _html_row(
275
276
 
276
277
 
277
278
  def _html_table(
278
- table_list: List[Tuple[int, List[Tuple[int, int, int, int]]]],
279
- cells_ann_list: List[Tuple[int, List[str]]],
279
+ table_list: list[tuple[int, list[tuple[int, int, int, int]]]],
280
+ cells_ann_list: list[tuple[int, list[str]]],
280
281
  number_of_rows: int,
281
282
  number_of_cols: int,
282
- ) -> List[str]:
283
+ ) -> list[str]:
283
284
  """
284
285
  Html table string generation
285
286
  """
286
287
  html = ["<table>"]
287
- position_filled: List[Tuple[int, int]] = []
288
+ position_filled: list[tuple[int, int]] = []
288
289
  for idx in range(1, number_of_rows + 1):
289
290
  row_idx = list(filter(lambda x: x[0] == idx, table_list))[0][1] # pylint:disable=W0640
290
291
  row_ann_ids = list(filter(lambda x: x[0] == idx, cells_ann_list))[0][1] # pylint:disable=W0640
@@ -294,7 +295,7 @@ def _html_table(
294
295
  return html
295
296
 
296
297
 
297
- def generate_html_string(table: ImageAnnotation) -> List[str]:
298
+ def generate_html_string(table: ImageAnnotation) -> list[str]:
298
299
  """
299
300
  Takes the table segmentation by using table cells row number, column numbers etc. and generates a html
300
301
  representation.
@@ -307,36 +308,36 @@ def generate_html_string(table: ImageAnnotation) -> List[str]:
307
308
  table_image = table.image
308
309
  cells = table_image.get_annotation(
309
310
  category_names=[
310
- LayoutType.cell,
311
- CellType.header,
312
- CellType.body,
313
- CellType.spanning,
314
- CellType.row_header,
315
- CellType.column_header,
316
- CellType.projected_row_header,
311
+ LayoutType.CELL,
312
+ CellType.HEADER,
313
+ CellType.BODY,
314
+ CellType.SPANNING,
315
+ CellType.ROW_HEADER,
316
+ CellType.COLUMN_HEADER,
317
+ CellType.PROJECTED_ROW_HEADER,
317
318
  ]
318
319
  )
319
- number_of_rows = int(table_image.summary.get_sub_category(TableType.number_of_rows).category_id)
320
- number_of_cols = int(table_image.summary.get_sub_category(TableType.number_of_columns).category_id)
320
+ number_of_rows = table_image.summary.get_sub_category(TableType.NUMBER_OF_ROWS).category_id
321
+ number_of_cols = table_image.summary.get_sub_category(TableType.NUMBER_OF_COLUMNS).category_id
321
322
  table_list = []
322
323
  cells_ann_list = []
323
324
  for row_number in range(1, number_of_rows + 1):
324
325
  cells_of_row = list(
325
326
  sorted(
326
327
  filter(
327
- lambda cell: cell.get_sub_category(CellType.row_number).category_id
328
- == str(row_number), # pylint: disable=W0640
328
+ lambda cell: cell.get_sub_category(CellType.ROW_NUMBER).category_id
329
+ == row_number, # pylint: disable=W0640
329
330
  cells,
330
331
  ),
331
- key=lambda cell: cell.get_sub_category(CellType.column_number).category_id,
332
+ key=lambda cell: cell.get_sub_category(CellType.COLUMN_NUMBER).category_id,
332
333
  )
333
334
  )
334
335
  row_list = [
335
336
  (
336
- int(cell.get_sub_category(CellType.row_number).category_id),
337
- int(cell.get_sub_category(CellType.column_number).category_id),
338
- int(cell.get_sub_category(CellType.row_span).category_id),
339
- int(cell.get_sub_category(CellType.column_span).category_id),
337
+ cell.get_sub_category(CellType.ROW_NUMBER).category_id,
338
+ cell.get_sub_category(CellType.COLUMN_NUMBER).category_id,
339
+ cell.get_sub_category(CellType.ROW_SPAN).category_id,
340
+ cell.get_sub_category(CellType.COLUMN_SPAN).category_id,
340
341
  )
341
342
  for cell in cells_of_row
342
343
  ]
@@ -421,23 +422,23 @@ class TableSegmentationRefinementService(PipelineComponent):
421
422
  det_result = DetectionResult(
422
423
  box=merged_box.to_list(mode="xyxy"),
423
424
  score=-1.0,
424
- class_id=int(cells[0].category_id),
425
+ class_id=cells[0].category_id,
425
426
  class_name=get_type(cells[0].category_name),
426
427
  )
427
428
  new_cell_ann_id = self.dp_manager.set_image_annotation(det_result, table.annotation_id)
428
429
  if new_cell_ann_id is not None:
429
430
  row_number, col_number, row_span, col_span = _tiling_to_cell_position(tiling)
430
431
  self.dp_manager.set_category_annotation(
431
- CellType.row_number, row_number, CellType.row_number, new_cell_ann_id
432
+ CellType.ROW_NUMBER, row_number, CellType.ROW_NUMBER, new_cell_ann_id
432
433
  )
433
434
  self.dp_manager.set_category_annotation(
434
- CellType.column_number, col_number, CellType.column_number, new_cell_ann_id
435
+ CellType.COLUMN_NUMBER, col_number, CellType.COLUMN_NUMBER, new_cell_ann_id
435
436
  )
436
437
  self.dp_manager.set_category_annotation(
437
- CellType.row_span, row_span, CellType.row_span, new_cell_ann_id
438
+ CellType.ROW_SPAN, row_span, CellType.ROW_SPAN, new_cell_ann_id
438
439
  )
439
440
  self.dp_manager.set_category_annotation(
440
- CellType.column_span, col_span, CellType.column_span, new_cell_ann_id
441
+ CellType.COLUMN_SPAN, col_span, CellType.COLUMN_SPAN, new_cell_ann_id
441
442
  )
442
443
  else:
443
444
  # DetectionResult cannot be dumped, hence merged_box must already exist. Hence, it must
@@ -453,66 +454,58 @@ class TableSegmentationRefinementService(PipelineComponent):
453
454
  cell.deactivate()
454
455
 
455
456
  cells = table.image.get_annotation(category_names=self.cell_names)
456
- number_of_rows = max(int(cell.get_sub_category(CellType.row_number).category_id) for cell in cells)
457
- number_of_cols = max(int(cell.get_sub_category(CellType.column_number).category_id) for cell in cells)
458
- max_row_span = max(int(cell.get_sub_category(CellType.row_span).category_id) for cell in cells)
459
- max_col_span = max(int(cell.get_sub_category(CellType.column_span).category_id) for cell in cells)
457
+ number_of_rows = max(cell.get_sub_category(CellType.ROW_NUMBER).category_id for cell in cells)
458
+ number_of_cols = max(cell.get_sub_category(CellType.COLUMN_NUMBER).category_id for cell in cells)
459
+ max_row_span = max(cell.get_sub_category(CellType.ROW_SPAN).category_id for cell in cells)
460
+ max_col_span = max(cell.get_sub_category(CellType.COLUMN_SPAN).category_id for cell in cells)
460
461
  # TODO: the summaries should be sub categories of the underlying ann
461
- if table.image.summary is not None:
462
- if (
463
- TableType.number_of_rows in table.image.summary.sub_categories
464
- and TableType.number_of_columns in table.image.summary.sub_categories
465
- and TableType.max_row_span in table.image.summary.sub_categories
466
- and TableType.max_col_span in table.image.summary.sub_categories
467
- ):
468
- table.image.summary.remove_sub_category(TableType.number_of_rows)
469
- table.image.summary.remove_sub_category(TableType.number_of_columns)
470
- table.image.summary.remove_sub_category(TableType.max_row_span)
471
- table.image.summary.remove_sub_category(TableType.max_col_span)
472
- else:
473
- raise AnnotationError(
474
- "Table summary does not contain sub categories TableType.number_of_rows, "
475
- "TableType.number_of_columns, TableType.max_row_span, TableType.max_col_span"
476
- )
462
+ if (
463
+ TableType.NUMBER_OF_ROWS in table.image.summary.sub_categories
464
+ and TableType.NUMBER_OF_COLUMNS in table.image.summary.sub_categories
465
+ and TableType.MAX_ROW_SPAN in table.image.summary.sub_categories
466
+ and TableType.MAX_COL_SPAN in table.image.summary.sub_categories
467
+ ):
468
+ table.image.summary.remove_sub_category(TableType.NUMBER_OF_ROWS)
469
+ table.image.summary.remove_sub_category(TableType.NUMBER_OF_COLUMNS)
470
+ table.image.summary.remove_sub_category(TableType.MAX_ROW_SPAN)
471
+ table.image.summary.remove_sub_category(TableType.MAX_COL_SPAN)
477
472
 
478
473
  self.dp_manager.set_summary_annotation(
479
- TableType.number_of_rows, TableType.number_of_rows, number_of_rows, annotation_id=table.annotation_id
474
+ TableType.NUMBER_OF_ROWS, TableType.NUMBER_OF_ROWS, number_of_rows, annotation_id=table.annotation_id
480
475
  )
481
476
  self.dp_manager.set_summary_annotation(
482
- TableType.number_of_columns,
483
- TableType.number_of_columns,
477
+ TableType.NUMBER_OF_COLUMNS,
478
+ TableType.NUMBER_OF_COLUMNS,
484
479
  number_of_cols,
485
480
  annotation_id=table.annotation_id,
486
481
  )
487
482
  self.dp_manager.set_summary_annotation(
488
- TableType.max_row_span, TableType.max_row_span, max_row_span, annotation_id=table.annotation_id
483
+ TableType.MAX_ROW_SPAN, TableType.MAX_ROW_SPAN, max_row_span, annotation_id=table.annotation_id
489
484
  )
490
485
  self.dp_manager.set_summary_annotation(
491
- TableType.max_col_span, TableType.max_col_span, max_col_span, annotation_id=table.annotation_id
486
+ TableType.MAX_COL_SPAN, TableType.MAX_COL_SPAN, max_col_span, annotation_id=table.annotation_id
492
487
  )
493
488
  html = generate_html_string(table)
494
- self.dp_manager.set_container_annotation(TableType.html, -1, TableType.html, table.annotation_id, html)
489
+ self.dp_manager.set_container_annotation(TableType.HTML, -1, TableType.HTML, table.annotation_id, html)
495
490
 
496
- def clone(self) -> PipelineComponent:
491
+ def clone(self) -> TableSegmentationRefinementService:
497
492
  return self.__class__(self.table_name, self.cell_names)
498
493
 
499
- def get_meta_annotation(self) -> JsonDict:
500
- return dict(
501
- [
502
- ("image_annotations", []),
503
- (
504
- "sub_categories",
505
- {
506
- LayoutType.cell: {
507
- CellType.row_number,
508
- CellType.column_number,
509
- CellType.row_span,
510
- CellType.column_span,
511
- },
512
- LayoutType.table: {TableType.html},
513
- },
514
- ),
515
- ("relationships", {}),
516
- ("summaries", []),
517
- ]
494
+ def get_meta_annotation(self) -> MetaAnnotation:
495
+ return MetaAnnotation(
496
+ image_annotations=(),
497
+ sub_categories={
498
+ LayoutType.CELL: {
499
+ CellType.ROW_NUMBER,
500
+ CellType.COLUMN_NUMBER,
501
+ CellType.ROW_SPAN,
502
+ CellType.COLUMN_SPAN,
503
+ },
504
+ LayoutType.TABLE: {TableType.HTML},
505
+ },
506
+ relationships={},
507
+ summaries=(),
518
508
  )
509
+
510
+ def clear_predictor(self) -> None:
511
+ pass