valor-lite 0.36.6__py3-none-any.whl → 0.37.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +211 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +536 -0
- valor_lite/classification/__init__.py +5 -10
- valor_lite/classification/annotation.py +4 -0
- valor_lite/classification/computation.py +233 -251
- valor_lite/classification/evaluator.py +882 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +141 -4
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +221 -118
- valor_lite/exceptions.py +5 -0
- valor_lite/object_detection/__init__.py +5 -4
- valor_lite/object_detection/annotation.py +13 -1
- valor_lite/object_detection/computation.py +368 -299
- valor_lite/object_detection/evaluator.py +804 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +152 -3
- valor_lite/object_detection/shared.py +206 -0
- valor_lite/object_detection/utilities.py +182 -100
- valor_lite/semantic_segmentation/__init__.py +5 -4
- valor_lite/semantic_segmentation/annotation.py +7 -0
- valor_lite/semantic_segmentation/computation.py +20 -110
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +6 -23
- {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
- valor_lite-0.37.5.dist-info/RECORD +49 -0
- {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
- valor_lite/classification/manager.py +0 -545
- valor_lite/object_detection/manager.py +0 -864
- valor_lite/profiling.py +0 -374
- valor_lite/semantic_segmentation/benchmark.py +0 -237
- valor_lite/semantic_segmentation/manager.py +0 -446
- valor_lite-0.36.6.dist-info/RECORD +0 -41
- {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,804 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pyarrow as pa
|
|
6
|
+
import pyarrow.compute as pc
|
|
7
|
+
|
|
8
|
+
from valor_lite.cache import (
|
|
9
|
+
FileCacheReader,
|
|
10
|
+
FileCacheWriter,
|
|
11
|
+
MemoryCacheReader,
|
|
12
|
+
MemoryCacheWriter,
|
|
13
|
+
compute,
|
|
14
|
+
)
|
|
15
|
+
from valor_lite.exceptions import EmptyCacheError
|
|
16
|
+
from valor_lite.object_detection.computation import (
|
|
17
|
+
compute_average_precision,
|
|
18
|
+
compute_average_recall,
|
|
19
|
+
compute_confusion_matrix,
|
|
20
|
+
compute_counts,
|
|
21
|
+
compute_pair_classifications,
|
|
22
|
+
compute_precision_recall_f1,
|
|
23
|
+
rank_table,
|
|
24
|
+
)
|
|
25
|
+
from valor_lite.object_detection.metric import Metric, MetricType
|
|
26
|
+
from valor_lite.object_detection.shared import (
|
|
27
|
+
EvaluatorInfo,
|
|
28
|
+
decode_metadata_fields,
|
|
29
|
+
encode_metadata_fields,
|
|
30
|
+
extract_counts,
|
|
31
|
+
extract_groundtruth_count_per_label,
|
|
32
|
+
extract_labels,
|
|
33
|
+
generate_detailed_cache_path,
|
|
34
|
+
generate_detailed_schema,
|
|
35
|
+
generate_metadata_path,
|
|
36
|
+
generate_ranked_cache_path,
|
|
37
|
+
generate_ranked_schema,
|
|
38
|
+
)
|
|
39
|
+
from valor_lite.object_detection.utilities import (
|
|
40
|
+
create_empty_confusion_matrix_with_examples,
|
|
41
|
+
create_mapping,
|
|
42
|
+
unpack_confusion_matrix,
|
|
43
|
+
unpack_confusion_matrix_with_examples,
|
|
44
|
+
unpack_examples,
|
|
45
|
+
unpack_precision_recall_into_metric_lists,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Builder:
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
detailed_writer: MemoryCacheWriter | FileCacheWriter,
|
|
53
|
+
ranked_writer: MemoryCacheWriter | FileCacheWriter,
|
|
54
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
55
|
+
):
|
|
56
|
+
self._detailed_writer = detailed_writer
|
|
57
|
+
self._ranked_writer = ranked_writer
|
|
58
|
+
self._metadata_fields = metadata_fields
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def in_memory(
|
|
62
|
+
cls,
|
|
63
|
+
batch_size: int = 10_000,
|
|
64
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
Create an in-memory evaluator cache.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
batch_size : int, default=10_000
|
|
72
|
+
The target number of rows to buffer before writing to the cache. Defaults to 10_000.
|
|
73
|
+
metadata_fields : list[tuple[str, str | pa.DataType]], optional
|
|
74
|
+
Optional datum metadata field definitions.
|
|
75
|
+
"""
|
|
76
|
+
# create cache
|
|
77
|
+
detailed_writer = MemoryCacheWriter.create(
|
|
78
|
+
schema=generate_detailed_schema(metadata_fields),
|
|
79
|
+
batch_size=batch_size,
|
|
80
|
+
)
|
|
81
|
+
ranked_writer = MemoryCacheWriter.create(
|
|
82
|
+
schema=generate_ranked_schema(metadata_fields),
|
|
83
|
+
batch_size=batch_size,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return cls(
|
|
87
|
+
detailed_writer=detailed_writer,
|
|
88
|
+
ranked_writer=ranked_writer,
|
|
89
|
+
metadata_fields=metadata_fields,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def persistent(
|
|
94
|
+
cls,
|
|
95
|
+
path: str | Path,
|
|
96
|
+
batch_size: int = 10_000,
|
|
97
|
+
rows_per_file: int = 100_000,
|
|
98
|
+
compression: str = "snappy",
|
|
99
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
100
|
+
):
|
|
101
|
+
"""
|
|
102
|
+
Create a persistent file-based evaluator cache.
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
path : str | Path
|
|
107
|
+
Where to store the file-based cache.
|
|
108
|
+
batch_size : int, default=10_000
|
|
109
|
+
The target number of rows to buffer before writing to the cache. Defaults to 10_000.
|
|
110
|
+
rows_per_file : int, default=100_000
|
|
111
|
+
The target number of rows to store per cache file. Defaults to 100_000.
|
|
112
|
+
compression : str, default="snappy"
|
|
113
|
+
The compression methods used when writing cache files.
|
|
114
|
+
metadata_fields : list[tuple[str, str | pa.DataType]], optional
|
|
115
|
+
Optional metadata field definitions.
|
|
116
|
+
"""
|
|
117
|
+
path = Path(path)
|
|
118
|
+
|
|
119
|
+
# create caches
|
|
120
|
+
detailed_writer = FileCacheWriter.create(
|
|
121
|
+
path=generate_detailed_cache_path(path),
|
|
122
|
+
schema=generate_detailed_schema(metadata_fields),
|
|
123
|
+
batch_size=batch_size,
|
|
124
|
+
rows_per_file=rows_per_file,
|
|
125
|
+
compression=compression,
|
|
126
|
+
)
|
|
127
|
+
ranked_writer = FileCacheWriter.create(
|
|
128
|
+
path=generate_ranked_cache_path(path),
|
|
129
|
+
schema=generate_ranked_schema(metadata_fields),
|
|
130
|
+
batch_size=batch_size,
|
|
131
|
+
rows_per_file=rows_per_file,
|
|
132
|
+
compression=compression,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# write metadata
|
|
136
|
+
metadata_path = generate_metadata_path(path)
|
|
137
|
+
with open(metadata_path, "w") as f:
|
|
138
|
+
encoded_types = encode_metadata_fields(metadata_fields)
|
|
139
|
+
json.dump(encoded_types, f, indent=2)
|
|
140
|
+
|
|
141
|
+
return cls(
|
|
142
|
+
detailed_writer=detailed_writer,
|
|
143
|
+
ranked_writer=ranked_writer,
|
|
144
|
+
metadata_fields=metadata_fields,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def _rank(self, batch_size: int = 1_000):
|
|
148
|
+
"""Perform pair ranking over the detailed cache."""
|
|
149
|
+
|
|
150
|
+
detailed_reader = self._detailed_writer.to_reader()
|
|
151
|
+
compute.sort(
|
|
152
|
+
source=detailed_reader,
|
|
153
|
+
sink=self._ranked_writer,
|
|
154
|
+
batch_size=batch_size,
|
|
155
|
+
sorting=[
|
|
156
|
+
("pd_score", "descending"),
|
|
157
|
+
("iou", "descending"),
|
|
158
|
+
],
|
|
159
|
+
columns=[
|
|
160
|
+
field.name
|
|
161
|
+
for field in self._ranked_writer.schema
|
|
162
|
+
if field.name != "iou_prev"
|
|
163
|
+
],
|
|
164
|
+
table_sort_override=rank_table,
|
|
165
|
+
)
|
|
166
|
+
self._ranked_writer.flush()
|
|
167
|
+
|
|
168
|
+
def finalize(
|
|
169
|
+
self,
|
|
170
|
+
batch_size: int = 1_000,
|
|
171
|
+
index_to_label_override: dict[int, str] | None = None,
|
|
172
|
+
):
|
|
173
|
+
"""
|
|
174
|
+
Performs data finalization and preprocessing.
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
batch_size : int, default=1_000
|
|
179
|
+
Sets the batch size for reading. Defaults to 1_000.
|
|
180
|
+
index_to_label_override : dict[int, str], optional
|
|
181
|
+
Pre-configures label mapping. Used when operating over filtered subsets.
|
|
182
|
+
"""
|
|
183
|
+
self._detailed_writer.flush()
|
|
184
|
+
if self._detailed_writer.count_rows() == 0:
|
|
185
|
+
raise EmptyCacheError()
|
|
186
|
+
|
|
187
|
+
self._detailed_writer.sort_by(
|
|
188
|
+
[
|
|
189
|
+
("pd_score", "descending"),
|
|
190
|
+
("iou", "descending"),
|
|
191
|
+
]
|
|
192
|
+
)
|
|
193
|
+
detailed_reader = self._detailed_writer.to_reader()
|
|
194
|
+
|
|
195
|
+
# extract labels
|
|
196
|
+
index_to_label = extract_labels(
|
|
197
|
+
reader=detailed_reader,
|
|
198
|
+
index_to_label_override=index_to_label_override,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# populate ranked cache
|
|
202
|
+
self._rank(batch_size)
|
|
203
|
+
|
|
204
|
+
ranked_reader = self._ranked_writer.to_reader()
|
|
205
|
+
return Evaluator(
|
|
206
|
+
detailed_reader=detailed_reader,
|
|
207
|
+
ranked_reader=ranked_reader,
|
|
208
|
+
index_to_label=index_to_label,
|
|
209
|
+
metadata_fields=self._metadata_fields,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class Evaluator:
|
|
214
|
+
def __init__(
|
|
215
|
+
self,
|
|
216
|
+
detailed_reader: MemoryCacheReader | FileCacheReader,
|
|
217
|
+
ranked_reader: MemoryCacheReader | FileCacheReader,
|
|
218
|
+
index_to_label: dict[int, str],
|
|
219
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
220
|
+
):
|
|
221
|
+
self._detailed_reader = detailed_reader
|
|
222
|
+
self._ranked_reader = ranked_reader
|
|
223
|
+
self._index_to_label = index_to_label
|
|
224
|
+
self._metadata_fields = metadata_fields
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def info(self) -> EvaluatorInfo:
|
|
228
|
+
return self.get_info()
|
|
229
|
+
|
|
230
|
+
def get_info(
|
|
231
|
+
self,
|
|
232
|
+
datums: pc.Expression | None = None,
|
|
233
|
+
groundtruths: pc.Expression | None = None,
|
|
234
|
+
predictions: pc.Expression | None = None,
|
|
235
|
+
) -> EvaluatorInfo:
|
|
236
|
+
info = EvaluatorInfo()
|
|
237
|
+
info.number_of_rows = self._detailed_reader.count_rows()
|
|
238
|
+
info.metadata_fields = self._metadata_fields
|
|
239
|
+
info.number_of_labels = len(self._index_to_label)
|
|
240
|
+
(
|
|
241
|
+
info.number_of_datums,
|
|
242
|
+
info.number_of_groundtruth_annotations,
|
|
243
|
+
info.number_of_prediction_annotations,
|
|
244
|
+
) = extract_counts(
|
|
245
|
+
reader=self._detailed_reader,
|
|
246
|
+
datums=datums,
|
|
247
|
+
groundtruths=groundtruths,
|
|
248
|
+
predictions=predictions,
|
|
249
|
+
)
|
|
250
|
+
return info
|
|
251
|
+
|
|
252
|
+
@classmethod
|
|
253
|
+
def load(
|
|
254
|
+
cls,
|
|
255
|
+
path: str | Path,
|
|
256
|
+
index_to_label_override: dict[int, str] | None = None,
|
|
257
|
+
):
|
|
258
|
+
# validate path
|
|
259
|
+
path = Path(path)
|
|
260
|
+
if not path.exists():
|
|
261
|
+
raise FileNotFoundError(f"Directory does not exist: {path}")
|
|
262
|
+
elif not path.is_dir():
|
|
263
|
+
raise NotADirectoryError(
|
|
264
|
+
f"Path exists but is not a directory: {path}"
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
detailed_reader = FileCacheReader.load(
|
|
268
|
+
generate_detailed_cache_path(path)
|
|
269
|
+
)
|
|
270
|
+
ranked_reader = FileCacheReader.load(generate_ranked_cache_path(path))
|
|
271
|
+
|
|
272
|
+
# extract labels from cache
|
|
273
|
+
index_to_label = extract_labels(
|
|
274
|
+
reader=detailed_reader,
|
|
275
|
+
index_to_label_override=index_to_label_override,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# read config
|
|
279
|
+
metadata_path = generate_metadata_path(path)
|
|
280
|
+
metadata_fields = None
|
|
281
|
+
with open(metadata_path, "r") as f:
|
|
282
|
+
encoded_metadata_types = json.load(f)
|
|
283
|
+
metadata_fields = decode_metadata_fields(encoded_metadata_types)
|
|
284
|
+
|
|
285
|
+
return cls(
|
|
286
|
+
detailed_reader=detailed_reader,
|
|
287
|
+
ranked_reader=ranked_reader,
|
|
288
|
+
index_to_label=index_to_label,
|
|
289
|
+
metadata_fields=metadata_fields,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def filter(
|
|
293
|
+
self,
|
|
294
|
+
datums: pc.Expression | None = None,
|
|
295
|
+
groundtruths: pc.Expression | None = None,
|
|
296
|
+
predictions: pc.Expression | None = None,
|
|
297
|
+
batch_size: int = 1_000,
|
|
298
|
+
path: str | Path | None = None,
|
|
299
|
+
) -> "Evaluator":
|
|
300
|
+
"""
|
|
301
|
+
Filter evaluator cache.
|
|
302
|
+
|
|
303
|
+
Parameters
|
|
304
|
+
----------
|
|
305
|
+
datums : pc.Expression | None = None
|
|
306
|
+
A filter expression used to filter datums.
|
|
307
|
+
groundtruths : pc.Expression | None = None
|
|
308
|
+
A filter expression used to filter ground truth annotations.
|
|
309
|
+
predictions : pc.Expression | None = None
|
|
310
|
+
A filter expression used to filter predictions.
|
|
311
|
+
batch_size : int
|
|
312
|
+
The maximum number of rows read into memory per file.
|
|
313
|
+
path : str | Path
|
|
314
|
+
Where to store the filtered cache if storing on disk.
|
|
315
|
+
|
|
316
|
+
Returns
|
|
317
|
+
-------
|
|
318
|
+
Evaluator
|
|
319
|
+
A new evaluator object containing the filtered cache.
|
|
320
|
+
"""
|
|
321
|
+
from valor_lite.object_detection.loader import Loader
|
|
322
|
+
|
|
323
|
+
if isinstance(self._detailed_reader, FileCacheReader):
|
|
324
|
+
if not path:
|
|
325
|
+
raise ValueError(
|
|
326
|
+
"expected path to be defined for file-based loader"
|
|
327
|
+
)
|
|
328
|
+
loader = Loader.persistent(
|
|
329
|
+
path=path,
|
|
330
|
+
batch_size=self._detailed_reader.batch_size,
|
|
331
|
+
rows_per_file=self._detailed_reader.rows_per_file,
|
|
332
|
+
compression=self._detailed_reader.compression,
|
|
333
|
+
metadata_fields=self._metadata_fields,
|
|
334
|
+
)
|
|
335
|
+
else:
|
|
336
|
+
loader = Loader.in_memory(
|
|
337
|
+
batch_size=self._detailed_reader.batch_size,
|
|
338
|
+
metadata_fields=self._metadata_fields,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
for tbl in self._detailed_reader.iterate_tables(filter=datums):
|
|
342
|
+
columns = (
|
|
343
|
+
"datum_id",
|
|
344
|
+
"gt_id",
|
|
345
|
+
"pd_id",
|
|
346
|
+
"gt_label_id",
|
|
347
|
+
"pd_label_id",
|
|
348
|
+
"iou",
|
|
349
|
+
"pd_score",
|
|
350
|
+
)
|
|
351
|
+
pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
|
|
352
|
+
|
|
353
|
+
n_pairs = pairs.shape[0]
|
|
354
|
+
gt_ids = pairs[:, (0, 1)].astype(np.int64)
|
|
355
|
+
pd_ids = pairs[:, (0, 2)].astype(np.int64)
|
|
356
|
+
|
|
357
|
+
if groundtruths is not None:
|
|
358
|
+
mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
|
|
359
|
+
gt_tbl = tbl.filter(groundtruths)
|
|
360
|
+
gt_pairs = np.column_stack(
|
|
361
|
+
[gt_tbl[col].to_numpy() for col in ("datum_id", "gt_id")]
|
|
362
|
+
).astype(np.int64)
|
|
363
|
+
for gt in np.unique(gt_pairs, axis=0):
|
|
364
|
+
mask_valid_gt |= (gt_ids == gt).all(axis=1)
|
|
365
|
+
else:
|
|
366
|
+
mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
|
|
367
|
+
|
|
368
|
+
if predictions is not None:
|
|
369
|
+
mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
|
|
370
|
+
pd_tbl = tbl.filter(predictions)
|
|
371
|
+
pd_pairs = np.column_stack(
|
|
372
|
+
[pd_tbl[col].to_numpy() for col in ("datum_id", "pd_id")]
|
|
373
|
+
).astype(np.int64)
|
|
374
|
+
for pd in np.unique(pd_pairs, axis=0):
|
|
375
|
+
mask_valid_pd |= (pd_ids == pd).all(axis=1)
|
|
376
|
+
else:
|
|
377
|
+
mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
|
|
378
|
+
|
|
379
|
+
mask_valid = mask_valid_gt | mask_valid_pd
|
|
380
|
+
mask_valid_gt &= mask_valid
|
|
381
|
+
mask_valid_pd &= mask_valid
|
|
382
|
+
|
|
383
|
+
# filter out invalid gt_id, gt_label_id by setting to -1.0
|
|
384
|
+
pairs[np.ix_(~mask_valid_gt, (1, 3))] = -1.0 # type: ignore[reportArgumentType]
|
|
385
|
+
# filter out invalid pd_id, pd_label_id, pd_score by setting to -1.0
|
|
386
|
+
pairs[np.ix_(~mask_valid_pd, (2, 4, 6))] = -1.0 # type: ignore[reportArgumentType]
|
|
387
|
+
# filter out invalid iou by setting to 0.0
|
|
388
|
+
pairs[~mask_valid_pd | ~mask_valid_gt, 5] = 0.0
|
|
389
|
+
|
|
390
|
+
for idx, col in enumerate(columns):
|
|
391
|
+
column = pairs[:, idx]
|
|
392
|
+
if col not in {"iou", "pd_score"}:
|
|
393
|
+
column = column.astype(np.int64)
|
|
394
|
+
|
|
395
|
+
col_idx = tbl.schema.names.index(col)
|
|
396
|
+
tbl = tbl.set_column(
|
|
397
|
+
col_idx, tbl.schema[col_idx], pa.array(column)
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
mask_invalid = ~mask_valid | (pairs[:, (1, 2)] < 0).all(axis=1)
|
|
401
|
+
filtered_tbl = tbl.filter(pa.array(~mask_invalid))
|
|
402
|
+
loader._detailed_writer.write_table(filtered_tbl)
|
|
403
|
+
|
|
404
|
+
return loader.finalize(
|
|
405
|
+
batch_size=batch_size,
|
|
406
|
+
index_to_label_override=self._index_to_label,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
def compute_precision_recall(
|
|
410
|
+
self,
|
|
411
|
+
iou_thresholds: list[float],
|
|
412
|
+
score_thresholds: list[float],
|
|
413
|
+
datums: pc.Expression | None = None,
|
|
414
|
+
) -> dict[MetricType, list[Metric]]:
|
|
415
|
+
"""
|
|
416
|
+
Computes all metrics except for ConfusionMatrix
|
|
417
|
+
|
|
418
|
+
Parameters
|
|
419
|
+
----------
|
|
420
|
+
iou_thresholds : list[float]
|
|
421
|
+
A list of IOU thresholds to compute metrics over.
|
|
422
|
+
score_thresholds : list[float]
|
|
423
|
+
A list of score thresholds to compute metrics over.
|
|
424
|
+
datums : pyarrow.compute.Expression, optional
|
|
425
|
+
Option to filter datums by an expression.
|
|
426
|
+
|
|
427
|
+
Returns
|
|
428
|
+
-------
|
|
429
|
+
dict[MetricType, list]
|
|
430
|
+
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
431
|
+
"""
|
|
432
|
+
if not iou_thresholds:
|
|
433
|
+
raise ValueError("At least one IOU threshold must be passed.")
|
|
434
|
+
elif not score_thresholds:
|
|
435
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
436
|
+
|
|
437
|
+
n_ious = len(iou_thresholds)
|
|
438
|
+
n_scores = len(score_thresholds)
|
|
439
|
+
n_labels = len(self._index_to_label)
|
|
440
|
+
n_gts_per_lbl = extract_groundtruth_count_per_label(
|
|
441
|
+
reader=self._detailed_reader,
|
|
442
|
+
number_of_labels=n_labels,
|
|
443
|
+
datums=datums,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
counts = np.zeros((n_ious, n_scores, 3, n_labels), dtype=np.uint64)
|
|
447
|
+
pr_curve = np.zeros((n_ious, n_labels, 101, 2), dtype=np.float64)
|
|
448
|
+
running_counts = np.zeros((n_ious, n_labels, 2), dtype=np.uint64)
|
|
449
|
+
|
|
450
|
+
for pairs in self._ranked_reader.iterate_arrays(
|
|
451
|
+
numeric_columns=[
|
|
452
|
+
"datum_id",
|
|
453
|
+
"gt_id",
|
|
454
|
+
"pd_id",
|
|
455
|
+
"gt_label_id",
|
|
456
|
+
"pd_label_id",
|
|
457
|
+
"iou",
|
|
458
|
+
"pd_score",
|
|
459
|
+
"iou_prev",
|
|
460
|
+
],
|
|
461
|
+
filter=datums,
|
|
462
|
+
):
|
|
463
|
+
if pairs.size == 0:
|
|
464
|
+
continue
|
|
465
|
+
|
|
466
|
+
batch_counts = compute_counts(
|
|
467
|
+
ranked_pairs=pairs,
|
|
468
|
+
iou_thresholds=np.array(iou_thresholds),
|
|
469
|
+
score_thresholds=np.array(score_thresholds),
|
|
470
|
+
number_of_groundtruths_per_label=n_gts_per_lbl,
|
|
471
|
+
number_of_labels=len(self._index_to_label),
|
|
472
|
+
running_counts=running_counts,
|
|
473
|
+
pr_curve=pr_curve,
|
|
474
|
+
)
|
|
475
|
+
counts += batch_counts
|
|
476
|
+
|
|
477
|
+
# fn count
|
|
478
|
+
counts[:, :, 2, :] = n_gts_per_lbl - counts[:, :, 0, :]
|
|
479
|
+
|
|
480
|
+
precision_recall_f1 = compute_precision_recall_f1(
|
|
481
|
+
counts=counts,
|
|
482
|
+
number_of_groundtruths_per_label=n_gts_per_lbl,
|
|
483
|
+
)
|
|
484
|
+
(
|
|
485
|
+
average_precision,
|
|
486
|
+
mean_average_precision,
|
|
487
|
+
pr_curve,
|
|
488
|
+
) = compute_average_precision(pr_curve=pr_curve)
|
|
489
|
+
average_recall, mean_average_recall = compute_average_recall(
|
|
490
|
+
prec_rec_f1=precision_recall_f1
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
return unpack_precision_recall_into_metric_lists(
|
|
494
|
+
counts=counts,
|
|
495
|
+
precision_recall_f1=precision_recall_f1,
|
|
496
|
+
average_precision=average_precision,
|
|
497
|
+
mean_average_precision=mean_average_precision,
|
|
498
|
+
average_recall=average_recall,
|
|
499
|
+
mean_average_recall=mean_average_recall,
|
|
500
|
+
pr_curve=pr_curve,
|
|
501
|
+
iou_thresholds=iou_thresholds,
|
|
502
|
+
score_thresholds=score_thresholds,
|
|
503
|
+
index_to_label=self._index_to_label,
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
def compute_confusion_matrix(
|
|
507
|
+
self,
|
|
508
|
+
iou_thresholds: list[float],
|
|
509
|
+
score_thresholds: list[float],
|
|
510
|
+
datums: pc.Expression | None = None,
|
|
511
|
+
) -> list[Metric]:
|
|
512
|
+
"""
|
|
513
|
+
Computes confusion matrices at various thresholds.
|
|
514
|
+
|
|
515
|
+
Parameters
|
|
516
|
+
----------
|
|
517
|
+
iou_thresholds : list[float]
|
|
518
|
+
A list of IOU thresholds to compute metrics over.
|
|
519
|
+
score_thresholds : list[float]
|
|
520
|
+
A list of score thresholds to compute metrics over.
|
|
521
|
+
datums : pyarrow.compute.Expression, optional
|
|
522
|
+
Option to filter datums by an expression.
|
|
523
|
+
|
|
524
|
+
Returns
|
|
525
|
+
-------
|
|
526
|
+
list[Metric]
|
|
527
|
+
List of confusion matrices per threshold pair.
|
|
528
|
+
"""
|
|
529
|
+
if not iou_thresholds:
|
|
530
|
+
raise ValueError("At least one IOU threshold must be passed.")
|
|
531
|
+
elif not score_thresholds:
|
|
532
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
533
|
+
|
|
534
|
+
n_ious = len(iou_thresholds)
|
|
535
|
+
n_scores = len(score_thresholds)
|
|
536
|
+
n_labels = len(self._index_to_label)
|
|
537
|
+
|
|
538
|
+
confusion_matrices = np.zeros(
|
|
539
|
+
(n_ious, n_scores, n_labels, n_labels), dtype=np.uint64
|
|
540
|
+
)
|
|
541
|
+
unmatched_groundtruths = np.zeros(
|
|
542
|
+
(n_ious, n_scores, n_labels), dtype=np.uint64
|
|
543
|
+
)
|
|
544
|
+
unmatched_predictions = np.zeros_like(unmatched_groundtruths)
|
|
545
|
+
|
|
546
|
+
for pairs in self._detailed_reader.iterate_arrays(
|
|
547
|
+
numeric_columns=[
|
|
548
|
+
"datum_id",
|
|
549
|
+
"gt_id",
|
|
550
|
+
"pd_id",
|
|
551
|
+
"gt_label_id",
|
|
552
|
+
"pd_label_id",
|
|
553
|
+
"iou",
|
|
554
|
+
"pd_score",
|
|
555
|
+
],
|
|
556
|
+
filter=datums,
|
|
557
|
+
):
|
|
558
|
+
if pairs.size == 0:
|
|
559
|
+
continue
|
|
560
|
+
|
|
561
|
+
(
|
|
562
|
+
batch_mask_tp,
|
|
563
|
+
batch_mask_fp_fn_misclf,
|
|
564
|
+
batch_mask_fp_unmatched,
|
|
565
|
+
batch_mask_fn_unmatched,
|
|
566
|
+
) = compute_pair_classifications(
|
|
567
|
+
detailed_pairs=pairs,
|
|
568
|
+
iou_thresholds=np.array(iou_thresholds),
|
|
569
|
+
score_thresholds=np.array(score_thresholds),
|
|
570
|
+
)
|
|
571
|
+
(
|
|
572
|
+
batch_confusion_matrices,
|
|
573
|
+
batch_unmatched_groundtruths,
|
|
574
|
+
batch_unmatched_predictions,
|
|
575
|
+
) = compute_confusion_matrix(
|
|
576
|
+
detailed_pairs=pairs,
|
|
577
|
+
mask_tp=batch_mask_tp,
|
|
578
|
+
mask_fp_fn_misclf=batch_mask_fp_fn_misclf,
|
|
579
|
+
mask_fp_unmatched=batch_mask_fp_unmatched,
|
|
580
|
+
mask_fn_unmatched=batch_mask_fn_unmatched,
|
|
581
|
+
number_of_labels=n_labels,
|
|
582
|
+
iou_thresholds=np.array(iou_thresholds),
|
|
583
|
+
score_thresholds=np.array(score_thresholds),
|
|
584
|
+
)
|
|
585
|
+
confusion_matrices += batch_confusion_matrices
|
|
586
|
+
unmatched_groundtruths += batch_unmatched_groundtruths
|
|
587
|
+
unmatched_predictions += batch_unmatched_predictions
|
|
588
|
+
|
|
589
|
+
return unpack_confusion_matrix(
|
|
590
|
+
confusion_matrices=confusion_matrices,
|
|
591
|
+
unmatched_groundtruths=unmatched_groundtruths,
|
|
592
|
+
unmatched_predictions=unmatched_predictions,
|
|
593
|
+
index_to_label=self._index_to_label,
|
|
594
|
+
iou_thresholds=iou_thresholds,
|
|
595
|
+
score_thresholds=score_thresholds,
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
def compute_examples(
|
|
599
|
+
self,
|
|
600
|
+
iou_thresholds: list[float],
|
|
601
|
+
score_thresholds: list[float],
|
|
602
|
+
datums: pc.Expression | None = None,
|
|
603
|
+
limit: int | None = None,
|
|
604
|
+
offset: int = 0,
|
|
605
|
+
) -> list[Metric]:
|
|
606
|
+
"""
|
|
607
|
+
Computes examples at various thresholds.
|
|
608
|
+
|
|
609
|
+
This function can use a lot of memory with larger or high density datasets. Please use it with filters.
|
|
610
|
+
|
|
611
|
+
Parameters
|
|
612
|
+
----------
|
|
613
|
+
iou_thresholds : list[float]
|
|
614
|
+
A list of IOU thresholds to compute metrics over.
|
|
615
|
+
score_thresholds : list[float]
|
|
616
|
+
A list of score thresholds to compute metrics over.
|
|
617
|
+
datums : pyarrow.compute.Expression, optional
|
|
618
|
+
Option to filter datums by an expression.
|
|
619
|
+
limit : int, optional
|
|
620
|
+
Option to set a limit to the number of returned datum examples.
|
|
621
|
+
offset : int, default=0
|
|
622
|
+
Option to offset where examples are being created in the datum index.
|
|
623
|
+
|
|
624
|
+
Returns
|
|
625
|
+
-------
|
|
626
|
+
list[Metric]
|
|
627
|
+
List of confusion matrices per threshold pair.
|
|
628
|
+
"""
|
|
629
|
+
if not iou_thresholds:
|
|
630
|
+
raise ValueError("At least one IOU threshold must be passed.")
|
|
631
|
+
elif not score_thresholds:
|
|
632
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
633
|
+
|
|
634
|
+
metrics = []
|
|
635
|
+
numeric_columns = [
|
|
636
|
+
"datum_id",
|
|
637
|
+
"gt_id",
|
|
638
|
+
"pd_id",
|
|
639
|
+
"gt_label_id",
|
|
640
|
+
"pd_label_id",
|
|
641
|
+
"iou",
|
|
642
|
+
"pd_score",
|
|
643
|
+
]
|
|
644
|
+
for tbl in compute.paginate_index(
|
|
645
|
+
source=self._detailed_reader,
|
|
646
|
+
column_key="datum_id",
|
|
647
|
+
modifier=datums,
|
|
648
|
+
limit=limit,
|
|
649
|
+
offset=offset,
|
|
650
|
+
):
|
|
651
|
+
if tbl.num_rows == 0:
|
|
652
|
+
continue
|
|
653
|
+
|
|
654
|
+
pairs = np.column_stack(
|
|
655
|
+
[tbl[col].to_numpy() for col in numeric_columns]
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
index_to_datum_id = {}
|
|
659
|
+
index_to_groundtruth_id = {}
|
|
660
|
+
index_to_prediction_id = {}
|
|
661
|
+
|
|
662
|
+
# extract external identifiers
|
|
663
|
+
index_to_datum_id = create_mapping(
|
|
664
|
+
tbl, pairs, 0, "datum_id", "datum_uid"
|
|
665
|
+
)
|
|
666
|
+
index_to_groundtruth_id = create_mapping(
|
|
667
|
+
tbl, pairs, 1, "gt_id", "gt_uid"
|
|
668
|
+
)
|
|
669
|
+
index_to_prediction_id = create_mapping(
|
|
670
|
+
tbl, pairs, 2, "pd_id", "pd_uid"
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
(
|
|
674
|
+
mask_tp,
|
|
675
|
+
mask_fp_fn_misclf,
|
|
676
|
+
mask_fp_unmatched,
|
|
677
|
+
mask_fn_unmatched,
|
|
678
|
+
) = compute_pair_classifications(
|
|
679
|
+
detailed_pairs=pairs,
|
|
680
|
+
iou_thresholds=np.array(iou_thresholds),
|
|
681
|
+
score_thresholds=np.array(score_thresholds),
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
mask_fn = mask_fp_fn_misclf | mask_fn_unmatched
|
|
685
|
+
mask_fp = mask_fp_fn_misclf | mask_fp_unmatched
|
|
686
|
+
|
|
687
|
+
batch_examples = unpack_examples(
|
|
688
|
+
detailed_pairs=pairs,
|
|
689
|
+
mask_tp=mask_tp,
|
|
690
|
+
mask_fp=mask_fp,
|
|
691
|
+
mask_fn=mask_fn,
|
|
692
|
+
index_to_datum_id=index_to_datum_id,
|
|
693
|
+
index_to_groundtruth_id=index_to_groundtruth_id,
|
|
694
|
+
index_to_prediction_id=index_to_prediction_id,
|
|
695
|
+
iou_thresholds=iou_thresholds,
|
|
696
|
+
score_thresholds=score_thresholds,
|
|
697
|
+
)
|
|
698
|
+
metrics.extend(batch_examples)
|
|
699
|
+
|
|
700
|
+
return metrics
|
|
701
|
+
|
|
702
|
+
def compute_confusion_matrix_with_examples(
|
|
703
|
+
self,
|
|
704
|
+
iou_thresholds: list[float],
|
|
705
|
+
score_thresholds: list[float],
|
|
706
|
+
datums: pc.Expression | None = None,
|
|
707
|
+
) -> list[Metric]:
|
|
708
|
+
"""
|
|
709
|
+
Computes confusion matrix with examples at various thresholds.
|
|
710
|
+
|
|
711
|
+
This function can use a lot of memory with larger or high density datasets. Please use it with filters.
|
|
712
|
+
|
|
713
|
+
Parameters
|
|
714
|
+
----------
|
|
715
|
+
iou_thresholds : list[float]
|
|
716
|
+
A list of IOU thresholds to compute metrics over.
|
|
717
|
+
score_thresholds : list[float]
|
|
718
|
+
A list of score thresholds to compute metrics over.
|
|
719
|
+
datums : pyarrow.compute.Expression, optional
|
|
720
|
+
Option to filter datums by an expression.
|
|
721
|
+
|
|
722
|
+
Returns
|
|
723
|
+
-------
|
|
724
|
+
list[Metric]
|
|
725
|
+
List of confusion matrices per threshold pair.
|
|
726
|
+
"""
|
|
727
|
+
if not iou_thresholds:
|
|
728
|
+
raise ValueError("At least one IOU threshold must be passed.")
|
|
729
|
+
elif not score_thresholds:
|
|
730
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
731
|
+
|
|
732
|
+
metrics = {
|
|
733
|
+
iou_idx: {
|
|
734
|
+
score_idx: create_empty_confusion_matrix_with_examples(
|
|
735
|
+
iou_threhsold=iou_thresh,
|
|
736
|
+
score_threshold=score_thresh,
|
|
737
|
+
index_to_label=self._index_to_label,
|
|
738
|
+
)
|
|
739
|
+
for score_idx, score_thresh in enumerate(score_thresholds)
|
|
740
|
+
}
|
|
741
|
+
for iou_idx, iou_thresh in enumerate(iou_thresholds)
|
|
742
|
+
}
|
|
743
|
+
tbl_columns = [
|
|
744
|
+
"datum_uid",
|
|
745
|
+
"gt_uid",
|
|
746
|
+
"pd_uid",
|
|
747
|
+
]
|
|
748
|
+
numeric_columns = [
|
|
749
|
+
"datum_id",
|
|
750
|
+
"gt_id",
|
|
751
|
+
"pd_id",
|
|
752
|
+
"gt_label_id",
|
|
753
|
+
"pd_label_id",
|
|
754
|
+
"iou",
|
|
755
|
+
"pd_score",
|
|
756
|
+
]
|
|
757
|
+
for tbl, pairs in self._detailed_reader.iterate_tables_with_arrays(
|
|
758
|
+
columns=tbl_columns + numeric_columns,
|
|
759
|
+
numeric_columns=numeric_columns,
|
|
760
|
+
filter=datums,
|
|
761
|
+
):
|
|
762
|
+
if pairs.size == 0:
|
|
763
|
+
continue
|
|
764
|
+
|
|
765
|
+
index_to_datum_id = {}
|
|
766
|
+
index_to_groundtruth_id = {}
|
|
767
|
+
index_to_prediction_id = {}
|
|
768
|
+
|
|
769
|
+
# extract external identifiers
|
|
770
|
+
index_to_datum_id = create_mapping(
|
|
771
|
+
tbl, pairs, 0, "datum_id", "datum_uid"
|
|
772
|
+
)
|
|
773
|
+
index_to_groundtruth_id = create_mapping(
|
|
774
|
+
tbl, pairs, 1, "gt_id", "gt_uid"
|
|
775
|
+
)
|
|
776
|
+
index_to_prediction_id = create_mapping(
|
|
777
|
+
tbl, pairs, 2, "pd_id", "pd_uid"
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
(
|
|
781
|
+
mask_tp,
|
|
782
|
+
mask_fp_fn_misclf,
|
|
783
|
+
mask_fp_unmatched,
|
|
784
|
+
mask_fn_unmatched,
|
|
785
|
+
) = compute_pair_classifications(
|
|
786
|
+
detailed_pairs=pairs,
|
|
787
|
+
iou_thresholds=np.array(iou_thresholds),
|
|
788
|
+
score_thresholds=np.array(score_thresholds),
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
unpack_confusion_matrix_with_examples(
|
|
792
|
+
metrics=metrics,
|
|
793
|
+
detailed_pairs=pairs,
|
|
794
|
+
mask_tp=mask_tp,
|
|
795
|
+
mask_fp_fn_misclf=mask_fp_fn_misclf,
|
|
796
|
+
mask_fp_unmatched=mask_fp_unmatched,
|
|
797
|
+
mask_fn_unmatched=mask_fn_unmatched,
|
|
798
|
+
index_to_datum_id=index_to_datum_id,
|
|
799
|
+
index_to_groundtruth_id=index_to_groundtruth_id,
|
|
800
|
+
index_to_prediction_id=index_to_prediction_id,
|
|
801
|
+
index_to_label=self._index_to_label,
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
return [m for inner in metrics.values() for m in inner.values()]
|