deepdoctection 0.30__py3-none-any.whl → 0.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (74) hide show
  1. deepdoctection/__init__.py +4 -2
  2. deepdoctection/analyzer/dd.py +6 -5
  3. deepdoctection/dataflow/base.py +0 -19
  4. deepdoctection/dataflow/custom.py +4 -3
  5. deepdoctection/dataflow/custom_serialize.py +14 -5
  6. deepdoctection/dataflow/parallel_map.py +12 -11
  7. deepdoctection/dataflow/serialize.py +5 -4
  8. deepdoctection/datapoint/annotation.py +33 -12
  9. deepdoctection/datapoint/box.py +1 -4
  10. deepdoctection/datapoint/convert.py +3 -1
  11. deepdoctection/datapoint/image.py +66 -29
  12. deepdoctection/datapoint/view.py +57 -25
  13. deepdoctection/datasets/adapter.py +1 -1
  14. deepdoctection/datasets/base.py +83 -10
  15. deepdoctection/datasets/dataflow_builder.py +1 -1
  16. deepdoctection/datasets/info.py +2 -2
  17. deepdoctection/datasets/instances/layouttest.py +2 -7
  18. deepdoctection/eval/accmetric.py +1 -1
  19. deepdoctection/eval/base.py +5 -4
  20. deepdoctection/eval/eval.py +2 -2
  21. deepdoctection/eval/tp_eval_callback.py +5 -4
  22. deepdoctection/extern/base.py +39 -13
  23. deepdoctection/extern/d2detect.py +164 -64
  24. deepdoctection/extern/deskew.py +32 -7
  25. deepdoctection/extern/doctrocr.py +227 -39
  26. deepdoctection/extern/fastlang.py +45 -7
  27. deepdoctection/extern/hfdetr.py +90 -33
  28. deepdoctection/extern/hflayoutlm.py +109 -22
  29. deepdoctection/extern/pdftext.py +2 -1
  30. deepdoctection/extern/pt/ptutils.py +3 -2
  31. deepdoctection/extern/tessocr.py +134 -22
  32. deepdoctection/extern/texocr.py +2 -0
  33. deepdoctection/extern/tp/tpcompat.py +4 -4
  34. deepdoctection/extern/tp/tpfrcnn/preproc.py +2 -7
  35. deepdoctection/extern/tpdetect.py +50 -23
  36. deepdoctection/mapper/d2struct.py +1 -1
  37. deepdoctection/mapper/hfstruct.py +1 -1
  38. deepdoctection/mapper/laylmstruct.py +1 -1
  39. deepdoctection/mapper/maputils.py +13 -2
  40. deepdoctection/mapper/prodigystruct.py +1 -1
  41. deepdoctection/mapper/pubstruct.py +10 -10
  42. deepdoctection/mapper/tpstruct.py +1 -1
  43. deepdoctection/pipe/anngen.py +35 -8
  44. deepdoctection/pipe/base.py +53 -19
  45. deepdoctection/pipe/cell.py +29 -8
  46. deepdoctection/pipe/common.py +12 -4
  47. deepdoctection/pipe/doctectionpipe.py +2 -2
  48. deepdoctection/pipe/language.py +3 -2
  49. deepdoctection/pipe/layout.py +3 -2
  50. deepdoctection/pipe/lm.py +2 -2
  51. deepdoctection/pipe/refine.py +18 -10
  52. deepdoctection/pipe/segment.py +21 -16
  53. deepdoctection/pipe/text.py +14 -8
  54. deepdoctection/pipe/transform.py +16 -9
  55. deepdoctection/train/d2_frcnn_train.py +15 -12
  56. deepdoctection/train/hf_detr_train.py +8 -6
  57. deepdoctection/train/hf_layoutlm_train.py +16 -11
  58. deepdoctection/utils/__init__.py +3 -0
  59. deepdoctection/utils/concurrency.py +1 -1
  60. deepdoctection/utils/context.py +2 -2
  61. deepdoctection/utils/env_info.py +55 -22
  62. deepdoctection/utils/error.py +84 -0
  63. deepdoctection/utils/file_utils.py +4 -15
  64. deepdoctection/utils/fs.py +7 -7
  65. deepdoctection/utils/pdf_utils.py +5 -4
  66. deepdoctection/utils/settings.py +5 -1
  67. deepdoctection/utils/transform.py +1 -1
  68. deepdoctection/utils/utils.py +0 -6
  69. deepdoctection/utils/viz.py +44 -2
  70. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/METADATA +33 -58
  71. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/RECORD +74 -73
  72. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/WHEEL +1 -1
  73. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/LICENSE +0 -0
  74. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from ..datapoint.image import Image
25
25
  from ..datapoint.view import Page
26
26
  from ..extern.base import LanguageDetector, ObjectDetector
27
27
  from ..utils.detection_types import JsonDict
28
+ from ..utils.error import ImageError
28
29
  from ..utils.settings import PageType, TypeOrStr, get_type
29
30
  from .base import PipelineComponent
30
31
  from .registry import pipeline_component_registry
@@ -86,7 +87,7 @@ class LanguageDetectionService(PipelineComponent):
86
87
  text = page.text_no_line_break
87
88
  else:
88
89
  if dp.image is None:
89
- raise ValueError("dp.image cannot be None")
90
+ raise ImageError("image cannot be None")
90
91
  detect_result_list = self.text_detector.predict(dp.image)
91
92
  # this is a concatenation of all detection result. No reading order
92
93
  text = " ".join([result.text for result in detect_result_list if result.text is not None])
@@ -98,7 +99,7 @@ class LanguageDetectionService(PipelineComponent):
98
99
  def clone(self) -> PipelineComponent:
99
100
  predictor = self.predictor.clone()
100
101
  if not isinstance(predictor, LanguageDetector):
101
- raise ValueError(f"Predictor must be of type LanguageDetector, but is of type {type(predictor)}")
102
+ raise TypeError(f"Predictor must be of type LanguageDetector, but is of type {type(predictor)}")
102
103
  return self.__class__(
103
104
  predictor,
104
105
  copy(self.text_container),
@@ -25,6 +25,7 @@ import numpy as np
25
25
  from ..datapoint.image import Image
26
26
  from ..extern.base import ObjectDetector, PdfMiner
27
27
  from ..utils.detection_types import JsonDict
28
+ from ..utils.error import ImageError
28
29
  from ..utils.transform import PadTransform
29
30
  from .base import PredictorPipelineComponent
30
31
  from .registry import pipeline_component_registry
@@ -79,7 +80,7 @@ class ImageLayoutService(PredictorPipelineComponent):
79
80
  if anns:
80
81
  return
81
82
  if dp.image is None:
82
- raise ValueError("image cannot be None")
83
+ raise ImageError("image cannot be None")
83
84
  np_image = dp.image
84
85
  if self.padder:
85
86
  np_image = self.padder.apply_image(np_image)
@@ -114,5 +115,5 @@ class ImageLayoutService(PredictorPipelineComponent):
114
115
  if self.padder:
115
116
  padder_clone = self.padder.clone()
116
117
  if not isinstance(predictor, ObjectDetector):
117
- raise ValueError(f"predictor must be of type ObjectDetector, but is of type {type(predictor)}")
118
+ raise TypeError(f"predictor must be of type ObjectDetector, but is of type {type(predictor)}")
118
119
  return self.__class__(predictor, self.to_image, self.crop_image, padder_clone, self.skip_if_layout_extracted)
deepdoctection/pipe/lm.py CHANGED
@@ -252,7 +252,7 @@ class LMTokenClassifierService(LanguageModelPipelineComponent):
252
252
  self.language_model.model.__class__.__name__, use_xlm_tokenizer
253
253
  )
254
254
  if not isinstance(self.tokenizer, type(tokenizer_reference)):
255
- raise ValueError(
255
+ raise TypeError(
256
256
  f"You want to use {type(self.tokenizer)} but you should use {type(tokenizer_reference)} "
257
257
  f"in this framework"
258
258
  )
@@ -366,7 +366,7 @@ class LMSequenceClassifierService(LanguageModelPipelineComponent):
366
366
  self.language_model.model.__class__.__name__, use_xlm_tokenizer
367
367
  )
368
368
  if not isinstance(self.tokenizer, type(tokenizer_reference)):
369
- raise ValueError(
369
+ raise TypeError(
370
370
  f"You want to use {type(self.tokenizer)} but you should use {type(tokenizer_reference)} "
371
371
  f"in this framework"
372
372
  )
@@ -33,6 +33,7 @@ from ..datapoint.image import Image
33
33
  from ..extern.base import DetectionResult
34
34
  from ..mapper.maputils import MappingContextManager
35
35
  from ..utils.detection_types import JsonDict
36
+ from ..utils.error import AnnotationError, ImageError
36
37
  from ..utils.settings import CellType, LayoutType, Relationships, TableType, get_type
37
38
  from .base import PipelineComponent
38
39
  from .registry import pipeline_component_registry
@@ -302,7 +303,7 @@ def generate_html_string(table: ImageAnnotation) -> List[str]:
302
303
  :return: HTML representation of the table
303
304
  """
304
305
  if table.image is None:
305
- raise ValueError("table.image cannot be None")
306
+ raise ImageError("table.image cannot be None")
306
307
  table_image = table.image
307
308
  cells = table_image.get_annotation(
308
309
  category_names=[
@@ -412,7 +413,7 @@ class TableSegmentationRefinementService(PipelineComponent):
412
413
  tables = dp.get_annotation(category_names=self._table_name)
413
414
  for table in tables:
414
415
  if table.image is None:
415
- raise ValueError("table.image cannot be None")
416
+ raise ImageError("table.image cannot be None")
416
417
  tiles_to_cells_list = tiles_to_cells(dp, table)
417
418
  connected_components, tile_to_cell_dict = connected_component_tiles(tiles_to_cells_list)
418
419
  rectangle_tiling = generate_rectangle_tiling(connected_components)
@@ -464,14 +465,21 @@ class TableSegmentationRefinementService(PipelineComponent):
464
465
  max_col_span = max(int(cell.get_sub_category(CellType.column_span).category_id) for cell in cells)
465
466
  # TODO: the summaries should be sub categories of the underlying ann
466
467
  if table.image.summary is not None:
467
- if TableType.number_of_rows in table.image.summary.sub_categories:
468
- table.get_summary(TableType.number_of_rows)
469
- if TableType.number_of_columns in table.image.summary.sub_categories:
470
- table.get_summary(TableType.number_of_columns)
471
- if TableType.max_row_span in table.image.summary.sub_categories:
472
- table.get_summary(TableType.max_row_span)
473
- if TableType.max_col_span in table.image.summary.sub_categories:
474
- table.get_summary(TableType.max_col_span)
468
+ if (
469
+ TableType.number_of_rows in table.image.summary.sub_categories
470
+ and TableType.number_of_columns in table.image.summary.sub_categories
471
+ and TableType.max_row_span in table.image.summary.sub_categories
472
+ and TableType.max_col_span in table.image.summary.sub_categories
473
+ ):
474
+ table.image.summary.remove_sub_category(TableType.number_of_rows)
475
+ table.image.summary.remove_sub_category(TableType.number_of_columns)
476
+ table.image.summary.remove_sub_category(TableType.max_row_span)
477
+ table.image.summary.remove_sub_category(TableType.max_col_span)
478
+ else:
479
+ raise AnnotationError(
480
+ "Table summary does not contain sub categories TableType.number_of_rows, "
481
+ "TableType.number_of_columns, TableType.max_row_span, TableType.max_col_span"
482
+ )
475
483
 
476
484
  self.dp_manager.set_summary_annotation(
477
485
  TableType.number_of_rows, TableType.number_of_rows, number_of_rows, annotation_id=table.annotation_id
@@ -33,6 +33,7 @@ from ..extern.base import DetectionResult
33
33
  from ..mapper.maputils import MappingContextManager
34
34
  from ..mapper.match import match_anns_by_intersection
35
35
  from ..utils.detection_types import JsonDict
36
+ from ..utils.error import ImageError
36
37
  from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType
37
38
  from .base import PipelineComponent
38
39
  from .refine import generate_html_string
@@ -136,12 +137,12 @@ def stretch_item_per_table(
136
137
 
137
138
  rows = dp.get_annotation(category_names=row_name, annotation_ids=item_ann_ids)
138
139
  if table.image is None:
139
- raise ValueError("table.image cannot be None")
140
+ raise ImageError("table.image cannot be None")
140
141
  table_embedding_box = table.get_bounding_box(dp.image_id)
141
142
 
142
143
  for row in rows:
143
144
  if row.image is None:
144
- raise ValueError("row.image cannot be None")
145
+ raise ImageError("row.image cannot be None")
145
146
  row_embedding_box = row.get_bounding_box(dp.image_id)
146
147
  row_embedding_box.ulx = table_embedding_box.ulx + 1.0
147
148
  row_embedding_box.lrx = table_embedding_box.lrx - 1.0
@@ -166,7 +167,7 @@ def stretch_item_per_table(
166
167
 
167
168
  for col in cols:
168
169
  if col.image is None:
169
- raise ValueError("row.image cannot be None")
170
+ raise ImageError("row.image cannot be None")
170
171
  col_embedding_box = col.get_bounding_box(dp.image_id)
171
172
  col_embedding_box.uly = table_embedding_box.uly + 1.0
172
173
  col_embedding_box.lry = table_embedding_box.lry - 1.0
@@ -194,7 +195,7 @@ def _tile_by_stretching_rows_left_and_rightwise(
194
195
  dp: Image, items: List[ImageAnnotation], table: ImageAnnotation, item_name: str
195
196
  ) -> None:
196
197
  if table.image is None:
197
- raise ValueError("table.image cannot be None")
198
+ raise ImageError("table.image cannot be None")
198
199
  table_embedding_box = table.get_bounding_box(dp.image_id)
199
200
 
200
201
  tmp_item_xy = table_embedding_box.uly + 1.0 if item_name == LayoutType.row else table_embedding_box.ulx + 1.0
@@ -206,7 +207,7 @@ def _tile_by_stretching_rows_left_and_rightwise(
206
207
  image_annotation={"category_name": item.category_name, "annotation_id": item.annotation_id},
207
208
  ):
208
209
  if item.image is None:
209
- raise ValueError("item.image cannot be None")
210
+ raise ImageError("item.image cannot be None")
210
211
  item_embedding_box = item.get_bounding_box(dp.image_id)
211
212
  if idx != len(items) - 1:
212
213
  next_item_embedding_box = items[idx + 1].get_bounding_box(dp.image_id)
@@ -258,7 +259,7 @@ def _tile_by_stretching_rows_leftwise_column_downwise(
258
259
  dp: Image, items: List[ImageAnnotation], table: ImageAnnotation, item_name: str
259
260
  ) -> None:
260
261
  if table.image is None:
261
- raise ValueError("table.image cannot be None")
262
+ raise ImageError("table.image cannot be None")
262
263
  table_embedding_box = table.get_bounding_box(dp.image_id)
263
264
 
264
265
  tmp_item_xy = table_embedding_box.uly + 1.0 if item_name == LayoutType.row else table_embedding_box.ulx + 1.0
@@ -270,7 +271,7 @@ def _tile_by_stretching_rows_leftwise_column_downwise(
270
271
  image_annotation={"category_name": item.category_name, "annotation_id": item.annotation_id},
271
272
  ):
272
273
  if item.image is None:
273
- raise ValueError("item.image cannot be None")
274
+ raise ImageError("item.image cannot be None")
274
275
  item_embedding_box = item.get_bounding_box(dp.image_id)
275
276
  new_embedding_box = BoundingBox(
276
277
  ulx=item_embedding_box.ulx if item_name == LayoutType.row else tmp_item_xy,
@@ -339,9 +340,9 @@ def tile_tables_with_items_per_table(
339
340
  items = dp.get_annotation(category_names=item_name, annotation_ids=item_ann_ids)
340
341
 
341
342
  items.sort(
342
- key=lambda x: x.get_bounding_box(dp.image_id).cx
343
- if item_name == LayoutType.column
344
- else x.get_bounding_box(dp.image_id).cy
343
+ key=lambda x: (
344
+ x.get_bounding_box(dp.image_id).cx if item_name == LayoutType.column else x.get_bounding_box(dp.image_id).cy
345
+ )
345
346
  )
346
347
 
347
348
  if stretch_rule == "left":
@@ -737,9 +738,11 @@ class TableSegmentationService(PipelineComponent):
737
738
 
738
739
  # we will assume that either all or no image attribute has been generated
739
740
  items.sort(
740
- key=lambda x: x.get_bounding_box(dp.image_id).cx # pylint: disable=W0640
741
- if item_name == LayoutType.column # pylint: disable=W0640
742
- else x.get_bounding_box(dp.image_id).cy # pylint: disable=W0640
741
+ key=lambda x: (
742
+ x.get_bounding_box(dp.image_id).cx # pylint: disable=W0640
743
+ if item_name == LayoutType.column # pylint: disable=W0640
744
+ else x.get_bounding_box(dp.image_id).cy # pylint: disable=W0640
745
+ )
743
746
  )
744
747
 
745
748
  for item_number, item in enumerate(items, 1):
@@ -939,9 +942,11 @@ class PubtablesSegmentationService(PipelineComponent):
939
942
 
940
943
  # we will assume that either all or no image attribute has been generated
941
944
  items.sort(
942
- key=lambda x: x.get_bounding_box(dp.image_id).cx
943
- if item_name == LayoutType.column # pylint: disable=W0640
944
- else x.get_bounding_box(dp.image_id).cy
945
+ key=lambda x: (
946
+ x.get_bounding_box(dp.image_id).cx
947
+ if item_name == LayoutType.column # pylint: disable=W0640
948
+ else x.get_bounding_box(dp.image_id).cy
949
+ )
945
950
  )
946
951
 
947
952
  for item_number, item in enumerate(items, 1):
@@ -26,6 +26,7 @@ from ..datapoint.image import Image
26
26
  from ..extern.base import ObjectDetector, PdfMiner, TextRecognizer
27
27
  from ..extern.tessocr import TesseractOcrDetector
28
28
  from ..utils.detection_types import ImageType, JsonDict
29
+ from ..utils.error import ImageError
29
30
  from ..utils.settings import PageType, TypeOrStr, WordType, get_type
30
31
  from .base import PredictorPipelineComponent
31
32
  from .registry import pipeline_component_registry
@@ -89,7 +90,10 @@ class TextExtractionService(PredictorPipelineComponent):
89
90
  super().__init__(self._get_name(text_extract_detector.name), text_extract_detector)
90
91
  if self.extract_from_category:
91
92
  if not isinstance(self.predictor, (ObjectDetector, TextRecognizer)):
92
- raise TypeError("Predicting from a cropped image requires to pass an ObjectDetector or TextRecognizer.")
93
+ raise TypeError(
94
+ f"Predicting from a cropped image requires to pass an ObjectDetector or "
95
+ f"TextRecognizer. Got {type(self.predictor)}"
96
+ )
93
97
  if run_time_ocr_language_selection:
94
98
  assert isinstance(
95
99
  self.predictor, TesseractOcrDetector
@@ -171,13 +175,13 @@ class TextExtractionService(PredictorPipelineComponent):
171
175
 
172
176
  if isinstance(text_roi, ImageAnnotation):
173
177
  if text_roi.image is None:
174
- raise ValueError("text_roi.image cannot be None")
178
+ raise ImageError("text_roi.image cannot be None")
175
179
  if text_roi.image.image is None:
176
- raise ValueError("text_roi.image.image cannot be None")
180
+ raise ImageError("text_roi.image.image cannot be None")
177
181
  return text_roi.image.image
178
182
  if isinstance(self.predictor, ObjectDetector):
179
183
  if not isinstance(text_roi, Image):
180
- raise ValueError("text_roi must be an image")
184
+ raise ImageError("text_roi must be an image")
181
185
  return text_roi.image
182
186
  if isinstance(text_roi, list):
183
187
  assert all(roi.image is not None for roi in text_roi)
@@ -201,9 +205,11 @@ class TextExtractionService(PredictorPipelineComponent):
201
205
  [
202
206
  (
203
207
  "image_annotations",
204
- self.predictor.possible_categories()
205
- if isinstance(self.predictor, (ObjectDetector, PdfMiner))
206
- else [],
208
+ (
209
+ self.predictor.possible_categories()
210
+ if isinstance(self.predictor, (ObjectDetector, PdfMiner))
211
+ else []
212
+ ),
207
213
  ),
208
214
  ("sub_categories", sub_cat_dict),
209
215
  ("relationships", {}),
@@ -218,5 +224,5 @@ class TextExtractionService(PredictorPipelineComponent):
218
224
  def clone(self) -> "PredictorPipelineComponent":
219
225
  predictor = self.predictor.clone()
220
226
  if not isinstance(predictor, (ObjectDetector, PdfMiner, TextRecognizer)):
221
- raise ValueError(f"predictor must be of type ObjectDetector or PdfMiner, but is of type {type(predictor)}")
227
+ raise ImageError(f"predictor must be of type ObjectDetector or PdfMiner, but is of type {type(predictor)}")
222
228
  return self.__class__(predictor, deepcopy(self.extract_from_category), self.run_time_ocr_language_selection)
@@ -23,7 +23,6 @@ on images (e.g. deskew, de-noising or more general GAN like operations.
23
23
  from ..datapoint.image import Image
24
24
  from ..extern.base import ImageTransformer
25
25
  from ..utils.detection_types import JsonDict
26
- from ..utils.logger import LoggingRecord, logger
27
26
  from .base import ImageTransformPipelineComponent
28
27
  from .registry import pipeline_component_registry
29
28
 
@@ -49,16 +48,24 @@ class SimpleTransformService(ImageTransformPipelineComponent):
49
48
 
50
49
  def serve(self, dp: Image) -> None:
51
50
  if dp.annotations:
52
- logger.warning(
53
- LoggingRecord(
54
- f"{self.name} has already received image with image annotations. These annotations "
55
- f"will not be transformed and might cause unexpected output in your pipeline."
56
- )
51
+ raise RuntimeError(
52
+ "SimpleTransformService receives datapoints with ÌmageAnnotations. This violates the "
53
+ "pipeline building API but this can currently be catched only at runtime. "
54
+ "Please make sure that this component is the first one in the pipeline."
57
55
  )
56
+
58
57
  if dp.image is not None:
59
- np_image_transform = self.transform_predictor.transform(dp.image)
58
+ detection_result = self.transform_predictor.predict(dp.image)
59
+ transformed_image = self.transform_predictor.transform(dp.image, detection_result)
60
60
  self.dp_manager.datapoint.clear_image(True)
61
- self.dp_manager.datapoint.image = np_image_transform
61
+ self.dp_manager.datapoint.image = transformed_image
62
+ self.dp_manager.set_summary_annotation(
63
+ summary_key=self.transform_predictor.possible_category(),
64
+ summary_name=self.transform_predictor.possible_category(),
65
+ summary_number=None,
66
+ summary_value=getattr(detection_result, self.transform_predictor.possible_category().value, None),
67
+ summary_score=detection_result.score,
68
+ )
62
69
 
63
70
  def clone(self) -> "SimpleTransformService":
64
71
  return self.__class__(self.transform_predictor)
@@ -69,7 +76,7 @@ class SimpleTransformService(ImageTransformPipelineComponent):
69
76
  ("image_annotations", []),
70
77
  ("sub_categories", {}),
71
78
  ("relationships", {}),
72
- ("summaries", []),
79
+ ("summaries", [self.transform_predictor.possible_category()]),
73
80
  ]
74
81
  )
75
82
 
@@ -43,6 +43,7 @@ from ..extern.pt.ptutils import get_num_gpu
43
43
  from ..mapper.d2struct import image_to_d2_frcnn_training
44
44
  from ..pipe.base import PredictorPipelineComponent
45
45
  from ..pipe.registry import pipeline_component_registry
46
+ from ..utils.error import DependencyError
46
47
  from ..utils.file_utils import get_wandb_requirement, wandb_available
47
48
  from ..utils.logger import LoggingRecord, logger
48
49
  from ..utils.utils import string_to_dict
@@ -153,16 +154,18 @@ class D2Trainer(DefaultTrainer):
153
154
  ret = [
154
155
  hooks.IterationTimer(),
155
156
  hooks.LRScheduler(),
156
- hooks.PreciseBN(
157
- # Run at the same freq as (but before) evaluation.
158
- cfg.TEST.EVAL_PERIOD,
159
- self.model, # pylint: disable=E1101
160
- # Build a new data loader to not affect training
161
- self.build_train_loader(cfg),
162
- cfg.TEST.PRECISE_BN.NUM_ITER,
163
- )
164
- if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) # pylint: disable=E1101
165
- else None,
157
+ (
158
+ hooks.PreciseBN(
159
+ # Run at the same freq as (but before) evaluation.
160
+ cfg.TEST.EVAL_PERIOD,
161
+ self.model, # pylint: disable=E1101
162
+ # Build a new data loader to not affect training
163
+ self.build_train_loader(cfg),
164
+ cfg.TEST.PRECISE_BN.NUM_ITER,
165
+ )
166
+ if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) # pylint: disable=E1101
167
+ else None
168
+ ),
166
169
  ]
167
170
 
168
171
  # Do PreciseBN before checkpointer, because it updates the model and need to
@@ -201,7 +204,7 @@ class D2Trainer(DefaultTrainer):
201
204
  if self.cfg.WANDB.USE_WANDB:
202
205
  _, _wandb_available, err_msg = get_wandb_requirement()
203
206
  if not _wandb_available:
204
- raise ImportError(err_msg)
207
+ raise DependencyError(err_msg)
205
208
  if self.cfg.WANDB.PROJECT is None:
206
209
  raise ValueError("When using W&B, you must specify a project, i.e. WANDB.PROJECT")
207
210
  writers_list.append(WandbWriter(self.cfg.WANDB.PROJECT, self.cfg.WANDB.REPO, self.cfg))
@@ -269,7 +272,7 @@ class D2Trainer(DefaultTrainer):
269
272
 
270
273
  @classmethod
271
274
  def build_evaluator(cls, cfg, dataset_name): # type: ignore
272
- raise NotImplementedError
275
+ raise NotImplementedError()
273
276
 
274
277
 
275
278
  def train_d2_faster_rcnn(
@@ -97,9 +97,9 @@ class DetrDerivedTrainer(Trainer):
97
97
 
98
98
  def evaluate(
99
99
  self,
100
- eval_dataset: Optional[Dataset[Any]] = None,
101
- ignore_keys: Optional[List[str]] = None,
102
- metric_key_prefix: str = "eval",
100
+ eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
101
+ ignore_keys: Optional[List[str]] = None, # pylint: disable=W0613
102
+ metric_key_prefix: str = "eval", # pylint: disable=W0613
103
103
  ) -> Dict[str, float]:
104
104
  """
105
105
  Overwritten method from `Trainer`. Arguments will not be used.
@@ -193,9 +193,11 @@ def train_hf_detr(
193
193
  "remove_unused_columns": False,
194
194
  "per_device_train_batch_size": 2,
195
195
  "max_steps": number_samples,
196
- "evaluation_strategy": "steps"
197
- if (dataset_val is not None and metric is not None and pipeline_component_name is not None)
198
- else "no",
196
+ "evaluation_strategy": (
197
+ "steps"
198
+ if (dataset_val is not None and metric is not None and pipeline_component_name is not None)
199
+ else "no"
200
+ ),
199
201
  "eval_steps": 5000,
200
202
  }
201
203
 
@@ -63,6 +63,7 @@ from ..pipe.base import LanguageModelPipelineComponent
63
63
  from ..pipe.lm import get_tokenizer_from_architecture
64
64
  from ..pipe.registry import pipeline_component_registry
65
65
  from ..utils.env_info import get_device
66
+ from ..utils.error import DependencyError
66
67
  from ..utils.file_utils import wandb_available
67
68
  from ..utils.logger import LoggingRecord, logger
68
69
  from ..utils.settings import DatasetType, LayoutType, ObjectTypes, WordType
@@ -180,15 +181,17 @@ class LayoutLMTrainer(Trainer):
180
181
 
181
182
  def evaluate(
182
183
  self,
183
- eval_dataset: Optional[Dataset[Any]] = None,
184
- ignore_keys: Optional[List[str]] = None,
185
- metric_key_prefix: str = "eval",
184
+ eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
185
+ ignore_keys: Optional[List[str]] = None, # pylint: disable=W0613
186
+ metric_key_prefix: str = "eval", # pylint: disable=W0613
186
187
  ) -> Dict[str, float]:
187
188
  """
188
189
  Overwritten method from `Trainer`. Arguments will not be used.
189
190
  """
190
- assert self.evaluator is not None
191
- assert self.evaluator.pipe_component is not None
191
+ if self.evaluator is None:
192
+ raise ValueError("Evaluator not set up. Please use `setup_evaluator` before running evaluation")
193
+ if self.evaluator.pipe_component is None:
194
+ raise ValueError("Pipeline component not set up. Please use `setup_evaluator` before running evaluation")
192
195
 
193
196
  # memory metrics - must set up as early as possible
194
197
  self._memory_tracker.start()
@@ -222,7 +225,7 @@ def _get_model_class_and_tokenizer(
222
225
  raise KeyError("model_type and architectures not available in configs")
223
226
 
224
227
  if not model_cls:
225
- raise ValueError("model not eligible to run with this framework")
228
+ raise UserWarning("model not eligible to run with this framework")
226
229
 
227
230
  return config_cls, model_cls, model_wrapper_cls, tokenizer_fast
228
231
 
@@ -347,7 +350,7 @@ def train_hf_layoutlm(
347
350
  name_as_key=True,
348
351
  )[LayoutType.word][WordType.token_class]
349
352
  else:
350
- raise ValueError("Dataset type not supported for training")
353
+ raise UserWarning("Dataset type not supported for training")
351
354
 
352
355
  config_cls, model_cls, model_wrapper_cls, tokenizer_fast = _get_model_class_and_tokenizer(
353
356
  path_config_json, dataset_type, use_xlm_tokenizer
@@ -374,9 +377,11 @@ def train_hf_layoutlm(
374
377
  "remove_unused_columns": False,
375
378
  "per_device_train_batch_size": 8,
376
379
  "max_steps": number_samples,
377
- "evaluation_strategy": "steps"
378
- if (dataset_val is not None and metric is not None and pipeline_component_name is not None)
379
- else "no",
380
+ "evaluation_strategy": (
381
+ "steps"
382
+ if (dataset_val is not None and metric is not None and pipeline_component_name is not None)
383
+ else "no"
384
+ ),
380
385
  "eval_steps": 100,
381
386
  "use_wandb": False,
382
387
  "wandb_project": None,
@@ -416,7 +421,7 @@ def train_hf_layoutlm(
416
421
  run = None
417
422
  if use_wandb:
418
423
  if not wandb_available():
419
- raise ModuleNotFoundError("WandB must be installed separately")
424
+ raise DependencyError("WandB must be installed separately")
420
425
  run = wandb.init(project=wandb_project, config=conf_dict) # type: ignore
421
426
  run._label(repo=wandb_repo) # type: ignore # pylint: disable=W0212
422
427
  else:
@@ -6,7 +6,10 @@ Init file for utils package
6
6
  """
7
7
  from typing import Optional, Tuple, Union, no_type_check
8
8
 
9
+ from .concurrency import *
9
10
  from .context import *
11
+ from .env_info import *
12
+ from .error import *
10
13
  from .file_utils import *
11
14
  from .fs import *
12
15
  from .identifier import *
@@ -109,7 +109,7 @@ def enable_death_signal(_warn: bool = True) -> None:
109
109
  prctl, "set_pdeathsig"
110
110
  ), "prctl.set_pdeathsig does not exist! Note that you need to install 'python-prctl' instead of 'prctl'."
111
111
  # is SIGHUP a good choice?
112
- prctl.set_pdeathsig(signal.SIGHUP)
112
+ prctl.set_pdeathsig(signal.SIGHUP) # pylint: disable=E1101
113
113
 
114
114
 
115
115
  # taken from https://github.com/tensorpack/dataflow/blob/master/dataflow/utils/concurrency.py
@@ -61,7 +61,7 @@ def timeout_manager(proc, seconds: Optional[int] = None) -> Iterator[str]: # ty
61
61
  proc.terminate()
62
62
  proc.kill()
63
63
  proc.returncode = -1
64
- raise RuntimeError("Tesseract process timeout") # pylint: disable=W0707
64
+ raise RuntimeError(f"timeout for process id: {proc.pid}") # pylint: disable=W0707
65
65
  finally:
66
66
  if proc.stdin is not None:
67
67
  proc.stdin.close()
@@ -88,7 +88,7 @@ def save_tmp_file(image: Union[str, ImageType, bytes], prefix: str) -> Iterator[
88
88
  yield file.name, path.realpath(path.normpath(path.normcase(image)))
89
89
  return
90
90
  if isinstance(image, (np.ndarray, np.generic)):
91
- input_file_name = file.name + ".PNG"
91
+ input_file_name = file.name + "_input.PNG"
92
92
  viz_handler.write_image(input_file_name, image)
93
93
  yield file.name, input_file_name
94
94
  if isinstance(image, bytes):