valor-lite 0.36.6__py3-none-any.whl → 0.37.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +211 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +536 -0
- valor_lite/classification/__init__.py +5 -10
- valor_lite/classification/annotation.py +4 -0
- valor_lite/classification/computation.py +233 -251
- valor_lite/classification/evaluator.py +882 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +141 -4
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +221 -118
- valor_lite/exceptions.py +5 -0
- valor_lite/object_detection/__init__.py +5 -4
- valor_lite/object_detection/annotation.py +13 -1
- valor_lite/object_detection/computation.py +368 -299
- valor_lite/object_detection/evaluator.py +804 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +152 -3
- valor_lite/object_detection/shared.py +206 -0
- valor_lite/object_detection/utilities.py +182 -100
- valor_lite/semantic_segmentation/__init__.py +5 -4
- valor_lite/semantic_segmentation/annotation.py +7 -0
- valor_lite/semantic_segmentation/computation.py +20 -110
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +6 -23
- {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
- valor_lite-0.37.5.dist-info/RECORD +49 -0
- {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
- valor_lite/classification/manager.py +0 -545
- valor_lite/object_detection/manager.py +0 -864
- valor_lite/profiling.py +0 -374
- valor_lite/semantic_segmentation/benchmark.py +0 -237
- valor_lite/semantic_segmentation/manager.py +0 -446
- valor_lite-0.36.6.dist-info/RECORD +0 -41
- {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,882 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
import pyarrow.compute as pc
|
|
9
|
+
|
|
10
|
+
from valor_lite.cache import (
|
|
11
|
+
FileCacheReader,
|
|
12
|
+
FileCacheWriter,
|
|
13
|
+
MemoryCacheReader,
|
|
14
|
+
MemoryCacheWriter,
|
|
15
|
+
compute,
|
|
16
|
+
)
|
|
17
|
+
from valor_lite.classification.computation import (
|
|
18
|
+
compute_accuracy,
|
|
19
|
+
compute_confusion_matrix,
|
|
20
|
+
compute_counts,
|
|
21
|
+
compute_f1_score,
|
|
22
|
+
compute_pair_classifications,
|
|
23
|
+
compute_precision,
|
|
24
|
+
compute_recall,
|
|
25
|
+
compute_rocauc,
|
|
26
|
+
)
|
|
27
|
+
from valor_lite.classification.metric import Metric, MetricType
|
|
28
|
+
from valor_lite.classification.shared import (
|
|
29
|
+
EvaluatorInfo,
|
|
30
|
+
decode_metadata_fields,
|
|
31
|
+
encode_metadata_fields,
|
|
32
|
+
extract_counts,
|
|
33
|
+
extract_groundtruth_count_per_label,
|
|
34
|
+
extract_labels,
|
|
35
|
+
generate_cache_path,
|
|
36
|
+
generate_intermediate_cache_path,
|
|
37
|
+
generate_intermediate_schema,
|
|
38
|
+
generate_metadata_path,
|
|
39
|
+
generate_roc_curve_cache_path,
|
|
40
|
+
generate_roc_curve_schema,
|
|
41
|
+
generate_schema,
|
|
42
|
+
)
|
|
43
|
+
from valor_lite.classification.utilities import (
|
|
44
|
+
create_empty_confusion_matrix_with_examples,
|
|
45
|
+
create_mapping,
|
|
46
|
+
unpack_confusion_matrix,
|
|
47
|
+
unpack_confusion_matrix_with_examples,
|
|
48
|
+
unpack_examples,
|
|
49
|
+
unpack_precision_recall,
|
|
50
|
+
unpack_rocauc,
|
|
51
|
+
)
|
|
52
|
+
from valor_lite.exceptions import EmptyCacheError
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Builder:
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
writer: MemoryCacheWriter | FileCacheWriter,
|
|
59
|
+
roc_curve_writer: MemoryCacheWriter | FileCacheWriter,
|
|
60
|
+
intermediate_writer: MemoryCacheWriter | FileCacheWriter,
|
|
61
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
62
|
+
):
|
|
63
|
+
self._writer = writer
|
|
64
|
+
self._roc_curve_writer = roc_curve_writer
|
|
65
|
+
self._intermediate_writer = intermediate_writer
|
|
66
|
+
self._metadata_fields = metadata_fields
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def in_memory(
|
|
70
|
+
cls,
|
|
71
|
+
batch_size: int = 10_000,
|
|
72
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
73
|
+
):
|
|
74
|
+
"""
|
|
75
|
+
Create an in-memory evaluator cache.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
batch_size : int, default=10_000
|
|
80
|
+
The target number of rows to buffer before writing to the cache. Defaults to 10_000.
|
|
81
|
+
metadata_fields : list[tuple[str, str | pa.DataType]], optional
|
|
82
|
+
Optional metadata field definitions.
|
|
83
|
+
"""
|
|
84
|
+
writer = MemoryCacheWriter.create(
|
|
85
|
+
schema=generate_schema(metadata_fields),
|
|
86
|
+
batch_size=batch_size,
|
|
87
|
+
)
|
|
88
|
+
intermediate_writer = MemoryCacheWriter.create(
|
|
89
|
+
schema=generate_intermediate_schema(),
|
|
90
|
+
batch_size=batch_size,
|
|
91
|
+
)
|
|
92
|
+
roc_curve_writer = MemoryCacheWriter.create(
|
|
93
|
+
schema=generate_roc_curve_schema(),
|
|
94
|
+
batch_size=batch_size,
|
|
95
|
+
)
|
|
96
|
+
return cls(
|
|
97
|
+
writer=writer,
|
|
98
|
+
roc_curve_writer=roc_curve_writer,
|
|
99
|
+
intermediate_writer=intermediate_writer,
|
|
100
|
+
metadata_fields=metadata_fields,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
def persistent(
|
|
105
|
+
cls,
|
|
106
|
+
path: str | Path,
|
|
107
|
+
batch_size: int = 10_000,
|
|
108
|
+
rows_per_file: int = 100_000,
|
|
109
|
+
compression: str = "snappy",
|
|
110
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
111
|
+
):
|
|
112
|
+
"""
|
|
113
|
+
Create a persistent file-based evaluator cache.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
path : str | Path
|
|
118
|
+
Where to store file-based cache.
|
|
119
|
+
batch_size : int, default=10_000
|
|
120
|
+
Sets the batch size for writing to file.
|
|
121
|
+
rows_per_file : int, default=100_000
|
|
122
|
+
Sets the maximum number of rows per file. This may be exceeded as files are datum aligned.
|
|
123
|
+
compression : str, default="snappy"
|
|
124
|
+
Sets the pyarrow compression method.
|
|
125
|
+
metadata_fields : list[tuple[str, str | pa.DataType]], optional
|
|
126
|
+
Optionally sets metadata description for use in filtering.
|
|
127
|
+
"""
|
|
128
|
+
path = Path(path)
|
|
129
|
+
|
|
130
|
+
# create cache
|
|
131
|
+
writer = FileCacheWriter.create(
|
|
132
|
+
path=generate_cache_path(path),
|
|
133
|
+
schema=generate_schema(metadata_fields),
|
|
134
|
+
batch_size=batch_size,
|
|
135
|
+
rows_per_file=rows_per_file,
|
|
136
|
+
compression=compression,
|
|
137
|
+
)
|
|
138
|
+
intermediate_writer = FileCacheWriter.create(
|
|
139
|
+
path=generate_intermediate_cache_path(path),
|
|
140
|
+
schema=generate_intermediate_schema(),
|
|
141
|
+
batch_size=batch_size,
|
|
142
|
+
rows_per_file=rows_per_file,
|
|
143
|
+
compression=compression,
|
|
144
|
+
)
|
|
145
|
+
roc_curve_writer = FileCacheWriter.create(
|
|
146
|
+
path=generate_roc_curve_cache_path(path),
|
|
147
|
+
schema=generate_roc_curve_schema(),
|
|
148
|
+
batch_size=batch_size,
|
|
149
|
+
rows_per_file=rows_per_file,
|
|
150
|
+
compression=compression,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# write metadatata config
|
|
154
|
+
metadata_path = generate_metadata_path(path)
|
|
155
|
+
with open(metadata_path, "w") as f:
|
|
156
|
+
encoded_types = encode_metadata_fields(metadata_fields)
|
|
157
|
+
json.dump(encoded_types, f, indent=2)
|
|
158
|
+
|
|
159
|
+
return cls(
|
|
160
|
+
writer=writer,
|
|
161
|
+
roc_curve_writer=roc_curve_writer,
|
|
162
|
+
intermediate_writer=intermediate_writer,
|
|
163
|
+
metadata_fields=metadata_fields,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def _create_rocauc_intermediate(
|
|
167
|
+
self,
|
|
168
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
169
|
+
batch_size: int,
|
|
170
|
+
index_to_label: dict[int, str],
|
|
171
|
+
):
|
|
172
|
+
n_labels = len(index_to_label)
|
|
173
|
+
compute.sort(
|
|
174
|
+
source=reader,
|
|
175
|
+
sink=self._intermediate_writer,
|
|
176
|
+
batch_size=batch_size,
|
|
177
|
+
sorting=[
|
|
178
|
+
("pd_score", "descending"),
|
|
179
|
+
# ("pd_label_id", "ascending"),
|
|
180
|
+
("match", "ascending"),
|
|
181
|
+
],
|
|
182
|
+
columns=[
|
|
183
|
+
"pd_label_id",
|
|
184
|
+
"pd_score",
|
|
185
|
+
"match",
|
|
186
|
+
],
|
|
187
|
+
)
|
|
188
|
+
intermediate = self._intermediate_writer.to_reader()
|
|
189
|
+
|
|
190
|
+
running_max_fp = np.zeros(n_labels, dtype=np.uint64)
|
|
191
|
+
running_max_tp = np.zeros(n_labels, dtype=np.uint64)
|
|
192
|
+
running_max_scores = np.zeros(n_labels, dtype=np.float64)
|
|
193
|
+
|
|
194
|
+
last_pair = np.zeros((n_labels, 2), dtype=np.uint64)
|
|
195
|
+
for tbl in intermediate.iterate_tables():
|
|
196
|
+
pd_label_ids = tbl["pd_label_id"].to_numpy()
|
|
197
|
+
tps = tbl["match"].to_numpy()
|
|
198
|
+
scores = tbl["pd_score"].to_numpy()
|
|
199
|
+
fps = ~tps
|
|
200
|
+
|
|
201
|
+
for idx in index_to_label.keys():
|
|
202
|
+
mask = pd_label_ids == idx
|
|
203
|
+
if mask.sum() == 0:
|
|
204
|
+
continue
|
|
205
|
+
|
|
206
|
+
cumulative_fp = np.r_[
|
|
207
|
+
running_max_fp[idx],
|
|
208
|
+
np.cumsum(fps[mask]) + running_max_fp[idx],
|
|
209
|
+
]
|
|
210
|
+
cumulative_tp = np.r_[
|
|
211
|
+
running_max_tp[idx],
|
|
212
|
+
np.cumsum(tps[mask]) + running_max_tp[idx],
|
|
213
|
+
]
|
|
214
|
+
pd_scores = np.r_[running_max_scores[idx], scores[mask]]
|
|
215
|
+
|
|
216
|
+
indices = (
|
|
217
|
+
np.where(
|
|
218
|
+
np.diff(np.r_[running_max_scores[idx], pd_scores])
|
|
219
|
+
)[0]
|
|
220
|
+
- 1
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
running_max_fp[idx] = cumulative_fp[-1]
|
|
224
|
+
running_max_tp[idx] = cumulative_tp[-1]
|
|
225
|
+
running_max_scores[idx] = pd_scores[-1]
|
|
226
|
+
|
|
227
|
+
for fp, tp in zip(
|
|
228
|
+
cumulative_fp[indices],
|
|
229
|
+
cumulative_tp[indices],
|
|
230
|
+
):
|
|
231
|
+
last_pair[idx, 0] = fp
|
|
232
|
+
last_pair[idx, 1] = tp
|
|
233
|
+
self._roc_curve_writer.write_rows(
|
|
234
|
+
[
|
|
235
|
+
{
|
|
236
|
+
"pd_label_id": idx,
|
|
237
|
+
"cumulative_fp": fp,
|
|
238
|
+
"cumulative_tp": tp,
|
|
239
|
+
}
|
|
240
|
+
]
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# ensure any remaining values are ingested
|
|
244
|
+
for idx in range(n_labels):
|
|
245
|
+
last_fp = last_pair[idx, 0]
|
|
246
|
+
last_tp = last_pair[idx, 1]
|
|
247
|
+
if (
|
|
248
|
+
last_fp != running_max_fp[idx]
|
|
249
|
+
or last_tp != running_max_tp[idx]
|
|
250
|
+
):
|
|
251
|
+
self._roc_curve_writer.write_rows(
|
|
252
|
+
[
|
|
253
|
+
{
|
|
254
|
+
"pd_label_id": idx,
|
|
255
|
+
"cumulative_fp": running_max_fp[idx],
|
|
256
|
+
"cumulative_tp": running_max_tp[idx],
|
|
257
|
+
}
|
|
258
|
+
]
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def finalize(
|
|
262
|
+
self,
|
|
263
|
+
batch_size: int = 1_000,
|
|
264
|
+
index_to_label_override: dict[int, str] | None = None,
|
|
265
|
+
):
|
|
266
|
+
"""
|
|
267
|
+
Performs data finalization and some preprocessing steps.
|
|
268
|
+
|
|
269
|
+
Parameters
|
|
270
|
+
----------
|
|
271
|
+
batch_size : int, default=1_000
|
|
272
|
+
Sets the maximum number of elements read into memory per-file when performing merge sort.
|
|
273
|
+
index_to_label_override : dict[int, str], optional
|
|
274
|
+
Pre-configures label mapping. Used when operating over filtered subsets.
|
|
275
|
+
|
|
276
|
+
Returns
|
|
277
|
+
-------
|
|
278
|
+
Evaluator
|
|
279
|
+
A ready-to-use evaluator object.
|
|
280
|
+
"""
|
|
281
|
+
self._writer.flush()
|
|
282
|
+
if self._writer.count_rows() == 0:
|
|
283
|
+
raise EmptyCacheError()
|
|
284
|
+
elif self._roc_curve_writer.count_rows() > 0:
|
|
285
|
+
raise RuntimeError("data already finalized")
|
|
286
|
+
|
|
287
|
+
# sort in-place and locally
|
|
288
|
+
self._writer.sort_by(
|
|
289
|
+
[
|
|
290
|
+
("pd_score", "descending"),
|
|
291
|
+
("datum_id", "ascending"),
|
|
292
|
+
("gt_label_id", "ascending"),
|
|
293
|
+
("pd_label_id", "ascending"),
|
|
294
|
+
]
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# post-process into sorted writer
|
|
298
|
+
reader = self._writer.to_reader()
|
|
299
|
+
|
|
300
|
+
# extract labels
|
|
301
|
+
index_to_label = extract_labels(
|
|
302
|
+
reader=reader,
|
|
303
|
+
index_to_label_override=index_to_label_override,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
self._create_rocauc_intermediate(
|
|
307
|
+
reader=reader,
|
|
308
|
+
batch_size=batch_size,
|
|
309
|
+
index_to_label=index_to_label,
|
|
310
|
+
)
|
|
311
|
+
roc_curve_reader = self._roc_curve_writer.to_reader()
|
|
312
|
+
|
|
313
|
+
return Evaluator(
|
|
314
|
+
reader=reader,
|
|
315
|
+
roc_curve_reader=roc_curve_reader,
|
|
316
|
+
index_to_label=index_to_label,
|
|
317
|
+
metadata_fields=self._metadata_fields,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class Evaluator:
|
|
322
|
+
def __init__(
|
|
323
|
+
self,
|
|
324
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
325
|
+
roc_curve_reader: MemoryCacheReader | FileCacheReader,
|
|
326
|
+
index_to_label: dict[int, str],
|
|
327
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
328
|
+
):
|
|
329
|
+
self._reader = reader
|
|
330
|
+
self._roc_curve_reader = roc_curve_reader
|
|
331
|
+
self._index_to_label = index_to_label
|
|
332
|
+
self._metadata_fields = metadata_fields
|
|
333
|
+
|
|
334
|
+
@property
|
|
335
|
+
def info(self) -> EvaluatorInfo:
|
|
336
|
+
return self.get_info()
|
|
337
|
+
|
|
338
|
+
def get_info(
|
|
339
|
+
self,
|
|
340
|
+
datums: pc.Expression | None = None,
|
|
341
|
+
) -> EvaluatorInfo:
|
|
342
|
+
info = EvaluatorInfo()
|
|
343
|
+
info.metadata_fields = self._metadata_fields
|
|
344
|
+
info.number_of_rows = self._reader.count_rows()
|
|
345
|
+
info.number_of_labels = len(self._index_to_label)
|
|
346
|
+
info.number_of_datums = extract_counts(
|
|
347
|
+
reader=self._reader,
|
|
348
|
+
datums=datums,
|
|
349
|
+
)
|
|
350
|
+
return info
|
|
351
|
+
|
|
352
|
+
@classmethod
|
|
353
|
+
def load(
|
|
354
|
+
cls,
|
|
355
|
+
path: str | Path,
|
|
356
|
+
index_to_label_override: dict[int, str] | None = None,
|
|
357
|
+
):
|
|
358
|
+
"""
|
|
359
|
+
Load from an existing classification cache.
|
|
360
|
+
|
|
361
|
+
Parameters
|
|
362
|
+
----------
|
|
363
|
+
path : str | Path
|
|
364
|
+
Path to the existing cache.
|
|
365
|
+
index_to_label_override : dict[int, str], optional
|
|
366
|
+
Option to preset index to label dictionary. Used when loading from filtered caches.
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
# validate path
|
|
370
|
+
path = Path(path)
|
|
371
|
+
if not path.exists():
|
|
372
|
+
raise FileNotFoundError(f"Directory does not exist: {path}")
|
|
373
|
+
elif not path.is_dir():
|
|
374
|
+
raise NotADirectoryError(
|
|
375
|
+
f"Path exists but is not a directory: {path}"
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# load cache
|
|
379
|
+
reader = FileCacheReader.load(generate_cache_path(path))
|
|
380
|
+
roc_curve_reader = FileCacheReader.load(
|
|
381
|
+
generate_roc_curve_cache_path(path)
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# extract labels
|
|
385
|
+
index_to_label = extract_labels(
|
|
386
|
+
reader=reader,
|
|
387
|
+
index_to_label_override=index_to_label_override,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# read config
|
|
391
|
+
metadata_path = generate_metadata_path(path)
|
|
392
|
+
metadata_fields = None
|
|
393
|
+
with open(metadata_path, "r") as f:
|
|
394
|
+
encoded_types = json.load(f)
|
|
395
|
+
metadata_fields = decode_metadata_fields(encoded_types)
|
|
396
|
+
|
|
397
|
+
return cls(
|
|
398
|
+
reader=reader,
|
|
399
|
+
roc_curve_reader=roc_curve_reader,
|
|
400
|
+
index_to_label=index_to_label,
|
|
401
|
+
metadata_fields=metadata_fields,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
def filter(
|
|
405
|
+
self,
|
|
406
|
+
datums: pc.Expression | None = None,
|
|
407
|
+
groundtruths: pc.Expression | None = None,
|
|
408
|
+
predictions: pc.Expression | None = None,
|
|
409
|
+
path: str | Path | None = None,
|
|
410
|
+
) -> Evaluator:
|
|
411
|
+
"""
|
|
412
|
+
Filter evaluator cache.
|
|
413
|
+
|
|
414
|
+
Parameters
|
|
415
|
+
----------
|
|
416
|
+
datums : pc.Expression | None = None
|
|
417
|
+
A filter expression used to filter datums.
|
|
418
|
+
groundtruths : pc.Expression | None = None
|
|
419
|
+
A filter expression used to filter ground truth annotations.
|
|
420
|
+
predictions : pc.Expression | None = None
|
|
421
|
+
A filter expression used to filter predictions.
|
|
422
|
+
path : str | Path, optional
|
|
423
|
+
Where to store the filtered cache if storing on disk.
|
|
424
|
+
|
|
425
|
+
Returns
|
|
426
|
+
-------
|
|
427
|
+
Evaluator
|
|
428
|
+
A new evaluator object containing the filtered cache.
|
|
429
|
+
"""
|
|
430
|
+
from valor_lite.classification.loader import Loader
|
|
431
|
+
|
|
432
|
+
if isinstance(self._reader, FileCacheReader):
|
|
433
|
+
if not path:
|
|
434
|
+
raise ValueError(
|
|
435
|
+
"expected path to be defined for file-based loader"
|
|
436
|
+
)
|
|
437
|
+
loader = Loader.persistent(
|
|
438
|
+
path=path,
|
|
439
|
+
batch_size=self._reader.batch_size,
|
|
440
|
+
rows_per_file=self._reader.rows_per_file,
|
|
441
|
+
compression=self._reader.compression,
|
|
442
|
+
metadata_fields=self.info.metadata_fields,
|
|
443
|
+
)
|
|
444
|
+
else:
|
|
445
|
+
loader = Loader.in_memory(
|
|
446
|
+
batch_size=self._reader.batch_size,
|
|
447
|
+
metadata_fields=self.info.metadata_fields,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
for tbl in self._reader.iterate_tables(filter=datums):
|
|
451
|
+
columns = (
|
|
452
|
+
"datum_id",
|
|
453
|
+
"gt_label_id",
|
|
454
|
+
"pd_label_id",
|
|
455
|
+
)
|
|
456
|
+
pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
|
|
457
|
+
|
|
458
|
+
n_pairs = pairs.shape[0]
|
|
459
|
+
gt_ids = pairs[:, (0, 1)].astype(np.int64)
|
|
460
|
+
pd_ids = pairs[:, (0, 2)].astype(np.int64)
|
|
461
|
+
|
|
462
|
+
if groundtruths is not None:
|
|
463
|
+
mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
|
|
464
|
+
gt_tbl = tbl.filter(groundtruths)
|
|
465
|
+
gt_pairs = np.column_stack(
|
|
466
|
+
[
|
|
467
|
+
gt_tbl[col].to_numpy()
|
|
468
|
+
for col in ("datum_id", "gt_label_id")
|
|
469
|
+
]
|
|
470
|
+
).astype(np.int64)
|
|
471
|
+
for gt in np.unique(gt_pairs, axis=0):
|
|
472
|
+
mask_valid_gt |= (gt_ids == gt).all(axis=1)
|
|
473
|
+
else:
|
|
474
|
+
mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
|
|
475
|
+
|
|
476
|
+
if predictions is not None:
|
|
477
|
+
mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
|
|
478
|
+
pd_tbl = tbl.filter(predictions)
|
|
479
|
+
pd_pairs = np.column_stack(
|
|
480
|
+
[
|
|
481
|
+
pd_tbl[col].to_numpy()
|
|
482
|
+
for col in ("datum_id", "pd_label_id")
|
|
483
|
+
]
|
|
484
|
+
).astype(np.int64)
|
|
485
|
+
for pd in np.unique(pd_pairs, axis=0):
|
|
486
|
+
mask_valid_pd |= (pd_ids == pd).all(axis=1)
|
|
487
|
+
else:
|
|
488
|
+
mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
|
|
489
|
+
|
|
490
|
+
mask_valid = mask_valid_gt | mask_valid_pd
|
|
491
|
+
mask_valid_gt &= mask_valid
|
|
492
|
+
mask_valid_pd &= mask_valid
|
|
493
|
+
|
|
494
|
+
pairs[~mask_valid_gt, 1] = -1
|
|
495
|
+
pairs[~mask_valid_pd, 2] = -1
|
|
496
|
+
|
|
497
|
+
for idx, col in enumerate(columns):
|
|
498
|
+
tbl = tbl.set_column(
|
|
499
|
+
tbl.schema.names.index(col), col, pa.array(pairs[:, idx])
|
|
500
|
+
)
|
|
501
|
+
# TODO (c.zaloom) - improve write strategy, filtered data could be small
|
|
502
|
+
loader._writer.write_table(tbl)
|
|
503
|
+
|
|
504
|
+
return loader.finalize(index_to_label_override=self._index_to_label)
|
|
505
|
+
|
|
506
|
+
def iterate_values(self, datums: pc.Expression | None = None):
|
|
507
|
+
columns = [
|
|
508
|
+
"datum_id",
|
|
509
|
+
"gt_label_id",
|
|
510
|
+
"pd_label_id",
|
|
511
|
+
"pd_score",
|
|
512
|
+
"pd_winner",
|
|
513
|
+
"match",
|
|
514
|
+
]
|
|
515
|
+
for tbl in self._reader.iterate_tables(columns=columns, filter=datums):
|
|
516
|
+
ids = np.column_stack(
|
|
517
|
+
[
|
|
518
|
+
tbl[col].to_numpy()
|
|
519
|
+
for col in [
|
|
520
|
+
"datum_id",
|
|
521
|
+
"gt_label_id",
|
|
522
|
+
"pd_label_id",
|
|
523
|
+
]
|
|
524
|
+
]
|
|
525
|
+
)
|
|
526
|
+
scores = tbl["pd_score"].to_numpy()
|
|
527
|
+
winners = tbl["pd_winner"].to_numpy()
|
|
528
|
+
matches = tbl["match"].to_numpy()
|
|
529
|
+
yield ids, scores, winners, matches
|
|
530
|
+
|
|
531
|
+
def compute_rocauc(self) -> dict[MetricType, list[Metric]]:
|
|
532
|
+
"""
|
|
533
|
+
Compute ROCAUC.
|
|
534
|
+
|
|
535
|
+
This function does not support direct filtering. To perform evaluation over a filtered
|
|
536
|
+
set you must first create a new evaluator using `Evaluator.filter`.
|
|
537
|
+
|
|
538
|
+
Returns
|
|
539
|
+
-------
|
|
540
|
+
dict[MetricType, list[Metric]]
|
|
541
|
+
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
542
|
+
"""
|
|
543
|
+
n_labels = self.info.number_of_labels
|
|
544
|
+
|
|
545
|
+
rocauc = np.zeros(n_labels, dtype=np.float64)
|
|
546
|
+
label_counts = extract_groundtruth_count_per_label(
|
|
547
|
+
reader=self._reader,
|
|
548
|
+
number_of_labels=len(self._index_to_label),
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
prev = np.zeros((n_labels, 2), dtype=np.uint64)
|
|
552
|
+
for array in self._roc_curve_reader.iterate_arrays(
|
|
553
|
+
numeric_columns=[
|
|
554
|
+
"pd_label_id",
|
|
555
|
+
"cumulative_fp",
|
|
556
|
+
"cumulative_tp",
|
|
557
|
+
],
|
|
558
|
+
):
|
|
559
|
+
rocauc, prev = compute_rocauc(
|
|
560
|
+
rocauc=rocauc,
|
|
561
|
+
array=array,
|
|
562
|
+
gt_count_per_label=label_counts[:, 0],
|
|
563
|
+
pd_count_per_label=label_counts[:, 1],
|
|
564
|
+
n_labels=self.info.number_of_labels,
|
|
565
|
+
prev=prev,
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
mean_rocauc = rocauc.mean()
|
|
569
|
+
|
|
570
|
+
return unpack_rocauc(
|
|
571
|
+
rocauc=rocauc,
|
|
572
|
+
mean_rocauc=mean_rocauc,
|
|
573
|
+
index_to_label=self._index_to_label,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
def compute_precision_recall(
|
|
577
|
+
self,
|
|
578
|
+
score_thresholds: list[float] = [0.0],
|
|
579
|
+
hardmax: bool = True,
|
|
580
|
+
datums: pc.Expression | None = None,
|
|
581
|
+
) -> dict[MetricType, list]:
|
|
582
|
+
"""
|
|
583
|
+
Performs an evaluation and returns metrics.
|
|
584
|
+
|
|
585
|
+
Parameters
|
|
586
|
+
----------
|
|
587
|
+
score_thresholds : list[float]
|
|
588
|
+
A list of score thresholds to compute metrics over.
|
|
589
|
+
hardmax : bool
|
|
590
|
+
Toggles whether a hardmax is applied to predictions.
|
|
591
|
+
datums : pyarrow.compute.Expression, optional
|
|
592
|
+
Option to filter datums by an expression.
|
|
593
|
+
|
|
594
|
+
Returns
|
|
595
|
+
-------
|
|
596
|
+
dict[MetricType, list]
|
|
597
|
+
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
598
|
+
"""
|
|
599
|
+
if not score_thresholds:
|
|
600
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
601
|
+
|
|
602
|
+
n_scores = len(score_thresholds)
|
|
603
|
+
n_datums = self.info.number_of_datums
|
|
604
|
+
n_labels = self.info.number_of_labels
|
|
605
|
+
|
|
606
|
+
# intermediates
|
|
607
|
+
counts = np.zeros((n_scores, n_labels, 4), dtype=np.uint64)
|
|
608
|
+
|
|
609
|
+
for ids, scores, winners, _ in self.iterate_values(datums=datums):
|
|
610
|
+
batch_counts = compute_counts(
|
|
611
|
+
ids=ids,
|
|
612
|
+
scores=scores,
|
|
613
|
+
winners=winners,
|
|
614
|
+
score_thresholds=np.array(score_thresholds),
|
|
615
|
+
hardmax=hardmax,
|
|
616
|
+
n_labels=n_labels,
|
|
617
|
+
)
|
|
618
|
+
counts += batch_counts
|
|
619
|
+
|
|
620
|
+
precision = compute_precision(counts)
|
|
621
|
+
recall = compute_recall(counts)
|
|
622
|
+
f1_score = compute_f1_score(precision, recall)
|
|
623
|
+
accuracy = compute_accuracy(counts, n_datums=n_datums)
|
|
624
|
+
|
|
625
|
+
return unpack_precision_recall(
|
|
626
|
+
counts=counts,
|
|
627
|
+
precision=precision,
|
|
628
|
+
recall=recall,
|
|
629
|
+
accuracy=accuracy,
|
|
630
|
+
f1_score=f1_score,
|
|
631
|
+
score_thresholds=score_thresholds,
|
|
632
|
+
hardmax=hardmax,
|
|
633
|
+
index_to_label=self._index_to_label,
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
def compute_confusion_matrix(
|
|
637
|
+
self,
|
|
638
|
+
score_thresholds: list[float] = [0.0],
|
|
639
|
+
hardmax: bool = True,
|
|
640
|
+
datums: pc.Expression | None = None,
|
|
641
|
+
) -> list[Metric]:
|
|
642
|
+
"""
|
|
643
|
+
Compute a confusion matrix.
|
|
644
|
+
|
|
645
|
+
Parameters
|
|
646
|
+
----------
|
|
647
|
+
score_thresholds : list[float]
|
|
648
|
+
A list of score thresholds to compute metrics over.
|
|
649
|
+
hardmax : bool
|
|
650
|
+
Toggles whether a hardmax is applied to predictions.
|
|
651
|
+
datums : pyarrow.compute.Expression, optional
|
|
652
|
+
Option to filter datums by an expression.
|
|
653
|
+
|
|
654
|
+
Returns
|
|
655
|
+
-------
|
|
656
|
+
list[Metric]
|
|
657
|
+
A list of confusion matrices.
|
|
658
|
+
"""
|
|
659
|
+
if not score_thresholds:
|
|
660
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
661
|
+
|
|
662
|
+
n_scores = len(score_thresholds)
|
|
663
|
+
n_labels = len(self._index_to_label)
|
|
664
|
+
confusion_matrices = np.zeros(
|
|
665
|
+
(n_scores, n_labels, n_labels), dtype=np.uint64
|
|
666
|
+
)
|
|
667
|
+
unmatched_groundtruths = np.zeros(
|
|
668
|
+
(n_scores, n_labels), dtype=np.uint64
|
|
669
|
+
)
|
|
670
|
+
for ids, scores, winners, matches in self.iterate_values(
|
|
671
|
+
datums=datums
|
|
672
|
+
):
|
|
673
|
+
(
|
|
674
|
+
mask_tp,
|
|
675
|
+
mask_fp_fn_misclf,
|
|
676
|
+
mask_fn_unmatched,
|
|
677
|
+
) = compute_pair_classifications(
|
|
678
|
+
ids=ids,
|
|
679
|
+
scores=scores,
|
|
680
|
+
winners=winners,
|
|
681
|
+
score_thresholds=np.array(score_thresholds),
|
|
682
|
+
hardmax=hardmax,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
batch_cm, batch_ugt = compute_confusion_matrix(
|
|
686
|
+
ids=ids,
|
|
687
|
+
mask_tp=mask_tp,
|
|
688
|
+
mask_fp_fn_misclf=mask_fp_fn_misclf,
|
|
689
|
+
mask_fn_unmatched=mask_fn_unmatched,
|
|
690
|
+
score_thresholds=np.array(score_thresholds),
|
|
691
|
+
n_labels=n_labels,
|
|
692
|
+
)
|
|
693
|
+
confusion_matrices += batch_cm
|
|
694
|
+
unmatched_groundtruths += batch_ugt
|
|
695
|
+
|
|
696
|
+
return unpack_confusion_matrix(
|
|
697
|
+
confusion_matrices=confusion_matrices,
|
|
698
|
+
unmatched_groundtruths=unmatched_groundtruths,
|
|
699
|
+
index_to_label=self._index_to_label,
|
|
700
|
+
score_thresholds=score_thresholds,
|
|
701
|
+
hardmax=hardmax,
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
def compute_examples(
|
|
705
|
+
self,
|
|
706
|
+
score_thresholds: list[float] = [0.0],
|
|
707
|
+
hardmax: bool = True,
|
|
708
|
+
datums: pc.Expression | None = None,
|
|
709
|
+
limit: int | None = None,
|
|
710
|
+
offset: int = 0,
|
|
711
|
+
) -> list[Metric]:
|
|
712
|
+
"""
|
|
713
|
+
Compute examples per datum.
|
|
714
|
+
|
|
715
|
+
Note: This function should be used with filtering to reduce response size.
|
|
716
|
+
|
|
717
|
+
Parameters
|
|
718
|
+
----------
|
|
719
|
+
score_thresholds : list[float]
|
|
720
|
+
A list of score thresholds to compute metrics over.
|
|
721
|
+
hardmax : bool
|
|
722
|
+
Toggles whether a hardmax is applied to predictions.
|
|
723
|
+
datums : pyarrow.compute.Expression, optional
|
|
724
|
+
Option to filter datums by an expression.
|
|
725
|
+
limit : int, optional
|
|
726
|
+
Option to set a limit to the number of returned datum examples.
|
|
727
|
+
offset : int, default=0
|
|
728
|
+
Option to offset where examples are being created in the datum index.
|
|
729
|
+
|
|
730
|
+
Returns
|
|
731
|
+
-------
|
|
732
|
+
list[Metric]
|
|
733
|
+
A list of confusion matrices.
|
|
734
|
+
"""
|
|
735
|
+
if not score_thresholds:
|
|
736
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
737
|
+
|
|
738
|
+
metrics = []
|
|
739
|
+
for tbl in compute.paginate_index(
|
|
740
|
+
source=self._reader,
|
|
741
|
+
column_key="datum_id",
|
|
742
|
+
modifier=datums,
|
|
743
|
+
limit=limit,
|
|
744
|
+
offset=offset,
|
|
745
|
+
):
|
|
746
|
+
if tbl.num_rows == 0:
|
|
747
|
+
continue
|
|
748
|
+
|
|
749
|
+
ids = np.column_stack(
|
|
750
|
+
[
|
|
751
|
+
tbl[col].to_numpy()
|
|
752
|
+
for col in [
|
|
753
|
+
"datum_id",
|
|
754
|
+
"gt_label_id",
|
|
755
|
+
"pd_label_id",
|
|
756
|
+
]
|
|
757
|
+
]
|
|
758
|
+
)
|
|
759
|
+
scores = tbl["pd_score"].to_numpy()
|
|
760
|
+
winners = tbl["pd_winner"].to_numpy()
|
|
761
|
+
|
|
762
|
+
# extract external identifiers
|
|
763
|
+
index_to_datum_id = create_mapping(
|
|
764
|
+
tbl, ids, 0, "datum_id", "datum_uid"
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
(
|
|
768
|
+
mask_tp,
|
|
769
|
+
mask_fp_fn_misclf,
|
|
770
|
+
mask_fn_unmatched,
|
|
771
|
+
) = compute_pair_classifications(
|
|
772
|
+
ids=ids,
|
|
773
|
+
scores=scores,
|
|
774
|
+
winners=winners,
|
|
775
|
+
score_thresholds=np.array(score_thresholds),
|
|
776
|
+
hardmax=hardmax,
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
mask_fn = mask_fp_fn_misclf | mask_fn_unmatched
|
|
780
|
+
mask_fp = mask_fp_fn_misclf
|
|
781
|
+
|
|
782
|
+
batch_examples = unpack_examples(
|
|
783
|
+
ids=ids,
|
|
784
|
+
mask_tp=mask_tp,
|
|
785
|
+
mask_fp=mask_fp,
|
|
786
|
+
mask_fn=mask_fn,
|
|
787
|
+
index_to_datum_id=index_to_datum_id,
|
|
788
|
+
score_thresholds=score_thresholds,
|
|
789
|
+
hardmax=hardmax,
|
|
790
|
+
index_to_label=self._index_to_label,
|
|
791
|
+
)
|
|
792
|
+
metrics.extend(batch_examples)
|
|
793
|
+
|
|
794
|
+
return metrics
|
|
795
|
+
|
|
796
|
+
def compute_confusion_matrix_with_examples(
|
|
797
|
+
self,
|
|
798
|
+
score_thresholds: list[float] = [0.0],
|
|
799
|
+
hardmax: bool = True,
|
|
800
|
+
datums: pc.Expression | None = None,
|
|
801
|
+
) -> list[Metric]:
|
|
802
|
+
"""
|
|
803
|
+
Compute confusion matrix with examples.
|
|
804
|
+
|
|
805
|
+
Note: This function should be used with filtering to reduce response size.
|
|
806
|
+
|
|
807
|
+
Parameters
|
|
808
|
+
----------
|
|
809
|
+
metrics : dict[int, Metric]
|
|
810
|
+
Mapping of score threshold index to cached metric.
|
|
811
|
+
score_thresholds : list[float]
|
|
812
|
+
A list of score thresholds to compute metrics over.
|
|
813
|
+
hardmax : bool
|
|
814
|
+
Toggles whether a hardmax is applied to predictions.
|
|
815
|
+
datums : pyarrow.compute.Expression, optional
|
|
816
|
+
Option to filter datums by an expression.
|
|
817
|
+
|
|
818
|
+
Returns
|
|
819
|
+
-------
|
|
820
|
+
list[Metric]
|
|
821
|
+
A list of confusion matrices.
|
|
822
|
+
"""
|
|
823
|
+
if not score_thresholds:
|
|
824
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
825
|
+
|
|
826
|
+
metrics = {
|
|
827
|
+
score_idx: create_empty_confusion_matrix_with_examples(
|
|
828
|
+
score_threshold=score_thresh,
|
|
829
|
+
hardmax=hardmax,
|
|
830
|
+
index_to_label=self._index_to_label,
|
|
831
|
+
)
|
|
832
|
+
for score_idx, score_thresh in enumerate(score_thresholds)
|
|
833
|
+
}
|
|
834
|
+
for tbl in self._reader.iterate_tables(filter=datums):
|
|
835
|
+
if tbl.num_rows == 0:
|
|
836
|
+
continue
|
|
837
|
+
|
|
838
|
+
ids = np.column_stack(
|
|
839
|
+
[
|
|
840
|
+
tbl[col].to_numpy()
|
|
841
|
+
for col in [
|
|
842
|
+
"datum_id",
|
|
843
|
+
"gt_label_id",
|
|
844
|
+
"pd_label_id",
|
|
845
|
+
]
|
|
846
|
+
]
|
|
847
|
+
)
|
|
848
|
+
scores = tbl["pd_score"].to_numpy()
|
|
849
|
+
winners = tbl["pd_winner"].to_numpy()
|
|
850
|
+
|
|
851
|
+
# extract external identifiers
|
|
852
|
+
index_to_datum_id = create_mapping(
|
|
853
|
+
tbl, ids, 0, "datum_id", "datum_uid"
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
(
|
|
857
|
+
mask_tp,
|
|
858
|
+
mask_fp_fn_misclf,
|
|
859
|
+
mask_fn_unmatched,
|
|
860
|
+
) = compute_pair_classifications(
|
|
861
|
+
ids=ids,
|
|
862
|
+
scores=scores,
|
|
863
|
+
winners=winners,
|
|
864
|
+
score_thresholds=np.array(score_thresholds),
|
|
865
|
+
hardmax=hardmax,
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
mask_matched = mask_tp | mask_fp_fn_misclf
|
|
869
|
+
mask_unmatched_fn = mask_fn_unmatched
|
|
870
|
+
|
|
871
|
+
unpack_confusion_matrix_with_examples(
|
|
872
|
+
metrics=metrics,
|
|
873
|
+
ids=ids,
|
|
874
|
+
scores=scores,
|
|
875
|
+
winners=winners,
|
|
876
|
+
mask_matched=mask_matched,
|
|
877
|
+
mask_unmatched_fn=mask_unmatched_fn,
|
|
878
|
+
index_to_datum_id=index_to_datum_id,
|
|
879
|
+
index_to_label=self._index_to_label,
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
return list(metrics.values())
|