scale-nucleus 0.12b1__py3-none-any.whl → 0.14.14b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. cli/slices.py +14 -28
  2. nucleus/__init__.py +211 -18
  3. nucleus/annotation.py +28 -5
  4. nucleus/connection.py +9 -1
  5. nucleus/constants.py +9 -3
  6. nucleus/dataset.py +197 -59
  7. nucleus/dataset_item.py +11 -1
  8. nucleus/job.py +1 -1
  9. nucleus/metrics/__init__.py +2 -1
  10. nucleus/metrics/base.py +34 -56
  11. nucleus/metrics/categorization_metrics.py +6 -2
  12. nucleus/metrics/cuboid_utils.py +4 -6
  13. nucleus/metrics/errors.py +4 -0
  14. nucleus/metrics/filtering.py +369 -19
  15. nucleus/metrics/polygon_utils.py +3 -3
  16. nucleus/metrics/segmentation_loader.py +30 -0
  17. nucleus/metrics/segmentation_metrics.py +256 -195
  18. nucleus/metrics/segmentation_to_poly_metrics.py +229 -105
  19. nucleus/metrics/segmentation_utils.py +239 -8
  20. nucleus/model.py +66 -10
  21. nucleus/model_run.py +1 -1
  22. nucleus/{shapely_not_installed.py → package_not_installed.py} +3 -3
  23. nucleus/payload_constructor.py +4 -0
  24. nucleus/prediction.py +6 -3
  25. nucleus/scene.py +7 -0
  26. nucleus/slice.py +160 -16
  27. nucleus/utils.py +51 -12
  28. nucleus/validate/__init__.py +1 -0
  29. nucleus/validate/client.py +57 -8
  30. nucleus/validate/constants.py +1 -0
  31. nucleus/validate/data_transfer_objects/eval_function.py +22 -0
  32. nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +13 -5
  33. nucleus/validate/eval_functions/available_eval_functions.py +33 -20
  34. nucleus/validate/eval_functions/config_classes/segmentation.py +2 -46
  35. nucleus/validate/scenario_test.py +71 -13
  36. nucleus/validate/scenario_test_evaluation.py +21 -21
  37. nucleus/validate/utils.py +1 -1
  38. {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/LICENSE +0 -0
  39. {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/METADATA +13 -11
  40. {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/RECORD +42 -41
  41. {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/WHEEL +1 -1
  42. {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/entry_points.txt +0 -0
nucleus/metrics/base.py CHANGED
@@ -4,10 +4,13 @@ from dataclasses import dataclass
4
4
  from typing import Iterable, List, Optional, Union
5
5
 
6
6
  from nucleus.annotation import AnnotationList
7
+ from nucleus.metrics.errors import EverythingFilteredError
7
8
  from nucleus.metrics.filtering import (
8
9
  ListOfAndFilters,
9
10
  ListOfOrAndFilters,
10
- apply_filters,
11
+ compose_helpful_filtering_error,
12
+ filter_annotation_list,
13
+ filter_prediction_list,
11
14
  )
12
15
  from nucleus.prediction import PredictionList
13
16
 
@@ -133,64 +136,16 @@ class Metric(ABC):
133
136
  def __call__(
134
137
  self, annotations: AnnotationList, predictions: PredictionList
135
138
  ) -> MetricResult:
136
- annotations = self._filter_annotations(annotations)
137
- predictions = self._filter_predictions(predictions)
138
- return self.call_metric(annotations, predictions)
139
-
140
- def _filter_annotations(self, annotations: AnnotationList):
141
- if (
142
- self.annotation_filters is None
143
- or len(self.annotation_filters) == 0
144
- ):
145
- return annotations
146
- annotations.box_annotations = apply_filters(
147
- annotations.box_annotations, self.annotation_filters
139
+ filtered_anns = filter_annotation_list(
140
+ annotations, self.annotation_filters
148
141
  )
149
- annotations.line_annotations = apply_filters(
150
- annotations.line_annotations, self.annotation_filters
142
+ filtered_preds = filter_prediction_list(
143
+ predictions, self.prediction_filters
151
144
  )
152
- annotations.polygon_annotations = apply_filters(
153
- annotations.polygon_annotations, self.annotation_filters
145
+ self._raise_if_everything_filtered(
146
+ annotations, filtered_anns, predictions, filtered_preds
154
147
  )
155
- annotations.cuboid_annotations = apply_filters(
156
- annotations.cuboid_annotations, self.annotation_filters
157
- )
158
- annotations.category_annotations = apply_filters(
159
- annotations.category_annotations, self.annotation_filters
160
- )
161
- annotations.multi_category_annotations = apply_filters(
162
- annotations.multi_category_annotations, self.annotation_filters
163
- )
164
- annotations.segmentation_annotations = apply_filters(
165
- annotations.segmentation_annotations, self.annotation_filters
166
- )
167
- return annotations
168
-
169
- def _filter_predictions(self, predictions: PredictionList):
170
- if (
171
- self.prediction_filters is None
172
- or len(self.prediction_filters) == 0
173
- ):
174
- return predictions
175
- predictions.box_predictions = apply_filters(
176
- predictions.box_predictions, self.prediction_filters
177
- )
178
- predictions.line_predictions = apply_filters(
179
- predictions.line_predictions, self.prediction_filters
180
- )
181
- predictions.polygon_predictions = apply_filters(
182
- predictions.polygon_predictions, self.prediction_filters
183
- )
184
- predictions.cuboid_predictions = apply_filters(
185
- predictions.cuboid_predictions, self.prediction_filters
186
- )
187
- predictions.category_predictions = apply_filters(
188
- predictions.category_predictions, self.prediction_filters
189
- )
190
- predictions.segmentation_predictions = apply_filters(
191
- predictions.segmentation_predictions, self.prediction_filters
192
- )
193
- return predictions
148
+ return self.call_metric(filtered_anns, filtered_preds)
194
149
 
195
150
  @abstractmethod
196
151
  def aggregate_score(self, results: List[MetricResult]) -> ScalarResult:
@@ -215,3 +170,26 @@ class Metric(ABC):
215
170
  return ScalarResult(r2_score)
216
171
 
217
172
  """
173
+
174
+ def _raise_if_everything_filtered(
175
+ self,
176
+ annotations: AnnotationList,
177
+ filtered_annotations: AnnotationList,
178
+ predictions: PredictionList,
179
+ filtered_predictions: PredictionList,
180
+ ):
181
+ msg = []
182
+ if len(filtered_annotations) == 0:
183
+ msg.extend(
184
+ compose_helpful_filtering_error(
185
+ annotations, self.annotation_filters
186
+ )
187
+ )
188
+ if len(filtered_predictions) == 0:
189
+ msg.extend(
190
+ compose_helpful_filtering_error(
191
+ predictions, self.prediction_filters
192
+ )
193
+ )
194
+ if msg:
195
+ raise EverythingFilteredError("\n".join(msg))
@@ -2,8 +2,6 @@ from abc import abstractmethod
2
2
  from dataclasses import dataclass
3
3
  from typing import List, Optional, Set, Tuple, Union
4
4
 
5
- from sklearn.metrics import f1_score
6
-
7
5
  from nucleus.annotation import AnnotationList, CategoryAnnotation
8
6
  from nucleus.metrics.base import Metric, MetricResult, ScalarResult
9
7
  from nucleus.metrics.filtering import ListOfAndFilters, ListOfOrAndFilters
@@ -35,6 +33,9 @@ class CategorizationResult(MetricResult):
35
33
 
36
34
  @property
37
35
  def value(self):
36
+ # late import to avoid slow CLI init
37
+ from sklearn.metrics import f1_score
38
+
38
39
  annotation_labels = to_taxonomy_labels(self.annotations)
39
40
  prediction_labels = to_taxonomy_labels(self.predictions)
40
41
 
@@ -245,6 +246,9 @@ class CategorizationF1(CategorizationMetric):
245
246
  )
246
247
 
247
248
  def aggregate_score(self, results: List[CategorizationResult]) -> ScalarResult: # type: ignore[override]
249
+ # late import to avoid slow CLI init
250
+ from sklearn.metrics import f1_score
251
+
248
252
  gt = []
249
253
  predicted = []
250
254
  for result in results:
@@ -5,10 +5,10 @@ import numpy as np
5
5
 
6
6
  try:
7
7
  from shapely.geometry import Polygon
8
- except ModuleNotFoundError:
9
- from ..shapely_not_installed import ShapelyNotInstalled
8
+ except (ModuleNotFoundError, OSError):
9
+ from ..package_not_installed import PackageNotInstalled
10
10
 
11
- Polygon = ShapelyNotInstalled
11
+ Polygon = PackageNotInstalled
12
12
 
13
13
 
14
14
  from nucleus.annotation import CuboidAnnotation
@@ -357,6 +357,4 @@ def detection_iou(
357
357
  meter_3d.append(iou_3d[i, j])
358
358
  meter_2d.append(iou_2d[i, j])
359
359
 
360
- meter_3d = np.array(meter_3d)
361
- meter_2d = np.array(meter_2d)
362
- return meter_3d, meter_2d
360
+ return np.array(meter_3d), np.array(meter_2d)
nucleus/metrics/errors.py CHANGED
@@ -5,3 +5,7 @@ class PolygonAnnotationTypeError(Exception):
5
5
  ):
6
6
  self.message = message
7
7
  super().__init__(self.message)
8
+
9
+
10
+ class EverythingFilteredError(Exception):
11
+ pass
@@ -1,16 +1,32 @@
1
+ import copy
1
2
  import enum
2
3
  import functools
3
4
  import logging
4
5
  from enum import Enum
5
- from typing import Callable, Iterable, List, NamedTuple, Sequence, Set, Union
6
+ from typing import (
7
+ Callable,
8
+ Iterable,
9
+ List,
10
+ NamedTuple,
11
+ Optional,
12
+ Sequence,
13
+ Set,
14
+ Tuple,
15
+ Union,
16
+ )
17
+
18
+ from rich.console import Console
19
+ from rich.table import Table
6
20
 
7
21
  from nucleus.annotation import (
22
+ AnnotationList,
8
23
  BoxAnnotation,
9
24
  CategoryAnnotation,
10
25
  CuboidAnnotation,
11
26
  LineAnnotation,
12
27
  MultiCategoryAnnotation,
13
28
  PolygonAnnotation,
29
+ Segment,
14
30
  SegmentationAnnotation,
15
31
  )
16
32
  from nucleus.prediction import (
@@ -19,6 +35,7 @@ from nucleus.prediction import (
19
35
  CuboidPrediction,
20
36
  LinePrediction,
21
37
  PolygonPrediction,
38
+ PredictionList,
22
39
  SegmentationPrediction,
23
40
  )
24
41
 
@@ -40,10 +57,14 @@ class FilterType(str, enum.Enum):
40
57
  Attributes:
41
58
  FIELD: Access the attribute field of an object
42
59
  METADATA: Access the metadata dictionary of an object
60
+ SEGMENT_FIELD: Filter segments of a segmentation mask to be considered on segment fields
61
+ SEGMENT_METADATA: Filter segments of a segmentation mask based on segment metadata
43
62
  """
44
63
 
45
64
  FIELD = "field"
46
65
  METADATA = "metadata"
66
+ SEGMENT_FIELD = "segment_field"
67
+ SEGMENT_METADATA = "segment_metadata"
47
68
 
48
69
 
49
70
  FilterableBaseVals = Union[str, float, int, bool]
@@ -100,7 +121,8 @@ class FieldFilter(NamedTuple):
100
121
 
101
122
  Examples:
102
123
  FieldFilter("x", ">", 10) would pass every :class:`BoxAnnotation` with `x` attribute larger than 10
103
- FieldFilter("label", "in", [) would pass every :class:`BoxAnnotation` with `x` attribute larger than 10
124
+ FieldFilter("label", "in", ["car", "truck"]) would pass every :class:`BoxAnnotation` with `label`
125
+ in ["car", "truck"]
104
126
 
105
127
  Attributes:
106
128
  key: key to compare with value
@@ -129,7 +151,7 @@ class MetadataFilter(NamedTuple):
129
151
  with value field
130
152
  value: bool, str, float or int to compare the field with key or list of the same values for 'in' and 'not in'
131
153
  ops
132
- allow_missing: Allow missing metada values. Will REMOVE the object with the missing field from the selection
154
+ allow_missing: Allow missing metadata values. Will REMOVE the object with the missing field from the selection
133
155
  type: DO NOT USE. Internal type for serialization over the wire. Changing this will change the `NamedTuple`
134
156
  type as well.
135
157
  """
@@ -141,7 +163,60 @@ class MetadataFilter(NamedTuple):
141
163
  type: FilterType = FilterType.METADATA
142
164
 
143
165
 
144
- Filter = Union[FieldFilter, MetadataFilter, AnnotationOrPredictionFilter]
166
+ class SegmentMetadataFilter(NamedTuple):
167
+ """Filter on customer provided metadata associated with Segments of a SegmentationAnnotation or
168
+ SegmentationPrediction
169
+
170
+ Attributes:
171
+ key: key to compare with value
172
+ op: :class:`FilterOp` or one of [">", ">=", "<", "<=", "=", "==", "!=", "in", "not in"] to define comparison
173
+ with value field
174
+ value: bool, str, float or int to compare the field with key or list of the same values for 'in' and 'not in'
175
+ ops
176
+ allow_missing: Allow missing metadata values. Will REMOVE the object with the missing field from the selection
177
+ type: DO NOT USE. Internal type for serialization over the wire. Changing this will change the `NamedTuple`
178
+ type as well.
179
+ """
180
+
181
+ key: str
182
+ op: Union[FilterOp, str]
183
+ value: FilterableTypes
184
+ allow_missing: bool = False
185
+ type: FilterType = FilterType.SEGMENT_METADATA
186
+
187
+
188
+ class SegmentFieldFilter(NamedTuple):
189
+ """Filter on standard field of Segment(s) of SegmentationAnnotation and SegmentationPrediction
190
+
191
+ Examples:
192
+ SegmentFieldFilter("label", "in", ["grass", "tree"]) would pass every :class:`Segment` of a
193
+ :class:`SegmentationAnnotation or :class:`SegmentationPrediction`
194
+
195
+ Attributes:
196
+ key: key to compare with value
197
+ op: :class:`FilterOp` or one of [">", ">=", "<", "<=", "=", "==", "!=", "in", "not in"] to define comparison
198
+ with value field
199
+ value: bool, str, float or int to compare the field with key or list of the same values for 'in' and 'not in'
200
+ ops
201
+ allow_missing: Allow missing field values. Will REMOVE the object with the missing field from the selection
202
+ type: DO NOT USE. Internal type for serialization over the wire. Changing this will change the `NamedTuple`
203
+ type as well.
204
+ """
205
+
206
+ key: str
207
+ op: Union[FilterOp, str]
208
+ value: FilterableTypes
209
+ allow_missing: bool = False
210
+ type: FilterType = FilterType.SEGMENT_FIELD
211
+
212
+
213
+ Filter = Union[
214
+ FieldFilter,
215
+ MetadataFilter,
216
+ AnnotationOrPredictionFilter,
217
+ SegmentFieldFilter,
218
+ SegmentMetadataFilter,
219
+ ]
145
220
  OrAndDNFFilters = List[List[Filter]]
146
221
  OrAndDNFFilters.__doc__ = """\
147
222
  Disjunctive normal form (DNF) filters.
@@ -182,11 +257,22 @@ ListOfAndFilters = Union[
182
257
  ListOfAndJSONSerialized,
183
258
  ]
184
259
 
260
+ DNFFieldOrMetadataFilters = List[
261
+ List[Union[FieldFilter, MetadataFilter, AnnotationOrPredictionFilter]]
262
+ ]
263
+ DNFFieldOrMetadataFilters.__doc__ = """\
264
+ Disjunctive normal form (DNF) filters.
265
+ DNF allows arbitrary boolean logical combinations of single field predicates.
266
+ The innermost structures each describe a single field predicate.
267
+ -The list of inner predicates is interpreted as a conjunction (AND), forming a more selective and multiple column
268
+ predicate.
269
+ """
270
+
185
271
 
186
272
  def _attribute_getter(
187
273
  field_name: str,
188
274
  allow_missing: bool,
189
- ann_or_pred: Union[AnnotationTypes, PredictionTypes],
275
+ ann_or_pred: Union[AnnotationTypes, PredictionTypes, Segment],
190
276
  ):
191
277
  """Create a function to get object fields"""
192
278
  if allow_missing:
@@ -224,7 +310,7 @@ class AlwaysFalseComparison:
224
310
  def _metadata_field_getter(
225
311
  field_name: str,
226
312
  allow_missing: bool,
227
- ann_or_pred: Union[AnnotationTypes, PredictionTypes],
313
+ ann_or_pred: Union[AnnotationTypes, PredictionTypes, Segment],
228
314
  ):
229
315
  """Create a function to get a metadata field"""
230
316
  if isinstance(
@@ -259,7 +345,7 @@ def _metadata_field_getter(
259
345
 
260
346
  def _filter_to_comparison_function( # pylint: disable=too-many-return-statements
261
347
  filter_def: Filter,
262
- ) -> Callable[[Union[AnnotationTypes, PredictionTypes]], bool]:
348
+ ) -> Callable[[Union[AnnotationTypes, PredictionTypes, Segment]], bool]:
263
349
  """Creates a comparison function from a filter configuration to apply to annotations or predictions
264
350
 
265
351
  Parameters:
@@ -276,6 +362,10 @@ def _filter_to_comparison_function( # pylint: disable=too-many-return-statement
276
362
  getter = functools.partial(
277
363
  _metadata_field_getter, filter_def.key, filter_def.allow_missing
278
364
  )
365
+ else:
366
+ raise NotImplementedError(
367
+ f"Unhandled filter type: {filter_def.type}. NOTE: Segmentation filters are handled elsewhere."
368
+ )
279
369
  op = FilterOp(filter_def.op)
280
370
  if op is FilterOp.GT:
281
371
  return lambda ann_or_pred: getter(ann_or_pred) > filter_def.value
@@ -303,13 +393,16 @@ def _filter_to_comparison_function( # pylint: disable=too-many-return-statement
303
393
  )
304
394
 
305
395
 
306
- def apply_filters(
307
- ann_or_pred: Union[Sequence[AnnotationTypes], Sequence[PredictionTypes]],
308
- filters: Union[ListOfOrAndFilters, ListOfAndFilters],
396
+ def _apply_field_or_metadata_filters(
397
+ filterable_sequence: Union[
398
+ Sequence[AnnotationTypes], Sequence[PredictionTypes], Sequence[Segment]
399
+ ],
400
+ filters: DNFFieldOrMetadataFilters,
309
401
  ):
310
- """Apply filters to list of annotations or list of predictions
402
+ """Apply filters to list of annotations or list of predictions or to a list of segments
403
+
311
404
  Attributes:
312
- ann_or_pred: Prediction or Annotation
405
+ filterable_sequence: Prediction, Annotation or Segment sequence
313
406
  filters: Filter predicates. Allowed formats are:
314
407
  ListOfAndFilters where each Filter forms a chain of AND predicates.
315
408
  or
@@ -320,11 +413,6 @@ def apply_filters(
320
413
  is interpreted as a conjunction (AND), forming a more selective `and` multiple column predicate.
321
414
  Finally, the most outer list combines these filters as a disjunction (OR).
322
415
  """
323
- if filters is None or len(filters) == 0:
324
- return ann_or_pred
325
-
326
- filters = ensureDNFFilters(filters)
327
-
328
416
  dnf_condition_functions = []
329
417
  for or_branch in filters:
330
418
  and_conditions = [
@@ -333,18 +421,136 @@ def apply_filters(
333
421
  dnf_condition_functions.append(and_conditions)
334
422
 
335
423
  filtered = []
336
- for item in ann_or_pred:
424
+ for item in filterable_sequence:
337
425
  for or_conditions in dnf_condition_functions:
338
426
  if all(c(item) for c in or_conditions):
339
427
  filtered.append(item)
340
428
  break
429
+
430
+ return filtered
431
+
432
+
433
+ def _split_segment_filters(
434
+ dnf_filters: OrAndDNFFilters,
435
+ ) -> Tuple[OrAndDNFFilters, OrAndDNFFilters]:
436
+ """We treat Segment* filters differently -> this splits filters into two sets, one containing the
437
+ standard field, metadata branches and the other the segment filters.
438
+ """
439
+ normal_or_branches = []
440
+ segment_or_branches = []
441
+ for and_branch in dnf_filters:
442
+ normal_filters = []
443
+ segment_filters = []
444
+ for filter_statement in and_branch:
445
+ if filter_statement.type in {
446
+ FilterType.SEGMENT_METADATA,
447
+ FilterType.SEGMENT_FIELD,
448
+ }:
449
+ segment_filters.append(filter_statement)
450
+ else:
451
+ normal_filters.append(filter_statement)
452
+ normal_or_branches.append(normal_filters)
453
+ segment_or_branches.append(segment_filters)
454
+ return normal_or_branches, segment_or_branches
455
+
456
+
457
+ def _filter_segments(
458
+ anns_or_preds: Union[
459
+ Sequence[SegmentationAnnotation], Sequence[SegmentationPrediction]
460
+ ],
461
+ segment_filters: OrAndDNFFilters,
462
+ ):
463
+ """Filter Segments of a SegmentationAnnotation or Prediction
464
+
465
+ We have to treat this differently as metadata and labels are on nested Segment objects
466
+ """
467
+ if len(segment_filters) == 0 or len(segment_filters[0]) == 0:
468
+ return anns_or_preds
469
+
470
+ # Transform segment filter types to field and metadata to iterate over annotation sub fields
471
+ transformed_or_branches = (
472
+ []
473
+ ) # type: List[List[Union[MetadataFilter, FieldFilter]]]
474
+ for and_branch in segment_filters:
475
+ transformed_and = [] # type: List[Union[MetadataFilter, FieldFilter]]
476
+ for filter_statement in and_branch:
477
+ if filter_statement.type == FilterType.SEGMENT_FIELD:
478
+ transformed_and.append(
479
+ FieldFilter(
480
+ filter_statement.key,
481
+ filter_statement.op,
482
+ filter_statement.value,
483
+ filter_statement.allow_missing,
484
+ )
485
+ )
486
+ elif filter_statement.type == FilterType.SEGMENT_METADATA:
487
+ transformed_and.append(
488
+ MetadataFilter(
489
+ filter_statement.key,
490
+ filter_statement.op,
491
+ filter_statement.value,
492
+ filter_statement.allow_missing,
493
+ )
494
+ )
495
+ else:
496
+ raise RuntimeError("Encountered a non SEGMENT_* filter type")
497
+
498
+ transformed_or_branches.append(transformed_and)
499
+
500
+ segments_filtered = []
501
+ for ann_or_pred in anns_or_preds:
502
+ if isinstance(
503
+ ann_or_pred, (SegmentationAnnotation, SegmentationPrediction)
504
+ ):
505
+ ann_or_pred.annotations = _apply_field_or_metadata_filters(
506
+ ann_or_pred.annotations, transformed_or_branches # type: ignore
507
+ )
508
+ segments_filtered.append(ann_or_pred)
509
+
510
+ return segments_filtered
511
+
512
+
513
+ def apply_filters(
514
+ ann_or_pred: Union[Sequence[AnnotationTypes], Sequence[PredictionTypes]],
515
+ filters: Union[ListOfOrAndFilters, ListOfAndFilters],
516
+ ):
517
+ """Apply filters to list of annotations or list of predictions
518
+ Attributes:
519
+ ann_or_pred: Prediction or Annotation
520
+ filters: Filter predicates. Allowed formats are:
521
+ ListOfAndFilters where each Filter forms a chain of AND predicates.
522
+ or
523
+ ListOfOrAndFilters where Filters are expressed in disjunctive normal form (DNF), like
524
+ [[MetadataFilter("short_haired", "==", True), FieldFilter("label", "in", ["cat", "dog"]), ...].
525
+ DNF allows arbitrary boolean logical combinations of single field
526
+ predicates. The innermost structures each describe a single column predicate. The list of inner predicates
527
+ is interpreted as a conjunction (AND), forming a more selective `and` multiple column predicate.
528
+ Finally, the most outer list combines these filters as a disjunction (OR).
529
+ """
530
+ if filters is None or len(filters) == 0:
531
+ return ann_or_pred
532
+
533
+ dnf_filters = ensureDNFFilters(filters)
534
+ filters, segment_filters = _split_segment_filters(dnf_filters)
535
+ filtered = _apply_field_or_metadata_filters(ann_or_pred, filters) # type: ignore
536
+ filtered = _filter_segments(filtered, segment_filters)
537
+
341
538
  return filtered
342
539
 
343
540
 
344
541
  def ensureDNFFilters(filters) -> OrAndDNFFilters:
345
542
  """JSON encoding creates a triple nested lists from the doubly nested tuples. This function creates the
346
543
  tuple form again."""
347
- if isinstance(filters[0], (MetadataFilter, FieldFilter)):
544
+ if isinstance(
545
+ filters[0],
546
+ (
547
+ MetadataFilter,
548
+ FieldFilter,
549
+ AnnotationOrPredictionFilter,
550
+ SegmentFieldFilter,
551
+ SegmentMetadataFilter,
552
+ ),
553
+ ):
348
554
  # Normalize into DNF
349
555
  filters: ListOfOrAndFilters = [filters] # type: ignore
350
556
 
@@ -369,3 +575,147 @@ def ensureDNFFilters(filters) -> OrAndDNFFilters:
369
575
  formatted_filter.append(and_chain)
370
576
  filters = formatted_filter
371
577
  return filters
578
+
579
+
580
+ def pretty_format_filters_with_or_and(
581
+ filters: Optional[Union[ListOfOrAndFilters, ListOfAndFilters]]
582
+ ):
583
+ if filters is None:
584
+ return "No filters applied!"
585
+ dnf_filters = ensureDNFFilters(filters)
586
+ or_branches = []
587
+ for or_branch in dnf_filters:
588
+ and_statements = []
589
+ for and_branch in or_branch:
590
+ if and_branch.type == FilterType.FIELD:
591
+ class_name = "FieldFilter"
592
+ elif and_branch.type == FilterType.METADATA:
593
+ class_name = "MetadataFilter"
594
+ elif and_branch.type == FilterType.SEGMENT_FIELD:
595
+ class_name = "SegmentFieldFilter"
596
+ elif and_branch.type == FilterType.SEGMENT_METADATA:
597
+ class_name = "SegmentMetadataFilter"
598
+ else:
599
+ raise RuntimeError(
600
+ f"Un-handled filter type: {and_branch.type}"
601
+ )
602
+ op = (
603
+ and_branch.op.value
604
+ if isinstance(and_branch.op, FilterOp)
605
+ else and_branch.op
606
+ )
607
+ value_formatted = (
608
+ f'"{and_branch.value}"'
609
+ if isinstance(and_branch.value, str)
610
+ else f"{and_branch.value}".replace("'", '"')
611
+ )
612
+ statement = (
613
+ f'{class_name}("{and_branch.key}", "{op}", {value_formatted})'
614
+ )
615
+ and_statements.append(statement)
616
+
617
+ or_branches.append(and_statements)
618
+
619
+ and_to_join = []
620
+ for and_statements in or_branches:
621
+ joined_and = " and ".join(and_statements)
622
+ if len(or_branches) > 1 and len(and_statements) > 1:
623
+ joined_and = "(" + joined_and + ")"
624
+ and_to_join.append(joined_and)
625
+
626
+ full_statement = " or ".join(and_to_join)
627
+ return full_statement
628
+
629
+
630
+ def compose_helpful_filtering_error(
631
+ ann_or_pred_list: Union[AnnotationList, PredictionList], filters
632
+ ) -> List[str]:
633
+ prefix = (
634
+ "Annotations"
635
+ if isinstance(ann_or_pred_list, AnnotationList)
636
+ else "Predictions"
637
+ )
638
+ msg = []
639
+ msg.append(f"{prefix}: All items filtered out by:")
640
+ msg.append(f" {pretty_format_filters_with_or_and(filters)}")
641
+ msg.append("")
642
+ console = Console()
643
+ table = Table(
644
+ "Type",
645
+ "Count",
646
+ "Labels",
647
+ title=f"Original {prefix}",
648
+ title_justify="left",
649
+ )
650
+ for ann_or_pred_type, items in ann_or_pred_list.items():
651
+ if items and isinstance(
652
+ items[-1], (SegmentationAnnotation, SegmentationPrediction)
653
+ ):
654
+ labels = set()
655
+ for seg in items:
656
+ labels.update(set(s.label for s in seg.annotations))
657
+ else:
658
+ labels = set(a.label for a in items)
659
+ if items:
660
+ table.add_row(ann_or_pred_type, str(len(items)), str(list(labels)))
661
+ with console.capture() as capture:
662
+ console.print(table)
663
+ msg.append(capture.get())
664
+ return msg
665
+
666
+
667
+ def filter_annotation_list(
668
+ annotations: AnnotationList, annotation_filters
669
+ ) -> AnnotationList:
670
+ annotations = copy.deepcopy(annotations)
671
+ if annotation_filters is None or len(annotation_filters) == 0:
672
+ return annotations
673
+ annotations.box_annotations = apply_filters(
674
+ annotations.box_annotations, annotation_filters
675
+ )
676
+ annotations.line_annotations = apply_filters(
677
+ annotations.line_annotations, annotation_filters
678
+ )
679
+ annotations.polygon_annotations = apply_filters(
680
+ annotations.polygon_annotations, annotation_filters
681
+ )
682
+ annotations.cuboid_annotations = apply_filters(
683
+ annotations.cuboid_annotations, annotation_filters
684
+ )
685
+ annotations.category_annotations = apply_filters(
686
+ annotations.category_annotations, annotation_filters
687
+ )
688
+ annotations.multi_category_annotations = apply_filters(
689
+ annotations.multi_category_annotations, annotation_filters
690
+ )
691
+ annotations.segmentation_annotations = apply_filters(
692
+ annotations.segmentation_annotations, annotation_filters
693
+ )
694
+ return annotations
695
+
696
+
697
+ def filter_prediction_list(
698
+ predictions: PredictionList, prediction_filters
699
+ ) -> PredictionList:
700
+ predictions = copy.deepcopy(predictions)
701
+ if prediction_filters is None or len(prediction_filters) == 0:
702
+ return predictions
703
+ predictions.box_predictions = apply_filters(
704
+ predictions.box_predictions, prediction_filters
705
+ )
706
+ predictions.line_predictions = apply_filters(
707
+ predictions.line_predictions, prediction_filters
708
+ )
709
+ predictions.polygon_predictions = apply_filters(
710
+ predictions.polygon_predictions, prediction_filters
711
+ )
712
+ predictions.cuboid_predictions = apply_filters(
713
+ predictions.cuboid_predictions, prediction_filters
714
+ )
715
+ predictions.category_predictions = apply_filters(
716
+ predictions.category_predictions, prediction_filters
717
+ )
718
+ predictions.segmentation_predictions = apply_filters(
719
+ predictions.segmentation_predictions, prediction_filters
720
+ )
721
+ return predictions
@@ -12,10 +12,10 @@ from .custom_types import BoxOrPolygonAnnotation, BoxOrPolygonPrediction
12
12
 
13
13
  try:
14
14
  from shapely.geometry import Polygon
15
- except ModuleNotFoundError:
16
- from ..shapely_not_installed import ShapelyNotInstalled
15
+ except (ModuleNotFoundError, OSError):
16
+ from ..package_not_installed import PackageNotInstalled
17
17
 
18
- Polygon = ShapelyNotInstalled
18
+ Polygon = PackageNotInstalled
19
19
 
20
20
 
21
21
  from .base import ScalarResult