deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (111) hide show
  1. deepdoctection/__init__.py +8 -25
  2. deepdoctection/analyzer/dd.py +84 -71
  3. deepdoctection/dataflow/common.py +9 -5
  4. deepdoctection/dataflow/custom.py +5 -5
  5. deepdoctection/dataflow/custom_serialize.py +75 -18
  6. deepdoctection/dataflow/parallel_map.py +3 -3
  7. deepdoctection/dataflow/serialize.py +4 -4
  8. deepdoctection/dataflow/stats.py +3 -3
  9. deepdoctection/datapoint/annotation.py +78 -56
  10. deepdoctection/datapoint/box.py +7 -7
  11. deepdoctection/datapoint/convert.py +6 -6
  12. deepdoctection/datapoint/image.py +157 -75
  13. deepdoctection/datapoint/view.py +175 -151
  14. deepdoctection/datasets/adapter.py +30 -24
  15. deepdoctection/datasets/base.py +10 -10
  16. deepdoctection/datasets/dataflow_builder.py +3 -3
  17. deepdoctection/datasets/info.py +23 -25
  18. deepdoctection/datasets/instances/doclaynet.py +48 -49
  19. deepdoctection/datasets/instances/fintabnet.py +44 -45
  20. deepdoctection/datasets/instances/funsd.py +23 -23
  21. deepdoctection/datasets/instances/iiitar13k.py +8 -8
  22. deepdoctection/datasets/instances/layouttest.py +2 -2
  23. deepdoctection/datasets/instances/publaynet.py +3 -3
  24. deepdoctection/datasets/instances/pubtables1m.py +18 -18
  25. deepdoctection/datasets/instances/pubtabnet.py +30 -29
  26. deepdoctection/datasets/instances/rvlcdip.py +28 -29
  27. deepdoctection/datasets/instances/xfund.py +51 -30
  28. deepdoctection/datasets/save.py +6 -6
  29. deepdoctection/eval/accmetric.py +32 -33
  30. deepdoctection/eval/base.py +8 -9
  31. deepdoctection/eval/cocometric.py +13 -12
  32. deepdoctection/eval/eval.py +32 -26
  33. deepdoctection/eval/tedsmetric.py +16 -12
  34. deepdoctection/eval/tp_eval_callback.py +7 -16
  35. deepdoctection/extern/base.py +339 -134
  36. deepdoctection/extern/d2detect.py +69 -89
  37. deepdoctection/extern/deskew.py +11 -10
  38. deepdoctection/extern/doctrocr.py +81 -64
  39. deepdoctection/extern/fastlang.py +23 -16
  40. deepdoctection/extern/hfdetr.py +53 -38
  41. deepdoctection/extern/hflayoutlm.py +216 -155
  42. deepdoctection/extern/hflm.py +35 -30
  43. deepdoctection/extern/model.py +433 -255
  44. deepdoctection/extern/pdftext.py +15 -15
  45. deepdoctection/extern/pt/ptutils.py +4 -2
  46. deepdoctection/extern/tessocr.py +39 -38
  47. deepdoctection/extern/texocr.py +14 -16
  48. deepdoctection/extern/tp/tfutils.py +16 -2
  49. deepdoctection/extern/tp/tpcompat.py +11 -7
  50. deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
  56. deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
  57. deepdoctection/extern/tpdetect.py +40 -45
  58. deepdoctection/mapper/cats.py +36 -40
  59. deepdoctection/mapper/cocostruct.py +16 -12
  60. deepdoctection/mapper/d2struct.py +22 -22
  61. deepdoctection/mapper/hfstruct.py +7 -7
  62. deepdoctection/mapper/laylmstruct.py +22 -24
  63. deepdoctection/mapper/maputils.py +9 -10
  64. deepdoctection/mapper/match.py +33 -2
  65. deepdoctection/mapper/misc.py +6 -7
  66. deepdoctection/mapper/pascalstruct.py +4 -4
  67. deepdoctection/mapper/prodigystruct.py +6 -6
  68. deepdoctection/mapper/pubstruct.py +84 -92
  69. deepdoctection/mapper/tpstruct.py +3 -3
  70. deepdoctection/mapper/xfundstruct.py +33 -33
  71. deepdoctection/pipe/anngen.py +39 -14
  72. deepdoctection/pipe/base.py +68 -99
  73. deepdoctection/pipe/common.py +181 -85
  74. deepdoctection/pipe/concurrency.py +14 -10
  75. deepdoctection/pipe/doctectionpipe.py +24 -21
  76. deepdoctection/pipe/language.py +20 -25
  77. deepdoctection/pipe/layout.py +18 -16
  78. deepdoctection/pipe/lm.py +49 -47
  79. deepdoctection/pipe/order.py +63 -65
  80. deepdoctection/pipe/refine.py +102 -109
  81. deepdoctection/pipe/segment.py +157 -162
  82. deepdoctection/pipe/sub_layout.py +50 -40
  83. deepdoctection/pipe/text.py +37 -36
  84. deepdoctection/pipe/transform.py +19 -16
  85. deepdoctection/train/d2_frcnn_train.py +27 -25
  86. deepdoctection/train/hf_detr_train.py +22 -18
  87. deepdoctection/train/hf_layoutlm_train.py +49 -48
  88. deepdoctection/train/tp_frcnn_train.py +10 -11
  89. deepdoctection/utils/concurrency.py +1 -1
  90. deepdoctection/utils/context.py +13 -6
  91. deepdoctection/utils/develop.py +4 -4
  92. deepdoctection/utils/env_info.py +52 -14
  93. deepdoctection/utils/file_utils.py +6 -11
  94. deepdoctection/utils/fs.py +41 -14
  95. deepdoctection/utils/identifier.py +2 -2
  96. deepdoctection/utils/logger.py +15 -15
  97. deepdoctection/utils/metacfg.py +7 -7
  98. deepdoctection/utils/pdf_utils.py +39 -14
  99. deepdoctection/utils/settings.py +188 -182
  100. deepdoctection/utils/tqdm.py +1 -1
  101. deepdoctection/utils/transform.py +14 -9
  102. deepdoctection/utils/types.py +104 -0
  103. deepdoctection/utils/utils.py +7 -7
  104. deepdoctection/utils/viz.py +70 -69
  105. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
  106. deepdoctection-0.34.dist-info/RECORD +146 -0
  107. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
  108. deepdoctection/utils/detection_types.py +0 -68
  109. deepdoctection-0.32.dist-info/RECORD +0 -146
  110. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
  111. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
@@ -20,9 +20,10 @@ Module for pipeline component of table segmentation. Uses row/column detector an
20
20
  ious/ioas of rows and columns.
21
21
  """
22
22
 
23
+ from __future__ import annotations
23
24
 
24
25
  from dataclasses import dataclass
25
- from typing import List, Literal, Optional, Sequence, Tuple, Union
26
+ from typing import Literal, Optional, Sequence, Union
26
27
 
27
28
  import numpy as np
28
29
 
@@ -32,10 +33,9 @@ from ..datapoint.image import Image
32
33
  from ..extern.base import DetectionResult
33
34
  from ..mapper.maputils import MappingContextManager
34
35
  from ..mapper.match import match_anns_by_intersection
35
- from ..utils.detection_types import JsonDict
36
36
  from ..utils.error import ImageError
37
37
  from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType
38
- from .base import PipelineComponent
38
+ from .base import MetaAnnotation, PipelineComponent
39
39
  from .refine import generate_html_string
40
40
  from .registry import pipeline_component_registry
41
41
 
@@ -57,10 +57,10 @@ class SegmentationResult:
57
57
 
58
58
  def choose_items_by_iou(
59
59
  dp: Image,
60
- item_proposals: List[ImageAnnotation],
60
+ item_proposals: list[ImageAnnotation],
61
61
  iou_threshold: float,
62
62
  above_threshold: bool = True,
63
- reference_item_proposals: Optional[List[ImageAnnotation]] = None,
63
+ reference_item_proposals: Optional[list[ImageAnnotation]] = None,
64
64
  ) -> Image:
65
65
  """
66
66
  Deactivate image annotations that have ious with each other above some threshold. It will deactivate an annotation
@@ -133,7 +133,7 @@ def stretch_item_per_table(
133
133
  :return: Image
134
134
  """
135
135
 
136
- item_ann_ids = table.get_relationship(Relationships.child)
136
+ item_ann_ids = table.get_relationship(Relationships.CHILD)
137
137
 
138
138
  rows = dp.get_annotation(category_names=row_name, annotation_ids=item_ann_ids)
139
139
  if table.image is None:
@@ -192,13 +192,13 @@ def stretch_item_per_table(
192
192
 
193
193
 
194
194
  def _tile_by_stretching_rows_left_and_rightwise(
195
- dp: Image, items: List[ImageAnnotation], table: ImageAnnotation, item_name: str
195
+ dp: Image, items: list[ImageAnnotation], table: ImageAnnotation, item_name: str
196
196
  ) -> None:
197
197
  if table.image is None:
198
198
  raise ImageError("table.image cannot be None")
199
199
  table_embedding_box = table.get_bounding_box(dp.image_id)
200
200
 
201
- tmp_item_xy = table_embedding_box.uly + 1.0 if item_name == LayoutType.row else table_embedding_box.ulx + 1.0
201
+ tmp_item_xy = table_embedding_box.uly + 1.0 if item_name == LayoutType.ROW else table_embedding_box.ulx + 1.0
202
202
  tmp_item_table_xy = 1.0
203
203
  for idx, item in enumerate(items):
204
204
  with MappingContextManager(
@@ -213,19 +213,19 @@ def _tile_by_stretching_rows_left_and_rightwise(
213
213
  next_item_embedding_box = items[idx + 1].get_bounding_box(dp.image_id)
214
214
  tmp_next_item_xy = (
215
215
  (item_embedding_box.lry + next_item_embedding_box.uly) / 2
216
- if item_name == LayoutType.row
216
+ if item_name == LayoutType.ROW
217
217
  else (item_embedding_box.lrx + next_item_embedding_box.ulx) / 2
218
218
  )
219
219
  else:
220
220
  tmp_next_item_xy = (
221
- table_embedding_box.lry - 1.0 if item_name == LayoutType.row else table_embedding_box.lrx - 1.0
221
+ table_embedding_box.lry - 1.0 if item_name == LayoutType.ROW else table_embedding_box.lrx - 1.0
222
222
  )
223
223
 
224
224
  new_embedding_box = BoundingBox(
225
- ulx=item_embedding_box.ulx if item_name == LayoutType.row else tmp_item_xy,
226
- uly=tmp_item_xy if item_name == LayoutType.row else item_embedding_box.uly,
227
- lrx=item_embedding_box.lrx if item_name == LayoutType.row else tmp_next_item_xy,
228
- lry=tmp_next_item_xy if item_name == LayoutType.row else item_embedding_box.lry,
225
+ ulx=item_embedding_box.ulx if item_name == LayoutType.ROW else tmp_item_xy,
226
+ uly=tmp_item_xy if item_name == LayoutType.ROW else item_embedding_box.uly,
227
+ lrx=item_embedding_box.lrx if item_name == LayoutType.ROW else tmp_next_item_xy,
228
+ lry=tmp_next_item_xy if item_name == LayoutType.ROW else item_embedding_box.lry,
229
229
  absolute_coords=True,
230
230
  )
231
231
  item.image.set_embedding(dp.image_id, new_embedding_box)
@@ -236,19 +236,19 @@ def _tile_by_stretching_rows_left_and_rightwise(
236
236
  next_item_table_embedding_box = items[idx + 1].get_bounding_box(table.annotation_id)
237
237
  tmp_table_next_item_xy = (
238
238
  (item_table_embedding_box.lry + next_item_table_embedding_box.uly) / 2
239
- if item_name == LayoutType.row
239
+ if item_name == LayoutType.ROW
240
240
  else (item_table_embedding_box.lrx + next_item_table_embedding_box.ulx) / 2
241
241
  )
242
242
  else:
243
243
  tmp_table_next_item_xy = (
244
- table.image.height - 1.0 if item_name == LayoutType.row else table.image.width - 1.0
244
+ table.image.height - 1.0 if item_name == LayoutType.ROW else table.image.width - 1.0
245
245
  )
246
246
 
247
247
  new_table_embedding_box = BoundingBox(
248
- ulx=item_table_embedding_box.ulx if item_name == LayoutType.row else tmp_item_table_xy,
249
- uly=tmp_item_table_xy if item_name == LayoutType.row else item_table_embedding_box.uly,
250
- lrx=item_table_embedding_box.lrx if item_name == LayoutType.row else tmp_table_next_item_xy,
251
- lry=tmp_table_next_item_xy if item_name == LayoutType.row else item_table_embedding_box.lry,
248
+ ulx=item_table_embedding_box.ulx if item_name == LayoutType.ROW else tmp_item_table_xy,
249
+ uly=tmp_item_table_xy if item_name == LayoutType.ROW else item_table_embedding_box.uly,
250
+ lrx=item_table_embedding_box.lrx if item_name == LayoutType.ROW else tmp_table_next_item_xy,
251
+ lry=tmp_table_next_item_xy if item_name == LayoutType.ROW else item_table_embedding_box.lry,
252
252
  absolute_coords=True,
253
253
  )
254
254
  item.image.set_embedding(table.annotation_id, new_table_embedding_box)
@@ -256,13 +256,13 @@ def _tile_by_stretching_rows_left_and_rightwise(
256
256
 
257
257
 
258
258
  def _tile_by_stretching_rows_leftwise_column_downwise(
259
- dp: Image, items: List[ImageAnnotation], table: ImageAnnotation, item_name: str
259
+ dp: Image, items: list[ImageAnnotation], table: ImageAnnotation, item_name: str
260
260
  ) -> None:
261
261
  if table.image is None:
262
262
  raise ImageError("table.image cannot be None")
263
263
  table_embedding_box = table.get_bounding_box(dp.image_id)
264
264
 
265
- tmp_item_xy = table_embedding_box.uly + 1.0 if item_name == LayoutType.row else table_embedding_box.ulx + 1.0
265
+ tmp_item_xy = table_embedding_box.uly + 1.0 if item_name == LayoutType.ROW else table_embedding_box.ulx + 1.0
266
266
  tmp_item_table_xy = 1.0
267
267
  for item in items:
268
268
  with MappingContextManager(
@@ -274,16 +274,16 @@ def _tile_by_stretching_rows_leftwise_column_downwise(
274
274
  raise ImageError("item.image cannot be None")
275
275
  item_embedding_box = item.get_bounding_box(dp.image_id)
276
276
  new_embedding_box = BoundingBox(
277
- ulx=item_embedding_box.ulx if item_name == LayoutType.row else tmp_item_xy,
278
- uly=tmp_item_xy if item_name == LayoutType.row else item_embedding_box.uly,
277
+ ulx=item_embedding_box.ulx if item_name == LayoutType.ROW else tmp_item_xy,
278
+ uly=tmp_item_xy if item_name == LayoutType.ROW else item_embedding_box.uly,
279
279
  lrx=item_embedding_box.lrx,
280
280
  lry=item_embedding_box.lry,
281
281
  absolute_coords=True,
282
282
  )
283
283
  item_table_embedding_box = item.get_bounding_box(table.annotation_id)
284
284
  new_table_embedding_box = BoundingBox(
285
- ulx=item_table_embedding_box.ulx if item_name == LayoutType.row else tmp_item_table_xy,
286
- uly=tmp_item_table_xy if item_name == LayoutType.row else item_table_embedding_box.uly,
285
+ ulx=item_table_embedding_box.ulx if item_name == LayoutType.ROW else tmp_item_table_xy,
286
+ uly=tmp_item_table_xy if item_name == LayoutType.ROW else item_table_embedding_box.uly,
287
287
  lrx=item_table_embedding_box.lrx,
288
288
  lry=item_table_embedding_box.lry,
289
289
  absolute_coords=True,
@@ -291,23 +291,23 @@ def _tile_by_stretching_rows_leftwise_column_downwise(
291
291
 
292
292
  if item == items[-1]:
293
293
  new_embedding_box = BoundingBox(
294
- ulx=item_embedding_box.ulx if item_name == LayoutType.row else tmp_item_xy,
295
- uly=tmp_item_xy if item_name == LayoutType.row else item_embedding_box.uly,
296
- lrx=item_embedding_box.lrx if item_name == LayoutType.row else table_embedding_box.lrx - 1.0,
297
- lry=table_embedding_box.lry - 1.0 if item_name == LayoutType.row else item_embedding_box.lry,
294
+ ulx=item_embedding_box.ulx if item_name == LayoutType.ROW else tmp_item_xy,
295
+ uly=tmp_item_xy if item_name == LayoutType.ROW else item_embedding_box.uly,
296
+ lrx=item_embedding_box.lrx if item_name == LayoutType.ROW else table_embedding_box.lrx - 1.0,
297
+ lry=table_embedding_box.lry - 1.0 if item_name == LayoutType.ROW else item_embedding_box.lry,
298
298
  absolute_coords=True,
299
299
  )
300
300
  new_table_embedding_box = BoundingBox(
301
- ulx=item_table_embedding_box.ulx if item_name == LayoutType.row else tmp_item_table_xy,
302
- uly=tmp_item_table_xy if item_name == LayoutType.row else item_table_embedding_box.uly,
303
- lrx=item_table_embedding_box.lrx if item_name == LayoutType.row else table.image.width - 1.0,
304
- lry=table.image.height - 1.0 if item_name == LayoutType.row else item_table_embedding_box.lry,
301
+ ulx=item_table_embedding_box.ulx if item_name == LayoutType.ROW else tmp_item_table_xy,
302
+ uly=tmp_item_table_xy if item_name == LayoutType.ROW else item_table_embedding_box.uly,
303
+ lrx=item_table_embedding_box.lrx if item_name == LayoutType.ROW else table.image.width - 1.0,
304
+ lry=table.image.height - 1.0 if item_name == LayoutType.ROW else item_table_embedding_box.lry,
305
305
  absolute_coords=True,
306
306
  )
307
307
 
308
- tmp_item_xy = item_embedding_box.lry if item_name == LayoutType.row else item_embedding_box.lrx
308
+ tmp_item_xy = item_embedding_box.lry if item_name == LayoutType.ROW else item_embedding_box.lrx
309
309
  tmp_item_table_xy = (
310
- item_table_embedding_box.lry if item_name == LayoutType.row else item_table_embedding_box.lrx
310
+ item_table_embedding_box.lry if item_name == LayoutType.ROW else item_table_embedding_box.lrx
311
311
  )
312
312
  item.image.set_embedding(dp.image_id, new_embedding_box)
313
313
  item.image.set_embedding(table.annotation_id, new_table_embedding_box)
@@ -336,12 +336,12 @@ def tile_tables_with_items_per_table(
336
336
  :return: Image
337
337
  """
338
338
 
339
- item_ann_ids = table.get_relationship(Relationships.child)
339
+ item_ann_ids = table.get_relationship(Relationships.CHILD)
340
340
  items = dp.get_annotation(category_names=item_name, annotation_ids=item_ann_ids)
341
341
 
342
342
  items.sort(
343
343
  key=lambda x: (
344
- x.get_bounding_box(dp.image_id).cx if item_name == LayoutType.column else x.get_bounding_box(dp.image_id).cy
344
+ x.get_bounding_box(dp.image_id).cx if item_name == LayoutType.COLUMN else x.get_bounding_box(dp.image_id).cy
345
345
  )
346
346
  )
347
347
 
@@ -372,7 +372,7 @@ def stretch_items(
372
372
  :param remove_iou_threshold_cols: iou threshold for removing overlapping columns
373
373
  :return: An Image
374
374
  """
375
- table_anns = dp.get_annotation_iter(category_names=table_name)
375
+ table_anns = dp.get_annotation(category_names=table_name)
376
376
 
377
377
  for table in table_anns:
378
378
  dp = stretch_item_per_table(dp, table, row_name, col_name, remove_iou_threshold_rows, remove_iou_threshold_cols)
@@ -380,7 +380,7 @@ def stretch_items(
380
380
  return dp
381
381
 
382
382
 
383
- def _default_segment_table(cells: List[ImageAnnotation]) -> List[SegmentationResult]:
383
+ def _default_segment_table(cells: list[ImageAnnotation]) -> list[SegmentationResult]:
384
384
  """
385
385
  Error segmentation handling when segmentation goes wrong. It will generate a default segmentation, e.g. no real
386
386
  segmentation.
@@ -404,7 +404,7 @@ def segment_table(
404
404
  segment_rule: Literal["iou", "ioa"],
405
405
  threshold_rows: float,
406
406
  threshold_cols: float,
407
- ) -> List[SegmentationResult]:
407
+ ) -> list[SegmentationResult]:
408
408
  """
409
409
  Segments a table,i.e. produces for each cell a SegmentationResult. It uses numbered rows and columns that have to
410
410
  be predicted by an appropriate detector. E.g. for calculating row and rwo spans it first infers the iou of a cell
@@ -424,7 +424,7 @@ def segment_table(
424
424
  :return: A list of len(number of cells) of SegmentationResult.
425
425
  """
426
426
 
427
- child_ann_ids = table.get_relationship(Relationships.child)
427
+ child_ann_ids = table.get_relationship(Relationships.CHILD)
428
428
  cell_index_rows, row_index, _, _ = match_anns_by_intersection(
429
429
  dp,
430
430
  item_names[0],
@@ -459,7 +459,7 @@ def segment_table(
459
459
  rows_of_cell = [rows[k] for k in row_index[cell_positions_rows]]
460
460
  rs = np.count_nonzero(cell_index_rows == idx)
461
461
  if len(rows_of_cell):
462
- row_number = min(int(row.get_sub_category(CellType.row_number).category_id) for row in rows_of_cell)
462
+ row_number = min(row.get_sub_category(CellType.ROW_NUMBER).category_id for row in rows_of_cell)
463
463
  else:
464
464
  row_number = 0
465
465
 
@@ -467,7 +467,7 @@ def segment_table(
467
467
  cols_of_cell = [columns[k] for k in col_index[cell_positions_cols]]
468
468
  cs = np.count_nonzero(cell_index_cols == idx)
469
469
  if len(cols_of_cell):
470
- col_number = min(int(col.get_sub_category(CellType.column_number).category_id) for col in cols_of_cell)
470
+ col_number = min(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
471
471
  else:
472
472
  col_number = 0
473
473
 
@@ -492,7 +492,7 @@ def create_intersection_cells(
492
492
  table_annotation_id: str,
493
493
  cell_class_id: int,
494
494
  sub_item_names: Sequence[CellType],
495
- ) -> Tuple[Sequence[DetectionResult], Sequence[SegmentationResult]]:
495
+ ) -> tuple[Sequence[DetectionResult], Sequence[SegmentationResult]]:
496
496
  """
497
497
  Given rows and columns with row- and column number sub categories, create a list of `DetectionResult` and
498
498
  `SegmentationResult` as intersection of all their intersection rectangles.
@@ -518,14 +518,14 @@ def create_intersection_cells(
518
518
  box=boxes_cells[idx].to_list(mode="xyxy"),
519
519
  class_id=cell_class_id,
520
520
  absolute_coords=boxes_cells[idx].absolute_coords,
521
- class_name=LayoutType.cell,
521
+ class_name=LayoutType.CELL,
522
522
  )
523
523
  )
524
524
  segment_result_cells.append(
525
525
  SegmentationResult(
526
526
  annotation_id="",
527
- row_num=int(row.get_sub_category(sub_item_names[0]).category_id),
528
- col_num=int(col.get_sub_category(sub_item_names[1]).category_id),
527
+ row_num=row.get_sub_category(sub_item_names[0]).category_id,
528
+ col_num=col.get_sub_category(sub_item_names[1]).category_id,
529
529
  rs=1,
530
530
  cs=1,
531
531
  )
@@ -545,7 +545,7 @@ def segment_pubtables(
545
545
  segment_rule: Literal["iou", "ioa"],
546
546
  threshold_rows: float,
547
547
  threshold_cols: float,
548
- ) -> List[SegmentationResult]:
548
+ ) -> list[SegmentationResult]:
549
549
  """
550
550
  Segment a table based on the results of `table-transformer-structure-recognition`. The processing assumes that cells
551
551
  have already been generated from the intersection of columns and rows and that column and row numbers have been
@@ -566,7 +566,7 @@ def segment_pubtables(
566
566
  to the column.
567
567
  :return: A list of len(number of cells) of SegmentationResult for spanning cells
568
568
  """
569
- child_ann_ids = table.get_relationship(Relationships.child)
569
+ child_ann_ids = table.get_relationship(Relationships.CHILD)
570
570
  cell_index_rows, row_index, _, _ = match_anns_by_intersection(
571
571
  dp,
572
572
  item_names[0],
@@ -601,25 +601,25 @@ def segment_pubtables(
601
601
  cell_positions_rows = cell_index_rows == idx
602
602
  rows_of_cell = [rows[k] for k in row_index[cell_positions_rows]]
603
603
  rs = (
604
- max(int(row.get_sub_category(CellType.row_number).category_id) for row in rows_of_cell)
605
- - min(int(row.get_sub_category(CellType.row_number).category_id) for row in rows_of_cell)
604
+ max(row.get_sub_category(CellType.ROW_NUMBER).category_id for row in rows_of_cell)
605
+ - min(row.get_sub_category(CellType.ROW_NUMBER).category_id for row in rows_of_cell)
606
606
  + 1
607
607
  )
608
608
  if len(rows_of_cell):
609
- row_number = min(int(row.get_sub_category(CellType.row_number).category_id) for row in rows_of_cell)
609
+ row_number = min(row.get_sub_category(CellType.ROW_NUMBER).category_id for row in rows_of_cell)
610
610
  else:
611
611
  row_number = 0
612
612
 
613
613
  cell_positions_cols = cell_index_cols == idx
614
614
  cols_of_cell = [columns[k] for k in col_index[cell_positions_cols]]
615
615
  cs = (
616
- max(int(col.get_sub_category(CellType.column_number).category_id) for col in cols_of_cell)
617
- - min(int(col.get_sub_category(CellType.column_number).category_id) for col in cols_of_cell)
616
+ max(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
617
+ - min(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
618
618
  + 1
619
619
  )
620
620
 
621
621
  if len(cols_of_cell):
622
- col_number = min(int(col.get_sub_category(CellType.column_number).category_id) for col in cols_of_cell)
622
+ col_number = min(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
623
623
  else:
624
624
  col_number = 0
625
625
 
@@ -694,8 +694,10 @@ class TableSegmentationService(PipelineComponent):
694
694
  :param sub_item_names: cell types of sub items (e.g. row number and column number)
695
695
  :param stretch_rule: Check the description in `tile_tables_with_items_per_table`
696
696
  """
697
- assert segment_rule in ("iou", "ioa"), "segment_rule must be either iou or ioa"
698
- assert stretch_rule in ("left", "equal"), "stretch rule must be either 'left' or 'equal'"
697
+ if segment_rule not in ("iou", "ioa"):
698
+ raise ValueError("segment_rule must be either iou or ioa")
699
+ if stretch_rule not in ("left", "equal"):
700
+ raise ValueError("stretch rule must be either 'left' or 'equal'")
699
701
 
700
702
  self.segment_rule = segment_rule
701
703
  self.threshold_rows = threshold_rows
@@ -722,7 +724,7 @@ class TableSegmentationService(PipelineComponent):
722
724
  )
723
725
  table_anns = dp.get_annotation(category_names=self.table_name)
724
726
  for table in table_anns:
725
- item_ann_ids = table.get_relationship(Relationships.child)
727
+ item_ann_ids = table.get_relationship(Relationships.CHILD)
726
728
  for item_sub_item_name in zip(self.item_names, self.sub_item_names): # one pass for rows and one for cols
727
729
  item_name, sub_item_name = item_sub_item_name[0], item_sub_item_name[1]
728
730
  if self.tile_table:
@@ -740,7 +742,7 @@ class TableSegmentationService(PipelineComponent):
740
742
  items.sort(
741
743
  key=lambda x: (
742
744
  x.get_bounding_box(dp.image_id).cx # pylint: disable=W0640
743
- if item_name == LayoutType.column # pylint: disable=W0640
745
+ if item_name == LayoutType.COLUMN # pylint: disable=W0640
744
746
  else x.get_bounding_box(dp.image_id).cy # pylint: disable=W0640
745
747
  )
746
748
  )
@@ -760,45 +762,45 @@ class TableSegmentationService(PipelineComponent):
760
762
  )
761
763
  for segment_result in raw_table_segments:
762
764
  self.dp_manager.set_category_annotation(
763
- CellType.row_number, segment_result.row_num, CellType.row_number, segment_result.annotation_id
765
+ CellType.ROW_NUMBER, segment_result.row_num, CellType.ROW_NUMBER, segment_result.annotation_id
764
766
  )
765
767
  self.dp_manager.set_category_annotation(
766
- CellType.column_number, segment_result.col_num, CellType.column_number, segment_result.annotation_id
768
+ CellType.COLUMN_NUMBER, segment_result.col_num, CellType.COLUMN_NUMBER, segment_result.annotation_id
767
769
  )
768
770
  self.dp_manager.set_category_annotation(
769
- CellType.row_span, segment_result.rs, CellType.row_span, segment_result.annotation_id
771
+ CellType.ROW_SPAN, segment_result.rs, CellType.ROW_SPAN, segment_result.annotation_id
770
772
  )
771
773
  self.dp_manager.set_category_annotation(
772
- CellType.column_span, segment_result.cs, CellType.column_span, segment_result.annotation_id
774
+ CellType.COLUMN_SPAN, segment_result.cs, CellType.COLUMN_SPAN, segment_result.annotation_id
773
775
  )
774
776
 
775
777
  if table.image:
776
778
  cells = table.image.get_annotation(category_names=self.cell_names)
777
- number_of_rows = max(int(cell.get_sub_category(CellType.row_number).category_id) for cell in cells)
778
- number_of_cols = max(int(cell.get_sub_category(CellType.column_number).category_id) for cell in cells)
779
- max_row_span = max(int(cell.get_sub_category(CellType.row_span).category_id) for cell in cells)
780
- max_col_span = max(int(cell.get_sub_category(CellType.column_span).category_id) for cell in cells)
779
+ number_of_rows = max(cell.get_sub_category(CellType.ROW_NUMBER).category_id for cell in cells)
780
+ number_of_cols = max(cell.get_sub_category(CellType.COLUMN_NUMBER).category_id for cell in cells)
781
+ max_row_span = max(cell.get_sub_category(CellType.ROW_SPAN).category_id for cell in cells)
782
+ max_col_span = max(cell.get_sub_category(CellType.COLUMN_SPAN).category_id for cell in cells)
781
783
  # TODO: the summaries should be sub categories of the underlying ann
782
784
  self.dp_manager.set_summary_annotation(
783
- TableType.number_of_rows,
784
- TableType.number_of_rows,
785
+ TableType.NUMBER_OF_ROWS,
786
+ TableType.NUMBER_OF_ROWS,
785
787
  number_of_rows,
786
788
  annotation_id=table.annotation_id,
787
789
  )
788
790
  self.dp_manager.set_summary_annotation(
789
- TableType.number_of_columns,
790
- TableType.number_of_columns,
791
+ TableType.NUMBER_OF_COLUMNS,
792
+ TableType.NUMBER_OF_COLUMNS,
791
793
  number_of_cols,
792
794
  annotation_id=table.annotation_id,
793
795
  )
794
796
  self.dp_manager.set_summary_annotation(
795
- TableType.max_row_span, TableType.max_row_span, max_row_span, annotation_id=table.annotation_id
797
+ TableType.MAX_ROW_SPAN, TableType.MAX_ROW_SPAN, max_row_span, annotation_id=table.annotation_id
796
798
  )
797
799
  self.dp_manager.set_summary_annotation(
798
- TableType.max_col_span, TableType.max_col_span, max_col_span, annotation_id=table.annotation_id
800
+ TableType.MAX_COL_SPAN, TableType.MAX_COL_SPAN, max_col_span, annotation_id=table.annotation_id
799
801
  )
800
802
 
801
- def clone(self) -> PipelineComponent:
803
+ def clone(self) -> TableSegmentationService:
802
804
  return self.__class__(
803
805
  self.segment_rule,
804
806
  self.threshold_rows,
@@ -813,40 +815,38 @@ class TableSegmentationService(PipelineComponent):
813
815
  self.stretch_rule,
814
816
  )
815
817
 
816
- def get_meta_annotation(self) -> JsonDict:
817
- return dict(
818
- [
819
- ("image_annotations", []),
820
- (
821
- "sub_categories",
822
- {
823
- LayoutType.cell: {
824
- CellType.row_number,
825
- CellType.column_number,
826
- CellType.row_span,
827
- CellType.column_span,
828
- },
829
- LayoutType.row: {CellType.row_number},
830
- LayoutType.column: {CellType.column_number},
831
- },
832
- ),
833
- ("relationships", {}),
834
- ("summaries", []),
835
- ]
818
+ def get_meta_annotation(self) -> MetaAnnotation:
819
+ return MetaAnnotation(
820
+ image_annotations=(),
821
+ sub_categories={
822
+ LayoutType.CELL: {
823
+ CellType.ROW_NUMBER,
824
+ CellType.COLUMN_NUMBER,
825
+ CellType.ROW_SPAN,
826
+ CellType.COLUMN_SPAN,
827
+ },
828
+ LayoutType.ROW: {CellType.ROW_NUMBER},
829
+ LayoutType.COLUMN: {CellType.COLUMN_NUMBER},
830
+ },
831
+ relationships={},
832
+ summaries=(),
836
833
  )
837
834
 
835
+ def clear_predictor(self) -> None:
836
+ """clear predictor. Will do nothing"""
837
+
838
838
 
839
839
  class PubtablesSegmentationService(PipelineComponent):
840
840
  """
841
841
  Table segmentation for table recognition detectors trained on Pubtables1M dataset. It will require `ImageAnnotation`
842
842
  of type `LayoutType.row`, `LayoutType.column` and cells of at least one type `CellType.spanning`,
843
- `CellType.row_header`, `CellType.column_header`, `CellType.projected_row_header`. For table recognition using
843
+ `CellType.ROW_HEADER`, `CellType.COLUMN_HEADER`, `CellType.PROJECTED_ROW_HEADER`. For table recognition using
844
844
  this service build a pipeline as follows:
845
845
 
846
846
  **Example:**
847
847
 
848
848
  layout = ImageLayoutService(layout_detector, to_image=True, crop_image=True)
849
- recognition = SubImageLayoutService(table_recognition_detector, LayoutType.table, {1: 6, 2:7, 3:8, 4:9}, True)
849
+ recognition = SubImageLayoutService(table_recognition_detector, LayoutType.TABLE, {1: 6, 2:7, 3:8, 4:9}, True)
850
850
  segment = PubtablesSegmentationService('ioa', 0.4, 0.4, True, 0.8, 0.8, 7)
851
851
  ...
852
852
 
@@ -933,7 +933,7 @@ class PubtablesSegmentationService(PipelineComponent):
933
933
  )
934
934
  table_anns = dp.get_annotation(category_names=self.table_name)
935
935
  for table in table_anns:
936
- item_ann_ids = table.get_relationship(Relationships.child)
936
+ item_ann_ids = table.get_relationship(Relationships.CHILD)
937
937
  for item_sub_item_name in zip(self.item_names, self.sub_item_names): # one pass for rows and one for cols
938
938
  item_name, sub_item_name = item_sub_item_name[0], item_sub_item_name[1]
939
939
  if self.tile_table:
@@ -944,7 +944,7 @@ class PubtablesSegmentationService(PipelineComponent):
944
944
  items.sort(
945
945
  key=lambda x: (
946
946
  x.get_bounding_box(dp.image_id).cx
947
- if item_name == LayoutType.column # pylint: disable=W0640
947
+ if item_name == LayoutType.COLUMN # pylint: disable=W0640
948
948
  else x.get_bounding_box(dp.image_id).cy
949
949
  )
950
950
  )
@@ -967,16 +967,16 @@ class PubtablesSegmentationService(PipelineComponent):
967
967
  crop_image=self.crop_cell_image,
968
968
  )
969
969
  self.dp_manager.set_category_annotation(
970
- CellType.row_number, segment_result.row_num, CellType.row_number, segment_result.annotation_id
970
+ CellType.ROW_NUMBER, segment_result.row_num, CellType.ROW_NUMBER, segment_result.annotation_id
971
971
  )
972
972
  self.dp_manager.set_category_annotation(
973
- CellType.column_number, segment_result.col_num, CellType.column_number, segment_result.annotation_id
973
+ CellType.COLUMN_NUMBER, segment_result.col_num, CellType.COLUMN_NUMBER, segment_result.annotation_id
974
974
  )
975
975
  self.dp_manager.set_category_annotation(
976
- CellType.row_span, segment_result.rs, CellType.row_span, segment_result.annotation_id
976
+ CellType.ROW_SPAN, segment_result.rs, CellType.ROW_SPAN, segment_result.annotation_id
977
977
  )
978
978
  self.dp_manager.set_category_annotation(
979
- CellType.column_span, segment_result.cs, CellType.column_span, segment_result.annotation_id
979
+ CellType.COLUMN_SPAN, segment_result.cs, CellType.COLUMN_SPAN, segment_result.annotation_id
980
980
  )
981
981
  cell_rn_cn_to_ann_id[(segment_result.row_num, segment_result.col_num)] = segment_result.annotation_id
982
982
  spanning_cell_raw_segments = segment_pubtables(
@@ -990,16 +990,16 @@ class PubtablesSegmentationService(PipelineComponent):
990
990
  )
991
991
  for segment_result in spanning_cell_raw_segments:
992
992
  self.dp_manager.set_category_annotation(
993
- CellType.row_number, segment_result.row_num, CellType.row_number, segment_result.annotation_id
993
+ CellType.ROW_NUMBER, segment_result.row_num, CellType.ROW_NUMBER, segment_result.annotation_id
994
994
  )
995
995
  self.dp_manager.set_category_annotation(
996
- CellType.column_number, segment_result.col_num, CellType.column_number, segment_result.annotation_id
996
+ CellType.COLUMN_NUMBER, segment_result.col_num, CellType.COLUMN_NUMBER, segment_result.annotation_id
997
997
  )
998
998
  self.dp_manager.set_category_annotation(
999
- CellType.row_span, segment_result.rs, CellType.row_span, segment_result.annotation_id
999
+ CellType.ROW_SPAN, segment_result.rs, CellType.ROW_SPAN, segment_result.annotation_id
1000
1000
  )
1001
1001
  self.dp_manager.set_category_annotation(
1002
- CellType.column_span, segment_result.cs, CellType.column_span, segment_result.annotation_id
1002
+ CellType.COLUMN_SPAN, segment_result.cs, CellType.COLUMN_SPAN, segment_result.annotation_id
1003
1003
  )
1004
1004
  cells_to_deactivate = []
1005
1005
  for rs in range(segment_result.rs):
@@ -1013,10 +1013,10 @@ class PubtablesSegmentationService(PipelineComponent):
1013
1013
  if table.image:
1014
1014
  cells = table.image.get_annotation(category_names=self.cell_names)
1015
1015
  if cells:
1016
- number_of_rows = max(int(cell.get_sub_category(CellType.row_number).category_id) for cell in cells)
1017
- number_of_cols = max(int(cell.get_sub_category(CellType.column_number).category_id) for cell in cells)
1018
- max_row_span = max(int(cell.get_sub_category(CellType.row_span).category_id) for cell in cells)
1019
- max_col_span = max(int(cell.get_sub_category(CellType.column_span).category_id) for cell in cells)
1016
+ number_of_rows = max(cell.get_sub_category(CellType.ROW_NUMBER).category_id for cell in cells)
1017
+ number_of_cols = max(cell.get_sub_category(CellType.COLUMN_NUMBER).category_id for cell in cells)
1018
+ max_row_span = max(cell.get_sub_category(CellType.ROW_SPAN).category_id for cell in cells)
1019
+ max_col_span = max(cell.get_sub_category(CellType.COLUMN_SPAN).category_id for cell in cells)
1020
1020
  else:
1021
1021
  number_of_rows = 0
1022
1022
  number_of_cols = 0
@@ -1024,24 +1024,24 @@ class PubtablesSegmentationService(PipelineComponent):
1024
1024
  max_col_span = 0
1025
1025
  # TODO: the summaries should be sub categories of the underlying ann
1026
1026
  self.dp_manager.set_summary_annotation(
1027
- TableType.number_of_rows, TableType.number_of_rows, number_of_rows, annotation_id=table.annotation_id
1027
+ TableType.NUMBER_OF_ROWS, TableType.NUMBER_OF_ROWS, number_of_rows, annotation_id=table.annotation_id
1028
1028
  )
1029
1029
  self.dp_manager.set_summary_annotation(
1030
- TableType.number_of_columns,
1031
- TableType.number_of_columns,
1030
+ TableType.NUMBER_OF_COLUMNS,
1031
+ TableType.NUMBER_OF_COLUMNS,
1032
1032
  number_of_cols,
1033
1033
  annotation_id=table.annotation_id,
1034
1034
  )
1035
1035
  self.dp_manager.set_summary_annotation(
1036
- TableType.max_row_span, TableType.max_row_span, max_row_span, annotation_id=table.annotation_id
1036
+ TableType.MAX_ROW_SPAN, TableType.MAX_ROW_SPAN, max_row_span, annotation_id=table.annotation_id
1037
1037
  )
1038
1038
  self.dp_manager.set_summary_annotation(
1039
- TableType.max_col_span, TableType.max_col_span, max_col_span, annotation_id=table.annotation_id
1039
+ TableType.MAX_COL_SPAN, TableType.MAX_COL_SPAN, max_col_span, annotation_id=table.annotation_id
1040
1040
  )
1041
1041
  html = generate_html_string(table)
1042
- self.dp_manager.set_container_annotation(TableType.html, -1, TableType.html, table.annotation_id, html)
1042
+ self.dp_manager.set_container_annotation(TableType.HTML, -1, TableType.HTML, table.annotation_id, html)
1043
1043
 
1044
- def clone(self) -> PipelineComponent:
1044
+ def clone(self) -> PubtablesSegmentationService:
1045
1045
  return self.__class__(
1046
1046
  self.segment_rule,
1047
1047
  self.threshold_rows,
@@ -1060,48 +1060,43 @@ class PubtablesSegmentationService(PipelineComponent):
1060
1060
  self.stretch_rule,
1061
1061
  )
1062
1062
 
1063
- def get_meta_annotation(self) -> JsonDict:
1064
- return dict(
1065
- [
1066
- ("image_annotations", []),
1067
- (
1068
- "sub_categories",
1069
- {
1070
- LayoutType.cell: {
1071
- CellType.row_number,
1072
- CellType.column_number,
1073
- CellType.row_span,
1074
- CellType.column_span,
1075
- },
1076
- CellType.spanning: {
1077
- CellType.row_number,
1078
- CellType.column_number,
1079
- CellType.row_span,
1080
- CellType.column_span,
1081
- },
1082
- CellType.row_header: {
1083
- CellType.row_number,
1084
- CellType.column_number,
1085
- CellType.row_span,
1086
- CellType.column_span,
1087
- },
1088
- CellType.column_header: {
1089
- CellType.row_number,
1090
- CellType.column_number,
1091
- CellType.row_span,
1092
- CellType.column_span,
1093
- },
1094
- CellType.projected_row_header: {
1095
- CellType.row_number,
1096
- CellType.column_number,
1097
- CellType.row_span,
1098
- CellType.column_span,
1099
- },
1100
- LayoutType.row: {CellType.row_number},
1101
- LayoutType.column: {CellType.column_number},
1102
- },
1103
- ),
1104
- ("relationships", {}),
1105
- ("summaries", []),
1106
- ]
1063
+ def get_meta_annotation(self) -> MetaAnnotation:
1064
+ return MetaAnnotation(
1065
+ image_annotations=(),
1066
+ sub_categories={
1067
+ LayoutType.CELL: {
1068
+ CellType.ROW_NUMBER,
1069
+ CellType.COLUMN_NUMBER,
1070
+ CellType.ROW_SPAN,
1071
+ CellType.COLUMN_SPAN,
1072
+ },
1073
+ CellType.SPANNING: {
1074
+ CellType.ROW_NUMBER,
1075
+ CellType.COLUMN_NUMBER,
1076
+ CellType.ROW_SPAN,
1077
+ CellType.COLUMN_SPAN,
1078
+ },
1079
+ CellType.ROW_HEADER: {
1080
+ CellType.ROW_NUMBER,
1081
+ CellType.COLUMN_NUMBER,
1082
+ CellType.ROW_SPAN,
1083
+ CellType.COLUMN_SPAN,
1084
+ },
1085
+ CellType.COLUMN_HEADER: {
1086
+ CellType.ROW_NUMBER,
1087
+ CellType.COLUMN_NUMBER,
1088
+ CellType.ROW_SPAN,
1089
+ CellType.COLUMN_SPAN,
1090
+ },
1091
+ CellType.PROJECTED_ROW_HEADER: {
1092
+ CellType.ROW_NUMBER,
1093
+ CellType.COLUMN_NUMBER,
1094
+ CellType.ROW_SPAN,
1095
+ CellType.COLUMN_SPAN,
1096
+ },
1097
+ LayoutType.ROW: {CellType.ROW_NUMBER},
1098
+ LayoutType.COLUMN: {CellType.COLUMN_NUMBER},
1099
+ },
1100
+ relationships={},
1101
+ summaries=(),
1107
1102
  )