scale-nucleus 0.1.24__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. cli/client.py +14 -0
  2. cli/datasets.py +77 -0
  3. cli/helpers/__init__.py +0 -0
  4. cli/helpers/nucleus_url.py +10 -0
  5. cli/helpers/web_helper.py +40 -0
  6. cli/install_completion.py +33 -0
  7. cli/jobs.py +42 -0
  8. cli/models.py +35 -0
  9. cli/nu.py +42 -0
  10. cli/reference.py +8 -0
  11. cli/slices.py +62 -0
  12. cli/tests.py +121 -0
  13. nucleus/__init__.py +446 -710
  14. nucleus/annotation.py +405 -85
  15. nucleus/autocurate.py +9 -0
  16. nucleus/connection.py +87 -0
  17. nucleus/constants.py +5 -1
  18. nucleus/data_transfer_object/__init__.py +0 -0
  19. nucleus/data_transfer_object/dataset_details.py +9 -0
  20. nucleus/data_transfer_object/dataset_info.py +26 -0
  21. nucleus/data_transfer_object/dataset_size.py +5 -0
  22. nucleus/data_transfer_object/scenes_list.py +18 -0
  23. nucleus/dataset.py +1137 -212
  24. nucleus/dataset_item.py +130 -26
  25. nucleus/dataset_item_uploader.py +297 -0
  26. nucleus/deprecation_warning.py +32 -0
  27. nucleus/errors.py +9 -0
  28. nucleus/job.py +71 -3
  29. nucleus/logger.py +9 -0
  30. nucleus/metadata_manager.py +45 -0
  31. nucleus/metrics/__init__.py +10 -0
  32. nucleus/metrics/base.py +117 -0
  33. nucleus/metrics/categorization_metrics.py +197 -0
  34. nucleus/metrics/errors.py +7 -0
  35. nucleus/metrics/filters.py +40 -0
  36. nucleus/metrics/geometry.py +198 -0
  37. nucleus/metrics/metric_utils.py +28 -0
  38. nucleus/metrics/polygon_metrics.py +480 -0
  39. nucleus/metrics/polygon_utils.py +299 -0
  40. nucleus/model.py +121 -15
  41. nucleus/model_run.py +34 -57
  42. nucleus/payload_constructor.py +29 -19
  43. nucleus/prediction.py +259 -17
  44. nucleus/pydantic_base.py +26 -0
  45. nucleus/retry_strategy.py +4 -0
  46. nucleus/scene.py +204 -19
  47. nucleus/slice.py +230 -67
  48. nucleus/upload_response.py +20 -9
  49. nucleus/url_utils.py +4 -0
  50. nucleus/utils.py +134 -37
  51. nucleus/validate/__init__.py +24 -0
  52. nucleus/validate/client.py +168 -0
  53. nucleus/validate/constants.py +20 -0
  54. nucleus/validate/data_transfer_objects/__init__.py +0 -0
  55. nucleus/validate/data_transfer_objects/eval_function.py +81 -0
  56. nucleus/validate/data_transfer_objects/scenario_test.py +19 -0
  57. nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +11 -0
  58. nucleus/validate/data_transfer_objects/scenario_test_metric.py +12 -0
  59. nucleus/validate/errors.py +6 -0
  60. nucleus/validate/eval_functions/__init__.py +0 -0
  61. nucleus/validate/eval_functions/available_eval_functions.py +212 -0
  62. nucleus/validate/eval_functions/base_eval_function.py +60 -0
  63. nucleus/validate/scenario_test.py +143 -0
  64. nucleus/validate/scenario_test_evaluation.py +114 -0
  65. nucleus/validate/scenario_test_metric.py +14 -0
  66. nucleus/validate/utils.py +8 -0
  67. {scale_nucleus-0.1.24.dist-info → scale_nucleus-0.6.4.dist-info}/LICENSE +0 -0
  68. scale_nucleus-0.6.4.dist-info/METADATA +213 -0
  69. scale_nucleus-0.6.4.dist-info/RECORD +71 -0
  70. {scale_nucleus-0.1.24.dist-info → scale_nucleus-0.6.4.dist-info}/WHEEL +1 -1
  71. scale_nucleus-0.6.4.dist-info/entry_points.txt +3 -0
  72. scale_nucleus-0.1.24.dist-info/METADATA +0 -85
  73. scale_nucleus-0.1.24.dist-info/RECORD +0 -21
nucleus/job.py CHANGED
@@ -1,7 +1,9 @@
1
- from dataclasses import dataclass
2
1
  import time
2
+ from dataclasses import dataclass
3
3
  from typing import Dict, List
4
+
4
5
  import requests
6
+
5
7
  from nucleus.constants import (
6
8
  JOB_CREATION_TIME_KEY,
7
9
  JOB_ID_KEY,
@@ -9,12 +11,33 @@ from nucleus.constants import (
9
11
  JOB_TYPE_KEY,
10
12
  STATUS_KEY,
11
13
  )
14
+ from nucleus.utils import replace_double_slashes
12
15
 
13
16
  JOB_POLLING_INTERVAL = 5
14
17
 
15
18
 
16
19
  @dataclass
17
20
  class AsyncJob:
21
+ """Object used to check the status or errors of a long running asynchronous operation.
22
+
23
+ ::
24
+
25
+ import nucleus
26
+
27
+ client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
28
+ dataset = client.get_dataset("ds_bwkezj6g5c4g05gqp1eg")
29
+
30
+ # When kicking off an asynchronous job, store the return value as a variable
31
+ job = dataset.append(items=YOUR_DATASET_ITEMS, asynchronous=True)
32
+
33
+ # Poll for status or errors
34
+ print(job.status())
35
+ print(job.errors())
36
+
37
+ # Block until job finishes
38
+ job.sleep_until_complete()
39
+ """
40
+
18
41
  job_id: str
19
42
  job_last_known_status: str
20
43
  job_type: str
@@ -22,6 +45,23 @@ class AsyncJob:
22
45
  client: "NucleusClient" # type: ignore # noqa: F821
23
46
 
24
47
  def status(self) -> Dict[str, str]:
48
+ """Fetches status of the job and an informative message on job progress.
49
+
50
+ Returns:
51
+ A dict of the job ID, status (one of Running, Completed, or Errored),
52
+ an informative message on the job progress, and number of both completed
53
+ and total steps.
54
+ ::
55
+
56
+ {
57
+ "job_id": "job_c19xcf9mkws46gah0000",
58
+ "status": "Completed",
59
+ "message": "Job completed successfully.",
60
+ "job_progress": "0.33",
61
+ "completed_steps": "1",
62
+ "total_steps:": "3",
63
+ }
64
+ """
25
65
  response = self.client.make_request(
26
66
  payload={},
27
67
  route=f"job/{self.job_id}",
@@ -31,30 +71,57 @@ class AsyncJob:
31
71
  return response
32
72
 
33
73
  def errors(self) -> List[str]:
34
- return self.client.make_request(
74
+ """Fetches a list of the latest errors generated by the asynchronous job.
75
+
76
+ Useful for debugging failed or partially successful jobs.
77
+
78
+ Returns:
79
+ A list of strings containing the 10,000 most recently generated errors.
80
+ ::
81
+
82
+ [
83
+ '{"annotation":{"label":"car","type":"box","geometry":{"x":50,"y":60,"width":70,"height":80},"referenceId":"bad_ref_id","annotationId":"attempted_annot_upload","metadata":{}},"error":"Item with id bad_ref_id doesn\'t exist."}'
84
+ ]
85
+ """
86
+ errors = self.client.make_request(
35
87
  payload={},
36
88
  route=f"job/{self.job_id}/errors",
37
89
  requests_command=requests.get,
38
90
  )
91
+ return [replace_double_slashes(error) for error in errors]
39
92
 
40
93
  def sleep_until_complete(self, verbose_std_out=True):
94
+ """Blocks until the job completes or errors.
95
+
96
+ Parameters:
97
+ verbose_std_out (Optional[bool]): Whether or not to verbosely log while
98
+ sleeping. Defaults to True.
99
+ """
100
+ start_time = time.perf_counter()
41
101
  while 1:
42
102
  status = self.status()
43
103
  time.sleep(JOB_POLLING_INTERVAL)
44
104
 
45
105
  if verbose_std_out:
46
- print(f"Status at {time.ctime()}: {status}")
106
+ print(
107
+ f"Status at {time.perf_counter() - start_time} s: {status}"
108
+ )
47
109
  if status["status"] == "Running":
48
110
  continue
49
111
 
50
112
  break
51
113
 
114
+ if verbose_std_out:
115
+ print(
116
+ f"Finished at {time.perf_counter() - start_time} s: {status}"
117
+ )
52
118
  final_status = status
53
119
  if final_status["status"] == "Errored":
54
120
  raise JobError(final_status, self)
55
121
 
56
122
  @classmethod
57
123
  def from_json(cls, payload: dict, client):
124
+ # TODO: make private
58
125
  return cls(
59
126
  job_id=payload[JOB_ID_KEY],
60
127
  job_last_known_status=payload[JOB_LAST_KNOWN_STATUS_KEY],
@@ -74,4 +141,5 @@ class JobError(Exception):
74
141
  f"The final status message was: {final_status_message} \n"
75
142
  f"For more detailed error messages you can call {str(job)}.errors()"
76
143
  )
144
+ message = replace_double_slashes(message)
77
145
  super().__init__(message)
nucleus/logger.py ADDED
@@ -0,0 +1,9 @@
1
+ import logging
2
+
3
+ import requests
4
+
5
+ logger = logging.getLogger(__name__)
6
+ logging.basicConfig()
7
+ logging.getLogger(
8
+ requests.packages.urllib3.__package__ # pylint: disable=no-member
9
+ ).setLevel(logging.ERROR)
@@ -0,0 +1,45 @@
1
+ from enum import Enum
2
+ from typing import TYPE_CHECKING, Dict
3
+
4
+ if TYPE_CHECKING:
5
+ from . import NucleusClient
6
+
7
+
8
+ # Wording set to match with backend enum
9
+ class ExportMetadataType(Enum):
10
+ SCENES = "scene"
11
+ DATASET_ITEMS = "item"
12
+
13
+
14
+ class MetadataManager:
15
+ """
16
+ Helper class for managing metadata updates on a scene or dataset item.
17
+ Do not call directly, use the dataset class methods: `update_scene_metadata` or `update_item_metadata`
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ dataset_id: str,
23
+ client: "NucleusClient",
24
+ raw_mappings: Dict[str, dict],
25
+ level: ExportMetadataType,
26
+ ):
27
+ self.dataset_id = dataset_id
28
+ self._client = client
29
+ self.raw_mappings = raw_mappings
30
+ self.level = level
31
+
32
+ self._payload = self._format_mappings()
33
+
34
+ def _format_mappings(self):
35
+ payload = []
36
+ for ref_id, meta in self.raw_mappings.items():
37
+ payload.append({"reference_id": ref_id, "metadata": meta})
38
+ return payload
39
+
40
+ def update(self):
41
+ payload = {"metadata": self._payload, "level": self.level.value}
42
+ resp = self._client.make_request(
43
+ payload=payload, route=f"dataset/{self.dataset_id}/metadata"
44
+ )
45
+ return resp
@@ -0,0 +1,10 @@
1
+ from .base import Metric, ScalarResult
2
+ from .categorization_metrics import CategorizationF1
3
+ from .polygon_metrics import (
4
+ PolygonAveragePrecision,
5
+ PolygonIOU,
6
+ PolygonMAP,
7
+ PolygonMetric,
8
+ PolygonPrecision,
9
+ PolygonRecall,
10
+ )
@@ -0,0 +1,117 @@
1
+ import sys
2
+ from abc import ABC, abstractmethod
3
+ from dataclasses import dataclass
4
+ from typing import Iterable, List
5
+
6
+ from nucleus.annotation import AnnotationList
7
+ from nucleus.prediction import PredictionList
8
+
9
+
10
+ class MetricResult(ABC):
11
+ """Base MetricResult class"""
12
+
13
+
14
+ @dataclass
15
+ class ScalarResult(MetricResult):
16
+ """A scalar result contains the value of an evaluation, as well as its weight.
17
+ The weight is useful when aggregating metrics where each dataset item may hold a
18
+ different relative weight. For example, when calculating precision over a dataset,
19
+ the denominator of the precision is the number of annotations, and therefore the weight
20
+ can be set as the number of annotations.
21
+
22
+ Attributes:
23
+ value (float): The value of the evaluation result
24
+ weight (float): The weight of the evaluation result.
25
+ """
26
+
27
+ value: float
28
+ weight: float = 1.0
29
+
30
+ @staticmethod
31
+ def aggregate(results: Iterable["ScalarResult"]) -> "ScalarResult":
32
+ """Aggregates results using a weighted average."""
33
+ results = list(filter(lambda x: x.weight != 0, results))
34
+ total_weight = sum([result.weight for result in results])
35
+ total_value = sum([result.value * result.weight for result in results])
36
+ value = total_value / max(total_weight, sys.float_info.epsilon)
37
+ return ScalarResult(value, total_weight)
38
+
39
+
40
+ class Metric(ABC):
41
+ """Abstract class for defining a metric, which takes a list of annotations
42
+ and predictions and returns a scalar.
43
+
44
+ To create a new concrete Metric, override the `__call__` function
45
+ with logic to define a metric between annotations and predictions. ::
46
+
47
+ from nucleus import BoxAnnotation, CuboidPrediction, Point3D
48
+ from nucleus.annotation import AnnotationList
49
+ from nucleus.prediction import PredictionList
50
+ from nucleus.metrics import Metric, MetricResult
51
+ from nucleus.metrics.polygon_utils import BoxOrPolygonAnnotation, BoxOrPolygonPrediction
52
+
53
+ class MyMetric(Metric):
54
+ def __call__(
55
+ self, annotations: AnnotationList, predictions: PredictionList
56
+ ) -> MetricResult:
57
+ value = (len(annotations) - len(predictions)) ** 2
58
+ weight = len(annotations)
59
+ return MetricResult(value, weight)
60
+
61
+ box = BoxAnnotation(
62
+ label="car",
63
+ x=0,
64
+ y=0,
65
+ width=10,
66
+ height=10,
67
+ reference_id="image_1",
68
+ annotation_id="image_1_car_box_1",
69
+ metadata={"vehicle_color": "red"}
70
+ )
71
+
72
+ cuboid = CuboidPrediction(
73
+ label="car",
74
+ position=Point3D(100, 100, 10),
75
+ dimensions=Point3D(5, 10, 5),
76
+ yaw=0,
77
+ reference_id="pointcloud_1",
78
+ confidence=0.8,
79
+ annotation_id="pointcloud_1_car_cuboid_1",
80
+ metadata={"vehicle_color": "green"}
81
+ )
82
+
83
+ metric = MyMetric()
84
+ annotations = AnnotationList(box_annotations=[box])
85
+ predictions = PredictionList(cuboid_predictions=[cuboid])
86
+ metric(annotations, predictions)
87
+ """
88
+
89
+ @abstractmethod
90
+ def __call__(
91
+ self, annotations: AnnotationList, predictions: PredictionList
92
+ ) -> MetricResult:
93
+ """A metric must override this method and return a metric result, given annotations and predictions."""
94
+
95
+ @abstractmethod
96
+ def aggregate_score(self, results: List[MetricResult]) -> ScalarResult:
97
+ """A metric must define how to aggregate results from single items to a single ScalarResult.
98
+
99
+ E.g. to calculate a R2 score with sklearn you could define a custom metric class ::
100
+
101
+ class R2Result(MetricResult):
102
+ y_true: float
103
+ y_pred: float
104
+
105
+
106
+ And then define an aggregate_score ::
107
+
108
+ def aggregate_score(self, results: List[MetricResult]) -> ScalarResult:
109
+ y_trues = []
110
+ y_preds = []
111
+ for result in results:
112
+ y_true.append(result.y_true)
113
+ y_preds.append(result.y_pred)
114
+ r2_score = sklearn.metrics.r2_score(y_trues, y_preds)
115
+ return ScalarResult(r2_score)
116
+
117
+ """
@@ -0,0 +1,197 @@
1
+ from abc import abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import List, Set, Tuple, Union
4
+
5
+ from sklearn.metrics import f1_score
6
+
7
+ from nucleus.annotation import AnnotationList, CategoryAnnotation
8
+ from nucleus.metrics.base import Metric, MetricResult, ScalarResult
9
+ from nucleus.metrics.filters import confidence_filter
10
+ from nucleus.prediction import CategoryPrediction, PredictionList
11
+
12
+ F1_METHODS = {"micro", "macro", "samples", "weighted", "binary"}
13
+
14
+
15
+ def to_taxonomy_labels(
16
+ anns_or_preds: Union[List[CategoryAnnotation], List[CategoryPrediction]]
17
+ ) -> Set[str]:
18
+ """Transforms annotation or prediction lists to taxonomy labels by joining them with a seperator (->)"""
19
+ labels = set()
20
+ for item in anns_or_preds:
21
+ taxonomy_label = (
22
+ f"{item.taxonomy_name}->{item.label}"
23
+ if item.taxonomy_name
24
+ else item.label
25
+ )
26
+ labels.add(taxonomy_label)
27
+ return labels
28
+
29
+
30
+ @dataclass
31
+ class CategorizationResult(MetricResult):
32
+ annotations: List[CategoryAnnotation]
33
+ predictions: List[CategoryPrediction]
34
+
35
+ @property
36
+ def value(self):
37
+ annotation_labels = to_taxonomy_labels(self.annotations)
38
+ prediction_labels = to_taxonomy_labels(self.predictions)
39
+
40
+ # TODO: Change task.py interface such that we can return label matching
41
+ # NOTE: Returning 1 if all taxonomy labels match else 0
42
+ value = f1_score(list(annotation_labels), list(prediction_labels), average="macro")
43
+ return value
44
+
45
+
46
+ class CategorizationMetric(Metric):
47
+ """Abstract class for metrics related to Categorization
48
+
49
+ The Categorization class automatically filters incoming annotations and
50
+ predictions for only categorization annotations. It also filters
51
+ predictions whose confidence is less than the provided confidence_threshold.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ confidence_threshold: float = 0.0,
57
+ ):
58
+ """Initializes CategorizationMetric abstract object.
59
+
60
+ Args:
61
+ confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation. Must be in [0, 1]. Default 0.0
62
+ """
63
+ assert 0 <= confidence_threshold <= 1
64
+ self.confidence_threshold = confidence_threshold
65
+
66
+ @abstractmethod
67
+ def eval(
68
+ self,
69
+ annotations: List[
70
+ CategoryAnnotation
71
+ ], # TODO(gunnar): List to conform with other APIs or single instance?
72
+ predictions: List[CategoryPrediction],
73
+ ) -> CategorizationResult:
74
+ # Main evaluation function that subclasses must override.
75
+ # TODO(gunnar): Allow passing multiple predictions and selecting highest confidence? Allows us to show next
76
+ # contender. Are top-5 scores something that we care about?
77
+ # TODO(gunnar): How do we handle multi-head classification?
78
+ pass
79
+
80
+ @abstractmethod
81
+ def aggregate_score(self, results: List[CategorizationResult]) -> ScalarResult: # type: ignore[override]
82
+ pass
83
+
84
+ def __call__(
85
+ self, annotations: AnnotationList, predictions: PredictionList
86
+ ) -> CategorizationResult:
87
+ if self.confidence_threshold > 0:
88
+ predictions = confidence_filter(
89
+ predictions, self.confidence_threshold
90
+ )
91
+
92
+ cat_annotations, cat_predictions = self._filter_common_taxonomies(
93
+ annotations.category_annotations, predictions.category_predictions
94
+ )
95
+
96
+ result = self.eval(
97
+ cat_annotations,
98
+ cat_predictions,
99
+ )
100
+ return result
101
+
102
+ def _filter_common_taxonomies(
103
+ self,
104
+ annotations: List[CategoryAnnotation],
105
+ predictions: List[CategoryPrediction],
106
+ ) -> Tuple[List[CategoryAnnotation], List[CategoryPrediction]]:
107
+ annotated_taxonomies = {ann.taxonomy_name for ann in annotations}
108
+ matching_predictions, matching_taxonomies = self._filter_in_taxonomies(
109
+ predictions, annotated_taxonomies
110
+ )
111
+ matching_annotations, _ = self._filter_in_taxonomies(
112
+ annotations, matching_taxonomies
113
+ )
114
+
115
+ return matching_annotations, matching_predictions # type: ignore
116
+
117
+ def _filter_in_taxonomies(
118
+ self,
119
+ anns_or_preds: Union[
120
+ List[CategoryAnnotation], List[CategoryPrediction]
121
+ ],
122
+ filter_on_taxonomies: Set[Union[None, str]],
123
+ ) -> Tuple[
124
+ Union[List[CategoryAnnotation], List[CategoryPrediction]],
125
+ Set[Union[None, str]],
126
+ ]:
127
+ matching_predictions = []
128
+ matching_taxonomies = set()
129
+ for pred in anns_or_preds:
130
+ if pred.taxonomy_name in filter_on_taxonomies:
131
+ matching_predictions.append(pred)
132
+ matching_taxonomies.add(pred.taxonomy_name)
133
+ return matching_predictions, matching_taxonomies
134
+
135
+
136
+ class CategorizationF1(CategorizationMetric):
137
+ """Evaluation method that matches categories and returns a CategorizationF1Result that aggregates to the F1 score"""
138
+
139
+ def __init__(
140
+ self, confidence_threshold: float = 0.0, f1_method: str = "macro"
141
+ ):
142
+ """
143
+ Args:
144
+ confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation. Must be in [0, 1]. Default 0.0
145
+ f1_method: {'micro', 'macro', 'samples','weighted', 'binary'}, \
146
+ default='macro'
147
+ This parameter is required for multiclass/multilabel targets.
148
+ If ``None``, the scores for each class are returned. Otherwise, this
149
+ determines the type of averaging performed on the data:
150
+
151
+ ``'binary'``:
152
+ Only report results for the class specified by ``pos_label``.
153
+ This is applicable only if targets (``y_{true,pred}``) are binary.
154
+ ``'micro'``:
155
+ Calculate metrics globally by counting the total true positives,
156
+ false negatives and false positives.
157
+ ``'macro'``:
158
+ Calculate metrics for each label, and find their unweighted
159
+ mean. This does not take label imbalance into account.
160
+ ``'weighted'``:
161
+ Calculate metrics for each label, and find their average weighted
162
+ by support (the number of true instances for each label). This
163
+ alters 'macro' to account for label imbalance; it can result in an
164
+ F-score that is not between precision and recall.
165
+ ``'samples'``:
166
+ Calculate metrics for each instance, and find their average (only
167
+ meaningful for multilabel classification where this differs from
168
+ :func:`accuracy_score`).
169
+ """
170
+ super().__init__(confidence_threshold)
171
+ assert (
172
+ f1_method in F1_METHODS
173
+ ), f"Invalid f1_method {f1_method}, expected one of {F1_METHODS}"
174
+ self.f1_method = f1_method
175
+
176
+ def eval(
177
+ self,
178
+ annotations: List[CategoryAnnotation],
179
+ predictions: List[CategoryPrediction],
180
+ ) -> CategorizationResult:
181
+ """
182
+ Notes: This is a little weird eval function. It essentially only does matching of annotation to label and
183
+ the actual metric computation happens in the aggregate step since F1 score only makes sense on a collection.
184
+ """
185
+
186
+ return CategorizationResult(
187
+ annotations=annotations, predictions=predictions
188
+ )
189
+
190
+ def aggregate_score(self, results: List[CategorizationResult]) -> ScalarResult: # type: ignore[override]
191
+ gt = []
192
+ predicted = []
193
+ for result in results:
194
+ gt.extend(list(to_taxonomy_labels(result.annotations)))
195
+ predicted.extend(list(to_taxonomy_labels(result.predictions)))
196
+ value = f1_score(gt, predicted, average=self.f1_method)
197
+ return ScalarResult(value)
@@ -0,0 +1,7 @@
1
+ class PolygonAnnotationTypeError(Exception):
2
+ def __init__(
3
+ self,
4
+ message="Annotation was expected to be of type 'BoxAnnotation' or 'PolygonAnnotation'.",
5
+ ):
6
+ self.message = message
7
+ super().__init__(self.message)
@@ -0,0 +1,40 @@
1
+ from typing import List
2
+
3
+ from nucleus.prediction import PredictionList
4
+
5
+ from .polygon_utils import (
6
+ BoxOrPolygonAnnoOrPred,
7
+ polygon_annotation_to_geometry,
8
+ )
9
+
10
+
11
+ def polygon_area_filter(
12
+ polygons: List[BoxOrPolygonAnnoOrPred], min_area: float, max_area: float
13
+ ) -> List[BoxOrPolygonAnnoOrPred]:
14
+ filter_fn = (
15
+ lambda polygon: min_area
16
+ <= polygon_annotation_to_geometry(polygon).signed_area
17
+ <= max_area
18
+ )
19
+ return list(filter(filter_fn, polygons))
20
+
21
+
22
+ def confidence_filter(
23
+ predictions: PredictionList, min_confidence: float
24
+ ) -> PredictionList:
25
+ predictions_copy = PredictionList()
26
+ filter_fn = (
27
+ lambda prediction: not hasattr(prediction, "confidence")
28
+ or prediction.confidence >= min_confidence
29
+ )
30
+ for attr in predictions.__dict__:
31
+ predictions_copy.__dict__[attr] = list(
32
+ filter(filter_fn, predictions.__dict__[attr])
33
+ )
34
+ return predictions_copy
35
+
36
+
37
+ def polygon_label_filter(
38
+ polygons: List[BoxOrPolygonAnnoOrPred], label: str
39
+ ) -> List[BoxOrPolygonAnnoOrPred]:
40
+ return list(filter(lambda polygon: polygon.label == label, polygons))