deepdoctection 0.45.0__py3-none-any.whl → 0.46.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

@@ -25,7 +25,7 @@ from .utils.logger import LoggingRecord, logger
25
25
 
26
26
  # pylint: enable=wrong-import-position
27
27
 
28
- __version__ = "0.45.0"
28
+ __version__ = "0.46.1"
29
29
 
30
30
  _IMPORT_STRUCTURE = {
31
31
  "analyzer": ["config_sanity_checks", "get_dd_analyzer", "ServiceFactory", "update_cfg_from_defaults"],
@@ -271,6 +271,7 @@ _IMPORT_STRUCTURE = {
271
271
  "MultiThreadPipelineComponent",
272
272
  "DoctectionPipe",
273
273
  "LanguageDetectionService",
274
+ "skip_if_category_or_service_extracted",
274
275
  "ImageLayoutService",
275
276
  "LMTokenClassifierService",
276
277
  "LMSequenceClassifierService",
@@ -310,12 +311,14 @@ _IMPORT_STRUCTURE = {
310
311
  "get_tensorpack_requirement",
311
312
  "pytorch_available",
312
313
  "get_pytorch_requirement",
314
+ "pyzmq_available",
313
315
  "lxml_available",
314
316
  "get_lxml_requirement",
315
317
  "apted_available",
316
318
  "get_apted_requirement",
317
319
  "distance_available",
318
320
  "get_distance_requirement",
321
+ "networkx_available",
319
322
  "numpy_v1_available",
320
323
  "get_numpy_v1_requirement",
321
324
  "transformers_available",
@@ -526,6 +526,9 @@ cfg.USE_LM_SEQUENCE_CLASS = False
526
526
  # Enables a token classification pipeline component, e.g. a LayoutLM or Bert-like model
527
527
  cfg.USE_LM_TOKEN_CLASS = False
528
528
 
529
+ # Specifies the selection of the rotation model. There are two models available: A rotation estimator
530
+ # based on Tesseract ('tesseract'), and a rotation estimator based on DocTr ('doctr').
531
+ cfg.ROTATOR.MODEL = "tesseract"
529
532
 
530
533
  # Relevant when LIB = TF. Specifies the layout detection model.
531
534
  # This model should detect multiple or single objects across an entire page.
@@ -22,13 +22,13 @@
22
22
  from __future__ import annotations
23
23
 
24
24
  from os import environ
25
- from typing import TYPE_CHECKING, Union
25
+ from typing import TYPE_CHECKING, Literal, Union
26
26
 
27
27
  from lazy_imports import try_import
28
28
 
29
29
  from ..extern.base import ImageTransformer, ObjectDetector, PdfMiner
30
30
  from ..extern.d2detect import D2FrcnnDetector, D2FrcnnTracingDetector
31
- from ..extern.doctrocr import DoctrTextlineDetector, DoctrTextRecognizer
31
+ from ..extern.doctrocr import DocTrRotationTransformer, DoctrTextlineDetector, DoctrTextRecognizer
32
32
  from ..extern.hfdetr import HFDetrDerivedDetector
33
33
  from ..extern.hflayoutlm import (
34
34
  HFLayoutLmSequenceClassifier,
@@ -78,6 +78,7 @@ if TYPE_CHECKING:
78
78
  from ..extern.hflayoutlm import LayoutSequenceModels, LayoutTokenModels
79
79
  from ..extern.hflm import LmSequenceModels, LmTokenModels
80
80
 
81
+ RotationTransformer = Union[TesseractRotationTransformer, DocTrRotationTransformer]
81
82
 
82
83
  __all__ = [
83
84
  "ServiceFactory",
@@ -190,24 +191,32 @@ class ServiceFactory:
190
191
  return ServiceFactory._build_layout_detector(config, mode)
191
192
 
192
193
  @staticmethod
193
- def _build_rotation_detector() -> TesseractRotationTransformer:
194
+ def _build_rotation_detector(rotator_name: Literal["tesseract", "doctr"]) -> RotationTransformer:
194
195
  """
195
196
  Building a rotation detector.
196
197
 
197
198
  Returns:
198
199
  TesseractRotationTransformer: Rotation detector instance.
199
200
  """
200
- return TesseractRotationTransformer()
201
+
202
+ if rotator_name == "tesseract":
203
+ return TesseractRotationTransformer()
204
+ if rotator_name == "doctr":
205
+ return DocTrRotationTransformer()
206
+ raise ValueError(
207
+ f"You have chosen rotator_name: {rotator_name} which is not allowed. Only tesseract or "
208
+ f"doctr are allowed."
209
+ )
201
210
 
202
211
  @staticmethod
203
- def build_rotation_detector() -> TesseractRotationTransformer:
212
+ def build_rotation_detector(rotator_name: Literal["tesseract", "doctr"]) -> RotationTransformer:
204
213
  """
205
214
  Building a rotation detector.
206
215
 
207
216
  Returns:
208
217
  TesseractRotationTransformer: Rotation detector instance.
209
218
  """
210
- return ServiceFactory._build_rotation_detector()
219
+ return ServiceFactory._build_rotation_detector(rotator_name)
211
220
 
212
221
  @staticmethod
213
222
  def _build_transform_service(transform_predictor: ImageTransformer) -> SimpleTransformService:
@@ -1123,7 +1132,7 @@ class ServiceFactory:
1123
1132
  pipe_component_list: list[PipelineComponent] = []
1124
1133
 
1125
1134
  if config.USE_ROTATOR:
1126
- rotation_detector = ServiceFactory.build_rotation_detector()
1135
+ rotation_detector = ServiceFactory.build_rotation_detector(config.ROTATOR.MODEL)
1127
1136
  transform_service = ServiceFactory.build_transform_service(transform_predictor=rotation_detector)
1128
1137
  pipe_component_list.append(transform_service)
1129
1138
 
@@ -24,15 +24,19 @@ from abc import ABC, abstractmethod
24
24
  from contextlib import contextmanager
25
25
  from typing import Any, Callable, Iterator, no_type_check
26
26
 
27
- import zmq
27
+ from lazy_imports import try_import
28
28
 
29
29
  from ..utils.concurrency import StoppableThread, enable_death_signal, start_proc_mask_signal
30
30
  from ..utils.error import DataFlowTerminatedError
31
+ from ..utils.file_utils import pyzmq_available
31
32
  from ..utils.logger import LoggingRecord, logger
32
33
  from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow
33
34
  from .common import RepeatedData
34
35
  from .serialize import PickleSerializer
35
36
 
37
+ with try_import() as import_guard:
38
+ import zmq
39
+
36
40
 
37
41
  @no_type_check
38
42
  def del_weakref(x):
@@ -77,6 +81,8 @@ def _get_pipe_name(name):
77
81
 
78
82
  class _ParallelMapData(ProxyDataFlow, ABC):
79
83
  def __init__(self, df: DataFlow, buffer_size: int, strict: bool = False) -> None:
84
+ if not pyzmq_available():
85
+ raise ModuleNotFoundError("pyzmq is required for running parallel dataflows (multiprocess/multithread).")
80
86
  super().__init__(df)
81
87
  if buffer_size <= 0:
82
88
  raise ValueError(f"buffer_size must be a positive number, got {buffer_size}")
@@ -284,7 +284,7 @@ class BoundingBox:
284
284
  raise BoundingBoxError(
285
285
  f"bounding box must have height and width >0. Check coords "
286
286
  f"ulx: {self.ulx}, uly: {self.uly}, lrx: {self.lrx}, "
287
- f"lry: {self.lry}."
287
+ f"lry: {self.lry}, absolute_coords: {self.absolute_coords}"
288
288
  )
289
289
  if not self.absolute_coords and not (
290
290
  0 <= self.ulx <= 1 and 0 <= self.uly <= 1 and 0 <= self.lrx <= 1 and 0 <= self.lry <= 1
@@ -505,10 +505,10 @@ class BoundingBox:
505
505
  if self.absolute_coords:
506
506
  transformed_box = BoundingBox(
507
507
  absolute_coords=not self.absolute_coords,
508
- ulx=max(self.ulx / image_width, 0.0),
509
- uly=max(self.uly / image_height, 0.0),
510
- lrx=min(self.lrx / image_width, 1.0),
511
- lry=min(self.lry / image_height, 1.0),
508
+ ulx=min(max(self.ulx / image_width, 0.0), 1.0),
509
+ uly=min(max(self.uly / image_height, 0.0), 1.0),
510
+ lrx=max(min(self.lrx / image_width, 1.0), 0.0),
511
+ lry=max(min(self.lry / image_height, 1.0), 0.0),
512
512
  )
513
513
  else:
514
514
  transformed_box = BoundingBox(
@@ -36,7 +36,7 @@ from ..utils.logger import LoggingRecord, logger
36
36
  from ..utils.settings import ObjectTypes, SummaryType, get_type
37
37
  from ..utils.types import ImageDict, PathLikeOrStr, PixelValues
38
38
  from .annotation import Annotation, AnnotationMap, BoundingBox, CategoryAnnotation, ImageAnnotation
39
- from .box import crop_box_from_image, global_to_local_coords, intersection_box
39
+ from .box import BoxCoordinate, crop_box_from_image, global_to_local_coords, intersection_box
40
40
  from .convert import as_dict, convert_b64_to_np_array, convert_np_array_to_b64, convert_pdf_bytes_to_np_array_v2
41
41
 
42
42
 
@@ -318,7 +318,7 @@ class Image:
318
318
  return _Img(self.image)
319
319
 
320
320
  @property
321
- def width(self) -> float:
321
+ def width(self) -> BoxCoordinate:
322
322
  """
323
323
  `width`
324
324
  """
@@ -327,7 +327,7 @@ class Image:
327
327
  return self._bbox.width
328
328
 
329
329
  @property
330
- def height(self) -> float:
330
+ def height(self) -> BoxCoordinate:
331
331
  """
332
332
  `height`
333
333
  """
@@ -335,7 +335,7 @@ class Image:
335
335
  raise ImageError("Height not available. Call set_width_height first")
336
336
  return self._bbox.height
337
337
 
338
- def set_width_height(self, width: float, height: float) -> None:
338
+ def set_width_height(self, width: BoxCoordinate, height: BoxCoordinate) -> None:
339
339
  """
340
340
  Defines bounding box of the image if not already set. Use this, if you do not want to keep the image separated
341
341
  for memory reasons.
@@ -345,7 +345,7 @@ class Image:
345
345
  height: height of image
346
346
  """
347
347
  if self._bbox is None:
348
- self._bbox = BoundingBox(ulx=0.0, uly=0.0, height=height, width=width, absolute_coords=True)
348
+ self._bbox = BoundingBox(ulx=0, uly=0, height=height, width=width, absolute_coords=True)
349
349
  self._self_embedding()
350
350
 
351
351
  def set_embedding(self, image_id: str, bounding_box: BoundingBox) -> None:
@@ -428,6 +428,8 @@ class List(Layout):
428
428
  A list of words order by reading order. Words with no `reading_order` will not be returned"""
429
429
  try:
430
430
  list_items = self.list_items
431
+ if not list_items:
432
+ return super().get_ordered_words()
431
433
  all_words = []
432
434
  list_items.sort(key=lambda x: x.bbox[1])
433
435
  for list_item in list_items:
@@ -755,6 +757,8 @@ class Table(Layout):
755
757
  """
756
758
  try:
757
759
  cells = self.cells
760
+ if not cells:
761
+ return super().get_ordered_words()
758
762
  all_words = []
759
763
  cells.sort(key=lambda x: (x.ROW_NUMBER, x.COLUMN_NUMBER))
760
764
  for cell in cells:
@@ -1054,6 +1058,8 @@ class Page(Image):
1054
1058
  Returns:
1055
1059
  A `Page` instance with all annotations as `ImageAnnotationBaseView` subclasses.
1056
1060
  """
1061
+ if isinstance(image_orig, Page):
1062
+ raise ImageError("Page.from_image() cannot be called on a Page instance.")
1057
1063
 
1058
1064
  if text_container is None:
1059
1065
  text_container = IMAGE_DEFAULTS.TEXT_CONTAINER
@@ -1310,7 +1316,7 @@ class Page(Image):
1310
1316
  If `interactive=False` will return a `np.array`.
1311
1317
  """
1312
1318
 
1313
- category_names_list: list[Union[str, None]] = []
1319
+ category_names_list: list[Tuple[Union[str, None], Union[str, None]]] = []
1314
1320
  box_stack = []
1315
1321
  cells_found = False
1316
1322
 
@@ -1323,22 +1329,23 @@ class Page(Image):
1323
1329
  anns = self.get_annotation(category_names=list(debug_kwargs.keys()))
1324
1330
  for ann in anns:
1325
1331
  box_stack.append(self._ann_viz_bbox(ann))
1326
- category_names_list.append(str(getattr(ann, debug_kwargs[ann.category_name])))
1332
+ val = str(getattr(ann, debug_kwargs[ann.category_name]))
1333
+ category_names_list.append((val, val))
1327
1334
 
1328
1335
  if show_layouts and not debug_kwargs:
1329
1336
  for item in self.layouts:
1330
1337
  box_stack.append(self._ann_viz_bbox(item))
1331
- category_names_list.append(item.category_name.value)
1338
+ category_names_list.append((item.category_name.value, item.category_name.value))
1332
1339
 
1333
1340
  if show_figures and not debug_kwargs:
1334
1341
  for item in self.figures:
1335
1342
  box_stack.append(self._ann_viz_bbox(item))
1336
- category_names_list.append(item.category_name.value)
1343
+ category_names_list.append((item.category_name.value, item.category_name.value))
1337
1344
 
1338
1345
  if show_tables and not debug_kwargs:
1339
1346
  for table in self.tables:
1340
1347
  box_stack.append(self._ann_viz_bbox(table))
1341
- category_names_list.append(LayoutType.TABLE.value)
1348
+ category_names_list.append((LayoutType.TABLE.value, LayoutType.TABLE.value))
1342
1349
  if show_cells:
1343
1350
  for cell in table.cells:
1344
1351
  if cell.category_name in {
@@ -1347,21 +1354,21 @@ class Page(Image):
1347
1354
  }:
1348
1355
  cells_found = True
1349
1356
  box_stack.append(self._ann_viz_bbox(cell))
1350
- category_names_list.append(None)
1357
+ category_names_list.append((None, cell.category_name.value))
1351
1358
  if show_table_structure:
1352
1359
  rows = table.rows
1353
1360
  cols = table.columns
1354
1361
  for row in rows:
1355
1362
  box_stack.append(self._ann_viz_bbox(row))
1356
- category_names_list.append(None)
1363
+ category_names_list.append((None, row.category_name.value))
1357
1364
  for col in cols:
1358
1365
  box_stack.append(self._ann_viz_bbox(col))
1359
- category_names_list.append(None)
1366
+ category_names_list.append((None, col.category_name.value))
1360
1367
 
1361
1368
  if show_cells and not cells_found and not debug_kwargs:
1362
1369
  for ann in self.get_annotation(category_names=[LayoutType.CELL, CellType.SPANNING]):
1363
1370
  box_stack.append(self._ann_viz_bbox(ann))
1364
- category_names_list.append(None)
1371
+ category_names_list.append((None, ann.category_name.value))
1365
1372
 
1366
1373
  if show_words and not debug_kwargs:
1367
1374
  all_words = []
@@ -1379,22 +1386,36 @@ class Page(Image):
1379
1386
  for word in all_words:
1380
1387
  box_stack.append(self._ann_viz_bbox(word))
1381
1388
  if show_token_class:
1382
- category_names_list.append(word.token_class.value if word.token_class is not None else None)
1389
+ category_names_list.append(
1390
+ (word.token_class.value, word.token_class.value)
1391
+ if word.token_class is not None
1392
+ else (None, None)
1393
+ )
1383
1394
  else:
1384
- category_names_list.append(word.token_tag.value if word.token_tag is not None else None)
1395
+ category_names_list.append(
1396
+ (word.token_tag.value, word.token_tag.value) if word.token_tag is not None else (None, None)
1397
+ )
1385
1398
  else:
1386
1399
  for word in all_words:
1387
1400
  if word.token_class is not None and word.token_class != TokenClasses.OTHER:
1388
1401
  box_stack.append(self._ann_viz_bbox(word))
1389
1402
  if show_token_class:
1390
- category_names_list.append(word.token_class.value if word.token_class is not None else None)
1403
+ category_names_list.append(
1404
+ (word.token_class.value, word.token_class.value)
1405
+ if word.token_class is not None
1406
+ else (None, None)
1407
+ )
1391
1408
  else:
1392
- category_names_list.append(word.token_tag.value if word.token_tag is not None else None)
1409
+ category_names_list.append(
1410
+ (word.token_tag.value, word.token_tag.value)
1411
+ if word.token_tag is not None
1412
+ else (None, None)
1413
+ )
1393
1414
 
1394
1415
  if show_residual_layouts and not debug_kwargs:
1395
1416
  for item in self.residual_layouts:
1396
1417
  box_stack.append(item.bbox)
1397
- category_names_list.append(item.category_name.value)
1418
+ category_names_list.append((item.category_name.value, item.category_name.value))
1398
1419
 
1399
1420
  if self.image is not None:
1400
1421
  scale_fx = scaled_width / self.width
@@ -275,6 +275,7 @@ class CocoMetric(MetricBase):
275
275
  get the ultimate F1-score.
276
276
  f1_iou: Use with `f1_score=True` and reset the f1 iou threshold
277
277
  per_category: Whether to calculate metrics per category
278
+ per_category: If set to True, f1 score will be returned by each category.
278
279
  """
279
280
  if max_detections is not None:
280
281
  assert len(max_detections) == 3, max_detections
@@ -263,7 +263,7 @@ class PredictorBase(ABC):
263
263
  requirements = cls.get_requirements()
264
264
  name = cls.__name__ if hasattr(cls, "__name__") else cls.__class__.__name__
265
265
  if not all(requirement[1] for requirement in requirements):
266
- raise ImportError(
266
+ raise ModuleNotFoundError(
267
267
  "\n".join(
268
268
  [f"{name} has the following dependencies:"]
269
269
  + [requirement[2] for requirement in requirements if not requirement[1]]
@@ -334,6 +334,11 @@ class DetectionResult:
334
334
  block: block number. For reading order from some ocr predictors
335
335
  line: line number. For reading order from some ocr predictors
336
336
  uuid: uuid. For assigning detection result (e.g. text to image annotations)
337
+ relationships: A dictionary of relationships. Each key is a relationship type and each value is a list of
338
+ uuids of the related annotations.
339
+ angle: angle of rotation in degrees. Only used for text detection.
340
+ image_width: image width
341
+ image_height: image height
337
342
  """
338
343
 
339
344
  box: Optional[list[float]] = None
@@ -348,6 +353,8 @@ class DetectionResult:
348
353
  uuid: Optional[str] = None
349
354
  relationships: Optional[dict[str, Any]] = None
350
355
  angle: Optional[float] = None
356
+ image_width: Optional[Union[int, float]] = None
357
+ image_height: Optional[Union[int, float]] = None
351
358
 
352
359
 
353
360
  class ObjectDetector(PredictorBase, ABC):
@@ -24,9 +24,10 @@ from __future__ import annotations
24
24
  import os
25
25
  from abc import ABC
26
26
  from pathlib import Path
27
- from typing import Any, Literal, Mapping, Optional, Union
27
+ from typing import Any, Literal, Mapping, Optional, Sequence, Union
28
28
  from zipfile import ZipFile
29
29
 
30
+ import numpy as np
30
31
  from lazy_imports import try_import
31
32
 
32
33
  from ..utils.env_info import ENV_VARS_TRUE
@@ -39,6 +40,7 @@ from ..utils.file_utils import (
39
40
  )
40
41
  from ..utils.fs import load_json
41
42
  from ..utils.settings import LayoutType, ObjectTypes, PageType, TypeOrStr
43
+ from ..utils.transform import RotationTransform
42
44
  from ..utils.types import PathLikeOrStr, PixelValues, Requirement
43
45
  from ..utils.viz import viz_handler
44
46
  from .base import DetectionResult, ImageTransformer, ModelCategories, ObjectDetector, TextRecognizer
@@ -558,12 +560,13 @@ class DocTrRotationTransformer(ImageTransformer):
558
560
  """
559
561
  Args:
560
562
  number_contours: the number of contours used for the orientation estimation
561
- ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
563
+ ratio_threshold_for_lines: this is the ratio w/h used to discriminate lines
562
564
  """
563
565
  self.number_contours = number_contours
564
566
  self.ratio_threshold_for_lines = ratio_threshold_for_lines
565
567
  self.name = "doctr_rotation_transformer"
566
568
  self.model_id = self.get_model_id()
569
+ self.rotator = RotationTransform(360)
567
570
 
568
571
  def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
569
572
  """
@@ -579,6 +582,19 @@ class DocTrRotationTransformer(ImageTransformer):
579
582
  """
580
583
  return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
581
584
 
585
+ def transform_coords(self, detect_results: Sequence[DetectionResult]) -> Sequence[DetectionResult]:
586
+ if detect_results:
587
+ if detect_results[0].angle:
588
+ self.rotator.set_angle(detect_results[0].angle) # type: ignore
589
+ self.rotator.set_image_width(detect_results[0].image_width) # type: ignore
590
+ self.rotator.set_image_height(detect_results[0].image_height) # type: ignore
591
+ transformed_coords = self.rotator.apply_coords(
592
+ np.asarray([detect_result.box for detect_result in detect_results], dtype=float)
593
+ )
594
+ for idx, detect_result in enumerate(detect_results):
595
+ detect_result.box = transformed_coords[idx, :].tolist()
596
+ return detect_results
597
+
582
598
  def predict(self, np_img: PixelValues) -> DetectionResult:
583
599
  angle = estimate_orientation(
584
600
  np_img, n_ct=self.number_contours, ratio_threshold_for_lines=self.ratio_threshold_for_lines
@@ -1024,12 +1024,9 @@ class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
1024
1024
  else:
1025
1025
  raise ValueError(f"images must be list but is {type(images)}")
1026
1026
 
1027
- result = predict_sequence_classes_from_layoutlm(input_ids,
1028
- attention_mask,
1029
- token_type_ids,
1030
- boxes,
1031
- self.model,
1032
- images)
1027
+ result = predict_sequence_classes_from_layoutlm(
1028
+ input_ids, attention_mask, token_type_ids, boxes, self.model, images
1029
+ )
1033
1030
 
1034
1031
  result.class_id += 1
1035
1032
  result.class_name = self.categories.categories[result.class_id]
@@ -1123,12 +1120,9 @@ class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
1123
1120
  else:
1124
1121
  raise ValueError(f"images must be list but is {type(images)}")
1125
1122
 
1126
- result = predict_sequence_classes_from_layoutlm(input_ids,
1127
- attention_mask,
1128
- token_type_ids,
1129
- boxes,
1130
- self.model,
1131
- images)
1123
+ result = predict_sequence_classes_from_layoutlm(
1124
+ input_ids, attention_mask, token_type_ids, boxes, self.model, images
1125
+ )
1132
1126
 
1133
1127
  result.class_id += 1
1134
1128
  result.class_name = self.categories.categories[result.class_id]
@@ -28,8 +28,9 @@ from errno import ENOENT
28
28
  from itertools import groupby
29
29
  from os import environ, fspath
30
30
  from pathlib import Path
31
- from typing import Any, Mapping, Optional, Union
31
+ from typing import Any, Mapping, Optional, Sequence, Union
32
32
 
33
+ import numpy as np
33
34
  from packaging.version import InvalidVersion, Version, parse
34
35
 
35
36
  from ..utils.context import save_tmp_file, timeout_manager
@@ -37,6 +38,7 @@ from ..utils.error import DependencyError, TesseractError
37
38
  from ..utils.file_utils import _TESS_PATH, get_tesseract_requirement
38
39
  from ..utils.metacfg import config_to_cli_str, set_config_by_yaml
39
40
  from ..utils.settings import LayoutType, ObjectTypes, PageType
41
+ from ..utils.transform import RotationTransform
40
42
  from ..utils.types import PathLikeOrStr, PixelValues, Requirement
41
43
  from ..utils.viz import viz_handler
42
44
  from .base import DetectionResult, ImageTransformer, ModelCategories, ObjectDetector
@@ -450,6 +452,7 @@ class TesseractRotationTransformer(ImageTransformer):
450
452
  self.name = fspath(_TESS_PATH) + "-rotation"
451
453
  self.categories = ModelCategories(init_categories={1: PageType.ANGLE})
452
454
  self.model_id = self.get_model_id()
455
+ self.rotator = RotationTransform(360)
453
456
 
454
457
  def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
455
458
  """
@@ -465,6 +468,19 @@ class TesseractRotationTransformer(ImageTransformer):
465
468
  """
466
469
  return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
467
470
 
471
+ def transform_coords(self, detect_results: Sequence[DetectionResult]) -> Sequence[DetectionResult]:
472
+ if detect_results:
473
+ if detect_results[0].angle:
474
+ self.rotator.set_angle(detect_results[0].angle) # type: ignore
475
+ self.rotator.set_image_width(detect_results[0].image_width) # type: ignore
476
+ self.rotator.set_image_height(detect_results[0].image_height) # type: ignore
477
+ transformed_coords = self.rotator.apply_coords(
478
+ np.asarray([detect_result.box for detect_result in detect_results], dtype=float)
479
+ )
480
+ for idx, detect_result in enumerate(detect_results):
481
+ detect_result.box = transformed_coords[idx, :].tolist()
482
+ return detect_results
483
+
468
484
  def predict(self, np_img: PixelValues) -> DetectionResult:
469
485
  """
470
486
  Determines the angle of the rotated image. It can only handle angles that are multiples of 90 degrees.
@@ -228,8 +228,8 @@ class OrderGenerator:
228
228
  columns: list[BoundingBox] = []
229
229
  anns.sort(
230
230
  key=lambda x: (
231
- x.bounding_box.transform(image_width, image_height).cy, # type: ignore
232
- x.bounding_box.transform(image_width, image_height).cx, # type: ignore
231
+ x.get_bounding_box(image_id).transform(image_width, image_height).cy,
232
+ x.get_bounding_box(image_id).transform(image_width, image_height).cx,
233
233
  )
234
234
  )
235
235
  for ann in anns:
@@ -309,7 +309,9 @@ class OrderGenerator:
309
309
  filtered_blocks: Sequence[tuple[int, str]]
310
310
  for idx in range(max_block_number + 1):
311
311
  filtered_blocks = list(filter(lambda x: x[0] == idx, blocks)) # type: ignore # pylint: disable=W0640
312
- sorted_blocks.extend(self._sort_anns_grouped_by_blocks(filtered_blocks, anns, image_width, image_height))
312
+ sorted_blocks.extend(
313
+ self._sort_anns_grouped_by_blocks(filtered_blocks, anns, image_width, image_height, image_id)
314
+ )
313
315
  reading_blocks = [(idx + 1, block[1]) for idx, block in enumerate(sorted_blocks)]
314
316
 
315
317
  if logger.isEnabledFor(DEBUG):
@@ -346,7 +348,11 @@ class OrderGenerator:
346
348
 
347
349
  @staticmethod
348
350
  def _sort_anns_grouped_by_blocks(
349
- block: Sequence[tuple[int, str]], anns: Sequence[ImageAnnotation], image_width: float, image_height: float
351
+ block: Sequence[tuple[int, str]],
352
+ anns: Sequence[ImageAnnotation],
353
+ image_width: float,
354
+ image_height: float,
355
+ image_id: Optional[str] = None,
350
356
  ) -> list[tuple[int, str]]:
351
357
  if not block:
352
358
  return []
@@ -356,8 +362,8 @@ class OrderGenerator:
356
362
  block_anns = [ann for ann in anns if ann.annotation_id in ann_ids]
357
363
  block_anns.sort(
358
364
  key=lambda x: (
359
- round(x.bounding_box.transform(image_width, image_height).uly, 2), # type: ignore
360
- round(x.bounding_box.transform(image_width, image_height).ulx, 2), # type: ignore
365
+ round(x.get_bounding_box(image_id).transform(image_width, image_height).uly, 2),
366
+ round(x.get_bounding_box(image_id).transform(image_width, image_height).ulx, 2),
361
367
  )
362
368
  )
363
369
  return [(block_number, ann.annotation_id) for ann in block_anns]
@@ -27,7 +27,7 @@ from dataclasses import asdict
27
27
  from itertools import chain, product
28
28
  from typing import DefaultDict, Optional, Sequence, Union
29
29
 
30
- import networkx as nx # type: ignore
30
+ from lazy_imports import try_import
31
31
 
32
32
  from ..datapoint.annotation import ImageAnnotation
33
33
  from ..datapoint.box import merge_boxes
@@ -35,10 +35,15 @@ from ..datapoint.image import Image, MetaAnnotation
35
35
  from ..extern.base import DetectionResult
36
36
  from ..mapper.maputils import MappingContextManager
37
37
  from ..utils.error import ImageError
38
+ from ..utils.file_utils import networkx_available
38
39
  from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType, get_type
39
40
  from .base import PipelineComponent
40
41
  from .registry import pipeline_component_registry
41
42
 
43
+ with try_import() as import_guard:
44
+ import networkx as nx # type: ignore
45
+
46
+
42
47
  __all__ = ["TableSegmentationRefinementService", "generate_html_string"]
43
48
 
44
49
 
@@ -441,6 +446,10 @@ class TableSegmentationRefinementService(PipelineComponent):
441
446
  table_names: Sequence of table object types.
442
447
  cell_names: Sequence of cell object types.
443
448
  """
449
+ if not networkx_available():
450
+ raise ModuleNotFoundError(
451
+ "TableSegmentationRefinementService requires networkx. Please install separately."
452
+ )
444
453
  self.table_name = table_names
445
454
  self.cell_names = cell_names
446
455
  super().__init__("table_segment_refine")
@@ -129,6 +129,12 @@ class TextExtractionService(PipelineComponent):
129
129
  width, height = self.predictor.get_width_height(predictor_input) # type: ignore
130
130
 
131
131
  for detect_result in detect_result_list:
132
+ if width is not None and height is not None:
133
+ box = detect_result.box
134
+ if box:
135
+ if box[0] >= width or box[1] >= height or box[2] >= width or box[3] >= height:
136
+ continue
137
+
132
138
  if isinstance(self.predictor, TextRecognizer):
133
139
  detect_ann_id = detect_result.uuid
134
140
  else:
@@ -77,6 +77,9 @@ class SimpleTransformService(PipelineComponent):
77
77
  score=ann.score,
78
78
  class_id=ann.category_id,
79
79
  uuid=ann.annotation_id,
80
+ angle=detection_result.angle,
81
+ image_width=dp.width, # we need the original width, not the transformed width
82
+ image_height=dp.height, # same with height
80
83
  )
81
84
  )
82
85
  output_detect_results = self.transform_predictor.transform_coords(detect_results)