deepdoctection 0.42.0__py3-none-any.whl → 0.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +2 -1
- deepdoctection/analyzer/__init__.py +2 -1
- deepdoctection/analyzer/config.py +904 -0
- deepdoctection/analyzer/dd.py +36 -62
- deepdoctection/analyzer/factory.py +311 -141
- deepdoctection/configs/conf_dd_one.yaml +100 -44
- deepdoctection/configs/profiles.jsonl +32 -0
- deepdoctection/dataflow/__init__.py +9 -6
- deepdoctection/dataflow/base.py +33 -15
- deepdoctection/dataflow/common.py +96 -75
- deepdoctection/dataflow/custom.py +36 -29
- deepdoctection/dataflow/custom_serialize.py +135 -91
- deepdoctection/dataflow/parallel_map.py +33 -31
- deepdoctection/dataflow/serialize.py +15 -10
- deepdoctection/dataflow/stats.py +41 -28
- deepdoctection/datapoint/__init__.py +4 -6
- deepdoctection/datapoint/annotation.py +104 -66
- deepdoctection/datapoint/box.py +190 -130
- deepdoctection/datapoint/convert.py +66 -39
- deepdoctection/datapoint/image.py +151 -95
- deepdoctection/datapoint/view.py +383 -236
- deepdoctection/datasets/__init__.py +2 -6
- deepdoctection/datasets/adapter.py +11 -11
- deepdoctection/datasets/base.py +118 -81
- deepdoctection/datasets/dataflow_builder.py +18 -12
- deepdoctection/datasets/info.py +76 -57
- deepdoctection/datasets/instances/__init__.py +6 -2
- deepdoctection/datasets/instances/doclaynet.py +17 -14
- deepdoctection/datasets/instances/fintabnet.py +16 -22
- deepdoctection/datasets/instances/funsd.py +11 -6
- deepdoctection/datasets/instances/iiitar13k.py +9 -9
- deepdoctection/datasets/instances/layouttest.py +9 -9
- deepdoctection/datasets/instances/publaynet.py +9 -9
- deepdoctection/datasets/instances/pubtables1m.py +13 -13
- deepdoctection/datasets/instances/pubtabnet.py +13 -15
- deepdoctection/datasets/instances/rvlcdip.py +8 -8
- deepdoctection/datasets/instances/xfund.py +11 -9
- deepdoctection/datasets/registry.py +18 -11
- deepdoctection/datasets/save.py +12 -11
- deepdoctection/eval/__init__.py +3 -2
- deepdoctection/eval/accmetric.py +72 -52
- deepdoctection/eval/base.py +29 -10
- deepdoctection/eval/cocometric.py +14 -12
- deepdoctection/eval/eval.py +56 -41
- deepdoctection/eval/registry.py +6 -3
- deepdoctection/eval/tedsmetric.py +24 -9
- deepdoctection/eval/tp_eval_callback.py +13 -12
- deepdoctection/extern/__init__.py +1 -1
- deepdoctection/extern/base.py +176 -97
- deepdoctection/extern/d2detect.py +127 -92
- deepdoctection/extern/deskew.py +19 -10
- deepdoctection/extern/doctrocr.py +157 -106
- deepdoctection/extern/fastlang.py +25 -17
- deepdoctection/extern/hfdetr.py +137 -60
- deepdoctection/extern/hflayoutlm.py +329 -248
- deepdoctection/extern/hflm.py +67 -33
- deepdoctection/extern/model.py +108 -762
- deepdoctection/extern/pdftext.py +37 -12
- deepdoctection/extern/pt/nms.py +15 -1
- deepdoctection/extern/pt/ptutils.py +13 -9
- deepdoctection/extern/tessocr.py +87 -54
- deepdoctection/extern/texocr.py +29 -14
- deepdoctection/extern/tp/tfutils.py +36 -8
- deepdoctection/extern/tp/tpcompat.py +54 -16
- deepdoctection/extern/tp/tpfrcnn/config/config.py +20 -4
- deepdoctection/extern/tpdetect.py +4 -2
- deepdoctection/mapper/__init__.py +1 -1
- deepdoctection/mapper/cats.py +117 -76
- deepdoctection/mapper/cocostruct.py +35 -17
- deepdoctection/mapper/d2struct.py +56 -29
- deepdoctection/mapper/hfstruct.py +32 -19
- deepdoctection/mapper/laylmstruct.py +221 -185
- deepdoctection/mapper/maputils.py +71 -35
- deepdoctection/mapper/match.py +76 -62
- deepdoctection/mapper/misc.py +68 -44
- deepdoctection/mapper/pascalstruct.py +13 -12
- deepdoctection/mapper/prodigystruct.py +33 -19
- deepdoctection/mapper/pubstruct.py +42 -32
- deepdoctection/mapper/tpstruct.py +39 -19
- deepdoctection/mapper/xfundstruct.py +20 -13
- deepdoctection/pipe/__init__.py +1 -2
- deepdoctection/pipe/anngen.py +104 -62
- deepdoctection/pipe/base.py +226 -107
- deepdoctection/pipe/common.py +206 -123
- deepdoctection/pipe/concurrency.py +74 -47
- deepdoctection/pipe/doctectionpipe.py +108 -47
- deepdoctection/pipe/language.py +41 -24
- deepdoctection/pipe/layout.py +45 -18
- deepdoctection/pipe/lm.py +146 -78
- deepdoctection/pipe/order.py +196 -113
- deepdoctection/pipe/refine.py +111 -63
- deepdoctection/pipe/registry.py +1 -1
- deepdoctection/pipe/segment.py +213 -142
- deepdoctection/pipe/sub_layout.py +76 -46
- deepdoctection/pipe/text.py +52 -33
- deepdoctection/pipe/transform.py +8 -6
- deepdoctection/train/d2_frcnn_train.py +87 -69
- deepdoctection/train/hf_detr_train.py +72 -40
- deepdoctection/train/hf_layoutlm_train.py +85 -46
- deepdoctection/train/tp_frcnn_train.py +56 -28
- deepdoctection/utils/concurrency.py +59 -16
- deepdoctection/utils/context.py +40 -19
- deepdoctection/utils/develop.py +25 -17
- deepdoctection/utils/env_info.py +85 -36
- deepdoctection/utils/error.py +16 -10
- deepdoctection/utils/file_utils.py +246 -62
- deepdoctection/utils/fs.py +162 -43
- deepdoctection/utils/identifier.py +29 -16
- deepdoctection/utils/logger.py +49 -32
- deepdoctection/utils/metacfg.py +83 -21
- deepdoctection/utils/pdf_utils.py +119 -62
- deepdoctection/utils/settings.py +24 -10
- deepdoctection/utils/tqdm.py +10 -5
- deepdoctection/utils/transform.py +182 -46
- deepdoctection/utils/utils.py +61 -28
- deepdoctection/utils/viz.py +150 -104
- deepdoctection-0.43.dist-info/METADATA +376 -0
- deepdoctection-0.43.dist-info/RECORD +149 -0
- {deepdoctection-0.42.0.dist-info → deepdoctection-0.43.dist-info}/WHEEL +1 -1
- deepdoctection/analyzer/_config.py +0 -146
- deepdoctection-0.42.0.dist-info/METADATA +0 -431
- deepdoctection-0.42.0.dist-info/RECORD +0 -148
- {deepdoctection-0.42.0.dist-info → deepdoctection-0.43.dist-info}/licenses/LICENSE +0 -0
- {deepdoctection-0.42.0.dist-info → deepdoctection-0.43.dist-info}/top_level.txt +0 -0
deepdoctection/extern/hfdetr.py
CHANGED
|
@@ -16,14 +16,15 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
"""
|
|
19
|
-
HF Detr
|
|
19
|
+
HF Detr and DeformableDetr models.
|
|
20
20
|
"""
|
|
21
|
+
|
|
21
22
|
from __future__ import annotations
|
|
22
23
|
|
|
23
24
|
import os
|
|
24
25
|
from abc import ABC
|
|
25
26
|
from pathlib import Path
|
|
26
|
-
from typing import Literal, Mapping, Optional, Sequence, Union
|
|
27
|
+
from typing import TYPE_CHECKING, Literal, Mapping, Optional, Sequence, Union
|
|
27
28
|
|
|
28
29
|
from lazy_imports import try_import
|
|
29
30
|
|
|
@@ -39,13 +40,17 @@ with try_import() as pt_import_guard:
|
|
|
39
40
|
|
|
40
41
|
with try_import() as tr_import_guard:
|
|
41
42
|
from transformers import ( # pylint: disable=W0611
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
43
|
+
DeformableDetrForObjectDetection,
|
|
44
|
+
DeformableDetrImageProcessorFast,
|
|
45
|
+
DetrImageProcessorFast,
|
|
45
46
|
PretrainedConfig,
|
|
46
47
|
TableTransformerForObjectDetection,
|
|
47
48
|
)
|
|
48
49
|
|
|
50
|
+
if TYPE_CHECKING:
|
|
51
|
+
EligibleDetrModel = Union[TableTransformerForObjectDetection, DeformableDetrForObjectDetection]
|
|
52
|
+
DetrImageProcessor = Union[DetrImageProcessorFast, DeformableDetrImageProcessorFast]
|
|
53
|
+
|
|
49
54
|
|
|
50
55
|
def _detr_post_processing(
|
|
51
56
|
boxes: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor, nms_thresh: float
|
|
@@ -55,24 +60,27 @@ def _detr_post_processing(
|
|
|
55
60
|
|
|
56
61
|
def detr_predict_image(
|
|
57
62
|
np_img: PixelValues,
|
|
58
|
-
predictor:
|
|
63
|
+
predictor: EligibleDetrModel,
|
|
59
64
|
feature_extractor: DetrImageProcessor,
|
|
60
65
|
device: torch.device,
|
|
61
66
|
threshold: float,
|
|
62
67
|
nms_threshold: float,
|
|
63
68
|
) -> list[DetectionResult]:
|
|
64
69
|
"""
|
|
65
|
-
Calling predictor. Before
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
70
|
+
Calling predictor. Before, tensors must be transferred to the device where the model is loaded.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
np_img: Image as `np.array`.
|
|
74
|
+
predictor: `TableTransformerForObjectDetection` instance.
|
|
75
|
+
feature_extractor: Feature extractor instance.
|
|
76
|
+
device: Device where the model is loaded.
|
|
77
|
+
threshold: Will filter all predictions with confidence score less threshold.
|
|
78
|
+
nms_threshold: Threshold to perform NMS on prediction outputs.
|
|
79
|
+
Note:
|
|
80
|
+
NMS does not belong to canonical Detr inference processing.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
List of `DetectionResult` after running prediction.
|
|
76
84
|
"""
|
|
77
85
|
target_sizes = [np_img.shape[:2]]
|
|
78
86
|
inputs = feature_extractor(images=np_img, return_tensors="pt")
|
|
@@ -101,10 +109,10 @@ class HFDetrDerivedDetectorMixin(ObjectDetector, ABC):
|
|
|
101
109
|
|
|
102
110
|
def __init__(self, categories: Mapping[int, TypeOrStr], filter_categories: Optional[Sequence[TypeOrStr]] = None):
|
|
103
111
|
"""
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
112
|
+
Args:
|
|
113
|
+
categories: A dict with key (indices) and values (category names).
|
|
114
|
+
filter_categories: The model might return objects that are not supposed to be predicted and that should
|
|
115
|
+
be filtered. Pass a list of category names that must not be returned.
|
|
108
116
|
"""
|
|
109
117
|
self.categories = ModelCategories(init_categories=categories)
|
|
110
118
|
if filter_categories:
|
|
@@ -112,10 +120,13 @@ class HFDetrDerivedDetectorMixin(ObjectDetector, ABC):
|
|
|
112
120
|
|
|
113
121
|
def _map_category_names(self, detection_results: list[DetectionResult]) -> list[DetectionResult]:
|
|
114
122
|
"""
|
|
115
|
-
Populating category names to
|
|
123
|
+
Populating category names to `DetectionResult`. Will also filter categories.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
detection_results: List of `DetectionResult`s.
|
|
116
127
|
|
|
117
|
-
:
|
|
118
|
-
|
|
128
|
+
Returns:
|
|
129
|
+
List of `DetectionResult`s with `class_name`.
|
|
119
130
|
"""
|
|
120
131
|
filtered_detection_result: list[DetectionResult] = []
|
|
121
132
|
shifted_categories = self.categories.shift_category_ids(shift_by=-1)
|
|
@@ -132,24 +143,39 @@ class HFDetrDerivedDetectorMixin(ObjectDetector, ABC):
|
|
|
132
143
|
|
|
133
144
|
@staticmethod
|
|
134
145
|
def get_name(path_weights: PathLikeOrStr) -> str:
|
|
135
|
-
"""
|
|
146
|
+
"""
|
|
147
|
+
Returns the name of the model.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
path_weights: Path to the model weights.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Model name string.
|
|
154
|
+
"""
|
|
136
155
|
return "Transformers_Tatr_" + "_".join(Path(path_weights).parts[-2:])
|
|
137
156
|
|
|
138
157
|
def get_category_names(self) -> tuple[ObjectTypes, ...]:
|
|
158
|
+
"""
|
|
159
|
+
Returns:
|
|
160
|
+
Tuple of `category_name`s.
|
|
161
|
+
"""
|
|
139
162
|
return self.categories.get_categories(as_dict=False)
|
|
140
163
|
|
|
141
164
|
|
|
142
165
|
class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
|
|
143
166
|
"""
|
|
144
|
-
Model wrapper for TableTransformerForObjectDetection that again is based on
|
|
145
|
-
|
|
146
|
-
https://github.com/microsoft/table-transformer .
|
|
167
|
+
Model wrapper for `TableTransformerForObjectDetection` that again is based on
|
|
168
|
+
<https://github.com/microsoft/table-transformer>.
|
|
147
169
|
|
|
148
170
|
The wrapper can be used to load pre-trained models for table detection and table structure recognition. Running Detr
|
|
149
|
-
models trained from scratch on custom datasets is possible as well.
|
|
150
|
-
|
|
151
|
-
|
|
171
|
+
models trained from scratch on custom datasets is possible as well.
|
|
172
|
+
|
|
173
|
+
Note:
|
|
174
|
+
This wrapper will load `TableTransformerForObjectDetection` that is slightly different compared to
|
|
175
|
+
`DetrForObjectDetection` that can be found in the transformer library as well.
|
|
152
176
|
|
|
177
|
+
Example:
|
|
178
|
+
```python
|
|
153
179
|
config_path = ModelCatalog.
|
|
154
180
|
get_full_path_configs("microsoft/table-transformer-structure-recognition/pytorch_model.bin")
|
|
155
181
|
weights_path = ModelDownloadManager.
|
|
@@ -162,6 +188,7 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
|
|
|
162
188
|
detr_predictor = HFDetrDerivedDetector(config_path,weights_path,feature_extractor_config_path,categories)
|
|
163
189
|
|
|
164
190
|
detection_result = detr_predictor.predict(bgr_image_np_array)
|
|
191
|
+
```
|
|
165
192
|
"""
|
|
166
193
|
|
|
167
194
|
def __init__(
|
|
@@ -175,13 +202,15 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
|
|
|
175
202
|
):
|
|
176
203
|
"""
|
|
177
204
|
Set up the predictor.
|
|
178
|
-
|
|
179
|
-
:
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
path_config_json: The path to the json config.
|
|
208
|
+
path_weights: The path to the model checkpoint.
|
|
209
|
+
path_feature_extractor_config_json: The path to the feature extractor config.
|
|
210
|
+
categories: A dict with key (indices) and values (category names).
|
|
211
|
+
device: "cpu" or "cuda". If not specified will auto select depending on what is available.
|
|
212
|
+
filter_categories: The model might return objects that are not supposed to be predicted and that should
|
|
213
|
+
be filtered. Pass a list of category names that must not be returned.
|
|
185
214
|
"""
|
|
186
215
|
super().__init__(categories, filter_categories)
|
|
187
216
|
|
|
@@ -195,12 +224,21 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
|
|
|
195
224
|
self.config = self.get_config(path_config_json)
|
|
196
225
|
|
|
197
226
|
self.hf_detr_predictor = self.get_model(self.path_weights, self.config)
|
|
198
|
-
self.feature_extractor = self.get_pre_processor(self.path_feature_extractor_config)
|
|
227
|
+
self.feature_extractor = self.get_pre_processor(self.path_feature_extractor_config, self.config)
|
|
199
228
|
|
|
200
229
|
self.device = get_torch_device(device)
|
|
201
230
|
self.hf_detr_predictor.to(self.device)
|
|
202
231
|
|
|
203
232
|
def predict(self, np_img: PixelValues) -> list[DetectionResult]:
|
|
233
|
+
"""
|
|
234
|
+
Predicts objects in an image.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
np_img: Image as `np.array`.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
List of `DetectionResult`.
|
|
241
|
+
"""
|
|
204
242
|
results = detr_predict_image(
|
|
205
243
|
np_img,
|
|
206
244
|
self.hf_detr_predictor,
|
|
@@ -212,36 +250,71 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
|
|
|
212
250
|
return self._map_category_names(results)
|
|
213
251
|
|
|
214
252
|
@staticmethod
|
|
215
|
-
def get_model(path_weights: PathLikeOrStr, config: PretrainedConfig) ->
|
|
253
|
+
def get_model(path_weights: PathLikeOrStr, config: PretrainedConfig) -> EligibleDetrModel:
|
|
216
254
|
"""
|
|
217
|
-
Builds the Detr model
|
|
255
|
+
Builds the Detr model.
|
|
218
256
|
|
|
219
|
-
:
|
|
220
|
-
|
|
221
|
-
|
|
257
|
+
Args:
|
|
258
|
+
path_weights: The path to the model checkpoint.
|
|
259
|
+
config: `PretrainedConfig` instance.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
`TableTransformerForObjectDetection` instance.
|
|
263
|
+
|
|
264
|
+
Raises:
|
|
265
|
+
ValueError: If model architecture is not eligible.
|
|
222
266
|
"""
|
|
223
|
-
|
|
224
|
-
|
|
267
|
+
if "TableTransformerForObjectDetection" in config.architectures:
|
|
268
|
+
return TableTransformerForObjectDetection.from_pretrained(
|
|
269
|
+
pretrained_model_name_or_path=os.fspath(path_weights), config=config
|
|
270
|
+
)
|
|
271
|
+
if "DeformableDetrForObjectDetection" in config.architectures:
|
|
272
|
+
return DeformableDetrForObjectDetection.from_pretrained(
|
|
273
|
+
pretrained_model_name_or_path=os.fspath(path_weights), config=config
|
|
274
|
+
)
|
|
275
|
+
raise ValueError(
|
|
276
|
+
f"Model architecture {config.architectures} not eligible. Please use either "
|
|
277
|
+
"TableTransformerForObjectDetection or DeformableDetrForObjectDetection."
|
|
225
278
|
)
|
|
226
279
|
|
|
227
280
|
@staticmethod
|
|
228
|
-
def get_pre_processor(path_feature_extractor_config: PathLikeOrStr) -> DetrImageProcessor:
|
|
281
|
+
def get_pre_processor(path_feature_extractor_config: PathLikeOrStr, config: PretrainedConfig) -> DetrImageProcessor:
|
|
229
282
|
"""
|
|
230
|
-
Builds the feature extractor
|
|
283
|
+
Builds the feature extractor.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
path_feature_extractor_config: Path to feature extractor config.
|
|
287
|
+
config: Model configuration.
|
|
231
288
|
|
|
232
|
-
:
|
|
289
|
+
Returns:
|
|
290
|
+
`DetrImageProcessor` instance.
|
|
291
|
+
|
|
292
|
+
Raises:
|
|
293
|
+
ValueError: If model architecture is not eligible.
|
|
233
294
|
"""
|
|
234
|
-
|
|
235
|
-
|
|
295
|
+
if "TableTransformerForObjectDetection" in config.architectures:
|
|
296
|
+
return DetrImageProcessorFast.from_pretrained(
|
|
297
|
+
pretrained_model_name_or_path=os.fspath(path_feature_extractor_config),
|
|
298
|
+
)
|
|
299
|
+
if "DeformableDetrForObjectDetection" in config.architectures:
|
|
300
|
+
return DeformableDetrImageProcessorFast.from_pretrained(
|
|
301
|
+
pretrained_model_name_or_path=os.fspath(path_feature_extractor_config),
|
|
302
|
+
)
|
|
303
|
+
raise ValueError(
|
|
304
|
+
f"Model architecture {config.architectures} not eligible. Please use either "
|
|
305
|
+
"TableTransformerForObjectDetection or DeformableDetrForObjectDetection."
|
|
236
306
|
)
|
|
237
307
|
|
|
238
308
|
@staticmethod
|
|
239
309
|
def get_config(path_config: PathLikeOrStr) -> PretrainedConfig:
|
|
240
310
|
"""
|
|
241
|
-
Builds the config
|
|
311
|
+
Builds the config.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
path_config: The path to the config.
|
|
242
315
|
|
|
243
|
-
:
|
|
244
|
-
|
|
316
|
+
Returns:
|
|
317
|
+
`PretrainedConfig` instance.
|
|
245
318
|
"""
|
|
246
319
|
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config))
|
|
247
320
|
config.use_timm_backbone = True
|
|
@@ -270,17 +343,21 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
|
|
|
270
343
|
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
271
344
|
) -> TableTransformerForObjectDetection:
|
|
272
345
|
"""
|
|
273
|
-
Get the wrapped model
|
|
346
|
+
Get the wrapped model.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
path_config_json: The path to the json config.
|
|
350
|
+
path_weights: The path to the model checkpoint.
|
|
351
|
+
device: "cpu" or "cuda". If not specified will auto select depending on what is available.
|
|
274
352
|
|
|
275
|
-
:
|
|
276
|
-
|
|
277
|
-
:param device: "cpu" or "cuda". If not specified will auto select depending on what is available
|
|
278
|
-
:return: TableTransformerForObjectDetection instance
|
|
353
|
+
Returns:
|
|
354
|
+
`TableTransformerForObjectDetection` instance.
|
|
279
355
|
"""
|
|
280
356
|
config = HFDetrDerivedDetector.get_config(path_config_json)
|
|
281
357
|
hf_detr_predictor = HFDetrDerivedDetector.get_model(path_weights, config)
|
|
282
358
|
device = get_torch_device(device)
|
|
283
|
-
|
|
359
|
+
hf_detr_predictor.to(device)
|
|
360
|
+
return hf_detr_predictor
|
|
284
361
|
|
|
285
362
|
def clear_model(self) -> None:
|
|
286
363
|
self.hf_detr_predictor = None
|