deepdoctection 0.31__py3-none-any.whl → 0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (91) hide show
  1. deepdoctection/__init__.py +35 -28
  2. deepdoctection/analyzer/dd.py +30 -24
  3. deepdoctection/configs/conf_dd_one.yaml +34 -31
  4. deepdoctection/datapoint/annotation.py +2 -1
  5. deepdoctection/datapoint/box.py +2 -1
  6. deepdoctection/datapoint/image.py +13 -7
  7. deepdoctection/datapoint/view.py +95 -24
  8. deepdoctection/datasets/__init__.py +1 -4
  9. deepdoctection/datasets/adapter.py +5 -2
  10. deepdoctection/datasets/base.py +5 -3
  11. deepdoctection/datasets/info.py +2 -2
  12. deepdoctection/datasets/instances/doclaynet.py +3 -2
  13. deepdoctection/datasets/instances/fintabnet.py +2 -1
  14. deepdoctection/datasets/instances/funsd.py +2 -1
  15. deepdoctection/datasets/instances/iiitar13k.py +5 -2
  16. deepdoctection/datasets/instances/layouttest.py +2 -1
  17. deepdoctection/datasets/instances/publaynet.py +2 -2
  18. deepdoctection/datasets/instances/pubtables1m.py +6 -3
  19. deepdoctection/datasets/instances/pubtabnet.py +2 -1
  20. deepdoctection/datasets/instances/rvlcdip.py +2 -1
  21. deepdoctection/datasets/instances/xfund.py +2 -1
  22. deepdoctection/eval/__init__.py +1 -4
  23. deepdoctection/eval/cocometric.py +2 -1
  24. deepdoctection/eval/eval.py +17 -13
  25. deepdoctection/eval/tedsmetric.py +14 -11
  26. deepdoctection/eval/tp_eval_callback.py +9 -3
  27. deepdoctection/extern/__init__.py +2 -7
  28. deepdoctection/extern/d2detect.py +24 -32
  29. deepdoctection/extern/deskew.py +4 -2
  30. deepdoctection/extern/doctrocr.py +75 -81
  31. deepdoctection/extern/fastlang.py +4 -2
  32. deepdoctection/extern/hfdetr.py +22 -28
  33. deepdoctection/extern/hflayoutlm.py +335 -103
  34. deepdoctection/extern/hflm.py +225 -0
  35. deepdoctection/extern/model.py +56 -47
  36. deepdoctection/extern/pdftext.py +8 -4
  37. deepdoctection/extern/pt/__init__.py +1 -3
  38. deepdoctection/extern/pt/nms.py +6 -2
  39. deepdoctection/extern/pt/ptutils.py +27 -19
  40. deepdoctection/extern/texocr.py +4 -2
  41. deepdoctection/extern/tp/tfutils.py +43 -9
  42. deepdoctection/extern/tp/tpcompat.py +10 -7
  43. deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
  44. deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
  45. deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
  46. deepdoctection/extern/tp/tpfrcnn/config/config.py +9 -6
  47. deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
  48. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +17 -7
  49. deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
  50. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +9 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +16 -11
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +17 -10
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +14 -8
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
  56. deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
  57. deepdoctection/extern/tp/tpfrcnn/preproc.py +7 -3
  58. deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
  59. deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
  60. deepdoctection/extern/tpdetect.py +5 -8
  61. deepdoctection/mapper/__init__.py +3 -8
  62. deepdoctection/mapper/d2struct.py +8 -6
  63. deepdoctection/mapper/hfstruct.py +6 -1
  64. deepdoctection/mapper/laylmstruct.py +163 -20
  65. deepdoctection/mapper/maputils.py +3 -1
  66. deepdoctection/mapper/misc.py +6 -3
  67. deepdoctection/mapper/tpstruct.py +2 -2
  68. deepdoctection/pipe/__init__.py +1 -1
  69. deepdoctection/pipe/common.py +11 -9
  70. deepdoctection/pipe/concurrency.py +2 -1
  71. deepdoctection/pipe/layout.py +3 -1
  72. deepdoctection/pipe/lm.py +32 -64
  73. deepdoctection/pipe/order.py +142 -35
  74. deepdoctection/pipe/refine.py +8 -14
  75. deepdoctection/pipe/{cell.py → sub_layout.py} +1 -1
  76. deepdoctection/train/__init__.py +6 -12
  77. deepdoctection/train/d2_frcnn_train.py +21 -16
  78. deepdoctection/train/hf_detr_train.py +18 -11
  79. deepdoctection/train/hf_layoutlm_train.py +118 -101
  80. deepdoctection/train/tp_frcnn_train.py +21 -19
  81. deepdoctection/utils/env_info.py +41 -117
  82. deepdoctection/utils/logger.py +1 -0
  83. deepdoctection/utils/mocks.py +93 -0
  84. deepdoctection/utils/settings.py +1 -0
  85. deepdoctection/utils/viz.py +4 -3
  86. {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/METADATA +27 -18
  87. deepdoctection-0.32.dist-info/RECORD +146 -0
  88. deepdoctection-0.31.dist-info/RECORD +0 -144
  89. {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/LICENSE +0 -0
  90. {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/WHEEL +0 -0
  91. {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,10 @@
18
18
  """
19
19
  Module for ordering text and layout segments pipeline components
20
20
  """
21
+ from __future__ import annotations
22
+
21
23
  import os
24
+ from abc import ABC
22
25
  from copy import copy
23
26
  from itertools import chain
24
27
  from logging import DEBUG
@@ -349,10 +352,11 @@ class TextLineGenerator:
349
352
  self, make_sub_lines: bool, line_category_id: Union[int, str], paragraph_break: Optional[float] = None
350
353
  ):
351
354
  """
352
- :param make_sub_lines: Whether to build sub lines from lines
355
+ :param make_sub_lines: Whether to build sub lines from lines.
353
356
  :param line_category_id: category_id to give a text line
354
- :param paragraph_break: threshold of two consecutive words. If distance is larger than threshold, two sublines
355
- will be built
357
+ :param paragraph_break: threshold of two consecutive words. If distance is larger than threshold, two sub-lines
358
+ will be built. We use relative coordinates to calculate the distance between two
359
+ consecutive words. A reasonable value is 0.035
356
360
  """
357
361
  if make_sub_lines and paragraph_break is None:
358
362
  raise ValueError("You must specify paragraph_break when setting make_sub_lines to True")
@@ -375,6 +379,7 @@ class TextLineGenerator:
375
379
  image_width: float,
376
380
  image_height: float,
377
381
  image_id: Optional[str] = None,
382
+ highest_level: bool = True,
378
383
  ) -> Sequence[DetectionResult]:
379
384
  """
380
385
  Creating detecting result of lines (or sub lines) from given word type `ImageAnnotation`.
@@ -392,6 +397,8 @@ class TextLineGenerator:
392
397
  # list of (word index, text line, word annotation_id)
393
398
  word_order_list = OrderGenerator.group_words_into_lines(word_anns, image_id)
394
399
  number_rows = max(word[1] for word in word_order_list)
400
+ if number_rows == 1 and not highest_level:
401
+ return []
395
402
  detection_result_list = []
396
403
  for number_row in range(1, number_rows + 1):
397
404
  # list of (word index, text line, word annotation_id) for text line equal to number_row
@@ -423,29 +430,141 @@ class TextLineGenerator:
423
430
  if current_box.absolute_coords:
424
431
  current_box = current_box.transform(image_width, image_height)
425
432
 
426
- # If distance between boxes is lower than paragraph break, same sub line
433
+ # If distance between boxes is lower than paragraph break, same sub-line
427
434
  if current_box.ulx - prev_box.lrx < self.paragraph_break: # type: ignore
428
435
  sub_line.append(ann)
429
436
  sub_line_ann_ids.append(ann.annotation_id)
430
437
  else:
431
- boxes = [ann.get_bounding_box(image_id) for ann in sub_line]
432
- merge_box = merge_boxes(*boxes)
433
- detection_result = self._make_detect_result(merge_box, {"child": sub_line_ann_ids})
434
- detection_result_list.append(detection_result)
435
- sub_line = [ann]
436
- sub_line_ann_ids = [ann.annotation_id]
438
+ # We need to iterate maybe more than one time, because sub-lines may have more than one line
439
+ # if having been split. Take fore example a multi-column layout where a sub-line has
440
+ # two lines because of a column break and fonts twice as large as the other column.
441
+ detection_results = self.create_detection_result(
442
+ sub_line, image_width, image_height, image_id, False
443
+ )
444
+ if detection_results:
445
+ detection_result_list.extend(detection_results)
446
+ else:
447
+ boxes = [ann.get_bounding_box(image_id) for ann in sub_line]
448
+ merge_box = merge_boxes(*boxes)
449
+ detection_result = self._make_detect_result(merge_box, {"child": sub_line_ann_ids})
450
+ detection_result_list.append(detection_result)
451
+ sub_line = [ann]
452
+ sub_line_ann_ids = [ann.annotation_id]
437
453
 
438
454
  if idx == len(anns_per_row) - 1:
439
- boxes = [ann.get_bounding_box(image_id) for ann in sub_line]
440
- merge_box = merge_boxes(*boxes)
441
- detection_result = self._make_detect_result(merge_box, {"child": sub_line_ann_ids})
442
- detection_result_list.append(detection_result)
455
+ detection_results = self.create_detection_result(
456
+ sub_line, image_width, image_height, image_id, False
457
+ )
458
+ if detection_results:
459
+ detection_result_list.extend(detection_results)
460
+ else:
461
+ boxes = [ann.get_bounding_box(image_id) for ann in sub_line]
462
+ merge_box = merge_boxes(*boxes)
463
+ detection_result = self._make_detect_result(merge_box, {"child": sub_line_ann_ids})
464
+ detection_result_list.append(detection_result)
443
465
 
444
466
  return detection_result_list
445
467
 
446
468
 
469
+ class TextLineServiceMixin(PipelineComponent, ABC):
470
+ """
471
+ This class is used to create text lines similar to TextOrderService.
472
+ It uses the logic of the TextOrderService but modifies it to suit its needs.
473
+ It specifically uses the _create_lines_for_words method and modifies the serve method.
474
+ """
475
+
476
+ def __init__(
477
+ self,
478
+ name: str,
479
+ line_category_id: int = 1,
480
+ include_residual_text_container: bool = True,
481
+ paragraph_break: Optional[float] = None,
482
+ ):
483
+ """
484
+ Initialize the TextLineService with a line_category_id and a TextLineGenerator instance.
485
+ """
486
+ self.line_category_id = line_category_id
487
+ self.include_residual_text_container = include_residual_text_container
488
+ self.text_line_generator = TextLineGenerator(
489
+ self.include_residual_text_container, self.line_category_id, paragraph_break
490
+ )
491
+ super().__init__(name)
492
+
493
+ def _create_lines_for_words(self, word_anns: Sequence[ImageAnnotation]) -> Sequence[ImageAnnotation]:
494
+ """
495
+ This method creates lines for words using the TextLineGenerator instance.
496
+ """
497
+ detection_result_list = self.text_line_generator.create_detection_result(
498
+ word_anns,
499
+ self.dp_manager.datapoint.width,
500
+ self.dp_manager.datapoint.height,
501
+ self.dp_manager.datapoint.image_id,
502
+ )
503
+ line_anns = []
504
+ for detect_result in detection_result_list:
505
+ ann_id = self.dp_manager.set_image_annotation(detect_result)
506
+ if ann_id:
507
+ line_ann = self.dp_manager.get_annotation(ann_id)
508
+ child_ann_id_list = detect_result.relationships["child"] # type: ignore
509
+ for child_ann_id in child_ann_id_list:
510
+ line_ann.dump_relationship(Relationships.child, child_ann_id)
511
+ line_anns.append(line_ann)
512
+ return line_anns
513
+
514
+
515
+ class TextLineService(TextLineServiceMixin):
516
+ """
517
+ Some OCR systems do not identify lines of text but only provide text boxes for words. This is not sufficient
518
+ for certain applications. This service determines rule-based text lines based on word boxes. One difficulty is
519
+ that text lines are not continuous but are interrupted, for example in multi-column layouts.
520
+ These interruptions are taken into account insofar as the gap between two words on almost the same page height
521
+ must not be too large.
522
+
523
+ The service constructs new ImageAnnotation of the category `LayoutType.line` and forms relations between the
524
+ text lines and the words contained in the text lines. The reading order is not arranged.
525
+ """
526
+
527
+ def __init__(self, line_category_id: int = 1, paragraph_break: Optional[float] = None):
528
+ """
529
+ Initialize `TextLineService`
530
+
531
+ :param line_category_id: category_id to give a text line
532
+ :param paragraph_break: threshold of two consecutive words. If distance is larger than threshold, two sublines
533
+ will be built
534
+ """
535
+ super().__init__(
536
+ name="text_line",
537
+ line_category_id=line_category_id,
538
+ include_residual_text_container=True,
539
+ paragraph_break=paragraph_break,
540
+ )
541
+
542
+ def clone(self) -> PipelineComponent:
543
+ """
544
+ This method returns a new instance of the class with the same configuration.
545
+ """
546
+ return self.__class__(self.line_category_id, self.text_line_generator.paragraph_break)
547
+
548
+ def serve(self, dp: Image) -> None:
549
+ text_container_anns = dp.get_annotation(category_names=LayoutType.word)
550
+ self._create_lines_for_words(text_container_anns)
551
+
552
+ def get_meta_annotation(self) -> JsonDict:
553
+ """
554
+ This method returns metadata about the annotations created by this pipeline component.
555
+ """
556
+ return dict(
557
+ [
558
+ ("image_annotations", [LayoutType.line]),
559
+ ("sub_categories", {LayoutType.line: {Relationships.child}}),
560
+ ("relationships", {}),
561
+ ("summaries", []),
562
+ ]
563
+ )
564
+
565
+
447
566
  @pipeline_component_registry.register("TextOrderService")
448
- class TextOrderService(PipelineComponent):
567
+ class TextOrderService(TextLineServiceMixin):
449
568
  """
450
569
  Reading order of words within floating text blocks as well as reading order of blocks within simple text blocks.
451
570
  To understand the difference between floating text blocks and simple text blocks consider a page containing an
@@ -470,7 +589,8 @@ class TextOrderService(PipelineComponent):
470
589
  A category annotation per word is generated, which fixes the order per word in the block, as well as a category
471
590
  annotation per block, which saves the reading order of the block per page.
472
591
 
473
- The blocks are defined in `_floating_text_block_names` and text blocks in `_floating_text_block_names`.
592
+ The blocks are defined in `text_block_categories` and text blocks that should be considered when generating
593
+ narrative text must be added in `floating_text_block_categories`.
474
594
 
475
595
  order = TextOrderService(text_container="word",
476
596
  text_block_categories=["title", "text", "list", "cell",
@@ -533,7 +653,12 @@ class TextOrderService(PipelineComponent):
533
653
  self.text_line_generator = TextLineGenerator(
534
654
  self.include_residual_text_container, line_category_id, paragraph_break
535
655
  )
536
- super().__init__("text_order")
656
+ super().__init__(
657
+ name="text_order",
658
+ line_category_id=line_category_id,
659
+ include_residual_text_container=include_residual_text_container,
660
+ paragraph_break=paragraph_break,
661
+ )
537
662
  self._init_sanity_checks()
538
663
 
539
664
  def serve(self, dp: Image) -> None:
@@ -567,24 +692,6 @@ class TextOrderService(PipelineComponent):
567
692
  Relationships.reading_order, idx, Relationships.reading_order, annotation_id
568
693
  )
569
694
 
570
- def _create_lines_for_words(self, word_anns: Sequence[ImageAnnotation]) -> Sequence[ImageAnnotation]:
571
- detection_result_list = self.text_line_generator.create_detection_result(
572
- word_anns,
573
- self.dp_manager.datapoint.width,
574
- self.dp_manager.datapoint.height,
575
- self.dp_manager.datapoint.image_id,
576
- )
577
- line_anns = []
578
- for detect_result in detection_result_list:
579
- ann_id = self.dp_manager.set_image_annotation(detect_result)
580
- if ann_id:
581
- line_ann = self.dp_manager.get_annotation(ann_id)
582
- child_ann_id_list = detect_result.relationships["child"] # type: ignore
583
- for child_ann_id in child_ann_id_list:
584
- line_ann.dump_relationship(Relationships.child, child_ann_id)
585
- line_anns.append(line_ann)
586
- return line_anns
587
-
588
695
  def order_text_in_text_block(self, text_block_ann: ImageAnnotation) -> None:
589
696
  """
590
697
  Order text within a text block. It will take all child-like text containers (determined by a
@@ -23,7 +23,7 @@ from collections import defaultdict
23
23
  from copy import copy
24
24
  from dataclasses import asdict
25
25
  from itertools import chain, product
26
- from typing import DefaultDict, List, Optional, Set, Tuple, Union
26
+ from typing import DefaultDict, List, Optional, Sequence, Set, Tuple, Union
27
27
 
28
28
  import networkx as nx # type: ignore
29
29
 
@@ -34,7 +34,7 @@ from ..extern.base import DetectionResult
34
34
  from ..mapper.maputils import MappingContextManager
35
35
  from ..utils.detection_types import JsonDict
36
36
  from ..utils.error import AnnotationError, ImageError
37
- from ..utils.settings import CellType, LayoutType, Relationships, TableType, get_type
37
+ from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType, get_type
38
38
  from .base import PipelineComponent
39
39
  from .registry import pipeline_component_registry
40
40
 
@@ -398,19 +398,13 @@ class TableSegmentationRefinementService(PipelineComponent):
398
398
 
399
399
  """
400
400
 
401
- def __init__(self) -> None:
402
- self._table_name = [LayoutType.table, LayoutType.table_rotated]
403
- self._cell_names = [
404
- LayoutType.cell,
405
- CellType.column_header,
406
- CellType.projected_row_header,
407
- CellType.spanning,
408
- CellType.row_header,
409
- ]
401
+ def __init__(self, table_name: Sequence[ObjectTypes], cell_names: Sequence[ObjectTypes]) -> None:
402
+ self.table_name = table_name
403
+ self.cell_names = cell_names
410
404
  super().__init__("table_segment_refine")
411
405
 
412
406
  def serve(self, dp: Image) -> None:
413
- tables = dp.get_annotation(category_names=self._table_name)
407
+ tables = dp.get_annotation(category_names=self.table_name)
414
408
  for table in tables:
415
409
  if table.image is None:
416
410
  raise ImageError("table.image cannot be None")
@@ -458,7 +452,7 @@ class TableSegmentationRefinementService(PipelineComponent):
458
452
  for cell in cells:
459
453
  cell.deactivate()
460
454
 
461
- cells = table.image.get_annotation(category_names=self._cell_names)
455
+ cells = table.image.get_annotation(category_names=self.cell_names)
462
456
  number_of_rows = max(int(cell.get_sub_category(CellType.row_number).category_id) for cell in cells)
463
457
  number_of_cols = max(int(cell.get_sub_category(CellType.column_number).category_id) for cell in cells)
464
458
  max_row_span = max(int(cell.get_sub_category(CellType.row_span).category_id) for cell in cells)
@@ -500,7 +494,7 @@ class TableSegmentationRefinementService(PipelineComponent):
500
494
  self.dp_manager.set_container_annotation(TableType.html, -1, TableType.html, table.annotation_id, html)
501
495
 
502
496
  def clone(self) -> PipelineComponent:
503
- return self.__class__()
497
+ return self.__class__(self.table_name, self.cell_names)
504
498
 
505
499
  def get_meta_annotation(self) -> JsonDict:
506
500
  return dict(
@@ -1,5 +1,5 @@
1
1
  # -*- coding: utf-8 -*-
2
- # File: cell.py
2
+ # File: sub_layout.py
3
3
 
4
4
  # Copyright 2021 Dr. Janis Meyer. All rights reserved.
5
5
  #
@@ -19,20 +19,14 @@
19
19
  Init module for train package
20
20
  """
21
21
 
22
- from ..utils.file_utils import (
23
- detectron2_available,
24
- pytorch_available,
25
- tensorpack_available,
26
- tf_available,
27
- transformers_available,
28
- )
22
+ from ..utils.file_utils import detectron2_available, tensorpack_available, transformers_available
29
23
 
30
- if tf_available() and tensorpack_available():
31
- from .tp_frcnn_train import train_faster_rcnn
32
-
33
- if pytorch_available() and detectron2_available():
24
+ if detectron2_available():
34
25
  from .d2_frcnn_train import train_d2_faster_rcnn
35
26
 
36
- if pytorch_available() and transformers_available():
27
+ if transformers_available():
37
28
  from .hf_detr_train import train_hf_detr
38
29
  from .hf_layoutlm_train import train_hf_layoutlm
30
+
31
+ if tensorpack_available():
32
+ from .tp_frcnn_train import train_faster_rcnn
@@ -18,19 +18,12 @@
18
18
  """
19
19
  Module for training Detectron2 `GeneralizedRCNN`
20
20
  """
21
-
21
+ from __future__ import annotations
22
22
 
23
23
  import copy
24
24
  from typing import Any, Dict, List, Mapping, Optional, Sequence, Type, Union
25
25
 
26
- from detectron2.config import CfgNode, get_cfg
27
- from detectron2.data import DatasetMapper, build_detection_train_loader
28
- from detectron2.data.transforms import RandomFlip, ResizeShortestEdge
29
- from detectron2.engine import DefaultTrainer, HookBase, default_writers, hooks
30
- from detectron2.utils import comm
31
- from detectron2.utils.events import EventWriter, get_event_storage
32
- from fvcore.nn.precise_bn import get_bn_modules # type: ignore
33
- from torch.utils.data import DataLoader, IterableDataset
26
+ from lazy_imports import try_import
34
27
 
35
28
  from ..datasets.adapter import DatasetAdapter
36
29
  from ..datasets.base import DatasetBase
@@ -39,7 +32,6 @@ from ..eval.base import MetricBase
39
32
  from ..eval.eval import Evaluator
40
33
  from ..eval.registry import metric_registry
41
34
  from ..extern.d2detect import D2FrcnnDetector
42
- from ..extern.pt.ptutils import get_num_gpu
43
35
  from ..mapper.d2struct import image_to_d2_frcnn_training
44
36
  from ..pipe.base import PredictorPipelineComponent
45
37
  from ..pipe.registry import pipeline_component_registry
@@ -48,7 +40,20 @@ from ..utils.file_utils import get_wandb_requirement, wandb_available
48
40
  from ..utils.logger import LoggingRecord, logger
49
41
  from ..utils.utils import string_to_dict
50
42
 
51
- if wandb_available():
43
+ with try_import() as d2_import_guard:
44
+ from detectron2.config import CfgNode, get_cfg
45
+ from detectron2.data import DatasetMapper, build_detection_train_loader
46
+ from detectron2.data.transforms import RandomFlip, ResizeShortestEdge
47
+ from detectron2.engine import DefaultTrainer, HookBase, default_writers, hooks
48
+ from detectron2.utils import comm
49
+ from detectron2.utils.events import EventWriter, get_event_storage
50
+ from fvcore.nn.precise_bn import get_bn_modules # type: ignore
51
+
52
+ with try_import() as pt_import_guard:
53
+ from torch import cuda
54
+ from torch.utils.data import DataLoader, IterableDataset
55
+
56
+ with try_import() as wb_import_guard:
52
57
  import wandb
53
58
 
54
59
 
@@ -112,7 +117,7 @@ class WandbWriter(EventWriter):
112
117
  config = {}
113
118
  self._window_size = window_size
114
119
  self._run = wandb.init(project=project, config=config, **kwargs) if not wandb.run else wandb.run
115
- self._run._label(repo=repo) # type:ignore
120
+ self._run._label(repo=repo)
116
121
 
117
122
  def write(self) -> None:
118
123
  storage = get_event_storage()
@@ -121,10 +126,10 @@ class WandbWriter(EventWriter):
121
126
  for key, (val, _) in storage.latest_with_smoothing_hint(self._window_size).items():
122
127
  log_dict[key] = val
123
128
 
124
- self._run.log(log_dict) # type:ignore
129
+ self._run.log(log_dict)
125
130
 
126
131
  def close(self) -> None:
127
- self._run.finish() # type:ignore
132
+ self._run.finish()
128
133
 
129
134
 
130
135
  class D2Trainer(DefaultTrainer):
@@ -259,7 +264,7 @@ class D2Trainer(DefaultTrainer):
259
264
  dataset_val,
260
265
  pipeline_component,
261
266
  metric,
262
- num_threads=get_num_gpu() * 2,
267
+ num_threads=cuda.device_count() * 2,
263
268
  run=run,
264
269
  )
265
270
  if build_val_dict:
@@ -335,7 +340,7 @@ def train_d2_faster_rcnn(
335
340
  :param pipeline_component_name: A pipeline component name to use for validation.
336
341
  """
337
342
 
338
- assert get_num_gpu() > 0, "Has to train with GPU!"
343
+ assert cuda.device_count() > 0, "Has to train with GPU!"
339
344
 
340
345
  build_train_dict: Dict[str, str] = {}
341
346
  if build_train_config is not None:
@@ -19,20 +19,12 @@
19
19
  Module for training Hugging Face Detr implementation. Note, that this scripts only trans Tabletransformer like Detr
20
20
  models that are a slightly different from the plain Detr model that are provided by the transformer library.
21
21
  """
22
+ from __future__ import annotations
22
23
 
23
24
  import copy
24
25
  from typing import Any, Dict, List, Optional, Sequence, Type, Union
25
26
 
26
- from torch.nn import Module
27
- from torch.utils.data import Dataset
28
- from transformers import (
29
- AutoFeatureExtractor,
30
- IntervalStrategy,
31
- PretrainedConfig,
32
- PreTrainedModel,
33
- TableTransformerForObjectDetection,
34
- )
35
- from transformers.trainer import Trainer, TrainingArguments
27
+ from lazy_imports import try_import
36
28
 
37
29
  from ..datasets.adapter import DatasetAdapter
38
30
  from ..datasets.base import DatasetBase
@@ -47,6 +39,21 @@ from ..pipe.registry import pipeline_component_registry
47
39
  from ..utils.logger import LoggingRecord, logger
48
40
  from ..utils.utils import string_to_dict
49
41
 
42
+ with try_import() as pt_import_guard:
43
+ from torch import nn
44
+ from torch.utils.data import Dataset
45
+
46
+ with try_import() as hf_import_guard:
47
+ from transformers import (
48
+ AutoFeatureExtractor,
49
+ IntervalStrategy,
50
+ PretrainedConfig,
51
+ PreTrainedModel,
52
+ TableTransformerForObjectDetection,
53
+ Trainer,
54
+ TrainingArguments,
55
+ )
56
+
50
57
 
51
58
  class DetrDerivedTrainer(Trainer):
52
59
  """
@@ -61,7 +68,7 @@ class DetrDerivedTrainer(Trainer):
61
68
 
62
69
  def __init__(
63
70
  self,
64
- model: Union[PreTrainedModel, Module],
71
+ model: Union[PreTrainedModel, nn.Module],
65
72
  args: TrainingArguments,
66
73
  data_collator: DetrDataCollator,
67
74
  train_dataset: Dataset[Any],