deepdoctection 0.46__py3-none-any.whl → 0.46.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +1 -1
- deepdoctection/analyzer/factory.py +711 -205
- deepdoctection/utils/viz.py +34 -130
- {deepdoctection-0.46.dist-info → deepdoctection-0.46.2.dist-info}/METADATA +1 -1
- {deepdoctection-0.46.dist-info → deepdoctection-0.46.2.dist-info}/RECORD +8 -8
- {deepdoctection-0.46.dist-info → deepdoctection-0.46.2.dist-info}/WHEEL +0 -0
- {deepdoctection-0.46.dist-info → deepdoctection-0.46.2.dist-info}/licenses/LICENSE +0 -0
- {deepdoctection-0.46.dist-info → deepdoctection-0.46.2.dist-info}/top_level.txt +0 -0
|
@@ -22,7 +22,7 @@
|
|
|
22
22
|
from __future__ import annotations
|
|
23
23
|
|
|
24
24
|
from os import environ
|
|
25
|
-
from typing import TYPE_CHECKING, Literal, Union
|
|
25
|
+
from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Sequence, Union
|
|
26
26
|
|
|
27
27
|
from lazy_imports import try_import
|
|
28
28
|
|
|
@@ -42,7 +42,7 @@ from ..extern.hflayoutlm import (
|
|
|
42
42
|
get_tokenizer_from_model_class,
|
|
43
43
|
)
|
|
44
44
|
from ..extern.hflm import HFLmSequenceClassifier, HFLmTokenClassifier
|
|
45
|
-
from ..extern.model import ModelCatalog, ModelDownloadManager
|
|
45
|
+
from ..extern.model import ModelCatalog, ModelDownloadManager, ModelProfile
|
|
46
46
|
from ..extern.pdftext import PdfPlumberTextDetector
|
|
47
47
|
from ..extern.tessocr import TesseractOcrDetector, TesseractRotationTransformer
|
|
48
48
|
from ..extern.texocr import TextractOcrDetector
|
|
@@ -68,7 +68,7 @@ from ..pipe.transform import SimpleTransformService
|
|
|
68
68
|
from ..utils.error import DependencyError
|
|
69
69
|
from ..utils.fs import get_configs_dir_path
|
|
70
70
|
from ..utils.metacfg import AttrDict
|
|
71
|
-
from ..utils.settings import CellType, LayoutType, Relationships
|
|
71
|
+
from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships
|
|
72
72
|
from ..utils.transform import PadTransform
|
|
73
73
|
|
|
74
74
|
with try_import() as image_guard:
|
|
@@ -104,20 +104,14 @@ class ServiceFactory:
|
|
|
104
104
|
"""
|
|
105
105
|
|
|
106
106
|
@staticmethod
|
|
107
|
-
def
|
|
108
|
-
config: AttrDict,
|
|
109
|
-
mode: str,
|
|
110
|
-
) -> Union[D2FrcnnDetector, TPFrcnnDetector, HFDetrDerivedDetector, D2FrcnnTracingDetector]:
|
|
107
|
+
def _get_layout_detector_kwargs_from_config(config: AttrDict, mode: str) -> dict[str, Any]:
|
|
111
108
|
"""
|
|
112
|
-
|
|
113
|
-
the config.
|
|
109
|
+
Extracting layout detector kwargs from config.
|
|
114
110
|
|
|
115
111
|
Args:
|
|
116
112
|
config: Configuration object.
|
|
117
113
|
mode: Either `LAYOUT`, `CELL`, or `ITEM`.
|
|
118
114
|
"""
|
|
119
|
-
if config.LIB is None:
|
|
120
|
-
raise DependencyError("At least one of the env variables DD_USE_TF or DD_USE_TORCH must be set.")
|
|
121
115
|
|
|
122
116
|
weights = (
|
|
123
117
|
getattr(config.TF, mode).WEIGHTS
|
|
@@ -128,16 +122,52 @@ class ServiceFactory:
|
|
|
128
122
|
else getattr(config.PT, mode).WEIGHTS_TS
|
|
129
123
|
)
|
|
130
124
|
)
|
|
125
|
+
|
|
131
126
|
filter_categories = (
|
|
132
127
|
getattr(getattr(config.TF, mode), "FILTER")
|
|
133
128
|
if config.LIB == "TF"
|
|
134
129
|
else getattr(getattr(config.PT, mode), "FILTER")
|
|
135
130
|
)
|
|
136
|
-
|
|
137
|
-
weights_path = ModelDownloadManager.maybe_download_weights_and_configs(weights)
|
|
131
|
+
|
|
138
132
|
profile = ModelCatalog.get_profile(weights)
|
|
133
|
+
|
|
139
134
|
if config.LIB == "PT" and profile.padding is not None:
|
|
140
135
|
getattr(config.PT, mode).PADDING = profile.padding
|
|
136
|
+
|
|
137
|
+
device = config.DEVICE
|
|
138
|
+
|
|
139
|
+
return {
|
|
140
|
+
"weights": weights,
|
|
141
|
+
"filter_categories": filter_categories,
|
|
142
|
+
"profile": profile,
|
|
143
|
+
"device": device,
|
|
144
|
+
"lib": config.LIB,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def _build_layout_detector(
|
|
149
|
+
weights: str,
|
|
150
|
+
filter_categories: list[str],
|
|
151
|
+
profile: ModelProfile,
|
|
152
|
+
device: Literal["cpu", "cuda"],
|
|
153
|
+
lib: Literal["TF", "PT", None],
|
|
154
|
+
) -> Union[D2FrcnnDetector, TPFrcnnDetector, HFDetrDerivedDetector, D2FrcnnTracingDetector]:
|
|
155
|
+
"""
|
|
156
|
+
Building a D2-Detector, a TP-Detector as Detr-Detector or a D2-Torch Tracing Detector according to
|
|
157
|
+
the config.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
weights: Weights for the layout detector.
|
|
161
|
+
filter_categories: Categories to filter during detection.
|
|
162
|
+
profile: Model profile for the layout detector.
|
|
163
|
+
device: Device to use for computation.
|
|
164
|
+
lib: Deep learning library to use.
|
|
165
|
+
"""
|
|
166
|
+
if lib is None:
|
|
167
|
+
raise DependencyError("At least one of the env variables DD_USE_TF or DD_USE_TORCH must be set.")
|
|
168
|
+
|
|
169
|
+
config_path = ModelCatalog.get_full_path_configs(weights)
|
|
170
|
+
weights_path = ModelDownloadManager.maybe_download_weights_and_configs(weights)
|
|
141
171
|
categories = profile.categories if profile.categories is not None else {}
|
|
142
172
|
|
|
143
173
|
if profile.model_wrapper in ("TPFrcnnDetector",):
|
|
@@ -152,7 +182,7 @@ class ServiceFactory:
|
|
|
152
182
|
path_yaml=config_path,
|
|
153
183
|
path_weights=weights_path,
|
|
154
184
|
categories=categories,
|
|
155
|
-
device=
|
|
185
|
+
device=device,
|
|
156
186
|
filter_categories=filter_categories,
|
|
157
187
|
)
|
|
158
188
|
if profile.model_wrapper in ("D2FrcnnTracingDetector",):
|
|
@@ -169,7 +199,7 @@ class ServiceFactory:
|
|
|
169
199
|
path_weights=weights_path,
|
|
170
200
|
path_feature_extractor_config_json=preprocessor_config,
|
|
171
201
|
categories=categories,
|
|
172
|
-
device=
|
|
202
|
+
device=device,
|
|
173
203
|
filter_categories=filter_categories,
|
|
174
204
|
)
|
|
175
205
|
raise TypeError(
|
|
@@ -188,7 +218,8 @@ class ServiceFactory:
|
|
|
188
218
|
config: Configuration object.
|
|
189
219
|
mode: Either `LAYOUT`, `CELL`, or `ITEM`.
|
|
190
220
|
"""
|
|
191
|
-
|
|
221
|
+
layout_detector_kwargs = ServiceFactory._get_layout_detector_kwargs_from_config(config, mode)
|
|
222
|
+
return ServiceFactory._build_layout_detector(**layout_detector_kwargs)
|
|
192
223
|
|
|
193
224
|
@staticmethod
|
|
194
225
|
def _build_rotation_detector(rotator_name: Literal["tesseract", "doctr"]) -> RotationTransformer:
|
|
@@ -245,24 +276,36 @@ class ServiceFactory:
|
|
|
245
276
|
return ServiceFactory._build_transform_service(transform_predictor)
|
|
246
277
|
|
|
247
278
|
@staticmethod
|
|
248
|
-
def
|
|
279
|
+
def _get_padder_kwargs_from_config(config: AttrDict, mode: str) -> dict[str, Any]:
|
|
249
280
|
"""
|
|
250
|
-
|
|
281
|
+
Extracting padder kwargs from config.
|
|
251
282
|
|
|
252
283
|
Args:
|
|
253
284
|
config: Configuration object.
|
|
254
285
|
mode: Either `LAYOUT`, `CELL`, or `ITEM`.
|
|
286
|
+
"""
|
|
287
|
+
return {
|
|
288
|
+
"top": getattr(config.PT, mode).PAD.TOP,
|
|
289
|
+
"right": getattr(config.PT, mode).PAD.RIGHT,
|
|
290
|
+
"bottom": getattr(config.PT, mode).PAD.BOTTOM,
|
|
291
|
+
"left": getattr(config.PT, mode).PAD.LEFT,
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
@staticmethod
|
|
295
|
+
def _build_padder(top: int, right: int, bottom: int, left: int) -> PadTransform:
|
|
296
|
+
"""
|
|
297
|
+
Building a padder according to the config.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
top: Padding on the top side.
|
|
301
|
+
right: Padding on the right side.
|
|
302
|
+
bottom: Padding on the bottom side.
|
|
303
|
+
left: Padding on the left side.
|
|
255
304
|
|
|
256
305
|
Returns:
|
|
257
306
|
PadTransform: `PadTransform` instance.
|
|
258
307
|
"""
|
|
259
|
-
top, right, bottom, left
|
|
260
|
-
getattr(config.PT, mode).PAD.TOP,
|
|
261
|
-
getattr(config.PT, mode).PAD.RIGHT,
|
|
262
|
-
getattr(config.PT, mode).PAD.BOTTOM,
|
|
263
|
-
getattr(config.PT, mode).PAD.LEFT,
|
|
264
|
-
)
|
|
265
|
-
return PadTransform(pad_top=top, pad_right=right, pad_bottom=bottom, pad_left=left) #
|
|
308
|
+
return PadTransform(pad_top=top, pad_right=right, pad_bottom=bottom, pad_left=left)
|
|
266
309
|
|
|
267
310
|
@staticmethod
|
|
268
311
|
def build_padder(config: AttrDict, mode: str) -> PadTransform:
|
|
@@ -276,24 +319,37 @@ class ServiceFactory:
|
|
|
276
319
|
Returns:
|
|
277
320
|
PadTransform: `PadTransform` instance.
|
|
278
321
|
"""
|
|
279
|
-
|
|
322
|
+
padder_kwargs = ServiceFactory._get_padder_kwargs_from_config(config, mode)
|
|
323
|
+
return ServiceFactory._build_padder(**padder_kwargs)
|
|
280
324
|
|
|
281
325
|
@staticmethod
|
|
282
|
-
def
|
|
326
|
+
def _get_layout_service_kwargs_from_config(config: AttrDict, mode: str) -> dict[str, Any]:
|
|
283
327
|
"""
|
|
284
|
-
|
|
328
|
+
Extracting layout service kwargs from config.
|
|
285
329
|
|
|
286
330
|
Args:
|
|
287
331
|
config: Configuration object.
|
|
288
|
-
detector: Will be passed to the `ImageLayoutService`.
|
|
289
332
|
mode: Either `LAYOUT`, `CELL`, or `ITEM`.
|
|
290
|
-
|
|
291
|
-
Returns:
|
|
292
|
-
ImageLayoutService: `ImageLayoutService` instance.
|
|
293
333
|
"""
|
|
294
334
|
padder = None
|
|
295
335
|
if getattr(config.PT, mode).PADDING:
|
|
296
336
|
padder = ServiceFactory.build_padder(config, mode=mode)
|
|
337
|
+
return {
|
|
338
|
+
"padder": padder,
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
@staticmethod
|
|
342
|
+
def _build_layout_service(detector: ObjectDetector, padder: PadTransform) -> ImageLayoutService:
|
|
343
|
+
"""
|
|
344
|
+
Building a layout service with a given detector.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
detector: Will be passed to the `ImageLayoutService`.
|
|
348
|
+
padder: PadTransform instance.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
ImageLayoutService: `ImageLayoutService` instance.
|
|
352
|
+
"""
|
|
297
353
|
return ImageLayoutService(layout_detector=detector, to_image=True, crop_image=True, padder=padder)
|
|
298
354
|
|
|
299
355
|
@staticmethod
|
|
@@ -309,27 +365,51 @@ class ServiceFactory:
|
|
|
309
365
|
Returns:
|
|
310
366
|
ImageLayoutService: `ImageLayoutService` instance.
|
|
311
367
|
"""
|
|
312
|
-
|
|
368
|
+
layout_service_kwargs = ServiceFactory._get_layout_service_kwargs_from_config(config, mode)
|
|
369
|
+
return ServiceFactory._build_layout_service(detector, **layout_service_kwargs)
|
|
313
370
|
|
|
314
371
|
@staticmethod
|
|
315
|
-
def
|
|
372
|
+
def _get_layout_nms_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
316
373
|
"""
|
|
317
|
-
|
|
374
|
+
Extracting layout NMS service kwargs from config.
|
|
318
375
|
|
|
319
376
|
Args:
|
|
320
377
|
config: Configuration object.
|
|
321
|
-
|
|
322
|
-
Returns:
|
|
323
|
-
AnnotationNmsService: NMS service instance.
|
|
324
378
|
"""
|
|
379
|
+
|
|
325
380
|
if not isinstance(config.LAYOUT_NMS_PAIRS.COMBINATIONS, list) and not isinstance(
|
|
326
381
|
config.LAYOUT_NMS_PAIRS.COMBINATIONS[0], list
|
|
327
382
|
):
|
|
328
383
|
raise ValueError("LAYOUT_NMS_PAIRS must be a list of lists")
|
|
384
|
+
|
|
385
|
+
return {
|
|
386
|
+
"nms_pairs": config.LAYOUT_NMS_PAIRS.COMBINATIONS,
|
|
387
|
+
"thresholds": config.LAYOUT_NMS_PAIRS.THRESHOLDS,
|
|
388
|
+
"priority": config.LAYOUT_NMS_PAIRS.PRIORITY,
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
@staticmethod
|
|
392
|
+
def _build_layout_nms_service(
|
|
393
|
+
nms_pairs: Sequence[Sequence[Union[ObjectTypes, str]]],
|
|
394
|
+
thresholds: Union[float, Sequence[float]],
|
|
395
|
+
priority: Sequence[Union[ObjectTypes, str, None]],
|
|
396
|
+
) -> AnnotationNmsService:
|
|
397
|
+
"""
|
|
398
|
+
Building a NMS service for layout annotations.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
nms_pairs: Pairs of categories for NMS.
|
|
402
|
+
thresholds: NMS thresholds.
|
|
403
|
+
priority: Priority of categories.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
AnnotationNmsService: NMS service instance.
|
|
407
|
+
"""
|
|
408
|
+
|
|
329
409
|
return AnnotationNmsService(
|
|
330
|
-
nms_pairs=
|
|
331
|
-
thresholds=
|
|
332
|
-
priority=
|
|
410
|
+
nms_pairs=nms_pairs,
|
|
411
|
+
thresholds=thresholds,
|
|
412
|
+
priority=priority,
|
|
333
413
|
)
|
|
334
414
|
|
|
335
415
|
@staticmethod
|
|
@@ -343,29 +423,41 @@ class ServiceFactory:
|
|
|
343
423
|
Returns:
|
|
344
424
|
AnnotationNmsService: NMS service instance.
|
|
345
425
|
"""
|
|
346
|
-
|
|
426
|
+
nms_service_kwargs = ServiceFactory._get_layout_nms_service_kwargs_from_config(config)
|
|
427
|
+
return ServiceFactory._build_layout_nms_service(**nms_service_kwargs)
|
|
347
428
|
|
|
348
429
|
@staticmethod
|
|
349
|
-
def
|
|
430
|
+
def _get_sub_image_layout_service_kwargs_from_config(detector: ObjectDetector, mode: str) -> dict[str, Any]:
|
|
350
431
|
"""
|
|
351
|
-
|
|
432
|
+
Extracting sub image service kwargs from config.
|
|
352
433
|
|
|
353
434
|
Args:
|
|
354
|
-
config: Configuration object.
|
|
355
|
-
detector: Will be passed to the `SubImageLayoutService`.
|
|
356
435
|
mode: Either `LAYOUT`, `CELL`, or `ITEM`.
|
|
357
|
-
|
|
358
|
-
Returns:
|
|
359
|
-
SubImageLayoutService: `SubImageLayoutService` instance.
|
|
360
436
|
"""
|
|
437
|
+
|
|
361
438
|
exclude_category_names = []
|
|
362
|
-
padder = None
|
|
363
439
|
if mode == "ITEM":
|
|
364
440
|
if detector.__class__.__name__ in ("HFDetrDerivedDetector",):
|
|
365
441
|
exclude_category_names.extend(
|
|
366
442
|
[LayoutType.TABLE, CellType.COLUMN_HEADER, CellType.PROJECTED_ROW_HEADER, CellType.SPANNING]
|
|
367
443
|
)
|
|
368
|
-
|
|
444
|
+
return {"exclude_category_names": exclude_category_names}
|
|
445
|
+
|
|
446
|
+
@staticmethod
|
|
447
|
+
def _build_sub_image_service(
|
|
448
|
+
detector: ObjectDetector, padder: Optional[PadTransform], exclude_category_names: list[ObjectTypes]
|
|
449
|
+
) -> SubImageLayoutService:
|
|
450
|
+
"""
|
|
451
|
+
Building a sub image layout service with a given detector.
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
detector: Will be passed to the `SubImageLayoutService`.
|
|
455
|
+
padder: PadTransform instance.
|
|
456
|
+
exclude_category_names: Category names to exclude during detection.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
SubImageLayoutService: `SubImageLayoutService` instance.
|
|
460
|
+
"""
|
|
369
461
|
detect_result_generator = DetectResultGenerator(
|
|
370
462
|
categories_name_as_key=detector.categories.get_categories(as_dict=True, name_as_key=True),
|
|
371
463
|
exclude_category_names=exclude_category_names,
|
|
@@ -390,26 +482,37 @@ class ServiceFactory:
|
|
|
390
482
|
Returns:
|
|
391
483
|
SubImageLayoutService: `SubImageLayoutService` instance.
|
|
392
484
|
"""
|
|
393
|
-
|
|
485
|
+
padder = None
|
|
486
|
+
if mode == "ITEM":
|
|
487
|
+
padder = ServiceFactory.build_padder(config, mode)
|
|
488
|
+
sub_image_layout_service_kwargs = ServiceFactory._get_sub_image_layout_service_kwargs_from_config(
|
|
489
|
+
detector, mode
|
|
490
|
+
)
|
|
491
|
+
return ServiceFactory._build_sub_image_service(detector, padder, **sub_image_layout_service_kwargs)
|
|
394
492
|
|
|
395
493
|
@staticmethod
|
|
396
|
-
def
|
|
494
|
+
def _get_ocr_detector_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
397
495
|
"""
|
|
398
|
-
|
|
496
|
+
Extracting OCR detector kwargs from config.
|
|
399
497
|
|
|
400
498
|
Args:
|
|
401
499
|
config: Configuration object.
|
|
402
|
-
|
|
403
|
-
Returns:
|
|
404
|
-
Union[TesseractOcrDetector, DoctrTextRecognizer, TextractOcrDetector]: OCR detector instance.
|
|
405
500
|
"""
|
|
501
|
+
ocr_config_path = None
|
|
502
|
+
weights = None
|
|
503
|
+
languages = None
|
|
504
|
+
credentials_kwargs = None
|
|
505
|
+
use_tesseract = False
|
|
506
|
+
use_doctr = False
|
|
507
|
+
use_textract = False
|
|
508
|
+
|
|
406
509
|
if config.OCR.USE_TESSERACT:
|
|
510
|
+
use_tesseract = True
|
|
407
511
|
ocr_config_path = get_configs_dir_path() / config.OCR.CONFIG.TESSERACT
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
config_overwrite=[f"LANGUAGES={config.LANGUAGE}"] if config.LANGUAGE is not None else None,
|
|
411
|
-
)
|
|
512
|
+
languages = [f"LANGUAGES={config.LANGUAGE}"] if config.LANGUAGE is not None else None
|
|
513
|
+
|
|
412
514
|
if config.OCR.USE_DOCTR:
|
|
515
|
+
use_doctr = True
|
|
413
516
|
if config.LIB is None:
|
|
414
517
|
raise DependencyError("At least one of the env variables DD_USE_TF or DD_USE_TORCH must be set.")
|
|
415
518
|
weights = (
|
|
@@ -417,6 +520,63 @@ class ServiceFactory:
|
|
|
417
520
|
if config.LIB == "TF"
|
|
418
521
|
else (config.OCR.WEIGHTS.DOCTR_RECOGNITION.PT)
|
|
419
522
|
)
|
|
523
|
+
if config.OCR.USE_TEXTRACT:
|
|
524
|
+
use_textract = True
|
|
525
|
+
credentials_kwargs = {
|
|
526
|
+
"aws_access_key_id": environ.get("AWS_ACCESS_KEY", None),
|
|
527
|
+
"aws_secret_access_key": environ.get("AWS_SECRET_KEY", None),
|
|
528
|
+
"config": Config(region_name=environ.get("AWS_REGION", None)),
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
return {
|
|
532
|
+
"use_tesseract": use_tesseract,
|
|
533
|
+
"use_doctr": use_doctr,
|
|
534
|
+
"use_textract": use_textract,
|
|
535
|
+
"ocr_config_path": ocr_config_path,
|
|
536
|
+
"languages": languages,
|
|
537
|
+
"weights": weights,
|
|
538
|
+
"credentials_kwargs": credentials_kwargs,
|
|
539
|
+
"lib": config.LIB,
|
|
540
|
+
"device": config.DEVICE,
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
@staticmethod
|
|
544
|
+
def _build_ocr_detector(
|
|
545
|
+
use_tesseract: bool,
|
|
546
|
+
use_doctr: bool,
|
|
547
|
+
use_textract: bool,
|
|
548
|
+
ocr_config_path: str,
|
|
549
|
+
languages: Union[list[str], None],
|
|
550
|
+
weights: str,
|
|
551
|
+
credentials_kwargs: dict[str, Any],
|
|
552
|
+
lib: Literal["TF", "PT", None],
|
|
553
|
+
device: Literal["cuda", "cpu"],
|
|
554
|
+
) -> Union[TesseractOcrDetector, DoctrTextRecognizer, TextractOcrDetector]:
|
|
555
|
+
"""
|
|
556
|
+
Building OCR predictor.
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
use_tesseract: Whether to use Tesseract OCR.
|
|
560
|
+
use_doctr: Whether to use Doctr OCR.
|
|
561
|
+
use_textract: Whether to use Textract OCR.
|
|
562
|
+
ocr_config_path: Path to OCR config.
|
|
563
|
+
languages: Languages for OCR.
|
|
564
|
+
weights: Weights for Doctr OCR.
|
|
565
|
+
credentials_kwargs: Credentials for Textract OCR.
|
|
566
|
+
lib: Deep learning library to use.
|
|
567
|
+
device: Device to use for computation.
|
|
568
|
+
|
|
569
|
+
Returns:
|
|
570
|
+
Union[TesseractOcrDetector, DoctrTextRecognizer, TextractOcrDetector]: OCR detector instance.
|
|
571
|
+
"""
|
|
572
|
+
if use_tesseract:
|
|
573
|
+
return TesseractOcrDetector(
|
|
574
|
+
ocr_config_path,
|
|
575
|
+
config_overwrite=languages,
|
|
576
|
+
)
|
|
577
|
+
if use_doctr:
|
|
578
|
+
if lib is None:
|
|
579
|
+
raise DependencyError("At least one of the env variables DD_USE_TF or DD_USE_TORCH must be set.")
|
|
420
580
|
weights_path = ModelDownloadManager.maybe_download_weights_and_configs(weights)
|
|
421
581
|
profile = ModelCatalog.get_profile(weights)
|
|
422
582
|
# get_full_path_configs will complete the path even if the model is not registered
|
|
@@ -426,16 +586,11 @@ class ServiceFactory:
|
|
|
426
586
|
return DoctrTextRecognizer(
|
|
427
587
|
architecture=profile.architecture,
|
|
428
588
|
path_weights=weights_path,
|
|
429
|
-
device=
|
|
430
|
-
lib=
|
|
589
|
+
device=device,
|
|
590
|
+
lib=lib,
|
|
431
591
|
path_config_json=config_path,
|
|
432
592
|
)
|
|
433
|
-
if
|
|
434
|
-
credentials_kwargs = {
|
|
435
|
-
"aws_access_key_id": environ.get("AWS_ACCESS_KEY", None),
|
|
436
|
-
"aws_secret_access_key": environ.get("AWS_SECRET_KEY", None),
|
|
437
|
-
"config": Config(region_name=environ.get("AWS_REGION", None)),
|
|
438
|
-
}
|
|
593
|
+
if use_textract:
|
|
439
594
|
return TextractOcrDetector(**credentials_kwargs)
|
|
440
595
|
raise ValueError("You have set USE_OCR=True but any of USE_TESSERACT, USE_DOCTR, USE_TEXTRACT is set to False")
|
|
441
596
|
|
|
@@ -450,36 +605,114 @@ class ServiceFactory:
|
|
|
450
605
|
Returns:
|
|
451
606
|
Union[TesseractOcrDetector, DoctrTextRecognizer, TextractOcrDetector]: OCR detector instance.
|
|
452
607
|
"""
|
|
453
|
-
|
|
608
|
+
ocr_detector_kwargs = ServiceFactory._get_ocr_detector_kwargs_from_config(config)
|
|
609
|
+
return ServiceFactory._build_ocr_detector(**ocr_detector_kwargs)
|
|
454
610
|
|
|
455
611
|
@staticmethod
|
|
456
|
-
def
|
|
612
|
+
def _get_doctr_word_detector_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
457
613
|
"""
|
|
458
|
-
|
|
614
|
+
Extracting Doctr word detector kwargs from config.
|
|
459
615
|
|
|
460
616
|
Args:
|
|
461
617
|
config: Configuration object.
|
|
618
|
+
"""
|
|
619
|
+
weights = config.OCR.WEIGHTS.DOCTR_WORD.TF if config.LIB == "TF" else config.OCR.WEIGHTS.DOCTR_WORD.PT
|
|
620
|
+
profile = ModelCatalog.get_profile(weights)
|
|
621
|
+
return {
|
|
622
|
+
"weights": weights,
|
|
623
|
+
"profile": profile,
|
|
624
|
+
"device": config.DEVICE,
|
|
625
|
+
"lib": config.LIB,
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
@staticmethod
|
|
629
|
+
def _build_doctr_word_detector(
|
|
630
|
+
weights: str, profile: ModelProfile, device: Literal["cuda", "cpu"], lib: Literal["PT", "TF"]
|
|
631
|
+
) -> DoctrTextlineDetector:
|
|
632
|
+
"""
|
|
633
|
+
Building `DoctrTextlineDetector` instance.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
weights: Weights for Doctr word detector.
|
|
637
|
+
profile: Model profile for Doctr word detector.
|
|
638
|
+
device: Device to use for computation.
|
|
639
|
+
lib: Deep learning library to use.
|
|
462
640
|
|
|
463
641
|
Returns:
|
|
464
642
|
DoctrTextlineDetector: Textline detector instance.
|
|
465
643
|
"""
|
|
466
|
-
if
|
|
644
|
+
if lib is None:
|
|
467
645
|
raise DependencyError("At least one of the env variables DD_USE_TF or DD_USE_TORCH must be set.")
|
|
468
|
-
weights = config.OCR.WEIGHTS.DOCTR_WORD.TF if config.LIB == "TF" else config.OCR.WEIGHTS.DOCTR_WORD.PT
|
|
469
646
|
weights_path = ModelDownloadManager.maybe_download_weights_and_configs(weights)
|
|
470
|
-
profile = ModelCatalog.get_profile(weights)
|
|
471
647
|
if profile.architecture is None:
|
|
472
648
|
raise ValueError("model profile.architecture must be specified")
|
|
473
649
|
if profile.categories is None:
|
|
474
650
|
raise ValueError("model profile.categories must be specified")
|
|
475
|
-
return DoctrTextlineDetector(
|
|
476
|
-
|
|
477
|
-
|
|
651
|
+
return DoctrTextlineDetector(profile.architecture, weights_path, profile.categories, device, lib=lib)
|
|
652
|
+
|
|
653
|
+
@staticmethod
|
|
654
|
+
def build_doctr_word_detector(config: AttrDict) -> DoctrTextlineDetector:
|
|
655
|
+
"""
|
|
656
|
+
Building `DoctrTextlineDetector` instance.
|
|
657
|
+
|
|
658
|
+
Args:
|
|
659
|
+
config: Configuration object.
|
|
660
|
+
|
|
661
|
+
Returns:
|
|
662
|
+
DoctrTextlineDetector: Textline detector instance.
|
|
663
|
+
"""
|
|
664
|
+
doctr_word_detector_kwargs = ServiceFactory._get_doctr_word_detector_kwargs_from_config(config)
|
|
665
|
+
return ServiceFactory._build_doctr_word_detector(**doctr_word_detector_kwargs)
|
|
666
|
+
|
|
667
|
+
@staticmethod
|
|
668
|
+
def _get_table_segmentation_service_kwargs_from_config(config: AttrDict, detector_name: str) -> dict[str, Any]:
|
|
669
|
+
"""
|
|
670
|
+
Extracting table segmentation service kwargs from config.
|
|
671
|
+
|
|
672
|
+
Args:
|
|
673
|
+
config: Configuration object.
|
|
674
|
+
detector_name: An instance name of `ObjectDetector`.
|
|
675
|
+
"""
|
|
676
|
+
return {
|
|
677
|
+
"segment_rule": config.SEGMENTATION.ASSIGNMENT_RULE,
|
|
678
|
+
"threshold_rows": config.SEGMENTATION.THRESHOLD_ROWS,
|
|
679
|
+
"threshold_cols": config.SEGMENTATION.THRESHOLD_COLS,
|
|
680
|
+
"tile_table_with_items": config.SEGMENTATION.FULL_TABLE_TILING,
|
|
681
|
+
"remove_iou_threshold_rows": config.SEGMENTATION.REMOVE_IOU_THRESHOLD_ROWS,
|
|
682
|
+
"remove_iou_threshold_cols": config.SEGMENTATION.REMOVE_IOU_THRESHOLD_COLS,
|
|
683
|
+
"table_name": config.SEGMENTATION.TABLE_NAME,
|
|
684
|
+
"cell_names": config.SEGMENTATION.PUBTABLES_CELL_NAMES
|
|
685
|
+
if detector_name in ("HFDetrDerivedDetector",)
|
|
686
|
+
else config.SEGMENTATION.CELL_NAMES,
|
|
687
|
+
"spanning_cell_names": config.SEGMENTATION.PUBTABLES_SPANNING_CELL_NAMES,
|
|
688
|
+
"item_names": config.SEGMENTATION.PUBTABLES_ITEM_NAMES
|
|
689
|
+
if detector_name in ("HFDetrDerivedDetector",)
|
|
690
|
+
else config.SEGMENTATION.ITEM_NAMES,
|
|
691
|
+
"sub_item_names": config.SEGMENTATION.PUBTABLES_SUB_ITEM_NAMES
|
|
692
|
+
if detector_name in ("HFDetrDerivedDetector",)
|
|
693
|
+
else config.SEGMENTATION.SUB_ITEM_NAMES,
|
|
694
|
+
"item_header_cell_names": config.SEGMENTATION.PUBTABLES_ITEM_HEADER_CELL_NAMES,
|
|
695
|
+
"item_header_thresholds": config.SEGMENTATION.PUBTABLES_ITEM_HEADER_THRESHOLDS,
|
|
696
|
+
"stretch_rule": config.SEGMENTATION.STRETCH_RULE,
|
|
697
|
+
}
|
|
478
698
|
|
|
479
699
|
@staticmethod
|
|
480
700
|
def _build_table_segmentation_service(
|
|
481
|
-
config: AttrDict,
|
|
482
701
|
detector: ObjectDetector,
|
|
702
|
+
segment_rule: Literal["iou", "ioa"],
|
|
703
|
+
threshold_rows: float,
|
|
704
|
+
threshold_cols: float,
|
|
705
|
+
tile_table_with_items: bool,
|
|
706
|
+
remove_iou_threshold_rows: float,
|
|
707
|
+
remove_iou_threshold_cols: float,
|
|
708
|
+
table_name: Union[ObjectTypes, str],
|
|
709
|
+
cell_names: Sequence[Union[ObjectTypes, str]],
|
|
710
|
+
spanning_cell_names: Sequence[Union[ObjectTypes, str]],
|
|
711
|
+
item_names: Sequence[Union[ObjectTypes, str]],
|
|
712
|
+
sub_item_names: Sequence[Union[ObjectTypes, str]],
|
|
713
|
+
item_header_cell_names: Sequence[Union[ObjectTypes, str]],
|
|
714
|
+
item_header_thresholds: Sequence[float],
|
|
715
|
+
stretch_rule: Literal["left", "equal"],
|
|
483
716
|
) -> Union[PubtablesSegmentationService, TableSegmentationService]:
|
|
484
717
|
"""
|
|
485
718
|
Build and return a table segmentation service based on the provided detector.
|
|
@@ -495,8 +728,21 @@ class ServiceFactory:
|
|
|
495
728
|
configuration parameters from the `cfg` object but is tailored for different segmentation needs.
|
|
496
729
|
|
|
497
730
|
Args:
|
|
498
|
-
config: Configuration object.
|
|
499
731
|
detector: An instance of `ObjectDetector` used to determine the type of table segmentation service to build.
|
|
732
|
+
segment_rule: Rule for segmenting tables.
|
|
733
|
+
threshold_rows: Threshold for row segmentation.
|
|
734
|
+
threshold_cols: Threshold for column segmentation.
|
|
735
|
+
tile_table_with_items: Whether to tile the table with items.
|
|
736
|
+
remove_iou_threshold_rows: IOU threshold for removing rows.
|
|
737
|
+
remove_iou_threshold_cols: IOU threshold for removing columns.
|
|
738
|
+
table_name: Name of the table object type.
|
|
739
|
+
cell_names: Names of the cell object types.
|
|
740
|
+
spanning_cell_names: Names of the spanning cell object types.
|
|
741
|
+
item_names: Names of the item object types.
|
|
742
|
+
sub_item_names: Names of the sub-item object types.
|
|
743
|
+
item_header_cell_names: Names of the item header cell object types.
|
|
744
|
+
item_header_thresholds: Thresholds for item header segmentation.
|
|
745
|
+
stretch_rule: Rule for stretching cells.
|
|
500
746
|
|
|
501
747
|
Returns:
|
|
502
748
|
Table segmentation service instance.
|
|
@@ -504,35 +750,35 @@ class ServiceFactory:
|
|
|
504
750
|
table_segmentation: Union[PubtablesSegmentationService, TableSegmentationService]
|
|
505
751
|
if detector.__class__.__name__ in ("HFDetrDerivedDetector",):
|
|
506
752
|
table_segmentation = PubtablesSegmentationService(
|
|
507
|
-
segment_rule=
|
|
508
|
-
threshold_rows=
|
|
509
|
-
threshold_cols=
|
|
510
|
-
tile_table_with_items=
|
|
511
|
-
remove_iou_threshold_rows=
|
|
512
|
-
remove_iou_threshold_cols=
|
|
513
|
-
table_name=
|
|
514
|
-
cell_names=
|
|
515
|
-
spanning_cell_names=
|
|
516
|
-
item_names=
|
|
517
|
-
sub_item_names=
|
|
518
|
-
item_header_cell_names=
|
|
519
|
-
item_header_thresholds=
|
|
520
|
-
stretch_rule=
|
|
753
|
+
segment_rule=segment_rule,
|
|
754
|
+
threshold_rows=threshold_rows,
|
|
755
|
+
threshold_cols=threshold_cols,
|
|
756
|
+
tile_table_with_items=tile_table_with_items,
|
|
757
|
+
remove_iou_threshold_rows=remove_iou_threshold_rows,
|
|
758
|
+
remove_iou_threshold_cols=remove_iou_threshold_cols,
|
|
759
|
+
table_name=table_name,
|
|
760
|
+
cell_names=cell_names,
|
|
761
|
+
spanning_cell_names=spanning_cell_names,
|
|
762
|
+
item_names=item_names,
|
|
763
|
+
sub_item_names=sub_item_names,
|
|
764
|
+
item_header_cell_names=item_header_cell_names,
|
|
765
|
+
item_header_thresholds=item_header_thresholds,
|
|
766
|
+
stretch_rule=stretch_rule,
|
|
521
767
|
)
|
|
522
768
|
|
|
523
769
|
else:
|
|
524
770
|
table_segmentation = TableSegmentationService(
|
|
525
|
-
segment_rule=
|
|
526
|
-
threshold_rows=
|
|
527
|
-
threshold_cols=
|
|
528
|
-
tile_table_with_items=
|
|
529
|
-
remove_iou_threshold_rows=
|
|
530
|
-
remove_iou_threshold_cols=
|
|
531
|
-
table_name=
|
|
532
|
-
cell_names=
|
|
533
|
-
item_names=
|
|
534
|
-
sub_item_names=
|
|
535
|
-
stretch_rule=
|
|
771
|
+
segment_rule=segment_rule,
|
|
772
|
+
threshold_rows=threshold_rows,
|
|
773
|
+
threshold_cols=threshold_cols,
|
|
774
|
+
tile_table_with_items=tile_table_with_items,
|
|
775
|
+
remove_iou_threshold_rows=remove_iou_threshold_rows,
|
|
776
|
+
remove_iou_threshold_cols=remove_iou_threshold_cols,
|
|
777
|
+
table_name=table_name,
|
|
778
|
+
cell_names=cell_names,
|
|
779
|
+
item_names=item_names,
|
|
780
|
+
sub_item_names=sub_item_names,
|
|
781
|
+
stretch_rule=stretch_rule,
|
|
536
782
|
)
|
|
537
783
|
return table_segmentation
|
|
538
784
|
|
|
@@ -561,10 +807,29 @@ class ServiceFactory:
|
|
|
561
807
|
Returns:
|
|
562
808
|
Table segmentation service instance.
|
|
563
809
|
"""
|
|
564
|
-
|
|
810
|
+
table_segmentation_service_kwargs = ServiceFactory._get_table_segmentation_service_kwargs_from_config(
|
|
811
|
+
config, detector.__class__.__name__
|
|
812
|
+
)
|
|
813
|
+
return ServiceFactory._build_table_segmentation_service(detector, **table_segmentation_service_kwargs)
|
|
565
814
|
|
|
566
815
|
@staticmethod
|
|
567
|
-
def
|
|
816
|
+
def _get_table_refinement_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
817
|
+
"""
|
|
818
|
+
Extracting table segmentation refinement service kwargs from config.
|
|
819
|
+
|
|
820
|
+
Args:
|
|
821
|
+
config: Configuration object.
|
|
822
|
+
"""
|
|
823
|
+
|
|
824
|
+
return {
|
|
825
|
+
"table_names": [config.SEGMENTATION.TABLE_NAME],
|
|
826
|
+
"cell_names": config.SEGMENTATION.PUBTABLES_CELL_NAMES,
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
@staticmethod
|
|
830
|
+
def _build_table_refinement_service(
|
|
831
|
+
table_names: Sequence[ObjectTypes], cell_names: Sequence[ObjectTypes]
|
|
832
|
+
) -> TableSegmentationRefinementService:
|
|
568
833
|
"""
|
|
569
834
|
Building a table segmentation refinement service.
|
|
570
835
|
|
|
@@ -574,10 +839,7 @@ class ServiceFactory:
|
|
|
574
839
|
Returns:
|
|
575
840
|
TableSegmentationRefinementService: Refinement service instance.
|
|
576
841
|
"""
|
|
577
|
-
return TableSegmentationRefinementService(
|
|
578
|
-
[config.SEGMENTATION.TABLE_NAME],
|
|
579
|
-
config.SEGMENTATION.PUBTABLES_CELL_NAMES,
|
|
580
|
-
)
|
|
842
|
+
return TableSegmentationRefinementService(table_names=table_names, cell_names=cell_names)
|
|
581
843
|
|
|
582
844
|
@staticmethod
|
|
583
845
|
def build_table_refinement_service(config: AttrDict) -> TableSegmentationRefinementService:
|
|
@@ -590,22 +852,35 @@ class ServiceFactory:
|
|
|
590
852
|
Returns:
|
|
591
853
|
TableSegmentationRefinementService: Refinement service instance.
|
|
592
854
|
"""
|
|
593
|
-
|
|
855
|
+
table_refinement_service_kwargs = ServiceFactory._get_table_refinement_service_kwargs_from_config(config)
|
|
856
|
+
return ServiceFactory._build_table_refinement_service(**table_refinement_service_kwargs)
|
|
594
857
|
|
|
595
858
|
@staticmethod
|
|
596
|
-
def
|
|
859
|
+
def _get_pdf_text_detector_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
597
860
|
"""
|
|
598
|
-
|
|
861
|
+
Extracting PDF text detector kwargs from config.
|
|
599
862
|
|
|
600
863
|
Args:
|
|
601
864
|
config: Configuration object.
|
|
865
|
+
"""
|
|
866
|
+
return {
|
|
867
|
+
"x_tolerance": config.PDF_MINER.X_TOLERANCE,
|
|
868
|
+
"y_tolerance": config.PDF_MINER.Y_TOLERANCE,
|
|
869
|
+
}
|
|
870
|
+
|
|
871
|
+
@staticmethod
|
|
872
|
+
def _build_pdf_text_detector(x_tolerance: int, y_tolerance: int) -> PdfPlumberTextDetector:
|
|
873
|
+
"""
|
|
874
|
+
Building a PDF text detector.
|
|
875
|
+
|
|
876
|
+
Args:
|
|
877
|
+
x_tolerance: X tolerance for text extraction.
|
|
878
|
+
y_tolerance: Y tolerance for text extraction.
|
|
602
879
|
|
|
603
880
|
Returns:
|
|
604
881
|
PdfPlumberTextDetector: PDF text detector instance.
|
|
605
882
|
"""
|
|
606
|
-
return PdfPlumberTextDetector(
|
|
607
|
-
x_tolerance=config.PDF_MINER.X_TOLERANCE, y_tolerance=config.PDF_MINER.Y_TOLERANCE
|
|
608
|
-
)
|
|
883
|
+
return PdfPlumberTextDetector(x_tolerance=x_tolerance, y_tolerance=y_tolerance)
|
|
609
884
|
|
|
610
885
|
@staticmethod
|
|
611
886
|
def build_pdf_text_detector(config: AttrDict) -> PdfPlumberTextDetector:
|
|
@@ -618,7 +893,8 @@ class ServiceFactory:
|
|
|
618
893
|
Returns:
|
|
619
894
|
PdfPlumberTextDetector: PDF text detector instance.
|
|
620
895
|
"""
|
|
621
|
-
|
|
896
|
+
pdf_text_detector_kwargs = ServiceFactory._get_pdf_text_detector_kwargs_from_config(config)
|
|
897
|
+
return ServiceFactory._build_pdf_text_detector(**pdf_text_detector_kwargs)
|
|
622
898
|
|
|
623
899
|
@staticmethod
|
|
624
900
|
def _build_pdf_miner_text_service(detector: PdfMiner) -> TextExtractionService:
|
|
@@ -672,24 +948,34 @@ class ServiceFactory:
|
|
|
672
948
|
"""
|
|
673
949
|
return ServiceFactory._build_doctr_word_detector_service(detector)
|
|
674
950
|
|
|
951
|
+
@staticmethod
|
|
952
|
+
def _get_text_extraction_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
953
|
+
"""
|
|
954
|
+
Extracting text extraction service kwargs from config.
|
|
955
|
+
|
|
956
|
+
Args:
|
|
957
|
+
config: Configuration object.
|
|
958
|
+
"""
|
|
959
|
+
return {
|
|
960
|
+
"extract_from_roi": config.TEXT_CONTAINER if config.OCR.USE_DOCTR else None,
|
|
961
|
+
}
|
|
962
|
+
|
|
675
963
|
@staticmethod
|
|
676
964
|
def _build_text_extraction_service(
|
|
677
|
-
|
|
965
|
+
detector: Union[TesseractOcrDetector, DoctrTextRecognizer, TextractOcrDetector],
|
|
966
|
+
extract_from_roi: Union[Sequence[ObjectTypes], ObjectTypes, None] = None,
|
|
678
967
|
) -> TextExtractionService:
|
|
679
968
|
"""
|
|
680
969
|
Building a text extraction service.
|
|
681
970
|
|
|
682
971
|
Args:
|
|
683
|
-
config: Configuration object.
|
|
684
972
|
detector: OCR detector instance.
|
|
973
|
+
extract_from_roi: ROI categories to extract text from.
|
|
685
974
|
|
|
686
975
|
Returns:
|
|
687
976
|
TextExtractionService: Text extraction service instance.
|
|
688
977
|
"""
|
|
689
|
-
return TextExtractionService(
|
|
690
|
-
detector,
|
|
691
|
-
extract_from_roi=config.TEXT_CONTAINER if config.OCR.USE_DOCTR else None,
|
|
692
|
-
)
|
|
978
|
+
return TextExtractionService(detector, extract_from_roi=extract_from_roi)
|
|
693
979
|
|
|
694
980
|
@staticmethod
|
|
695
981
|
def build_text_extraction_service(
|
|
@@ -705,28 +991,55 @@ class ServiceFactory:
|
|
|
705
991
|
Returns:
|
|
706
992
|
TextExtractionService: Text extraction service instance.
|
|
707
993
|
"""
|
|
708
|
-
|
|
994
|
+
text_extraction_service_kwargs = ServiceFactory._get_text_extraction_service_kwargs_from_config(config)
|
|
995
|
+
return ServiceFactory._build_text_extraction_service(detector, **text_extraction_service_kwargs)
|
|
709
996
|
|
|
710
997
|
@staticmethod
|
|
711
|
-
def
|
|
998
|
+
def _get_word_matching_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
712
999
|
"""
|
|
713
|
-
|
|
1000
|
+
Extracting word matching service kwargs from config.
|
|
714
1001
|
|
|
715
1002
|
Args:
|
|
716
1003
|
config: Configuration object.
|
|
1004
|
+
"""
|
|
1005
|
+
return {
|
|
1006
|
+
"matching_rule": config.WORD_MATCHING.RULE,
|
|
1007
|
+
"threshold": config.WORD_MATCHING.THRESHOLD,
|
|
1008
|
+
"max_parent_only": config.WORD_MATCHING.MAX_PARENT_ONLY,
|
|
1009
|
+
"parental_categories": config.WORD_MATCHING.PARENTAL_CATEGORIES,
|
|
1010
|
+
"text_container": config.TEXT_CONTAINER,
|
|
1011
|
+
}
|
|
1012
|
+
|
|
1013
|
+
@staticmethod
|
|
1014
|
+
def _build_word_matching_service(
|
|
1015
|
+
matching_rule: Literal["iou", "ioa"],
|
|
1016
|
+
threshold: float,
|
|
1017
|
+
max_parent_only: bool,
|
|
1018
|
+
parental_categories: Union[Sequence[ObjectTypes], ObjectTypes, None],
|
|
1019
|
+
text_container: Union[Sequence[ObjectTypes], ObjectTypes, None],
|
|
1020
|
+
) -> MatchingService:
|
|
1021
|
+
"""
|
|
1022
|
+
Building a word matching service.
|
|
1023
|
+
|
|
1024
|
+
Args:
|
|
1025
|
+
matching_rule: Matching rule for intersection matcher.
|
|
1026
|
+
threshold: Threshold for intersection matcher.
|
|
1027
|
+
max_parent_only: Whether to use max parent only.
|
|
1028
|
+
parental_categories: Parent categories for matching.
|
|
1029
|
+
text_container: Text container categories.
|
|
717
1030
|
|
|
718
1031
|
Returns:
|
|
719
1032
|
MatchingService: Word matching service instance.
|
|
720
1033
|
"""
|
|
721
1034
|
matcher = IntersectionMatcher(
|
|
722
|
-
matching_rule=
|
|
723
|
-
threshold=
|
|
724
|
-
max_parent_only=
|
|
1035
|
+
matching_rule=matching_rule,
|
|
1036
|
+
threshold=threshold,
|
|
1037
|
+
max_parent_only=max_parent_only,
|
|
725
1038
|
)
|
|
726
1039
|
family_compounds = [
|
|
727
1040
|
FamilyCompound(
|
|
728
|
-
parent_categories=
|
|
729
|
-
child_categories=
|
|
1041
|
+
parent_categories=parental_categories,
|
|
1042
|
+
child_categories=text_container,
|
|
730
1043
|
relationship_key=Relationships.CHILD,
|
|
731
1044
|
),
|
|
732
1045
|
FamilyCompound(
|
|
@@ -753,15 +1066,33 @@ class ServiceFactory:
|
|
|
753
1066
|
Returns:
|
|
754
1067
|
MatchingService: Word matching service instance.
|
|
755
1068
|
"""
|
|
756
|
-
|
|
1069
|
+
word_matching_service_kwargs = ServiceFactory._get_word_matching_service_kwargs_from_config(config)
|
|
1070
|
+
return ServiceFactory._build_word_matching_service(**word_matching_service_kwargs)
|
|
757
1071
|
|
|
758
1072
|
@staticmethod
|
|
759
|
-
def
|
|
1073
|
+
def _get_layout_link_matching_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
760
1074
|
"""
|
|
761
|
-
|
|
1075
|
+
Extracting layout link matching service kwargs from config.
|
|
762
1076
|
|
|
763
1077
|
Args:
|
|
764
1078
|
config: Configuration object.
|
|
1079
|
+
"""
|
|
1080
|
+
return {
|
|
1081
|
+
"parental_categories": config.LAYOUT_LINK.PARENTAL_CATEGORIES,
|
|
1082
|
+
"child_categories": config.LAYOUT_LINK.CHILD_CATEGORIES,
|
|
1083
|
+
}
|
|
1084
|
+
|
|
1085
|
+
@staticmethod
|
|
1086
|
+
def _build_layout_link_matching_service(
|
|
1087
|
+
parental_categories: Union[Sequence[ObjectTypes], ObjectTypes, None],
|
|
1088
|
+
child_categories: Union[Sequence[ObjectTypes], ObjectTypes, None],
|
|
1089
|
+
) -> MatchingService:
|
|
1090
|
+
"""
|
|
1091
|
+
Building a layout link matching service.
|
|
1092
|
+
|
|
1093
|
+
Args:
|
|
1094
|
+
parental_categories: Parent categories for layout linking.
|
|
1095
|
+
child_categories: Child categories for layout linking.
|
|
765
1096
|
|
|
766
1097
|
Returns:
|
|
767
1098
|
MatchingService: Layout link matching service instance.
|
|
@@ -769,8 +1100,8 @@ class ServiceFactory:
|
|
|
769
1100
|
neighbor_matcher = NeighbourMatcher()
|
|
770
1101
|
family_compounds = [
|
|
771
1102
|
FamilyCompound(
|
|
772
|
-
parent_categories=
|
|
773
|
-
child_categories=
|
|
1103
|
+
parent_categories=parental_categories,
|
|
1104
|
+
child_categories=child_categories,
|
|
774
1105
|
relationship_key=Relationships.LAYOUT_LINK,
|
|
775
1106
|
)
|
|
776
1107
|
]
|
|
@@ -790,23 +1121,44 @@ class ServiceFactory:
|
|
|
790
1121
|
Returns:
|
|
791
1122
|
MatchingService: Layout link matching service instance.
|
|
792
1123
|
"""
|
|
793
|
-
|
|
1124
|
+
layout_link_matching_service_kwargs = ServiceFactory._get_layout_link_matching_service_kwargs_from_config(
|
|
1125
|
+
config
|
|
1126
|
+
)
|
|
1127
|
+
return ServiceFactory._build_layout_link_matching_service(**layout_link_matching_service_kwargs)
|
|
794
1128
|
|
|
795
1129
|
@staticmethod
|
|
796
|
-
def
|
|
1130
|
+
def _get_line_matching_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
797
1131
|
"""
|
|
798
|
-
|
|
1132
|
+
Extracting line matching service kwargs from config.
|
|
799
1133
|
|
|
800
1134
|
Args:
|
|
801
1135
|
config: Configuration object.
|
|
1136
|
+
"""
|
|
1137
|
+
return {
|
|
1138
|
+
"matching_rule": config.WORD_MATCHING.RULE,
|
|
1139
|
+
"threshold": config.WORD_MATCHING.THRESHOLD,
|
|
1140
|
+
"max_parent_only": config.WORD_MATCHING.MAX_PARENT_ONLY,
|
|
1141
|
+
}
|
|
1142
|
+
|
|
1143
|
+
@staticmethod
|
|
1144
|
+
def _build_line_matching_service(
|
|
1145
|
+
matching_rule: Literal["iou", "ioa"], threshold: float, max_parent_only: bool
|
|
1146
|
+
) -> MatchingService:
|
|
1147
|
+
"""
|
|
1148
|
+
Building a line matching service.
|
|
1149
|
+
|
|
1150
|
+
Args:
|
|
1151
|
+
matching_rule: Matching rule for intersection matcher.
|
|
1152
|
+
threshold: Threshold for intersection matcher.
|
|
1153
|
+
max_parent_only: Whether to use max parent only.
|
|
802
1154
|
|
|
803
1155
|
Returns:
|
|
804
1156
|
MatchingService: Line matching service instance.
|
|
805
1157
|
"""
|
|
806
1158
|
matcher = IntersectionMatcher(
|
|
807
|
-
matching_rule=
|
|
808
|
-
threshold=
|
|
809
|
-
max_parent_only=
|
|
1159
|
+
matching_rule=matching_rule,
|
|
1160
|
+
threshold=threshold,
|
|
1161
|
+
max_parent_only=max_parent_only,
|
|
810
1162
|
)
|
|
811
1163
|
family_compounds = [
|
|
812
1164
|
FamilyCompound(
|
|
@@ -831,28 +1183,64 @@ class ServiceFactory:
|
|
|
831
1183
|
Returns:
|
|
832
1184
|
MatchingService: Line matching service instance.
|
|
833
1185
|
"""
|
|
834
|
-
|
|
1186
|
+
line_matching_service_kwargs = ServiceFactory._get_line_matching_service_kwargs_from_config(config)
|
|
1187
|
+
return ServiceFactory._build_line_matching_service(**line_matching_service_kwargs)
|
|
835
1188
|
|
|
836
1189
|
@staticmethod
|
|
837
|
-
def
|
|
1190
|
+
def _get_text_order_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
838
1191
|
"""
|
|
839
|
-
|
|
1192
|
+
Extracting text order service kwargs from config.
|
|
840
1193
|
|
|
841
1194
|
Args:
|
|
842
1195
|
config: Configuration object.
|
|
1196
|
+
"""
|
|
1197
|
+
return {
|
|
1198
|
+
"text_container": config.TEXT_CONTAINER,
|
|
1199
|
+
"text_block_categories": config.TEXT_ORDERING.TEXT_BLOCK_CATEGORIES,
|
|
1200
|
+
"floating_text_block_categories": config.TEXT_ORDERING.FLOATING_TEXT_BLOCK_CATEGORIES,
|
|
1201
|
+
"include_residual_text_container": config.TEXT_ORDERING.INCLUDE_RESIDUAL_TEXT_CONTAINER,
|
|
1202
|
+
"starting_point_tolerance": config.TEXT_ORDERING.STARTING_POINT_TOLERANCE,
|
|
1203
|
+
"broken_line_tolerance": config.TEXT_ORDERING.BROKEN_LINE_TOLERANCE,
|
|
1204
|
+
"height_tolerance": config.TEXT_ORDERING.HEIGHT_TOLERANCE,
|
|
1205
|
+
"paragraph_break": config.TEXT_ORDERING.PARAGRAPH_BREAK,
|
|
1206
|
+
}
|
|
1207
|
+
|
|
1208
|
+
@staticmethod
|
|
1209
|
+
def _build_text_order_service(
|
|
1210
|
+
text_container: str,
|
|
1211
|
+
text_block_categories: Sequence[str],
|
|
1212
|
+
floating_text_block_categories: Sequence[str],
|
|
1213
|
+
include_residual_text_container: bool,
|
|
1214
|
+
starting_point_tolerance: float,
|
|
1215
|
+
broken_line_tolerance: float,
|
|
1216
|
+
height_tolerance: float,
|
|
1217
|
+
paragraph_break: float,
|
|
1218
|
+
) -> TextOrderService:
|
|
1219
|
+
"""
|
|
1220
|
+
Building a text order service.
|
|
1221
|
+
|
|
1222
|
+
Args:
|
|
1223
|
+
text_container: Text container categories.
|
|
1224
|
+
text_block_categories: Text block categories for ordering.
|
|
1225
|
+
floating_text_block_categories: Floating text block categories.
|
|
1226
|
+
include_residual_text_container: Whether to include residual text container.
|
|
1227
|
+
starting_point_tolerance: Starting point tolerance for text ordering.
|
|
1228
|
+
broken_line_tolerance: Broken line tolerance for text ordering.
|
|
1229
|
+
height_tolerance: Height tolerance for text ordering.
|
|
1230
|
+
paragraph_break: Paragraph break threshold.
|
|
843
1231
|
|
|
844
1232
|
Returns:
|
|
845
1233
|
TextOrderService: Text order service instance.
|
|
846
1234
|
"""
|
|
847
1235
|
return TextOrderService(
|
|
848
|
-
text_container=
|
|
849
|
-
text_block_categories=
|
|
850
|
-
floating_text_block_categories=
|
|
851
|
-
include_residual_text_container=
|
|
852
|
-
starting_point_tolerance=
|
|
853
|
-
broken_line_tolerance=
|
|
854
|
-
height_tolerance=
|
|
855
|
-
paragraph_break=
|
|
1236
|
+
text_container=text_container,
|
|
1237
|
+
text_block_categories=text_block_categories,
|
|
1238
|
+
floating_text_block_categories=floating_text_block_categories,
|
|
1239
|
+
include_residual_text_container=include_residual_text_container,
|
|
1240
|
+
starting_point_tolerance=starting_point_tolerance,
|
|
1241
|
+
broken_line_tolerance=broken_line_tolerance,
|
|
1242
|
+
height_tolerance=height_tolerance,
|
|
1243
|
+
paragraph_break=paragraph_break,
|
|
856
1244
|
)
|
|
857
1245
|
|
|
858
1246
|
@staticmethod
|
|
@@ -866,18 +1254,16 @@ class ServiceFactory:
|
|
|
866
1254
|
Returns:
|
|
867
1255
|
TextOrderService: Text order service instance.
|
|
868
1256
|
"""
|
|
869
|
-
|
|
1257
|
+
text_order_service_kwargs = ServiceFactory._get_text_order_service_kwargs_from_config(config)
|
|
1258
|
+
return ServiceFactory._build_text_order_service(**text_order_service_kwargs)
|
|
870
1259
|
|
|
871
1260
|
@staticmethod
|
|
872
|
-
def
|
|
1261
|
+
def _get_sequence_classifier_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
873
1262
|
"""
|
|
874
|
-
|
|
1263
|
+
Extracting sequence classifier kwargs from config.
|
|
875
1264
|
|
|
876
1265
|
Args:
|
|
877
|
-
config: Configuration object
|
|
878
|
-
|
|
879
|
-
Returns:
|
|
880
|
-
A sequence classifier instance constructed according to the specified configuration.
|
|
1266
|
+
config: Configuration object.
|
|
881
1267
|
"""
|
|
882
1268
|
config_path = ModelCatalog.get_full_path_configs(config.LM_SEQUENCE_CLASS.WEIGHTS)
|
|
883
1269
|
weights_path = ModelDownloadManager.maybe_download_weights_and_configs(config.LM_SEQUENCE_CLASS.WEIGHTS)
|
|
@@ -885,47 +1271,79 @@ class ServiceFactory:
|
|
|
885
1271
|
categories = profile.categories if profile.categories is not None else {}
|
|
886
1272
|
use_xlm_tokenizer = "xlm_tokenizer" == profile.architecture
|
|
887
1273
|
|
|
888
|
-
|
|
1274
|
+
return {
|
|
1275
|
+
"config_path": config_path,
|
|
1276
|
+
"weights_path": weights_path,
|
|
1277
|
+
"categories": categories,
|
|
1278
|
+
"device": config.DEVICE,
|
|
1279
|
+
"use_xlm_tokenizer": use_xlm_tokenizer,
|
|
1280
|
+
"model_wrapper": profile.model_wrapper,
|
|
1281
|
+
}
|
|
1282
|
+
|
|
1283
|
+
@staticmethod
|
|
1284
|
+
def _build_sequence_classifier(
|
|
1285
|
+
config_path: str,
|
|
1286
|
+
weights_path: str,
|
|
1287
|
+
categories: Mapping[int, Union[ObjectTypes, str]],
|
|
1288
|
+
device: Literal["cuda", "cpu"],
|
|
1289
|
+
use_xlm_tokenizer: bool,
|
|
1290
|
+
model_wrapper: str,
|
|
1291
|
+
) -> Union[LayoutSequenceModels, LmSequenceModels]:
|
|
1292
|
+
"""
|
|
1293
|
+
Builds and returns a sequence classifier instance.
|
|
1294
|
+
|
|
1295
|
+
Args:
|
|
1296
|
+
config_path: Path to model configuration.
|
|
1297
|
+
weights_path: Path to model weights.
|
|
1298
|
+
categories: Model categories mapping.
|
|
1299
|
+
device: Device to run model on.
|
|
1300
|
+
use_xlm_tokenizer: Whether to use XLM tokenizer.
|
|
1301
|
+
model_wrapper: Model wrapper class name.
|
|
1302
|
+
|
|
1303
|
+
Returns:
|
|
1304
|
+
A sequence classifier instance constructed according to the specified configuration.
|
|
1305
|
+
"""
|
|
1306
|
+
if model_wrapper in ("HFLayoutLmSequenceClassifier",):
|
|
889
1307
|
return HFLayoutLmSequenceClassifier(
|
|
890
1308
|
path_config_json=config_path,
|
|
891
1309
|
path_weights=weights_path,
|
|
892
1310
|
categories=categories,
|
|
893
|
-
device=
|
|
1311
|
+
device=device,
|
|
894
1312
|
use_xlm_tokenizer=use_xlm_tokenizer,
|
|
895
1313
|
)
|
|
896
|
-
if
|
|
1314
|
+
if model_wrapper in ("HFLayoutLmv2SequenceClassifier",):
|
|
897
1315
|
return HFLayoutLmv2SequenceClassifier(
|
|
898
1316
|
path_config_json=config_path,
|
|
899
1317
|
path_weights=weights_path,
|
|
900
1318
|
categories=categories,
|
|
901
|
-
device=
|
|
1319
|
+
device=device,
|
|
902
1320
|
use_xlm_tokenizer=use_xlm_tokenizer,
|
|
903
1321
|
)
|
|
904
|
-
if
|
|
1322
|
+
if model_wrapper in ("HFLayoutLmv3SequenceClassifier",):
|
|
905
1323
|
return HFLayoutLmv3SequenceClassifier(
|
|
906
1324
|
path_config_json=config_path,
|
|
907
1325
|
path_weights=weights_path,
|
|
908
1326
|
categories=categories,
|
|
909
|
-
device=
|
|
1327
|
+
device=device,
|
|
910
1328
|
use_xlm_tokenizer=use_xlm_tokenizer,
|
|
911
1329
|
)
|
|
912
|
-
if
|
|
1330
|
+
if model_wrapper in ("HFLiltSequenceClassifier",):
|
|
913
1331
|
return HFLiltSequenceClassifier(
|
|
914
1332
|
path_config_json=config_path,
|
|
915
1333
|
path_weights=weights_path,
|
|
916
1334
|
categories=categories,
|
|
917
|
-
device=
|
|
1335
|
+
device=device,
|
|
918
1336
|
use_xlm_tokenizer=use_xlm_tokenizer,
|
|
919
1337
|
)
|
|
920
|
-
if
|
|
1338
|
+
if model_wrapper in ("HFLmSequenceClassifier",):
|
|
921
1339
|
return HFLmSequenceClassifier(
|
|
922
1340
|
path_config_json=config_path,
|
|
923
1341
|
path_weights=weights_path,
|
|
924
1342
|
categories=categories,
|
|
925
|
-
device=
|
|
1343
|
+
device=device,
|
|
926
1344
|
use_xlm_tokenizer=use_xlm_tokenizer,
|
|
927
1345
|
)
|
|
928
|
-
raise ValueError(f"Unsupported model wrapper: {
|
|
1346
|
+
raise ValueError(f"Unsupported model wrapper: {model_wrapper}")
|
|
929
1347
|
|
|
930
1348
|
@staticmethod
|
|
931
1349
|
def build_sequence_classifier(config: AttrDict) -> Union[LayoutSequenceModels, LmSequenceModels]:
|
|
@@ -938,18 +1356,31 @@ class ServiceFactory:
|
|
|
938
1356
|
Returns:
|
|
939
1357
|
A sequence classifier instance constructed according to the specified configuration.
|
|
940
1358
|
"""
|
|
941
|
-
|
|
1359
|
+
sequence_classifier_kwargs = ServiceFactory._get_sequence_classifier_kwargs_from_config(config)
|
|
1360
|
+
return ServiceFactory._build_sequence_classifier(**sequence_classifier_kwargs)
|
|
1361
|
+
|
|
1362
|
+
@staticmethod
|
|
1363
|
+
def _get_sequence_classifier_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
1364
|
+
"""
|
|
1365
|
+
Extracting sequence classifier service kwargs from config.
|
|
1366
|
+
|
|
1367
|
+
Args:
|
|
1368
|
+
config: Configuration object.
|
|
1369
|
+
"""
|
|
1370
|
+
return {
|
|
1371
|
+
"use_other_as_default_category": config.LM_SEQUENCE_CLASS.USE_OTHER_AS_DEFAULT_CATEGORY,
|
|
1372
|
+
}
|
|
942
1373
|
|
|
943
1374
|
@staticmethod
|
|
944
1375
|
def _build_sequence_classifier_service(
|
|
945
|
-
|
|
1376
|
+
sequence_classifier: Union[LayoutSequenceModels, LmSequenceModels], use_other_as_default_category: bool
|
|
946
1377
|
) -> LMSequenceClassifierService:
|
|
947
1378
|
"""
|
|
948
1379
|
Building a sequence classifier service.
|
|
949
1380
|
|
|
950
1381
|
Args:
|
|
951
|
-
config: Configuration object.
|
|
952
1382
|
sequence_classifier: Sequence classifier instance.
|
|
1383
|
+
use_other_as_default_category: Whether to use other as default category.
|
|
953
1384
|
|
|
954
1385
|
Returns:
|
|
955
1386
|
LMSequenceClassifierService: Text order service instance.
|
|
@@ -961,7 +1392,7 @@ class ServiceFactory:
|
|
|
961
1392
|
return LMSequenceClassifierService(
|
|
962
1393
|
tokenizer=tokenizer_fast,
|
|
963
1394
|
language_model=sequence_classifier,
|
|
964
|
-
use_other_as_default_category=
|
|
1395
|
+
use_other_as_default_category=use_other_as_default_category,
|
|
965
1396
|
)
|
|
966
1397
|
|
|
967
1398
|
@staticmethod
|
|
@@ -978,60 +1409,93 @@ class ServiceFactory:
|
|
|
978
1409
|
Returns:
|
|
979
1410
|
LMSequenceClassifierService: Text order service instance.
|
|
980
1411
|
"""
|
|
981
|
-
|
|
1412
|
+
sequence_classifier_service_kwargs = ServiceFactory._get_sequence_classifier_service_kwargs_from_config(config)
|
|
1413
|
+
return ServiceFactory._build_sequence_classifier_service(
|
|
1414
|
+
sequence_classifier, **sequence_classifier_service_kwargs
|
|
1415
|
+
)
|
|
982
1416
|
|
|
983
1417
|
@staticmethod
|
|
984
|
-
def
|
|
1418
|
+
def _get_token_classifier_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
985
1419
|
"""
|
|
986
|
-
|
|
1420
|
+
Extracting token classifier kwargs from config.
|
|
987
1421
|
|
|
988
1422
|
Args:
|
|
989
1423
|
config: Configuration object.
|
|
990
|
-
|
|
991
|
-
Returns:
|
|
992
|
-
The instantiated token classifier model.
|
|
993
1424
|
"""
|
|
994
1425
|
config_path = ModelCatalog.get_full_path_configs(config.LM_TOKEN_CLASS.WEIGHTS)
|
|
995
1426
|
weights_path = ModelDownloadManager.maybe_download_weights_and_configs(config.LM_TOKEN_CLASS.WEIGHTS)
|
|
996
1427
|
profile = ModelCatalog.get_profile(config.LM_TOKEN_CLASS.WEIGHTS)
|
|
997
1428
|
categories = profile.categories if profile.categories is not None else {}
|
|
998
1429
|
use_xlm_tokenizer = "xlm_tokenizer" == profile.architecture
|
|
999
|
-
|
|
1430
|
+
|
|
1431
|
+
return {
|
|
1432
|
+
"config_path": config_path,
|
|
1433
|
+
"weights_path": weights_path,
|
|
1434
|
+
"categories": categories,
|
|
1435
|
+
"device": config.DEVICE,
|
|
1436
|
+
"use_xlm_tokenizer": use_xlm_tokenizer,
|
|
1437
|
+
"model_wrapper": profile.model_wrapper,
|
|
1438
|
+
}
|
|
1439
|
+
|
|
1440
|
+
@staticmethod
|
|
1441
|
+
def _build_token_classifier(
|
|
1442
|
+
config_path: str,
|
|
1443
|
+
weights_path: str,
|
|
1444
|
+
categories: Mapping[int, Union[ObjectTypes, str]],
|
|
1445
|
+
device: Literal["cpu", "cuda"],
|
|
1446
|
+
use_xlm_tokenizer: bool,
|
|
1447
|
+
model_wrapper: str,
|
|
1448
|
+
) -> Union[LayoutTokenModels, LmTokenModels]:
|
|
1449
|
+
"""
|
|
1450
|
+
Builds and returns a token classifier model.
|
|
1451
|
+
|
|
1452
|
+
Args:
|
|
1453
|
+
config_path: Path to model configuration.
|
|
1454
|
+
weights_path: Path to model weights.
|
|
1455
|
+
categories: Model categories mapping.
|
|
1456
|
+
device: Device to run model on.
|
|
1457
|
+
use_xlm_tokenizer: Whether to use XLM tokenizer.
|
|
1458
|
+
model_wrapper: Model wrapper class name.
|
|
1459
|
+
|
|
1460
|
+
Returns:
|
|
1461
|
+
The instantiated token classifier model.
|
|
1462
|
+
"""
|
|
1463
|
+
if model_wrapper in ("HFLayoutLmTokenClassifier",):
|
|
1000
1464
|
return HFLayoutLmTokenClassifier(
|
|
1001
1465
|
path_config_json=config_path,
|
|
1002
1466
|
path_weights=weights_path,
|
|
1003
1467
|
categories=categories,
|
|
1004
|
-
device=
|
|
1468
|
+
device=device,
|
|
1005
1469
|
use_xlm_tokenizer=use_xlm_tokenizer,
|
|
1006
1470
|
)
|
|
1007
|
-
if
|
|
1471
|
+
if model_wrapper in ("HFLayoutLmv2TokenClassifier",):
|
|
1008
1472
|
return HFLayoutLmv2TokenClassifier(
|
|
1009
1473
|
path_config_json=config_path,
|
|
1010
1474
|
path_weights=weights_path,
|
|
1011
1475
|
categories=categories,
|
|
1012
|
-
device=
|
|
1476
|
+
device=device,
|
|
1013
1477
|
)
|
|
1014
|
-
if
|
|
1478
|
+
if model_wrapper in ("HFLayoutLmv3TokenClassifier",):
|
|
1015
1479
|
return HFLayoutLmv3TokenClassifier(
|
|
1016
1480
|
path_config_json=config_path,
|
|
1017
1481
|
path_weights=weights_path,
|
|
1018
1482
|
categories=categories,
|
|
1019
|
-
device=
|
|
1483
|
+
device=device,
|
|
1020
1484
|
)
|
|
1021
|
-
if
|
|
1485
|
+
if model_wrapper in ("HFLiltTokenClassifier",):
|
|
1022
1486
|
return HFLiltTokenClassifier(
|
|
1023
1487
|
path_config_json=config_path,
|
|
1024
1488
|
path_weights=weights_path,
|
|
1025
1489
|
categories=categories,
|
|
1026
|
-
device=
|
|
1490
|
+
device=device,
|
|
1027
1491
|
)
|
|
1028
|
-
if
|
|
1492
|
+
if model_wrapper in ("HFLmTokenClassifier",):
|
|
1029
1493
|
return HFLmTokenClassifier(
|
|
1030
1494
|
path_config_json=config_path,
|
|
1031
1495
|
path_weights=weights_path,
|
|
1032
1496
|
categories=categories,
|
|
1033
1497
|
)
|
|
1034
|
-
raise ValueError(f"Unsupported model wrapper: {
|
|
1498
|
+
raise ValueError(f"Unsupported model wrapper: {model_wrapper}")
|
|
1035
1499
|
|
|
1036
1500
|
@staticmethod
|
|
1037
1501
|
def build_token_classifier(config: AttrDict) -> Union[LayoutTokenModels, LmTokenModels]:
|
|
@@ -1044,18 +1508,38 @@ class ServiceFactory:
|
|
|
1044
1508
|
Returns:
|
|
1045
1509
|
The instantiated token classifier model.
|
|
1046
1510
|
"""
|
|
1047
|
-
|
|
1511
|
+
token_classifier_kwargs = ServiceFactory._get_token_classifier_kwargs_from_config(config)
|
|
1512
|
+
return ServiceFactory._build_token_classifier(**token_classifier_kwargs)
|
|
1513
|
+
|
|
1514
|
+
@staticmethod
|
|
1515
|
+
def _get_token_classifier_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
1516
|
+
"""
|
|
1517
|
+
Extracting token classifier service kwargs from config.
|
|
1518
|
+
|
|
1519
|
+
Args:
|
|
1520
|
+
config: Configuration object.
|
|
1521
|
+
"""
|
|
1522
|
+
return {
|
|
1523
|
+
"use_other_as_default_category": config.LM_TOKEN_CLASS.USE_OTHER_AS_DEFAULT_CATEGORY,
|
|
1524
|
+
"segment_positions": config.LM_TOKEN_CLASS.SEGMENT_POSITIONS,
|
|
1525
|
+
"sliding_window_stride": config.LM_TOKEN_CLASS.SLIDING_WINDOW_STRIDE,
|
|
1526
|
+
}
|
|
1048
1527
|
|
|
1049
1528
|
@staticmethod
|
|
1050
1529
|
def _build_token_classifier_service(
|
|
1051
|
-
|
|
1530
|
+
token_classifier: Union[LayoutTokenModels, LmTokenModels],
|
|
1531
|
+
use_other_as_default_category: bool,
|
|
1532
|
+
segment_positions: Union[LayoutType, Sequence[LayoutType], None],
|
|
1533
|
+
sliding_window_stride: int,
|
|
1052
1534
|
) -> LMTokenClassifierService:
|
|
1053
1535
|
"""
|
|
1054
1536
|
Building a token classifier service.
|
|
1055
1537
|
|
|
1056
1538
|
Args:
|
|
1057
|
-
config: Configuration object.
|
|
1058
1539
|
token_classifier: Token classifier instance.
|
|
1540
|
+
use_other_as_default_category: Whether to use other as default category.
|
|
1541
|
+
segment_positions: Segment positions configuration.
|
|
1542
|
+
sliding_window_stride: Sliding window stride.
|
|
1059
1543
|
|
|
1060
1544
|
Returns:
|
|
1061
1545
|
A LMTokenClassifierService instance.
|
|
@@ -1067,9 +1551,9 @@ class ServiceFactory:
|
|
|
1067
1551
|
return LMTokenClassifierService(
|
|
1068
1552
|
tokenizer=tokenizer_fast,
|
|
1069
1553
|
language_model=token_classifier,
|
|
1070
|
-
use_other_as_default_category=
|
|
1071
|
-
segment_positions=
|
|
1072
|
-
sliding_window_stride=
|
|
1554
|
+
use_other_as_default_category=use_other_as_default_category,
|
|
1555
|
+
segment_positions=segment_positions,
|
|
1556
|
+
sliding_window_stride=sliding_window_stride,
|
|
1073
1557
|
)
|
|
1074
1558
|
|
|
1075
1559
|
@staticmethod
|
|
@@ -1086,23 +1570,44 @@ class ServiceFactory:
|
|
|
1086
1570
|
Returns:
|
|
1087
1571
|
A LMTokenClassifierService instance.
|
|
1088
1572
|
"""
|
|
1089
|
-
|
|
1573
|
+
token_classifier_service_kwargs = ServiceFactory._get_token_classifier_service_kwargs_from_config(config)
|
|
1574
|
+
return ServiceFactory._build_token_classifier_service(token_classifier, **token_classifier_service_kwargs)
|
|
1090
1575
|
|
|
1091
1576
|
@staticmethod
|
|
1092
|
-
def
|
|
1577
|
+
def _get_page_parsing_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
|
|
1093
1578
|
"""
|
|
1094
|
-
|
|
1579
|
+
Extracting page parsing service kwargs from config.
|
|
1095
1580
|
|
|
1096
1581
|
Args:
|
|
1097
1582
|
config: Configuration object.
|
|
1583
|
+
"""
|
|
1584
|
+
return {
|
|
1585
|
+
"text_container": config.TEXT_CONTAINER,
|
|
1586
|
+
"floating_text_block_categories": config.TEXT_ORDERING.FLOATING_TEXT_BLOCK_CATEGORIES,
|
|
1587
|
+
"include_residual_text_container": config.TEXT_ORDERING.INCLUDE_RESIDUAL_TEXT_CONTAINER,
|
|
1588
|
+
}
|
|
1589
|
+
|
|
1590
|
+
@staticmethod
|
|
1591
|
+
def _build_page_parsing_service(
|
|
1592
|
+
text_container: Union[ObjectTypes, str],
|
|
1593
|
+
floating_text_block_categories: Sequence[str],
|
|
1594
|
+
include_residual_text_container: bool,
|
|
1595
|
+
) -> PageParsingService:
|
|
1596
|
+
"""
|
|
1597
|
+
Building a page parsing service.
|
|
1598
|
+
|
|
1599
|
+
Args:
|
|
1600
|
+
text_container: Text container categories.
|
|
1601
|
+
floating_text_block_categories: Floating text block categories.
|
|
1602
|
+
include_residual_text_container: Whether to include residual text container.
|
|
1098
1603
|
|
|
1099
1604
|
Returns:
|
|
1100
1605
|
PageParsingService: Page parsing service instance.
|
|
1101
1606
|
"""
|
|
1102
1607
|
return PageParsingService(
|
|
1103
|
-
text_container=
|
|
1104
|
-
floating_text_block_categories=
|
|
1105
|
-
include_residual_text_container=
|
|
1608
|
+
text_container=text_container,
|
|
1609
|
+
floating_text_block_categories=floating_text_block_categories,
|
|
1610
|
+
include_residual_text_container=include_residual_text_container,
|
|
1106
1611
|
)
|
|
1107
1612
|
|
|
1108
1613
|
@staticmethod
|
|
@@ -1116,7 +1621,8 @@ class ServiceFactory:
|
|
|
1116
1621
|
Returns:
|
|
1117
1622
|
PageParsingService: Page parsing service instance.
|
|
1118
1623
|
"""
|
|
1119
|
-
|
|
1624
|
+
page_parsing_service_kwargs = ServiceFactory._get_page_parsing_service_kwargs_from_config(config)
|
|
1625
|
+
return ServiceFactory._build_page_parsing_service(**page_parsing_service_kwargs)
|
|
1120
1626
|
|
|
1121
1627
|
@staticmethod
|
|
1122
1628
|
def build_analyzer(config: AttrDict) -> DoctectionPipe:
|