valor-lite 0.36.6__py3-none-any.whl → 0.37.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. valor_lite/cache/__init__.py +11 -0
  2. valor_lite/cache/compute.py +211 -0
  3. valor_lite/cache/ephemeral.py +302 -0
  4. valor_lite/cache/persistent.py +536 -0
  5. valor_lite/classification/__init__.py +5 -10
  6. valor_lite/classification/annotation.py +4 -0
  7. valor_lite/classification/computation.py +233 -251
  8. valor_lite/classification/evaluator.py +882 -0
  9. valor_lite/classification/loader.py +97 -0
  10. valor_lite/classification/metric.py +141 -4
  11. valor_lite/classification/shared.py +184 -0
  12. valor_lite/classification/utilities.py +221 -118
  13. valor_lite/exceptions.py +5 -0
  14. valor_lite/object_detection/__init__.py +5 -4
  15. valor_lite/object_detection/annotation.py +13 -1
  16. valor_lite/object_detection/computation.py +368 -299
  17. valor_lite/object_detection/evaluator.py +804 -0
  18. valor_lite/object_detection/loader.py +292 -0
  19. valor_lite/object_detection/metric.py +152 -3
  20. valor_lite/object_detection/shared.py +206 -0
  21. valor_lite/object_detection/utilities.py +182 -100
  22. valor_lite/semantic_segmentation/__init__.py +5 -4
  23. valor_lite/semantic_segmentation/annotation.py +7 -0
  24. valor_lite/semantic_segmentation/computation.py +20 -110
  25. valor_lite/semantic_segmentation/evaluator.py +414 -0
  26. valor_lite/semantic_segmentation/loader.py +205 -0
  27. valor_lite/semantic_segmentation/shared.py +149 -0
  28. valor_lite/semantic_segmentation/utilities.py +6 -23
  29. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
  30. valor_lite-0.37.5.dist-info/RECORD +49 -0
  31. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
  32. valor_lite/classification/manager.py +0 -545
  33. valor_lite/object_detection/manager.py +0 -864
  34. valor_lite/profiling.py +0 -374
  35. valor_lite/semantic_segmentation/benchmark.py +0 -237
  36. valor_lite/semantic_segmentation/manager.py +0 -446
  37. valor_lite-0.36.6.dist-info/RECORD +0 -41
  38. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,804 @@
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import pyarrow as pa
6
+ import pyarrow.compute as pc
7
+
8
+ from valor_lite.cache import (
9
+ FileCacheReader,
10
+ FileCacheWriter,
11
+ MemoryCacheReader,
12
+ MemoryCacheWriter,
13
+ compute,
14
+ )
15
+ from valor_lite.exceptions import EmptyCacheError
16
+ from valor_lite.object_detection.computation import (
17
+ compute_average_precision,
18
+ compute_average_recall,
19
+ compute_confusion_matrix,
20
+ compute_counts,
21
+ compute_pair_classifications,
22
+ compute_precision_recall_f1,
23
+ rank_table,
24
+ )
25
+ from valor_lite.object_detection.metric import Metric, MetricType
26
+ from valor_lite.object_detection.shared import (
27
+ EvaluatorInfo,
28
+ decode_metadata_fields,
29
+ encode_metadata_fields,
30
+ extract_counts,
31
+ extract_groundtruth_count_per_label,
32
+ extract_labels,
33
+ generate_detailed_cache_path,
34
+ generate_detailed_schema,
35
+ generate_metadata_path,
36
+ generate_ranked_cache_path,
37
+ generate_ranked_schema,
38
+ )
39
+ from valor_lite.object_detection.utilities import (
40
+ create_empty_confusion_matrix_with_examples,
41
+ create_mapping,
42
+ unpack_confusion_matrix,
43
+ unpack_confusion_matrix_with_examples,
44
+ unpack_examples,
45
+ unpack_precision_recall_into_metric_lists,
46
+ )
47
+
48
+
49
+ class Builder:
50
+ def __init__(
51
+ self,
52
+ detailed_writer: MemoryCacheWriter | FileCacheWriter,
53
+ ranked_writer: MemoryCacheWriter | FileCacheWriter,
54
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
55
+ ):
56
+ self._detailed_writer = detailed_writer
57
+ self._ranked_writer = ranked_writer
58
+ self._metadata_fields = metadata_fields
59
+
60
+ @classmethod
61
+ def in_memory(
62
+ cls,
63
+ batch_size: int = 10_000,
64
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
65
+ ):
66
+ """
67
+ Create an in-memory evaluator cache.
68
+
69
+ Parameters
70
+ ----------
71
+ batch_size : int, default=10_000
72
+ The target number of rows to buffer before writing to the cache. Defaults to 10_000.
73
+ metadata_fields : list[tuple[str, str | pa.DataType]], optional
74
+ Optional datum metadata field definitions.
75
+ """
76
+ # create cache
77
+ detailed_writer = MemoryCacheWriter.create(
78
+ schema=generate_detailed_schema(metadata_fields),
79
+ batch_size=batch_size,
80
+ )
81
+ ranked_writer = MemoryCacheWriter.create(
82
+ schema=generate_ranked_schema(metadata_fields),
83
+ batch_size=batch_size,
84
+ )
85
+
86
+ return cls(
87
+ detailed_writer=detailed_writer,
88
+ ranked_writer=ranked_writer,
89
+ metadata_fields=metadata_fields,
90
+ )
91
+
92
+ @classmethod
93
+ def persistent(
94
+ cls,
95
+ path: str | Path,
96
+ batch_size: int = 10_000,
97
+ rows_per_file: int = 100_000,
98
+ compression: str = "snappy",
99
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
100
+ ):
101
+ """
102
+ Create a persistent file-based evaluator cache.
103
+
104
+ Parameters
105
+ ----------
106
+ path : str | Path
107
+ Where to store the file-based cache.
108
+ batch_size : int, default=10_000
109
+ The target number of rows to buffer before writing to the cache. Defaults to 10_000.
110
+ rows_per_file : int, default=100_000
111
+ The target number of rows to store per cache file. Defaults to 100_000.
112
+ compression : str, default="snappy"
113
+ The compression methods used when writing cache files.
114
+ metadata_fields : list[tuple[str, str | pa.DataType]], optional
115
+ Optional metadata field definitions.
116
+ """
117
+ path = Path(path)
118
+
119
+ # create caches
120
+ detailed_writer = FileCacheWriter.create(
121
+ path=generate_detailed_cache_path(path),
122
+ schema=generate_detailed_schema(metadata_fields),
123
+ batch_size=batch_size,
124
+ rows_per_file=rows_per_file,
125
+ compression=compression,
126
+ )
127
+ ranked_writer = FileCacheWriter.create(
128
+ path=generate_ranked_cache_path(path),
129
+ schema=generate_ranked_schema(metadata_fields),
130
+ batch_size=batch_size,
131
+ rows_per_file=rows_per_file,
132
+ compression=compression,
133
+ )
134
+
135
+ # write metadata
136
+ metadata_path = generate_metadata_path(path)
137
+ with open(metadata_path, "w") as f:
138
+ encoded_types = encode_metadata_fields(metadata_fields)
139
+ json.dump(encoded_types, f, indent=2)
140
+
141
+ return cls(
142
+ detailed_writer=detailed_writer,
143
+ ranked_writer=ranked_writer,
144
+ metadata_fields=metadata_fields,
145
+ )
146
+
147
+ def _rank(self, batch_size: int = 1_000):
148
+ """Perform pair ranking over the detailed cache."""
149
+
150
+ detailed_reader = self._detailed_writer.to_reader()
151
+ compute.sort(
152
+ source=detailed_reader,
153
+ sink=self._ranked_writer,
154
+ batch_size=batch_size,
155
+ sorting=[
156
+ ("pd_score", "descending"),
157
+ ("iou", "descending"),
158
+ ],
159
+ columns=[
160
+ field.name
161
+ for field in self._ranked_writer.schema
162
+ if field.name != "iou_prev"
163
+ ],
164
+ table_sort_override=rank_table,
165
+ )
166
+ self._ranked_writer.flush()
167
+
168
+ def finalize(
169
+ self,
170
+ batch_size: int = 1_000,
171
+ index_to_label_override: dict[int, str] | None = None,
172
+ ):
173
+ """
174
+ Performs data finalization and preprocessing.
175
+
176
+ Parameters
177
+ ----------
178
+ batch_size : int, default=1_000
179
+ Sets the batch size for reading. Defaults to 1_000.
180
+ index_to_label_override : dict[int, str], optional
181
+ Pre-configures label mapping. Used when operating over filtered subsets.
182
+ """
183
+ self._detailed_writer.flush()
184
+ if self._detailed_writer.count_rows() == 0:
185
+ raise EmptyCacheError()
186
+
187
+ self._detailed_writer.sort_by(
188
+ [
189
+ ("pd_score", "descending"),
190
+ ("iou", "descending"),
191
+ ]
192
+ )
193
+ detailed_reader = self._detailed_writer.to_reader()
194
+
195
+ # extract labels
196
+ index_to_label = extract_labels(
197
+ reader=detailed_reader,
198
+ index_to_label_override=index_to_label_override,
199
+ )
200
+
201
+ # populate ranked cache
202
+ self._rank(batch_size)
203
+
204
+ ranked_reader = self._ranked_writer.to_reader()
205
+ return Evaluator(
206
+ detailed_reader=detailed_reader,
207
+ ranked_reader=ranked_reader,
208
+ index_to_label=index_to_label,
209
+ metadata_fields=self._metadata_fields,
210
+ )
211
+
212
+
213
+ class Evaluator:
214
+ def __init__(
215
+ self,
216
+ detailed_reader: MemoryCacheReader | FileCacheReader,
217
+ ranked_reader: MemoryCacheReader | FileCacheReader,
218
+ index_to_label: dict[int, str],
219
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
220
+ ):
221
+ self._detailed_reader = detailed_reader
222
+ self._ranked_reader = ranked_reader
223
+ self._index_to_label = index_to_label
224
+ self._metadata_fields = metadata_fields
225
+
226
+ @property
227
+ def info(self) -> EvaluatorInfo:
228
+ return self.get_info()
229
+
230
+ def get_info(
231
+ self,
232
+ datums: pc.Expression | None = None,
233
+ groundtruths: pc.Expression | None = None,
234
+ predictions: pc.Expression | None = None,
235
+ ) -> EvaluatorInfo:
236
+ info = EvaluatorInfo()
237
+ info.number_of_rows = self._detailed_reader.count_rows()
238
+ info.metadata_fields = self._metadata_fields
239
+ info.number_of_labels = len(self._index_to_label)
240
+ (
241
+ info.number_of_datums,
242
+ info.number_of_groundtruth_annotations,
243
+ info.number_of_prediction_annotations,
244
+ ) = extract_counts(
245
+ reader=self._detailed_reader,
246
+ datums=datums,
247
+ groundtruths=groundtruths,
248
+ predictions=predictions,
249
+ )
250
+ return info
251
+
252
+ @classmethod
253
+ def load(
254
+ cls,
255
+ path: str | Path,
256
+ index_to_label_override: dict[int, str] | None = None,
257
+ ):
258
+ # validate path
259
+ path = Path(path)
260
+ if not path.exists():
261
+ raise FileNotFoundError(f"Directory does not exist: {path}")
262
+ elif not path.is_dir():
263
+ raise NotADirectoryError(
264
+ f"Path exists but is not a directory: {path}"
265
+ )
266
+
267
+ detailed_reader = FileCacheReader.load(
268
+ generate_detailed_cache_path(path)
269
+ )
270
+ ranked_reader = FileCacheReader.load(generate_ranked_cache_path(path))
271
+
272
+ # extract labels from cache
273
+ index_to_label = extract_labels(
274
+ reader=detailed_reader,
275
+ index_to_label_override=index_to_label_override,
276
+ )
277
+
278
+ # read config
279
+ metadata_path = generate_metadata_path(path)
280
+ metadata_fields = None
281
+ with open(metadata_path, "r") as f:
282
+ encoded_metadata_types = json.load(f)
283
+ metadata_fields = decode_metadata_fields(encoded_metadata_types)
284
+
285
+ return cls(
286
+ detailed_reader=detailed_reader,
287
+ ranked_reader=ranked_reader,
288
+ index_to_label=index_to_label,
289
+ metadata_fields=metadata_fields,
290
+ )
291
+
292
+ def filter(
293
+ self,
294
+ datums: pc.Expression | None = None,
295
+ groundtruths: pc.Expression | None = None,
296
+ predictions: pc.Expression | None = None,
297
+ batch_size: int = 1_000,
298
+ path: str | Path | None = None,
299
+ ) -> "Evaluator":
300
+ """
301
+ Filter evaluator cache.
302
+
303
+ Parameters
304
+ ----------
305
+ datums : pc.Expression | None = None
306
+ A filter expression used to filter datums.
307
+ groundtruths : pc.Expression | None = None
308
+ A filter expression used to filter ground truth annotations.
309
+ predictions : pc.Expression | None = None
310
+ A filter expression used to filter predictions.
311
+ batch_size : int
312
+ The maximum number of rows read into memory per file.
313
+ path : str | Path
314
+ Where to store the filtered cache if storing on disk.
315
+
316
+ Returns
317
+ -------
318
+ Evaluator
319
+ A new evaluator object containing the filtered cache.
320
+ """
321
+ from valor_lite.object_detection.loader import Loader
322
+
323
+ if isinstance(self._detailed_reader, FileCacheReader):
324
+ if not path:
325
+ raise ValueError(
326
+ "expected path to be defined for file-based loader"
327
+ )
328
+ loader = Loader.persistent(
329
+ path=path,
330
+ batch_size=self._detailed_reader.batch_size,
331
+ rows_per_file=self._detailed_reader.rows_per_file,
332
+ compression=self._detailed_reader.compression,
333
+ metadata_fields=self._metadata_fields,
334
+ )
335
+ else:
336
+ loader = Loader.in_memory(
337
+ batch_size=self._detailed_reader.batch_size,
338
+ metadata_fields=self._metadata_fields,
339
+ )
340
+
341
+ for tbl in self._detailed_reader.iterate_tables(filter=datums):
342
+ columns = (
343
+ "datum_id",
344
+ "gt_id",
345
+ "pd_id",
346
+ "gt_label_id",
347
+ "pd_label_id",
348
+ "iou",
349
+ "pd_score",
350
+ )
351
+ pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
352
+
353
+ n_pairs = pairs.shape[0]
354
+ gt_ids = pairs[:, (0, 1)].astype(np.int64)
355
+ pd_ids = pairs[:, (0, 2)].astype(np.int64)
356
+
357
+ if groundtruths is not None:
358
+ mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
359
+ gt_tbl = tbl.filter(groundtruths)
360
+ gt_pairs = np.column_stack(
361
+ [gt_tbl[col].to_numpy() for col in ("datum_id", "gt_id")]
362
+ ).astype(np.int64)
363
+ for gt in np.unique(gt_pairs, axis=0):
364
+ mask_valid_gt |= (gt_ids == gt).all(axis=1)
365
+ else:
366
+ mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
367
+
368
+ if predictions is not None:
369
+ mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
370
+ pd_tbl = tbl.filter(predictions)
371
+ pd_pairs = np.column_stack(
372
+ [pd_tbl[col].to_numpy() for col in ("datum_id", "pd_id")]
373
+ ).astype(np.int64)
374
+ for pd in np.unique(pd_pairs, axis=0):
375
+ mask_valid_pd |= (pd_ids == pd).all(axis=1)
376
+ else:
377
+ mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
378
+
379
+ mask_valid = mask_valid_gt | mask_valid_pd
380
+ mask_valid_gt &= mask_valid
381
+ mask_valid_pd &= mask_valid
382
+
383
+ # filter out invalid gt_id, gt_label_id by setting to -1.0
384
+ pairs[np.ix_(~mask_valid_gt, (1, 3))] = -1.0 # type: ignore[reportArgumentType]
385
+ # filter out invalid pd_id, pd_label_id, pd_score by setting to -1.0
386
+ pairs[np.ix_(~mask_valid_pd, (2, 4, 6))] = -1.0 # type: ignore[reportArgumentType]
387
+ # filter out invalid iou by setting to 0.0
388
+ pairs[~mask_valid_pd | ~mask_valid_gt, 5] = 0.0
389
+
390
+ for idx, col in enumerate(columns):
391
+ column = pairs[:, idx]
392
+ if col not in {"iou", "pd_score"}:
393
+ column = column.astype(np.int64)
394
+
395
+ col_idx = tbl.schema.names.index(col)
396
+ tbl = tbl.set_column(
397
+ col_idx, tbl.schema[col_idx], pa.array(column)
398
+ )
399
+
400
+ mask_invalid = ~mask_valid | (pairs[:, (1, 2)] < 0).all(axis=1)
401
+ filtered_tbl = tbl.filter(pa.array(~mask_invalid))
402
+ loader._detailed_writer.write_table(filtered_tbl)
403
+
404
+ return loader.finalize(
405
+ batch_size=batch_size,
406
+ index_to_label_override=self._index_to_label,
407
+ )
408
+
409
+ def compute_precision_recall(
410
+ self,
411
+ iou_thresholds: list[float],
412
+ score_thresholds: list[float],
413
+ datums: pc.Expression | None = None,
414
+ ) -> dict[MetricType, list[Metric]]:
415
+ """
416
+ Computes all metrics except for ConfusionMatrix
417
+
418
+ Parameters
419
+ ----------
420
+ iou_thresholds : list[float]
421
+ A list of IOU thresholds to compute metrics over.
422
+ score_thresholds : list[float]
423
+ A list of score thresholds to compute metrics over.
424
+ datums : pyarrow.compute.Expression, optional
425
+ Option to filter datums by an expression.
426
+
427
+ Returns
428
+ -------
429
+ dict[MetricType, list]
430
+ A dictionary mapping MetricType enumerations to lists of computed metrics.
431
+ """
432
+ if not iou_thresholds:
433
+ raise ValueError("At least one IOU threshold must be passed.")
434
+ elif not score_thresholds:
435
+ raise ValueError("At least one score threshold must be passed.")
436
+
437
+ n_ious = len(iou_thresholds)
438
+ n_scores = len(score_thresholds)
439
+ n_labels = len(self._index_to_label)
440
+ n_gts_per_lbl = extract_groundtruth_count_per_label(
441
+ reader=self._detailed_reader,
442
+ number_of_labels=n_labels,
443
+ datums=datums,
444
+ )
445
+
446
+ counts = np.zeros((n_ious, n_scores, 3, n_labels), dtype=np.uint64)
447
+ pr_curve = np.zeros((n_ious, n_labels, 101, 2), dtype=np.float64)
448
+ running_counts = np.zeros((n_ious, n_labels, 2), dtype=np.uint64)
449
+
450
+ for pairs in self._ranked_reader.iterate_arrays(
451
+ numeric_columns=[
452
+ "datum_id",
453
+ "gt_id",
454
+ "pd_id",
455
+ "gt_label_id",
456
+ "pd_label_id",
457
+ "iou",
458
+ "pd_score",
459
+ "iou_prev",
460
+ ],
461
+ filter=datums,
462
+ ):
463
+ if pairs.size == 0:
464
+ continue
465
+
466
+ batch_counts = compute_counts(
467
+ ranked_pairs=pairs,
468
+ iou_thresholds=np.array(iou_thresholds),
469
+ score_thresholds=np.array(score_thresholds),
470
+ number_of_groundtruths_per_label=n_gts_per_lbl,
471
+ number_of_labels=len(self._index_to_label),
472
+ running_counts=running_counts,
473
+ pr_curve=pr_curve,
474
+ )
475
+ counts += batch_counts
476
+
477
+ # fn count
478
+ counts[:, :, 2, :] = n_gts_per_lbl - counts[:, :, 0, :]
479
+
480
+ precision_recall_f1 = compute_precision_recall_f1(
481
+ counts=counts,
482
+ number_of_groundtruths_per_label=n_gts_per_lbl,
483
+ )
484
+ (
485
+ average_precision,
486
+ mean_average_precision,
487
+ pr_curve,
488
+ ) = compute_average_precision(pr_curve=pr_curve)
489
+ average_recall, mean_average_recall = compute_average_recall(
490
+ prec_rec_f1=precision_recall_f1
491
+ )
492
+
493
+ return unpack_precision_recall_into_metric_lists(
494
+ counts=counts,
495
+ precision_recall_f1=precision_recall_f1,
496
+ average_precision=average_precision,
497
+ mean_average_precision=mean_average_precision,
498
+ average_recall=average_recall,
499
+ mean_average_recall=mean_average_recall,
500
+ pr_curve=pr_curve,
501
+ iou_thresholds=iou_thresholds,
502
+ score_thresholds=score_thresholds,
503
+ index_to_label=self._index_to_label,
504
+ )
505
+
506
+ def compute_confusion_matrix(
507
+ self,
508
+ iou_thresholds: list[float],
509
+ score_thresholds: list[float],
510
+ datums: pc.Expression | None = None,
511
+ ) -> list[Metric]:
512
+ """
513
+ Computes confusion matrices at various thresholds.
514
+
515
+ Parameters
516
+ ----------
517
+ iou_thresholds : list[float]
518
+ A list of IOU thresholds to compute metrics over.
519
+ score_thresholds : list[float]
520
+ A list of score thresholds to compute metrics over.
521
+ datums : pyarrow.compute.Expression, optional
522
+ Option to filter datums by an expression.
523
+
524
+ Returns
525
+ -------
526
+ list[Metric]
527
+ List of confusion matrices per threshold pair.
528
+ """
529
+ if not iou_thresholds:
530
+ raise ValueError("At least one IOU threshold must be passed.")
531
+ elif not score_thresholds:
532
+ raise ValueError("At least one score threshold must be passed.")
533
+
534
+ n_ious = len(iou_thresholds)
535
+ n_scores = len(score_thresholds)
536
+ n_labels = len(self._index_to_label)
537
+
538
+ confusion_matrices = np.zeros(
539
+ (n_ious, n_scores, n_labels, n_labels), dtype=np.uint64
540
+ )
541
+ unmatched_groundtruths = np.zeros(
542
+ (n_ious, n_scores, n_labels), dtype=np.uint64
543
+ )
544
+ unmatched_predictions = np.zeros_like(unmatched_groundtruths)
545
+
546
+ for pairs in self._detailed_reader.iterate_arrays(
547
+ numeric_columns=[
548
+ "datum_id",
549
+ "gt_id",
550
+ "pd_id",
551
+ "gt_label_id",
552
+ "pd_label_id",
553
+ "iou",
554
+ "pd_score",
555
+ ],
556
+ filter=datums,
557
+ ):
558
+ if pairs.size == 0:
559
+ continue
560
+
561
+ (
562
+ batch_mask_tp,
563
+ batch_mask_fp_fn_misclf,
564
+ batch_mask_fp_unmatched,
565
+ batch_mask_fn_unmatched,
566
+ ) = compute_pair_classifications(
567
+ detailed_pairs=pairs,
568
+ iou_thresholds=np.array(iou_thresholds),
569
+ score_thresholds=np.array(score_thresholds),
570
+ )
571
+ (
572
+ batch_confusion_matrices,
573
+ batch_unmatched_groundtruths,
574
+ batch_unmatched_predictions,
575
+ ) = compute_confusion_matrix(
576
+ detailed_pairs=pairs,
577
+ mask_tp=batch_mask_tp,
578
+ mask_fp_fn_misclf=batch_mask_fp_fn_misclf,
579
+ mask_fp_unmatched=batch_mask_fp_unmatched,
580
+ mask_fn_unmatched=batch_mask_fn_unmatched,
581
+ number_of_labels=n_labels,
582
+ iou_thresholds=np.array(iou_thresholds),
583
+ score_thresholds=np.array(score_thresholds),
584
+ )
585
+ confusion_matrices += batch_confusion_matrices
586
+ unmatched_groundtruths += batch_unmatched_groundtruths
587
+ unmatched_predictions += batch_unmatched_predictions
588
+
589
+ return unpack_confusion_matrix(
590
+ confusion_matrices=confusion_matrices,
591
+ unmatched_groundtruths=unmatched_groundtruths,
592
+ unmatched_predictions=unmatched_predictions,
593
+ index_to_label=self._index_to_label,
594
+ iou_thresholds=iou_thresholds,
595
+ score_thresholds=score_thresholds,
596
+ )
597
+
598
+ def compute_examples(
599
+ self,
600
+ iou_thresholds: list[float],
601
+ score_thresholds: list[float],
602
+ datums: pc.Expression | None = None,
603
+ limit: int | None = None,
604
+ offset: int = 0,
605
+ ) -> list[Metric]:
606
+ """
607
+ Computes examples at various thresholds.
608
+
609
+ This function can use a lot of memory with larger or high density datasets. Please use it with filters.
610
+
611
+ Parameters
612
+ ----------
613
+ iou_thresholds : list[float]
614
+ A list of IOU thresholds to compute metrics over.
615
+ score_thresholds : list[float]
616
+ A list of score thresholds to compute metrics over.
617
+ datums : pyarrow.compute.Expression, optional
618
+ Option to filter datums by an expression.
619
+ limit : int, optional
620
+ Option to set a limit to the number of returned datum examples.
621
+ offset : int, default=0
622
+ Option to offset where examples are being created in the datum index.
623
+
624
+ Returns
625
+ -------
626
+ list[Metric]
627
+ List of confusion matrices per threshold pair.
628
+ """
629
+ if not iou_thresholds:
630
+ raise ValueError("At least one IOU threshold must be passed.")
631
+ elif not score_thresholds:
632
+ raise ValueError("At least one score threshold must be passed.")
633
+
634
+ metrics = []
635
+ numeric_columns = [
636
+ "datum_id",
637
+ "gt_id",
638
+ "pd_id",
639
+ "gt_label_id",
640
+ "pd_label_id",
641
+ "iou",
642
+ "pd_score",
643
+ ]
644
+ for tbl in compute.paginate_index(
645
+ source=self._detailed_reader,
646
+ column_key="datum_id",
647
+ modifier=datums,
648
+ limit=limit,
649
+ offset=offset,
650
+ ):
651
+ if tbl.num_rows == 0:
652
+ continue
653
+
654
+ pairs = np.column_stack(
655
+ [tbl[col].to_numpy() for col in numeric_columns]
656
+ )
657
+
658
+ index_to_datum_id = {}
659
+ index_to_groundtruth_id = {}
660
+ index_to_prediction_id = {}
661
+
662
+ # extract external identifiers
663
+ index_to_datum_id = create_mapping(
664
+ tbl, pairs, 0, "datum_id", "datum_uid"
665
+ )
666
+ index_to_groundtruth_id = create_mapping(
667
+ tbl, pairs, 1, "gt_id", "gt_uid"
668
+ )
669
+ index_to_prediction_id = create_mapping(
670
+ tbl, pairs, 2, "pd_id", "pd_uid"
671
+ )
672
+
673
+ (
674
+ mask_tp,
675
+ mask_fp_fn_misclf,
676
+ mask_fp_unmatched,
677
+ mask_fn_unmatched,
678
+ ) = compute_pair_classifications(
679
+ detailed_pairs=pairs,
680
+ iou_thresholds=np.array(iou_thresholds),
681
+ score_thresholds=np.array(score_thresholds),
682
+ )
683
+
684
+ mask_fn = mask_fp_fn_misclf | mask_fn_unmatched
685
+ mask_fp = mask_fp_fn_misclf | mask_fp_unmatched
686
+
687
+ batch_examples = unpack_examples(
688
+ detailed_pairs=pairs,
689
+ mask_tp=mask_tp,
690
+ mask_fp=mask_fp,
691
+ mask_fn=mask_fn,
692
+ index_to_datum_id=index_to_datum_id,
693
+ index_to_groundtruth_id=index_to_groundtruth_id,
694
+ index_to_prediction_id=index_to_prediction_id,
695
+ iou_thresholds=iou_thresholds,
696
+ score_thresholds=score_thresholds,
697
+ )
698
+ metrics.extend(batch_examples)
699
+
700
+ return metrics
701
+
702
+ def compute_confusion_matrix_with_examples(
703
+ self,
704
+ iou_thresholds: list[float],
705
+ score_thresholds: list[float],
706
+ datums: pc.Expression | None = None,
707
+ ) -> list[Metric]:
708
+ """
709
+ Computes confusion matrix with examples at various thresholds.
710
+
711
+ This function can use a lot of memory with larger or high density datasets. Please use it with filters.
712
+
713
+ Parameters
714
+ ----------
715
+ iou_thresholds : list[float]
716
+ A list of IOU thresholds to compute metrics over.
717
+ score_thresholds : list[float]
718
+ A list of score thresholds to compute metrics over.
719
+ datums : pyarrow.compute.Expression, optional
720
+ Option to filter datums by an expression.
721
+
722
+ Returns
723
+ -------
724
+ list[Metric]
725
+ List of confusion matrices per threshold pair.
726
+ """
727
+ if not iou_thresholds:
728
+ raise ValueError("At least one IOU threshold must be passed.")
729
+ elif not score_thresholds:
730
+ raise ValueError("At least one score threshold must be passed.")
731
+
732
+ metrics = {
733
+ iou_idx: {
734
+ score_idx: create_empty_confusion_matrix_with_examples(
735
+ iou_threhsold=iou_thresh,
736
+ score_threshold=score_thresh,
737
+ index_to_label=self._index_to_label,
738
+ )
739
+ for score_idx, score_thresh in enumerate(score_thresholds)
740
+ }
741
+ for iou_idx, iou_thresh in enumerate(iou_thresholds)
742
+ }
743
+ tbl_columns = [
744
+ "datum_uid",
745
+ "gt_uid",
746
+ "pd_uid",
747
+ ]
748
+ numeric_columns = [
749
+ "datum_id",
750
+ "gt_id",
751
+ "pd_id",
752
+ "gt_label_id",
753
+ "pd_label_id",
754
+ "iou",
755
+ "pd_score",
756
+ ]
757
+ for tbl, pairs in self._detailed_reader.iterate_tables_with_arrays(
758
+ columns=tbl_columns + numeric_columns,
759
+ numeric_columns=numeric_columns,
760
+ filter=datums,
761
+ ):
762
+ if pairs.size == 0:
763
+ continue
764
+
765
+ index_to_datum_id = {}
766
+ index_to_groundtruth_id = {}
767
+ index_to_prediction_id = {}
768
+
769
+ # extract external identifiers
770
+ index_to_datum_id = create_mapping(
771
+ tbl, pairs, 0, "datum_id", "datum_uid"
772
+ )
773
+ index_to_groundtruth_id = create_mapping(
774
+ tbl, pairs, 1, "gt_id", "gt_uid"
775
+ )
776
+ index_to_prediction_id = create_mapping(
777
+ tbl, pairs, 2, "pd_id", "pd_uid"
778
+ )
779
+
780
+ (
781
+ mask_tp,
782
+ mask_fp_fn_misclf,
783
+ mask_fp_unmatched,
784
+ mask_fn_unmatched,
785
+ ) = compute_pair_classifications(
786
+ detailed_pairs=pairs,
787
+ iou_thresholds=np.array(iou_thresholds),
788
+ score_thresholds=np.array(score_thresholds),
789
+ )
790
+
791
+ unpack_confusion_matrix_with_examples(
792
+ metrics=metrics,
793
+ detailed_pairs=pairs,
794
+ mask_tp=mask_tp,
795
+ mask_fp_fn_misclf=mask_fp_fn_misclf,
796
+ mask_fp_unmatched=mask_fp_unmatched,
797
+ mask_fn_unmatched=mask_fn_unmatched,
798
+ index_to_datum_id=index_to_datum_id,
799
+ index_to_groundtruth_id=index_to_groundtruth_id,
800
+ index_to_prediction_id=index_to_prediction_id,
801
+ index_to_label=self._index_to_label,
802
+ )
803
+
804
+ return [m for inner in metrics.values() for m in inner.values()]