deepdoctection 0.31__py3-none-any.whl → 0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (131) hide show
  1. deepdoctection/__init__.py +16 -29
  2. deepdoctection/analyzer/dd.py +70 -59
  3. deepdoctection/configs/conf_dd_one.yaml +34 -31
  4. deepdoctection/dataflow/common.py +9 -5
  5. deepdoctection/dataflow/custom.py +5 -5
  6. deepdoctection/dataflow/custom_serialize.py +75 -18
  7. deepdoctection/dataflow/parallel_map.py +3 -3
  8. deepdoctection/dataflow/serialize.py +4 -4
  9. deepdoctection/dataflow/stats.py +3 -3
  10. deepdoctection/datapoint/annotation.py +41 -56
  11. deepdoctection/datapoint/box.py +9 -8
  12. deepdoctection/datapoint/convert.py +6 -6
  13. deepdoctection/datapoint/image.py +56 -44
  14. deepdoctection/datapoint/view.py +245 -150
  15. deepdoctection/datasets/__init__.py +1 -4
  16. deepdoctection/datasets/adapter.py +35 -26
  17. deepdoctection/datasets/base.py +14 -12
  18. deepdoctection/datasets/dataflow_builder.py +3 -3
  19. deepdoctection/datasets/info.py +24 -26
  20. deepdoctection/datasets/instances/doclaynet.py +51 -51
  21. deepdoctection/datasets/instances/fintabnet.py +46 -46
  22. deepdoctection/datasets/instances/funsd.py +25 -24
  23. deepdoctection/datasets/instances/iiitar13k.py +13 -10
  24. deepdoctection/datasets/instances/layouttest.py +4 -3
  25. deepdoctection/datasets/instances/publaynet.py +5 -5
  26. deepdoctection/datasets/instances/pubtables1m.py +24 -21
  27. deepdoctection/datasets/instances/pubtabnet.py +32 -30
  28. deepdoctection/datasets/instances/rvlcdip.py +30 -30
  29. deepdoctection/datasets/instances/xfund.py +26 -26
  30. deepdoctection/datasets/save.py +6 -6
  31. deepdoctection/eval/__init__.py +1 -4
  32. deepdoctection/eval/accmetric.py +32 -33
  33. deepdoctection/eval/base.py +8 -9
  34. deepdoctection/eval/cocometric.py +15 -13
  35. deepdoctection/eval/eval.py +41 -37
  36. deepdoctection/eval/tedsmetric.py +30 -23
  37. deepdoctection/eval/tp_eval_callback.py +16 -19
  38. deepdoctection/extern/__init__.py +2 -7
  39. deepdoctection/extern/base.py +339 -134
  40. deepdoctection/extern/d2detect.py +85 -113
  41. deepdoctection/extern/deskew.py +14 -11
  42. deepdoctection/extern/doctrocr.py +141 -130
  43. deepdoctection/extern/fastlang.py +27 -18
  44. deepdoctection/extern/hfdetr.py +71 -62
  45. deepdoctection/extern/hflayoutlm.py +504 -211
  46. deepdoctection/extern/hflm.py +230 -0
  47. deepdoctection/extern/model.py +488 -302
  48. deepdoctection/extern/pdftext.py +23 -19
  49. deepdoctection/extern/pt/__init__.py +1 -3
  50. deepdoctection/extern/pt/nms.py +6 -2
  51. deepdoctection/extern/pt/ptutils.py +29 -19
  52. deepdoctection/extern/tessocr.py +39 -38
  53. deepdoctection/extern/texocr.py +18 -18
  54. deepdoctection/extern/tp/tfutils.py +57 -9
  55. deepdoctection/extern/tp/tpcompat.py +21 -14
  56. deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
  57. deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
  58. deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
  59. deepdoctection/extern/tp/tpfrcnn/config/config.py +13 -10
  60. deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
  61. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +18 -8
  62. deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
  63. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +14 -9
  64. deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
  65. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +22 -17
  66. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +21 -14
  67. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +19 -11
  68. deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
  69. deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
  70. deepdoctection/extern/tp/tpfrcnn/preproc.py +12 -8
  71. deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
  72. deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
  73. deepdoctection/extern/tpdetect.py +45 -53
  74. deepdoctection/mapper/__init__.py +3 -8
  75. deepdoctection/mapper/cats.py +27 -29
  76. deepdoctection/mapper/cocostruct.py +10 -10
  77. deepdoctection/mapper/d2struct.py +27 -26
  78. deepdoctection/mapper/hfstruct.py +13 -8
  79. deepdoctection/mapper/laylmstruct.py +178 -37
  80. deepdoctection/mapper/maputils.py +12 -11
  81. deepdoctection/mapper/match.py +2 -2
  82. deepdoctection/mapper/misc.py +11 -9
  83. deepdoctection/mapper/pascalstruct.py +4 -4
  84. deepdoctection/mapper/prodigystruct.py +5 -5
  85. deepdoctection/mapper/pubstruct.py +84 -92
  86. deepdoctection/mapper/tpstruct.py +5 -5
  87. deepdoctection/mapper/xfundstruct.py +33 -33
  88. deepdoctection/pipe/__init__.py +1 -1
  89. deepdoctection/pipe/anngen.py +12 -14
  90. deepdoctection/pipe/base.py +52 -106
  91. deepdoctection/pipe/common.py +72 -59
  92. deepdoctection/pipe/concurrency.py +16 -11
  93. deepdoctection/pipe/doctectionpipe.py +24 -21
  94. deepdoctection/pipe/language.py +20 -25
  95. deepdoctection/pipe/layout.py +20 -16
  96. deepdoctection/pipe/lm.py +75 -105
  97. deepdoctection/pipe/order.py +194 -89
  98. deepdoctection/pipe/refine.py +111 -124
  99. deepdoctection/pipe/segment.py +156 -161
  100. deepdoctection/pipe/{cell.py → sub_layout.py} +50 -40
  101. deepdoctection/pipe/text.py +37 -36
  102. deepdoctection/pipe/transform.py +19 -16
  103. deepdoctection/train/__init__.py +6 -12
  104. deepdoctection/train/d2_frcnn_train.py +48 -41
  105. deepdoctection/train/hf_detr_train.py +41 -30
  106. deepdoctection/train/hf_layoutlm_train.py +153 -135
  107. deepdoctection/train/tp_frcnn_train.py +32 -31
  108. deepdoctection/utils/concurrency.py +1 -1
  109. deepdoctection/utils/context.py +13 -6
  110. deepdoctection/utils/develop.py +4 -4
  111. deepdoctection/utils/env_info.py +87 -125
  112. deepdoctection/utils/file_utils.py +6 -11
  113. deepdoctection/utils/fs.py +22 -18
  114. deepdoctection/utils/identifier.py +2 -2
  115. deepdoctection/utils/logger.py +16 -15
  116. deepdoctection/utils/metacfg.py +7 -7
  117. deepdoctection/utils/mocks.py +93 -0
  118. deepdoctection/utils/pdf_utils.py +11 -11
  119. deepdoctection/utils/settings.py +185 -181
  120. deepdoctection/utils/tqdm.py +1 -1
  121. deepdoctection/utils/transform.py +14 -9
  122. deepdoctection/utils/types.py +104 -0
  123. deepdoctection/utils/utils.py +7 -7
  124. deepdoctection/utils/viz.py +74 -72
  125. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/METADATA +30 -21
  126. deepdoctection-0.33.dist-info/RECORD +146 -0
  127. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/WHEEL +1 -1
  128. deepdoctection/utils/detection_types.py +0 -68
  129. deepdoctection-0.31.dist-info/RECORD +0 -144
  130. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/LICENSE +0 -0
  131. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/top_level.txt +0 -0
@@ -19,36 +19,35 @@
19
19
  """
20
20
  Module for `Evaluator`
21
21
  """
22
-
23
- __all__ = ["Evaluator"]
22
+ from __future__ import annotations
24
23
 
25
24
  from copy import deepcopy
26
- from typing import Any, Dict, List, Literal, Mapping, Optional, Type, Union, overload
25
+ from typing import Any, Generator, Literal, Mapping, Optional, Type, Union, overload
27
26
 
28
27
  import numpy as np
28
+ from lazy_imports import try_import
29
29
 
30
30
  from ..dataflow import CacheData, DataFlow, DataFromList, MapData
31
31
  from ..datapoint.image import Image
32
32
  from ..datasets.base import DatasetBase
33
33
  from ..mapper.cats import filter_cat, remove_cats
34
+ from ..mapper.d2struct import to_wandb_image
34
35
  from ..mapper.misc import maybe_load_image, maybe_remove_image, maybe_remove_image_from_category
35
- from ..pipe.base import LanguageModelPipelineComponent, PredictorPipelineComponent
36
+ from ..pipe.base import PipelineComponent
36
37
  from ..pipe.common import PageParsingService
37
38
  from ..pipe.concurrency import MultiThreadPipelineComponent
38
39
  from ..pipe.doctectionpipe import DoctectionPipe
39
- from ..utils.detection_types import ImageType
40
- from ..utils.file_utils import detectron2_available, wandb_available
41
40
  from ..utils.logger import LoggingRecord, logger
42
41
  from ..utils.settings import DatasetType, LayoutType, TypeOrStr, get_type
42
+ from ..utils.types import PixelValues
43
43
  from ..utils.viz import interactive_imshow
44
44
  from .base import MetricBase
45
45
 
46
- if wandb_available():
46
+ with try_import() as wb_import_guard:
47
47
  import wandb # pylint:disable=W0611
48
48
  from wandb import Artifact, Table
49
49
 
50
- if wandb_available() and detectron2_available():
51
- from ..mapper.d2struct import to_wandb_image
50
+ __all__ = ["Evaluator"]
52
51
 
53
52
 
54
53
  class Evaluator:
@@ -91,10 +90,10 @@ class Evaluator:
91
90
  def __init__(
92
91
  self,
93
92
  dataset: DatasetBase,
94
- component_or_pipeline: Union[PredictorPipelineComponent, LanguageModelPipelineComponent, DoctectionPipe],
93
+ component_or_pipeline: Union[PipelineComponent, DoctectionPipe],
95
94
  metric: Union[Type[MetricBase], MetricBase],
96
95
  num_threads: int = 2,
97
- run: Optional["wandb.sdk.wandb_run.Run"] = None,
96
+ run: Optional[wandb.sdk.wandb_run.Run] = None,
98
97
  ) -> None:
99
98
  """
100
99
  Evaluating a pipeline component on a dataset with a given metric.
@@ -109,14 +108,14 @@ class Evaluator:
109
108
  self.pipe: Optional[DoctectionPipe] = None
110
109
 
111
110
  # when passing a component, we will process prediction on num_threads
112
- if isinstance(component_or_pipeline, (PredictorPipelineComponent, LanguageModelPipelineComponent)):
111
+ if isinstance(component_or_pipeline, PipelineComponent):
113
112
  logger.info(
114
113
  LoggingRecord(
115
114
  f"Building multi threading pipeline component to increase prediction throughput. "
116
115
  f"Using {num_threads} threads"
117
116
  )
118
117
  )
119
- pipeline_components: List[Union[PredictorPipelineComponent, LanguageModelPipelineComponent]] = []
118
+ pipeline_components: list[PipelineComponent] = []
120
119
 
121
120
  for _ in range(num_threads - 1):
122
121
  copy_pipe_component = component_or_pipeline.clone()
@@ -140,14 +139,14 @@ class Evaluator:
140
139
 
141
140
  self.wandb_table_agent: Optional[WandbTableAgent]
142
141
  if run is not None:
143
- if self.dataset.dataset_info.type == DatasetType.object_detection:
142
+ if self.dataset.dataset_info.type == DatasetType.OBJECT_DETECTION:
144
143
  self.wandb_table_agent = WandbTableAgent(
145
144
  run,
146
145
  self.dataset.dataset_info.name,
147
146
  50,
148
147
  self.dataset.dataflow.categories.get_categories(filtered=True),
149
148
  )
150
- elif self.dataset.dataset_info.type == DatasetType.token_classification:
149
+ elif self.dataset.dataset_info.type == DatasetType.TOKEN_CLASSIFICATION:
151
150
  if hasattr(self.metric, "sub_cats"):
152
151
  sub_cat_key, sub_cat_val_list = list(self.metric.sub_cats.items())[0]
153
152
  sub_cat_val = sub_cat_val_list[0]
@@ -179,16 +178,16 @@ class Evaluator:
179
178
  @overload
180
179
  def run(
181
180
  self, output_as_dict: Literal[False] = False, **dataflow_build_kwargs: Union[str, int]
182
- ) -> List[Dict[str, float]]:
181
+ ) -> list[dict[str, float]]:
183
182
  ...
184
183
 
185
184
  @overload
186
- def run(self, output_as_dict: Literal[True], **dataflow_build_kwargs: Union[str, int]) -> Dict[str, float]:
185
+ def run(self, output_as_dict: Literal[True], **dataflow_build_kwargs: Union[str, int]) -> dict[str, float]:
187
186
  ...
188
187
 
189
188
  def run(
190
189
  self, output_as_dict: bool = False, **dataflow_build_kwargs: Union[str, int]
191
- ) -> Union[List[Dict[str, float]], Dict[str, float]]:
190
+ ) -> Union[list[dict[str, float]], dict[str, float]]:
192
191
  """
193
192
  Start evaluation process and return the results.
194
193
 
@@ -247,11 +246,11 @@ class Evaluator:
247
246
  possible_cats_in_datapoint = self.dataset.dataflow.categories.get_categories(as_dict=False, filtered=True)
248
247
 
249
248
  # clean-up procedure depends on the dataset type
250
- if self.dataset.dataset_info.type == DatasetType.object_detection:
249
+ if self.dataset.dataset_info.type == DatasetType.OBJECT_DETECTION:
251
250
  # we keep all image annotations that will not be generated through processing
252
- anns_to_keep = {ann for ann in possible_cats_in_datapoint if ann not in meta_anns["image_annotations"]}
253
- sub_cats_to_remove = meta_anns["sub_categories"]
254
- relationships_to_remove = meta_anns["relationships"]
251
+ anns_to_keep = {ann for ann in possible_cats_in_datapoint if ann not in meta_anns.image_annotations}
252
+ sub_cats_to_remove = meta_anns.sub_categories
253
+ relationships_to_remove = meta_anns.relationships
255
254
  # removing annotations takes place in three steps: First we remove all image annotations. Then, with all
256
255
  # remaining image annotations we check, if the image attribute (with Image instance !) is not empty and
257
256
  # remove it as well, if necessary. In the last step we remove all sub categories and relationships, if
@@ -263,19 +262,19 @@ class Evaluator:
263
262
  remove_cats(sub_categories=sub_cats_to_remove, relationships=relationships_to_remove),
264
263
  )
265
264
 
266
- elif self.dataset.dataset_info.type == DatasetType.sequence_classification:
267
- summary_sub_cats_to_remove = meta_anns["summaries"]
265
+ elif self.dataset.dataset_info.type == DatasetType.SEQUENCE_CLASSIFICATION:
266
+ summary_sub_cats_to_remove = meta_anns.summaries
268
267
  df_pr = MapData(df_pr, remove_cats(summary_sub_categories=summary_sub_cats_to_remove))
269
268
 
270
- elif self.dataset.dataset_info.type == DatasetType.token_classification:
271
- sub_cats_to_remove = meta_anns["sub_categories"]
269
+ elif self.dataset.dataset_info.type == DatasetType.TOKEN_CLASSIFICATION:
270
+ sub_cats_to_remove = meta_anns.sub_categories
272
271
  df_pr = MapData(df_pr, remove_cats(sub_categories=sub_cats_to_remove))
273
272
  else:
274
273
  raise NotImplementedError()
275
274
 
276
275
  return df_pr
277
276
 
278
- def compare(self, interactive: bool = False, **kwargs: Union[str, int]) -> Optional[ImageType]:
277
+ def compare(self, interactive: bool = False, **kwargs: Union[str, int]) -> Generator[PixelValues, None, None]:
279
278
  """
280
279
  Visualize ground truth and prediction datapoint. Given a dataflow config it will run predictions per sample
281
280
  and concat the prediction image (with predicted bounding boxes) with ground truth image.
@@ -292,6 +291,8 @@ class Evaluator:
292
291
  show_layouts = kwargs.pop("show_layouts", True)
293
292
  show_table_structure = kwargs.pop("show_table_structure", True)
294
293
  show_words = kwargs.pop("show_words", False)
294
+ show_token_class = kwargs.pop("show_token_class", True)
295
+ ignore_default_token_class = kwargs.pop("ignore_default_token_class", False)
295
296
 
296
297
  df_gt = self.dataset.dataflow.build(**kwargs)
297
298
  df_pr = self.dataset.dataflow.build(**kwargs)
@@ -300,7 +301,7 @@ class Evaluator:
300
301
  df_pr = MapData(df_pr, deepcopy)
301
302
  df_pr = self._clean_up_predict_dataflow_annotations(df_pr)
302
303
 
303
- page_parsing_component = PageParsingService(text_container=LayoutType.word)
304
+ page_parsing_component = PageParsingService(text_container=LayoutType.WORD)
304
305
  df_gt = page_parsing_component.predict_dataflow(df_gt)
305
306
 
306
307
  if self.pipe_component:
@@ -321,18 +322,21 @@ class Evaluator:
321
322
  show_layouts=show_layouts,
322
323
  show_table_structure=show_table_structure,
323
324
  show_words=show_words,
325
+ show_token_class=show_token_class,
326
+ ignore_default_token_class=ignore_default_token_class,
324
327
  ), dp_pred.viz(
325
328
  show_tables=show_tables,
326
329
  show_layouts=show_layouts,
327
330
  show_table_structure=show_table_structure,
328
331
  show_words=show_words,
332
+ show_token_class=show_token_class,
333
+ ignore_default_token_class=ignore_default_token_class,
329
334
  )
330
335
  img_concat = np.concatenate((img_gt, img_pred), axis=1)
331
336
  if interactive:
332
337
  interactive_imshow(img_concat)
333
338
  else:
334
- return img_concat
335
- return None
339
+ yield img_concat
336
340
 
337
341
 
338
342
  class WandbTableAgent:
@@ -350,11 +354,11 @@ class WandbTableAgent:
350
354
 
351
355
  def __init__(
352
356
  self,
353
- wandb_run: "wandb.sdk.wandb_run.Run",
357
+ wandb_run: wandb.sdk.wandb_run.Run,
354
358
  dataset_name: str,
355
359
  num_samples: int,
356
- categories: Mapping[str, TypeOrStr],
357
- sub_categories: Optional[Mapping[str, TypeOrStr]] = None,
360
+ categories: Mapping[int, TypeOrStr],
361
+ sub_categories: Optional[Mapping[int, TypeOrStr]] = None,
358
362
  cat_to_sub_cat: Optional[Mapping[TypeOrStr, TypeOrStr]] = None,
359
363
  ):
360
364
  """
@@ -381,8 +385,8 @@ class WandbTableAgent:
381
385
  self._counter = 0
382
386
 
383
387
  # Table logging utils
384
- self._table_cols: List[str] = ["file_name", "image"]
385
- self._table_rows: List[Any] = []
388
+ self._table_cols: list[str] = ["file_name", "image"]
389
+ self._table_rows: list[Any] = []
386
390
  self._table_ref = None
387
391
 
388
392
  def dump(self, dp: Image) -> Image:
@@ -409,7 +413,7 @@ class WandbTableAgent:
409
413
  self._table_rows = []
410
414
  self._counter = 0
411
415
 
412
- def _build_table(self) -> "Table":
416
+ def _build_table(self) -> Table:
413
417
  """
414
418
  Builds wandb.Table object for logging evaluation
415
419
 
@@ -435,4 +439,4 @@ class WandbTableAgent:
435
439
  eval_art.add(self._build_table(), self.dataset_name)
436
440
  self._run.use_artifact(eval_art)
437
441
  eval_art.wait()
438
- self._table_ref = eval_art.get(self.dataset_name).data # type:ignore
442
+ self._table_ref = eval_art.get(self.dataset_name).data # type: ignore
@@ -18,30 +18,34 @@ Tree distance similarity metric taken from <https://github.com/ibm-aur-nlp/PubTa
18
18
 
19
19
  import statistics
20
20
  from collections import defaultdict, deque
21
- from typing import Any, List, Optional, Tuple
21
+ from typing import Any, Callable, Optional
22
+
23
+ from lazy_imports import try_import
22
24
 
23
25
  from ..dataflow import DataFlow, DataFromList, MapData, MultiThreadMapData
26
+ from ..datapoint.image import Image
24
27
  from ..datapoint.view import Page
25
28
  from ..datasets.base import DatasetCategories
26
- from ..utils.detection_types import JsonDict
27
- from ..utils.file_utils import (
28
- Requirement,
29
- apted_available,
30
- distance_available,
31
- get_apted_requirement,
32
- get_distance_requirement,
33
- get_lxml_requirement,
34
- lxml_available,
35
- )
29
+ from ..utils.file_utils import Requirement, get_apted_requirement, get_distance_requirement, get_lxml_requirement
36
30
  from ..utils.logger import LoggingRecord, logger
37
31
  from ..utils.settings import LayoutType
32
+ from ..utils.types import MetricResults
38
33
  from .base import MetricBase
39
34
  from .registry import metric_registry
40
35
 
41
- if distance_available() and lxml_available() and apted_available():
42
- import distance # type: ignore
36
+ with try_import() as ap_import_guard:
43
37
  from apted import APTED, Config # type: ignore
44
38
  from apted.helpers import Tree # type: ignore
39
+
40
+
41
+ if not ap_import_guard.is_successful():
42
+ from ..utils.mocks import Config, Tree
43
+
44
+
45
+ with try_import() as ds_import_guard:
46
+ import distance # type: ignore
47
+
48
+ with try_import() as lx_import_guard:
45
49
  from lxml import etree
46
50
 
47
51
 
@@ -56,7 +60,7 @@ class TableTree(Tree):
56
60
  tag: str,
57
61
  colspan: Optional[int] = None,
58
62
  rowspan: Optional[int] = None,
59
- content: Optional[List[str]] = None,
63
+ content: Optional[list[str]] = None,
60
64
  ) -> None:
61
65
  self.tag = tag
62
66
  self.colspan = colspan
@@ -104,7 +108,7 @@ class TEDS:
104
108
 
105
109
  def __init__(self, structure_only: bool = False):
106
110
  self.structure_only = structure_only
107
- self.__tokens__: List[str] = []
111
+ self.__tokens__: list[str] = []
108
112
 
109
113
  def tokenize(self, node: TableTree) -> None:
110
114
  """Tokenizes table cells"""
@@ -146,7 +150,7 @@ class TEDS:
146
150
  return new_node
147
151
  return None
148
152
 
149
- def evaluate(self, inputs: Tuple[str, str]) -> float:
153
+ def evaluate(self, inputs: tuple[str, str]) -> float:
150
154
  """Computes TEDS score between the prediction and the ground truth of a
151
155
  given sample
152
156
  """
@@ -185,7 +189,7 @@ class TEDS:
185
189
  return 0.0
186
190
 
187
191
 
188
- def teds_metric(gt_list: List[str], predict_list: List[str], structure_only: bool) -> Tuple[float, int]:
192
+ def teds_metric(gt_list: list[str], predict_list: list[str], structure_only: bool) -> tuple[float, int]:
189
193
  """
190
194
  Computes tree edit distance score (TEDS) between the prediction and the ground truth of a batch of samples. The
191
195
  approach to measure similarity of tables by means of their html representation has been adovacated in
@@ -218,13 +222,16 @@ class TedsMetric(MetricBase):
218
222
  """
219
223
 
220
224
  metric = teds_metric # type: ignore
221
- mapper = Page.from_image
225
+ mapper: Callable[[Image, LayoutType, list[LayoutType]], Page] = Page.from_image
226
+ text_container: LayoutType = LayoutType.WORD
227
+ floating_text_block_categories = [LayoutType.TABLE]
228
+
222
229
  structure_only = False
223
230
 
224
231
  @classmethod
225
232
  def dump(
226
233
  cls, dataflow_gt: DataFlow, dataflow_predictions: DataFlow, categories: DatasetCategories
227
- ) -> Tuple[List[str], List[str]]:
234
+ ) -> tuple[list[str], list[str]]:
228
235
  dataflow_gt.reset_state()
229
236
  dataflow_predictions.reset_state()
230
237
 
@@ -232,11 +239,11 @@ class TedsMetric(MetricBase):
232
239
  gt_dict = defaultdict(list)
233
240
  pred_dict = defaultdict(list)
234
241
  for dp_gt, dp_pred in zip(dataflow_gt, dataflow_predictions):
235
- page_gt = cls.mapper(dp_gt, LayoutType.word, [LayoutType.table])
242
+ page_gt = cls.mapper(dp_gt, cls.text_container, cls.floating_text_block_categories)
236
243
  for table in page_gt.tables:
237
244
  gt_dict[page_gt.image_id].append(table.html)
238
245
 
239
- page_pred = cls.mapper(dp_pred, LayoutType.word, [LayoutType.table])
246
+ page_pred = cls.mapper(dp_pred, cls.text_container, cls.floating_text_block_categories)
240
247
  for table in page_pred.tables:
241
248
  pred_dict[page_pred.image_id].append(table.html)
242
249
 
@@ -251,12 +258,12 @@ class TedsMetric(MetricBase):
251
258
  @classmethod
252
259
  def get_distance(
253
260
  cls, dataflow_gt: DataFlow, dataflow_predictions: DataFlow, categories: DatasetCategories
254
- ) -> List[JsonDict]:
261
+ ) -> list[MetricResults]:
255
262
  html_gt_list, html_pr_list = cls.dump(dataflow_gt, dataflow_predictions, categories)
256
263
 
257
264
  score, num_samples = cls.metric(html_gt_list, html_pr_list, cls.structure_only) # type: ignore
258
265
  return [{"teds_score": score, "num_samples": num_samples}]
259
266
 
260
267
  @classmethod
261
- def get_requirements(cls) -> List[Requirement]:
268
+ def get_requirements(cls) -> list[Requirement]:
262
269
  return [get_apted_requirement(), get_distance_requirement(), get_lxml_requirement()]
@@ -19,13 +19,15 @@
19
19
  Module for EvalCallback in Tensorpack
20
20
  """
21
21
 
22
+ from __future__ import annotations
23
+
22
24
  from itertools import count
23
25
  from typing import Mapping, Optional, Sequence, Type, Union
24
26
 
27
+ from lazy_imports import try_import
28
+
25
29
  from ..datasets import DatasetBase
26
- from ..extern.tpdetect import TPFrcnnDetector
27
- from ..pipe.base import PredictorPipelineComponent
28
- from ..utils.file_utils import tensorpack_available
30
+ from ..pipe.base import PipelineComponent
29
31
  from ..utils.logger import LoggingRecord, logger
30
32
  from ..utils.metacfg import AttrDict
31
33
  from ..utils.settings import ObjectTypes
@@ -33,12 +35,15 @@ from .base import MetricBase
33
35
  from .eval import Evaluator
34
36
 
35
37
  # pylint: disable=import-error
36
- if tensorpack_available():
38
+ with try_import() as import_guard:
37
39
  from tensorpack.callbacks import Callback
38
40
  from tensorpack.predict import OnlinePredictor
39
41
  from tensorpack.utils.gpu import get_num_gpu
40
42
  # pylint: enable=import-error
41
43
 
44
+ if not import_guard.is_successful():
45
+ from ..utils.mocks import Callback
46
+
42
47
 
43
48
  # The following class is modified from
44
49
  # https://github.com/tensorpack/tensorpack/blob/master/examples/FasterRCNN/eval.py
@@ -53,15 +58,16 @@ class EvalCallback(Callback): # pylint: disable=R0903
53
58
 
54
59
  _chief_only = False
55
60
 
56
- def __init__(
61
+ def __init__( # pylint: disable=W0231
57
62
  self,
58
63
  dataset: DatasetBase,
59
64
  category_names: Optional[Union[ObjectTypes, Sequence[ObjectTypes]]],
60
65
  sub_categories: Optional[Union[Mapping[ObjectTypes, ObjectTypes], Mapping[ObjectTypes, Sequence[ObjectTypes]]]],
61
66
  metric: Union[Type[MetricBase], MetricBase],
62
- pipeline_component: PredictorPipelineComponent,
67
+ pipeline_component: PipelineComponent,
63
68
  in_names: str,
64
69
  out_names: str,
70
+ cfg: AttrDict,
65
71
  **build_eval_kwargs: str,
66
72
  ) -> None:
67
73
  """
@@ -83,12 +89,7 @@ class EvalCallback(Callback): # pylint: disable=R0903
83
89
  self.num_gpu = get_num_gpu()
84
90
  self.category_names = category_names
85
91
  self.sub_categories = sub_categories
86
- if not isinstance(pipeline_component.predictor, TPFrcnnDetector):
87
- raise TypeError(
88
- f"pipeline_component.predictor must be of type TPFrcnnDetector but is "
89
- f"type {type(pipeline_component.predictor)}"
90
- )
91
- self.cfg = pipeline_component.predictor.model.cfg
92
+ self.cfg = cfg
92
93
  if _use_replicated(self.cfg):
93
94
  self.evaluator = Evaluator(dataset, pipeline_component, metric, num_threads=self.num_gpu * 2)
94
95
  else:
@@ -99,13 +100,9 @@ class EvalCallback(Callback): # pylint: disable=R0903
99
100
  if self.evaluator.pipe_component is None:
100
101
  raise TypeError("self.evaluator.pipe_component cannot be None")
101
102
  for idx, comp in enumerate(self.evaluator.pipe_component.pipe_components):
102
- if not isinstance(comp, PredictorPipelineComponent):
103
- raise TypeError(f"comp must be of type PredictorPipelineComponent but is type {type(comp)}")
104
- if not isinstance(comp.predictor, TPFrcnnDetector):
105
- raise TypeError(
106
- f"comp.predictor mus be of type TPFrcnnDetector but is of type {type(comp.predictor)}"
107
- )
108
- comp.predictor.tp_predictor = self._build_predictor(idx % self.num_gpu)
103
+ if hasattr(comp, "predictor"):
104
+ if hasattr(comp.predictor, "tp_predictor"):
105
+ comp.predictor.tp_predictor = self._build_predictor(idx % self.num_gpu)
109
106
 
110
107
  def _build_predictor(self, idx: int) -> OnlinePredictor:
111
108
  return self.trainer.get_predictor(self.in_names, self.out_names, device=idx)
@@ -19,8 +19,8 @@
19
19
  Wrappers for models of external libraries as well as implementation of the Cascade-RCNN model of Tensorpack.
20
20
  """
21
21
 
22
- from ..utils.file_utils import detectron2_available, tensorpack_available
23
22
  from .base import *
23
+ from .d2detect import *
24
24
  from .deskew import *
25
25
  from .doctrocr import *
26
26
  from .fastlang import *
@@ -30,9 +30,4 @@ from .model import *
30
30
  from .pdftext import *
31
31
  from .tessocr import *
32
32
  from .texocr import * # type: ignore
33
-
34
- if tensorpack_available():
35
- from .tpdetect import *
36
-
37
- if detectron2_available():
38
- from .d2detect import *
33
+ from .tpdetect import *