deepdoctection 0.39.7__py3-none-any.whl → 0.41.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

@@ -22,6 +22,7 @@ from __future__ import annotations
22
22
 
23
23
  import os
24
24
  from copy import deepcopy
25
+ from dataclasses import dataclass, field
25
26
  from typing import Literal, Mapping, Optional, Sequence, Union
26
27
 
27
28
  import numpy as np
@@ -49,24 +50,31 @@ class ImageCroppingService(PipelineComponent):
49
50
  generally not stored.
50
51
  """
51
52
 
52
- def __init__(self, category_names: Union[TypeOrStr, Sequence[TypeOrStr]]):
53
+ def __init__(
54
+ self,
55
+ category_names: Optional[Union[TypeOrStr, Sequence[TypeOrStr]]] = None,
56
+ service_ids: Optional[Sequence[str]] = None,
57
+ ) -> None:
53
58
  """
54
59
  :param category_names: A single name or a list of category names to crop
55
60
  """
56
-
57
- self.category_names = (
58
- (category_names,)
59
- if isinstance(category_names, str)
60
- else tuple(get_type(category_name) for category_name in category_names)
61
- )
61
+ if category_names is None:
62
+ self.category_names = None
63
+ else:
64
+ self.category_names = (
65
+ (category_names,)
66
+ if isinstance(category_names, str)
67
+ else tuple(get_type(category_name) for category_name in category_names)
68
+ )
69
+ self.service_ids = service_ids
62
70
  super().__init__("image_crop")
63
71
 
64
72
  def serve(self, dp: Image) -> None:
65
- for ann in dp.get_annotation(category_names=self.category_names):
73
+ for ann in dp.get_annotation(category_names=self.category_names, service_ids=self.service_ids):
66
74
  dp.image_ann_to_image(ann.annotation_id, crop_image=True)
67
75
 
68
76
  def clone(self) -> ImageCroppingService:
69
- return self.__class__(self.category_names)
77
+ return self.__class__(self.category_names, self.service_ids)
70
78
 
71
79
  def get_meta_annotation(self) -> MetaAnnotation:
72
80
  return MetaAnnotation(image_annotations=(), sub_categories={}, relationships={}, summaries=())
@@ -124,8 +132,10 @@ class IntersectionMatcher:
124
132
  def match(
125
133
  self,
126
134
  dp: Image,
127
- parent_categories: Union[TypeOrStr, Sequence[TypeOrStr]],
128
- child_categories: Union[TypeOrStr, Sequence[TypeOrStr]],
135
+ parent_categories: Optional[Union[TypeOrStr, Sequence[TypeOrStr]]] = None,
136
+ child_categories: Optional[Union[TypeOrStr, Sequence[TypeOrStr]]] = None,
137
+ parent_ann_service_ids: Optional[Union[str, Sequence[str]]] = None,
138
+ child_ann_service_ids: Optional[Union[str, Sequence[str]]] = None,
129
139
  ) -> list[tuple[str, str]]:
130
140
  """
131
141
  The matching algorithm
@@ -133,6 +143,10 @@ class IntersectionMatcher:
133
143
  :param dp: datapoint image
134
144
  :param parent_categories: list of categories to be used a for parent class. Will generate a child-relationship
135
145
  :param child_categories: list of categories to be used for a child class.
146
+ :param parent_ann_service_ids: Additional filter condition. If some ids are selected, it will ignore all other
147
+ parent candidates which are not in the list.
148
+ :param child_ann_service_ids: Additional filter condition. If some ids are selected, it will ignore all other
149
+ children candidates which are not in the list.
136
150
 
137
151
  :return: A list of tuples with parent and child annotation ids
138
152
  """
@@ -144,6 +158,8 @@ class IntersectionMatcher:
144
158
  threshold=self.threshold,
145
159
  use_weighted_intersections=self.use_weighted_intersections,
146
160
  max_parent_only=self.max_parent_only,
161
+ parent_ann_service_ids=parent_ann_service_ids,
162
+ child_ann_service_ids=child_ann_service_ids,
147
163
  )
148
164
 
149
165
  matched_child_anns = np.take(child_anns, child_index) # type: ignore
@@ -174,8 +190,10 @@ class NeighbourMatcher:
174
190
  def match(
175
191
  self,
176
192
  dp: Image,
177
- parent_categories: Union[TypeOrStr, Sequence[TypeOrStr]],
178
- child_categories: Union[TypeOrStr, Sequence[TypeOrStr]],
193
+ parent_categories: Optional[Union[TypeOrStr, Sequence[TypeOrStr]]] = None,
194
+ child_categories: Optional[Union[TypeOrStr, Sequence[TypeOrStr]]] = None,
195
+ parent_ann_service_ids: Optional[Union[str, Sequence[str]]] = None,
196
+ child_ann_service_ids: Optional[Union[str, Sequence[str]]] = None,
179
197
  ) -> list[tuple[str, str]]:
180
198
  """
181
199
  The matching algorithm
@@ -183,16 +201,54 @@ class NeighbourMatcher:
183
201
  :param dp: datapoint image
184
202
  :param parent_categories: list of categories to be used a for parent class. Will generate a child-relationship
185
203
  :param child_categories: list of categories to be used for a child class.
204
+ :param parent_ann_service_ids: Additional filter condition. If some ids are selected, it will ignore all other
205
+ parent candidates which are not in the list.
206
+ :param child_ann_service_ids: Additional filter condition. If some ids are selected, it will ignore all other
207
+ children candidates which are not in the list.
186
208
 
187
209
  :return: A list of tuples with parent and child annotation ids
188
210
  """
189
211
 
190
212
  return [
191
213
  (pair[0].annotation_id, pair[1].annotation_id)
192
- for pair in match_anns_by_distance(dp, parent_categories, child_categories)
214
+ for pair in match_anns_by_distance(
215
+ dp,
216
+ parent_ann_category_names=parent_categories,
217
+ child_ann_category_names=child_categories,
218
+ parent_ann_service_ids=parent_ann_service_ids,
219
+ child_ann_service_ids=child_ann_service_ids,
220
+ )
193
221
  ]
194
222
 
195
223
 
224
+ @dataclass
225
+ class FamilyCompound:
226
+ """
227
+ A family compound is a set of parent and child categories that are related by a relationship key. The parent
228
+ categories will receive a relationship to the child categories.
229
+ """
230
+
231
+ relationship_key: Relationships
232
+ parent_categories: Optional[Union[ObjectTypes, Sequence[ObjectTypes]]] = field(default=None)
233
+ child_categories: Optional[Union[ObjectTypes, Sequence[ObjectTypes]]] = field(default=None)
234
+ parent_ann_service_ids: Optional[Union[str, Sequence[str]]] = field(default=None)
235
+ child_ann_service_ids: Optional[Union[str, Sequence[str]]] = field(default=None)
236
+
237
+ def __post_init__(self) -> None:
238
+ if isinstance(self.parent_categories, str):
239
+ self.parent_categories = (get_type(self.parent_categories),)
240
+ elif self.parent_categories is not None:
241
+ self.parent_categories = tuple(get_type(parent) for parent in self.parent_categories)
242
+ if isinstance(self.child_categories, str):
243
+ self.child_categories = (get_type(self.child_categories),)
244
+ elif self.child_categories is not None:
245
+ self.child_categories = tuple(get_type(child) for child in self.child_categories)
246
+ if isinstance(self.parent_ann_service_ids, str):
247
+ self.parent_ann_service_ids = (self.parent_ann_service_ids,)
248
+ if isinstance(self.child_ann_service_ids, str):
249
+ self.child_ann_service_ids = (self.child_ann_service_ids,)
250
+
251
+
196
252
  @pipeline_component_registry.register("MatchingService")
197
253
  class MatchingService(PipelineComponent):
198
254
  """
@@ -202,28 +258,15 @@ class MatchingService(PipelineComponent):
202
258
 
203
259
  def __init__(
204
260
  self,
205
- parent_categories: Union[TypeOrStr, Sequence[TypeOrStr]],
206
- child_categories: Union[TypeOrStr, Sequence[TypeOrStr]],
261
+ family_compounds: Sequence[FamilyCompound],
207
262
  matcher: Union[IntersectionMatcher, NeighbourMatcher],
208
- relationship_key: Relationships,
209
263
  ) -> None:
210
264
  """
211
- :param parent_categories: list of categories to be used a for parent class. Will generate a child-relationship
212
- :param child_categories: list of categories to be used for a child class.
213
-
265
+ :param family_compounds: A list of FamilyCompounds
266
+ :param matcher: A matcher object
214
267
  """
215
- self.parent_categories = (
216
- (get_type(parent_categories),)
217
- if isinstance(parent_categories, str)
218
- else tuple(get_type(category_name) for category_name in parent_categories)
219
- )
220
- self.child_categories = (
221
- (get_type(child_categories),)
222
- if isinstance(child_categories, str)
223
- else (tuple(get_type(category_name) for category_name in child_categories))
224
- )
268
+ self.family_compounds = family_compounds
225
269
  self.matcher = matcher
226
- self.relationship_key = relationship_key
227
270
  super().__init__("matching")
228
271
 
229
272
  def serve(self, dp: Image) -> None:
@@ -233,20 +276,31 @@ class MatchingService(PipelineComponent):
233
276
 
234
277
  :param dp: datapoint image
235
278
  """
236
-
237
- matched_pairs = self.matcher.match(dp, self.parent_categories, self.child_categories)
238
-
239
- for pair in matched_pairs:
240
- self.dp_manager.set_relationship_annotation(self.relationship_key, pair[0], pair[1])
279
+ for family_compound in self.family_compounds:
280
+ matched_pairs = self.matcher.match(
281
+ dp,
282
+ parent_categories=family_compound.parent_categories,
283
+ child_categories=family_compound.child_categories,
284
+ parent_ann_service_ids=family_compound.parent_ann_service_ids,
285
+ child_ann_service_ids=family_compound.child_ann_service_ids,
286
+ )
287
+
288
+ for pair in matched_pairs:
289
+ self.dp_manager.set_relationship_annotation(family_compound.relationship_key, pair[0], pair[1])
241
290
 
242
291
  def clone(self) -> PipelineComponent:
243
- return self.__class__(self.parent_categories, self.child_categories, self.matcher, self.relationship_key)
292
+ return self.__class__(self.family_compounds, self.matcher)
244
293
 
245
294
  def get_meta_annotation(self) -> MetaAnnotation:
295
+ relationships: dict[ObjectTypes, set[ObjectTypes]] = {}
296
+ for family_compound in self.family_compounds:
297
+ if family_compound.parent_categories is not None:
298
+ for parent_category in family_compound.parent_categories:
299
+ relationships[parent_category] = {family_compound.relationship_key} # type: ignore
246
300
  return MetaAnnotation(
247
301
  image_annotations=(),
248
302
  sub_categories={},
249
- relationships={parent: {Relationships.CHILD} for parent in self.parent_categories},
303
+ relationships=relationships,
250
304
  summaries=(),
251
305
  )
252
306
 
@@ -20,18 +20,41 @@ Module for layout pipeline component
20
20
  """
21
21
  from __future__ import annotations
22
22
 
23
- from typing import Optional
23
+ from typing import Optional, Sequence, Union
24
24
 
25
25
  import numpy as np
26
26
 
27
27
  from ..datapoint.image import Image
28
28
  from ..extern.base import ObjectDetector, PdfMiner
29
+ from ..mapper.misc import curry
29
30
  from ..utils.error import ImageError
31
+ from ..utils.settings import ObjectTypes
30
32
  from ..utils.transform import PadTransform
31
33
  from .base import MetaAnnotation, PipelineComponent
32
34
  from .registry import pipeline_component_registry
33
35
 
34
36
 
37
+ @curry
38
+ def skip_if_category_or_service_extracted(
39
+ dp: Image,
40
+ category_names: Optional[Union[str, Sequence[ObjectTypes]]] = None,
41
+ service_ids: Optional[Union[str, Sequence[str]]] = None,
42
+ ) -> bool:
43
+ """
44
+ Skip the processing of the pipeline component if the category or service is already extracted.
45
+
46
+ **Example**
47
+
48
+ detector = # some detector
49
+ item_component = ImageLayoutService(detector)
50
+ item_component.set_inbound_filter(skip_if_category_or_service_extracted(detector.get_categories(as_dict=False)))
51
+ """
52
+
53
+ if dp.get_annotation(category_names=category_names, service_ids=service_ids):
54
+ return True
55
+ return False
56
+
57
+
35
58
  @pipeline_component_registry.register("ImageLayoutService")
36
59
  class ImageLayoutService(PipelineComponent):
37
60
  """
@@ -45,7 +68,7 @@ class ImageLayoutService(PipelineComponent):
45
68
 
46
69
  **Example**
47
70
 
48
- d_items = TPFrcnnDetector(item_config_path, item_weights_path, {"1": "ROW", "2": "COLUMNS"})
71
+ d_items = TPFrcnnDetector(item_config_path, item_weights_path, {1: 'row', 2: 'column'})
49
72
  item_component = ImageLayoutService(d_items)
50
73
  """
51
74
 
@@ -55,7 +78,6 @@ class ImageLayoutService(PipelineComponent):
55
78
  to_image: bool = False,
56
79
  crop_image: bool = False,
57
80
  padder: Optional[PadTransform] = None,
58
- skip_if_layout_extracted: bool = False,
59
81
  ):
60
82
  """
61
83
  :param layout_detector: object detector
@@ -65,23 +87,14 @@ class ImageLayoutService(PipelineComponent):
65
87
  to its bounding box and populate the resulting sub image to
66
88
  `ImageAnnotation.image.image`.
67
89
  :param padder: If not `None`, will apply the padder to the image before prediction and inverse apply the padder
68
- :param skip_if_layout_extracted: When `True` will check, if there are already `ImageAnnotation` of a category
69
- available that will be predicted by the `layout_detector`. If yes, will skip
70
- the prediction process.
71
90
  """
72
91
  self.to_image = to_image
73
92
  self.crop_image = crop_image
74
93
  self.padder = padder
75
- self.skip_if_layout_extracted = skip_if_layout_extracted
76
94
  self.predictor = layout_detector
77
95
  super().__init__(self._get_name(layout_detector.name), self.predictor.model_id)
78
96
 
79
97
  def serve(self, dp: Image) -> None:
80
- if self.skip_if_layout_extracted:
81
- categories = self.predictor.get_category_names()
82
- anns = dp.get_annotation(category_names=categories)
83
- if anns:
84
- return
85
98
  if dp.image is None:
86
99
  raise ImageError("image cannot be None")
87
100
  np_image = dp.image
@@ -117,7 +130,7 @@ class ImageLayoutService(PipelineComponent):
117
130
  padder_clone = self.padder.clone()
118
131
  if not isinstance(predictor, ObjectDetector):
119
132
  raise TypeError(f"predictor must be of type ObjectDetector, but is of type {type(predictor)}")
120
- return self.__class__(predictor, self.to_image, self.crop_image, padder_clone, self.skip_if_layout_extracted)
133
+ return self.__class__(predictor, self.to_image, self.crop_image, padder_clone)
121
134
 
122
135
  def clear_predictor(self) -> None:
123
136
  self.predictor.clear_model()
@@ -347,19 +347,15 @@ class TextLineGenerator:
347
347
  a paragraph break threshold. This allows to detect a multi column structure just by observing sub lines.
348
348
  """
349
349
 
350
- def __init__(
351
- self, make_sub_lines: bool, line_category_id: Union[int, str], paragraph_break: Optional[float] = None
352
- ):
350
+ def __init__(self, make_sub_lines: bool, paragraph_break: Optional[float] = None):
353
351
  """
354
352
  :param make_sub_lines: Whether to build sub lines from lines.
355
- :param line_category_id: category_id to give a text line
356
353
  :param paragraph_break: threshold of two consecutive words. If distance is larger than threshold, two sub-lines
357
354
  will be built. We use relative coordinates to calculate the distance between two
358
355
  consecutive words. A reasonable value is 0.035
359
356
  """
360
357
  if make_sub_lines and paragraph_break is None:
361
358
  raise ValueError("You must specify paragraph_break when setting make_sub_lines to True")
362
- self.line_category_id = int(line_category_id)
363
359
  self.make_sub_lines = make_sub_lines
364
360
  self.paragraph_break = paragraph_break
365
361
 
@@ -367,7 +363,6 @@ class TextLineGenerator:
367
363
  return DetectionResult(
368
364
  box=box.to_list(mode="xyxy"),
369
365
  class_name=LayoutType.LINE,
370
- class_id=self.line_category_id,
371
366
  absolute_coords=box.absolute_coords,
372
367
  relationships=relationships,
373
368
  )
@@ -475,18 +470,14 @@ class TextLineServiceMixin(PipelineComponent, ABC):
475
470
  def __init__(
476
471
  self,
477
472
  name: str,
478
- line_category_id: int = 1,
479
473
  include_residual_text_container: bool = True,
480
474
  paragraph_break: Optional[float] = None,
481
475
  ):
482
476
  """
483
- Initialize the TextLineService with a line_category_id and a TextLineGenerator instance.
477
+ Initialize the TextLineServiceMixin with a TextLineGenerator instance.
484
478
  """
485
- self.line_category_id = line_category_id
486
479
  self.include_residual_text_container = include_residual_text_container
487
- self.text_line_generator = TextLineGenerator(
488
- self.include_residual_text_container, self.line_category_id, paragraph_break
489
- )
480
+ self.text_line_generator = TextLineGenerator(self.include_residual_text_container, paragraph_break)
490
481
  super().__init__(name)
491
482
 
492
483
  def _create_lines_for_words(self, word_anns: Sequence[ImageAnnotation]) -> Sequence[ImageAnnotation]:
@@ -523,17 +514,15 @@ class TextLineService(TextLineServiceMixin):
523
514
  text lines and the words contained in the text lines. The reading order is not arranged.
524
515
  """
525
516
 
526
- def __init__(self, line_category_id: int = 1, paragraph_break: Optional[float] = None):
517
+ def __init__(self, paragraph_break: Optional[float] = None):
527
518
  """
528
519
  Initialize `TextLineService`
529
520
 
530
- :param line_category_id: category_id to give a text line
531
521
  :param paragraph_break: threshold of two consecutive words. If distance is larger than threshold, two sublines
532
522
  will be built
533
523
  """
534
524
  super().__init__(
535
525
  name="text_line",
536
- line_category_id=line_category_id,
537
526
  include_residual_text_container=True,
538
527
  paragraph_break=paragraph_break,
539
528
  )
@@ -542,7 +531,7 @@ class TextLineService(TextLineServiceMixin):
542
531
  """
543
532
  This method returns a new instance of the class with the same configuration.
544
533
  """
545
- return self.__class__(self.line_category_id, self.text_line_generator.paragraph_break)
534
+ return self.__class__(self.text_line_generator.paragraph_break)
546
535
 
547
536
  def serve(self, dp: Image) -> None:
548
537
  text_container_anns = dp.get_annotation(category_names=LayoutType.WORD)
@@ -605,7 +594,6 @@ class TextOrderService(TextLineServiceMixin):
605
594
  broken_line_tolerance: float = 0.003,
606
595
  height_tolerance: float = 2.0,
607
596
  paragraph_break: Optional[float] = 0.035,
608
- line_category_id: int = 1,
609
597
  ):
610
598
  """
611
599
  :param text_container: name of an image annotation that has a CHARS sub category. These annotations will be
@@ -647,12 +635,9 @@ class TextOrderService(TextLineServiceMixin):
647
635
  self.floating_text_block_categories = self.floating_text_block_categories + (LayoutType.LINE,)
648
636
  self.include_residual_text_container = include_residual_text_container
649
637
  self.order_generator = OrderGenerator(starting_point_tolerance, broken_line_tolerance, height_tolerance)
650
- self.text_line_generator = TextLineGenerator(
651
- self.include_residual_text_container, line_category_id, paragraph_break
652
- )
638
+ self.text_line_generator = TextLineGenerator(self.include_residual_text_container, paragraph_break)
653
639
  super().__init__(
654
640
  name="text_order",
655
- line_category_id=line_category_id,
656
641
  include_residual_text_container=include_residual_text_container,
657
642
  paragraph_break=paragraph_break,
658
643
  )
@@ -763,7 +748,6 @@ class TextOrderService(TextLineServiceMixin):
763
748
  self.order_generator.broken_line_tolerance,
764
749
  self.order_generator.height_tolerance,
765
750
  self.text_line_generator.paragraph_break,
766
- self.text_line_generator.line_category_id,
767
751
  )
768
752
 
769
753
  def clear_predictor(self) -> None:
@@ -436,24 +436,24 @@ def segment_table(
436
436
  child_ann_ids = table.get_relationship(Relationships.CHILD)
437
437
  cell_index_rows, row_index, _, _ = match_anns_by_intersection(
438
438
  dp,
439
- item_names[0],
440
- cell_names,
441
- segment_rule,
442
- threshold_rows,
443
- True,
444
- child_ann_ids,
445
- child_ann_ids,
439
+ parent_ann_category_names=item_names[0],
440
+ child_ann_category_names=cell_names,
441
+ matching_rule=segment_rule,
442
+ threshold=threshold_rows,
443
+ use_weighted_intersections=True,
444
+ parent_ann_ids=child_ann_ids,
445
+ child_ann_ids=child_ann_ids,
446
446
  )
447
447
 
448
448
  cell_index_cols, col_index, _, _ = match_anns_by_intersection(
449
449
  dp,
450
- item_names[1],
451
- cell_names,
452
- segment_rule,
453
- threshold_cols,
454
- True,
455
- child_ann_ids,
456
- child_ann_ids,
450
+ parent_ann_category_names=item_names[1],
451
+ child_ann_category_names=cell_names,
452
+ matching_rule=segment_rule,
453
+ threshold=threshold_cols,
454
+ use_weighted_intersections=True,
455
+ parent_ann_ids=child_ann_ids,
456
+ child_ann_ids=child_ann_ids,
457
457
  )
458
458
 
459
459
  cells = dp.get_annotation(annotation_ids=child_ann_ids, category_names=cell_names)
@@ -499,7 +499,6 @@ def create_intersection_cells(
499
499
  rows: Sequence[ImageAnnotation],
500
500
  cols: Sequence[ImageAnnotation],
501
501
  table_annotation_id: str,
502
- cell_class_id: int,
503
502
  sub_item_names: Sequence[ObjectTypes],
504
503
  ) -> tuple[Sequence[DetectionResult], Sequence[SegmentationResult]]:
505
504
  """
@@ -509,7 +508,6 @@ def create_intersection_cells(
509
508
  :param rows: list of rows
510
509
  :param cols: list of columns
511
510
  :param table_annotation_id: annotation_id of underlying table ImageAnnotation
512
- :param cell_class_id: The class_id to a synthetically generated DetectionResult
513
511
  :param sub_item_names: ObjectTypes for row-/column number
514
512
  :return: Pair of lists of `DetectionResult` and `SegmentationResult`.
515
513
  """
@@ -526,7 +524,6 @@ def create_intersection_cells(
526
524
  detect_result_cells.append(
527
525
  DetectionResult(
528
526
  box=boxes_cells[idx].to_list(mode="xyxy"),
529
- class_id=cell_class_id,
530
527
  absolute_coords=boxes_cells[idx].absolute_coords,
531
528
  class_name=LayoutType.CELL,
532
529
  )
@@ -574,13 +571,13 @@ def header_cell_to_item_detect_result(
574
571
  child_ann_ids = table.get_relationship(Relationships.CHILD)
575
572
  item_index, _, items, _ = match_anns_by_intersection(
576
573
  dp,
577
- item_header_name,
578
- item_name,
579
- segment_rule,
580
- threshold,
581
- True,
582
- child_ann_ids,
583
- child_ann_ids,
574
+ parent_ann_category_names=item_header_name,
575
+ child_ann_category_names=item_name,
576
+ matching_rule=segment_rule,
577
+ threshold=threshold,
578
+ use_weighted_intersections=True,
579
+ parent_ann_ids=child_ann_ids,
580
+ child_ann_ids=child_ann_ids,
584
581
  )
585
582
  item_headers = []
586
583
  for idx, item in enumerate(items):
@@ -622,24 +619,24 @@ def segment_pubtables(
622
619
  child_ann_ids = table.get_relationship(Relationships.CHILD)
623
620
  cell_index_rows, row_index, _, _ = match_anns_by_intersection(
624
621
  dp,
625
- item_names[0],
626
- spanning_cell_names,
627
- segment_rule,
628
- threshold_rows,
629
- True,
630
- child_ann_ids,
631
- child_ann_ids,
622
+ parent_ann_category_names=item_names[0],
623
+ child_ann_category_names=spanning_cell_names,
624
+ matching_rule=segment_rule,
625
+ threshold=threshold_rows,
626
+ use_weighted_intersections=True,
627
+ parent_ann_ids=child_ann_ids,
628
+ child_ann_ids=child_ann_ids,
632
629
  )
633
630
 
634
631
  cell_index_cols, col_index, _, _ = match_anns_by_intersection(
635
632
  dp,
636
- item_names[1],
637
- spanning_cell_names,
638
- segment_rule,
639
- threshold_cols,
640
- True,
641
- child_ann_ids,
642
- child_ann_ids,
633
+ parent_ann_category_names=item_names[1],
634
+ child_ann_category_names=spanning_cell_names,
635
+ matching_rule=segment_rule,
636
+ threshold=threshold_cols,
637
+ use_weighted_intersections=True,
638
+ parent_ann_ids=child_ann_ids,
639
+ child_ann_ids=child_ann_ids,
643
640
  )
644
641
 
645
642
  spanning_cells = dp.get_annotation(annotation_ids=child_ann_ids, category_names=spanning_cell_names)
@@ -976,7 +973,6 @@ class PubtablesSegmentationService(PipelineComponent):
976
973
  tile_table_with_items: bool,
977
974
  remove_iou_threshold_rows: float,
978
975
  remove_iou_threshold_cols: float,
979
- cell_class_id: int,
980
976
  table_name: TypeOrStr,
981
977
  cell_names: Sequence[TypeOrStr],
982
978
  spanning_cell_names: Sequence[TypeOrStr],
@@ -997,7 +993,6 @@ class PubtablesSegmentationService(PipelineComponent):
997
993
  the adjacent row. Will do a similar shifting with columns.
998
994
  :param remove_iou_threshold_rows: iou threshold for removing overlapping rows
999
995
  :param remove_iou_threshold_cols: iou threshold for removing overlapping columns
1000
- :param cell_class_id: 'category_id' for cells to be generated from intersected rows and columns
1001
996
  :param table_name: layout type table
1002
997
  :param cell_names: layout type of cells
1003
998
  :param spanning_cell_names: layout type of spanning cells
@@ -1022,7 +1017,6 @@ class PubtablesSegmentationService(PipelineComponent):
1022
1017
  self.spanning_cell_names = [get_type(cell_name) for cell_name in spanning_cell_names]
1023
1018
  self.remove_iou_threshold_rows = remove_iou_threshold_rows
1024
1019
  self.remove_iou_threshold_cols = remove_iou_threshold_cols
1025
- self.cell_class_id = cell_class_id
1026
1020
  self.cell_to_image = cell_to_image
1027
1021
  self.crop_cell_image = crop_cell_image
1028
1022
  self.item_names = [get_type(item_name) for item_name in item_names] # row names must be before column name
@@ -1089,7 +1083,7 @@ class PubtablesSegmentationService(PipelineComponent):
1089
1083
  rows = dp.get_annotation(category_names=self.item_names[0], annotation_ids=item_ann_ids)
1090
1084
  columns = dp.get_annotation(category_names=self.item_names[1], annotation_ids=item_ann_ids)
1091
1085
  detect_result_cells, segment_result_cells = create_intersection_cells(
1092
- rows, columns, table.annotation_id, self.cell_class_id, self.sub_item_names
1086
+ rows, columns, table.annotation_id, self.sub_item_names
1093
1087
  )
1094
1088
  cell_rn_cn_to_ann_id = {}
1095
1089
  for detect_result, segment_result in zip(detect_result_cells, segment_result_cells):
@@ -1228,7 +1222,6 @@ class PubtablesSegmentationService(PipelineComponent):
1228
1222
  self.tile_table,
1229
1223
  self.remove_iou_threshold_rows,
1230
1224
  self.remove_iou_threshold_cols,
1231
- self.cell_class_id,
1232
1225
  self.table_name,
1233
1226
  self.cell_names,
1234
1227
  self.spanning_cell_names,
@@ -92,7 +92,6 @@ class DetectResultGenerator:
92
92
  detect_result_list.append(
93
93
  DetectionResult(
94
94
  box=[0.0, 0.0, float(self.width), float(self.height)], # type: ignore
95
- class_id=self.categories_name_as_key[category_name],
96
95
  class_name=category_name,
97
96
  score=0.0,
98
97
  absolute_coords=self.absolute_coords,
@@ -154,16 +153,16 @@ class SubImageLayoutService(PipelineComponent):
154
153
  **Example**
155
154
 
156
155
  detect_result_generator = DetectResultGenerator(categories_items)
157
- d_items = TPFrcnnDetector(item_config_path, item_weights_path, {"1": LayoutType.row,
158
- "2": LayoutType.column})
159
- item_component = SubImageLayoutService(d_items, LayoutType.table, {1: 7, 2: 8}, detect_result_generator)
156
+ d_items = TPFrcnnDetector(item_config_path, item_weights_path, {1: LayoutType.row,
157
+ 2: LayoutType.column})
158
+ item_component = SubImageLayoutService(d_items, LayoutType.table, detect_result_generator)
160
159
  """
161
160
 
162
161
  def __init__(
163
162
  self,
164
163
  sub_image_detector: ObjectDetector,
165
164
  sub_image_names: Union[str, Sequence[TypeOrStr]],
166
- category_id_mapping: Optional[dict[int, int]] = None,
165
+ service_ids: Optional[Sequence[str]] = None,
167
166
  detect_result_generator: Optional[DetectResultGenerator] = None,
168
167
  padder: Optional[PadTransform] = None,
169
168
  ):
@@ -172,7 +171,8 @@ class SubImageLayoutService(PipelineComponent):
172
171
  :param sub_image_names: Category names of ImageAnnotations to be presented to the detector.
173
172
  Attention: The selected ImageAnnotations must have: attr:`image` and: attr:`image.image`
174
173
  not None.
175
- :param category_id_mapping: Mapping of category IDs. Usually, the category ids start with 1.
174
+ :param service_ids: List of service ids to be used for filtering the ImageAnnotations. If None, all
175
+ ImageAnnotations will be used.
176
176
  :param detect_result_generator: 'DetectResultGenerator' instance. 'categories' attribute has to be the same as
177
177
  the 'categories' attribute of the 'sub_image_detector'. The generator will be
178
178
  responsible to create 'DetectionResult' for some categories, if they have not
@@ -186,7 +186,7 @@ class SubImageLayoutService(PipelineComponent):
186
186
  if isinstance(sub_image_names, str)
187
187
  else tuple((get_type(cat) for cat in sub_image_names))
188
188
  )
189
- self.category_id_mapping = category_id_mapping
189
+ self.service_ids = service_ids
190
190
  self.detect_result_generator = detect_result_generator
191
191
  self.padder = padder
192
192
  self.predictor = sub_image_detector
@@ -208,7 +208,7 @@ class SubImageLayoutService(PipelineComponent):
208
208
  - Optionally invoke the DetectResultGenerator
209
209
  - Generate ImageAnnotations and dump to parent image and sub image.
210
210
  """
211
- sub_image_anns = dp.get_annotation(category_names=self.sub_image_name)
211
+ sub_image_anns = dp.get_annotation(category_names=self.sub_image_name, service_ids=self.service_ids)
212
212
  for sub_image_ann in sub_image_anns:
213
213
  np_image = self.prepare_np_image(sub_image_ann)
214
214
  detect_result_list = self.predictor.predict(np_image)
@@ -223,11 +223,6 @@ class SubImageLayoutService(PipelineComponent):
223
223
  detect_result_list = self.detect_result_generator.create_detection_result(detect_result_list)
224
224
 
225
225
  for detect_result in detect_result_list:
226
- if self.category_id_mapping:
227
- if detect_result.class_id:
228
- detect_result.class_id = self.category_id_mapping.get(
229
- detect_result.class_id, detect_result.class_id
230
- )
231
226
  self.dp_manager.set_image_annotation(detect_result, sub_image_ann.annotation_id)
232
227
 
233
228
  def get_meta_annotation(self) -> MetaAnnotation:
@@ -254,7 +249,7 @@ class SubImageLayoutService(PipelineComponent):
254
249
  return self.__class__(
255
250
  predictor,
256
251
  self.sub_image_name,
257
- self.category_id_mapping,
252
+ self.service_ids,
258
253
  self.detect_result_generator,
259
254
  padder_clone,
260
255
  )