valor-lite 0.37.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of valor-lite might be problematic. Click here for more details.

Files changed (49) hide show
  1. valor_lite/LICENSE +21 -0
  2. valor_lite/__init__.py +0 -0
  3. valor_lite/cache/__init__.py +11 -0
  4. valor_lite/cache/compute.py +154 -0
  5. valor_lite/cache/ephemeral.py +302 -0
  6. valor_lite/cache/persistent.py +529 -0
  7. valor_lite/classification/__init__.py +14 -0
  8. valor_lite/classification/annotation.py +45 -0
  9. valor_lite/classification/computation.py +378 -0
  10. valor_lite/classification/evaluator.py +879 -0
  11. valor_lite/classification/loader.py +97 -0
  12. valor_lite/classification/metric.py +535 -0
  13. valor_lite/classification/numpy_compatibility.py +13 -0
  14. valor_lite/classification/shared.py +184 -0
  15. valor_lite/classification/utilities.py +314 -0
  16. valor_lite/exceptions.py +20 -0
  17. valor_lite/object_detection/__init__.py +17 -0
  18. valor_lite/object_detection/annotation.py +238 -0
  19. valor_lite/object_detection/computation.py +841 -0
  20. valor_lite/object_detection/evaluator.py +805 -0
  21. valor_lite/object_detection/loader.py +292 -0
  22. valor_lite/object_detection/metric.py +850 -0
  23. valor_lite/object_detection/shared.py +185 -0
  24. valor_lite/object_detection/utilities.py +396 -0
  25. valor_lite/schemas.py +11 -0
  26. valor_lite/semantic_segmentation/__init__.py +15 -0
  27. valor_lite/semantic_segmentation/annotation.py +123 -0
  28. valor_lite/semantic_segmentation/computation.py +165 -0
  29. valor_lite/semantic_segmentation/evaluator.py +414 -0
  30. valor_lite/semantic_segmentation/loader.py +205 -0
  31. valor_lite/semantic_segmentation/metric.py +275 -0
  32. valor_lite/semantic_segmentation/shared.py +149 -0
  33. valor_lite/semantic_segmentation/utilities.py +88 -0
  34. valor_lite/text_generation/__init__.py +15 -0
  35. valor_lite/text_generation/annotation.py +56 -0
  36. valor_lite/text_generation/computation.py +611 -0
  37. valor_lite/text_generation/llm/__init__.py +0 -0
  38. valor_lite/text_generation/llm/exceptions.py +14 -0
  39. valor_lite/text_generation/llm/generation.py +903 -0
  40. valor_lite/text_generation/llm/instructions.py +814 -0
  41. valor_lite/text_generation/llm/integrations.py +226 -0
  42. valor_lite/text_generation/llm/utilities.py +43 -0
  43. valor_lite/text_generation/llm/validators.py +68 -0
  44. valor_lite/text_generation/manager.py +697 -0
  45. valor_lite/text_generation/metric.py +381 -0
  46. valor_lite-0.37.1.dist-info/METADATA +174 -0
  47. valor_lite-0.37.1.dist-info/RECORD +49 -0
  48. valor_lite-0.37.1.dist-info/WHEEL +5 -0
  49. valor_lite-0.37.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,123 @@
1
+ import warnings
2
+ from dataclasses import dataclass, field
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+
9
+ @dataclass
10
+ class Bitmask:
11
+ """
12
+ Represents a binary mask with an associated semantic label.
13
+
14
+ Parameters
15
+ ----------
16
+ mask : NDArray[np.bool_]
17
+ A NumPy array of boolean values representing the mask.
18
+ label : str
19
+ The semantic label associated with the mask.
20
+ metadata : dict[str, Any], optional
21
+ A dictionary containing any metadata to be used within filtering operations.
22
+
23
+ Examples
24
+ --------
25
+ >>> import numpy as np
26
+ >>> mask = np.array([[True, False], [False, True]], dtype=np.bool_)
27
+ >>> bitmask = Bitmask(mask=mask, label='ocean')
28
+ """
29
+
30
+ mask: NDArray[np.bool_]
31
+ label: str
32
+ metadata: dict[str, Any] | None = None
33
+
34
+ def __post_init__(self):
35
+ if self.mask.dtype != np.bool_:
36
+ raise ValueError(
37
+ f"Bitmask recieved mask with dtype '{self.mask.dtype}'."
38
+ )
39
+
40
+
41
+ @dataclass
42
+ class Segmentation:
43
+ """
44
+ Segmentation data structure holding ground truth and prediction bitmasks for semantic segmentation tasks.
45
+
46
+ Parameters
47
+ ----------
48
+ uid : str
49
+ Unique identifier for the image or sample.
50
+ groundtruths : List[Bitmask]
51
+ List of ground truth bitmasks.
52
+ predictions : List[Bitmask]
53
+ List of predicted bitmasks.
54
+ shape : tuple of int, optional
55
+ The shape of the segmentation masks. This is set automatically after initialization.
56
+ size : int, optional
57
+ The total number of pixels in the masks. This is set automatically after initialization.
58
+ metadata : dict[str, Any], optional
59
+ A dictionary containing any metadata to be used within filtering operations.
60
+
61
+ Examples
62
+ --------
63
+ >>> import numpy as np
64
+ >>> mask1 = np.array([[True, False], [False, True]], dtype=np.bool_)
65
+ >>> groundtruth = Bitmask(mask=mask1, label='object')
66
+ >>> mask2 = np.array([[False, True], [True, False]], dtype=np.bool_)
67
+ >>> prediction = Bitmask(mask=mask2, label='object')
68
+ >>> segmentation = Segmentation(
69
+ ... uid='123',
70
+ ... groundtruths=[groundtruth],
71
+ ... predictions=[prediction]
72
+ ... )
73
+ """
74
+
75
+ uid: str
76
+ groundtruths: list[Bitmask]
77
+ predictions: list[Bitmask]
78
+ shape: tuple[int, ...]
79
+ size: int = field(default=0)
80
+ metadata: dict[str, Any] | None = None
81
+
82
+ def __post_init__(self):
83
+
84
+ if len(self.shape) != 2 or self.shape[0] <= 0 or self.shape[1] <= 0:
85
+ raise ValueError(
86
+ f"segmentations must be 2-dimensional and have non-zero dimensions. Recieved shape '{self.shape}'"
87
+ )
88
+ self.size = self.shape[0] * self.shape[1]
89
+
90
+ self._validate_bitmasks(self.groundtruths, "ground truth")
91
+ self._validate_bitmasks(self.predictions, "prediction")
92
+
93
+ def _validate_bitmasks(self, bitmasks: list[Bitmask], key: str):
94
+ mask_accumulation = None
95
+ mask_overlap_accumulation = None
96
+ for idx, bitmask in enumerate(bitmasks):
97
+ if not isinstance(bitmask, Bitmask):
98
+ raise ValueError(f"expected 'Bitmask', got '{bitmask}'")
99
+ if self.shape != bitmask.mask.shape:
100
+ raise ValueError(
101
+ f"{key} masks for datum '{self.uid}' should have shape '{self.shape}'. Received mask with shape '{bitmask.mask.shape}'"
102
+ )
103
+
104
+ if mask_accumulation is None:
105
+ mask_accumulation = bitmask.mask.copy()
106
+ mask_overlap_accumulation = np.zeros_like(mask_accumulation)
107
+ elif np.logical_and(mask_accumulation, bitmask.mask).any():
108
+ mask_overlap = np.logical_and(mask_accumulation, bitmask.mask)
109
+ bitmasks[idx].mask[mask_overlap] = False
110
+ mask_overlap_accumulation = (
111
+ mask_overlap_accumulation | mask_overlap
112
+ )
113
+ else:
114
+ mask_accumulation = mask_accumulation | bitmask.mask
115
+ if (
116
+ mask_overlap_accumulation is not None
117
+ and mask_overlap_accumulation.any()
118
+ ):
119
+ count = mask_overlap_accumulation.sum()
120
+ total = mask_overlap_accumulation.size
121
+ warnings.warn(
122
+ f"{key} masks for datum '{self.uid}' had {count} / {total} pixels overlapped."
123
+ )
@@ -0,0 +1,165 @@
1
+ import numpy as np
2
+ from numpy.typing import NDArray
3
+
4
+
5
+ def compute_intermediates(
6
+ groundtruths: NDArray[np.bool_],
7
+ predictions: NDArray[np.bool_],
8
+ groundtruth_labels: NDArray[np.int64],
9
+ prediction_labels: NDArray[np.int64],
10
+ n_labels: int,
11
+ ) -> NDArray[np.uint64]:
12
+ """
13
+ Computes an intermediate confusion matrix containing label counts.
14
+
15
+ Parameters
16
+ ----------
17
+ groundtruths : NDArray[np.bool_]
18
+ A 2-D array containing flattened bitmasks for each label.
19
+ predictions : NDArray[np.bool_]
20
+ A 2-D array containing flattened bitmasks for each label.
21
+ groundtruth_labels : NDArray[np.int64]
22
+ A 1-D array containing ground truth label indices.
23
+ prediction_labels : NDArray[np.int64]
24
+ A 1-D array containing prediction label indices.
25
+ n_labels : int
26
+ The number of unique labels.
27
+
28
+ Returns
29
+ -------
30
+ NDArray[np.uint64]
31
+ A 2-D confusion matrix with shape (n_labels + 1, n_labels + 1).
32
+ """
33
+
34
+ groundtruth_counts = groundtruths.sum(axis=1)
35
+ prediction_counts = predictions.sum(axis=1)
36
+
37
+ background_counts = np.logical_not(
38
+ groundtruths.any(axis=0) | predictions.any(axis=0)
39
+ ).sum()
40
+
41
+ intersection_counts = np.logical_and(
42
+ groundtruths[:, None, :],
43
+ predictions[None, :, :],
44
+ ).sum(axis=2)
45
+ intersected_groundtruth_counts = intersection_counts.sum(axis=1)
46
+ intersected_prediction_counts = intersection_counts.sum(axis=0)
47
+
48
+ confusion_matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.uint64)
49
+ confusion_matrix[0, 0] = background_counts
50
+ confusion_matrix[
51
+ np.ix_(groundtruth_labels + 1, prediction_labels + 1)
52
+ ] = intersection_counts
53
+ confusion_matrix[0, prediction_labels + 1] = (
54
+ prediction_counts - intersected_prediction_counts
55
+ )
56
+ confusion_matrix[groundtruth_labels + 1, 0] = (
57
+ groundtruth_counts - intersected_groundtruth_counts
58
+ )
59
+ return confusion_matrix
60
+
61
+
62
+ def compute_metrics(
63
+ confusion_matrix: NDArray[np.uint64],
64
+ ) -> tuple[
65
+ NDArray[np.float64],
66
+ NDArray[np.float64],
67
+ NDArray[np.float64],
68
+ float,
69
+ NDArray[np.float64],
70
+ NDArray[np.float64],
71
+ NDArray[np.float64],
72
+ ]:
73
+ """
74
+ Computes semantic segmentation metrics.
75
+
76
+ Parameters
77
+ ----------
78
+ counts : NDArray[np.uint64]
79
+ A 2-D confusion matrix with shape (n_labels + 1, n_labels + 1).
80
+
81
+ Returns
82
+ -------
83
+ NDArray[np.float64]
84
+ Precision.
85
+ NDArray[np.float64]
86
+ Recall.
87
+ NDArray[np.float64]
88
+ F1 Score.
89
+ float
90
+ Accuracy
91
+ NDArray[np.float64]
92
+ Confusion matrix containing IOU values.
93
+ NDArray[np.float64]
94
+ Unmatched prediction ratios.
95
+ NDArray[np.float64]
96
+ Unmatched ground truth ratios.
97
+ """
98
+ n_labels = confusion_matrix.shape[0] - 1
99
+ n_pixels = confusion_matrix.sum()
100
+ gt_counts = confusion_matrix[1:, :].sum(axis=1)
101
+ pd_counts = confusion_matrix[:, 1:].sum(axis=0)
102
+
103
+ # compute iou, unmatched_ground_truth and unmatched predictions
104
+ intersection_ = confusion_matrix[1:, 1:]
105
+ union_ = (
106
+ gt_counts[:, np.newaxis] + pd_counts[np.newaxis, :] - intersection_
107
+ )
108
+
109
+ ious = np.zeros((n_labels, n_labels), dtype=np.float64)
110
+ np.divide(
111
+ intersection_,
112
+ union_,
113
+ where=union_ > 1e-9,
114
+ out=ious,
115
+ )
116
+
117
+ unmatched_prediction_ratio = np.zeros((n_labels), dtype=np.float64)
118
+ np.divide(
119
+ confusion_matrix[0, 1:],
120
+ pd_counts,
121
+ where=pd_counts > 1e-9,
122
+ out=unmatched_prediction_ratio,
123
+ )
124
+
125
+ unmatched_ground_truth_ratio = np.zeros((n_labels), dtype=np.float64)
126
+ np.divide(
127
+ confusion_matrix[1:, 0],
128
+ gt_counts,
129
+ where=gt_counts > 1e-9,
130
+ out=unmatched_ground_truth_ratio,
131
+ )
132
+
133
+ # compute precision, recall, f1
134
+ tp_counts = confusion_matrix.diagonal()[1:]
135
+
136
+ precision = np.zeros(n_labels, dtype=np.float64)
137
+ np.divide(tp_counts, pd_counts, where=pd_counts > 1e-9, out=precision)
138
+
139
+ recall = np.zeros_like(precision)
140
+ np.divide(tp_counts, gt_counts, where=gt_counts > 1e-9, out=recall)
141
+
142
+ f1_score = np.zeros_like(precision)
143
+ np.divide(
144
+ 2 * (precision * recall),
145
+ (precision + recall),
146
+ where=(precision + recall) > 0,
147
+ out=f1_score,
148
+ )
149
+
150
+ # compute accuracy
151
+ tp_count = confusion_matrix[1:, 1:].diagonal().sum()
152
+ background_count = confusion_matrix[0, 0]
153
+ accuracy = (
154
+ (tp_count + background_count) / n_pixels if n_pixels > 0 else 0.0
155
+ )
156
+
157
+ return (
158
+ precision,
159
+ recall,
160
+ f1_score,
161
+ accuracy,
162
+ ious,
163
+ unmatched_prediction_ratio,
164
+ unmatched_ground_truth_ratio,
165
+ )
@@ -0,0 +1,414 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pyarrow as pa
8
+ import pyarrow.compute as pc
9
+ from numpy.typing import NDArray
10
+
11
+ from valor_lite.cache import (
12
+ FileCacheReader,
13
+ FileCacheWriter,
14
+ MemoryCacheReader,
15
+ MemoryCacheWriter,
16
+ )
17
+ from valor_lite.exceptions import EmptyCacheError
18
+ from valor_lite.semantic_segmentation.computation import compute_metrics
19
+ from valor_lite.semantic_segmentation.metric import MetricType
20
+ from valor_lite.semantic_segmentation.shared import (
21
+ EvaluatorInfo,
22
+ decode_metadata_fields,
23
+ encode_metadata_fields,
24
+ extract_counts,
25
+ extract_labels,
26
+ generate_cache_path,
27
+ generate_metadata_path,
28
+ generate_schema,
29
+ )
30
+ from valor_lite.semantic_segmentation.utilities import (
31
+ unpack_precision_recall_iou_into_metric_lists,
32
+ )
33
+
34
+
35
+ class Builder:
36
+ def __init__(
37
+ self,
38
+ writer: MemoryCacheWriter | FileCacheWriter,
39
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
40
+ ):
41
+ self._writer = writer
42
+ self._metadata_fields = metadata_fields
43
+
44
+ @classmethod
45
+ def in_memory(
46
+ cls,
47
+ batch_size: int = 10_000,
48
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
49
+ ):
50
+ """
51
+ Create an in-memory evaluator cache.
52
+
53
+ Parameters
54
+ ----------
55
+ batch_size : int, default=10_000
56
+ The target number of rows to buffer before writing to the cache. Defaults to 10_000.
57
+ metadata_fields : list[tuple[str, str | pa.DataType]], optional
58
+ Optional metadata field definitions.
59
+ """
60
+ # create cache
61
+ writer = MemoryCacheWriter.create(
62
+ schema=generate_schema(metadata_fields),
63
+ batch_size=batch_size,
64
+ )
65
+ return cls(
66
+ writer=writer,
67
+ metadata_fields=metadata_fields,
68
+ )
69
+
70
+ @classmethod
71
+ def persistent(
72
+ cls,
73
+ path: str | Path,
74
+ batch_size: int = 10_000,
75
+ rows_per_file: int = 100_000,
76
+ compression: str = "snappy",
77
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
78
+ ):
79
+ """
80
+ Create a persistent file-based evaluator cache.
81
+
82
+ Parameters
83
+ ----------
84
+ path : str | Path
85
+ Where to store the file-based cache.
86
+ batch_size : int, default=10_000
87
+ The target number of rows to buffer before writing to the cache. Defaults to 10_000.
88
+ rows_per_file : int, default=100_000
89
+ The target number of rows to store per cache file. Defaults to 100_000.
90
+ compression : str, default="snappy"
91
+ The compression methods used when writing cache files.
92
+ metadata_fields : list[tuple[str, str | pa.DataType]], optional
93
+ Optional metadata field definitions.
94
+ """
95
+ path = Path(path)
96
+
97
+ # create cache
98
+ writer = FileCacheWriter.create(
99
+ path=generate_cache_path(path),
100
+ schema=generate_schema(metadata_fields),
101
+ batch_size=batch_size,
102
+ rows_per_file=rows_per_file,
103
+ compression=compression,
104
+ )
105
+
106
+ # write metadata
107
+ metadata_path = generate_metadata_path(path)
108
+ with open(metadata_path, "w") as f:
109
+ encoded_types = encode_metadata_fields(metadata_fields)
110
+ json.dump(encoded_types, f, indent=2)
111
+
112
+ return cls(
113
+ writer=writer,
114
+ metadata_fields=metadata_fields,
115
+ )
116
+
117
+ def finalize(
118
+ self,
119
+ index_to_label_override: dict[int, str] | None = None,
120
+ ):
121
+ """
122
+ Performs data finalization and some preprocessing steps.
123
+
124
+ Parameters
125
+ ----------
126
+ index_to_label_override : dict[int, str], optional
127
+ Pre-configures label mapping. Used when operating over filtered subsets.
128
+
129
+ Returns
130
+ -------
131
+ Evaluator
132
+ A ready-to-use evaluator object.
133
+ """
134
+ self._writer.flush()
135
+ if self._writer.count_rows() == 0:
136
+ raise EmptyCacheError()
137
+
138
+ reader = self._writer.to_reader()
139
+
140
+ # extract labels
141
+ index_to_label = extract_labels(
142
+ reader=reader,
143
+ index_to_label_override=index_to_label_override,
144
+ )
145
+
146
+ return Evaluator(
147
+ reader=reader,
148
+ index_to_label=index_to_label,
149
+ metadata_fields=self._metadata_fields,
150
+ )
151
+
152
+
153
+ class Evaluator:
154
+ def __init__(
155
+ self,
156
+ reader: MemoryCacheReader | FileCacheReader,
157
+ index_to_label: dict[int, str],
158
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
159
+ ):
160
+ self._reader = reader
161
+ self._index_to_label = index_to_label
162
+ self._metadata_fields = metadata_fields
163
+
164
+ @property
165
+ def info(self) -> EvaluatorInfo:
166
+ return self.get_info()
167
+
168
+ def get_info(
169
+ self,
170
+ datums: pc.Expression | None = None,
171
+ groundtruths: pc.Expression | None = None,
172
+ predictions: pc.Expression | None = None,
173
+ ) -> EvaluatorInfo:
174
+ info = EvaluatorInfo()
175
+ info.number_of_rows = self._reader.count_rows()
176
+ info.number_of_labels = len(self._index_to_label)
177
+ info.metadata_fields = self._metadata_fields
178
+ (
179
+ info.number_of_datums,
180
+ info.number_of_pixels,
181
+ info.number_of_groundtruth_pixels,
182
+ info.number_of_prediction_pixels,
183
+ ) = extract_counts(
184
+ reader=self._reader,
185
+ datums=datums,
186
+ groundtruths=groundtruths,
187
+ predictions=predictions,
188
+ )
189
+ return info
190
+
191
+ @classmethod
192
+ def load(
193
+ cls,
194
+ path: str | Path,
195
+ index_to_label_override: dict[int, str] | None = None,
196
+ ):
197
+ """
198
+ Load from an existing semantic segmentation cache.
199
+
200
+ Parameters
201
+ ----------
202
+ path : str | Path
203
+ Path to the existing cache.
204
+ index_to_label_override : dict[int, str], optional
205
+ Option to preset index to label dictionary. Used when loading from filtered caches.
206
+ """
207
+ # validate path
208
+ path = Path(path)
209
+ if not path.exists():
210
+ raise FileNotFoundError(f"Directory does not exist: {path}")
211
+ elif not path.is_dir():
212
+ raise NotADirectoryError(
213
+ f"Path exists but is not a directory: {path}"
214
+ )
215
+
216
+ # load cache
217
+ reader = FileCacheReader.load(generate_cache_path(path))
218
+
219
+ # extract labels
220
+ index_to_label = extract_labels(
221
+ reader=reader,
222
+ index_to_label_override=index_to_label_override,
223
+ )
224
+
225
+ # read config
226
+ metadata_path = generate_metadata_path(path)
227
+ metadata_fields = None
228
+ with open(metadata_path, "r") as f:
229
+ metadata_types = json.load(f)
230
+ metadata_fields = decode_metadata_fields(metadata_types)
231
+
232
+ return cls(
233
+ reader=reader,
234
+ index_to_label=index_to_label,
235
+ metadata_fields=metadata_fields,
236
+ )
237
+
238
+ def filter(
239
+ self,
240
+ datums: pc.Expression | None = None,
241
+ groundtruths: pc.Expression | None = None,
242
+ predictions: pc.Expression | None = None,
243
+ path: str | Path | None = None,
244
+ ) -> Evaluator:
245
+ """
246
+ Filter evaluator cache.
247
+
248
+ Parameters
249
+ ----------
250
+ datums : pc.Expression | None = None
251
+ A filter expression used to filter datums.
252
+ groundtruths : pc.Expression | None = None
253
+ A filter expression used to filter ground truth annotations.
254
+ predictions : pc.Expression | None = None
255
+ A filter expression used to filter predictions.
256
+ path : str | Path, optional
257
+ Where to store the filtered cache if storing on disk.
258
+
259
+ Returns
260
+ -------
261
+ Evaluator
262
+ A new evaluator object containing the filtered cache.
263
+ """
264
+ if isinstance(self._reader, FileCacheReader):
265
+ if not path:
266
+ raise ValueError(
267
+ "expected path to be defined for file-based cache"
268
+ )
269
+ builder = Builder.persistent(
270
+ path=path,
271
+ batch_size=self._reader.batch_size,
272
+ rows_per_file=self._reader.rows_per_file,
273
+ compression=self._reader.compression,
274
+ metadata_fields=self.info.metadata_fields,
275
+ )
276
+ else:
277
+ builder = Builder.in_memory(
278
+ batch_size=self._reader.batch_size,
279
+ metadata_fields=self.info.metadata_fields,
280
+ )
281
+
282
+ for tbl in self._reader.iterate_tables(filter=datums):
283
+ columns = (
284
+ "datum_id",
285
+ "gt_label_id",
286
+ "pd_label_id",
287
+ )
288
+ pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
289
+
290
+ n_pairs = pairs.shape[0]
291
+ gt_ids = pairs[:, (0, 1)].astype(np.int64)
292
+ pd_ids = pairs[:, (0, 2)].astype(np.int64)
293
+
294
+ if groundtruths is not None:
295
+ mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
296
+ gt_tbl = tbl.filter(groundtruths)
297
+ gt_pairs = np.column_stack(
298
+ [
299
+ gt_tbl[col].to_numpy()
300
+ for col in ("datum_id", "gt_label_id")
301
+ ]
302
+ ).astype(np.int64)
303
+ for gt in np.unique(gt_pairs, axis=0):
304
+ mask_valid_gt |= (gt_ids == gt).all(axis=1)
305
+ else:
306
+ mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
307
+
308
+ if predictions is not None:
309
+ mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
310
+ pd_tbl = tbl.filter(predictions)
311
+ pd_pairs = np.column_stack(
312
+ [
313
+ pd_tbl[col].to_numpy()
314
+ for col in ("datum_id", "pd_label_id")
315
+ ]
316
+ ).astype(np.int64)
317
+ for pd in np.unique(pd_pairs, axis=0):
318
+ mask_valid_pd |= (pd_ids == pd).all(axis=1)
319
+ else:
320
+ mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
321
+
322
+ mask_valid = mask_valid_gt | mask_valid_pd
323
+ mask_valid_gt &= mask_valid
324
+ mask_valid_pd &= mask_valid
325
+
326
+ pairs[~mask_valid_gt, 1] = -1
327
+ pairs[~mask_valid_pd, 2] = -1
328
+
329
+ for idx, col in enumerate(columns):
330
+ tbl = tbl.set_column(
331
+ tbl.schema.names.index(col), col, pa.array(pairs[:, idx])
332
+ )
333
+ builder._writer.write_table(tbl)
334
+
335
+ return builder.finalize(index_to_label_override=self._index_to_label)
336
+
337
+ def _compute_confusion_matrix_intermediate(
338
+ self, datums: pc.Expression | None = None
339
+ ) -> NDArray[np.uint64]:
340
+ """
341
+ Performs an evaluation and returns metrics.
342
+
343
+ Parameters
344
+ ----------
345
+ datums : pyarrow.compute.Expression, optional
346
+ Option to filter datums by an expression.
347
+
348
+ Returns
349
+ -------
350
+ dict[MetricType, list]
351
+ A dictionary mapping MetricType enumerations to lists of computed metrics.
352
+ """
353
+ n_labels = len(self._index_to_label)
354
+ confusion_matrix = np.zeros(
355
+ (n_labels + 1, n_labels + 1), dtype=np.uint64
356
+ )
357
+ for tbl in self._reader.iterate_tables(filter=datums):
358
+ columns = (
359
+ "datum_id",
360
+ "gt_label_id",
361
+ "pd_label_id",
362
+ )
363
+ ids = np.column_stack(
364
+ [tbl[col].to_numpy() for col in columns]
365
+ ).astype(np.int64)
366
+ counts = tbl["count"].to_numpy()
367
+
368
+ mask_null_gts = ids[:, 1] == -1
369
+ mask_null_pds = ids[:, 2] == -1
370
+ confusion_matrix[0, 0] += counts[
371
+ mask_null_gts & mask_null_pds
372
+ ].sum()
373
+ for idx in range(n_labels):
374
+ mask_gts = ids[:, 1] == idx
375
+ for pidx in range(n_labels):
376
+ mask_pds = ids[:, 2] == pidx
377
+ confusion_matrix[idx + 1, pidx + 1] += counts[
378
+ mask_gts & mask_pds
379
+ ].sum()
380
+
381
+ mask_unmatched_gts = mask_gts & mask_null_pds
382
+ confusion_matrix[idx + 1, 0] += counts[
383
+ mask_unmatched_gts
384
+ ].sum()
385
+ mask_unmatched_pds = mask_null_gts & (ids[:, 2] == idx)
386
+ confusion_matrix[0, idx + 1] += counts[
387
+ mask_unmatched_pds
388
+ ].sum()
389
+ return confusion_matrix
390
+
391
+ def compute_precision_recall_iou(
392
+ self, datums: pc.Expression | None = None
393
+ ) -> dict[MetricType, list]:
394
+ """
395
+ Performs an evaluation and returns metrics.
396
+
397
+ Parameters
398
+ ----------
399
+ datums : pyarrow.compute.Expression, optional
400
+ Option to filter datums by an expression.
401
+
402
+ Returns
403
+ -------
404
+ dict[MetricType, list]
405
+ A dictionary mapping MetricType enumerations to lists of computed metrics.
406
+ """
407
+ confusion_matrix = self._compute_confusion_matrix_intermediate(
408
+ datums=datums
409
+ )
410
+ results = compute_metrics(confusion_matrix=confusion_matrix)
411
+ return unpack_precision_recall_iou_into_metric_lists(
412
+ results=results,
413
+ index_to_label=self._index_to_label,
414
+ )