deepdoctection 0.31__py3-none-any.whl → 0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +35 -28
- deepdoctection/analyzer/dd.py +30 -24
- deepdoctection/configs/conf_dd_one.yaml +34 -31
- deepdoctection/datapoint/annotation.py +2 -1
- deepdoctection/datapoint/box.py +2 -1
- deepdoctection/datapoint/image.py +13 -7
- deepdoctection/datapoint/view.py +95 -24
- deepdoctection/datasets/__init__.py +1 -4
- deepdoctection/datasets/adapter.py +5 -2
- deepdoctection/datasets/base.py +5 -3
- deepdoctection/datasets/info.py +2 -2
- deepdoctection/datasets/instances/doclaynet.py +3 -2
- deepdoctection/datasets/instances/fintabnet.py +2 -1
- deepdoctection/datasets/instances/funsd.py +2 -1
- deepdoctection/datasets/instances/iiitar13k.py +5 -2
- deepdoctection/datasets/instances/layouttest.py +2 -1
- deepdoctection/datasets/instances/publaynet.py +2 -2
- deepdoctection/datasets/instances/pubtables1m.py +6 -3
- deepdoctection/datasets/instances/pubtabnet.py +2 -1
- deepdoctection/datasets/instances/rvlcdip.py +2 -1
- deepdoctection/datasets/instances/xfund.py +2 -1
- deepdoctection/eval/__init__.py +1 -4
- deepdoctection/eval/cocometric.py +2 -1
- deepdoctection/eval/eval.py +17 -13
- deepdoctection/eval/tedsmetric.py +14 -11
- deepdoctection/eval/tp_eval_callback.py +9 -3
- deepdoctection/extern/__init__.py +2 -7
- deepdoctection/extern/d2detect.py +24 -32
- deepdoctection/extern/deskew.py +4 -2
- deepdoctection/extern/doctrocr.py +75 -81
- deepdoctection/extern/fastlang.py +4 -2
- deepdoctection/extern/hfdetr.py +22 -28
- deepdoctection/extern/hflayoutlm.py +335 -103
- deepdoctection/extern/hflm.py +225 -0
- deepdoctection/extern/model.py +56 -47
- deepdoctection/extern/pdftext.py +8 -4
- deepdoctection/extern/pt/__init__.py +1 -3
- deepdoctection/extern/pt/nms.py +6 -2
- deepdoctection/extern/pt/ptutils.py +27 -19
- deepdoctection/extern/texocr.py +4 -2
- deepdoctection/extern/tp/tfutils.py +43 -9
- deepdoctection/extern/tp/tpcompat.py +10 -7
- deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
- deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/config/config.py +9 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +17 -7
- deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +16 -11
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +17 -10
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +14 -8
- deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
- deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/preproc.py +7 -3
- deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
- deepdoctection/extern/tpdetect.py +5 -8
- deepdoctection/mapper/__init__.py +3 -8
- deepdoctection/mapper/d2struct.py +8 -6
- deepdoctection/mapper/hfstruct.py +6 -1
- deepdoctection/mapper/laylmstruct.py +163 -20
- deepdoctection/mapper/maputils.py +3 -1
- deepdoctection/mapper/misc.py +6 -3
- deepdoctection/mapper/tpstruct.py +2 -2
- deepdoctection/pipe/__init__.py +1 -1
- deepdoctection/pipe/common.py +11 -9
- deepdoctection/pipe/concurrency.py +2 -1
- deepdoctection/pipe/layout.py +3 -1
- deepdoctection/pipe/lm.py +32 -64
- deepdoctection/pipe/order.py +142 -35
- deepdoctection/pipe/refine.py +8 -14
- deepdoctection/pipe/{cell.py → sub_layout.py} +1 -1
- deepdoctection/train/__init__.py +6 -12
- deepdoctection/train/d2_frcnn_train.py +21 -16
- deepdoctection/train/hf_detr_train.py +18 -11
- deepdoctection/train/hf_layoutlm_train.py +118 -101
- deepdoctection/train/tp_frcnn_train.py +21 -19
- deepdoctection/utils/env_info.py +41 -117
- deepdoctection/utils/logger.py +1 -0
- deepdoctection/utils/mocks.py +93 -0
- deepdoctection/utils/settings.py +1 -0
- deepdoctection/utils/viz.py +4 -3
- {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/METADATA +27 -18
- deepdoctection-0.32.dist-info/RECORD +146 -0
- deepdoctection-0.31.dist-info/RECORD +0 -144
- {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/LICENSE +0 -0
- {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/WHEEL +0 -0
- {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/top_level.txt +0 -0
deepdoctection/pipe/order.py
CHANGED
|
@@ -18,7 +18,10 @@
|
|
|
18
18
|
"""
|
|
19
19
|
Module for ordering text and layout segments pipeline components
|
|
20
20
|
"""
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
21
23
|
import os
|
|
24
|
+
from abc import ABC
|
|
22
25
|
from copy import copy
|
|
23
26
|
from itertools import chain
|
|
24
27
|
from logging import DEBUG
|
|
@@ -349,10 +352,11 @@ class TextLineGenerator:
|
|
|
349
352
|
self, make_sub_lines: bool, line_category_id: Union[int, str], paragraph_break: Optional[float] = None
|
|
350
353
|
):
|
|
351
354
|
"""
|
|
352
|
-
:param make_sub_lines: Whether to build sub lines from lines
|
|
355
|
+
:param make_sub_lines: Whether to build sub lines from lines.
|
|
353
356
|
:param line_category_id: category_id to give a text line
|
|
354
|
-
:param paragraph_break: threshold of two consecutive words. If distance is larger than threshold, two
|
|
355
|
-
will be built
|
|
357
|
+
:param paragraph_break: threshold of two consecutive words. If distance is larger than threshold, two sub-lines
|
|
358
|
+
will be built. We use relative coordinates to calculate the distance between two
|
|
359
|
+
consecutive words. A reasonable value is 0.035
|
|
356
360
|
"""
|
|
357
361
|
if make_sub_lines and paragraph_break is None:
|
|
358
362
|
raise ValueError("You must specify paragraph_break when setting make_sub_lines to True")
|
|
@@ -375,6 +379,7 @@ class TextLineGenerator:
|
|
|
375
379
|
image_width: float,
|
|
376
380
|
image_height: float,
|
|
377
381
|
image_id: Optional[str] = None,
|
|
382
|
+
highest_level: bool = True,
|
|
378
383
|
) -> Sequence[DetectionResult]:
|
|
379
384
|
"""
|
|
380
385
|
Creating detecting result of lines (or sub lines) from given word type `ImageAnnotation`.
|
|
@@ -392,6 +397,8 @@ class TextLineGenerator:
|
|
|
392
397
|
# list of (word index, text line, word annotation_id)
|
|
393
398
|
word_order_list = OrderGenerator.group_words_into_lines(word_anns, image_id)
|
|
394
399
|
number_rows = max(word[1] for word in word_order_list)
|
|
400
|
+
if number_rows == 1 and not highest_level:
|
|
401
|
+
return []
|
|
395
402
|
detection_result_list = []
|
|
396
403
|
for number_row in range(1, number_rows + 1):
|
|
397
404
|
# list of (word index, text line, word annotation_id) for text line equal to number_row
|
|
@@ -423,29 +430,141 @@ class TextLineGenerator:
|
|
|
423
430
|
if current_box.absolute_coords:
|
|
424
431
|
current_box = current_box.transform(image_width, image_height)
|
|
425
432
|
|
|
426
|
-
# If distance between boxes is lower than paragraph break, same sub
|
|
433
|
+
# If distance between boxes is lower than paragraph break, same sub-line
|
|
427
434
|
if current_box.ulx - prev_box.lrx < self.paragraph_break: # type: ignore
|
|
428
435
|
sub_line.append(ann)
|
|
429
436
|
sub_line_ann_ids.append(ann.annotation_id)
|
|
430
437
|
else:
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
438
|
+
# We need to iterate maybe more than one time, because sub-lines may have more than one line
|
|
439
|
+
# if having been split. Take fore example a multi-column layout where a sub-line has
|
|
440
|
+
# two lines because of a column break and fonts twice as large as the other column.
|
|
441
|
+
detection_results = self.create_detection_result(
|
|
442
|
+
sub_line, image_width, image_height, image_id, False
|
|
443
|
+
)
|
|
444
|
+
if detection_results:
|
|
445
|
+
detection_result_list.extend(detection_results)
|
|
446
|
+
else:
|
|
447
|
+
boxes = [ann.get_bounding_box(image_id) for ann in sub_line]
|
|
448
|
+
merge_box = merge_boxes(*boxes)
|
|
449
|
+
detection_result = self._make_detect_result(merge_box, {"child": sub_line_ann_ids})
|
|
450
|
+
detection_result_list.append(detection_result)
|
|
451
|
+
sub_line = [ann]
|
|
452
|
+
sub_line_ann_ids = [ann.annotation_id]
|
|
437
453
|
|
|
438
454
|
if idx == len(anns_per_row) - 1:
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
455
|
+
detection_results = self.create_detection_result(
|
|
456
|
+
sub_line, image_width, image_height, image_id, False
|
|
457
|
+
)
|
|
458
|
+
if detection_results:
|
|
459
|
+
detection_result_list.extend(detection_results)
|
|
460
|
+
else:
|
|
461
|
+
boxes = [ann.get_bounding_box(image_id) for ann in sub_line]
|
|
462
|
+
merge_box = merge_boxes(*boxes)
|
|
463
|
+
detection_result = self._make_detect_result(merge_box, {"child": sub_line_ann_ids})
|
|
464
|
+
detection_result_list.append(detection_result)
|
|
443
465
|
|
|
444
466
|
return detection_result_list
|
|
445
467
|
|
|
446
468
|
|
|
469
|
+
class TextLineServiceMixin(PipelineComponent, ABC):
|
|
470
|
+
"""
|
|
471
|
+
This class is used to create text lines similar to TextOrderService.
|
|
472
|
+
It uses the logic of the TextOrderService but modifies it to suit its needs.
|
|
473
|
+
It specifically uses the _create_lines_for_words method and modifies the serve method.
|
|
474
|
+
"""
|
|
475
|
+
|
|
476
|
+
def __init__(
|
|
477
|
+
self,
|
|
478
|
+
name: str,
|
|
479
|
+
line_category_id: int = 1,
|
|
480
|
+
include_residual_text_container: bool = True,
|
|
481
|
+
paragraph_break: Optional[float] = None,
|
|
482
|
+
):
|
|
483
|
+
"""
|
|
484
|
+
Initialize the TextLineService with a line_category_id and a TextLineGenerator instance.
|
|
485
|
+
"""
|
|
486
|
+
self.line_category_id = line_category_id
|
|
487
|
+
self.include_residual_text_container = include_residual_text_container
|
|
488
|
+
self.text_line_generator = TextLineGenerator(
|
|
489
|
+
self.include_residual_text_container, self.line_category_id, paragraph_break
|
|
490
|
+
)
|
|
491
|
+
super().__init__(name)
|
|
492
|
+
|
|
493
|
+
def _create_lines_for_words(self, word_anns: Sequence[ImageAnnotation]) -> Sequence[ImageAnnotation]:
|
|
494
|
+
"""
|
|
495
|
+
This method creates lines for words using the TextLineGenerator instance.
|
|
496
|
+
"""
|
|
497
|
+
detection_result_list = self.text_line_generator.create_detection_result(
|
|
498
|
+
word_anns,
|
|
499
|
+
self.dp_manager.datapoint.width,
|
|
500
|
+
self.dp_manager.datapoint.height,
|
|
501
|
+
self.dp_manager.datapoint.image_id,
|
|
502
|
+
)
|
|
503
|
+
line_anns = []
|
|
504
|
+
for detect_result in detection_result_list:
|
|
505
|
+
ann_id = self.dp_manager.set_image_annotation(detect_result)
|
|
506
|
+
if ann_id:
|
|
507
|
+
line_ann = self.dp_manager.get_annotation(ann_id)
|
|
508
|
+
child_ann_id_list = detect_result.relationships["child"] # type: ignore
|
|
509
|
+
for child_ann_id in child_ann_id_list:
|
|
510
|
+
line_ann.dump_relationship(Relationships.child, child_ann_id)
|
|
511
|
+
line_anns.append(line_ann)
|
|
512
|
+
return line_anns
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
class TextLineService(TextLineServiceMixin):
|
|
516
|
+
"""
|
|
517
|
+
Some OCR systems do not identify lines of text but only provide text boxes for words. This is not sufficient
|
|
518
|
+
for certain applications. This service determines rule-based text lines based on word boxes. One difficulty is
|
|
519
|
+
that text lines are not continuous but are interrupted, for example in multi-column layouts.
|
|
520
|
+
These interruptions are taken into account insofar as the gap between two words on almost the same page height
|
|
521
|
+
must not be too large.
|
|
522
|
+
|
|
523
|
+
The service constructs new ImageAnnotation of the category `LayoutType.line` and forms relations between the
|
|
524
|
+
text lines and the words contained in the text lines. The reading order is not arranged.
|
|
525
|
+
"""
|
|
526
|
+
|
|
527
|
+
def __init__(self, line_category_id: int = 1, paragraph_break: Optional[float] = None):
|
|
528
|
+
"""
|
|
529
|
+
Initialize `TextLineService`
|
|
530
|
+
|
|
531
|
+
:param line_category_id: category_id to give a text line
|
|
532
|
+
:param paragraph_break: threshold of two consecutive words. If distance is larger than threshold, two sublines
|
|
533
|
+
will be built
|
|
534
|
+
"""
|
|
535
|
+
super().__init__(
|
|
536
|
+
name="text_line",
|
|
537
|
+
line_category_id=line_category_id,
|
|
538
|
+
include_residual_text_container=True,
|
|
539
|
+
paragraph_break=paragraph_break,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
def clone(self) -> PipelineComponent:
|
|
543
|
+
"""
|
|
544
|
+
This method returns a new instance of the class with the same configuration.
|
|
545
|
+
"""
|
|
546
|
+
return self.__class__(self.line_category_id, self.text_line_generator.paragraph_break)
|
|
547
|
+
|
|
548
|
+
def serve(self, dp: Image) -> None:
|
|
549
|
+
text_container_anns = dp.get_annotation(category_names=LayoutType.word)
|
|
550
|
+
self._create_lines_for_words(text_container_anns)
|
|
551
|
+
|
|
552
|
+
def get_meta_annotation(self) -> JsonDict:
|
|
553
|
+
"""
|
|
554
|
+
This method returns metadata about the annotations created by this pipeline component.
|
|
555
|
+
"""
|
|
556
|
+
return dict(
|
|
557
|
+
[
|
|
558
|
+
("image_annotations", [LayoutType.line]),
|
|
559
|
+
("sub_categories", {LayoutType.line: {Relationships.child}}),
|
|
560
|
+
("relationships", {}),
|
|
561
|
+
("summaries", []),
|
|
562
|
+
]
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
|
|
447
566
|
@pipeline_component_registry.register("TextOrderService")
|
|
448
|
-
class TextOrderService(
|
|
567
|
+
class TextOrderService(TextLineServiceMixin):
|
|
449
568
|
"""
|
|
450
569
|
Reading order of words within floating text blocks as well as reading order of blocks within simple text blocks.
|
|
451
570
|
To understand the difference between floating text blocks and simple text blocks consider a page containing an
|
|
@@ -470,7 +589,8 @@ class TextOrderService(PipelineComponent):
|
|
|
470
589
|
A category annotation per word is generated, which fixes the order per word in the block, as well as a category
|
|
471
590
|
annotation per block, which saves the reading order of the block per page.
|
|
472
591
|
|
|
473
|
-
The blocks are defined in `
|
|
592
|
+
The blocks are defined in `text_block_categories` and text blocks that should be considered when generating
|
|
593
|
+
narrative text must be added in `floating_text_block_categories`.
|
|
474
594
|
|
|
475
595
|
order = TextOrderService(text_container="word",
|
|
476
596
|
text_block_categories=["title", "text", "list", "cell",
|
|
@@ -533,7 +653,12 @@ class TextOrderService(PipelineComponent):
|
|
|
533
653
|
self.text_line_generator = TextLineGenerator(
|
|
534
654
|
self.include_residual_text_container, line_category_id, paragraph_break
|
|
535
655
|
)
|
|
536
|
-
super().__init__(
|
|
656
|
+
super().__init__(
|
|
657
|
+
name="text_order",
|
|
658
|
+
line_category_id=line_category_id,
|
|
659
|
+
include_residual_text_container=include_residual_text_container,
|
|
660
|
+
paragraph_break=paragraph_break,
|
|
661
|
+
)
|
|
537
662
|
self._init_sanity_checks()
|
|
538
663
|
|
|
539
664
|
def serve(self, dp: Image) -> None:
|
|
@@ -567,24 +692,6 @@ class TextOrderService(PipelineComponent):
|
|
|
567
692
|
Relationships.reading_order, idx, Relationships.reading_order, annotation_id
|
|
568
693
|
)
|
|
569
694
|
|
|
570
|
-
def _create_lines_for_words(self, word_anns: Sequence[ImageAnnotation]) -> Sequence[ImageAnnotation]:
|
|
571
|
-
detection_result_list = self.text_line_generator.create_detection_result(
|
|
572
|
-
word_anns,
|
|
573
|
-
self.dp_manager.datapoint.width,
|
|
574
|
-
self.dp_manager.datapoint.height,
|
|
575
|
-
self.dp_manager.datapoint.image_id,
|
|
576
|
-
)
|
|
577
|
-
line_anns = []
|
|
578
|
-
for detect_result in detection_result_list:
|
|
579
|
-
ann_id = self.dp_manager.set_image_annotation(detect_result)
|
|
580
|
-
if ann_id:
|
|
581
|
-
line_ann = self.dp_manager.get_annotation(ann_id)
|
|
582
|
-
child_ann_id_list = detect_result.relationships["child"] # type: ignore
|
|
583
|
-
for child_ann_id in child_ann_id_list:
|
|
584
|
-
line_ann.dump_relationship(Relationships.child, child_ann_id)
|
|
585
|
-
line_anns.append(line_ann)
|
|
586
|
-
return line_anns
|
|
587
|
-
|
|
588
695
|
def order_text_in_text_block(self, text_block_ann: ImageAnnotation) -> None:
|
|
589
696
|
"""
|
|
590
697
|
Order text within a text block. It will take all child-like text containers (determined by a
|
deepdoctection/pipe/refine.py
CHANGED
|
@@ -23,7 +23,7 @@ from collections import defaultdict
|
|
|
23
23
|
from copy import copy
|
|
24
24
|
from dataclasses import asdict
|
|
25
25
|
from itertools import chain, product
|
|
26
|
-
from typing import DefaultDict, List, Optional, Set, Tuple, Union
|
|
26
|
+
from typing import DefaultDict, List, Optional, Sequence, Set, Tuple, Union
|
|
27
27
|
|
|
28
28
|
import networkx as nx # type: ignore
|
|
29
29
|
|
|
@@ -34,7 +34,7 @@ from ..extern.base import DetectionResult
|
|
|
34
34
|
from ..mapper.maputils import MappingContextManager
|
|
35
35
|
from ..utils.detection_types import JsonDict
|
|
36
36
|
from ..utils.error import AnnotationError, ImageError
|
|
37
|
-
from ..utils.settings import CellType, LayoutType, Relationships, TableType, get_type
|
|
37
|
+
from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType, get_type
|
|
38
38
|
from .base import PipelineComponent
|
|
39
39
|
from .registry import pipeline_component_registry
|
|
40
40
|
|
|
@@ -398,19 +398,13 @@ class TableSegmentationRefinementService(PipelineComponent):
|
|
|
398
398
|
|
|
399
399
|
"""
|
|
400
400
|
|
|
401
|
-
def __init__(self) -> None:
|
|
402
|
-
self.
|
|
403
|
-
self.
|
|
404
|
-
LayoutType.cell,
|
|
405
|
-
CellType.column_header,
|
|
406
|
-
CellType.projected_row_header,
|
|
407
|
-
CellType.spanning,
|
|
408
|
-
CellType.row_header,
|
|
409
|
-
]
|
|
401
|
+
def __init__(self, table_name: Sequence[ObjectTypes], cell_names: Sequence[ObjectTypes]) -> None:
|
|
402
|
+
self.table_name = table_name
|
|
403
|
+
self.cell_names = cell_names
|
|
410
404
|
super().__init__("table_segment_refine")
|
|
411
405
|
|
|
412
406
|
def serve(self, dp: Image) -> None:
|
|
413
|
-
tables = dp.get_annotation(category_names=self.
|
|
407
|
+
tables = dp.get_annotation(category_names=self.table_name)
|
|
414
408
|
for table in tables:
|
|
415
409
|
if table.image is None:
|
|
416
410
|
raise ImageError("table.image cannot be None")
|
|
@@ -458,7 +452,7 @@ class TableSegmentationRefinementService(PipelineComponent):
|
|
|
458
452
|
for cell in cells:
|
|
459
453
|
cell.deactivate()
|
|
460
454
|
|
|
461
|
-
cells = table.image.get_annotation(category_names=self.
|
|
455
|
+
cells = table.image.get_annotation(category_names=self.cell_names)
|
|
462
456
|
number_of_rows = max(int(cell.get_sub_category(CellType.row_number).category_id) for cell in cells)
|
|
463
457
|
number_of_cols = max(int(cell.get_sub_category(CellType.column_number).category_id) for cell in cells)
|
|
464
458
|
max_row_span = max(int(cell.get_sub_category(CellType.row_span).category_id) for cell in cells)
|
|
@@ -500,7 +494,7 @@ class TableSegmentationRefinementService(PipelineComponent):
|
|
|
500
494
|
self.dp_manager.set_container_annotation(TableType.html, -1, TableType.html, table.annotation_id, html)
|
|
501
495
|
|
|
502
496
|
def clone(self) -> PipelineComponent:
|
|
503
|
-
return self.__class__()
|
|
497
|
+
return self.__class__(self.table_name, self.cell_names)
|
|
504
498
|
|
|
505
499
|
def get_meta_annotation(self) -> JsonDict:
|
|
506
500
|
return dict(
|
deepdoctection/train/__init__.py
CHANGED
|
@@ -19,20 +19,14 @@
|
|
|
19
19
|
Init module for train package
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
from ..utils.file_utils import
|
|
23
|
-
detectron2_available,
|
|
24
|
-
pytorch_available,
|
|
25
|
-
tensorpack_available,
|
|
26
|
-
tf_available,
|
|
27
|
-
transformers_available,
|
|
28
|
-
)
|
|
22
|
+
from ..utils.file_utils import detectron2_available, tensorpack_available, transformers_available
|
|
29
23
|
|
|
30
|
-
if
|
|
31
|
-
from .tp_frcnn_train import train_faster_rcnn
|
|
32
|
-
|
|
33
|
-
if pytorch_available() and detectron2_available():
|
|
24
|
+
if detectron2_available():
|
|
34
25
|
from .d2_frcnn_train import train_d2_faster_rcnn
|
|
35
26
|
|
|
36
|
-
if
|
|
27
|
+
if transformers_available():
|
|
37
28
|
from .hf_detr_train import train_hf_detr
|
|
38
29
|
from .hf_layoutlm_train import train_hf_layoutlm
|
|
30
|
+
|
|
31
|
+
if tensorpack_available():
|
|
32
|
+
from .tp_frcnn_train import train_faster_rcnn
|
|
@@ -18,19 +18,12 @@
|
|
|
18
18
|
"""
|
|
19
19
|
Module for training Detectron2 `GeneralizedRCNN`
|
|
20
20
|
"""
|
|
21
|
-
|
|
21
|
+
from __future__ import annotations
|
|
22
22
|
|
|
23
23
|
import copy
|
|
24
24
|
from typing import Any, Dict, List, Mapping, Optional, Sequence, Type, Union
|
|
25
25
|
|
|
26
|
-
from
|
|
27
|
-
from detectron2.data import DatasetMapper, build_detection_train_loader
|
|
28
|
-
from detectron2.data.transforms import RandomFlip, ResizeShortestEdge
|
|
29
|
-
from detectron2.engine import DefaultTrainer, HookBase, default_writers, hooks
|
|
30
|
-
from detectron2.utils import comm
|
|
31
|
-
from detectron2.utils.events import EventWriter, get_event_storage
|
|
32
|
-
from fvcore.nn.precise_bn import get_bn_modules # type: ignore
|
|
33
|
-
from torch.utils.data import DataLoader, IterableDataset
|
|
26
|
+
from lazy_imports import try_import
|
|
34
27
|
|
|
35
28
|
from ..datasets.adapter import DatasetAdapter
|
|
36
29
|
from ..datasets.base import DatasetBase
|
|
@@ -39,7 +32,6 @@ from ..eval.base import MetricBase
|
|
|
39
32
|
from ..eval.eval import Evaluator
|
|
40
33
|
from ..eval.registry import metric_registry
|
|
41
34
|
from ..extern.d2detect import D2FrcnnDetector
|
|
42
|
-
from ..extern.pt.ptutils import get_num_gpu
|
|
43
35
|
from ..mapper.d2struct import image_to_d2_frcnn_training
|
|
44
36
|
from ..pipe.base import PredictorPipelineComponent
|
|
45
37
|
from ..pipe.registry import pipeline_component_registry
|
|
@@ -48,7 +40,20 @@ from ..utils.file_utils import get_wandb_requirement, wandb_available
|
|
|
48
40
|
from ..utils.logger import LoggingRecord, logger
|
|
49
41
|
from ..utils.utils import string_to_dict
|
|
50
42
|
|
|
51
|
-
|
|
43
|
+
with try_import() as d2_import_guard:
|
|
44
|
+
from detectron2.config import CfgNode, get_cfg
|
|
45
|
+
from detectron2.data import DatasetMapper, build_detection_train_loader
|
|
46
|
+
from detectron2.data.transforms import RandomFlip, ResizeShortestEdge
|
|
47
|
+
from detectron2.engine import DefaultTrainer, HookBase, default_writers, hooks
|
|
48
|
+
from detectron2.utils import comm
|
|
49
|
+
from detectron2.utils.events import EventWriter, get_event_storage
|
|
50
|
+
from fvcore.nn.precise_bn import get_bn_modules # type: ignore
|
|
51
|
+
|
|
52
|
+
with try_import() as pt_import_guard:
|
|
53
|
+
from torch import cuda
|
|
54
|
+
from torch.utils.data import DataLoader, IterableDataset
|
|
55
|
+
|
|
56
|
+
with try_import() as wb_import_guard:
|
|
52
57
|
import wandb
|
|
53
58
|
|
|
54
59
|
|
|
@@ -112,7 +117,7 @@ class WandbWriter(EventWriter):
|
|
|
112
117
|
config = {}
|
|
113
118
|
self._window_size = window_size
|
|
114
119
|
self._run = wandb.init(project=project, config=config, **kwargs) if not wandb.run else wandb.run
|
|
115
|
-
self._run._label(repo=repo)
|
|
120
|
+
self._run._label(repo=repo)
|
|
116
121
|
|
|
117
122
|
def write(self) -> None:
|
|
118
123
|
storage = get_event_storage()
|
|
@@ -121,10 +126,10 @@ class WandbWriter(EventWriter):
|
|
|
121
126
|
for key, (val, _) in storage.latest_with_smoothing_hint(self._window_size).items():
|
|
122
127
|
log_dict[key] = val
|
|
123
128
|
|
|
124
|
-
self._run.log(log_dict)
|
|
129
|
+
self._run.log(log_dict)
|
|
125
130
|
|
|
126
131
|
def close(self) -> None:
|
|
127
|
-
self._run.finish()
|
|
132
|
+
self._run.finish()
|
|
128
133
|
|
|
129
134
|
|
|
130
135
|
class D2Trainer(DefaultTrainer):
|
|
@@ -259,7 +264,7 @@ class D2Trainer(DefaultTrainer):
|
|
|
259
264
|
dataset_val,
|
|
260
265
|
pipeline_component,
|
|
261
266
|
metric,
|
|
262
|
-
num_threads=
|
|
267
|
+
num_threads=cuda.device_count() * 2,
|
|
263
268
|
run=run,
|
|
264
269
|
)
|
|
265
270
|
if build_val_dict:
|
|
@@ -335,7 +340,7 @@ def train_d2_faster_rcnn(
|
|
|
335
340
|
:param pipeline_component_name: A pipeline component name to use for validation.
|
|
336
341
|
"""
|
|
337
342
|
|
|
338
|
-
assert
|
|
343
|
+
assert cuda.device_count() > 0, "Has to train with GPU!"
|
|
339
344
|
|
|
340
345
|
build_train_dict: Dict[str, str] = {}
|
|
341
346
|
if build_train_config is not None:
|
|
@@ -19,20 +19,12 @@
|
|
|
19
19
|
Module for training Hugging Face Detr implementation. Note, that this scripts only trans Tabletransformer like Detr
|
|
20
20
|
models that are a slightly different from the plain Detr model that are provided by the transformer library.
|
|
21
21
|
"""
|
|
22
|
+
from __future__ import annotations
|
|
22
23
|
|
|
23
24
|
import copy
|
|
24
25
|
from typing import Any, Dict, List, Optional, Sequence, Type, Union
|
|
25
26
|
|
|
26
|
-
from
|
|
27
|
-
from torch.utils.data import Dataset
|
|
28
|
-
from transformers import (
|
|
29
|
-
AutoFeatureExtractor,
|
|
30
|
-
IntervalStrategy,
|
|
31
|
-
PretrainedConfig,
|
|
32
|
-
PreTrainedModel,
|
|
33
|
-
TableTransformerForObjectDetection,
|
|
34
|
-
)
|
|
35
|
-
from transformers.trainer import Trainer, TrainingArguments
|
|
27
|
+
from lazy_imports import try_import
|
|
36
28
|
|
|
37
29
|
from ..datasets.adapter import DatasetAdapter
|
|
38
30
|
from ..datasets.base import DatasetBase
|
|
@@ -47,6 +39,21 @@ from ..pipe.registry import pipeline_component_registry
|
|
|
47
39
|
from ..utils.logger import LoggingRecord, logger
|
|
48
40
|
from ..utils.utils import string_to_dict
|
|
49
41
|
|
|
42
|
+
with try_import() as pt_import_guard:
|
|
43
|
+
from torch import nn
|
|
44
|
+
from torch.utils.data import Dataset
|
|
45
|
+
|
|
46
|
+
with try_import() as hf_import_guard:
|
|
47
|
+
from transformers import (
|
|
48
|
+
AutoFeatureExtractor,
|
|
49
|
+
IntervalStrategy,
|
|
50
|
+
PretrainedConfig,
|
|
51
|
+
PreTrainedModel,
|
|
52
|
+
TableTransformerForObjectDetection,
|
|
53
|
+
Trainer,
|
|
54
|
+
TrainingArguments,
|
|
55
|
+
)
|
|
56
|
+
|
|
50
57
|
|
|
51
58
|
class DetrDerivedTrainer(Trainer):
|
|
52
59
|
"""
|
|
@@ -61,7 +68,7 @@ class DetrDerivedTrainer(Trainer):
|
|
|
61
68
|
|
|
62
69
|
def __init__(
|
|
63
70
|
self,
|
|
64
|
-
model: Union[PreTrainedModel, Module],
|
|
71
|
+
model: Union[PreTrainedModel, nn.Module],
|
|
65
72
|
args: TrainingArguments,
|
|
66
73
|
data_collator: DetrDataCollator,
|
|
67
74
|
train_dataset: Dataset[Any],
|