valor-lite 0.37.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of valor-lite might be problematic. Click here for more details.

Files changed (49) hide show
  1. valor_lite/LICENSE +21 -0
  2. valor_lite/__init__.py +0 -0
  3. valor_lite/cache/__init__.py +11 -0
  4. valor_lite/cache/compute.py +154 -0
  5. valor_lite/cache/ephemeral.py +302 -0
  6. valor_lite/cache/persistent.py +529 -0
  7. valor_lite/classification/__init__.py +14 -0
  8. valor_lite/classification/annotation.py +45 -0
  9. valor_lite/classification/computation.py +378 -0
  10. valor_lite/classification/evaluator.py +879 -0
  11. valor_lite/classification/loader.py +97 -0
  12. valor_lite/classification/metric.py +535 -0
  13. valor_lite/classification/numpy_compatibility.py +13 -0
  14. valor_lite/classification/shared.py +184 -0
  15. valor_lite/classification/utilities.py +314 -0
  16. valor_lite/exceptions.py +20 -0
  17. valor_lite/object_detection/__init__.py +17 -0
  18. valor_lite/object_detection/annotation.py +238 -0
  19. valor_lite/object_detection/computation.py +841 -0
  20. valor_lite/object_detection/evaluator.py +805 -0
  21. valor_lite/object_detection/loader.py +292 -0
  22. valor_lite/object_detection/metric.py +850 -0
  23. valor_lite/object_detection/shared.py +185 -0
  24. valor_lite/object_detection/utilities.py +396 -0
  25. valor_lite/schemas.py +11 -0
  26. valor_lite/semantic_segmentation/__init__.py +15 -0
  27. valor_lite/semantic_segmentation/annotation.py +123 -0
  28. valor_lite/semantic_segmentation/computation.py +165 -0
  29. valor_lite/semantic_segmentation/evaluator.py +414 -0
  30. valor_lite/semantic_segmentation/loader.py +205 -0
  31. valor_lite/semantic_segmentation/metric.py +275 -0
  32. valor_lite/semantic_segmentation/shared.py +149 -0
  33. valor_lite/semantic_segmentation/utilities.py +88 -0
  34. valor_lite/text_generation/__init__.py +15 -0
  35. valor_lite/text_generation/annotation.py +56 -0
  36. valor_lite/text_generation/computation.py +611 -0
  37. valor_lite/text_generation/llm/__init__.py +0 -0
  38. valor_lite/text_generation/llm/exceptions.py +14 -0
  39. valor_lite/text_generation/llm/generation.py +903 -0
  40. valor_lite/text_generation/llm/instructions.py +814 -0
  41. valor_lite/text_generation/llm/integrations.py +226 -0
  42. valor_lite/text_generation/llm/utilities.py +43 -0
  43. valor_lite/text_generation/llm/validators.py +68 -0
  44. valor_lite/text_generation/manager.py +697 -0
  45. valor_lite/text_generation/metric.py +381 -0
  46. valor_lite-0.37.1.dist-info/METADATA +174 -0
  47. valor_lite-0.37.1.dist-info/RECORD +49 -0
  48. valor_lite-0.37.1.dist-info/WHEEL +5 -0
  49. valor_lite-0.37.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,805 @@
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import pyarrow as pa
6
+ import pyarrow.compute as pc
7
+
8
+ from valor_lite.cache import (
9
+ FileCacheReader,
10
+ FileCacheWriter,
11
+ MemoryCacheReader,
12
+ MemoryCacheWriter,
13
+ compute,
14
+ )
15
+ from valor_lite.exceptions import EmptyCacheError
16
+ from valor_lite.object_detection.computation import (
17
+ compute_average_precision,
18
+ compute_average_recall,
19
+ compute_confusion_matrix,
20
+ compute_counts,
21
+ compute_pair_classifications,
22
+ compute_precision_recall_f1,
23
+ rank_table,
24
+ )
25
+ from valor_lite.object_detection.metric import Metric, MetricType
26
+ from valor_lite.object_detection.shared import (
27
+ EvaluatorInfo,
28
+ decode_metadata_fields,
29
+ encode_metadata_fields,
30
+ extract_counts,
31
+ extract_groundtruth_count_per_label,
32
+ extract_labels,
33
+ generate_detailed_cache_path,
34
+ generate_detailed_schema,
35
+ generate_metadata_path,
36
+ generate_ranked_cache_path,
37
+ generate_ranked_schema,
38
+ )
39
+ from valor_lite.object_detection.utilities import (
40
+ create_empty_confusion_matrix_with_examples,
41
+ create_mapping,
42
+ unpack_confusion_matrix,
43
+ unpack_confusion_matrix_with_examples,
44
+ unpack_examples,
45
+ unpack_precision_recall_into_metric_lists,
46
+ )
47
+
48
+
49
+ class Builder:
50
+ def __init__(
51
+ self,
52
+ detailed_writer: MemoryCacheWriter | FileCacheWriter,
53
+ ranked_writer: MemoryCacheWriter | FileCacheWriter,
54
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
55
+ ):
56
+ self._detailed_writer = detailed_writer
57
+ self._ranked_writer = ranked_writer
58
+ self._metadata_fields = metadata_fields
59
+
60
+ @classmethod
61
+ def in_memory(
62
+ cls,
63
+ batch_size: int = 10_000,
64
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
65
+ ):
66
+ """
67
+ Create an in-memory evaluator cache.
68
+
69
+ Parameters
70
+ ----------
71
+ batch_size : int, default=10_000
72
+ The target number of rows to buffer before writing to the cache. Defaults to 10_000.
73
+ metadata_fields : list[tuple[str, str | pa.DataType]], optional
74
+ Optional datum metadata field definitions.
75
+ """
76
+ # create cache
77
+ detailed_writer = MemoryCacheWriter.create(
78
+ schema=generate_detailed_schema(metadata_fields),
79
+ batch_size=batch_size,
80
+ )
81
+ ranked_writer = MemoryCacheWriter.create(
82
+ schema=generate_ranked_schema(),
83
+ batch_size=batch_size,
84
+ )
85
+
86
+ return cls(
87
+ detailed_writer=detailed_writer,
88
+ ranked_writer=ranked_writer,
89
+ metadata_fields=metadata_fields,
90
+ )
91
+
92
+ @classmethod
93
+ def persistent(
94
+ cls,
95
+ path: str | Path,
96
+ batch_size: int = 10_000,
97
+ rows_per_file: int = 100_000,
98
+ compression: str = "snappy",
99
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
100
+ ):
101
+ """
102
+ Create a persistent file-based evaluator cache.
103
+
104
+ Parameters
105
+ ----------
106
+ path : str | Path
107
+ Where to store the file-based cache.
108
+ batch_size : int, default=10_000
109
+ The target number of rows to buffer before writing to the cache. Defaults to 10_000.
110
+ rows_per_file : int, default=100_000
111
+ The target number of rows to store per cache file. Defaults to 100_000.
112
+ compression : str, default="snappy"
113
+ The compression methods used when writing cache files.
114
+ metadata_fields : list[tuple[str, str | pa.DataType]], optional
115
+ Optional metadata field definitions.
116
+ """
117
+ path = Path(path)
118
+
119
+ # create caches
120
+ detailed_writer = FileCacheWriter.create(
121
+ path=generate_detailed_cache_path(path),
122
+ schema=generate_detailed_schema(metadata_fields),
123
+ batch_size=batch_size,
124
+ rows_per_file=rows_per_file,
125
+ compression=compression,
126
+ )
127
+ ranked_writer = FileCacheWriter.create(
128
+ path=generate_ranked_cache_path(path),
129
+ schema=generate_ranked_schema(),
130
+ batch_size=batch_size,
131
+ rows_per_file=rows_per_file,
132
+ compression=compression,
133
+ )
134
+
135
+ # write metadata
136
+ metadata_path = generate_metadata_path(path)
137
+ with open(metadata_path, "w") as f:
138
+ encoded_types = encode_metadata_fields(metadata_fields)
139
+ json.dump(encoded_types, f, indent=2)
140
+
141
+ return cls(
142
+ detailed_writer=detailed_writer,
143
+ ranked_writer=ranked_writer,
144
+ metadata_fields=metadata_fields,
145
+ )
146
+
147
+ def _rank(
148
+ self,
149
+ n_labels: int,
150
+ batch_size: int = 1_000,
151
+ ):
152
+ """Perform pair ranking over the detailed cache."""
153
+
154
+ detailed_reader = self._detailed_writer.to_reader()
155
+ compute.sort(
156
+ source=detailed_reader,
157
+ sink=self._ranked_writer,
158
+ batch_size=batch_size,
159
+ sorting=[
160
+ ("pd_score", "descending"),
161
+ ("iou", "descending"),
162
+ ],
163
+ columns=[
164
+ field.name
165
+ for field in self._ranked_writer.schema
166
+ if field.name not in {"high_score", "iou_prev"}
167
+ ],
168
+ table_sort_override=lambda tbl: rank_table(
169
+ tbl, number_of_labels=n_labels
170
+ ),
171
+ )
172
+ self._ranked_writer.flush()
173
+
174
+ def finalize(
175
+ self,
176
+ batch_size: int = 1_000,
177
+ index_to_label_override: dict[int, str] | None = None,
178
+ ):
179
+ """
180
+ Performs data finalization and preprocessing.
181
+
182
+ Parameters
183
+ ----------
184
+ batch_size : int, default=1_000
185
+ Sets the batch size for reading. Defaults to 1_000.
186
+ index_to_label_override : dict[int, str], optional
187
+ Pre-configures label mapping. Used when operating over filtered subsets.
188
+ """
189
+ self._detailed_writer.flush()
190
+ if self._detailed_writer.count_rows() == 0:
191
+ raise EmptyCacheError()
192
+
193
+ self._detailed_writer.sort_by(
194
+ [
195
+ ("pd_score", "descending"),
196
+ ("iou", "descending"),
197
+ ]
198
+ )
199
+ detailed_reader = self._detailed_writer.to_reader()
200
+
201
+ # extract labels
202
+ index_to_label = extract_labels(
203
+ reader=detailed_reader,
204
+ index_to_label_override=index_to_label_override,
205
+ )
206
+
207
+ # populate ranked cache
208
+ self._rank(
209
+ n_labels=len(index_to_label),
210
+ batch_size=batch_size,
211
+ )
212
+
213
+ ranked_reader = self._ranked_writer.to_reader()
214
+ return Evaluator(
215
+ detailed_reader=detailed_reader,
216
+ ranked_reader=ranked_reader,
217
+ index_to_label=index_to_label,
218
+ metadata_fields=self._metadata_fields,
219
+ )
220
+
221
+
222
+ class Evaluator:
223
+ def __init__(
224
+ self,
225
+ detailed_reader: MemoryCacheReader | FileCacheReader,
226
+ ranked_reader: MemoryCacheReader | FileCacheReader,
227
+ index_to_label: dict[int, str],
228
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
229
+ ):
230
+ self._detailed_reader = detailed_reader
231
+ self._ranked_reader = ranked_reader
232
+ self._index_to_label = index_to_label
233
+ self._metadata_fields = metadata_fields
234
+
235
+ @property
236
+ def info(self) -> EvaluatorInfo:
237
+ return self.get_info()
238
+
239
+ def get_info(
240
+ self,
241
+ datums: pc.Expression | None = None,
242
+ groundtruths: pc.Expression | None = None,
243
+ predictions: pc.Expression | None = None,
244
+ ) -> EvaluatorInfo:
245
+ info = EvaluatorInfo()
246
+ info.number_of_rows = self._detailed_reader.count_rows()
247
+ info.metadata_fields = self._metadata_fields
248
+ info.number_of_labels = len(self._index_to_label)
249
+ (
250
+ info.number_of_datums,
251
+ info.number_of_groundtruth_annotations,
252
+ info.number_of_prediction_annotations,
253
+ ) = extract_counts(
254
+ reader=self._detailed_reader,
255
+ datums=datums,
256
+ groundtruths=groundtruths,
257
+ predictions=predictions,
258
+ )
259
+ return info
260
+
261
+ @classmethod
262
+ def load(
263
+ cls,
264
+ path: str | Path,
265
+ index_to_label_override: dict[int, str] | None = None,
266
+ ):
267
+ # validate path
268
+ path = Path(path)
269
+ if not path.exists():
270
+ raise FileNotFoundError(f"Directory does not exist: {path}")
271
+ elif not path.is_dir():
272
+ raise NotADirectoryError(
273
+ f"Path exists but is not a directory: {path}"
274
+ )
275
+
276
+ detailed_reader = FileCacheReader.load(
277
+ generate_detailed_cache_path(path)
278
+ )
279
+ ranked_reader = FileCacheReader.load(generate_ranked_cache_path(path))
280
+
281
+ # extract labels from cache
282
+ index_to_label = extract_labels(
283
+ reader=detailed_reader,
284
+ index_to_label_override=index_to_label_override,
285
+ )
286
+
287
+ # read config
288
+ metadata_path = generate_metadata_path(path)
289
+ metadata_fields = None
290
+ with open(metadata_path, "r") as f:
291
+ encoded_metadata_types = json.load(f)
292
+ metadata_fields = decode_metadata_fields(encoded_metadata_types)
293
+
294
+ return cls(
295
+ detailed_reader=detailed_reader,
296
+ ranked_reader=ranked_reader,
297
+ index_to_label=index_to_label,
298
+ metadata_fields=metadata_fields,
299
+ )
300
+
301
+ def filter(
302
+ self,
303
+ datums: pc.Expression | None = None,
304
+ groundtruths: pc.Expression | None = None,
305
+ predictions: pc.Expression | None = None,
306
+ batch_size: int = 1_000,
307
+ path: str | Path | None = None,
308
+ ) -> "Evaluator":
309
+ """
310
+ Filter evaluator cache.
311
+
312
+ Parameters
313
+ ----------
314
+ datums : pc.Expression | None = None
315
+ A filter expression used to filter datums.
316
+ groundtruths : pc.Expression | None = None
317
+ A filter expression used to filter ground truth annotations.
318
+ predictions : pc.Expression | None = None
319
+ A filter expression used to filter predictions.
320
+ batch_size : int
321
+ The maximum number of rows read into memory per file.
322
+ path : str | Path
323
+ Where to store the filtered cache if storing on disk.
324
+
325
+ Returns
326
+ -------
327
+ Evaluator
328
+ A new evaluator object containing the filtered cache.
329
+ """
330
+ from valor_lite.object_detection.loader import Loader
331
+
332
+ if isinstance(self._detailed_reader, FileCacheReader):
333
+ if not path:
334
+ raise ValueError(
335
+ "expected path to be defined for file-based loader"
336
+ )
337
+ loader = Loader.persistent(
338
+ path=path,
339
+ batch_size=self._detailed_reader.batch_size,
340
+ rows_per_file=self._detailed_reader.rows_per_file,
341
+ compression=self._detailed_reader.compression,
342
+ metadata_fields=self._metadata_fields,
343
+ )
344
+ else:
345
+ loader = Loader.in_memory(
346
+ batch_size=self._detailed_reader.batch_size,
347
+ metadata_fields=self._metadata_fields,
348
+ )
349
+
350
+ for tbl in self._detailed_reader.iterate_tables(filter=datums):
351
+ columns = (
352
+ "datum_id",
353
+ "gt_id",
354
+ "pd_id",
355
+ "gt_label_id",
356
+ "pd_label_id",
357
+ "iou",
358
+ "pd_score",
359
+ )
360
+ pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
361
+
362
+ n_pairs = pairs.shape[0]
363
+ gt_ids = pairs[:, (0, 1)].astype(np.int64)
364
+ pd_ids = pairs[:, (0, 2)].astype(np.int64)
365
+
366
+ if groundtruths is not None:
367
+ mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
368
+ gt_tbl = tbl.filter(groundtruths)
369
+ gt_pairs = np.column_stack(
370
+ [gt_tbl[col].to_numpy() for col in ("datum_id", "gt_id")]
371
+ ).astype(np.int64)
372
+ for gt in np.unique(gt_pairs, axis=0):
373
+ mask_valid_gt |= (gt_ids == gt).all(axis=1)
374
+ else:
375
+ mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
376
+
377
+ if predictions is not None:
378
+ mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
379
+ pd_tbl = tbl.filter(predictions)
380
+ pd_pairs = np.column_stack(
381
+ [pd_tbl[col].to_numpy() for col in ("datum_id", "pd_id")]
382
+ ).astype(np.int64)
383
+ for pd in np.unique(pd_pairs, axis=0):
384
+ mask_valid_pd |= (pd_ids == pd).all(axis=1)
385
+ else:
386
+ mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
387
+
388
+ mask_valid = mask_valid_gt | mask_valid_pd
389
+ mask_valid_gt &= mask_valid
390
+ mask_valid_pd &= mask_valid
391
+
392
+ # filter out invalid gt_id, gt_label_id by setting to -1.0
393
+ pairs[np.ix_(~mask_valid_gt, (1, 3))] = -1.0 # type: ignore[reportArgumentType]
394
+ # filter out invalid pd_id, pd_label_id, pd_score by setting to -1.0
395
+ pairs[np.ix_(~mask_valid_pd, (2, 4, 6))] = -1.0 # type: ignore[reportArgumentType]
396
+ # filter out invalid iou by setting to 0.0
397
+ pairs[~mask_valid_pd | ~mask_valid_gt, 5] = 0.0
398
+
399
+ for idx, col in enumerate(columns):
400
+ column = pairs[:, idx]
401
+ if col not in {"iou", "pd_score"}:
402
+ column = column.astype(np.int64)
403
+
404
+ col_idx = tbl.schema.names.index(col)
405
+ tbl = tbl.set_column(
406
+ col_idx, tbl.schema[col_idx], pa.array(column)
407
+ )
408
+
409
+ mask_invalid = ~mask_valid | (pairs[:, (1, 2)] < 0).all(axis=1)
410
+ filtered_tbl = tbl.filter(pa.array(~mask_invalid))
411
+ loader._detailed_writer.write_table(filtered_tbl)
412
+
413
+ return loader.finalize(
414
+ batch_size=batch_size,
415
+ index_to_label_override=self._index_to_label,
416
+ )
417
+
418
+ def compute_precision_recall(
419
+ self,
420
+ iou_thresholds: list[float],
421
+ score_thresholds: list[float],
422
+ datums: pc.Expression | None = None,
423
+ ) -> dict[MetricType, list[Metric]]:
424
+ """
425
+ Computes all metrics except for ConfusionMatrix
426
+
427
+ Parameters
428
+ ----------
429
+ iou_thresholds : list[float]
430
+ A list of IOU thresholds to compute metrics over.
431
+ score_thresholds : list[float]
432
+ A list of score thresholds to compute metrics over.
433
+ datums : pyarrow.compute.Expression, optional
434
+ Option to filter datums by an expression.
435
+
436
+ Returns
437
+ -------
438
+ dict[MetricType, list]
439
+ A dictionary mapping MetricType enumerations to lists of computed metrics.
440
+ """
441
+ if not iou_thresholds:
442
+ raise ValueError("At least one IOU threshold must be passed.")
443
+ elif not score_thresholds:
444
+ raise ValueError("At least one score threshold must be passed.")
445
+
446
+ n_ious = len(iou_thresholds)
447
+ n_scores = len(score_thresholds)
448
+ n_labels = len(self._index_to_label)
449
+ n_gts_per_lbl = extract_groundtruth_count_per_label(
450
+ reader=self._detailed_reader,
451
+ number_of_labels=n_labels,
452
+ datums=datums,
453
+ )
454
+
455
+ counts = np.zeros((n_ious, n_scores, 3, n_labels), dtype=np.uint64)
456
+ pr_curve = np.zeros((n_ious, n_labels, 101, 2), dtype=np.float64)
457
+ running_counts = np.ones((n_ious, n_labels, 2), dtype=np.uint64)
458
+ for pairs in self._ranked_reader.iterate_arrays(
459
+ numeric_columns=[
460
+ "datum_id",
461
+ "gt_id",
462
+ "pd_id",
463
+ "gt_label_id",
464
+ "pd_label_id",
465
+ "iou",
466
+ "pd_score",
467
+ "iou_prev",
468
+ "high_score",
469
+ ],
470
+ filter=datums,
471
+ ):
472
+ if pairs.size == 0:
473
+ continue
474
+
475
+ (batch_counts, batch_pr_curve) = compute_counts(
476
+ ranked_pairs=pairs,
477
+ iou_thresholds=np.array(iou_thresholds),
478
+ score_thresholds=np.array(score_thresholds),
479
+ number_of_groundtruths_per_label=n_gts_per_lbl,
480
+ number_of_labels=len(self._index_to_label),
481
+ running_counts=running_counts,
482
+ )
483
+ counts += batch_counts
484
+ pr_curve = np.maximum(batch_pr_curve, pr_curve)
485
+
486
+ # fn count
487
+ counts[:, :, 2, :] = n_gts_per_lbl - counts[:, :, 0, :]
488
+
489
+ precision_recall_f1 = compute_precision_recall_f1(
490
+ counts=counts,
491
+ number_of_groundtruths_per_label=n_gts_per_lbl,
492
+ )
493
+ (
494
+ average_precision,
495
+ mean_average_precision,
496
+ pr_curve,
497
+ ) = compute_average_precision(pr_curve=pr_curve)
498
+ average_recall, mean_average_recall = compute_average_recall(
499
+ prec_rec_f1=precision_recall_f1
500
+ )
501
+
502
+ return unpack_precision_recall_into_metric_lists(
503
+ counts=counts,
504
+ precision_recall_f1=precision_recall_f1,
505
+ average_precision=average_precision,
506
+ mean_average_precision=mean_average_precision,
507
+ average_recall=average_recall,
508
+ mean_average_recall=mean_average_recall,
509
+ pr_curve=pr_curve,
510
+ iou_thresholds=iou_thresholds,
511
+ score_thresholds=score_thresholds,
512
+ index_to_label=self._index_to_label,
513
+ )
514
+
515
+ def compute_confusion_matrix(
516
+ self,
517
+ iou_thresholds: list[float],
518
+ score_thresholds: list[float],
519
+ datums: pc.Expression | None = None,
520
+ ) -> list[Metric]:
521
+ """
522
+ Computes confusion matrices at various thresholds.
523
+
524
+ Parameters
525
+ ----------
526
+ iou_thresholds : list[float]
527
+ A list of IOU thresholds to compute metrics over.
528
+ score_thresholds : list[float]
529
+ A list of score thresholds to compute metrics over.
530
+ datums : pyarrow.compute.Expression, optional
531
+ Option to filter datums by an expression.
532
+
533
+ Returns
534
+ -------
535
+ list[Metric]
536
+ List of confusion matrices per threshold pair.
537
+ """
538
+ if not iou_thresholds:
539
+ raise ValueError("At least one IOU threshold must be passed.")
540
+ elif not score_thresholds:
541
+ raise ValueError("At least one score threshold must be passed.")
542
+
543
+ n_ious = len(iou_thresholds)
544
+ n_scores = len(score_thresholds)
545
+ n_labels = len(self._index_to_label)
546
+
547
+ confusion_matrices = np.zeros(
548
+ (n_ious, n_scores, n_labels, n_labels), dtype=np.uint64
549
+ )
550
+ unmatched_groundtruths = np.zeros(
551
+ (n_ious, n_scores, n_labels), dtype=np.uint64
552
+ )
553
+ unmatched_predictions = np.zeros_like(unmatched_groundtruths)
554
+ for pairs in self._detailed_reader.iterate_arrays(
555
+ numeric_columns=[
556
+ "datum_id",
557
+ "gt_id",
558
+ "pd_id",
559
+ "gt_label_id",
560
+ "pd_label_id",
561
+ "iou",
562
+ "pd_score",
563
+ ],
564
+ filter=datums,
565
+ ):
566
+ if pairs.size == 0:
567
+ continue
568
+
569
+ (
570
+ batch_mask_tp,
571
+ batch_mask_fp_fn_misclf,
572
+ batch_mask_fp_unmatched,
573
+ batch_mask_fn_unmatched,
574
+ ) = compute_pair_classifications(
575
+ detailed_pairs=pairs,
576
+ iou_thresholds=np.array(iou_thresholds),
577
+ score_thresholds=np.array(score_thresholds),
578
+ )
579
+ (
580
+ batch_confusion_matrices,
581
+ batch_unmatched_groundtruths,
582
+ batch_unmatched_predictions,
583
+ ) = compute_confusion_matrix(
584
+ detailed_pairs=pairs,
585
+ mask_tp=batch_mask_tp,
586
+ mask_fp_fn_misclf=batch_mask_fp_fn_misclf,
587
+ mask_fp_unmatched=batch_mask_fp_unmatched,
588
+ mask_fn_unmatched=batch_mask_fn_unmatched,
589
+ number_of_labels=n_labels,
590
+ iou_thresholds=np.array(iou_thresholds),
591
+ score_thresholds=np.array(score_thresholds),
592
+ )
593
+ confusion_matrices += batch_confusion_matrices
594
+ unmatched_groundtruths += batch_unmatched_groundtruths
595
+ unmatched_predictions += batch_unmatched_predictions
596
+
597
+ return unpack_confusion_matrix(
598
+ confusion_matrices=confusion_matrices,
599
+ unmatched_groundtruths=unmatched_groundtruths,
600
+ unmatched_predictions=unmatched_predictions,
601
+ index_to_label=self._index_to_label,
602
+ iou_thresholds=iou_thresholds,
603
+ score_thresholds=score_thresholds,
604
+ )
605
+
606
+ def compute_examples(
607
+ self,
608
+ iou_thresholds: list[float],
609
+ score_thresholds: list[float],
610
+ datums: pc.Expression | None = None,
611
+ ) -> list[Metric]:
612
+ """
613
+ Computes examples at various thresholds.
614
+
615
+ This function can use a lot of memory with larger or high density datasets. Please use it with filters.
616
+
617
+ Parameters
618
+ ----------
619
+ iou_thresholds : list[float]
620
+ A list of IOU thresholds to compute metrics over.
621
+ score_thresholds : list[float]
622
+ A list of score thresholds to compute metrics over.
623
+ datums : pyarrow.compute.Expression, optional
624
+ Option to filter datums by an expression.
625
+
626
+ Returns
627
+ -------
628
+ list[Metric]
629
+ List of confusion matrices per threshold pair.
630
+ """
631
+ if not iou_thresholds:
632
+ raise ValueError("At least one IOU threshold must be passed.")
633
+ elif not score_thresholds:
634
+ raise ValueError("At least one score threshold must be passed.")
635
+
636
+ metrics = []
637
+ tbl_columns = [
638
+ "datum_uid",
639
+ "gt_uid",
640
+ "pd_uid",
641
+ ]
642
+ numeric_columns = [
643
+ "datum_id",
644
+ "gt_id",
645
+ "pd_id",
646
+ "gt_label_id",
647
+ "pd_label_id",
648
+ "iou",
649
+ "pd_score",
650
+ ]
651
+ for tbl, pairs in self._detailed_reader.iterate_tables_with_arrays(
652
+ columns=tbl_columns + numeric_columns,
653
+ numeric_columns=numeric_columns,
654
+ filter=datums,
655
+ ):
656
+ if pairs.size == 0:
657
+ continue
658
+
659
+ index_to_datum_id = {}
660
+ index_to_groundtruth_id = {}
661
+ index_to_prediction_id = {}
662
+
663
+ # extract external identifiers
664
+ index_to_datum_id = create_mapping(
665
+ tbl, pairs, 0, "datum_id", "datum_uid"
666
+ )
667
+ index_to_groundtruth_id = create_mapping(
668
+ tbl, pairs, 1, "gt_id", "gt_uid"
669
+ )
670
+ index_to_prediction_id = create_mapping(
671
+ tbl, pairs, 2, "pd_id", "pd_uid"
672
+ )
673
+
674
+ (
675
+ mask_tp,
676
+ mask_fp_fn_misclf,
677
+ mask_fp_unmatched,
678
+ mask_fn_unmatched,
679
+ ) = compute_pair_classifications(
680
+ detailed_pairs=pairs,
681
+ iou_thresholds=np.array(iou_thresholds),
682
+ score_thresholds=np.array(score_thresholds),
683
+ )
684
+
685
+ mask_fn = mask_fp_fn_misclf | mask_fn_unmatched
686
+ mask_fp = mask_fp_fn_misclf | mask_fp_unmatched
687
+
688
+ batch_examples = unpack_examples(
689
+ detailed_pairs=pairs,
690
+ mask_tp=mask_tp,
691
+ mask_fp=mask_fp,
692
+ mask_fn=mask_fn,
693
+ index_to_datum_id=index_to_datum_id,
694
+ index_to_groundtruth_id=index_to_groundtruth_id,
695
+ index_to_prediction_id=index_to_prediction_id,
696
+ iou_thresholds=iou_thresholds,
697
+ score_thresholds=score_thresholds,
698
+ )
699
+ metrics.extend(batch_examples)
700
+
701
+ return metrics
702
+
703
+ def compute_confusion_matrix_with_examples(
704
+ self,
705
+ iou_thresholds: list[float],
706
+ score_thresholds: list[float],
707
+ datums: pc.Expression | None = None,
708
+ ) -> list[Metric]:
709
+ """
710
+ Computes confusion matrix with examples at various thresholds.
711
+
712
+ This function can use a lot of memory with larger or high density datasets. Please use it with filters.
713
+
714
+ Parameters
715
+ ----------
716
+ iou_thresholds : list[float]
717
+ A list of IOU thresholds to compute metrics over.
718
+ score_thresholds : list[float]
719
+ A list of score thresholds to compute metrics over.
720
+ datums : pyarrow.compute.Expression, optional
721
+ Option to filter datums by an expression.
722
+
723
+ Returns
724
+ -------
725
+ list[Metric]
726
+ List of confusion matrices per threshold pair.
727
+ """
728
+ if not iou_thresholds:
729
+ raise ValueError("At least one IOU threshold must be passed.")
730
+ elif not score_thresholds:
731
+ raise ValueError("At least one score threshold must be passed.")
732
+
733
+ metrics = {
734
+ iou_idx: {
735
+ score_idx: create_empty_confusion_matrix_with_examples(
736
+ iou_threhsold=iou_thresh,
737
+ score_threshold=score_thresh,
738
+ index_to_label=self._index_to_label,
739
+ )
740
+ for score_idx, score_thresh in enumerate(score_thresholds)
741
+ }
742
+ for iou_idx, iou_thresh in enumerate(iou_thresholds)
743
+ }
744
+ tbl_columns = [
745
+ "datum_uid",
746
+ "gt_uid",
747
+ "pd_uid",
748
+ ]
749
+ numeric_columns = [
750
+ "datum_id",
751
+ "gt_id",
752
+ "pd_id",
753
+ "gt_label_id",
754
+ "pd_label_id",
755
+ "iou",
756
+ "pd_score",
757
+ ]
758
+ for tbl, pairs in self._detailed_reader.iterate_tables_with_arrays(
759
+ columns=tbl_columns + numeric_columns,
760
+ numeric_columns=numeric_columns,
761
+ filter=datums,
762
+ ):
763
+ if pairs.size == 0:
764
+ continue
765
+
766
+ index_to_datum_id = {}
767
+ index_to_groundtruth_id = {}
768
+ index_to_prediction_id = {}
769
+
770
+ # extract external identifiers
771
+ index_to_datum_id = create_mapping(
772
+ tbl, pairs, 0, "datum_id", "datum_uid"
773
+ )
774
+ index_to_groundtruth_id = create_mapping(
775
+ tbl, pairs, 1, "gt_id", "gt_uid"
776
+ )
777
+ index_to_prediction_id = create_mapping(
778
+ tbl, pairs, 2, "pd_id", "pd_uid"
779
+ )
780
+
781
+ (
782
+ mask_tp,
783
+ mask_fp_fn_misclf,
784
+ mask_fp_unmatched,
785
+ mask_fn_unmatched,
786
+ ) = compute_pair_classifications(
787
+ detailed_pairs=pairs,
788
+ iou_thresholds=np.array(iou_thresholds),
789
+ score_thresholds=np.array(score_thresholds),
790
+ )
791
+
792
+ unpack_confusion_matrix_with_examples(
793
+ metrics=metrics,
794
+ detailed_pairs=pairs,
795
+ mask_tp=mask_tp,
796
+ mask_fp_fn_misclf=mask_fp_fn_misclf,
797
+ mask_fp_unmatched=mask_fp_unmatched,
798
+ mask_fn_unmatched=mask_fn_unmatched,
799
+ index_to_datum_id=index_to_datum_id,
800
+ index_to_groundtruth_id=index_to_groundtruth_id,
801
+ index_to_prediction_id=index_to_prediction_id,
802
+ index_to_label=self._index_to_label,
803
+ )
804
+
805
+ return [m for inner in metrics.values() for m in inner.values()]