deepdoctection 0.31__py3-none-any.whl → 0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +16 -29
- deepdoctection/analyzer/dd.py +70 -59
- deepdoctection/configs/conf_dd_one.yaml +34 -31
- deepdoctection/dataflow/common.py +9 -5
- deepdoctection/dataflow/custom.py +5 -5
- deepdoctection/dataflow/custom_serialize.py +75 -18
- deepdoctection/dataflow/parallel_map.py +3 -3
- deepdoctection/dataflow/serialize.py +4 -4
- deepdoctection/dataflow/stats.py +3 -3
- deepdoctection/datapoint/annotation.py +41 -56
- deepdoctection/datapoint/box.py +9 -8
- deepdoctection/datapoint/convert.py +6 -6
- deepdoctection/datapoint/image.py +56 -44
- deepdoctection/datapoint/view.py +245 -150
- deepdoctection/datasets/__init__.py +1 -4
- deepdoctection/datasets/adapter.py +35 -26
- deepdoctection/datasets/base.py +14 -12
- deepdoctection/datasets/dataflow_builder.py +3 -3
- deepdoctection/datasets/info.py +24 -26
- deepdoctection/datasets/instances/doclaynet.py +51 -51
- deepdoctection/datasets/instances/fintabnet.py +46 -46
- deepdoctection/datasets/instances/funsd.py +25 -24
- deepdoctection/datasets/instances/iiitar13k.py +13 -10
- deepdoctection/datasets/instances/layouttest.py +4 -3
- deepdoctection/datasets/instances/publaynet.py +5 -5
- deepdoctection/datasets/instances/pubtables1m.py +24 -21
- deepdoctection/datasets/instances/pubtabnet.py +32 -30
- deepdoctection/datasets/instances/rvlcdip.py +30 -30
- deepdoctection/datasets/instances/xfund.py +26 -26
- deepdoctection/datasets/save.py +6 -6
- deepdoctection/eval/__init__.py +1 -4
- deepdoctection/eval/accmetric.py +32 -33
- deepdoctection/eval/base.py +8 -9
- deepdoctection/eval/cocometric.py +15 -13
- deepdoctection/eval/eval.py +41 -37
- deepdoctection/eval/tedsmetric.py +30 -23
- deepdoctection/eval/tp_eval_callback.py +16 -19
- deepdoctection/extern/__init__.py +2 -7
- deepdoctection/extern/base.py +339 -134
- deepdoctection/extern/d2detect.py +85 -113
- deepdoctection/extern/deskew.py +14 -11
- deepdoctection/extern/doctrocr.py +141 -130
- deepdoctection/extern/fastlang.py +27 -18
- deepdoctection/extern/hfdetr.py +71 -62
- deepdoctection/extern/hflayoutlm.py +504 -211
- deepdoctection/extern/hflm.py +230 -0
- deepdoctection/extern/model.py +488 -302
- deepdoctection/extern/pdftext.py +23 -19
- deepdoctection/extern/pt/__init__.py +1 -3
- deepdoctection/extern/pt/nms.py +6 -2
- deepdoctection/extern/pt/ptutils.py +29 -19
- deepdoctection/extern/tessocr.py +39 -38
- deepdoctection/extern/texocr.py +18 -18
- deepdoctection/extern/tp/tfutils.py +57 -9
- deepdoctection/extern/tp/tpcompat.py +21 -14
- deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
- deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/config/config.py +13 -10
- deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +18 -8
- deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +14 -9
- deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +22 -17
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +21 -14
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +19 -11
- deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
- deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/preproc.py +12 -8
- deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
- deepdoctection/extern/tpdetect.py +45 -53
- deepdoctection/mapper/__init__.py +3 -8
- deepdoctection/mapper/cats.py +27 -29
- deepdoctection/mapper/cocostruct.py +10 -10
- deepdoctection/mapper/d2struct.py +27 -26
- deepdoctection/mapper/hfstruct.py +13 -8
- deepdoctection/mapper/laylmstruct.py +178 -37
- deepdoctection/mapper/maputils.py +12 -11
- deepdoctection/mapper/match.py +2 -2
- deepdoctection/mapper/misc.py +11 -9
- deepdoctection/mapper/pascalstruct.py +4 -4
- deepdoctection/mapper/prodigystruct.py +5 -5
- deepdoctection/mapper/pubstruct.py +84 -92
- deepdoctection/mapper/tpstruct.py +5 -5
- deepdoctection/mapper/xfundstruct.py +33 -33
- deepdoctection/pipe/__init__.py +1 -1
- deepdoctection/pipe/anngen.py +12 -14
- deepdoctection/pipe/base.py +52 -106
- deepdoctection/pipe/common.py +72 -59
- deepdoctection/pipe/concurrency.py +16 -11
- deepdoctection/pipe/doctectionpipe.py +24 -21
- deepdoctection/pipe/language.py +20 -25
- deepdoctection/pipe/layout.py +20 -16
- deepdoctection/pipe/lm.py +75 -105
- deepdoctection/pipe/order.py +194 -89
- deepdoctection/pipe/refine.py +111 -124
- deepdoctection/pipe/segment.py +156 -161
- deepdoctection/pipe/{cell.py → sub_layout.py} +50 -40
- deepdoctection/pipe/text.py +37 -36
- deepdoctection/pipe/transform.py +19 -16
- deepdoctection/train/__init__.py +6 -12
- deepdoctection/train/d2_frcnn_train.py +48 -41
- deepdoctection/train/hf_detr_train.py +41 -30
- deepdoctection/train/hf_layoutlm_train.py +153 -135
- deepdoctection/train/tp_frcnn_train.py +32 -31
- deepdoctection/utils/concurrency.py +1 -1
- deepdoctection/utils/context.py +13 -6
- deepdoctection/utils/develop.py +4 -4
- deepdoctection/utils/env_info.py +87 -125
- deepdoctection/utils/file_utils.py +6 -11
- deepdoctection/utils/fs.py +22 -18
- deepdoctection/utils/identifier.py +2 -2
- deepdoctection/utils/logger.py +16 -15
- deepdoctection/utils/metacfg.py +7 -7
- deepdoctection/utils/mocks.py +93 -0
- deepdoctection/utils/pdf_utils.py +11 -11
- deepdoctection/utils/settings.py +185 -181
- deepdoctection/utils/tqdm.py +1 -1
- deepdoctection/utils/transform.py +14 -9
- deepdoctection/utils/types.py +104 -0
- deepdoctection/utils/utils.py +7 -7
- deepdoctection/utils/viz.py +74 -72
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/METADATA +30 -21
- deepdoctection-0.33.dist-info/RECORD +146 -0
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/WHEEL +1 -1
- deepdoctection/utils/detection_types.py +0 -68
- deepdoctection-0.31.dist-info/RECORD +0 -144
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/LICENSE +0 -0
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/top_level.txt +0 -0
deepdoctection/eval/eval.py
CHANGED
|
@@ -19,36 +19,35 @@
|
|
|
19
19
|
"""
|
|
20
20
|
Module for `Evaluator`
|
|
21
21
|
"""
|
|
22
|
-
|
|
23
|
-
__all__ = ["Evaluator"]
|
|
22
|
+
from __future__ import annotations
|
|
24
23
|
|
|
25
24
|
from copy import deepcopy
|
|
26
|
-
from typing import Any,
|
|
25
|
+
from typing import Any, Generator, Literal, Mapping, Optional, Type, Union, overload
|
|
27
26
|
|
|
28
27
|
import numpy as np
|
|
28
|
+
from lazy_imports import try_import
|
|
29
29
|
|
|
30
30
|
from ..dataflow import CacheData, DataFlow, DataFromList, MapData
|
|
31
31
|
from ..datapoint.image import Image
|
|
32
32
|
from ..datasets.base import DatasetBase
|
|
33
33
|
from ..mapper.cats import filter_cat, remove_cats
|
|
34
|
+
from ..mapper.d2struct import to_wandb_image
|
|
34
35
|
from ..mapper.misc import maybe_load_image, maybe_remove_image, maybe_remove_image_from_category
|
|
35
|
-
from ..pipe.base import
|
|
36
|
+
from ..pipe.base import PipelineComponent
|
|
36
37
|
from ..pipe.common import PageParsingService
|
|
37
38
|
from ..pipe.concurrency import MultiThreadPipelineComponent
|
|
38
39
|
from ..pipe.doctectionpipe import DoctectionPipe
|
|
39
|
-
from ..utils.detection_types import ImageType
|
|
40
|
-
from ..utils.file_utils import detectron2_available, wandb_available
|
|
41
40
|
from ..utils.logger import LoggingRecord, logger
|
|
42
41
|
from ..utils.settings import DatasetType, LayoutType, TypeOrStr, get_type
|
|
42
|
+
from ..utils.types import PixelValues
|
|
43
43
|
from ..utils.viz import interactive_imshow
|
|
44
44
|
from .base import MetricBase
|
|
45
45
|
|
|
46
|
-
|
|
46
|
+
with try_import() as wb_import_guard:
|
|
47
47
|
import wandb # pylint:disable=W0611
|
|
48
48
|
from wandb import Artifact, Table
|
|
49
49
|
|
|
50
|
-
|
|
51
|
-
from ..mapper.d2struct import to_wandb_image
|
|
50
|
+
__all__ = ["Evaluator"]
|
|
52
51
|
|
|
53
52
|
|
|
54
53
|
class Evaluator:
|
|
@@ -91,10 +90,10 @@ class Evaluator:
|
|
|
91
90
|
def __init__(
|
|
92
91
|
self,
|
|
93
92
|
dataset: DatasetBase,
|
|
94
|
-
component_or_pipeline: Union[
|
|
93
|
+
component_or_pipeline: Union[PipelineComponent, DoctectionPipe],
|
|
95
94
|
metric: Union[Type[MetricBase], MetricBase],
|
|
96
95
|
num_threads: int = 2,
|
|
97
|
-
run: Optional[
|
|
96
|
+
run: Optional[wandb.sdk.wandb_run.Run] = None,
|
|
98
97
|
) -> None:
|
|
99
98
|
"""
|
|
100
99
|
Evaluating a pipeline component on a dataset with a given metric.
|
|
@@ -109,14 +108,14 @@ class Evaluator:
|
|
|
109
108
|
self.pipe: Optional[DoctectionPipe] = None
|
|
110
109
|
|
|
111
110
|
# when passing a component, we will process prediction on num_threads
|
|
112
|
-
if isinstance(component_or_pipeline,
|
|
111
|
+
if isinstance(component_or_pipeline, PipelineComponent):
|
|
113
112
|
logger.info(
|
|
114
113
|
LoggingRecord(
|
|
115
114
|
f"Building multi threading pipeline component to increase prediction throughput. "
|
|
116
115
|
f"Using {num_threads} threads"
|
|
117
116
|
)
|
|
118
117
|
)
|
|
119
|
-
pipeline_components:
|
|
118
|
+
pipeline_components: list[PipelineComponent] = []
|
|
120
119
|
|
|
121
120
|
for _ in range(num_threads - 1):
|
|
122
121
|
copy_pipe_component = component_or_pipeline.clone()
|
|
@@ -140,14 +139,14 @@ class Evaluator:
|
|
|
140
139
|
|
|
141
140
|
self.wandb_table_agent: Optional[WandbTableAgent]
|
|
142
141
|
if run is not None:
|
|
143
|
-
if self.dataset.dataset_info.type == DatasetType.
|
|
142
|
+
if self.dataset.dataset_info.type == DatasetType.OBJECT_DETECTION:
|
|
144
143
|
self.wandb_table_agent = WandbTableAgent(
|
|
145
144
|
run,
|
|
146
145
|
self.dataset.dataset_info.name,
|
|
147
146
|
50,
|
|
148
147
|
self.dataset.dataflow.categories.get_categories(filtered=True),
|
|
149
148
|
)
|
|
150
|
-
elif self.dataset.dataset_info.type == DatasetType.
|
|
149
|
+
elif self.dataset.dataset_info.type == DatasetType.TOKEN_CLASSIFICATION:
|
|
151
150
|
if hasattr(self.metric, "sub_cats"):
|
|
152
151
|
sub_cat_key, sub_cat_val_list = list(self.metric.sub_cats.items())[0]
|
|
153
152
|
sub_cat_val = sub_cat_val_list[0]
|
|
@@ -179,16 +178,16 @@ class Evaluator:
|
|
|
179
178
|
@overload
|
|
180
179
|
def run(
|
|
181
180
|
self, output_as_dict: Literal[False] = False, **dataflow_build_kwargs: Union[str, int]
|
|
182
|
-
) ->
|
|
181
|
+
) -> list[dict[str, float]]:
|
|
183
182
|
...
|
|
184
183
|
|
|
185
184
|
@overload
|
|
186
|
-
def run(self, output_as_dict: Literal[True], **dataflow_build_kwargs: Union[str, int]) ->
|
|
185
|
+
def run(self, output_as_dict: Literal[True], **dataflow_build_kwargs: Union[str, int]) -> dict[str, float]:
|
|
187
186
|
...
|
|
188
187
|
|
|
189
188
|
def run(
|
|
190
189
|
self, output_as_dict: bool = False, **dataflow_build_kwargs: Union[str, int]
|
|
191
|
-
) -> Union[
|
|
190
|
+
) -> Union[list[dict[str, float]], dict[str, float]]:
|
|
192
191
|
"""
|
|
193
192
|
Start evaluation process and return the results.
|
|
194
193
|
|
|
@@ -247,11 +246,11 @@ class Evaluator:
|
|
|
247
246
|
possible_cats_in_datapoint = self.dataset.dataflow.categories.get_categories(as_dict=False, filtered=True)
|
|
248
247
|
|
|
249
248
|
# clean-up procedure depends on the dataset type
|
|
250
|
-
if self.dataset.dataset_info.type == DatasetType.
|
|
249
|
+
if self.dataset.dataset_info.type == DatasetType.OBJECT_DETECTION:
|
|
251
250
|
# we keep all image annotations that will not be generated through processing
|
|
252
|
-
anns_to_keep = {ann for ann in possible_cats_in_datapoint if ann not in meta_anns
|
|
253
|
-
sub_cats_to_remove = meta_anns
|
|
254
|
-
relationships_to_remove = meta_anns
|
|
251
|
+
anns_to_keep = {ann for ann in possible_cats_in_datapoint if ann not in meta_anns.image_annotations}
|
|
252
|
+
sub_cats_to_remove = meta_anns.sub_categories
|
|
253
|
+
relationships_to_remove = meta_anns.relationships
|
|
255
254
|
# removing annotations takes place in three steps: First we remove all image annotations. Then, with all
|
|
256
255
|
# remaining image annotations we check, if the image attribute (with Image instance !) is not empty and
|
|
257
256
|
# remove it as well, if necessary. In the last step we remove all sub categories and relationships, if
|
|
@@ -263,19 +262,19 @@ class Evaluator:
|
|
|
263
262
|
remove_cats(sub_categories=sub_cats_to_remove, relationships=relationships_to_remove),
|
|
264
263
|
)
|
|
265
264
|
|
|
266
|
-
elif self.dataset.dataset_info.type == DatasetType.
|
|
267
|
-
summary_sub_cats_to_remove = meta_anns
|
|
265
|
+
elif self.dataset.dataset_info.type == DatasetType.SEQUENCE_CLASSIFICATION:
|
|
266
|
+
summary_sub_cats_to_remove = meta_anns.summaries
|
|
268
267
|
df_pr = MapData(df_pr, remove_cats(summary_sub_categories=summary_sub_cats_to_remove))
|
|
269
268
|
|
|
270
|
-
elif self.dataset.dataset_info.type == DatasetType.
|
|
271
|
-
sub_cats_to_remove = meta_anns
|
|
269
|
+
elif self.dataset.dataset_info.type == DatasetType.TOKEN_CLASSIFICATION:
|
|
270
|
+
sub_cats_to_remove = meta_anns.sub_categories
|
|
272
271
|
df_pr = MapData(df_pr, remove_cats(sub_categories=sub_cats_to_remove))
|
|
273
272
|
else:
|
|
274
273
|
raise NotImplementedError()
|
|
275
274
|
|
|
276
275
|
return df_pr
|
|
277
276
|
|
|
278
|
-
def compare(self, interactive: bool = False, **kwargs: Union[str, int]) ->
|
|
277
|
+
def compare(self, interactive: bool = False, **kwargs: Union[str, int]) -> Generator[PixelValues, None, None]:
|
|
279
278
|
"""
|
|
280
279
|
Visualize ground truth and prediction datapoint. Given a dataflow config it will run predictions per sample
|
|
281
280
|
and concat the prediction image (with predicted bounding boxes) with ground truth image.
|
|
@@ -292,6 +291,8 @@ class Evaluator:
|
|
|
292
291
|
show_layouts = kwargs.pop("show_layouts", True)
|
|
293
292
|
show_table_structure = kwargs.pop("show_table_structure", True)
|
|
294
293
|
show_words = kwargs.pop("show_words", False)
|
|
294
|
+
show_token_class = kwargs.pop("show_token_class", True)
|
|
295
|
+
ignore_default_token_class = kwargs.pop("ignore_default_token_class", False)
|
|
295
296
|
|
|
296
297
|
df_gt = self.dataset.dataflow.build(**kwargs)
|
|
297
298
|
df_pr = self.dataset.dataflow.build(**kwargs)
|
|
@@ -300,7 +301,7 @@ class Evaluator:
|
|
|
300
301
|
df_pr = MapData(df_pr, deepcopy)
|
|
301
302
|
df_pr = self._clean_up_predict_dataflow_annotations(df_pr)
|
|
302
303
|
|
|
303
|
-
page_parsing_component = PageParsingService(text_container=LayoutType.
|
|
304
|
+
page_parsing_component = PageParsingService(text_container=LayoutType.WORD)
|
|
304
305
|
df_gt = page_parsing_component.predict_dataflow(df_gt)
|
|
305
306
|
|
|
306
307
|
if self.pipe_component:
|
|
@@ -321,18 +322,21 @@ class Evaluator:
|
|
|
321
322
|
show_layouts=show_layouts,
|
|
322
323
|
show_table_structure=show_table_structure,
|
|
323
324
|
show_words=show_words,
|
|
325
|
+
show_token_class=show_token_class,
|
|
326
|
+
ignore_default_token_class=ignore_default_token_class,
|
|
324
327
|
), dp_pred.viz(
|
|
325
328
|
show_tables=show_tables,
|
|
326
329
|
show_layouts=show_layouts,
|
|
327
330
|
show_table_structure=show_table_structure,
|
|
328
331
|
show_words=show_words,
|
|
332
|
+
show_token_class=show_token_class,
|
|
333
|
+
ignore_default_token_class=ignore_default_token_class,
|
|
329
334
|
)
|
|
330
335
|
img_concat = np.concatenate((img_gt, img_pred), axis=1)
|
|
331
336
|
if interactive:
|
|
332
337
|
interactive_imshow(img_concat)
|
|
333
338
|
else:
|
|
334
|
-
|
|
335
|
-
return None
|
|
339
|
+
yield img_concat
|
|
336
340
|
|
|
337
341
|
|
|
338
342
|
class WandbTableAgent:
|
|
@@ -350,11 +354,11 @@ class WandbTableAgent:
|
|
|
350
354
|
|
|
351
355
|
def __init__(
|
|
352
356
|
self,
|
|
353
|
-
wandb_run:
|
|
357
|
+
wandb_run: wandb.sdk.wandb_run.Run,
|
|
354
358
|
dataset_name: str,
|
|
355
359
|
num_samples: int,
|
|
356
|
-
categories: Mapping[
|
|
357
|
-
sub_categories: Optional[Mapping[
|
|
360
|
+
categories: Mapping[int, TypeOrStr],
|
|
361
|
+
sub_categories: Optional[Mapping[int, TypeOrStr]] = None,
|
|
358
362
|
cat_to_sub_cat: Optional[Mapping[TypeOrStr, TypeOrStr]] = None,
|
|
359
363
|
):
|
|
360
364
|
"""
|
|
@@ -381,8 +385,8 @@ class WandbTableAgent:
|
|
|
381
385
|
self._counter = 0
|
|
382
386
|
|
|
383
387
|
# Table logging utils
|
|
384
|
-
self._table_cols:
|
|
385
|
-
self._table_rows:
|
|
388
|
+
self._table_cols: list[str] = ["file_name", "image"]
|
|
389
|
+
self._table_rows: list[Any] = []
|
|
386
390
|
self._table_ref = None
|
|
387
391
|
|
|
388
392
|
def dump(self, dp: Image) -> Image:
|
|
@@ -409,7 +413,7 @@ class WandbTableAgent:
|
|
|
409
413
|
self._table_rows = []
|
|
410
414
|
self._counter = 0
|
|
411
415
|
|
|
412
|
-
def _build_table(self) ->
|
|
416
|
+
def _build_table(self) -> Table:
|
|
413
417
|
"""
|
|
414
418
|
Builds wandb.Table object for logging evaluation
|
|
415
419
|
|
|
@@ -435,4 +439,4 @@ class WandbTableAgent:
|
|
|
435
439
|
eval_art.add(self._build_table(), self.dataset_name)
|
|
436
440
|
self._run.use_artifact(eval_art)
|
|
437
441
|
eval_art.wait()
|
|
438
|
-
self._table_ref = eval_art.get(self.dataset_name).data # type:ignore
|
|
442
|
+
self._table_ref = eval_art.get(self.dataset_name).data # type: ignore
|
|
@@ -18,30 +18,34 @@ Tree distance similarity metric taken from <https://github.com/ibm-aur-nlp/PubTa
|
|
|
18
18
|
|
|
19
19
|
import statistics
|
|
20
20
|
from collections import defaultdict, deque
|
|
21
|
-
from typing import Any,
|
|
21
|
+
from typing import Any, Callable, Optional
|
|
22
|
+
|
|
23
|
+
from lazy_imports import try_import
|
|
22
24
|
|
|
23
25
|
from ..dataflow import DataFlow, DataFromList, MapData, MultiThreadMapData
|
|
26
|
+
from ..datapoint.image import Image
|
|
24
27
|
from ..datapoint.view import Page
|
|
25
28
|
from ..datasets.base import DatasetCategories
|
|
26
|
-
from ..utils.
|
|
27
|
-
from ..utils.file_utils import (
|
|
28
|
-
Requirement,
|
|
29
|
-
apted_available,
|
|
30
|
-
distance_available,
|
|
31
|
-
get_apted_requirement,
|
|
32
|
-
get_distance_requirement,
|
|
33
|
-
get_lxml_requirement,
|
|
34
|
-
lxml_available,
|
|
35
|
-
)
|
|
29
|
+
from ..utils.file_utils import Requirement, get_apted_requirement, get_distance_requirement, get_lxml_requirement
|
|
36
30
|
from ..utils.logger import LoggingRecord, logger
|
|
37
31
|
from ..utils.settings import LayoutType
|
|
32
|
+
from ..utils.types import MetricResults
|
|
38
33
|
from .base import MetricBase
|
|
39
34
|
from .registry import metric_registry
|
|
40
35
|
|
|
41
|
-
|
|
42
|
-
import distance # type: ignore
|
|
36
|
+
with try_import() as ap_import_guard:
|
|
43
37
|
from apted import APTED, Config # type: ignore
|
|
44
38
|
from apted.helpers import Tree # type: ignore
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
if not ap_import_guard.is_successful():
|
|
42
|
+
from ..utils.mocks import Config, Tree
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
with try_import() as ds_import_guard:
|
|
46
|
+
import distance # type: ignore
|
|
47
|
+
|
|
48
|
+
with try_import() as lx_import_guard:
|
|
45
49
|
from lxml import etree
|
|
46
50
|
|
|
47
51
|
|
|
@@ -56,7 +60,7 @@ class TableTree(Tree):
|
|
|
56
60
|
tag: str,
|
|
57
61
|
colspan: Optional[int] = None,
|
|
58
62
|
rowspan: Optional[int] = None,
|
|
59
|
-
content: Optional[
|
|
63
|
+
content: Optional[list[str]] = None,
|
|
60
64
|
) -> None:
|
|
61
65
|
self.tag = tag
|
|
62
66
|
self.colspan = colspan
|
|
@@ -104,7 +108,7 @@ class TEDS:
|
|
|
104
108
|
|
|
105
109
|
def __init__(self, structure_only: bool = False):
|
|
106
110
|
self.structure_only = structure_only
|
|
107
|
-
self.__tokens__:
|
|
111
|
+
self.__tokens__: list[str] = []
|
|
108
112
|
|
|
109
113
|
def tokenize(self, node: TableTree) -> None:
|
|
110
114
|
"""Tokenizes table cells"""
|
|
@@ -146,7 +150,7 @@ class TEDS:
|
|
|
146
150
|
return new_node
|
|
147
151
|
return None
|
|
148
152
|
|
|
149
|
-
def evaluate(self, inputs:
|
|
153
|
+
def evaluate(self, inputs: tuple[str, str]) -> float:
|
|
150
154
|
"""Computes TEDS score between the prediction and the ground truth of a
|
|
151
155
|
given sample
|
|
152
156
|
"""
|
|
@@ -185,7 +189,7 @@ class TEDS:
|
|
|
185
189
|
return 0.0
|
|
186
190
|
|
|
187
191
|
|
|
188
|
-
def teds_metric(gt_list:
|
|
192
|
+
def teds_metric(gt_list: list[str], predict_list: list[str], structure_only: bool) -> tuple[float, int]:
|
|
189
193
|
"""
|
|
190
194
|
Computes tree edit distance score (TEDS) between the prediction and the ground truth of a batch of samples. The
|
|
191
195
|
approach to measure similarity of tables by means of their html representation has been adovacated in
|
|
@@ -218,13 +222,16 @@ class TedsMetric(MetricBase):
|
|
|
218
222
|
"""
|
|
219
223
|
|
|
220
224
|
metric = teds_metric # type: ignore
|
|
221
|
-
mapper = Page.from_image
|
|
225
|
+
mapper: Callable[[Image, LayoutType, list[LayoutType]], Page] = Page.from_image
|
|
226
|
+
text_container: LayoutType = LayoutType.WORD
|
|
227
|
+
floating_text_block_categories = [LayoutType.TABLE]
|
|
228
|
+
|
|
222
229
|
structure_only = False
|
|
223
230
|
|
|
224
231
|
@classmethod
|
|
225
232
|
def dump(
|
|
226
233
|
cls, dataflow_gt: DataFlow, dataflow_predictions: DataFlow, categories: DatasetCategories
|
|
227
|
-
) ->
|
|
234
|
+
) -> tuple[list[str], list[str]]:
|
|
228
235
|
dataflow_gt.reset_state()
|
|
229
236
|
dataflow_predictions.reset_state()
|
|
230
237
|
|
|
@@ -232,11 +239,11 @@ class TedsMetric(MetricBase):
|
|
|
232
239
|
gt_dict = defaultdict(list)
|
|
233
240
|
pred_dict = defaultdict(list)
|
|
234
241
|
for dp_gt, dp_pred in zip(dataflow_gt, dataflow_predictions):
|
|
235
|
-
page_gt = cls.mapper(dp_gt,
|
|
242
|
+
page_gt = cls.mapper(dp_gt, cls.text_container, cls.floating_text_block_categories)
|
|
236
243
|
for table in page_gt.tables:
|
|
237
244
|
gt_dict[page_gt.image_id].append(table.html)
|
|
238
245
|
|
|
239
|
-
page_pred = cls.mapper(dp_pred,
|
|
246
|
+
page_pred = cls.mapper(dp_pred, cls.text_container, cls.floating_text_block_categories)
|
|
240
247
|
for table in page_pred.tables:
|
|
241
248
|
pred_dict[page_pred.image_id].append(table.html)
|
|
242
249
|
|
|
@@ -251,12 +258,12 @@ class TedsMetric(MetricBase):
|
|
|
251
258
|
@classmethod
|
|
252
259
|
def get_distance(
|
|
253
260
|
cls, dataflow_gt: DataFlow, dataflow_predictions: DataFlow, categories: DatasetCategories
|
|
254
|
-
) ->
|
|
261
|
+
) -> list[MetricResults]:
|
|
255
262
|
html_gt_list, html_pr_list = cls.dump(dataflow_gt, dataflow_predictions, categories)
|
|
256
263
|
|
|
257
264
|
score, num_samples = cls.metric(html_gt_list, html_pr_list, cls.structure_only) # type: ignore
|
|
258
265
|
return [{"teds_score": score, "num_samples": num_samples}]
|
|
259
266
|
|
|
260
267
|
@classmethod
|
|
261
|
-
def get_requirements(cls) ->
|
|
268
|
+
def get_requirements(cls) -> list[Requirement]:
|
|
262
269
|
return [get_apted_requirement(), get_distance_requirement(), get_lxml_requirement()]
|
|
@@ -19,13 +19,15 @@
|
|
|
19
19
|
Module for EvalCallback in Tensorpack
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
22
24
|
from itertools import count
|
|
23
25
|
from typing import Mapping, Optional, Sequence, Type, Union
|
|
24
26
|
|
|
27
|
+
from lazy_imports import try_import
|
|
28
|
+
|
|
25
29
|
from ..datasets import DatasetBase
|
|
26
|
-
from ..
|
|
27
|
-
from ..pipe.base import PredictorPipelineComponent
|
|
28
|
-
from ..utils.file_utils import tensorpack_available
|
|
30
|
+
from ..pipe.base import PipelineComponent
|
|
29
31
|
from ..utils.logger import LoggingRecord, logger
|
|
30
32
|
from ..utils.metacfg import AttrDict
|
|
31
33
|
from ..utils.settings import ObjectTypes
|
|
@@ -33,12 +35,15 @@ from .base import MetricBase
|
|
|
33
35
|
from .eval import Evaluator
|
|
34
36
|
|
|
35
37
|
# pylint: disable=import-error
|
|
36
|
-
|
|
38
|
+
with try_import() as import_guard:
|
|
37
39
|
from tensorpack.callbacks import Callback
|
|
38
40
|
from tensorpack.predict import OnlinePredictor
|
|
39
41
|
from tensorpack.utils.gpu import get_num_gpu
|
|
40
42
|
# pylint: enable=import-error
|
|
41
43
|
|
|
44
|
+
if not import_guard.is_successful():
|
|
45
|
+
from ..utils.mocks import Callback
|
|
46
|
+
|
|
42
47
|
|
|
43
48
|
# The following class is modified from
|
|
44
49
|
# https://github.com/tensorpack/tensorpack/blob/master/examples/FasterRCNN/eval.py
|
|
@@ -53,15 +58,16 @@ class EvalCallback(Callback): # pylint: disable=R0903
|
|
|
53
58
|
|
|
54
59
|
_chief_only = False
|
|
55
60
|
|
|
56
|
-
def __init__(
|
|
61
|
+
def __init__( # pylint: disable=W0231
|
|
57
62
|
self,
|
|
58
63
|
dataset: DatasetBase,
|
|
59
64
|
category_names: Optional[Union[ObjectTypes, Sequence[ObjectTypes]]],
|
|
60
65
|
sub_categories: Optional[Union[Mapping[ObjectTypes, ObjectTypes], Mapping[ObjectTypes, Sequence[ObjectTypes]]]],
|
|
61
66
|
metric: Union[Type[MetricBase], MetricBase],
|
|
62
|
-
pipeline_component:
|
|
67
|
+
pipeline_component: PipelineComponent,
|
|
63
68
|
in_names: str,
|
|
64
69
|
out_names: str,
|
|
70
|
+
cfg: AttrDict,
|
|
65
71
|
**build_eval_kwargs: str,
|
|
66
72
|
) -> None:
|
|
67
73
|
"""
|
|
@@ -83,12 +89,7 @@ class EvalCallback(Callback): # pylint: disable=R0903
|
|
|
83
89
|
self.num_gpu = get_num_gpu()
|
|
84
90
|
self.category_names = category_names
|
|
85
91
|
self.sub_categories = sub_categories
|
|
86
|
-
|
|
87
|
-
raise TypeError(
|
|
88
|
-
f"pipeline_component.predictor must be of type TPFrcnnDetector but is "
|
|
89
|
-
f"type {type(pipeline_component.predictor)}"
|
|
90
|
-
)
|
|
91
|
-
self.cfg = pipeline_component.predictor.model.cfg
|
|
92
|
+
self.cfg = cfg
|
|
92
93
|
if _use_replicated(self.cfg):
|
|
93
94
|
self.evaluator = Evaluator(dataset, pipeline_component, metric, num_threads=self.num_gpu * 2)
|
|
94
95
|
else:
|
|
@@ -99,13 +100,9 @@ class EvalCallback(Callback): # pylint: disable=R0903
|
|
|
99
100
|
if self.evaluator.pipe_component is None:
|
|
100
101
|
raise TypeError("self.evaluator.pipe_component cannot be None")
|
|
101
102
|
for idx, comp in enumerate(self.evaluator.pipe_component.pipe_components):
|
|
102
|
-
if
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
raise TypeError(
|
|
106
|
-
f"comp.predictor mus be of type TPFrcnnDetector but is of type {type(comp.predictor)}"
|
|
107
|
-
)
|
|
108
|
-
comp.predictor.tp_predictor = self._build_predictor(idx % self.num_gpu)
|
|
103
|
+
if hasattr(comp, "predictor"):
|
|
104
|
+
if hasattr(comp.predictor, "tp_predictor"):
|
|
105
|
+
comp.predictor.tp_predictor = self._build_predictor(idx % self.num_gpu)
|
|
109
106
|
|
|
110
107
|
def _build_predictor(self, idx: int) -> OnlinePredictor:
|
|
111
108
|
return self.trainer.get_predictor(self.in_names, self.out_names, device=idx)
|
|
@@ -19,8 +19,8 @@
|
|
|
19
19
|
Wrappers for models of external libraries as well as implementation of the Cascade-RCNN model of Tensorpack.
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
from ..utils.file_utils import detectron2_available, tensorpack_available
|
|
23
22
|
from .base import *
|
|
23
|
+
from .d2detect import *
|
|
24
24
|
from .deskew import *
|
|
25
25
|
from .doctrocr import *
|
|
26
26
|
from .fastlang import *
|
|
@@ -30,9 +30,4 @@ from .model import *
|
|
|
30
30
|
from .pdftext import *
|
|
31
31
|
from .tessocr import *
|
|
32
32
|
from .texocr import * # type: ignore
|
|
33
|
-
|
|
34
|
-
if tensorpack_available():
|
|
35
|
-
from .tpdetect import *
|
|
36
|
-
|
|
37
|
-
if detectron2_available():
|
|
38
|
-
from .d2detect import *
|
|
33
|
+
from .tpdetect import *
|