valor-lite 0.36.6__py3-none-any.whl → 0.37.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. valor_lite/cache/__init__.py +11 -0
  2. valor_lite/cache/compute.py +211 -0
  3. valor_lite/cache/ephemeral.py +302 -0
  4. valor_lite/cache/persistent.py +536 -0
  5. valor_lite/classification/__init__.py +5 -10
  6. valor_lite/classification/annotation.py +4 -0
  7. valor_lite/classification/computation.py +233 -251
  8. valor_lite/classification/evaluator.py +882 -0
  9. valor_lite/classification/loader.py +97 -0
  10. valor_lite/classification/metric.py +141 -4
  11. valor_lite/classification/shared.py +184 -0
  12. valor_lite/classification/utilities.py +221 -118
  13. valor_lite/exceptions.py +5 -0
  14. valor_lite/object_detection/__init__.py +5 -4
  15. valor_lite/object_detection/annotation.py +13 -1
  16. valor_lite/object_detection/computation.py +368 -299
  17. valor_lite/object_detection/evaluator.py +804 -0
  18. valor_lite/object_detection/loader.py +292 -0
  19. valor_lite/object_detection/metric.py +152 -3
  20. valor_lite/object_detection/shared.py +206 -0
  21. valor_lite/object_detection/utilities.py +182 -100
  22. valor_lite/semantic_segmentation/__init__.py +5 -4
  23. valor_lite/semantic_segmentation/annotation.py +7 -0
  24. valor_lite/semantic_segmentation/computation.py +20 -110
  25. valor_lite/semantic_segmentation/evaluator.py +414 -0
  26. valor_lite/semantic_segmentation/loader.py +205 -0
  27. valor_lite/semantic_segmentation/shared.py +149 -0
  28. valor_lite/semantic_segmentation/utilities.py +6 -23
  29. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
  30. valor_lite-0.37.5.dist-info/RECORD +49 -0
  31. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
  32. valor_lite/classification/manager.py +0 -545
  33. valor_lite/object_detection/manager.py +0 -864
  34. valor_lite/profiling.py +0 -374
  35. valor_lite/semantic_segmentation/benchmark.py +0 -237
  36. valor_lite/semantic_segmentation/manager.py +0 -446
  37. valor_lite-0.36.6.dist-info/RECORD +0 -41
  38. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,882 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pyarrow as pa
8
+ import pyarrow.compute as pc
9
+
10
+ from valor_lite.cache import (
11
+ FileCacheReader,
12
+ FileCacheWriter,
13
+ MemoryCacheReader,
14
+ MemoryCacheWriter,
15
+ compute,
16
+ )
17
+ from valor_lite.classification.computation import (
18
+ compute_accuracy,
19
+ compute_confusion_matrix,
20
+ compute_counts,
21
+ compute_f1_score,
22
+ compute_pair_classifications,
23
+ compute_precision,
24
+ compute_recall,
25
+ compute_rocauc,
26
+ )
27
+ from valor_lite.classification.metric import Metric, MetricType
28
+ from valor_lite.classification.shared import (
29
+ EvaluatorInfo,
30
+ decode_metadata_fields,
31
+ encode_metadata_fields,
32
+ extract_counts,
33
+ extract_groundtruth_count_per_label,
34
+ extract_labels,
35
+ generate_cache_path,
36
+ generate_intermediate_cache_path,
37
+ generate_intermediate_schema,
38
+ generate_metadata_path,
39
+ generate_roc_curve_cache_path,
40
+ generate_roc_curve_schema,
41
+ generate_schema,
42
+ )
43
+ from valor_lite.classification.utilities import (
44
+ create_empty_confusion_matrix_with_examples,
45
+ create_mapping,
46
+ unpack_confusion_matrix,
47
+ unpack_confusion_matrix_with_examples,
48
+ unpack_examples,
49
+ unpack_precision_recall,
50
+ unpack_rocauc,
51
+ )
52
+ from valor_lite.exceptions import EmptyCacheError
53
+
54
+
55
+ class Builder:
56
+ def __init__(
57
+ self,
58
+ writer: MemoryCacheWriter | FileCacheWriter,
59
+ roc_curve_writer: MemoryCacheWriter | FileCacheWriter,
60
+ intermediate_writer: MemoryCacheWriter | FileCacheWriter,
61
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
62
+ ):
63
+ self._writer = writer
64
+ self._roc_curve_writer = roc_curve_writer
65
+ self._intermediate_writer = intermediate_writer
66
+ self._metadata_fields = metadata_fields
67
+
68
+ @classmethod
69
+ def in_memory(
70
+ cls,
71
+ batch_size: int = 10_000,
72
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
73
+ ):
74
+ """
75
+ Create an in-memory evaluator cache.
76
+
77
+ Parameters
78
+ ----------
79
+ batch_size : int, default=10_000
80
+ The target number of rows to buffer before writing to the cache. Defaults to 10_000.
81
+ metadata_fields : list[tuple[str, str | pa.DataType]], optional
82
+ Optional metadata field definitions.
83
+ """
84
+ writer = MemoryCacheWriter.create(
85
+ schema=generate_schema(metadata_fields),
86
+ batch_size=batch_size,
87
+ )
88
+ intermediate_writer = MemoryCacheWriter.create(
89
+ schema=generate_intermediate_schema(),
90
+ batch_size=batch_size,
91
+ )
92
+ roc_curve_writer = MemoryCacheWriter.create(
93
+ schema=generate_roc_curve_schema(),
94
+ batch_size=batch_size,
95
+ )
96
+ return cls(
97
+ writer=writer,
98
+ roc_curve_writer=roc_curve_writer,
99
+ intermediate_writer=intermediate_writer,
100
+ metadata_fields=metadata_fields,
101
+ )
102
+
103
+ @classmethod
104
+ def persistent(
105
+ cls,
106
+ path: str | Path,
107
+ batch_size: int = 10_000,
108
+ rows_per_file: int = 100_000,
109
+ compression: str = "snappy",
110
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
111
+ ):
112
+ """
113
+ Create a persistent file-based evaluator cache.
114
+
115
+ Parameters
116
+ ----------
117
+ path : str | Path
118
+ Where to store file-based cache.
119
+ batch_size : int, default=10_000
120
+ Sets the batch size for writing to file.
121
+ rows_per_file : int, default=100_000
122
+ Sets the maximum number of rows per file. This may be exceeded as files are datum aligned.
123
+ compression : str, default="snappy"
124
+ Sets the pyarrow compression method.
125
+ metadata_fields : list[tuple[str, str | pa.DataType]], optional
126
+ Optionally sets metadata description for use in filtering.
127
+ """
128
+ path = Path(path)
129
+
130
+ # create cache
131
+ writer = FileCacheWriter.create(
132
+ path=generate_cache_path(path),
133
+ schema=generate_schema(metadata_fields),
134
+ batch_size=batch_size,
135
+ rows_per_file=rows_per_file,
136
+ compression=compression,
137
+ )
138
+ intermediate_writer = FileCacheWriter.create(
139
+ path=generate_intermediate_cache_path(path),
140
+ schema=generate_intermediate_schema(),
141
+ batch_size=batch_size,
142
+ rows_per_file=rows_per_file,
143
+ compression=compression,
144
+ )
145
+ roc_curve_writer = FileCacheWriter.create(
146
+ path=generate_roc_curve_cache_path(path),
147
+ schema=generate_roc_curve_schema(),
148
+ batch_size=batch_size,
149
+ rows_per_file=rows_per_file,
150
+ compression=compression,
151
+ )
152
+
153
+ # write metadatata config
154
+ metadata_path = generate_metadata_path(path)
155
+ with open(metadata_path, "w") as f:
156
+ encoded_types = encode_metadata_fields(metadata_fields)
157
+ json.dump(encoded_types, f, indent=2)
158
+
159
+ return cls(
160
+ writer=writer,
161
+ roc_curve_writer=roc_curve_writer,
162
+ intermediate_writer=intermediate_writer,
163
+ metadata_fields=metadata_fields,
164
+ )
165
+
166
+ def _create_rocauc_intermediate(
167
+ self,
168
+ reader: MemoryCacheReader | FileCacheReader,
169
+ batch_size: int,
170
+ index_to_label: dict[int, str],
171
+ ):
172
+ n_labels = len(index_to_label)
173
+ compute.sort(
174
+ source=reader,
175
+ sink=self._intermediate_writer,
176
+ batch_size=batch_size,
177
+ sorting=[
178
+ ("pd_score", "descending"),
179
+ # ("pd_label_id", "ascending"),
180
+ ("match", "ascending"),
181
+ ],
182
+ columns=[
183
+ "pd_label_id",
184
+ "pd_score",
185
+ "match",
186
+ ],
187
+ )
188
+ intermediate = self._intermediate_writer.to_reader()
189
+
190
+ running_max_fp = np.zeros(n_labels, dtype=np.uint64)
191
+ running_max_tp = np.zeros(n_labels, dtype=np.uint64)
192
+ running_max_scores = np.zeros(n_labels, dtype=np.float64)
193
+
194
+ last_pair = np.zeros((n_labels, 2), dtype=np.uint64)
195
+ for tbl in intermediate.iterate_tables():
196
+ pd_label_ids = tbl["pd_label_id"].to_numpy()
197
+ tps = tbl["match"].to_numpy()
198
+ scores = tbl["pd_score"].to_numpy()
199
+ fps = ~tps
200
+
201
+ for idx in index_to_label.keys():
202
+ mask = pd_label_ids == idx
203
+ if mask.sum() == 0:
204
+ continue
205
+
206
+ cumulative_fp = np.r_[
207
+ running_max_fp[idx],
208
+ np.cumsum(fps[mask]) + running_max_fp[idx],
209
+ ]
210
+ cumulative_tp = np.r_[
211
+ running_max_tp[idx],
212
+ np.cumsum(tps[mask]) + running_max_tp[idx],
213
+ ]
214
+ pd_scores = np.r_[running_max_scores[idx], scores[mask]]
215
+
216
+ indices = (
217
+ np.where(
218
+ np.diff(np.r_[running_max_scores[idx], pd_scores])
219
+ )[0]
220
+ - 1
221
+ )
222
+
223
+ running_max_fp[idx] = cumulative_fp[-1]
224
+ running_max_tp[idx] = cumulative_tp[-1]
225
+ running_max_scores[idx] = pd_scores[-1]
226
+
227
+ for fp, tp in zip(
228
+ cumulative_fp[indices],
229
+ cumulative_tp[indices],
230
+ ):
231
+ last_pair[idx, 0] = fp
232
+ last_pair[idx, 1] = tp
233
+ self._roc_curve_writer.write_rows(
234
+ [
235
+ {
236
+ "pd_label_id": idx,
237
+ "cumulative_fp": fp,
238
+ "cumulative_tp": tp,
239
+ }
240
+ ]
241
+ )
242
+
243
+ # ensure any remaining values are ingested
244
+ for idx in range(n_labels):
245
+ last_fp = last_pair[idx, 0]
246
+ last_tp = last_pair[idx, 1]
247
+ if (
248
+ last_fp != running_max_fp[idx]
249
+ or last_tp != running_max_tp[idx]
250
+ ):
251
+ self._roc_curve_writer.write_rows(
252
+ [
253
+ {
254
+ "pd_label_id": idx,
255
+ "cumulative_fp": running_max_fp[idx],
256
+ "cumulative_tp": running_max_tp[idx],
257
+ }
258
+ ]
259
+ )
260
+
261
+ def finalize(
262
+ self,
263
+ batch_size: int = 1_000,
264
+ index_to_label_override: dict[int, str] | None = None,
265
+ ):
266
+ """
267
+ Performs data finalization and some preprocessing steps.
268
+
269
+ Parameters
270
+ ----------
271
+ batch_size : int, default=1_000
272
+ Sets the maximum number of elements read into memory per-file when performing merge sort.
273
+ index_to_label_override : dict[int, str], optional
274
+ Pre-configures label mapping. Used when operating over filtered subsets.
275
+
276
+ Returns
277
+ -------
278
+ Evaluator
279
+ A ready-to-use evaluator object.
280
+ """
281
+ self._writer.flush()
282
+ if self._writer.count_rows() == 0:
283
+ raise EmptyCacheError()
284
+ elif self._roc_curve_writer.count_rows() > 0:
285
+ raise RuntimeError("data already finalized")
286
+
287
+ # sort in-place and locally
288
+ self._writer.sort_by(
289
+ [
290
+ ("pd_score", "descending"),
291
+ ("datum_id", "ascending"),
292
+ ("gt_label_id", "ascending"),
293
+ ("pd_label_id", "ascending"),
294
+ ]
295
+ )
296
+
297
+ # post-process into sorted writer
298
+ reader = self._writer.to_reader()
299
+
300
+ # extract labels
301
+ index_to_label = extract_labels(
302
+ reader=reader,
303
+ index_to_label_override=index_to_label_override,
304
+ )
305
+
306
+ self._create_rocauc_intermediate(
307
+ reader=reader,
308
+ batch_size=batch_size,
309
+ index_to_label=index_to_label,
310
+ )
311
+ roc_curve_reader = self._roc_curve_writer.to_reader()
312
+
313
+ return Evaluator(
314
+ reader=reader,
315
+ roc_curve_reader=roc_curve_reader,
316
+ index_to_label=index_to_label,
317
+ metadata_fields=self._metadata_fields,
318
+ )
319
+
320
+
321
+ class Evaluator:
322
+ def __init__(
323
+ self,
324
+ reader: MemoryCacheReader | FileCacheReader,
325
+ roc_curve_reader: MemoryCacheReader | FileCacheReader,
326
+ index_to_label: dict[int, str],
327
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
328
+ ):
329
+ self._reader = reader
330
+ self._roc_curve_reader = roc_curve_reader
331
+ self._index_to_label = index_to_label
332
+ self._metadata_fields = metadata_fields
333
+
334
+ @property
335
+ def info(self) -> EvaluatorInfo:
336
+ return self.get_info()
337
+
338
+ def get_info(
339
+ self,
340
+ datums: pc.Expression | None = None,
341
+ ) -> EvaluatorInfo:
342
+ info = EvaluatorInfo()
343
+ info.metadata_fields = self._metadata_fields
344
+ info.number_of_rows = self._reader.count_rows()
345
+ info.number_of_labels = len(self._index_to_label)
346
+ info.number_of_datums = extract_counts(
347
+ reader=self._reader,
348
+ datums=datums,
349
+ )
350
+ return info
351
+
352
+ @classmethod
353
+ def load(
354
+ cls,
355
+ path: str | Path,
356
+ index_to_label_override: dict[int, str] | None = None,
357
+ ):
358
+ """
359
+ Load from an existing classification cache.
360
+
361
+ Parameters
362
+ ----------
363
+ path : str | Path
364
+ Path to the existing cache.
365
+ index_to_label_override : dict[int, str], optional
366
+ Option to preset index to label dictionary. Used when loading from filtered caches.
367
+ """
368
+
369
+ # validate path
370
+ path = Path(path)
371
+ if not path.exists():
372
+ raise FileNotFoundError(f"Directory does not exist: {path}")
373
+ elif not path.is_dir():
374
+ raise NotADirectoryError(
375
+ f"Path exists but is not a directory: {path}"
376
+ )
377
+
378
+ # load cache
379
+ reader = FileCacheReader.load(generate_cache_path(path))
380
+ roc_curve_reader = FileCacheReader.load(
381
+ generate_roc_curve_cache_path(path)
382
+ )
383
+
384
+ # extract labels
385
+ index_to_label = extract_labels(
386
+ reader=reader,
387
+ index_to_label_override=index_to_label_override,
388
+ )
389
+
390
+ # read config
391
+ metadata_path = generate_metadata_path(path)
392
+ metadata_fields = None
393
+ with open(metadata_path, "r") as f:
394
+ encoded_types = json.load(f)
395
+ metadata_fields = decode_metadata_fields(encoded_types)
396
+
397
+ return cls(
398
+ reader=reader,
399
+ roc_curve_reader=roc_curve_reader,
400
+ index_to_label=index_to_label,
401
+ metadata_fields=metadata_fields,
402
+ )
403
+
404
+ def filter(
405
+ self,
406
+ datums: pc.Expression | None = None,
407
+ groundtruths: pc.Expression | None = None,
408
+ predictions: pc.Expression | None = None,
409
+ path: str | Path | None = None,
410
+ ) -> Evaluator:
411
+ """
412
+ Filter evaluator cache.
413
+
414
+ Parameters
415
+ ----------
416
+ datums : pc.Expression | None = None
417
+ A filter expression used to filter datums.
418
+ groundtruths : pc.Expression | None = None
419
+ A filter expression used to filter ground truth annotations.
420
+ predictions : pc.Expression | None = None
421
+ A filter expression used to filter predictions.
422
+ path : str | Path, optional
423
+ Where to store the filtered cache if storing on disk.
424
+
425
+ Returns
426
+ -------
427
+ Evaluator
428
+ A new evaluator object containing the filtered cache.
429
+ """
430
+ from valor_lite.classification.loader import Loader
431
+
432
+ if isinstance(self._reader, FileCacheReader):
433
+ if not path:
434
+ raise ValueError(
435
+ "expected path to be defined for file-based loader"
436
+ )
437
+ loader = Loader.persistent(
438
+ path=path,
439
+ batch_size=self._reader.batch_size,
440
+ rows_per_file=self._reader.rows_per_file,
441
+ compression=self._reader.compression,
442
+ metadata_fields=self.info.metadata_fields,
443
+ )
444
+ else:
445
+ loader = Loader.in_memory(
446
+ batch_size=self._reader.batch_size,
447
+ metadata_fields=self.info.metadata_fields,
448
+ )
449
+
450
+ for tbl in self._reader.iterate_tables(filter=datums):
451
+ columns = (
452
+ "datum_id",
453
+ "gt_label_id",
454
+ "pd_label_id",
455
+ )
456
+ pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
457
+
458
+ n_pairs = pairs.shape[0]
459
+ gt_ids = pairs[:, (0, 1)].astype(np.int64)
460
+ pd_ids = pairs[:, (0, 2)].astype(np.int64)
461
+
462
+ if groundtruths is not None:
463
+ mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
464
+ gt_tbl = tbl.filter(groundtruths)
465
+ gt_pairs = np.column_stack(
466
+ [
467
+ gt_tbl[col].to_numpy()
468
+ for col in ("datum_id", "gt_label_id")
469
+ ]
470
+ ).astype(np.int64)
471
+ for gt in np.unique(gt_pairs, axis=0):
472
+ mask_valid_gt |= (gt_ids == gt).all(axis=1)
473
+ else:
474
+ mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
475
+
476
+ if predictions is not None:
477
+ mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
478
+ pd_tbl = tbl.filter(predictions)
479
+ pd_pairs = np.column_stack(
480
+ [
481
+ pd_tbl[col].to_numpy()
482
+ for col in ("datum_id", "pd_label_id")
483
+ ]
484
+ ).astype(np.int64)
485
+ for pd in np.unique(pd_pairs, axis=0):
486
+ mask_valid_pd |= (pd_ids == pd).all(axis=1)
487
+ else:
488
+ mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
489
+
490
+ mask_valid = mask_valid_gt | mask_valid_pd
491
+ mask_valid_gt &= mask_valid
492
+ mask_valid_pd &= mask_valid
493
+
494
+ pairs[~mask_valid_gt, 1] = -1
495
+ pairs[~mask_valid_pd, 2] = -1
496
+
497
+ for idx, col in enumerate(columns):
498
+ tbl = tbl.set_column(
499
+ tbl.schema.names.index(col), col, pa.array(pairs[:, idx])
500
+ )
501
+ # TODO (c.zaloom) - improve write strategy, filtered data could be small
502
+ loader._writer.write_table(tbl)
503
+
504
+ return loader.finalize(index_to_label_override=self._index_to_label)
505
+
506
+ def iterate_values(self, datums: pc.Expression | None = None):
507
+ columns = [
508
+ "datum_id",
509
+ "gt_label_id",
510
+ "pd_label_id",
511
+ "pd_score",
512
+ "pd_winner",
513
+ "match",
514
+ ]
515
+ for tbl in self._reader.iterate_tables(columns=columns, filter=datums):
516
+ ids = np.column_stack(
517
+ [
518
+ tbl[col].to_numpy()
519
+ for col in [
520
+ "datum_id",
521
+ "gt_label_id",
522
+ "pd_label_id",
523
+ ]
524
+ ]
525
+ )
526
+ scores = tbl["pd_score"].to_numpy()
527
+ winners = tbl["pd_winner"].to_numpy()
528
+ matches = tbl["match"].to_numpy()
529
+ yield ids, scores, winners, matches
530
+
531
+ def compute_rocauc(self) -> dict[MetricType, list[Metric]]:
532
+ """
533
+ Compute ROCAUC.
534
+
535
+ This function does not support direct filtering. To perform evaluation over a filtered
536
+ set you must first create a new evaluator using `Evaluator.filter`.
537
+
538
+ Returns
539
+ -------
540
+ dict[MetricType, list[Metric]]
541
+ A dictionary mapping MetricType enumerations to lists of computed metrics.
542
+ """
543
+ n_labels = self.info.number_of_labels
544
+
545
+ rocauc = np.zeros(n_labels, dtype=np.float64)
546
+ label_counts = extract_groundtruth_count_per_label(
547
+ reader=self._reader,
548
+ number_of_labels=len(self._index_to_label),
549
+ )
550
+
551
+ prev = np.zeros((n_labels, 2), dtype=np.uint64)
552
+ for array in self._roc_curve_reader.iterate_arrays(
553
+ numeric_columns=[
554
+ "pd_label_id",
555
+ "cumulative_fp",
556
+ "cumulative_tp",
557
+ ],
558
+ ):
559
+ rocauc, prev = compute_rocauc(
560
+ rocauc=rocauc,
561
+ array=array,
562
+ gt_count_per_label=label_counts[:, 0],
563
+ pd_count_per_label=label_counts[:, 1],
564
+ n_labels=self.info.number_of_labels,
565
+ prev=prev,
566
+ )
567
+
568
+ mean_rocauc = rocauc.mean()
569
+
570
+ return unpack_rocauc(
571
+ rocauc=rocauc,
572
+ mean_rocauc=mean_rocauc,
573
+ index_to_label=self._index_to_label,
574
+ )
575
+
576
+ def compute_precision_recall(
577
+ self,
578
+ score_thresholds: list[float] = [0.0],
579
+ hardmax: bool = True,
580
+ datums: pc.Expression | None = None,
581
+ ) -> dict[MetricType, list]:
582
+ """
583
+ Performs an evaluation and returns metrics.
584
+
585
+ Parameters
586
+ ----------
587
+ score_thresholds : list[float]
588
+ A list of score thresholds to compute metrics over.
589
+ hardmax : bool
590
+ Toggles whether a hardmax is applied to predictions.
591
+ datums : pyarrow.compute.Expression, optional
592
+ Option to filter datums by an expression.
593
+
594
+ Returns
595
+ -------
596
+ dict[MetricType, list]
597
+ A dictionary mapping MetricType enumerations to lists of computed metrics.
598
+ """
599
+ if not score_thresholds:
600
+ raise ValueError("At least one score threshold must be passed.")
601
+
602
+ n_scores = len(score_thresholds)
603
+ n_datums = self.info.number_of_datums
604
+ n_labels = self.info.number_of_labels
605
+
606
+ # intermediates
607
+ counts = np.zeros((n_scores, n_labels, 4), dtype=np.uint64)
608
+
609
+ for ids, scores, winners, _ in self.iterate_values(datums=datums):
610
+ batch_counts = compute_counts(
611
+ ids=ids,
612
+ scores=scores,
613
+ winners=winners,
614
+ score_thresholds=np.array(score_thresholds),
615
+ hardmax=hardmax,
616
+ n_labels=n_labels,
617
+ )
618
+ counts += batch_counts
619
+
620
+ precision = compute_precision(counts)
621
+ recall = compute_recall(counts)
622
+ f1_score = compute_f1_score(precision, recall)
623
+ accuracy = compute_accuracy(counts, n_datums=n_datums)
624
+
625
+ return unpack_precision_recall(
626
+ counts=counts,
627
+ precision=precision,
628
+ recall=recall,
629
+ accuracy=accuracy,
630
+ f1_score=f1_score,
631
+ score_thresholds=score_thresholds,
632
+ hardmax=hardmax,
633
+ index_to_label=self._index_to_label,
634
+ )
635
+
636
+ def compute_confusion_matrix(
637
+ self,
638
+ score_thresholds: list[float] = [0.0],
639
+ hardmax: bool = True,
640
+ datums: pc.Expression | None = None,
641
+ ) -> list[Metric]:
642
+ """
643
+ Compute a confusion matrix.
644
+
645
+ Parameters
646
+ ----------
647
+ score_thresholds : list[float]
648
+ A list of score thresholds to compute metrics over.
649
+ hardmax : bool
650
+ Toggles whether a hardmax is applied to predictions.
651
+ datums : pyarrow.compute.Expression, optional
652
+ Option to filter datums by an expression.
653
+
654
+ Returns
655
+ -------
656
+ list[Metric]
657
+ A list of confusion matrices.
658
+ """
659
+ if not score_thresholds:
660
+ raise ValueError("At least one score threshold must be passed.")
661
+
662
+ n_scores = len(score_thresholds)
663
+ n_labels = len(self._index_to_label)
664
+ confusion_matrices = np.zeros(
665
+ (n_scores, n_labels, n_labels), dtype=np.uint64
666
+ )
667
+ unmatched_groundtruths = np.zeros(
668
+ (n_scores, n_labels), dtype=np.uint64
669
+ )
670
+ for ids, scores, winners, matches in self.iterate_values(
671
+ datums=datums
672
+ ):
673
+ (
674
+ mask_tp,
675
+ mask_fp_fn_misclf,
676
+ mask_fn_unmatched,
677
+ ) = compute_pair_classifications(
678
+ ids=ids,
679
+ scores=scores,
680
+ winners=winners,
681
+ score_thresholds=np.array(score_thresholds),
682
+ hardmax=hardmax,
683
+ )
684
+
685
+ batch_cm, batch_ugt = compute_confusion_matrix(
686
+ ids=ids,
687
+ mask_tp=mask_tp,
688
+ mask_fp_fn_misclf=mask_fp_fn_misclf,
689
+ mask_fn_unmatched=mask_fn_unmatched,
690
+ score_thresholds=np.array(score_thresholds),
691
+ n_labels=n_labels,
692
+ )
693
+ confusion_matrices += batch_cm
694
+ unmatched_groundtruths += batch_ugt
695
+
696
+ return unpack_confusion_matrix(
697
+ confusion_matrices=confusion_matrices,
698
+ unmatched_groundtruths=unmatched_groundtruths,
699
+ index_to_label=self._index_to_label,
700
+ score_thresholds=score_thresholds,
701
+ hardmax=hardmax,
702
+ )
703
+
704
+ def compute_examples(
705
+ self,
706
+ score_thresholds: list[float] = [0.0],
707
+ hardmax: bool = True,
708
+ datums: pc.Expression | None = None,
709
+ limit: int | None = None,
710
+ offset: int = 0,
711
+ ) -> list[Metric]:
712
+ """
713
+ Compute examples per datum.
714
+
715
+ Note: This function should be used with filtering to reduce response size.
716
+
717
+ Parameters
718
+ ----------
719
+ score_thresholds : list[float]
720
+ A list of score thresholds to compute metrics over.
721
+ hardmax : bool
722
+ Toggles whether a hardmax is applied to predictions.
723
+ datums : pyarrow.compute.Expression, optional
724
+ Option to filter datums by an expression.
725
+ limit : int, optional
726
+ Option to set a limit to the number of returned datum examples.
727
+ offset : int, default=0
728
+ Option to offset where examples are being created in the datum index.
729
+
730
+ Returns
731
+ -------
732
+ list[Metric]
733
+ A list of confusion matrices.
734
+ """
735
+ if not score_thresholds:
736
+ raise ValueError("At least one score threshold must be passed.")
737
+
738
+ metrics = []
739
+ for tbl in compute.paginate_index(
740
+ source=self._reader,
741
+ column_key="datum_id",
742
+ modifier=datums,
743
+ limit=limit,
744
+ offset=offset,
745
+ ):
746
+ if tbl.num_rows == 0:
747
+ continue
748
+
749
+ ids = np.column_stack(
750
+ [
751
+ tbl[col].to_numpy()
752
+ for col in [
753
+ "datum_id",
754
+ "gt_label_id",
755
+ "pd_label_id",
756
+ ]
757
+ ]
758
+ )
759
+ scores = tbl["pd_score"].to_numpy()
760
+ winners = tbl["pd_winner"].to_numpy()
761
+
762
+ # extract external identifiers
763
+ index_to_datum_id = create_mapping(
764
+ tbl, ids, 0, "datum_id", "datum_uid"
765
+ )
766
+
767
+ (
768
+ mask_tp,
769
+ mask_fp_fn_misclf,
770
+ mask_fn_unmatched,
771
+ ) = compute_pair_classifications(
772
+ ids=ids,
773
+ scores=scores,
774
+ winners=winners,
775
+ score_thresholds=np.array(score_thresholds),
776
+ hardmax=hardmax,
777
+ )
778
+
779
+ mask_fn = mask_fp_fn_misclf | mask_fn_unmatched
780
+ mask_fp = mask_fp_fn_misclf
781
+
782
+ batch_examples = unpack_examples(
783
+ ids=ids,
784
+ mask_tp=mask_tp,
785
+ mask_fp=mask_fp,
786
+ mask_fn=mask_fn,
787
+ index_to_datum_id=index_to_datum_id,
788
+ score_thresholds=score_thresholds,
789
+ hardmax=hardmax,
790
+ index_to_label=self._index_to_label,
791
+ )
792
+ metrics.extend(batch_examples)
793
+
794
+ return metrics
795
+
796
+ def compute_confusion_matrix_with_examples(
797
+ self,
798
+ score_thresholds: list[float] = [0.0],
799
+ hardmax: bool = True,
800
+ datums: pc.Expression | None = None,
801
+ ) -> list[Metric]:
802
+ """
803
+ Compute confusion matrix with examples.
804
+
805
+ Note: This function should be used with filtering to reduce response size.
806
+
807
+ Parameters
808
+ ----------
809
+ metrics : dict[int, Metric]
810
+ Mapping of score threshold index to cached metric.
811
+ score_thresholds : list[float]
812
+ A list of score thresholds to compute metrics over.
813
+ hardmax : bool
814
+ Toggles whether a hardmax is applied to predictions.
815
+ datums : pyarrow.compute.Expression, optional
816
+ Option to filter datums by an expression.
817
+
818
+ Returns
819
+ -------
820
+ list[Metric]
821
+ A list of confusion matrices.
822
+ """
823
+ if not score_thresholds:
824
+ raise ValueError("At least one score threshold must be passed.")
825
+
826
+ metrics = {
827
+ score_idx: create_empty_confusion_matrix_with_examples(
828
+ score_threshold=score_thresh,
829
+ hardmax=hardmax,
830
+ index_to_label=self._index_to_label,
831
+ )
832
+ for score_idx, score_thresh in enumerate(score_thresholds)
833
+ }
834
+ for tbl in self._reader.iterate_tables(filter=datums):
835
+ if tbl.num_rows == 0:
836
+ continue
837
+
838
+ ids = np.column_stack(
839
+ [
840
+ tbl[col].to_numpy()
841
+ for col in [
842
+ "datum_id",
843
+ "gt_label_id",
844
+ "pd_label_id",
845
+ ]
846
+ ]
847
+ )
848
+ scores = tbl["pd_score"].to_numpy()
849
+ winners = tbl["pd_winner"].to_numpy()
850
+
851
+ # extract external identifiers
852
+ index_to_datum_id = create_mapping(
853
+ tbl, ids, 0, "datum_id", "datum_uid"
854
+ )
855
+
856
+ (
857
+ mask_tp,
858
+ mask_fp_fn_misclf,
859
+ mask_fn_unmatched,
860
+ ) = compute_pair_classifications(
861
+ ids=ids,
862
+ scores=scores,
863
+ winners=winners,
864
+ score_thresholds=np.array(score_thresholds),
865
+ hardmax=hardmax,
866
+ )
867
+
868
+ mask_matched = mask_tp | mask_fp_fn_misclf
869
+ mask_unmatched_fn = mask_fn_unmatched
870
+
871
+ unpack_confusion_matrix_with_examples(
872
+ metrics=metrics,
873
+ ids=ids,
874
+ scores=scores,
875
+ winners=winners,
876
+ mask_matched=mask_matched,
877
+ mask_unmatched_fn=mask_unmatched_fn,
878
+ index_to_datum_id=index_to_datum_id,
879
+ index_to_label=self._index_to_label,
880
+ )
881
+
882
+ return list(metrics.values())