scale-nucleus 0.11b2__tar.gz → 0.12b9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/LICENSE +0 -0
  2. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/PKG-INFO +2 -2
  3. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/README.md +0 -0
  4. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/cli/client.py +0 -0
  5. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/cli/datasets.py +0 -0
  6. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/cli/helpers/__init__.py +0 -0
  7. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/cli/helpers/nucleus_url.py +0 -0
  8. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/cli/helpers/web_helper.py +0 -0
  9. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/cli/install_completion.py +0 -0
  10. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/cli/jobs.py +0 -0
  11. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/cli/models.py +0 -0
  12. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/cli/nu.py +0 -0
  13. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/cli/reference.py +0 -0
  14. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/cli/slices.py +0 -0
  15. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/cli/tests.py +0 -0
  16. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/__init__.py +4 -0
  17. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/annotation.py +0 -0
  18. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/annotation_uploader.py +22 -0
  19. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/async_utils.py +0 -0
  20. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/autocurate.py +0 -0
  21. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/camera_params.py +0 -0
  22. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/connection.py +0 -0
  23. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/constants.py +0 -0
  24. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/data_transfer_object/__init__.py +0 -0
  25. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/data_transfer_object/dataset_details.py +0 -0
  26. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/data_transfer_object/dataset_info.py +0 -0
  27. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/data_transfer_object/dataset_size.py +0 -0
  28. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/data_transfer_object/scenes_list.py +0 -0
  29. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/dataset.py +60 -16
  30. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/dataset_item.py +0 -0
  31. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/dataset_item_uploader.py +0 -0
  32. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/deprecation_warning.py +0 -0
  33. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/errors.py +6 -0
  34. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/job.py +0 -0
  35. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/logger.py +0 -0
  36. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metadata_manager.py +0 -0
  37. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/__init__.py +11 -0
  38. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/base.py +0 -0
  39. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/categorization_metrics.py +0 -0
  40. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/cuboid_metrics.py +0 -0
  41. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/cuboid_utils.py +0 -0
  42. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/custom_types.py +0 -0
  43. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/errors.py +0 -0
  44. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/filtering.py +211 -19
  45. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/filters.py +0 -0
  46. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/metric_utils.py +0 -0
  47. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/polygon_metrics.py +0 -0
  48. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/polygon_utils.py +0 -0
  49. scale-nucleus-0.12b9/nucleus/metrics/segmentation_metrics.py +621 -0
  50. scale-nucleus-0.11b2/nucleus/metrics/segmentation_metrics.py → scale-nucleus-0.12b9/nucleus/metrics/segmentation_to_poly_metrics.py +19 -11
  51. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/metrics/segmentation_utils.py +0 -0
  52. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/model.py +0 -0
  53. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/model_run.py +6 -3
  54. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/payload_constructor.py +0 -0
  55. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/prediction.py +0 -0
  56. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/pydantic_base.py +0 -0
  57. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/quaternion.py +0 -0
  58. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/retry_strategy.py +0 -0
  59. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/scene.py +0 -0
  60. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/shapely_not_installed.py +0 -0
  61. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/slice.py +0 -0
  62. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/upload_response.py +0 -0
  63. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/url_utils.py +0 -0
  64. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/utils.py +0 -0
  65. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/__init__.py +0 -0
  66. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/client.py +0 -0
  67. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/constants.py +0 -0
  68. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/data_transfer_objects/__init__.py +0 -0
  69. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/data_transfer_objects/eval_function.py +0 -0
  70. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/data_transfer_objects/scenario_test.py +0 -0
  71. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +0 -0
  72. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/data_transfer_objects/scenario_test_metric.py +0 -0
  73. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/errors.py +0 -0
  74. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/eval_functions/__init__.py +0 -0
  75. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/eval_functions/available_eval_functions.py +85 -0
  76. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/eval_functions/base_eval_function.py +0 -0
  77. scale-nucleus-0.12b9/nucleus/validate/eval_functions/config_classes/__init__.py +0 -0
  78. scale-nucleus-0.12b9/nucleus/validate/eval_functions/config_classes/segmentation.py +319 -0
  79. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/scenario_test.py +0 -0
  80. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/scenario_test_evaluation.py +0 -0
  81. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/scenario_test_metric.py +0 -0
  82. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/nucleus/validate/utils.py +0 -0
  83. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/pyproject.toml +2 -2
  84. {scale-nucleus-0.11b2 → scale-nucleus-0.12b9}/setup.py +4 -3
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scale-nucleus
3
- Version: 0.11b2
3
+ Version: 0.12b9
4
4
  Summary: The official Python client library for Nucleus, the Data Platform for AI
5
5
  Home-page: https://scale.com/nucleus
6
6
  License: MIT
@@ -14,7 +14,7 @@ Classifier: Programming Language :: Python :: 3.7
14
14
  Classifier: Programming Language :: Python :: 3.8
15
15
  Classifier: Programming Language :: Python :: 3.9
16
16
  Provides-Extra: shapely
17
- Requires-Dist: Pillow (>=8.3.1)
17
+ Requires-Dist: Pillow (>=7.1.2)
18
18
  Requires-Dist: Shapely (>=1.8.0); extra == "shapely"
19
19
  Requires-Dist: aiohttp (>=3.7.4,<4.0.0)
20
20
  Requires-Dist: click (>=7.1.2,<9.0)
File without changes
File without changes
@@ -841,6 +841,10 @@ class NucleusClient:
841
841
  if payload is None:
842
842
  payload = {}
843
843
  if requests_command is requests.get:
844
+ if payload:
845
+ print(
846
+ "Received defined payload with GET request! Will ignore payload"
847
+ )
844
848
  payload = None
845
849
  return self._connection.make_request(payload, route, requests_command) # type: ignore
846
850
 
@@ -1,4 +1,5 @@
1
1
  import json
2
+ from collections import Counter
2
3
  from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
3
4
 
4
5
  from nucleus.annotation import Annotation, SegmentationAnnotation
@@ -8,6 +9,7 @@ from nucleus.async_utils import (
8
9
  make_many_form_data_requests_concurrently,
9
10
  )
10
11
  from nucleus.constants import MASK_TYPE, SERIALIZED_REQUEST_KEY
12
+ from nucleus.errors import DuplicateIDError
11
13
  from nucleus.payload_constructor import (
12
14
  construct_annotation_payload,
13
15
  construct_segmentation_payload,
@@ -208,6 +210,26 @@ class AnnotationUploader:
208
210
 
209
211
  return fn
210
212
 
213
+ @staticmethod
214
+ def check_for_duplicate_ids(annotations: Iterable[Annotation]):
215
+ """Do not allow annotations to have the same (annotation_id, reference_id) tuple"""
216
+
217
+ # some annotations like CategoryAnnotation do not have annotation_id attribute, and as such, we allow duplicates
218
+ tuple_ids = [
219
+ (ann.reference_id, ann.annotation_id) # type: ignore
220
+ for ann in annotations
221
+ if hasattr(ann, "annotation_id")
222
+ ]
223
+ tuple_count = Counter(tuple_ids)
224
+ duplicates = {key for key, value in tuple_count.items() if value > 1}
225
+ if len(duplicates) > 0:
226
+ raise DuplicateIDError(
227
+ f"Duplicate annotations with the same (reference_id, annotation_id) properties found.\n"
228
+ f"Duplicates: {duplicates}\n"
229
+ f"To fix this, avoid duplicate annotations, or specify a different annotation_id attribute "
230
+ f"for the failing items."
231
+ )
232
+
211
233
 
212
234
  class PredictionUploader(AnnotationUploader):
213
235
  def __init__(
@@ -389,6 +389,9 @@ class Dataset:
389
389
 
390
390
  Otherwise, returns an :class:`AsyncJob` object.
391
391
  """
392
+ uploader = AnnotationUploader(dataset_id=self.id, client=self._client)
393
+ uploader.check_for_duplicate_ids(annotations)
394
+
392
395
  if asynchronous:
393
396
  check_all_mask_paths_remote(annotations)
394
397
  request_id = serialize_and_write_to_presigned_url(
@@ -399,7 +402,7 @@ class Dataset:
399
402
  route=f"dataset/{self.id}/annotate?async=1",
400
403
  )
401
404
  return AsyncJob.from_json(response, self._client)
402
- uploader = AnnotationUploader(dataset_id=self.id, client=self._client)
405
+
403
406
  return uploader.upload(
404
407
  annotations=annotations,
405
408
  update=update,
@@ -1004,6 +1007,45 @@ class Dataset:
1004
1007
 
1005
1008
  return response
1006
1009
 
1010
+ def get_image_indexing_status(self):
1011
+ """Gets the primary image index progress for the dataset.
1012
+
1013
+ Returns:
1014
+ Response payload::
1015
+
1016
+ {
1017
+ "embedding_count": int
1018
+ "image_count": int
1019
+ "percent_indexed": float
1020
+ "additional_context": str
1021
+ }
1022
+ """
1023
+ return self._client.make_request(
1024
+ {"image": True},
1025
+ f"dataset/{self.id}/indexingStatus",
1026
+ requests_command=requests.post,
1027
+ )
1028
+
1029
+ def get_object_indexing_status(self, model_run_id=None):
1030
+ """Gets the primary object index progress of the dataset.
1031
+ If model_run_id is not specified, this endpoint will retrieve the indexing progress of the ground truth objects.
1032
+
1033
+ Returns:
1034
+ Response payload::
1035
+
1036
+ {
1037
+ "embedding_count": int
1038
+ "object_count": int
1039
+ "percent_indexed": float
1040
+ "additional_context": str
1041
+ }
1042
+ """
1043
+ return self._client.make_request(
1044
+ {"image": False, "model_run_id": model_run_id},
1045
+ f"dataset/{self.id}/indexingStatus",
1046
+ requests_command=requests.post,
1047
+ )
1048
+
1007
1049
  def create_image_index(self):
1008
1050
  """Creates or updates image index by generating embeddings for images that do not already have embeddings.
1009
1051
 
@@ -1405,6 +1447,14 @@ class Dataset:
1405
1447
  "predictions_ignored": int,
1406
1448
  }
1407
1449
  """
1450
+ uploader = PredictionUploader(
1451
+ model_run_id=None,
1452
+ dataset_id=self.id,
1453
+ model_id=model.id,
1454
+ client=self._client,
1455
+ )
1456
+ uploader.check_for_duplicate_ids(predictions)
1457
+
1408
1458
  if asynchronous:
1409
1459
  check_all_mask_paths_remote(predictions)
1410
1460
 
@@ -1416,21 +1466,15 @@ class Dataset:
1416
1466
  route=f"dataset/{self.id}/model/{model.id}/uploadPredictions?async=1",
1417
1467
  )
1418
1468
  return AsyncJob.from_json(response, self._client)
1419
- else:
1420
- uploader = PredictionUploader(
1421
- model_run_id=None,
1422
- dataset_id=self.id,
1423
- model_id=model.id,
1424
- client=self._client,
1425
- )
1426
- return uploader.upload(
1427
- annotations=predictions,
1428
- batch_size=batch_size,
1429
- update=update,
1430
- remote_files_per_upload_request=remote_files_per_upload_request,
1431
- local_files_per_upload_request=local_files_per_upload_request,
1432
- local_file_upload_concurrency=local_file_upload_concurrency,
1433
- )
1469
+
1470
+ return uploader.upload(
1471
+ annotations=predictions,
1472
+ batch_size=batch_size,
1473
+ update=update,
1474
+ remote_files_per_upload_request=remote_files_per_upload_request,
1475
+ local_files_per_upload_request=local_files_per_upload_request,
1476
+ local_file_upload_concurrency=local_file_upload_concurrency,
1477
+ )
1434
1478
 
1435
1479
  def predictions_iloc(self, model, index):
1436
1480
  """Fetches all predictions of a dataset item by its absolute index.
@@ -72,3 +72,9 @@ class NoAPIKey(Exception):
72
72
  ):
73
73
  self.message = message
74
74
  super().__init__(self.message)
75
+
76
+
77
+ class DuplicateIDError(Exception):
78
+ def __init__(self, message):
79
+ self.message = message
80
+ super().__init__(self.message)
@@ -5,6 +5,8 @@ from .filtering import (
5
5
  FieldFilter,
6
6
  ListOfOrAndFilters,
7
7
  MetadataFilter,
8
+ SegmentFieldFilter,
9
+ SegmentMetadataFilter,
8
10
  apply_filters,
9
11
  )
10
12
  from .polygon_metrics import (
@@ -16,6 +18,15 @@ from .polygon_metrics import (
16
18
  PolygonRecall,
17
19
  )
18
20
  from .segmentation_metrics import (
21
+ SegmentationAveragePrecision,
22
+ SegmentationFWAVACC,
23
+ SegmentationIOU,
24
+ SegmentationMAP,
25
+ SegmentationMaskMetric,
26
+ SegmentationPrecision,
27
+ SegmentationRecall,
28
+ )
29
+ from .segmentation_to_poly_metrics import (
19
30
  SegmentationMaskToPolyMetric,
20
31
  SegmentationToPolyAveragePrecision,
21
32
  SegmentationToPolyIOU,
@@ -2,7 +2,16 @@ import enum
2
2
  import functools
3
3
  import logging
4
4
  from enum import Enum
5
- from typing import Callable, Iterable, List, NamedTuple, Sequence, Set, Union
5
+ from typing import (
6
+ Callable,
7
+ Iterable,
8
+ List,
9
+ NamedTuple,
10
+ Sequence,
11
+ Set,
12
+ Tuple,
13
+ Union,
14
+ )
6
15
 
7
16
  from nucleus.annotation import (
8
17
  BoxAnnotation,
@@ -11,6 +20,7 @@ from nucleus.annotation import (
11
20
  LineAnnotation,
12
21
  MultiCategoryAnnotation,
13
22
  PolygonAnnotation,
23
+ Segment,
14
24
  SegmentationAnnotation,
15
25
  )
16
26
  from nucleus.prediction import (
@@ -40,10 +50,14 @@ class FilterType(str, enum.Enum):
40
50
  Attributes:
41
51
  FIELD: Access the attribute field of an object
42
52
  METADATA: Access the metadata dictionary of an object
53
+ SEGMENT_FIELD: Filter segments of a segmentation mask to be considered on segment fields
54
+ SEGMENT_METADATA: Filter segments of a segmentation mask based on segment metadata
43
55
  """
44
56
 
45
57
  FIELD = "field"
46
58
  METADATA = "metadata"
59
+ SEGMENT_FIELD = "segment_field"
60
+ SEGMENT_METADATA = "segment_metadata"
47
61
 
48
62
 
49
63
  FilterableBaseVals = Union[str, float, int, bool]
@@ -100,7 +114,8 @@ class FieldFilter(NamedTuple):
100
114
 
101
115
  Examples:
102
116
  FieldFilter("x", ">", 10) would pass every :class:`BoxAnnotation` with `x` attribute larger than 10
103
- FieldFilter("label", "in", [) would pass every :class:`BoxAnnotation` with `x` attribute larger than 10
117
+ FieldFilter("label", "in", ["car", "truck"]) would pass every :class:`BoxAnnotation` with `label`
118
+ in ["car", "truck"]
104
119
 
105
120
  Attributes:
106
121
  key: key to compare with value
@@ -129,7 +144,7 @@ class MetadataFilter(NamedTuple):
129
144
  with value field
130
145
  value: bool, str, float or int to compare the field with key or list of the same values for 'in' and 'not in'
131
146
  ops
132
- allow_missing: Allow missing metada values. Will REMOVE the object with the missing field from the selection
147
+ allow_missing: Allow missing metadata values. Will REMOVE the object with the missing field from the selection
133
148
  type: DO NOT USE. Internal type for serialization over the wire. Changing this will change the `NamedTuple`
134
149
  type as well.
135
150
  """
@@ -141,7 +156,60 @@ class MetadataFilter(NamedTuple):
141
156
  type: FilterType = FilterType.METADATA
142
157
 
143
158
 
144
- Filter = Union[FieldFilter, MetadataFilter, AnnotationOrPredictionFilter]
159
+ class SegmentMetadataFilter(NamedTuple):
160
+ """Filter on customer provided metadata associated with Segments of a SegmentationAnnotation or
161
+ SegmentationPrediction
162
+
163
+ Attributes:
164
+ key: key to compare with value
165
+ op: :class:`FilterOp` or one of [">", ">=", "<", "<=", "=", "==", "!=", "in", "not in"] to define comparison
166
+ with value field
167
+ value: bool, str, float or int to compare the field with key or list of the same values for 'in' and 'not in'
168
+ ops
169
+ allow_missing: Allow missing metadata values. Will REMOVE the object with the missing field from the selection
170
+ type: DO NOT USE. Internal type for serialization over the wire. Changing this will change the `NamedTuple`
171
+ type as well.
172
+ """
173
+
174
+ key: str
175
+ op: Union[FilterOp, str]
176
+ value: FilterableTypes
177
+ allow_missing: bool = False
178
+ type: FilterType = FilterType.SEGMENT_METADATA
179
+
180
+
181
+ class SegmentFieldFilter(NamedTuple):
182
+ """Filter on standard field of Segment(s) of SegmentationAnnotation and SegmentationPrediction
183
+
184
+ Examples:
185
+ SegmentFieldFilter("label", "in", ["grass", "tree"]) would pass every :class:`Segment` of a
186
+ :class:`SegmentationAnnotation or :class:`SegmentationPrediction`
187
+
188
+ Attributes:
189
+ key: key to compare with value
190
+ op: :class:`FilterOp` or one of [">", ">=", "<", "<=", "=", "==", "!=", "in", "not in"] to define comparison
191
+ with value field
192
+ value: bool, str, float or int to compare the field with key or list of the same values for 'in' and 'not in'
193
+ ops
194
+ allow_missing: Allow missing field values. Will REMOVE the object with the missing field from the selection
195
+ type: DO NOT USE. Internal type for serialization over the wire. Changing this will change the `NamedTuple`
196
+ type as well.
197
+ """
198
+
199
+ key: str
200
+ op: Union[FilterOp, str]
201
+ value: FilterableTypes
202
+ allow_missing: bool = False
203
+ type: FilterType = FilterType.SEGMENT_FIELD
204
+
205
+
206
+ Filter = Union[
207
+ FieldFilter,
208
+ MetadataFilter,
209
+ AnnotationOrPredictionFilter,
210
+ SegmentFieldFilter,
211
+ SegmentMetadataFilter,
212
+ ]
145
213
  OrAndDNFFilters = List[List[Filter]]
146
214
  OrAndDNFFilters.__doc__ = """\
147
215
  Disjunctive normal form (DNF) filters.
@@ -182,11 +250,15 @@ ListOfAndFilters = Union[
182
250
  ListOfAndJSONSerialized,
183
251
  ]
184
252
 
253
+ DNFFieldOrMetadataFilters = List[
254
+ List[Union[FieldFilter, MetadataFilter, AnnotationOrPredictionFilter]]
255
+ ]
256
+
185
257
 
186
258
  def _attribute_getter(
187
259
  field_name: str,
188
260
  allow_missing: bool,
189
- ann_or_pred: Union[AnnotationTypes, PredictionTypes],
261
+ ann_or_pred: Union[AnnotationTypes, PredictionTypes, Segment],
190
262
  ):
191
263
  """Create a function to get object fields"""
192
264
  if allow_missing:
@@ -224,7 +296,7 @@ class AlwaysFalseComparison:
224
296
  def _metadata_field_getter(
225
297
  field_name: str,
226
298
  allow_missing: bool,
227
- ann_or_pred: Union[AnnotationTypes, PredictionTypes],
299
+ ann_or_pred: Union[AnnotationTypes, PredictionTypes, Segment],
228
300
  ):
229
301
  """Create a function to get a metadata field"""
230
302
  if isinstance(
@@ -259,7 +331,7 @@ def _metadata_field_getter(
259
331
 
260
332
  def _filter_to_comparison_function( # pylint: disable=too-many-return-statements
261
333
  filter_def: Filter,
262
- ) -> Callable[[Union[AnnotationTypes, PredictionTypes]], bool]:
334
+ ) -> Callable[[Union[AnnotationTypes, PredictionTypes, Segment]], bool]:
263
335
  """Creates a comparison function from a filter configuration to apply to annotations or predictions
264
336
 
265
337
  Parameters:
@@ -276,6 +348,10 @@ def _filter_to_comparison_function( # pylint: disable=too-many-return-statement
276
348
  getter = functools.partial(
277
349
  _metadata_field_getter, filter_def.key, filter_def.allow_missing
278
350
  )
351
+ else:
352
+ raise NotImplementedError(
353
+ f"Unhandled filter type: {filter_def.type}. NOTE: Segmentation filters are handled elsewhere."
354
+ )
279
355
  op = FilterOp(filter_def.op)
280
356
  if op is FilterOp.GT:
281
357
  return lambda ann_or_pred: getter(ann_or_pred) > filter_def.value
@@ -303,13 +379,16 @@ def _filter_to_comparison_function( # pylint: disable=too-many-return-statement
303
379
  )
304
380
 
305
381
 
306
- def apply_filters(
307
- ann_or_pred: Union[Sequence[AnnotationTypes], Sequence[PredictionTypes]],
308
- filters: Union[ListOfOrAndFilters, ListOfAndFilters],
382
+ def _apply_field_or_metadata_filters(
383
+ filterable_sequence: Union[
384
+ Sequence[AnnotationTypes], Sequence[PredictionTypes], Sequence[Segment]
385
+ ],
386
+ filters: DNFFieldOrMetadataFilters,
309
387
  ):
310
- """Apply filters to list of annotations or list of predictions
388
+ """Apply filters to list of annotations or list of predictions or to a list of segments
389
+
311
390
  Attributes:
312
- ann_or_pred: Prediction or Annotation
391
+ filterable_sequence: Prediction or Annotation or Segment sequence
313
392
  filters: Filter predicates. Allowed formats are:
314
393
  ListOfAndFilters where each Filter forms a chain of AND predicates.
315
394
  or
@@ -320,11 +399,6 @@ def apply_filters(
320
399
  is interpreted as a conjunction (AND), forming a more selective `and` multiple column predicate.
321
400
  Finally, the most outer list combines these filters as a disjunction (OR).
322
401
  """
323
- if filters is None or len(filters) == 0:
324
- return ann_or_pred
325
-
326
- filters = ensureDNFFilters(filters)
327
-
328
402
  dnf_condition_functions = []
329
403
  for or_branch in filters:
330
404
  and_conditions = [
@@ -333,18 +407,136 @@ def apply_filters(
333
407
  dnf_condition_functions.append(and_conditions)
334
408
 
335
409
  filtered = []
336
- for item in ann_or_pred:
410
+ for item in filterable_sequence:
337
411
  for or_conditions in dnf_condition_functions:
338
412
  if all(c(item) for c in or_conditions):
339
413
  filtered.append(item)
340
414
  break
415
+
416
+ return filtered
417
+
418
+
419
+ def _split_segment_filters(
420
+ dnf_filters: OrAndDNFFilters,
421
+ ) -> Tuple[OrAndDNFFilters, OrAndDNFFilters]:
422
+ """We treat Segment* filters differently -> this splits filters into two sets, one containing the
423
+ standard field, metadata branches and the other the segment filters.
424
+ """
425
+ normal_or_branches = []
426
+ segment_or_branches = []
427
+ for and_branch in dnf_filters:
428
+ normal_filters = []
429
+ segment_filters = []
430
+ for filter_statement in and_branch:
431
+ if filter_statement.type in {
432
+ FilterType.SEGMENT_METADATA,
433
+ FilterType.SEGMENT_FIELD,
434
+ }:
435
+ segment_filters.append(filter_statement)
436
+ else:
437
+ normal_filters.append(filter_statement)
438
+ normal_or_branches.append(normal_filters)
439
+ segment_or_branches.append(segment_filters)
440
+ return normal_or_branches, segment_or_branches
441
+
442
+
443
+ def _filter_segments(
444
+ anns_or_preds: Union[
445
+ Sequence[SegmentationAnnotation], Sequence[SegmentationPrediction]
446
+ ],
447
+ segment_filters: OrAndDNFFilters,
448
+ ):
449
+ """Filter Segments of a SegmentationAnnotation or Prediction
450
+
451
+ We have to treat this differently as metadata and labels are on nested Segment objects
452
+ """
453
+ if len(segment_filters) == 0 or len(segment_filters[0]) == 0:
454
+ return anns_or_preds
455
+
456
+ # Transform segment filter types to field and metadata to iterate over annotation sub fields
457
+ transformed_or_branches = (
458
+ []
459
+ ) # type: List[List[Union[MetadataFilter, FieldFilter]]]
460
+ for and_branch in segment_filters:
461
+ transformed_and = [] # type: List[Union[MetadataFilter, FieldFilter]]
462
+ for filter_statement in and_branch:
463
+ if filter_statement.type == FilterType.SEGMENT_FIELD:
464
+ transformed_and.append(
465
+ FieldFilter(
466
+ filter_statement.key,
467
+ filter_statement.op,
468
+ filter_statement.value,
469
+ filter_statement.allow_missing,
470
+ )
471
+ )
472
+ elif filter_statement.type == FilterType.SEGMENT_METADATA:
473
+ transformed_and.append(
474
+ MetadataFilter(
475
+ filter_statement.key,
476
+ filter_statement.op,
477
+ filter_statement.value,
478
+ filter_statement.allow_missing,
479
+ )
480
+ )
481
+ else:
482
+ raise RuntimeError("Encountered a non SEGMENT_* filter type")
483
+
484
+ transformed_or_branches.append(transformed_and)
485
+
486
+ segments_filtered = []
487
+ for ann_or_pred in anns_or_preds:
488
+ if isinstance(
489
+ ann_or_pred, (SegmentationAnnotation, SegmentationPrediction)
490
+ ):
491
+ ann_or_pred.annotations = _apply_field_or_metadata_filters(
492
+ ann_or_pred.annotations, transformed_or_branches # type: ignore
493
+ )
494
+ segments_filtered.append(ann_or_pred)
495
+
496
+ return segments_filtered
497
+
498
+
499
+ def apply_filters(
500
+ ann_or_pred: Union[Sequence[AnnotationTypes], Sequence[PredictionTypes]],
501
+ filters: Union[ListOfOrAndFilters, ListOfAndFilters],
502
+ ):
503
+ """Apply filters to list of annotations or list of predictions
504
+ Attributes:
505
+ ann_or_pred: Prediction or Annotation
506
+ filters: Filter predicates. Allowed formats are:
507
+ ListOfAndFilters where each Filter forms a chain of AND predicates.
508
+ or
509
+ ListOfOrAndFilters where Filters are expressed in disjunctive normal form (DNF), like
510
+ [[MetadataFilter("short_haired", "==", True), FieldFilter("label", "in", ["cat", "dog"]), ...].
511
+ DNF allows arbitrary boolean logical combinations of single field
512
+ predicates. The innermost structures each describe a single column predicate. The list of inner predicates
513
+ is interpreted as a conjunction (AND), forming a more selective `and` multiple column predicate.
514
+ Finally, the most outer list combines these filters as a disjunction (OR).
515
+ """
516
+ if filters is None or len(filters) == 0:
517
+ return ann_or_pred
518
+
519
+ dnf_filters = ensureDNFFilters(filters)
520
+ filters, segment_filters = _split_segment_filters(dnf_filters)
521
+ filtered = _apply_field_or_metadata_filters(ann_or_pred, filters) # type: ignore
522
+ filtered = _filter_segments(filtered, segment_filters)
523
+
341
524
  return filtered
342
525
 
343
526
 
344
527
  def ensureDNFFilters(filters) -> OrAndDNFFilters:
345
528
  """JSON encoding creates a triple nested lists from the doubly nested tuples. This function creates the
346
529
  tuple form again."""
347
- if isinstance(filters[0], (MetadataFilter, FieldFilter)):
530
+ if isinstance(
531
+ filters[0],
532
+ (
533
+ MetadataFilter,
534
+ FieldFilter,
535
+ AnnotationOrPredictionFilter,
536
+ SegmentFieldFilter,
537
+ SegmentMetadataFilter,
538
+ ),
539
+ ):
348
540
  # Normalize into DNF
349
541
  filters: ListOfOrAndFilters = [filters] # type: ignore
350
542