deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (111) hide show
  1. deepdoctection/__init__.py +8 -25
  2. deepdoctection/analyzer/dd.py +84 -71
  3. deepdoctection/dataflow/common.py +9 -5
  4. deepdoctection/dataflow/custom.py +5 -5
  5. deepdoctection/dataflow/custom_serialize.py +75 -18
  6. deepdoctection/dataflow/parallel_map.py +3 -3
  7. deepdoctection/dataflow/serialize.py +4 -4
  8. deepdoctection/dataflow/stats.py +3 -3
  9. deepdoctection/datapoint/annotation.py +78 -56
  10. deepdoctection/datapoint/box.py +7 -7
  11. deepdoctection/datapoint/convert.py +6 -6
  12. deepdoctection/datapoint/image.py +157 -75
  13. deepdoctection/datapoint/view.py +175 -151
  14. deepdoctection/datasets/adapter.py +30 -24
  15. deepdoctection/datasets/base.py +10 -10
  16. deepdoctection/datasets/dataflow_builder.py +3 -3
  17. deepdoctection/datasets/info.py +23 -25
  18. deepdoctection/datasets/instances/doclaynet.py +48 -49
  19. deepdoctection/datasets/instances/fintabnet.py +44 -45
  20. deepdoctection/datasets/instances/funsd.py +23 -23
  21. deepdoctection/datasets/instances/iiitar13k.py +8 -8
  22. deepdoctection/datasets/instances/layouttest.py +2 -2
  23. deepdoctection/datasets/instances/publaynet.py +3 -3
  24. deepdoctection/datasets/instances/pubtables1m.py +18 -18
  25. deepdoctection/datasets/instances/pubtabnet.py +30 -29
  26. deepdoctection/datasets/instances/rvlcdip.py +28 -29
  27. deepdoctection/datasets/instances/xfund.py +51 -30
  28. deepdoctection/datasets/save.py +6 -6
  29. deepdoctection/eval/accmetric.py +32 -33
  30. deepdoctection/eval/base.py +8 -9
  31. deepdoctection/eval/cocometric.py +13 -12
  32. deepdoctection/eval/eval.py +32 -26
  33. deepdoctection/eval/tedsmetric.py +16 -12
  34. deepdoctection/eval/tp_eval_callback.py +7 -16
  35. deepdoctection/extern/base.py +339 -134
  36. deepdoctection/extern/d2detect.py +69 -89
  37. deepdoctection/extern/deskew.py +11 -10
  38. deepdoctection/extern/doctrocr.py +81 -64
  39. deepdoctection/extern/fastlang.py +23 -16
  40. deepdoctection/extern/hfdetr.py +53 -38
  41. deepdoctection/extern/hflayoutlm.py +216 -155
  42. deepdoctection/extern/hflm.py +35 -30
  43. deepdoctection/extern/model.py +433 -255
  44. deepdoctection/extern/pdftext.py +15 -15
  45. deepdoctection/extern/pt/ptutils.py +4 -2
  46. deepdoctection/extern/tessocr.py +39 -38
  47. deepdoctection/extern/texocr.py +14 -16
  48. deepdoctection/extern/tp/tfutils.py +16 -2
  49. deepdoctection/extern/tp/tpcompat.py +11 -7
  50. deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
  56. deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
  57. deepdoctection/extern/tpdetect.py +40 -45
  58. deepdoctection/mapper/cats.py +36 -40
  59. deepdoctection/mapper/cocostruct.py +16 -12
  60. deepdoctection/mapper/d2struct.py +22 -22
  61. deepdoctection/mapper/hfstruct.py +7 -7
  62. deepdoctection/mapper/laylmstruct.py +22 -24
  63. deepdoctection/mapper/maputils.py +9 -10
  64. deepdoctection/mapper/match.py +33 -2
  65. deepdoctection/mapper/misc.py +6 -7
  66. deepdoctection/mapper/pascalstruct.py +4 -4
  67. deepdoctection/mapper/prodigystruct.py +6 -6
  68. deepdoctection/mapper/pubstruct.py +84 -92
  69. deepdoctection/mapper/tpstruct.py +3 -3
  70. deepdoctection/mapper/xfundstruct.py +33 -33
  71. deepdoctection/pipe/anngen.py +39 -14
  72. deepdoctection/pipe/base.py +68 -99
  73. deepdoctection/pipe/common.py +181 -85
  74. deepdoctection/pipe/concurrency.py +14 -10
  75. deepdoctection/pipe/doctectionpipe.py +24 -21
  76. deepdoctection/pipe/language.py +20 -25
  77. deepdoctection/pipe/layout.py +18 -16
  78. deepdoctection/pipe/lm.py +49 -47
  79. deepdoctection/pipe/order.py +63 -65
  80. deepdoctection/pipe/refine.py +102 -109
  81. deepdoctection/pipe/segment.py +157 -162
  82. deepdoctection/pipe/sub_layout.py +50 -40
  83. deepdoctection/pipe/text.py +37 -36
  84. deepdoctection/pipe/transform.py +19 -16
  85. deepdoctection/train/d2_frcnn_train.py +27 -25
  86. deepdoctection/train/hf_detr_train.py +22 -18
  87. deepdoctection/train/hf_layoutlm_train.py +49 -48
  88. deepdoctection/train/tp_frcnn_train.py +10 -11
  89. deepdoctection/utils/concurrency.py +1 -1
  90. deepdoctection/utils/context.py +13 -6
  91. deepdoctection/utils/develop.py +4 -4
  92. deepdoctection/utils/env_info.py +52 -14
  93. deepdoctection/utils/file_utils.py +6 -11
  94. deepdoctection/utils/fs.py +41 -14
  95. deepdoctection/utils/identifier.py +2 -2
  96. deepdoctection/utils/logger.py +15 -15
  97. deepdoctection/utils/metacfg.py +7 -7
  98. deepdoctection/utils/pdf_utils.py +39 -14
  99. deepdoctection/utils/settings.py +188 -182
  100. deepdoctection/utils/tqdm.py +1 -1
  101. deepdoctection/utils/transform.py +14 -9
  102. deepdoctection/utils/types.py +104 -0
  103. deepdoctection/utils/utils.py +7 -7
  104. deepdoctection/utils/viz.py +70 -69
  105. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
  106. deepdoctection-0.34.dist-info/RECORD +146 -0
  107. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
  108. deepdoctection/utils/detection_types.py +0 -68
  109. deepdoctection-0.32.dist-info/RECORD +0 -146
  110. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
  111. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ Module for `Evaluator`
22
22
  from __future__ import annotations
23
23
 
24
24
  from copy import deepcopy
25
- from typing import Any, Dict, Generator, List, Literal, Mapping, Optional, Type, Union, overload
25
+ from typing import Any, Generator, Literal, Mapping, Optional, Type, Union, overload
26
26
 
27
27
  import numpy as np
28
28
  from lazy_imports import try_import
@@ -33,13 +33,13 @@ from ..datasets.base import DatasetBase
33
33
  from ..mapper.cats import filter_cat, remove_cats
34
34
  from ..mapper.d2struct import to_wandb_image
35
35
  from ..mapper.misc import maybe_load_image, maybe_remove_image, maybe_remove_image_from_category
36
- from ..pipe.base import LanguageModelPipelineComponent, PredictorPipelineComponent
36
+ from ..pipe.base import PipelineComponent
37
37
  from ..pipe.common import PageParsingService
38
38
  from ..pipe.concurrency import MultiThreadPipelineComponent
39
39
  from ..pipe.doctectionpipe import DoctectionPipe
40
- from ..utils.detection_types import ImageType
41
40
  from ..utils.logger import LoggingRecord, logger
42
41
  from ..utils.settings import DatasetType, LayoutType, TypeOrStr, get_type
42
+ from ..utils.types import PixelValues
43
43
  from ..utils.viz import interactive_imshow
44
44
  from .base import MetricBase
45
45
 
@@ -90,7 +90,7 @@ class Evaluator:
90
90
  def __init__(
91
91
  self,
92
92
  dataset: DatasetBase,
93
- component_or_pipeline: Union[PredictorPipelineComponent, LanguageModelPipelineComponent, DoctectionPipe],
93
+ component_or_pipeline: Union[PipelineComponent, DoctectionPipe],
94
94
  metric: Union[Type[MetricBase], MetricBase],
95
95
  num_threads: int = 2,
96
96
  run: Optional[wandb.sdk.wandb_run.Run] = None,
@@ -108,14 +108,14 @@ class Evaluator:
108
108
  self.pipe: Optional[DoctectionPipe] = None
109
109
 
110
110
  # when passing a component, we will process prediction on num_threads
111
- if isinstance(component_or_pipeline, (PredictorPipelineComponent, LanguageModelPipelineComponent)):
111
+ if isinstance(component_or_pipeline, PipelineComponent):
112
112
  logger.info(
113
113
  LoggingRecord(
114
114
  f"Building multi threading pipeline component to increase prediction throughput. "
115
115
  f"Using {num_threads} threads"
116
116
  )
117
117
  )
118
- pipeline_components: List[Union[PredictorPipelineComponent, LanguageModelPipelineComponent]] = []
118
+ pipeline_components: list[PipelineComponent] = []
119
119
 
120
120
  for _ in range(num_threads - 1):
121
121
  copy_pipe_component = component_or_pipeline.clone()
@@ -139,14 +139,14 @@ class Evaluator:
139
139
 
140
140
  self.wandb_table_agent: Optional[WandbTableAgent]
141
141
  if run is not None:
142
- if self.dataset.dataset_info.type == DatasetType.object_detection:
142
+ if self.dataset.dataset_info.type == DatasetType.OBJECT_DETECTION:
143
143
  self.wandb_table_agent = WandbTableAgent(
144
144
  run,
145
145
  self.dataset.dataset_info.name,
146
146
  50,
147
147
  self.dataset.dataflow.categories.get_categories(filtered=True),
148
148
  )
149
- elif self.dataset.dataset_info.type == DatasetType.token_classification:
149
+ elif self.dataset.dataset_info.type == DatasetType.TOKEN_CLASSIFICATION:
150
150
  if hasattr(self.metric, "sub_cats"):
151
151
  sub_cat_key, sub_cat_val_list = list(self.metric.sub_cats.items())[0]
152
152
  sub_cat_val = sub_cat_val_list[0]
@@ -178,16 +178,16 @@ class Evaluator:
178
178
  @overload
179
179
  def run(
180
180
  self, output_as_dict: Literal[False] = False, **dataflow_build_kwargs: Union[str, int]
181
- ) -> List[Dict[str, float]]:
181
+ ) -> list[dict[str, float]]:
182
182
  ...
183
183
 
184
184
  @overload
185
- def run(self, output_as_dict: Literal[True], **dataflow_build_kwargs: Union[str, int]) -> Dict[str, float]:
185
+ def run(self, output_as_dict: Literal[True], **dataflow_build_kwargs: Union[str, int]) -> dict[str, float]:
186
186
  ...
187
187
 
188
188
  def run(
189
189
  self, output_as_dict: bool = False, **dataflow_build_kwargs: Union[str, int]
190
- ) -> Union[List[Dict[str, float]], Dict[str, float]]:
190
+ ) -> Union[list[dict[str, float]], dict[str, float]]:
191
191
  """
192
192
  Start evaluation process and return the results.
193
193
 
@@ -246,11 +246,11 @@ class Evaluator:
246
246
  possible_cats_in_datapoint = self.dataset.dataflow.categories.get_categories(as_dict=False, filtered=True)
247
247
 
248
248
  # clean-up procedure depends on the dataset type
249
- if self.dataset.dataset_info.type == DatasetType.object_detection:
249
+ if self.dataset.dataset_info.type == DatasetType.OBJECT_DETECTION:
250
250
  # we keep all image annotations that will not be generated through processing
251
- anns_to_keep = {ann for ann in possible_cats_in_datapoint if ann not in meta_anns["image_annotations"]}
252
- sub_cats_to_remove = meta_anns["sub_categories"]
253
- relationships_to_remove = meta_anns["relationships"]
251
+ anns_to_keep = {ann for ann in possible_cats_in_datapoint if ann not in meta_anns.image_annotations}
252
+ sub_cats_to_remove = meta_anns.sub_categories
253
+ relationships_to_remove = meta_anns.relationships
254
254
  # removing annotations takes place in three steps: First we remove all image annotations. Then, with all
255
255
  # remaining image annotations we check, if the image attribute (with Image instance !) is not empty and
256
256
  # remove it as well, if necessary. In the last step we remove all sub categories and relationships, if
@@ -262,19 +262,19 @@ class Evaluator:
262
262
  remove_cats(sub_categories=sub_cats_to_remove, relationships=relationships_to_remove),
263
263
  )
264
264
 
265
- elif self.dataset.dataset_info.type == DatasetType.sequence_classification:
266
- summary_sub_cats_to_remove = meta_anns["summaries"]
265
+ elif self.dataset.dataset_info.type == DatasetType.SEQUENCE_CLASSIFICATION:
266
+ summary_sub_cats_to_remove = meta_anns.summaries
267
267
  df_pr = MapData(df_pr, remove_cats(summary_sub_categories=summary_sub_cats_to_remove))
268
268
 
269
- elif self.dataset.dataset_info.type == DatasetType.token_classification:
270
- sub_cats_to_remove = meta_anns["sub_categories"]
269
+ elif self.dataset.dataset_info.type == DatasetType.TOKEN_CLASSIFICATION:
270
+ sub_cats_to_remove = meta_anns.sub_categories
271
271
  df_pr = MapData(df_pr, remove_cats(sub_categories=sub_cats_to_remove))
272
272
  else:
273
273
  raise NotImplementedError()
274
274
 
275
275
  return df_pr
276
276
 
277
- def compare(self, interactive: bool = False, **kwargs: Union[str, int]) -> Generator[ImageType, None, None]:
277
+ def compare(self, interactive: bool = False, **kwargs: Union[str, int]) -> Generator[PixelValues, None, None]:
278
278
  """
279
279
  Visualize ground truth and prediction datapoint. Given a dataflow config it will run predictions per sample
280
280
  and concat the prediction image (with predicted bounding boxes) with ground truth image.
@@ -293,6 +293,8 @@ class Evaluator:
293
293
  show_words = kwargs.pop("show_words", False)
294
294
  show_token_class = kwargs.pop("show_token_class", True)
295
295
  ignore_default_token_class = kwargs.pop("ignore_default_token_class", False)
296
+ floating_text_block_categories = kwargs.pop("floating_text_block_categories", None)
297
+ include_residual_text_containers = kwargs.pop("include_residual_Text_containers", True)
296
298
 
297
299
  df_gt = self.dataset.dataflow.build(**kwargs)
298
300
  df_pr = self.dataset.dataflow.build(**kwargs)
@@ -301,7 +303,11 @@ class Evaluator:
301
303
  df_pr = MapData(df_pr, deepcopy)
302
304
  df_pr = self._clean_up_predict_dataflow_annotations(df_pr)
303
305
 
304
- page_parsing_component = PageParsingService(text_container=LayoutType.word)
306
+ page_parsing_component = PageParsingService(
307
+ text_container=LayoutType.WORD,
308
+ floating_text_block_categories=floating_text_block_categories, # type: ignore
309
+ include_residual_text_container=bool(include_residual_text_containers),
310
+ )
305
311
  df_gt = page_parsing_component.predict_dataflow(df_gt)
306
312
 
307
313
  if self.pipe_component:
@@ -357,8 +363,8 @@ class WandbTableAgent:
357
363
  wandb_run: wandb.sdk.wandb_run.Run,
358
364
  dataset_name: str,
359
365
  num_samples: int,
360
- categories: Mapping[str, TypeOrStr],
361
- sub_categories: Optional[Mapping[str, TypeOrStr]] = None,
366
+ categories: Mapping[int, TypeOrStr],
367
+ sub_categories: Optional[Mapping[int, TypeOrStr]] = None,
362
368
  cat_to_sub_cat: Optional[Mapping[TypeOrStr, TypeOrStr]] = None,
363
369
  ):
364
370
  """
@@ -385,8 +391,8 @@ class WandbTableAgent:
385
391
  self._counter = 0
386
392
 
387
393
  # Table logging utils
388
- self._table_cols: List[str] = ["file_name", "image"]
389
- self._table_rows: List[Any] = []
394
+ self._table_cols: list[str] = ["file_name", "image"]
395
+ self._table_rows: list[Any] = []
390
396
  self._table_ref = None
391
397
 
392
398
  def dump(self, dp: Image) -> Image:
@@ -439,4 +445,4 @@ class WandbTableAgent:
439
445
  eval_art.add(self._build_table(), self.dataset_name)
440
446
  self._run.use_artifact(eval_art)
441
447
  eval_art.wait()
442
- self._table_ref = eval_art.get(self.dataset_name).data # type:ignore
448
+ self._table_ref = eval_art.get(self.dataset_name).data # type: ignore
@@ -18,17 +18,18 @@ Tree distance similarity metric taken from <https://github.com/ibm-aur-nlp/PubTa
18
18
 
19
19
  import statistics
20
20
  from collections import defaultdict, deque
21
- from typing import Any, List, Optional, Tuple
21
+ from typing import Any, Callable, Optional
22
22
 
23
23
  from lazy_imports import try_import
24
24
 
25
25
  from ..dataflow import DataFlow, DataFromList, MapData, MultiThreadMapData
26
+ from ..datapoint.image import Image
26
27
  from ..datapoint.view import Page
27
28
  from ..datasets.base import DatasetCategories
28
- from ..utils.detection_types import JsonDict
29
29
  from ..utils.file_utils import Requirement, get_apted_requirement, get_distance_requirement, get_lxml_requirement
30
30
  from ..utils.logger import LoggingRecord, logger
31
31
  from ..utils.settings import LayoutType
32
+ from ..utils.types import MetricResults
32
33
  from .base import MetricBase
33
34
  from .registry import metric_registry
34
35
 
@@ -59,7 +60,7 @@ class TableTree(Tree):
59
60
  tag: str,
60
61
  colspan: Optional[int] = None,
61
62
  rowspan: Optional[int] = None,
62
- content: Optional[List[str]] = None,
63
+ content: Optional[list[str]] = None,
63
64
  ) -> None:
64
65
  self.tag = tag
65
66
  self.colspan = colspan
@@ -107,7 +108,7 @@ class TEDS:
107
108
 
108
109
  def __init__(self, structure_only: bool = False):
109
110
  self.structure_only = structure_only
110
- self.__tokens__: List[str] = []
111
+ self.__tokens__: list[str] = []
111
112
 
112
113
  def tokenize(self, node: TableTree) -> None:
113
114
  """Tokenizes table cells"""
@@ -149,7 +150,7 @@ class TEDS:
149
150
  return new_node
150
151
  return None
151
152
 
152
- def evaluate(self, inputs: Tuple[str, str]) -> float:
153
+ def evaluate(self, inputs: tuple[str, str]) -> float:
153
154
  """Computes TEDS score between the prediction and the ground truth of a
154
155
  given sample
155
156
  """
@@ -188,7 +189,7 @@ class TEDS:
188
189
  return 0.0
189
190
 
190
191
 
191
- def teds_metric(gt_list: List[str], predict_list: List[str], structure_only: bool) -> Tuple[float, int]:
192
+ def teds_metric(gt_list: list[str], predict_list: list[str], structure_only: bool) -> tuple[float, int]:
192
193
  """
193
194
  Computes tree edit distance score (TEDS) between the prediction and the ground truth of a batch of samples. The
194
195
  approach to measure similarity of tables by means of their html representation has been adovacated in
@@ -221,13 +222,16 @@ class TedsMetric(MetricBase):
221
222
  """
222
223
 
223
224
  metric = teds_metric # type: ignore
224
- mapper = Page.from_image
225
+ mapper: Callable[[Image, LayoutType, list[LayoutType]], Page] = Page.from_image
226
+ text_container: LayoutType = LayoutType.WORD
227
+ floating_text_block_categories = [LayoutType.TABLE]
228
+
225
229
  structure_only = False
226
230
 
227
231
  @classmethod
228
232
  def dump(
229
233
  cls, dataflow_gt: DataFlow, dataflow_predictions: DataFlow, categories: DatasetCategories
230
- ) -> Tuple[List[str], List[str]]:
234
+ ) -> tuple[list[str], list[str]]:
231
235
  dataflow_gt.reset_state()
232
236
  dataflow_predictions.reset_state()
233
237
 
@@ -235,11 +239,11 @@ class TedsMetric(MetricBase):
235
239
  gt_dict = defaultdict(list)
236
240
  pred_dict = defaultdict(list)
237
241
  for dp_gt, dp_pred in zip(dataflow_gt, dataflow_predictions):
238
- page_gt = cls.mapper(dp_gt, LayoutType.word, [LayoutType.table])
242
+ page_gt = cls.mapper(dp_gt, cls.text_container, cls.floating_text_block_categories)
239
243
  for table in page_gt.tables:
240
244
  gt_dict[page_gt.image_id].append(table.html)
241
245
 
242
- page_pred = cls.mapper(dp_pred, LayoutType.word, [LayoutType.table])
246
+ page_pred = cls.mapper(dp_pred, cls.text_container, cls.floating_text_block_categories)
243
247
  for table in page_pred.tables:
244
248
  pred_dict[page_pred.image_id].append(table.html)
245
249
 
@@ -254,12 +258,12 @@ class TedsMetric(MetricBase):
254
258
  @classmethod
255
259
  def get_distance(
256
260
  cls, dataflow_gt: DataFlow, dataflow_predictions: DataFlow, categories: DatasetCategories
257
- ) -> List[JsonDict]:
261
+ ) -> list[MetricResults]:
258
262
  html_gt_list, html_pr_list = cls.dump(dataflow_gt, dataflow_predictions, categories)
259
263
 
260
264
  score, num_samples = cls.metric(html_gt_list, html_pr_list, cls.structure_only) # type: ignore
261
265
  return [{"teds_score": score, "num_samples": num_samples}]
262
266
 
263
267
  @classmethod
264
- def get_requirements(cls) -> List[Requirement]:
268
+ def get_requirements(cls) -> list[Requirement]:
265
269
  return [get_apted_requirement(), get_distance_requirement(), get_lxml_requirement()]
@@ -27,8 +27,7 @@ from typing import Mapping, Optional, Sequence, Type, Union
27
27
  from lazy_imports import try_import
28
28
 
29
29
  from ..datasets import DatasetBase
30
- from ..extern.tpdetect import TPFrcnnDetector
31
- from ..pipe.base import PredictorPipelineComponent
30
+ from ..pipe.base import PipelineComponent
32
31
  from ..utils.logger import LoggingRecord, logger
33
32
  from ..utils.metacfg import AttrDict
34
33
  from ..utils.settings import ObjectTypes
@@ -65,9 +64,10 @@ class EvalCallback(Callback): # pylint: disable=R0903
65
64
  category_names: Optional[Union[ObjectTypes, Sequence[ObjectTypes]]],
66
65
  sub_categories: Optional[Union[Mapping[ObjectTypes, ObjectTypes], Mapping[ObjectTypes, Sequence[ObjectTypes]]]],
67
66
  metric: Union[Type[MetricBase], MetricBase],
68
- pipeline_component: PredictorPipelineComponent,
67
+ pipeline_component: PipelineComponent,
69
68
  in_names: str,
70
69
  out_names: str,
70
+ cfg: AttrDict,
71
71
  **build_eval_kwargs: str,
72
72
  ) -> None:
73
73
  """
@@ -89,12 +89,7 @@ class EvalCallback(Callback): # pylint: disable=R0903
89
89
  self.num_gpu = get_num_gpu()
90
90
  self.category_names = category_names
91
91
  self.sub_categories = sub_categories
92
- if not isinstance(pipeline_component.predictor, TPFrcnnDetector):
93
- raise TypeError(
94
- f"pipeline_component.predictor must be of type TPFrcnnDetector but is "
95
- f"type {type(pipeline_component.predictor)}"
96
- )
97
- self.cfg = pipeline_component.predictor.model.cfg
92
+ self.cfg = cfg
98
93
  if _use_replicated(self.cfg):
99
94
  self.evaluator = Evaluator(dataset, pipeline_component, metric, num_threads=self.num_gpu * 2)
100
95
  else:
@@ -105,13 +100,9 @@ class EvalCallback(Callback): # pylint: disable=R0903
105
100
  if self.evaluator.pipe_component is None:
106
101
  raise TypeError("self.evaluator.pipe_component cannot be None")
107
102
  for idx, comp in enumerate(self.evaluator.pipe_component.pipe_components):
108
- if not isinstance(comp, PredictorPipelineComponent):
109
- raise TypeError(f"comp must be of type PredictorPipelineComponent but is type {type(comp)}")
110
- if not isinstance(comp.predictor, TPFrcnnDetector):
111
- raise TypeError(
112
- f"comp.predictor mus be of type TPFrcnnDetector but is of type {type(comp.predictor)}"
113
- )
114
- comp.predictor.tp_predictor = self._build_predictor(idx % self.num_gpu)
103
+ if hasattr(comp, "predictor"):
104
+ if hasattr(comp.predictor, "tp_predictor"):
105
+ comp.predictor.tp_predictor = self._build_predictor(idx % self.num_gpu)
115
106
 
116
107
  def _build_predictor(self, idx: int) -> OnlinePredictor:
117
108
  return self.trainer.get_predictor(self.in_names, self.out_names, device=idx)