deepdoctection 0.31__py3-none-any.whl → 0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (131) hide show
  1. deepdoctection/__init__.py +16 -29
  2. deepdoctection/analyzer/dd.py +70 -59
  3. deepdoctection/configs/conf_dd_one.yaml +34 -31
  4. deepdoctection/dataflow/common.py +9 -5
  5. deepdoctection/dataflow/custom.py +5 -5
  6. deepdoctection/dataflow/custom_serialize.py +75 -18
  7. deepdoctection/dataflow/parallel_map.py +3 -3
  8. deepdoctection/dataflow/serialize.py +4 -4
  9. deepdoctection/dataflow/stats.py +3 -3
  10. deepdoctection/datapoint/annotation.py +41 -56
  11. deepdoctection/datapoint/box.py +9 -8
  12. deepdoctection/datapoint/convert.py +6 -6
  13. deepdoctection/datapoint/image.py +56 -44
  14. deepdoctection/datapoint/view.py +245 -150
  15. deepdoctection/datasets/__init__.py +1 -4
  16. deepdoctection/datasets/adapter.py +35 -26
  17. deepdoctection/datasets/base.py +14 -12
  18. deepdoctection/datasets/dataflow_builder.py +3 -3
  19. deepdoctection/datasets/info.py +24 -26
  20. deepdoctection/datasets/instances/doclaynet.py +51 -51
  21. deepdoctection/datasets/instances/fintabnet.py +46 -46
  22. deepdoctection/datasets/instances/funsd.py +25 -24
  23. deepdoctection/datasets/instances/iiitar13k.py +13 -10
  24. deepdoctection/datasets/instances/layouttest.py +4 -3
  25. deepdoctection/datasets/instances/publaynet.py +5 -5
  26. deepdoctection/datasets/instances/pubtables1m.py +24 -21
  27. deepdoctection/datasets/instances/pubtabnet.py +32 -30
  28. deepdoctection/datasets/instances/rvlcdip.py +30 -30
  29. deepdoctection/datasets/instances/xfund.py +26 -26
  30. deepdoctection/datasets/save.py +6 -6
  31. deepdoctection/eval/__init__.py +1 -4
  32. deepdoctection/eval/accmetric.py +32 -33
  33. deepdoctection/eval/base.py +8 -9
  34. deepdoctection/eval/cocometric.py +15 -13
  35. deepdoctection/eval/eval.py +41 -37
  36. deepdoctection/eval/tedsmetric.py +30 -23
  37. deepdoctection/eval/tp_eval_callback.py +16 -19
  38. deepdoctection/extern/__init__.py +2 -7
  39. deepdoctection/extern/base.py +339 -134
  40. deepdoctection/extern/d2detect.py +85 -113
  41. deepdoctection/extern/deskew.py +14 -11
  42. deepdoctection/extern/doctrocr.py +141 -130
  43. deepdoctection/extern/fastlang.py +27 -18
  44. deepdoctection/extern/hfdetr.py +71 -62
  45. deepdoctection/extern/hflayoutlm.py +504 -211
  46. deepdoctection/extern/hflm.py +230 -0
  47. deepdoctection/extern/model.py +488 -302
  48. deepdoctection/extern/pdftext.py +23 -19
  49. deepdoctection/extern/pt/__init__.py +1 -3
  50. deepdoctection/extern/pt/nms.py +6 -2
  51. deepdoctection/extern/pt/ptutils.py +29 -19
  52. deepdoctection/extern/tessocr.py +39 -38
  53. deepdoctection/extern/texocr.py +18 -18
  54. deepdoctection/extern/tp/tfutils.py +57 -9
  55. deepdoctection/extern/tp/tpcompat.py +21 -14
  56. deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
  57. deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
  58. deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
  59. deepdoctection/extern/tp/tpfrcnn/config/config.py +13 -10
  60. deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
  61. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +18 -8
  62. deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
  63. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +14 -9
  64. deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
  65. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +22 -17
  66. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +21 -14
  67. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +19 -11
  68. deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
  69. deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
  70. deepdoctection/extern/tp/tpfrcnn/preproc.py +12 -8
  71. deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
  72. deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
  73. deepdoctection/extern/tpdetect.py +45 -53
  74. deepdoctection/mapper/__init__.py +3 -8
  75. deepdoctection/mapper/cats.py +27 -29
  76. deepdoctection/mapper/cocostruct.py +10 -10
  77. deepdoctection/mapper/d2struct.py +27 -26
  78. deepdoctection/mapper/hfstruct.py +13 -8
  79. deepdoctection/mapper/laylmstruct.py +178 -37
  80. deepdoctection/mapper/maputils.py +12 -11
  81. deepdoctection/mapper/match.py +2 -2
  82. deepdoctection/mapper/misc.py +11 -9
  83. deepdoctection/mapper/pascalstruct.py +4 -4
  84. deepdoctection/mapper/prodigystruct.py +5 -5
  85. deepdoctection/mapper/pubstruct.py +84 -92
  86. deepdoctection/mapper/tpstruct.py +5 -5
  87. deepdoctection/mapper/xfundstruct.py +33 -33
  88. deepdoctection/pipe/__init__.py +1 -1
  89. deepdoctection/pipe/anngen.py +12 -14
  90. deepdoctection/pipe/base.py +52 -106
  91. deepdoctection/pipe/common.py +72 -59
  92. deepdoctection/pipe/concurrency.py +16 -11
  93. deepdoctection/pipe/doctectionpipe.py +24 -21
  94. deepdoctection/pipe/language.py +20 -25
  95. deepdoctection/pipe/layout.py +20 -16
  96. deepdoctection/pipe/lm.py +75 -105
  97. deepdoctection/pipe/order.py +194 -89
  98. deepdoctection/pipe/refine.py +111 -124
  99. deepdoctection/pipe/segment.py +156 -161
  100. deepdoctection/pipe/{cell.py → sub_layout.py} +50 -40
  101. deepdoctection/pipe/text.py +37 -36
  102. deepdoctection/pipe/transform.py +19 -16
  103. deepdoctection/train/__init__.py +6 -12
  104. deepdoctection/train/d2_frcnn_train.py +48 -41
  105. deepdoctection/train/hf_detr_train.py +41 -30
  106. deepdoctection/train/hf_layoutlm_train.py +153 -135
  107. deepdoctection/train/tp_frcnn_train.py +32 -31
  108. deepdoctection/utils/concurrency.py +1 -1
  109. deepdoctection/utils/context.py +13 -6
  110. deepdoctection/utils/develop.py +4 -4
  111. deepdoctection/utils/env_info.py +87 -125
  112. deepdoctection/utils/file_utils.py +6 -11
  113. deepdoctection/utils/fs.py +22 -18
  114. deepdoctection/utils/identifier.py +2 -2
  115. deepdoctection/utils/logger.py +16 -15
  116. deepdoctection/utils/metacfg.py +7 -7
  117. deepdoctection/utils/mocks.py +93 -0
  118. deepdoctection/utils/pdf_utils.py +11 -11
  119. deepdoctection/utils/settings.py +185 -181
  120. deepdoctection/utils/tqdm.py +1 -1
  121. deepdoctection/utils/transform.py +14 -9
  122. deepdoctection/utils/types.py +104 -0
  123. deepdoctection/utils/utils.py +7 -7
  124. deepdoctection/utils/viz.py +74 -72
  125. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/METADATA +30 -21
  126. deepdoctection-0.33.dist-info/RECORD +146 -0
  127. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/WHEEL +1 -1
  128. deepdoctection/utils/detection_types.py +0 -68
  129. deepdoctection-0.31.dist-info/RECORD +0 -144
  130. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/LICENSE +0 -0
  131. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/top_level.txt +0 -0
@@ -18,19 +18,14 @@
18
18
  """
19
19
  Module for training Detectron2 `GeneralizedRCNN`
20
20
  """
21
-
21
+ from __future__ import annotations
22
22
 
23
23
  import copy
24
- from typing import Any, Dict, List, Mapping, Optional, Sequence, Type, Union
24
+ import os
25
+ from pathlib import Path
26
+ from typing import Any, Mapping, Optional, Sequence, Type, Union
25
27
 
26
- from detectron2.config import CfgNode, get_cfg
27
- from detectron2.data import DatasetMapper, build_detection_train_loader
28
- from detectron2.data.transforms import RandomFlip, ResizeShortestEdge
29
- from detectron2.engine import DefaultTrainer, HookBase, default_writers, hooks
30
- from detectron2.utils import comm
31
- from detectron2.utils.events import EventWriter, get_event_storage
32
- from fvcore.nn.precise_bn import get_bn_modules # type: ignore
33
- from torch.utils.data import DataLoader, IterableDataset
28
+ from lazy_imports import try_import
34
29
 
35
30
  from ..datasets.adapter import DatasetAdapter
36
31
  from ..datasets.base import DatasetBase
@@ -39,22 +34,35 @@ from ..eval.base import MetricBase
39
34
  from ..eval.eval import Evaluator
40
35
  from ..eval.registry import metric_registry
41
36
  from ..extern.d2detect import D2FrcnnDetector
42
- from ..extern.pt.ptutils import get_num_gpu
43
37
  from ..mapper.d2struct import image_to_d2_frcnn_training
44
- from ..pipe.base import PredictorPipelineComponent
38
+ from ..pipe.base import PipelineComponent
45
39
  from ..pipe.registry import pipeline_component_registry
46
40
  from ..utils.error import DependencyError
47
41
  from ..utils.file_utils import get_wandb_requirement, wandb_available
48
42
  from ..utils.logger import LoggingRecord, logger
43
+ from ..utils.types import PathLikeOrStr
49
44
  from ..utils.utils import string_to_dict
50
45
 
51
- if wandb_available():
46
+ with try_import() as d2_import_guard:
47
+ from detectron2.config import CfgNode, get_cfg
48
+ from detectron2.data import DatasetMapper, build_detection_train_loader
49
+ from detectron2.data.transforms import RandomFlip, ResizeShortestEdge
50
+ from detectron2.engine import DefaultTrainer, HookBase, default_writers, hooks
51
+ from detectron2.utils import comm
52
+ from detectron2.utils.events import EventWriter, get_event_storage
53
+ from fvcore.nn.precise_bn import get_bn_modules # type: ignore
54
+
55
+ with try_import() as pt_import_guard:
56
+ from torch import cuda
57
+ from torch.utils.data import DataLoader, IterableDataset
58
+
59
+ with try_import() as wb_import_guard:
52
60
  import wandb
53
61
 
54
62
 
55
63
  def _set_config(
56
- path_config_yaml: str,
57
- conf_list: List[str],
64
+ path_config_yaml: PathLikeOrStr,
65
+ conf_list: list[str],
58
66
  dataset_train: DatasetBase,
59
67
  dataset_val: Optional[DatasetBase],
60
68
  metric_name: Optional[str],
@@ -69,7 +77,7 @@ def _set_config(
69
77
  cfg.WANDB.USE_WANDB = False
70
78
  cfg.WANDB.PROJECT = None
71
79
  cfg.WANDB.REPO = "deepdoctection"
72
- cfg.merge_from_file(path_config_yaml)
80
+ cfg.merge_from_file(path_config_yaml.as_posix() if isinstance(path_config_yaml, Path) else path_config_yaml)
73
81
  cfg.merge_from_list(conf_list)
74
82
 
75
83
  cfg.TEST.DO_EVAL = (
@@ -84,7 +92,7 @@ def _set_config(
84
92
  return cfg
85
93
 
86
94
 
87
- def _update_for_eval(config_overwrite: List[str]) -> List[str]:
95
+ def _update_for_eval(config_overwrite: list[str]) -> list[str]:
88
96
  ret = [item for item in config_overwrite if not "WANDB" in item]
89
97
  return ret
90
98
 
@@ -98,7 +106,7 @@ class WandbWriter(EventWriter):
98
106
  self,
99
107
  project: str,
100
108
  repo: str,
101
- config: Optional[Union[Dict[str, Any], CfgNode]] = None,
109
+ config: Optional[Union[dict[str, Any], CfgNode]] = None,
102
110
  window_size: int = 20,
103
111
  **kwargs: Any,
104
112
  ):
@@ -112,7 +120,7 @@ class WandbWriter(EventWriter):
112
120
  config = {}
113
121
  self._window_size = window_size
114
122
  self._run = wandb.init(project=project, config=config, **kwargs) if not wandb.run else wandb.run
115
- self._run._label(repo=repo) # type:ignore
123
+ self._run._label(repo=repo)
116
124
 
117
125
  def write(self) -> None:
118
126
  storage = get_event_storage()
@@ -121,10 +129,10 @@ class WandbWriter(EventWriter):
121
129
  for key, (val, _) in storage.latest_with_smoothing_hint(self._window_size).items():
122
130
  log_dict[key] = val
123
131
 
124
- self._run.log(log_dict) # type:ignore
132
+ self._run.log(log_dict)
125
133
 
126
134
  def close(self) -> None:
127
- self._run.finish() # type:ignore
135
+ self._run.finish()
128
136
 
129
137
 
130
138
  class D2Trainer(DefaultTrainer):
@@ -140,7 +148,7 @@ class D2Trainer(DefaultTrainer):
140
148
  self.build_val_dict: Mapping[str, str] = {}
141
149
  super().__init__(cfg)
142
150
 
143
- def build_hooks(self) -> List[HookBase]:
151
+ def build_hooks(self) -> list[HookBase]:
144
152
  """
145
153
  Overwritten from DefaultTrainer. This ensures that the EvalHook is being called before the writer and
146
154
  all metrics are being written to JSON, Tensorboard etc.
@@ -192,7 +200,7 @@ class D2Trainer(DefaultTrainer):
192
200
 
193
201
  return ret
194
202
 
195
- def build_writers(self) -> List[EventWriter]:
203
+ def build_writers(self) -> list[EventWriter]:
196
204
  """
197
205
  Build a list of writers to be using `default_writers()`.
198
206
  If you'd like a different list of writers, you can overwrite it in
@@ -221,7 +229,7 @@ class D2Trainer(DefaultTrainer):
221
229
  dataset=self.dataset, mapper=self.mapper, total_batch_size=cfg.SOLVER.IMS_PER_BATCH
222
230
  )
223
231
 
224
- def eval_with_dd_evaluator(self, **build_eval_kwargs: str) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
232
+ def eval_with_dd_evaluator(self, **build_eval_kwargs: str) -> Union[list[dict[str, Any]], dict[str, Any]]:
225
233
  """
226
234
  Running the Evaluator. This method will be called from the `EvalHook`
227
235
 
@@ -238,7 +246,7 @@ class D2Trainer(DefaultTrainer):
238
246
  def setup_evaluator(
239
247
  self,
240
248
  dataset_val: DatasetBase,
241
- pipeline_component: PredictorPipelineComponent,
249
+ pipeline_component: PipelineComponent,
242
250
  metric: Union[Type[MetricBase], MetricBase],
243
251
  build_val_dict: Optional[Mapping[str, str]] = None,
244
252
  ) -> None:
@@ -259,16 +267,14 @@ class D2Trainer(DefaultTrainer):
259
267
  dataset_val,
260
268
  pipeline_component,
261
269
  metric,
262
- num_threads=get_num_gpu() * 2,
270
+ num_threads=cuda.device_count() * 2,
263
271
  run=run,
264
272
  )
265
273
  if build_val_dict:
266
274
  self.build_val_dict = build_val_dict
267
275
  assert self.evaluator.pipe_component
268
276
  for comp in self.evaluator.pipe_component.pipe_components:
269
- assert isinstance(comp, PredictorPipelineComponent)
270
- assert isinstance(comp.predictor, D2FrcnnDetector)
271
- comp.predictor.d2_predictor = None
277
+ comp.clear_predictor()
272
278
 
273
279
  @classmethod
274
280
  def build_evaluator(cls, cfg, dataset_name): # type: ignore
@@ -276,11 +282,11 @@ class D2Trainer(DefaultTrainer):
276
282
 
277
283
 
278
284
  def train_d2_faster_rcnn(
279
- path_config_yaml: str,
285
+ path_config_yaml: PathLikeOrStr,
280
286
  dataset_train: Union[str, DatasetBase],
281
- path_weights: str,
282
- config_overwrite: Optional[List[str]] = None,
283
- log_dir: str = "train_log/frcnn",
287
+ path_weights: PathLikeOrStr,
288
+ config_overwrite: Optional[list[str]] = None,
289
+ log_dir: PathLikeOrStr = "train_log/frcnn",
284
290
  build_train_config: Optional[Sequence[str]] = None,
285
291
  dataset_val: Optional[DatasetBase] = None,
286
292
  build_val_config: Optional[Sequence[str]] = None,
@@ -335,15 +341,15 @@ def train_d2_faster_rcnn(
335
341
  :param pipeline_component_name: A pipeline component name to use for validation.
336
342
  """
337
343
 
338
- assert get_num_gpu() > 0, "Has to train with GPU!"
344
+ assert cuda.device_count() > 0, "Has to train with GPU!"
339
345
 
340
- build_train_dict: Dict[str, str] = {}
346
+ build_train_dict: dict[str, str] = {}
341
347
  if build_train_config is not None:
342
348
  build_train_dict = string_to_dict(",".join(build_train_config))
343
349
  if "split" not in build_train_dict:
344
350
  build_train_dict["split"] = "train"
345
351
 
346
- build_val_dict: Dict[str, str] = {}
352
+ build_val_dict: dict[str, str] = {}
347
353
  if build_val_config is not None:
348
354
  build_val_dict = string_to_dict(",".join(build_val_config))
349
355
  if "split" not in build_val_dict:
@@ -353,9 +359,9 @@ def train_d2_faster_rcnn(
353
359
  config_overwrite = []
354
360
  conf_list = [
355
361
  "MODEL.WEIGHTS",
356
- path_weights,
362
+ os.fspath(path_weights),
357
363
  "OUTPUT_DIR",
358
- log_dir,
364
+ os.fspath(log_dir),
359
365
  ]
360
366
  for conf in config_overwrite:
361
367
  key, val = conf.split("=", maxsplit=1)
@@ -371,11 +377,13 @@ def train_d2_faster_rcnn(
371
377
  if metric_name is not None:
372
378
  metric = metric_registry.get(metric_name)
373
379
 
374
- dataset = DatasetAdapter(dataset_train, True, image_to_d2_frcnn_training(False), True, **build_train_dict)
380
+ dataset = DatasetAdapter(
381
+ dataset_train, True, image_to_d2_frcnn_training(False), True, number_repetitions=-1, **build_train_dict
382
+ )
375
383
  augment_list = [ResizeShortestEdge(cfg.INPUT.MIN_SIZE_TRAIN, cfg.INPUT.MAX_SIZE_TRAIN), RandomFlip()]
376
384
  mapper = DatasetMapper(is_train=True, augmentations=augment_list, image_format="BGR")
377
385
 
378
- logger.info(LoggingRecord(f"Config: \n {str(cfg)}", cfg.to_dict()))
386
+ logger.info(LoggingRecord(f"Config: \n {str(cfg)}", dict(cfg)))
379
387
 
380
388
  trainer = D2Trainer(cfg, dataset, mapper)
381
389
  trainer.resume_or_load()
@@ -386,7 +394,6 @@ def train_d2_faster_rcnn(
386
394
  detector = D2FrcnnDetector(path_config_yaml, path_weights, categories, config_overwrite, cfg.MODEL.DEVICE)
387
395
  pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
388
396
  pipeline_component = pipeline_component_cls(detector)
389
- assert isinstance(pipeline_component, PredictorPipelineComponent)
390
397
 
391
398
  if metric_name is not None:
392
399
  metric = metric_registry.get(metric_name)
@@ -19,20 +19,13 @@
19
19
  Module for training Hugging Face Detr implementation. Note, that this scripts only trans Tabletransformer like Detr
20
20
  models that are a slightly different from the plain Detr model that are provided by the transformer library.
21
21
  """
22
+ from __future__ import annotations
22
23
 
23
24
  import copy
24
- from typing import Any, Dict, List, Optional, Sequence, Type, Union
25
-
26
- from torch.nn import Module
27
- from torch.utils.data import Dataset
28
- from transformers import (
29
- AutoFeatureExtractor,
30
- IntervalStrategy,
31
- PretrainedConfig,
32
- PreTrainedModel,
33
- TableTransformerForObjectDetection,
34
- )
35
- from transformers.trainer import Trainer, TrainingArguments
25
+ import os
26
+ from typing import Any, Optional, Sequence, Type, Union
27
+
28
+ from lazy_imports import try_import
36
29
 
37
30
  from ..datasets.adapter import DatasetAdapter
38
31
  from ..datasets.base import DatasetBase
@@ -42,11 +35,27 @@ from ..eval.eval import Evaluator
42
35
  from ..eval.registry import metric_registry
43
36
  from ..extern.hfdetr import HFDetrDerivedDetector
44
37
  from ..mapper.hfstruct import DetrDataCollator, image_to_hf_detr_training
45
- from ..pipe.base import PredictorPipelineComponent
38
+ from ..pipe.base import PipelineComponent
46
39
  from ..pipe.registry import pipeline_component_registry
47
40
  from ..utils.logger import LoggingRecord, logger
41
+ from ..utils.types import PathLikeOrStr
48
42
  from ..utils.utils import string_to_dict
49
43
 
44
+ with try_import() as pt_import_guard:
45
+ from torch import nn
46
+ from torch.utils.data import Dataset
47
+
48
+ with try_import() as hf_import_guard:
49
+ from transformers import (
50
+ AutoFeatureExtractor,
51
+ IntervalStrategy,
52
+ PretrainedConfig,
53
+ PreTrainedModel,
54
+ TableTransformerForObjectDetection,
55
+ Trainer,
56
+ TrainingArguments,
57
+ )
58
+
50
59
 
51
60
  class DetrDerivedTrainer(Trainer):
52
61
  """
@@ -61,19 +70,19 @@ class DetrDerivedTrainer(Trainer):
61
70
 
62
71
  def __init__(
63
72
  self,
64
- model: Union[PreTrainedModel, Module],
73
+ model: Union[PreTrainedModel, nn.Module],
65
74
  args: TrainingArguments,
66
75
  data_collator: DetrDataCollator,
67
76
  train_dataset: Dataset[Any],
68
77
  ):
69
78
  self.evaluator: Optional[Evaluator] = None
70
- self.build_eval_kwargs: Optional[Dict[str, Any]] = None
79
+ self.build_eval_kwargs: Optional[dict[str, Any]] = None
71
80
  super().__init__(model, args, data_collator, train_dataset)
72
81
 
73
82
  def setup_evaluator(
74
83
  self,
75
84
  dataset_val: DatasetBase,
76
- pipeline_component: PredictorPipelineComponent,
85
+ pipeline_component: PipelineComponent,
77
86
  metric: Union[Type[MetricBase], MetricBase],
78
87
  **build_eval_kwargs: Union[str, int],
79
88
  ) -> None:
@@ -90,17 +99,15 @@ class DetrDerivedTrainer(Trainer):
90
99
  self.evaluator = Evaluator(dataset_val, pipeline_component, metric, num_threads=1)
91
100
  assert self.evaluator.pipe_component
92
101
  for comp in self.evaluator.pipe_component.pipe_components:
93
- assert isinstance(comp, PredictorPipelineComponent)
94
- assert isinstance(comp.predictor, HFDetrDerivedDetector)
95
- comp.predictor.hf_detr_predictor = None
102
+ comp.clear_predictor()
96
103
  self.build_eval_kwargs = build_eval_kwargs
97
104
 
98
105
  def evaluate(
99
106
  self,
100
107
  eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
101
- ignore_keys: Optional[List[str]] = None, # pylint: disable=W0613
108
+ ignore_keys: Optional[list[str]] = None, # pylint: disable=W0613
102
109
  metric_key_prefix: str = "eval", # pylint: disable=W0613
103
- ) -> Dict[str, float]:
110
+ ) -> dict[str, float]:
104
111
  """
105
112
  Overwritten method from `Trainer`. Arguments will not be used.
106
113
  """
@@ -122,12 +129,12 @@ class DetrDerivedTrainer(Trainer):
122
129
 
123
130
 
124
131
  def train_hf_detr(
125
- path_config_json: str,
132
+ path_config_json: PathLikeOrStr,
126
133
  dataset_train: Union[str, DatasetBase],
127
- path_weights: str,
134
+ path_weights: PathLikeOrStr,
128
135
  path_feature_extractor_config_json: str,
129
- config_overwrite: Optional[List[str]] = None,
130
- log_dir: str = "train_log/detr",
136
+ config_overwrite: Optional[list[str]] = None,
137
+ log_dir: PathLikeOrStr = "train_log/detr",
131
138
  build_train_config: Optional[Sequence[str]] = None,
132
139
  dataset_val: Optional[DatasetBase] = None,
133
140
  build_val_config: Optional[Sequence[str]] = None,
@@ -162,13 +169,13 @@ def train_hf_detr(
162
169
  :param pipeline_component_name: A pipeline component name to use for validation
163
170
  """
164
171
 
165
- build_train_dict: Dict[str, str] = {}
172
+ build_train_dict: dict[str, str] = {}
166
173
  if build_train_config is not None:
167
174
  build_train_dict = string_to_dict(",".join(build_train_config))
168
175
  if "split" not in build_train_dict:
169
176
  build_train_dict["split"] = "train"
170
177
 
171
- build_val_dict: Dict[str, str] = {}
178
+ build_val_dict: dict[str, str] = {}
172
179
  if build_val_config is not None:
173
180
  build_val_dict = string_to_dict(",".join(build_val_config))
174
181
  if "split" not in build_val_dict:
@@ -184,12 +191,17 @@ def train_hf_detr(
184
191
  categories_dict_name_as_key = dataset_train.dataflow.categories.get_categories(name_as_key=True, filtered=True)
185
192
 
186
193
  dataset = DatasetAdapter(
187
- dataset_train, True, image_to_hf_detr_training(category_names=categories), True, **build_train_dict
194
+ dataset_train,
195
+ True,
196
+ image_to_hf_detr_training(category_names=categories),
197
+ True,
198
+ number_repetitions=-1,
199
+ **build_train_dict,
188
200
  )
189
201
 
190
202
  number_samples = len(dataset)
191
203
  conf_dict = {
192
- "output_dir": log_dir,
204
+ "output_dir": os.fspath(log_dir),
193
205
  "remove_unused_columns": False,
194
206
  "per_device_train_batch_size": 2,
195
207
  "max_steps": number_samples,
@@ -249,7 +261,6 @@ def train_hf_detr(
249
261
  )
250
262
  pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
251
263
  pipeline_component = pipeline_component_cls(detector)
252
- assert isinstance(pipeline_component, PredictorPipelineComponent)
253
264
 
254
265
  if metric_name is not None:
255
266
  metric = metric_registry.get(metric_name)