deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +8 -25
- deepdoctection/analyzer/dd.py +84 -71
- deepdoctection/dataflow/common.py +9 -5
- deepdoctection/dataflow/custom.py +5 -5
- deepdoctection/dataflow/custom_serialize.py +75 -18
- deepdoctection/dataflow/parallel_map.py +3 -3
- deepdoctection/dataflow/serialize.py +4 -4
- deepdoctection/dataflow/stats.py +3 -3
- deepdoctection/datapoint/annotation.py +78 -56
- deepdoctection/datapoint/box.py +7 -7
- deepdoctection/datapoint/convert.py +6 -6
- deepdoctection/datapoint/image.py +157 -75
- deepdoctection/datapoint/view.py +175 -151
- deepdoctection/datasets/adapter.py +30 -24
- deepdoctection/datasets/base.py +10 -10
- deepdoctection/datasets/dataflow_builder.py +3 -3
- deepdoctection/datasets/info.py +23 -25
- deepdoctection/datasets/instances/doclaynet.py +48 -49
- deepdoctection/datasets/instances/fintabnet.py +44 -45
- deepdoctection/datasets/instances/funsd.py +23 -23
- deepdoctection/datasets/instances/iiitar13k.py +8 -8
- deepdoctection/datasets/instances/layouttest.py +2 -2
- deepdoctection/datasets/instances/publaynet.py +3 -3
- deepdoctection/datasets/instances/pubtables1m.py +18 -18
- deepdoctection/datasets/instances/pubtabnet.py +30 -29
- deepdoctection/datasets/instances/rvlcdip.py +28 -29
- deepdoctection/datasets/instances/xfund.py +51 -30
- deepdoctection/datasets/save.py +6 -6
- deepdoctection/eval/accmetric.py +32 -33
- deepdoctection/eval/base.py +8 -9
- deepdoctection/eval/cocometric.py +13 -12
- deepdoctection/eval/eval.py +32 -26
- deepdoctection/eval/tedsmetric.py +16 -12
- deepdoctection/eval/tp_eval_callback.py +7 -16
- deepdoctection/extern/base.py +339 -134
- deepdoctection/extern/d2detect.py +69 -89
- deepdoctection/extern/deskew.py +11 -10
- deepdoctection/extern/doctrocr.py +81 -64
- deepdoctection/extern/fastlang.py +23 -16
- deepdoctection/extern/hfdetr.py +53 -38
- deepdoctection/extern/hflayoutlm.py +216 -155
- deepdoctection/extern/hflm.py +35 -30
- deepdoctection/extern/model.py +433 -255
- deepdoctection/extern/pdftext.py +15 -15
- deepdoctection/extern/pt/ptutils.py +4 -2
- deepdoctection/extern/tessocr.py +39 -38
- deepdoctection/extern/texocr.py +14 -16
- deepdoctection/extern/tp/tfutils.py +16 -2
- deepdoctection/extern/tp/tpcompat.py +11 -7
- deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
- deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
- deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
- deepdoctection/extern/tpdetect.py +40 -45
- deepdoctection/mapper/cats.py +36 -40
- deepdoctection/mapper/cocostruct.py +16 -12
- deepdoctection/mapper/d2struct.py +22 -22
- deepdoctection/mapper/hfstruct.py +7 -7
- deepdoctection/mapper/laylmstruct.py +22 -24
- deepdoctection/mapper/maputils.py +9 -10
- deepdoctection/mapper/match.py +33 -2
- deepdoctection/mapper/misc.py +6 -7
- deepdoctection/mapper/pascalstruct.py +4 -4
- deepdoctection/mapper/prodigystruct.py +6 -6
- deepdoctection/mapper/pubstruct.py +84 -92
- deepdoctection/mapper/tpstruct.py +3 -3
- deepdoctection/mapper/xfundstruct.py +33 -33
- deepdoctection/pipe/anngen.py +39 -14
- deepdoctection/pipe/base.py +68 -99
- deepdoctection/pipe/common.py +181 -85
- deepdoctection/pipe/concurrency.py +14 -10
- deepdoctection/pipe/doctectionpipe.py +24 -21
- deepdoctection/pipe/language.py +20 -25
- deepdoctection/pipe/layout.py +18 -16
- deepdoctection/pipe/lm.py +49 -47
- deepdoctection/pipe/order.py +63 -65
- deepdoctection/pipe/refine.py +102 -109
- deepdoctection/pipe/segment.py +157 -162
- deepdoctection/pipe/sub_layout.py +50 -40
- deepdoctection/pipe/text.py +37 -36
- deepdoctection/pipe/transform.py +19 -16
- deepdoctection/train/d2_frcnn_train.py +27 -25
- deepdoctection/train/hf_detr_train.py +22 -18
- deepdoctection/train/hf_layoutlm_train.py +49 -48
- deepdoctection/train/tp_frcnn_train.py +10 -11
- deepdoctection/utils/concurrency.py +1 -1
- deepdoctection/utils/context.py +13 -6
- deepdoctection/utils/develop.py +4 -4
- deepdoctection/utils/env_info.py +52 -14
- deepdoctection/utils/file_utils.py +6 -11
- deepdoctection/utils/fs.py +41 -14
- deepdoctection/utils/identifier.py +2 -2
- deepdoctection/utils/logger.py +15 -15
- deepdoctection/utils/metacfg.py +7 -7
- deepdoctection/utils/pdf_utils.py +39 -14
- deepdoctection/utils/settings.py +188 -182
- deepdoctection/utils/tqdm.py +1 -1
- deepdoctection/utils/transform.py +14 -9
- deepdoctection/utils/types.py +104 -0
- deepdoctection/utils/utils.py +7 -7
- deepdoctection/utils/viz.py +70 -69
- {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
- deepdoctection-0.34.dist-info/RECORD +146 -0
- {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
- deepdoctection/utils/detection_types.py +0 -68
- deepdoctection-0.32.dist-info/RECORD +0 -146
- {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
- {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
deepdoctection/eval/eval.py
CHANGED
|
@@ -22,7 +22,7 @@ Module for `Evaluator`
|
|
|
22
22
|
from __future__ import annotations
|
|
23
23
|
|
|
24
24
|
from copy import deepcopy
|
|
25
|
-
from typing import Any,
|
|
25
|
+
from typing import Any, Generator, Literal, Mapping, Optional, Type, Union, overload
|
|
26
26
|
|
|
27
27
|
import numpy as np
|
|
28
28
|
from lazy_imports import try_import
|
|
@@ -33,13 +33,13 @@ from ..datasets.base import DatasetBase
|
|
|
33
33
|
from ..mapper.cats import filter_cat, remove_cats
|
|
34
34
|
from ..mapper.d2struct import to_wandb_image
|
|
35
35
|
from ..mapper.misc import maybe_load_image, maybe_remove_image, maybe_remove_image_from_category
|
|
36
|
-
from ..pipe.base import
|
|
36
|
+
from ..pipe.base import PipelineComponent
|
|
37
37
|
from ..pipe.common import PageParsingService
|
|
38
38
|
from ..pipe.concurrency import MultiThreadPipelineComponent
|
|
39
39
|
from ..pipe.doctectionpipe import DoctectionPipe
|
|
40
|
-
from ..utils.detection_types import ImageType
|
|
41
40
|
from ..utils.logger import LoggingRecord, logger
|
|
42
41
|
from ..utils.settings import DatasetType, LayoutType, TypeOrStr, get_type
|
|
42
|
+
from ..utils.types import PixelValues
|
|
43
43
|
from ..utils.viz import interactive_imshow
|
|
44
44
|
from .base import MetricBase
|
|
45
45
|
|
|
@@ -90,7 +90,7 @@ class Evaluator:
|
|
|
90
90
|
def __init__(
|
|
91
91
|
self,
|
|
92
92
|
dataset: DatasetBase,
|
|
93
|
-
component_or_pipeline: Union[
|
|
93
|
+
component_or_pipeline: Union[PipelineComponent, DoctectionPipe],
|
|
94
94
|
metric: Union[Type[MetricBase], MetricBase],
|
|
95
95
|
num_threads: int = 2,
|
|
96
96
|
run: Optional[wandb.sdk.wandb_run.Run] = None,
|
|
@@ -108,14 +108,14 @@ class Evaluator:
|
|
|
108
108
|
self.pipe: Optional[DoctectionPipe] = None
|
|
109
109
|
|
|
110
110
|
# when passing a component, we will process prediction on num_threads
|
|
111
|
-
if isinstance(component_or_pipeline,
|
|
111
|
+
if isinstance(component_or_pipeline, PipelineComponent):
|
|
112
112
|
logger.info(
|
|
113
113
|
LoggingRecord(
|
|
114
114
|
f"Building multi threading pipeline component to increase prediction throughput. "
|
|
115
115
|
f"Using {num_threads} threads"
|
|
116
116
|
)
|
|
117
117
|
)
|
|
118
|
-
pipeline_components:
|
|
118
|
+
pipeline_components: list[PipelineComponent] = []
|
|
119
119
|
|
|
120
120
|
for _ in range(num_threads - 1):
|
|
121
121
|
copy_pipe_component = component_or_pipeline.clone()
|
|
@@ -139,14 +139,14 @@ class Evaluator:
|
|
|
139
139
|
|
|
140
140
|
self.wandb_table_agent: Optional[WandbTableAgent]
|
|
141
141
|
if run is not None:
|
|
142
|
-
if self.dataset.dataset_info.type == DatasetType.
|
|
142
|
+
if self.dataset.dataset_info.type == DatasetType.OBJECT_DETECTION:
|
|
143
143
|
self.wandb_table_agent = WandbTableAgent(
|
|
144
144
|
run,
|
|
145
145
|
self.dataset.dataset_info.name,
|
|
146
146
|
50,
|
|
147
147
|
self.dataset.dataflow.categories.get_categories(filtered=True),
|
|
148
148
|
)
|
|
149
|
-
elif self.dataset.dataset_info.type == DatasetType.
|
|
149
|
+
elif self.dataset.dataset_info.type == DatasetType.TOKEN_CLASSIFICATION:
|
|
150
150
|
if hasattr(self.metric, "sub_cats"):
|
|
151
151
|
sub_cat_key, sub_cat_val_list = list(self.metric.sub_cats.items())[0]
|
|
152
152
|
sub_cat_val = sub_cat_val_list[0]
|
|
@@ -178,16 +178,16 @@ class Evaluator:
|
|
|
178
178
|
@overload
|
|
179
179
|
def run(
|
|
180
180
|
self, output_as_dict: Literal[False] = False, **dataflow_build_kwargs: Union[str, int]
|
|
181
|
-
) ->
|
|
181
|
+
) -> list[dict[str, float]]:
|
|
182
182
|
...
|
|
183
183
|
|
|
184
184
|
@overload
|
|
185
|
-
def run(self, output_as_dict: Literal[True], **dataflow_build_kwargs: Union[str, int]) ->
|
|
185
|
+
def run(self, output_as_dict: Literal[True], **dataflow_build_kwargs: Union[str, int]) -> dict[str, float]:
|
|
186
186
|
...
|
|
187
187
|
|
|
188
188
|
def run(
|
|
189
189
|
self, output_as_dict: bool = False, **dataflow_build_kwargs: Union[str, int]
|
|
190
|
-
) -> Union[
|
|
190
|
+
) -> Union[list[dict[str, float]], dict[str, float]]:
|
|
191
191
|
"""
|
|
192
192
|
Start evaluation process and return the results.
|
|
193
193
|
|
|
@@ -246,11 +246,11 @@ class Evaluator:
|
|
|
246
246
|
possible_cats_in_datapoint = self.dataset.dataflow.categories.get_categories(as_dict=False, filtered=True)
|
|
247
247
|
|
|
248
248
|
# clean-up procedure depends on the dataset type
|
|
249
|
-
if self.dataset.dataset_info.type == DatasetType.
|
|
249
|
+
if self.dataset.dataset_info.type == DatasetType.OBJECT_DETECTION:
|
|
250
250
|
# we keep all image annotations that will not be generated through processing
|
|
251
|
-
anns_to_keep = {ann for ann in possible_cats_in_datapoint if ann not in meta_anns
|
|
252
|
-
sub_cats_to_remove = meta_anns
|
|
253
|
-
relationships_to_remove = meta_anns
|
|
251
|
+
anns_to_keep = {ann for ann in possible_cats_in_datapoint if ann not in meta_anns.image_annotations}
|
|
252
|
+
sub_cats_to_remove = meta_anns.sub_categories
|
|
253
|
+
relationships_to_remove = meta_anns.relationships
|
|
254
254
|
# removing annotations takes place in three steps: First we remove all image annotations. Then, with all
|
|
255
255
|
# remaining image annotations we check, if the image attribute (with Image instance !) is not empty and
|
|
256
256
|
# remove it as well, if necessary. In the last step we remove all sub categories and relationships, if
|
|
@@ -262,19 +262,19 @@ class Evaluator:
|
|
|
262
262
|
remove_cats(sub_categories=sub_cats_to_remove, relationships=relationships_to_remove),
|
|
263
263
|
)
|
|
264
264
|
|
|
265
|
-
elif self.dataset.dataset_info.type == DatasetType.
|
|
266
|
-
summary_sub_cats_to_remove = meta_anns
|
|
265
|
+
elif self.dataset.dataset_info.type == DatasetType.SEQUENCE_CLASSIFICATION:
|
|
266
|
+
summary_sub_cats_to_remove = meta_anns.summaries
|
|
267
267
|
df_pr = MapData(df_pr, remove_cats(summary_sub_categories=summary_sub_cats_to_remove))
|
|
268
268
|
|
|
269
|
-
elif self.dataset.dataset_info.type == DatasetType.
|
|
270
|
-
sub_cats_to_remove = meta_anns
|
|
269
|
+
elif self.dataset.dataset_info.type == DatasetType.TOKEN_CLASSIFICATION:
|
|
270
|
+
sub_cats_to_remove = meta_anns.sub_categories
|
|
271
271
|
df_pr = MapData(df_pr, remove_cats(sub_categories=sub_cats_to_remove))
|
|
272
272
|
else:
|
|
273
273
|
raise NotImplementedError()
|
|
274
274
|
|
|
275
275
|
return df_pr
|
|
276
276
|
|
|
277
|
-
def compare(self, interactive: bool = False, **kwargs: Union[str, int]) -> Generator[
|
|
277
|
+
def compare(self, interactive: bool = False, **kwargs: Union[str, int]) -> Generator[PixelValues, None, None]:
|
|
278
278
|
"""
|
|
279
279
|
Visualize ground truth and prediction datapoint. Given a dataflow config it will run predictions per sample
|
|
280
280
|
and concat the prediction image (with predicted bounding boxes) with ground truth image.
|
|
@@ -293,6 +293,8 @@ class Evaluator:
|
|
|
293
293
|
show_words = kwargs.pop("show_words", False)
|
|
294
294
|
show_token_class = kwargs.pop("show_token_class", True)
|
|
295
295
|
ignore_default_token_class = kwargs.pop("ignore_default_token_class", False)
|
|
296
|
+
floating_text_block_categories = kwargs.pop("floating_text_block_categories", None)
|
|
297
|
+
include_residual_text_containers = kwargs.pop("include_residual_Text_containers", True)
|
|
296
298
|
|
|
297
299
|
df_gt = self.dataset.dataflow.build(**kwargs)
|
|
298
300
|
df_pr = self.dataset.dataflow.build(**kwargs)
|
|
@@ -301,7 +303,11 @@ class Evaluator:
|
|
|
301
303
|
df_pr = MapData(df_pr, deepcopy)
|
|
302
304
|
df_pr = self._clean_up_predict_dataflow_annotations(df_pr)
|
|
303
305
|
|
|
304
|
-
page_parsing_component = PageParsingService(
|
|
306
|
+
page_parsing_component = PageParsingService(
|
|
307
|
+
text_container=LayoutType.WORD,
|
|
308
|
+
floating_text_block_categories=floating_text_block_categories, # type: ignore
|
|
309
|
+
include_residual_text_container=bool(include_residual_text_containers),
|
|
310
|
+
)
|
|
305
311
|
df_gt = page_parsing_component.predict_dataflow(df_gt)
|
|
306
312
|
|
|
307
313
|
if self.pipe_component:
|
|
@@ -357,8 +363,8 @@ class WandbTableAgent:
|
|
|
357
363
|
wandb_run: wandb.sdk.wandb_run.Run,
|
|
358
364
|
dataset_name: str,
|
|
359
365
|
num_samples: int,
|
|
360
|
-
categories: Mapping[
|
|
361
|
-
sub_categories: Optional[Mapping[
|
|
366
|
+
categories: Mapping[int, TypeOrStr],
|
|
367
|
+
sub_categories: Optional[Mapping[int, TypeOrStr]] = None,
|
|
362
368
|
cat_to_sub_cat: Optional[Mapping[TypeOrStr, TypeOrStr]] = None,
|
|
363
369
|
):
|
|
364
370
|
"""
|
|
@@ -385,8 +391,8 @@ class WandbTableAgent:
|
|
|
385
391
|
self._counter = 0
|
|
386
392
|
|
|
387
393
|
# Table logging utils
|
|
388
|
-
self._table_cols:
|
|
389
|
-
self._table_rows:
|
|
394
|
+
self._table_cols: list[str] = ["file_name", "image"]
|
|
395
|
+
self._table_rows: list[Any] = []
|
|
390
396
|
self._table_ref = None
|
|
391
397
|
|
|
392
398
|
def dump(self, dp: Image) -> Image:
|
|
@@ -439,4 +445,4 @@ class WandbTableAgent:
|
|
|
439
445
|
eval_art.add(self._build_table(), self.dataset_name)
|
|
440
446
|
self._run.use_artifact(eval_art)
|
|
441
447
|
eval_art.wait()
|
|
442
|
-
self._table_ref = eval_art.get(self.dataset_name).data # type:ignore
|
|
448
|
+
self._table_ref = eval_art.get(self.dataset_name).data # type: ignore
|
|
@@ -18,17 +18,18 @@ Tree distance similarity metric taken from <https://github.com/ibm-aur-nlp/PubTa
|
|
|
18
18
|
|
|
19
19
|
import statistics
|
|
20
20
|
from collections import defaultdict, deque
|
|
21
|
-
from typing import Any,
|
|
21
|
+
from typing import Any, Callable, Optional
|
|
22
22
|
|
|
23
23
|
from lazy_imports import try_import
|
|
24
24
|
|
|
25
25
|
from ..dataflow import DataFlow, DataFromList, MapData, MultiThreadMapData
|
|
26
|
+
from ..datapoint.image import Image
|
|
26
27
|
from ..datapoint.view import Page
|
|
27
28
|
from ..datasets.base import DatasetCategories
|
|
28
|
-
from ..utils.detection_types import JsonDict
|
|
29
29
|
from ..utils.file_utils import Requirement, get_apted_requirement, get_distance_requirement, get_lxml_requirement
|
|
30
30
|
from ..utils.logger import LoggingRecord, logger
|
|
31
31
|
from ..utils.settings import LayoutType
|
|
32
|
+
from ..utils.types import MetricResults
|
|
32
33
|
from .base import MetricBase
|
|
33
34
|
from .registry import metric_registry
|
|
34
35
|
|
|
@@ -59,7 +60,7 @@ class TableTree(Tree):
|
|
|
59
60
|
tag: str,
|
|
60
61
|
colspan: Optional[int] = None,
|
|
61
62
|
rowspan: Optional[int] = None,
|
|
62
|
-
content: Optional[
|
|
63
|
+
content: Optional[list[str]] = None,
|
|
63
64
|
) -> None:
|
|
64
65
|
self.tag = tag
|
|
65
66
|
self.colspan = colspan
|
|
@@ -107,7 +108,7 @@ class TEDS:
|
|
|
107
108
|
|
|
108
109
|
def __init__(self, structure_only: bool = False):
|
|
109
110
|
self.structure_only = structure_only
|
|
110
|
-
self.__tokens__:
|
|
111
|
+
self.__tokens__: list[str] = []
|
|
111
112
|
|
|
112
113
|
def tokenize(self, node: TableTree) -> None:
|
|
113
114
|
"""Tokenizes table cells"""
|
|
@@ -149,7 +150,7 @@ class TEDS:
|
|
|
149
150
|
return new_node
|
|
150
151
|
return None
|
|
151
152
|
|
|
152
|
-
def evaluate(self, inputs:
|
|
153
|
+
def evaluate(self, inputs: tuple[str, str]) -> float:
|
|
153
154
|
"""Computes TEDS score between the prediction and the ground truth of a
|
|
154
155
|
given sample
|
|
155
156
|
"""
|
|
@@ -188,7 +189,7 @@ class TEDS:
|
|
|
188
189
|
return 0.0
|
|
189
190
|
|
|
190
191
|
|
|
191
|
-
def teds_metric(gt_list:
|
|
192
|
+
def teds_metric(gt_list: list[str], predict_list: list[str], structure_only: bool) -> tuple[float, int]:
|
|
192
193
|
"""
|
|
193
194
|
Computes tree edit distance score (TEDS) between the prediction and the ground truth of a batch of samples. The
|
|
194
195
|
approach to measure similarity of tables by means of their html representation has been adovacated in
|
|
@@ -221,13 +222,16 @@ class TedsMetric(MetricBase):
|
|
|
221
222
|
"""
|
|
222
223
|
|
|
223
224
|
metric = teds_metric # type: ignore
|
|
224
|
-
mapper = Page.from_image
|
|
225
|
+
mapper: Callable[[Image, LayoutType, list[LayoutType]], Page] = Page.from_image
|
|
226
|
+
text_container: LayoutType = LayoutType.WORD
|
|
227
|
+
floating_text_block_categories = [LayoutType.TABLE]
|
|
228
|
+
|
|
225
229
|
structure_only = False
|
|
226
230
|
|
|
227
231
|
@classmethod
|
|
228
232
|
def dump(
|
|
229
233
|
cls, dataflow_gt: DataFlow, dataflow_predictions: DataFlow, categories: DatasetCategories
|
|
230
|
-
) ->
|
|
234
|
+
) -> tuple[list[str], list[str]]:
|
|
231
235
|
dataflow_gt.reset_state()
|
|
232
236
|
dataflow_predictions.reset_state()
|
|
233
237
|
|
|
@@ -235,11 +239,11 @@ class TedsMetric(MetricBase):
|
|
|
235
239
|
gt_dict = defaultdict(list)
|
|
236
240
|
pred_dict = defaultdict(list)
|
|
237
241
|
for dp_gt, dp_pred in zip(dataflow_gt, dataflow_predictions):
|
|
238
|
-
page_gt = cls.mapper(dp_gt,
|
|
242
|
+
page_gt = cls.mapper(dp_gt, cls.text_container, cls.floating_text_block_categories)
|
|
239
243
|
for table in page_gt.tables:
|
|
240
244
|
gt_dict[page_gt.image_id].append(table.html)
|
|
241
245
|
|
|
242
|
-
page_pred = cls.mapper(dp_pred,
|
|
246
|
+
page_pred = cls.mapper(dp_pred, cls.text_container, cls.floating_text_block_categories)
|
|
243
247
|
for table in page_pred.tables:
|
|
244
248
|
pred_dict[page_pred.image_id].append(table.html)
|
|
245
249
|
|
|
@@ -254,12 +258,12 @@ class TedsMetric(MetricBase):
|
|
|
254
258
|
@classmethod
|
|
255
259
|
def get_distance(
|
|
256
260
|
cls, dataflow_gt: DataFlow, dataflow_predictions: DataFlow, categories: DatasetCategories
|
|
257
|
-
) ->
|
|
261
|
+
) -> list[MetricResults]:
|
|
258
262
|
html_gt_list, html_pr_list = cls.dump(dataflow_gt, dataflow_predictions, categories)
|
|
259
263
|
|
|
260
264
|
score, num_samples = cls.metric(html_gt_list, html_pr_list, cls.structure_only) # type: ignore
|
|
261
265
|
return [{"teds_score": score, "num_samples": num_samples}]
|
|
262
266
|
|
|
263
267
|
@classmethod
|
|
264
|
-
def get_requirements(cls) ->
|
|
268
|
+
def get_requirements(cls) -> list[Requirement]:
|
|
265
269
|
return [get_apted_requirement(), get_distance_requirement(), get_lxml_requirement()]
|
|
@@ -27,8 +27,7 @@ from typing import Mapping, Optional, Sequence, Type, Union
|
|
|
27
27
|
from lazy_imports import try_import
|
|
28
28
|
|
|
29
29
|
from ..datasets import DatasetBase
|
|
30
|
-
from ..
|
|
31
|
-
from ..pipe.base import PredictorPipelineComponent
|
|
30
|
+
from ..pipe.base import PipelineComponent
|
|
32
31
|
from ..utils.logger import LoggingRecord, logger
|
|
33
32
|
from ..utils.metacfg import AttrDict
|
|
34
33
|
from ..utils.settings import ObjectTypes
|
|
@@ -65,9 +64,10 @@ class EvalCallback(Callback): # pylint: disable=R0903
|
|
|
65
64
|
category_names: Optional[Union[ObjectTypes, Sequence[ObjectTypes]]],
|
|
66
65
|
sub_categories: Optional[Union[Mapping[ObjectTypes, ObjectTypes], Mapping[ObjectTypes, Sequence[ObjectTypes]]]],
|
|
67
66
|
metric: Union[Type[MetricBase], MetricBase],
|
|
68
|
-
pipeline_component:
|
|
67
|
+
pipeline_component: PipelineComponent,
|
|
69
68
|
in_names: str,
|
|
70
69
|
out_names: str,
|
|
70
|
+
cfg: AttrDict,
|
|
71
71
|
**build_eval_kwargs: str,
|
|
72
72
|
) -> None:
|
|
73
73
|
"""
|
|
@@ -89,12 +89,7 @@ class EvalCallback(Callback): # pylint: disable=R0903
|
|
|
89
89
|
self.num_gpu = get_num_gpu()
|
|
90
90
|
self.category_names = category_names
|
|
91
91
|
self.sub_categories = sub_categories
|
|
92
|
-
|
|
93
|
-
raise TypeError(
|
|
94
|
-
f"pipeline_component.predictor must be of type TPFrcnnDetector but is "
|
|
95
|
-
f"type {type(pipeline_component.predictor)}"
|
|
96
|
-
)
|
|
97
|
-
self.cfg = pipeline_component.predictor.model.cfg
|
|
92
|
+
self.cfg = cfg
|
|
98
93
|
if _use_replicated(self.cfg):
|
|
99
94
|
self.evaluator = Evaluator(dataset, pipeline_component, metric, num_threads=self.num_gpu * 2)
|
|
100
95
|
else:
|
|
@@ -105,13 +100,9 @@ class EvalCallback(Callback): # pylint: disable=R0903
|
|
|
105
100
|
if self.evaluator.pipe_component is None:
|
|
106
101
|
raise TypeError("self.evaluator.pipe_component cannot be None")
|
|
107
102
|
for idx, comp in enumerate(self.evaluator.pipe_component.pipe_components):
|
|
108
|
-
if
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
raise TypeError(
|
|
112
|
-
f"comp.predictor mus be of type TPFrcnnDetector but is of type {type(comp.predictor)}"
|
|
113
|
-
)
|
|
114
|
-
comp.predictor.tp_predictor = self._build_predictor(idx % self.num_gpu)
|
|
103
|
+
if hasattr(comp, "predictor"):
|
|
104
|
+
if hasattr(comp.predictor, "tp_predictor"):
|
|
105
|
+
comp.predictor.tp_predictor = self._build_predictor(idx % self.num_gpu)
|
|
115
106
|
|
|
116
107
|
def _build_predictor(self, idx: int) -> OnlinePredictor:
|
|
117
108
|
return self.trainer.get_predictor(self.in_names, self.out_names, device=idx)
|