scale-nucleus 0.12b1__py3-none-any.whl → 0.14.14b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/slices.py +14 -28
- nucleus/__init__.py +211 -18
- nucleus/annotation.py +28 -5
- nucleus/connection.py +9 -1
- nucleus/constants.py +9 -3
- nucleus/dataset.py +197 -59
- nucleus/dataset_item.py +11 -1
- nucleus/job.py +1 -1
- nucleus/metrics/__init__.py +2 -1
- nucleus/metrics/base.py +34 -56
- nucleus/metrics/categorization_metrics.py +6 -2
- nucleus/metrics/cuboid_utils.py +4 -6
- nucleus/metrics/errors.py +4 -0
- nucleus/metrics/filtering.py +369 -19
- nucleus/metrics/polygon_utils.py +3 -3
- nucleus/metrics/segmentation_loader.py +30 -0
- nucleus/metrics/segmentation_metrics.py +256 -195
- nucleus/metrics/segmentation_to_poly_metrics.py +229 -105
- nucleus/metrics/segmentation_utils.py +239 -8
- nucleus/model.py +66 -10
- nucleus/model_run.py +1 -1
- nucleus/{shapely_not_installed.py → package_not_installed.py} +3 -3
- nucleus/payload_constructor.py +4 -0
- nucleus/prediction.py +6 -3
- nucleus/scene.py +7 -0
- nucleus/slice.py +160 -16
- nucleus/utils.py +51 -12
- nucleus/validate/__init__.py +1 -0
- nucleus/validate/client.py +57 -8
- nucleus/validate/constants.py +1 -0
- nucleus/validate/data_transfer_objects/eval_function.py +22 -0
- nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +13 -5
- nucleus/validate/eval_functions/available_eval_functions.py +33 -20
- nucleus/validate/eval_functions/config_classes/segmentation.py +2 -46
- nucleus/validate/scenario_test.py +71 -13
- nucleus/validate/scenario_test_evaluation.py +21 -21
- nucleus/validate/utils.py +1 -1
- {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/LICENSE +0 -0
- {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/METADATA +13 -11
- {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/RECORD +42 -41
- {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/WHEEL +1 -1
- {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/entry_points.txt +0 -0
nucleus/metrics/base.py
CHANGED
@@ -4,10 +4,13 @@ from dataclasses import dataclass
|
|
4
4
|
from typing import Iterable, List, Optional, Union
|
5
5
|
|
6
6
|
from nucleus.annotation import AnnotationList
|
7
|
+
from nucleus.metrics.errors import EverythingFilteredError
|
7
8
|
from nucleus.metrics.filtering import (
|
8
9
|
ListOfAndFilters,
|
9
10
|
ListOfOrAndFilters,
|
10
|
-
|
11
|
+
compose_helpful_filtering_error,
|
12
|
+
filter_annotation_list,
|
13
|
+
filter_prediction_list,
|
11
14
|
)
|
12
15
|
from nucleus.prediction import PredictionList
|
13
16
|
|
@@ -133,64 +136,16 @@ class Metric(ABC):
|
|
133
136
|
def __call__(
|
134
137
|
self, annotations: AnnotationList, predictions: PredictionList
|
135
138
|
) -> MetricResult:
|
136
|
-
|
137
|
-
|
138
|
-
return self.call_metric(annotations, predictions)
|
139
|
-
|
140
|
-
def _filter_annotations(self, annotations: AnnotationList):
|
141
|
-
if (
|
142
|
-
self.annotation_filters is None
|
143
|
-
or len(self.annotation_filters) == 0
|
144
|
-
):
|
145
|
-
return annotations
|
146
|
-
annotations.box_annotations = apply_filters(
|
147
|
-
annotations.box_annotations, self.annotation_filters
|
139
|
+
filtered_anns = filter_annotation_list(
|
140
|
+
annotations, self.annotation_filters
|
148
141
|
)
|
149
|
-
|
150
|
-
|
142
|
+
filtered_preds = filter_prediction_list(
|
143
|
+
predictions, self.prediction_filters
|
151
144
|
)
|
152
|
-
|
153
|
-
annotations
|
145
|
+
self._raise_if_everything_filtered(
|
146
|
+
annotations, filtered_anns, predictions, filtered_preds
|
154
147
|
)
|
155
|
-
|
156
|
-
annotations.cuboid_annotations, self.annotation_filters
|
157
|
-
)
|
158
|
-
annotations.category_annotations = apply_filters(
|
159
|
-
annotations.category_annotations, self.annotation_filters
|
160
|
-
)
|
161
|
-
annotations.multi_category_annotations = apply_filters(
|
162
|
-
annotations.multi_category_annotations, self.annotation_filters
|
163
|
-
)
|
164
|
-
annotations.segmentation_annotations = apply_filters(
|
165
|
-
annotations.segmentation_annotations, self.annotation_filters
|
166
|
-
)
|
167
|
-
return annotations
|
168
|
-
|
169
|
-
def _filter_predictions(self, predictions: PredictionList):
|
170
|
-
if (
|
171
|
-
self.prediction_filters is None
|
172
|
-
or len(self.prediction_filters) == 0
|
173
|
-
):
|
174
|
-
return predictions
|
175
|
-
predictions.box_predictions = apply_filters(
|
176
|
-
predictions.box_predictions, self.prediction_filters
|
177
|
-
)
|
178
|
-
predictions.line_predictions = apply_filters(
|
179
|
-
predictions.line_predictions, self.prediction_filters
|
180
|
-
)
|
181
|
-
predictions.polygon_predictions = apply_filters(
|
182
|
-
predictions.polygon_predictions, self.prediction_filters
|
183
|
-
)
|
184
|
-
predictions.cuboid_predictions = apply_filters(
|
185
|
-
predictions.cuboid_predictions, self.prediction_filters
|
186
|
-
)
|
187
|
-
predictions.category_predictions = apply_filters(
|
188
|
-
predictions.category_predictions, self.prediction_filters
|
189
|
-
)
|
190
|
-
predictions.segmentation_predictions = apply_filters(
|
191
|
-
predictions.segmentation_predictions, self.prediction_filters
|
192
|
-
)
|
193
|
-
return predictions
|
148
|
+
return self.call_metric(filtered_anns, filtered_preds)
|
194
149
|
|
195
150
|
@abstractmethod
|
196
151
|
def aggregate_score(self, results: List[MetricResult]) -> ScalarResult:
|
@@ -215,3 +170,26 @@ class Metric(ABC):
|
|
215
170
|
return ScalarResult(r2_score)
|
216
171
|
|
217
172
|
"""
|
173
|
+
|
174
|
+
def _raise_if_everything_filtered(
|
175
|
+
self,
|
176
|
+
annotations: AnnotationList,
|
177
|
+
filtered_annotations: AnnotationList,
|
178
|
+
predictions: PredictionList,
|
179
|
+
filtered_predictions: PredictionList,
|
180
|
+
):
|
181
|
+
msg = []
|
182
|
+
if len(filtered_annotations) == 0:
|
183
|
+
msg.extend(
|
184
|
+
compose_helpful_filtering_error(
|
185
|
+
annotations, self.annotation_filters
|
186
|
+
)
|
187
|
+
)
|
188
|
+
if len(filtered_predictions) == 0:
|
189
|
+
msg.extend(
|
190
|
+
compose_helpful_filtering_error(
|
191
|
+
predictions, self.prediction_filters
|
192
|
+
)
|
193
|
+
)
|
194
|
+
if msg:
|
195
|
+
raise EverythingFilteredError("\n".join(msg))
|
@@ -2,8 +2,6 @@ from abc import abstractmethod
|
|
2
2
|
from dataclasses import dataclass
|
3
3
|
from typing import List, Optional, Set, Tuple, Union
|
4
4
|
|
5
|
-
from sklearn.metrics import f1_score
|
6
|
-
|
7
5
|
from nucleus.annotation import AnnotationList, CategoryAnnotation
|
8
6
|
from nucleus.metrics.base import Metric, MetricResult, ScalarResult
|
9
7
|
from nucleus.metrics.filtering import ListOfAndFilters, ListOfOrAndFilters
|
@@ -35,6 +33,9 @@ class CategorizationResult(MetricResult):
|
|
35
33
|
|
36
34
|
@property
|
37
35
|
def value(self):
|
36
|
+
# late import to avoid slow CLI init
|
37
|
+
from sklearn.metrics import f1_score
|
38
|
+
|
38
39
|
annotation_labels = to_taxonomy_labels(self.annotations)
|
39
40
|
prediction_labels = to_taxonomy_labels(self.predictions)
|
40
41
|
|
@@ -245,6 +246,9 @@ class CategorizationF1(CategorizationMetric):
|
|
245
246
|
)
|
246
247
|
|
247
248
|
def aggregate_score(self, results: List[CategorizationResult]) -> ScalarResult: # type: ignore[override]
|
249
|
+
# late import to avoid slow CLI init
|
250
|
+
from sklearn.metrics import f1_score
|
251
|
+
|
248
252
|
gt = []
|
249
253
|
predicted = []
|
250
254
|
for result in results:
|
nucleus/metrics/cuboid_utils.py
CHANGED
@@ -5,10 +5,10 @@ import numpy as np
|
|
5
5
|
|
6
6
|
try:
|
7
7
|
from shapely.geometry import Polygon
|
8
|
-
except ModuleNotFoundError:
|
9
|
-
from ..
|
8
|
+
except (ModuleNotFoundError, OSError):
|
9
|
+
from ..package_not_installed import PackageNotInstalled
|
10
10
|
|
11
|
-
Polygon =
|
11
|
+
Polygon = PackageNotInstalled
|
12
12
|
|
13
13
|
|
14
14
|
from nucleus.annotation import CuboidAnnotation
|
@@ -357,6 +357,4 @@ def detection_iou(
|
|
357
357
|
meter_3d.append(iou_3d[i, j])
|
358
358
|
meter_2d.append(iou_2d[i, j])
|
359
359
|
|
360
|
-
|
361
|
-
meter_2d = np.array(meter_2d)
|
362
|
-
return meter_3d, meter_2d
|
360
|
+
return np.array(meter_3d), np.array(meter_2d)
|
nucleus/metrics/errors.py
CHANGED
nucleus/metrics/filtering.py
CHANGED
@@ -1,16 +1,32 @@
|
|
1
|
+
import copy
|
1
2
|
import enum
|
2
3
|
import functools
|
3
4
|
import logging
|
4
5
|
from enum import Enum
|
5
|
-
from typing import
|
6
|
+
from typing import (
|
7
|
+
Callable,
|
8
|
+
Iterable,
|
9
|
+
List,
|
10
|
+
NamedTuple,
|
11
|
+
Optional,
|
12
|
+
Sequence,
|
13
|
+
Set,
|
14
|
+
Tuple,
|
15
|
+
Union,
|
16
|
+
)
|
17
|
+
|
18
|
+
from rich.console import Console
|
19
|
+
from rich.table import Table
|
6
20
|
|
7
21
|
from nucleus.annotation import (
|
22
|
+
AnnotationList,
|
8
23
|
BoxAnnotation,
|
9
24
|
CategoryAnnotation,
|
10
25
|
CuboidAnnotation,
|
11
26
|
LineAnnotation,
|
12
27
|
MultiCategoryAnnotation,
|
13
28
|
PolygonAnnotation,
|
29
|
+
Segment,
|
14
30
|
SegmentationAnnotation,
|
15
31
|
)
|
16
32
|
from nucleus.prediction import (
|
@@ -19,6 +35,7 @@ from nucleus.prediction import (
|
|
19
35
|
CuboidPrediction,
|
20
36
|
LinePrediction,
|
21
37
|
PolygonPrediction,
|
38
|
+
PredictionList,
|
22
39
|
SegmentationPrediction,
|
23
40
|
)
|
24
41
|
|
@@ -40,10 +57,14 @@ class FilterType(str, enum.Enum):
|
|
40
57
|
Attributes:
|
41
58
|
FIELD: Access the attribute field of an object
|
42
59
|
METADATA: Access the metadata dictionary of an object
|
60
|
+
SEGMENT_FIELD: Filter segments of a segmentation mask to be considered on segment fields
|
61
|
+
SEGMENT_METADATA: Filter segments of a segmentation mask based on segment metadata
|
43
62
|
"""
|
44
63
|
|
45
64
|
FIELD = "field"
|
46
65
|
METADATA = "metadata"
|
66
|
+
SEGMENT_FIELD = "segment_field"
|
67
|
+
SEGMENT_METADATA = "segment_metadata"
|
47
68
|
|
48
69
|
|
49
70
|
FilterableBaseVals = Union[str, float, int, bool]
|
@@ -100,7 +121,8 @@ class FieldFilter(NamedTuple):
|
|
100
121
|
|
101
122
|
Examples:
|
102
123
|
FieldFilter("x", ">", 10) would pass every :class:`BoxAnnotation` with `x` attribute larger than 10
|
103
|
-
FieldFilter("label", "in", [) would pass every :class:`BoxAnnotation` with `
|
124
|
+
FieldFilter("label", "in", ["car", "truck"]) would pass every :class:`BoxAnnotation` with `label`
|
125
|
+
in ["car", "truck"]
|
104
126
|
|
105
127
|
Attributes:
|
106
128
|
key: key to compare with value
|
@@ -129,7 +151,7 @@ class MetadataFilter(NamedTuple):
|
|
129
151
|
with value field
|
130
152
|
value: bool, str, float or int to compare the field with key or list of the same values for 'in' and 'not in'
|
131
153
|
ops
|
132
|
-
allow_missing: Allow missing
|
154
|
+
allow_missing: Allow missing metadata values. Will REMOVE the object with the missing field from the selection
|
133
155
|
type: DO NOT USE. Internal type for serialization over the wire. Changing this will change the `NamedTuple`
|
134
156
|
type as well.
|
135
157
|
"""
|
@@ -141,7 +163,60 @@ class MetadataFilter(NamedTuple):
|
|
141
163
|
type: FilterType = FilterType.METADATA
|
142
164
|
|
143
165
|
|
144
|
-
|
166
|
+
class SegmentMetadataFilter(NamedTuple):
|
167
|
+
"""Filter on customer provided metadata associated with Segments of a SegmentationAnnotation or
|
168
|
+
SegmentationPrediction
|
169
|
+
|
170
|
+
Attributes:
|
171
|
+
key: key to compare with value
|
172
|
+
op: :class:`FilterOp` or one of [">", ">=", "<", "<=", "=", "==", "!=", "in", "not in"] to define comparison
|
173
|
+
with value field
|
174
|
+
value: bool, str, float or int to compare the field with key or list of the same values for 'in' and 'not in'
|
175
|
+
ops
|
176
|
+
allow_missing: Allow missing metadata values. Will REMOVE the object with the missing field from the selection
|
177
|
+
type: DO NOT USE. Internal type for serialization over the wire. Changing this will change the `NamedTuple`
|
178
|
+
type as well.
|
179
|
+
"""
|
180
|
+
|
181
|
+
key: str
|
182
|
+
op: Union[FilterOp, str]
|
183
|
+
value: FilterableTypes
|
184
|
+
allow_missing: bool = False
|
185
|
+
type: FilterType = FilterType.SEGMENT_METADATA
|
186
|
+
|
187
|
+
|
188
|
+
class SegmentFieldFilter(NamedTuple):
|
189
|
+
"""Filter on standard field of Segment(s) of SegmentationAnnotation and SegmentationPrediction
|
190
|
+
|
191
|
+
Examples:
|
192
|
+
SegmentFieldFilter("label", "in", ["grass", "tree"]) would pass every :class:`Segment` of a
|
193
|
+
:class:`SegmentationAnnotation or :class:`SegmentationPrediction`
|
194
|
+
|
195
|
+
Attributes:
|
196
|
+
key: key to compare with value
|
197
|
+
op: :class:`FilterOp` or one of [">", ">=", "<", "<=", "=", "==", "!=", "in", "not in"] to define comparison
|
198
|
+
with value field
|
199
|
+
value: bool, str, float or int to compare the field with key or list of the same values for 'in' and 'not in'
|
200
|
+
ops
|
201
|
+
allow_missing: Allow missing field values. Will REMOVE the object with the missing field from the selection
|
202
|
+
type: DO NOT USE. Internal type for serialization over the wire. Changing this will change the `NamedTuple`
|
203
|
+
type as well.
|
204
|
+
"""
|
205
|
+
|
206
|
+
key: str
|
207
|
+
op: Union[FilterOp, str]
|
208
|
+
value: FilterableTypes
|
209
|
+
allow_missing: bool = False
|
210
|
+
type: FilterType = FilterType.SEGMENT_FIELD
|
211
|
+
|
212
|
+
|
213
|
+
Filter = Union[
|
214
|
+
FieldFilter,
|
215
|
+
MetadataFilter,
|
216
|
+
AnnotationOrPredictionFilter,
|
217
|
+
SegmentFieldFilter,
|
218
|
+
SegmentMetadataFilter,
|
219
|
+
]
|
145
220
|
OrAndDNFFilters = List[List[Filter]]
|
146
221
|
OrAndDNFFilters.__doc__ = """\
|
147
222
|
Disjunctive normal form (DNF) filters.
|
@@ -182,11 +257,22 @@ ListOfAndFilters = Union[
|
|
182
257
|
ListOfAndJSONSerialized,
|
183
258
|
]
|
184
259
|
|
260
|
+
DNFFieldOrMetadataFilters = List[
|
261
|
+
List[Union[FieldFilter, MetadataFilter, AnnotationOrPredictionFilter]]
|
262
|
+
]
|
263
|
+
DNFFieldOrMetadataFilters.__doc__ = """\
|
264
|
+
Disjunctive normal form (DNF) filters.
|
265
|
+
DNF allows arbitrary boolean logical combinations of single field predicates.
|
266
|
+
The innermost structures each describe a single field predicate.
|
267
|
+
-The list of inner predicates is interpreted as a conjunction (AND), forming a more selective and multiple column
|
268
|
+
predicate.
|
269
|
+
"""
|
270
|
+
|
185
271
|
|
186
272
|
def _attribute_getter(
|
187
273
|
field_name: str,
|
188
274
|
allow_missing: bool,
|
189
|
-
ann_or_pred: Union[AnnotationTypes, PredictionTypes],
|
275
|
+
ann_or_pred: Union[AnnotationTypes, PredictionTypes, Segment],
|
190
276
|
):
|
191
277
|
"""Create a function to get object fields"""
|
192
278
|
if allow_missing:
|
@@ -224,7 +310,7 @@ class AlwaysFalseComparison:
|
|
224
310
|
def _metadata_field_getter(
|
225
311
|
field_name: str,
|
226
312
|
allow_missing: bool,
|
227
|
-
ann_or_pred: Union[AnnotationTypes, PredictionTypes],
|
313
|
+
ann_or_pred: Union[AnnotationTypes, PredictionTypes, Segment],
|
228
314
|
):
|
229
315
|
"""Create a function to get a metadata field"""
|
230
316
|
if isinstance(
|
@@ -259,7 +345,7 @@ def _metadata_field_getter(
|
|
259
345
|
|
260
346
|
def _filter_to_comparison_function( # pylint: disable=too-many-return-statements
|
261
347
|
filter_def: Filter,
|
262
|
-
) -> Callable[[Union[AnnotationTypes, PredictionTypes]], bool]:
|
348
|
+
) -> Callable[[Union[AnnotationTypes, PredictionTypes, Segment]], bool]:
|
263
349
|
"""Creates a comparison function from a filter configuration to apply to annotations or predictions
|
264
350
|
|
265
351
|
Parameters:
|
@@ -276,6 +362,10 @@ def _filter_to_comparison_function( # pylint: disable=too-many-return-statement
|
|
276
362
|
getter = functools.partial(
|
277
363
|
_metadata_field_getter, filter_def.key, filter_def.allow_missing
|
278
364
|
)
|
365
|
+
else:
|
366
|
+
raise NotImplementedError(
|
367
|
+
f"Unhandled filter type: {filter_def.type}. NOTE: Segmentation filters are handled elsewhere."
|
368
|
+
)
|
279
369
|
op = FilterOp(filter_def.op)
|
280
370
|
if op is FilterOp.GT:
|
281
371
|
return lambda ann_or_pred: getter(ann_or_pred) > filter_def.value
|
@@ -303,13 +393,16 @@ def _filter_to_comparison_function( # pylint: disable=too-many-return-statement
|
|
303
393
|
)
|
304
394
|
|
305
395
|
|
306
|
-
def
|
307
|
-
|
308
|
-
|
396
|
+
def _apply_field_or_metadata_filters(
|
397
|
+
filterable_sequence: Union[
|
398
|
+
Sequence[AnnotationTypes], Sequence[PredictionTypes], Sequence[Segment]
|
399
|
+
],
|
400
|
+
filters: DNFFieldOrMetadataFilters,
|
309
401
|
):
|
310
|
-
"""Apply filters to list of annotations or list of predictions
|
402
|
+
"""Apply filters to list of annotations or list of predictions or to a list of segments
|
403
|
+
|
311
404
|
Attributes:
|
312
|
-
|
405
|
+
filterable_sequence: Prediction, Annotation or Segment sequence
|
313
406
|
filters: Filter predicates. Allowed formats are:
|
314
407
|
ListOfAndFilters where each Filter forms a chain of AND predicates.
|
315
408
|
or
|
@@ -320,11 +413,6 @@ def apply_filters(
|
|
320
413
|
is interpreted as a conjunction (AND), forming a more selective `and` multiple column predicate.
|
321
414
|
Finally, the most outer list combines these filters as a disjunction (OR).
|
322
415
|
"""
|
323
|
-
if filters is None or len(filters) == 0:
|
324
|
-
return ann_or_pred
|
325
|
-
|
326
|
-
filters = ensureDNFFilters(filters)
|
327
|
-
|
328
416
|
dnf_condition_functions = []
|
329
417
|
for or_branch in filters:
|
330
418
|
and_conditions = [
|
@@ -333,18 +421,136 @@ def apply_filters(
|
|
333
421
|
dnf_condition_functions.append(and_conditions)
|
334
422
|
|
335
423
|
filtered = []
|
336
|
-
for item in
|
424
|
+
for item in filterable_sequence:
|
337
425
|
for or_conditions in dnf_condition_functions:
|
338
426
|
if all(c(item) for c in or_conditions):
|
339
427
|
filtered.append(item)
|
340
428
|
break
|
429
|
+
|
430
|
+
return filtered
|
431
|
+
|
432
|
+
|
433
|
+
def _split_segment_filters(
|
434
|
+
dnf_filters: OrAndDNFFilters,
|
435
|
+
) -> Tuple[OrAndDNFFilters, OrAndDNFFilters]:
|
436
|
+
"""We treat Segment* filters differently -> this splits filters into two sets, one containing the
|
437
|
+
standard field, metadata branches and the other the segment filters.
|
438
|
+
"""
|
439
|
+
normal_or_branches = []
|
440
|
+
segment_or_branches = []
|
441
|
+
for and_branch in dnf_filters:
|
442
|
+
normal_filters = []
|
443
|
+
segment_filters = []
|
444
|
+
for filter_statement in and_branch:
|
445
|
+
if filter_statement.type in {
|
446
|
+
FilterType.SEGMENT_METADATA,
|
447
|
+
FilterType.SEGMENT_FIELD,
|
448
|
+
}:
|
449
|
+
segment_filters.append(filter_statement)
|
450
|
+
else:
|
451
|
+
normal_filters.append(filter_statement)
|
452
|
+
normal_or_branches.append(normal_filters)
|
453
|
+
segment_or_branches.append(segment_filters)
|
454
|
+
return normal_or_branches, segment_or_branches
|
455
|
+
|
456
|
+
|
457
|
+
def _filter_segments(
|
458
|
+
anns_or_preds: Union[
|
459
|
+
Sequence[SegmentationAnnotation], Sequence[SegmentationPrediction]
|
460
|
+
],
|
461
|
+
segment_filters: OrAndDNFFilters,
|
462
|
+
):
|
463
|
+
"""Filter Segments of a SegmentationAnnotation or Prediction
|
464
|
+
|
465
|
+
We have to treat this differently as metadata and labels are on nested Segment objects
|
466
|
+
"""
|
467
|
+
if len(segment_filters) == 0 or len(segment_filters[0]) == 0:
|
468
|
+
return anns_or_preds
|
469
|
+
|
470
|
+
# Transform segment filter types to field and metadata to iterate over annotation sub fields
|
471
|
+
transformed_or_branches = (
|
472
|
+
[]
|
473
|
+
) # type: List[List[Union[MetadataFilter, FieldFilter]]]
|
474
|
+
for and_branch in segment_filters:
|
475
|
+
transformed_and = [] # type: List[Union[MetadataFilter, FieldFilter]]
|
476
|
+
for filter_statement in and_branch:
|
477
|
+
if filter_statement.type == FilterType.SEGMENT_FIELD:
|
478
|
+
transformed_and.append(
|
479
|
+
FieldFilter(
|
480
|
+
filter_statement.key,
|
481
|
+
filter_statement.op,
|
482
|
+
filter_statement.value,
|
483
|
+
filter_statement.allow_missing,
|
484
|
+
)
|
485
|
+
)
|
486
|
+
elif filter_statement.type == FilterType.SEGMENT_METADATA:
|
487
|
+
transformed_and.append(
|
488
|
+
MetadataFilter(
|
489
|
+
filter_statement.key,
|
490
|
+
filter_statement.op,
|
491
|
+
filter_statement.value,
|
492
|
+
filter_statement.allow_missing,
|
493
|
+
)
|
494
|
+
)
|
495
|
+
else:
|
496
|
+
raise RuntimeError("Encountered a non SEGMENT_* filter type")
|
497
|
+
|
498
|
+
transformed_or_branches.append(transformed_and)
|
499
|
+
|
500
|
+
segments_filtered = []
|
501
|
+
for ann_or_pred in anns_or_preds:
|
502
|
+
if isinstance(
|
503
|
+
ann_or_pred, (SegmentationAnnotation, SegmentationPrediction)
|
504
|
+
):
|
505
|
+
ann_or_pred.annotations = _apply_field_or_metadata_filters(
|
506
|
+
ann_or_pred.annotations, transformed_or_branches # type: ignore
|
507
|
+
)
|
508
|
+
segments_filtered.append(ann_or_pred)
|
509
|
+
|
510
|
+
return segments_filtered
|
511
|
+
|
512
|
+
|
513
|
+
def apply_filters(
|
514
|
+
ann_or_pred: Union[Sequence[AnnotationTypes], Sequence[PredictionTypes]],
|
515
|
+
filters: Union[ListOfOrAndFilters, ListOfAndFilters],
|
516
|
+
):
|
517
|
+
"""Apply filters to list of annotations or list of predictions
|
518
|
+
Attributes:
|
519
|
+
ann_or_pred: Prediction or Annotation
|
520
|
+
filters: Filter predicates. Allowed formats are:
|
521
|
+
ListOfAndFilters where each Filter forms a chain of AND predicates.
|
522
|
+
or
|
523
|
+
ListOfOrAndFilters where Filters are expressed in disjunctive normal form (DNF), like
|
524
|
+
[[MetadataFilter("short_haired", "==", True), FieldFilter("label", "in", ["cat", "dog"]), ...].
|
525
|
+
DNF allows arbitrary boolean logical combinations of single field
|
526
|
+
predicates. The innermost structures each describe a single column predicate. The list of inner predicates
|
527
|
+
is interpreted as a conjunction (AND), forming a more selective `and` multiple column predicate.
|
528
|
+
Finally, the most outer list combines these filters as a disjunction (OR).
|
529
|
+
"""
|
530
|
+
if filters is None or len(filters) == 0:
|
531
|
+
return ann_or_pred
|
532
|
+
|
533
|
+
dnf_filters = ensureDNFFilters(filters)
|
534
|
+
filters, segment_filters = _split_segment_filters(dnf_filters)
|
535
|
+
filtered = _apply_field_or_metadata_filters(ann_or_pred, filters) # type: ignore
|
536
|
+
filtered = _filter_segments(filtered, segment_filters)
|
537
|
+
|
341
538
|
return filtered
|
342
539
|
|
343
540
|
|
344
541
|
def ensureDNFFilters(filters) -> OrAndDNFFilters:
|
345
542
|
"""JSON encoding creates a triple nested lists from the doubly nested tuples. This function creates the
|
346
543
|
tuple form again."""
|
347
|
-
if isinstance(
|
544
|
+
if isinstance(
|
545
|
+
filters[0],
|
546
|
+
(
|
547
|
+
MetadataFilter,
|
548
|
+
FieldFilter,
|
549
|
+
AnnotationOrPredictionFilter,
|
550
|
+
SegmentFieldFilter,
|
551
|
+
SegmentMetadataFilter,
|
552
|
+
),
|
553
|
+
):
|
348
554
|
# Normalize into DNF
|
349
555
|
filters: ListOfOrAndFilters = [filters] # type: ignore
|
350
556
|
|
@@ -369,3 +575,147 @@ def ensureDNFFilters(filters) -> OrAndDNFFilters:
|
|
369
575
|
formatted_filter.append(and_chain)
|
370
576
|
filters = formatted_filter
|
371
577
|
return filters
|
578
|
+
|
579
|
+
|
580
|
+
def pretty_format_filters_with_or_and(
|
581
|
+
filters: Optional[Union[ListOfOrAndFilters, ListOfAndFilters]]
|
582
|
+
):
|
583
|
+
if filters is None:
|
584
|
+
return "No filters applied!"
|
585
|
+
dnf_filters = ensureDNFFilters(filters)
|
586
|
+
or_branches = []
|
587
|
+
for or_branch in dnf_filters:
|
588
|
+
and_statements = []
|
589
|
+
for and_branch in or_branch:
|
590
|
+
if and_branch.type == FilterType.FIELD:
|
591
|
+
class_name = "FieldFilter"
|
592
|
+
elif and_branch.type == FilterType.METADATA:
|
593
|
+
class_name = "MetadataFilter"
|
594
|
+
elif and_branch.type == FilterType.SEGMENT_FIELD:
|
595
|
+
class_name = "SegmentFieldFilter"
|
596
|
+
elif and_branch.type == FilterType.SEGMENT_METADATA:
|
597
|
+
class_name = "SegmentMetadataFilter"
|
598
|
+
else:
|
599
|
+
raise RuntimeError(
|
600
|
+
f"Un-handled filter type: {and_branch.type}"
|
601
|
+
)
|
602
|
+
op = (
|
603
|
+
and_branch.op.value
|
604
|
+
if isinstance(and_branch.op, FilterOp)
|
605
|
+
else and_branch.op
|
606
|
+
)
|
607
|
+
value_formatted = (
|
608
|
+
f'"{and_branch.value}"'
|
609
|
+
if isinstance(and_branch.value, str)
|
610
|
+
else f"{and_branch.value}".replace("'", '"')
|
611
|
+
)
|
612
|
+
statement = (
|
613
|
+
f'{class_name}("{and_branch.key}", "{op}", {value_formatted})'
|
614
|
+
)
|
615
|
+
and_statements.append(statement)
|
616
|
+
|
617
|
+
or_branches.append(and_statements)
|
618
|
+
|
619
|
+
and_to_join = []
|
620
|
+
for and_statements in or_branches:
|
621
|
+
joined_and = " and ".join(and_statements)
|
622
|
+
if len(or_branches) > 1 and len(and_statements) > 1:
|
623
|
+
joined_and = "(" + joined_and + ")"
|
624
|
+
and_to_join.append(joined_and)
|
625
|
+
|
626
|
+
full_statement = " or ".join(and_to_join)
|
627
|
+
return full_statement
|
628
|
+
|
629
|
+
|
630
|
+
def compose_helpful_filtering_error(
|
631
|
+
ann_or_pred_list: Union[AnnotationList, PredictionList], filters
|
632
|
+
) -> List[str]:
|
633
|
+
prefix = (
|
634
|
+
"Annotations"
|
635
|
+
if isinstance(ann_or_pred_list, AnnotationList)
|
636
|
+
else "Predictions"
|
637
|
+
)
|
638
|
+
msg = []
|
639
|
+
msg.append(f"{prefix}: All items filtered out by:")
|
640
|
+
msg.append(f" {pretty_format_filters_with_or_and(filters)}")
|
641
|
+
msg.append("")
|
642
|
+
console = Console()
|
643
|
+
table = Table(
|
644
|
+
"Type",
|
645
|
+
"Count",
|
646
|
+
"Labels",
|
647
|
+
title=f"Original {prefix}",
|
648
|
+
title_justify="left",
|
649
|
+
)
|
650
|
+
for ann_or_pred_type, items in ann_or_pred_list.items():
|
651
|
+
if items and isinstance(
|
652
|
+
items[-1], (SegmentationAnnotation, SegmentationPrediction)
|
653
|
+
):
|
654
|
+
labels = set()
|
655
|
+
for seg in items:
|
656
|
+
labels.update(set(s.label for s in seg.annotations))
|
657
|
+
else:
|
658
|
+
labels = set(a.label for a in items)
|
659
|
+
if items:
|
660
|
+
table.add_row(ann_or_pred_type, str(len(items)), str(list(labels)))
|
661
|
+
with console.capture() as capture:
|
662
|
+
console.print(table)
|
663
|
+
msg.append(capture.get())
|
664
|
+
return msg
|
665
|
+
|
666
|
+
|
667
|
+
def filter_annotation_list(
|
668
|
+
annotations: AnnotationList, annotation_filters
|
669
|
+
) -> AnnotationList:
|
670
|
+
annotations = copy.deepcopy(annotations)
|
671
|
+
if annotation_filters is None or len(annotation_filters) == 0:
|
672
|
+
return annotations
|
673
|
+
annotations.box_annotations = apply_filters(
|
674
|
+
annotations.box_annotations, annotation_filters
|
675
|
+
)
|
676
|
+
annotations.line_annotations = apply_filters(
|
677
|
+
annotations.line_annotations, annotation_filters
|
678
|
+
)
|
679
|
+
annotations.polygon_annotations = apply_filters(
|
680
|
+
annotations.polygon_annotations, annotation_filters
|
681
|
+
)
|
682
|
+
annotations.cuboid_annotations = apply_filters(
|
683
|
+
annotations.cuboid_annotations, annotation_filters
|
684
|
+
)
|
685
|
+
annotations.category_annotations = apply_filters(
|
686
|
+
annotations.category_annotations, annotation_filters
|
687
|
+
)
|
688
|
+
annotations.multi_category_annotations = apply_filters(
|
689
|
+
annotations.multi_category_annotations, annotation_filters
|
690
|
+
)
|
691
|
+
annotations.segmentation_annotations = apply_filters(
|
692
|
+
annotations.segmentation_annotations, annotation_filters
|
693
|
+
)
|
694
|
+
return annotations
|
695
|
+
|
696
|
+
|
697
|
+
def filter_prediction_list(
|
698
|
+
predictions: PredictionList, prediction_filters
|
699
|
+
) -> PredictionList:
|
700
|
+
predictions = copy.deepcopy(predictions)
|
701
|
+
if prediction_filters is None or len(prediction_filters) == 0:
|
702
|
+
return predictions
|
703
|
+
predictions.box_predictions = apply_filters(
|
704
|
+
predictions.box_predictions, prediction_filters
|
705
|
+
)
|
706
|
+
predictions.line_predictions = apply_filters(
|
707
|
+
predictions.line_predictions, prediction_filters
|
708
|
+
)
|
709
|
+
predictions.polygon_predictions = apply_filters(
|
710
|
+
predictions.polygon_predictions, prediction_filters
|
711
|
+
)
|
712
|
+
predictions.cuboid_predictions = apply_filters(
|
713
|
+
predictions.cuboid_predictions, prediction_filters
|
714
|
+
)
|
715
|
+
predictions.category_predictions = apply_filters(
|
716
|
+
predictions.category_predictions, prediction_filters
|
717
|
+
)
|
718
|
+
predictions.segmentation_predictions = apply_filters(
|
719
|
+
predictions.segmentation_predictions, prediction_filters
|
720
|
+
)
|
721
|
+
return predictions
|
nucleus/metrics/polygon_utils.py
CHANGED
@@ -12,10 +12,10 @@ from .custom_types import BoxOrPolygonAnnotation, BoxOrPolygonPrediction
|
|
12
12
|
|
13
13
|
try:
|
14
14
|
from shapely.geometry import Polygon
|
15
|
-
except ModuleNotFoundError:
|
16
|
-
from ..
|
15
|
+
except (ModuleNotFoundError, OSError):
|
16
|
+
from ..package_not_installed import PackageNotInstalled
|
17
17
|
|
18
|
-
Polygon =
|
18
|
+
Polygon = PackageNotInstalled
|
19
19
|
|
20
20
|
|
21
21
|
from .base import ScalarResult
|