valor-lite 0.37.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of valor-lite might be problematic. Click here for more details.

Files changed (49) hide show
  1. valor_lite/LICENSE +21 -0
  2. valor_lite/__init__.py +0 -0
  3. valor_lite/cache/__init__.py +11 -0
  4. valor_lite/cache/compute.py +154 -0
  5. valor_lite/cache/ephemeral.py +302 -0
  6. valor_lite/cache/persistent.py +529 -0
  7. valor_lite/classification/__init__.py +14 -0
  8. valor_lite/classification/annotation.py +45 -0
  9. valor_lite/classification/computation.py +378 -0
  10. valor_lite/classification/evaluator.py +879 -0
  11. valor_lite/classification/loader.py +97 -0
  12. valor_lite/classification/metric.py +535 -0
  13. valor_lite/classification/numpy_compatibility.py +13 -0
  14. valor_lite/classification/shared.py +184 -0
  15. valor_lite/classification/utilities.py +314 -0
  16. valor_lite/exceptions.py +20 -0
  17. valor_lite/object_detection/__init__.py +17 -0
  18. valor_lite/object_detection/annotation.py +238 -0
  19. valor_lite/object_detection/computation.py +841 -0
  20. valor_lite/object_detection/evaluator.py +805 -0
  21. valor_lite/object_detection/loader.py +292 -0
  22. valor_lite/object_detection/metric.py +850 -0
  23. valor_lite/object_detection/shared.py +185 -0
  24. valor_lite/object_detection/utilities.py +396 -0
  25. valor_lite/schemas.py +11 -0
  26. valor_lite/semantic_segmentation/__init__.py +15 -0
  27. valor_lite/semantic_segmentation/annotation.py +123 -0
  28. valor_lite/semantic_segmentation/computation.py +165 -0
  29. valor_lite/semantic_segmentation/evaluator.py +414 -0
  30. valor_lite/semantic_segmentation/loader.py +205 -0
  31. valor_lite/semantic_segmentation/metric.py +275 -0
  32. valor_lite/semantic_segmentation/shared.py +149 -0
  33. valor_lite/semantic_segmentation/utilities.py +88 -0
  34. valor_lite/text_generation/__init__.py +15 -0
  35. valor_lite/text_generation/annotation.py +56 -0
  36. valor_lite/text_generation/computation.py +611 -0
  37. valor_lite/text_generation/llm/__init__.py +0 -0
  38. valor_lite/text_generation/llm/exceptions.py +14 -0
  39. valor_lite/text_generation/llm/generation.py +903 -0
  40. valor_lite/text_generation/llm/instructions.py +814 -0
  41. valor_lite/text_generation/llm/integrations.py +226 -0
  42. valor_lite/text_generation/llm/utilities.py +43 -0
  43. valor_lite/text_generation/llm/validators.py +68 -0
  44. valor_lite/text_generation/manager.py +697 -0
  45. valor_lite/text_generation/metric.py +381 -0
  46. valor_lite-0.37.1.dist-info/METADATA +174 -0
  47. valor_lite-0.37.1.dist-info/RECORD +49 -0
  48. valor_lite-0.37.1.dist-info/WHEEL +5 -0
  49. valor_lite-0.37.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,879 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pyarrow as pa
8
+ import pyarrow.compute as pc
9
+
10
+ from valor_lite.cache import (
11
+ FileCacheReader,
12
+ FileCacheWriter,
13
+ MemoryCacheReader,
14
+ MemoryCacheWriter,
15
+ compute,
16
+ )
17
+ from valor_lite.classification.computation import (
18
+ compute_accuracy,
19
+ compute_confusion_matrix,
20
+ compute_counts,
21
+ compute_f1_score,
22
+ compute_pair_classifications,
23
+ compute_precision,
24
+ compute_recall,
25
+ compute_rocauc,
26
+ )
27
+ from valor_lite.classification.metric import Metric, MetricType
28
+ from valor_lite.classification.shared import (
29
+ EvaluatorInfo,
30
+ decode_metadata_fields,
31
+ encode_metadata_fields,
32
+ extract_counts,
33
+ extract_groundtruth_count_per_label,
34
+ extract_labels,
35
+ generate_cache_path,
36
+ generate_intermediate_cache_path,
37
+ generate_intermediate_schema,
38
+ generate_metadata_path,
39
+ generate_roc_curve_cache_path,
40
+ generate_roc_curve_schema,
41
+ generate_schema,
42
+ )
43
+ from valor_lite.classification.utilities import (
44
+ create_empty_confusion_matrix_with_examples,
45
+ create_mapping,
46
+ unpack_confusion_matrix,
47
+ unpack_confusion_matrix_with_examples,
48
+ unpack_examples,
49
+ unpack_precision_recall,
50
+ unpack_rocauc,
51
+ )
52
+ from valor_lite.exceptions import EmptyCacheError
53
+
54
+
55
+ class Builder:
56
+ def __init__(
57
+ self,
58
+ writer: MemoryCacheWriter | FileCacheWriter,
59
+ roc_curve_writer: MemoryCacheWriter | FileCacheWriter,
60
+ intermediate_writer: MemoryCacheWriter | FileCacheWriter,
61
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
62
+ ):
63
+ self._writer = writer
64
+ self._roc_curve_writer = roc_curve_writer
65
+ self._intermediate_writer = intermediate_writer
66
+ self._metadata_fields = metadata_fields
67
+
68
+ @classmethod
69
+ def in_memory(
70
+ cls,
71
+ batch_size: int = 10_000,
72
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
73
+ ):
74
+ """
75
+ Create an in-memory evaluator cache.
76
+
77
+ Parameters
78
+ ----------
79
+ batch_size : int, default=10_000
80
+ The target number of rows to buffer before writing to the cache. Defaults to 10_000.
81
+ metadata_fields : list[tuple[str, str | pa.DataType]], optional
82
+ Optional metadata field definitions.
83
+ """
84
+ writer = MemoryCacheWriter.create(
85
+ schema=generate_schema(metadata_fields),
86
+ batch_size=batch_size,
87
+ )
88
+ intermediate_writer = MemoryCacheWriter.create(
89
+ schema=generate_intermediate_schema(),
90
+ batch_size=batch_size,
91
+ )
92
+ roc_curve_writer = MemoryCacheWriter.create(
93
+ schema=generate_roc_curve_schema(),
94
+ batch_size=batch_size,
95
+ )
96
+ return cls(
97
+ writer=writer,
98
+ roc_curve_writer=roc_curve_writer,
99
+ intermediate_writer=intermediate_writer,
100
+ metadata_fields=metadata_fields,
101
+ )
102
+
103
+ @classmethod
104
+ def persistent(
105
+ cls,
106
+ path: str | Path,
107
+ batch_size: int = 10_000,
108
+ rows_per_file: int = 100_000,
109
+ compression: str = "snappy",
110
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
111
+ ):
112
+ """
113
+ Create a persistent file-based evaluator cache.
114
+
115
+ Parameters
116
+ ----------
117
+ path : str | Path
118
+ Where to store file-based cache.
119
+ batch_size : int, default=10_000
120
+ Sets the batch size for writing to file.
121
+ rows_per_file : int, default=100_000
122
+ Sets the maximum number of rows per file. This may be exceeded as files are datum aligned.
123
+ compression : str, default="snappy"
124
+ Sets the pyarrow compression method.
125
+ metadata_fields : list[tuple[str, str | pa.DataType]], optional
126
+ Optionally sets metadata description for use in filtering.
127
+ """
128
+ path = Path(path)
129
+
130
+ # create cache
131
+ writer = FileCacheWriter.create(
132
+ path=generate_cache_path(path),
133
+ schema=generate_schema(metadata_fields),
134
+ batch_size=batch_size,
135
+ rows_per_file=rows_per_file,
136
+ compression=compression,
137
+ )
138
+ intermediate_writer = FileCacheWriter.create(
139
+ path=generate_intermediate_cache_path(path),
140
+ schema=generate_intermediate_schema(),
141
+ batch_size=batch_size,
142
+ rows_per_file=rows_per_file,
143
+ compression=compression,
144
+ )
145
+ roc_curve_writer = FileCacheWriter.create(
146
+ path=generate_roc_curve_cache_path(path),
147
+ schema=generate_roc_curve_schema(),
148
+ batch_size=batch_size,
149
+ rows_per_file=rows_per_file,
150
+ compression=compression,
151
+ )
152
+
153
+ # write metadatata config
154
+ metadata_path = generate_metadata_path(path)
155
+ with open(metadata_path, "w") as f:
156
+ encoded_types = encode_metadata_fields(metadata_fields)
157
+ json.dump(encoded_types, f, indent=2)
158
+
159
+ return cls(
160
+ writer=writer,
161
+ roc_curve_writer=roc_curve_writer,
162
+ intermediate_writer=intermediate_writer,
163
+ metadata_fields=metadata_fields,
164
+ )
165
+
166
+ def _create_rocauc_intermediate(
167
+ self,
168
+ reader: MemoryCacheReader | FileCacheReader,
169
+ batch_size: int,
170
+ index_to_label: dict[int, str],
171
+ ):
172
+ n_labels = len(index_to_label)
173
+ compute.sort(
174
+ source=reader,
175
+ sink=self._intermediate_writer,
176
+ batch_size=batch_size,
177
+ sorting=[
178
+ ("pd_score", "descending"),
179
+ # ("pd_label_id", "ascending"),
180
+ ("match", "ascending"),
181
+ ],
182
+ columns=[
183
+ "pd_label_id",
184
+ "pd_score",
185
+ "match",
186
+ ],
187
+ )
188
+ intermediate = self._intermediate_writer.to_reader()
189
+
190
+ running_max_fp = np.zeros(n_labels, dtype=np.uint64)
191
+ running_max_tp = np.zeros(n_labels, dtype=np.uint64)
192
+ running_max_scores = np.zeros(n_labels, dtype=np.float64)
193
+
194
+ last_pair = np.zeros((n_labels, 2), dtype=np.uint64)
195
+ for tbl in intermediate.iterate_tables():
196
+ pd_label_ids = tbl["pd_label_id"].to_numpy()
197
+ tps = tbl["match"].to_numpy()
198
+ scores = tbl["pd_score"].to_numpy()
199
+ fps = ~tps
200
+
201
+ for idx in index_to_label.keys():
202
+ mask = pd_label_ids == idx
203
+ if mask.sum() == 0:
204
+ continue
205
+
206
+ cumulative_fp = np.r_[
207
+ running_max_fp[idx],
208
+ np.cumsum(fps[mask]) + running_max_fp[idx],
209
+ ]
210
+ cumulative_tp = np.r_[
211
+ running_max_tp[idx],
212
+ np.cumsum(tps[mask]) + running_max_tp[idx],
213
+ ]
214
+ pd_scores = np.r_[running_max_scores[idx], scores[mask]]
215
+
216
+ indices = (
217
+ np.where(
218
+ np.diff(np.r_[running_max_scores[idx], pd_scores])
219
+ )[0]
220
+ - 1
221
+ )
222
+
223
+ running_max_fp[idx] = cumulative_fp[-1]
224
+ running_max_tp[idx] = cumulative_tp[-1]
225
+ running_max_scores[idx] = pd_scores[-1]
226
+
227
+ for fp, tp in zip(
228
+ cumulative_fp[indices],
229
+ cumulative_tp[indices],
230
+ ):
231
+ last_pair[idx, 0] = fp
232
+ last_pair[idx, 1] = tp
233
+ self._roc_curve_writer.write_rows(
234
+ [
235
+ {
236
+ "pd_label_id": idx,
237
+ "cumulative_fp": fp,
238
+ "cumulative_tp": tp,
239
+ }
240
+ ]
241
+ )
242
+
243
+ # ensure any remaining values are ingested
244
+ for idx in range(n_labels):
245
+ last_fp = last_pair[idx, 0]
246
+ last_tp = last_pair[idx, 1]
247
+ if (
248
+ last_fp != running_max_fp[idx]
249
+ or last_tp != running_max_tp[idx]
250
+ ):
251
+ self._roc_curve_writer.write_rows(
252
+ [
253
+ {
254
+ "pd_label_id": idx,
255
+ "cumulative_fp": running_max_fp[idx],
256
+ "cumulative_tp": running_max_tp[idx],
257
+ }
258
+ ]
259
+ )
260
+
261
+ def finalize(
262
+ self,
263
+ batch_size: int = 1_000,
264
+ index_to_label_override: dict[int, str] | None = None,
265
+ ):
266
+ """
267
+ Performs data finalization and some preprocessing steps.
268
+
269
+ Parameters
270
+ ----------
271
+ batch_size : int, default=1_000
272
+ Sets the maximum number of elements read into memory per-file when performing merge sort.
273
+ index_to_label_override : dict[int, str], optional
274
+ Pre-configures label mapping. Used when operating over filtered subsets.
275
+
276
+ Returns
277
+ -------
278
+ Evaluator
279
+ A ready-to-use evaluator object.
280
+ """
281
+ self._writer.flush()
282
+ if self._writer.count_rows() == 0:
283
+ raise EmptyCacheError()
284
+ elif self._roc_curve_writer.count_rows() > 0:
285
+ raise RuntimeError("data already finalized")
286
+
287
+ # sort in-place and locally
288
+ self._writer.sort_by(
289
+ [
290
+ ("pd_score", "descending"),
291
+ ("datum_id", "ascending"),
292
+ ("gt_label_id", "ascending"),
293
+ ("pd_label_id", "ascending"),
294
+ ]
295
+ )
296
+
297
+ # post-process into sorted writer
298
+ reader = self._writer.to_reader()
299
+
300
+ # extract labels
301
+ index_to_label = extract_labels(
302
+ reader=reader,
303
+ index_to_label_override=index_to_label_override,
304
+ )
305
+
306
+ self._create_rocauc_intermediate(
307
+ reader=reader,
308
+ batch_size=batch_size,
309
+ index_to_label=index_to_label,
310
+ )
311
+ roc_curve_reader = self._roc_curve_writer.to_reader()
312
+
313
+ return Evaluator(
314
+ reader=reader,
315
+ roc_curve_reader=roc_curve_reader,
316
+ index_to_label=index_to_label,
317
+ metadata_fields=self._metadata_fields,
318
+ )
319
+
320
+
321
+ class Evaluator:
322
+ def __init__(
323
+ self,
324
+ reader: MemoryCacheReader | FileCacheReader,
325
+ roc_curve_reader: MemoryCacheReader | FileCacheReader,
326
+ index_to_label: dict[int, str],
327
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
328
+ ):
329
+ self._reader = reader
330
+ self._roc_curve_reader = roc_curve_reader
331
+ self._index_to_label = index_to_label
332
+ self._metadata_fields = metadata_fields
333
+
334
+ @property
335
+ def info(self) -> EvaluatorInfo:
336
+ return self.get_info()
337
+
338
+ def get_info(
339
+ self,
340
+ datums: pc.Expression | None = None,
341
+ ) -> EvaluatorInfo:
342
+ info = EvaluatorInfo()
343
+ info.metadata_fields = self._metadata_fields
344
+ info.number_of_rows = self._reader.count_rows()
345
+ info.number_of_labels = len(self._index_to_label)
346
+ info.number_of_datums = extract_counts(
347
+ reader=self._reader,
348
+ datums=datums,
349
+ )
350
+ return info
351
+
352
+ @classmethod
353
+ def load(
354
+ cls,
355
+ path: str | Path,
356
+ index_to_label_override: dict[int, str] | None = None,
357
+ ):
358
+ """
359
+ Load from an existing classification cache.
360
+
361
+ Parameters
362
+ ----------
363
+ path : str | Path
364
+ Path to the existing cache.
365
+ index_to_label_override : dict[int, str], optional
366
+ Option to preset index to label dictionary. Used when loading from filtered caches.
367
+ """
368
+
369
+ # validate path
370
+ path = Path(path)
371
+ if not path.exists():
372
+ raise FileNotFoundError(f"Directory does not exist: {path}")
373
+ elif not path.is_dir():
374
+ raise NotADirectoryError(
375
+ f"Path exists but is not a directory: {path}"
376
+ )
377
+
378
+ # load cache
379
+ reader = FileCacheReader.load(generate_cache_path(path))
380
+ roc_curve_reader = FileCacheReader.load(
381
+ generate_roc_curve_cache_path(path)
382
+ )
383
+
384
+ # extract labels
385
+ index_to_label = extract_labels(
386
+ reader=reader,
387
+ index_to_label_override=index_to_label_override,
388
+ )
389
+
390
+ # read config
391
+ metadata_path = generate_metadata_path(path)
392
+ metadata_fields = None
393
+ with open(metadata_path, "r") as f:
394
+ encoded_types = json.load(f)
395
+ metadata_fields = decode_metadata_fields(encoded_types)
396
+
397
+ return cls(
398
+ reader=reader,
399
+ roc_curve_reader=roc_curve_reader,
400
+ index_to_label=index_to_label,
401
+ metadata_fields=metadata_fields,
402
+ )
403
+
404
+ def filter(
405
+ self,
406
+ datums: pc.Expression | None = None,
407
+ groundtruths: pc.Expression | None = None,
408
+ predictions: pc.Expression | None = None,
409
+ path: str | Path | None = None,
410
+ ) -> Evaluator:
411
+ """
412
+ Filter evaluator cache.
413
+
414
+ Parameters
415
+ ----------
416
+ datums : pc.Expression | None = None
417
+ A filter expression used to filter datums.
418
+ groundtruths : pc.Expression | None = None
419
+ A filter expression used to filter ground truth annotations.
420
+ predictions : pc.Expression | None = None
421
+ A filter expression used to filter predictions.
422
+ path : str | Path, optional
423
+ Where to store the filtered cache if storing on disk.
424
+
425
+ Returns
426
+ -------
427
+ Evaluator
428
+ A new evaluator object containing the filtered cache.
429
+ """
430
+ from valor_lite.classification.loader import Loader
431
+
432
+ if isinstance(self._reader, FileCacheReader):
433
+ if not path:
434
+ raise ValueError(
435
+ "expected path to be defined for file-based loader"
436
+ )
437
+ loader = Loader.persistent(
438
+ path=path,
439
+ batch_size=self._reader.batch_size,
440
+ rows_per_file=self._reader.rows_per_file,
441
+ compression=self._reader.compression,
442
+ metadata_fields=self.info.metadata_fields,
443
+ )
444
+ else:
445
+ loader = Loader.in_memory(
446
+ batch_size=self._reader.batch_size,
447
+ metadata_fields=self.info.metadata_fields,
448
+ )
449
+
450
+ for tbl in self._reader.iterate_tables(filter=datums):
451
+ columns = (
452
+ "datum_id",
453
+ "gt_label_id",
454
+ "pd_label_id",
455
+ )
456
+ pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
457
+
458
+ n_pairs = pairs.shape[0]
459
+ gt_ids = pairs[:, (0, 1)].astype(np.int64)
460
+ pd_ids = pairs[:, (0, 2)].astype(np.int64)
461
+
462
+ if groundtruths is not None:
463
+ mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
464
+ gt_tbl = tbl.filter(groundtruths)
465
+ gt_pairs = np.column_stack(
466
+ [
467
+ gt_tbl[col].to_numpy()
468
+ for col in ("datum_id", "gt_label_id")
469
+ ]
470
+ ).astype(np.int64)
471
+ for gt in np.unique(gt_pairs, axis=0):
472
+ mask_valid_gt |= (gt_ids == gt).all(axis=1)
473
+ else:
474
+ mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
475
+
476
+ if predictions is not None:
477
+ mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
478
+ pd_tbl = tbl.filter(predictions)
479
+ pd_pairs = np.column_stack(
480
+ [
481
+ pd_tbl[col].to_numpy()
482
+ for col in ("datum_id", "pd_label_id")
483
+ ]
484
+ ).astype(np.int64)
485
+ for pd in np.unique(pd_pairs, axis=0):
486
+ mask_valid_pd |= (pd_ids == pd).all(axis=1)
487
+ else:
488
+ mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
489
+
490
+ mask_valid = mask_valid_gt | mask_valid_pd
491
+ mask_valid_gt &= mask_valid
492
+ mask_valid_pd &= mask_valid
493
+
494
+ pairs[~mask_valid_gt, 1] = -1
495
+ pairs[~mask_valid_pd, 2] = -1
496
+
497
+ for idx, col in enumerate(columns):
498
+ tbl = tbl.set_column(
499
+ tbl.schema.names.index(col), col, pa.array(pairs[:, idx])
500
+ )
501
+ # TODO (c.zaloom) - improve write strategy, filtered data could be small
502
+ loader._writer.write_table(tbl)
503
+
504
+ return loader.finalize(index_to_label_override=self._index_to_label)
505
+
506
+ def iterate_values(self, datums: pc.Expression | None = None):
507
+ columns = [
508
+ "datum_id",
509
+ "gt_label_id",
510
+ "pd_label_id",
511
+ "pd_score",
512
+ "pd_winner",
513
+ "match",
514
+ ]
515
+ for tbl in self._reader.iterate_tables(columns=columns, filter=datums):
516
+ ids = np.column_stack(
517
+ [
518
+ tbl[col].to_numpy()
519
+ for col in [
520
+ "datum_id",
521
+ "gt_label_id",
522
+ "pd_label_id",
523
+ ]
524
+ ]
525
+ )
526
+ scores = tbl["pd_score"].to_numpy()
527
+ winners = tbl["pd_winner"].to_numpy()
528
+ matches = tbl["match"].to_numpy()
529
+ yield ids, scores, winners, matches
530
+
531
+ def iterate_values_with_tables(self, datums: pc.Expression | None = None):
532
+ for tbl in self._reader.iterate_tables(filter=datums):
533
+ ids = np.column_stack(
534
+ [
535
+ tbl[col].to_numpy()
536
+ for col in [
537
+ "datum_id",
538
+ "gt_label_id",
539
+ "pd_label_id",
540
+ ]
541
+ ]
542
+ )
543
+ scores = tbl["pd_score"].to_numpy()
544
+ winners = tbl["pd_winner"].to_numpy()
545
+ matches = tbl["match"].to_numpy()
546
+ yield ids, scores, winners, matches, tbl
547
+
548
+ def compute_rocauc(
549
+ self, datums: pc.Expression | None = None
550
+ ) -> dict[MetricType, list[Metric]]:
551
+ """
552
+ Compute ROCAUC.
553
+
554
+ Parameters
555
+ ----------
556
+ datums : pyarrow.compute.Expression, optional
557
+ Option to filter datums by an expression.
558
+
559
+ Returns
560
+ -------
561
+ dict[MetricType, list[Metric]]
562
+ A dictionary mapping MetricType enumerations to lists of computed metrics.
563
+ """
564
+ n_labels = self.info.number_of_labels
565
+
566
+ rocauc = np.zeros(n_labels, dtype=np.float64)
567
+ label_counts = extract_groundtruth_count_per_label(
568
+ reader=self._reader,
569
+ number_of_labels=len(self._index_to_label),
570
+ datums=datums,
571
+ )
572
+
573
+ prev = np.zeros((n_labels, 2), dtype=np.uint64)
574
+ for array in self._roc_curve_reader.iterate_arrays(
575
+ numeric_columns=[
576
+ "pd_label_id",
577
+ "cumulative_fp",
578
+ "cumulative_tp",
579
+ ],
580
+ filter=datums,
581
+ ):
582
+ rocauc, prev = compute_rocauc(
583
+ rocauc=rocauc,
584
+ array=array,
585
+ gt_count_per_label=label_counts[:, 0],
586
+ pd_count_per_label=label_counts[:, 1],
587
+ n_labels=self.info.number_of_labels,
588
+ prev=prev,
589
+ )
590
+
591
+ mean_rocauc = rocauc.mean()
592
+
593
+ return unpack_rocauc(
594
+ rocauc=rocauc,
595
+ mean_rocauc=mean_rocauc,
596
+ index_to_label=self._index_to_label,
597
+ )
598
+
599
+ def compute_precision_recall(
600
+ self,
601
+ score_thresholds: list[float] = [0.0],
602
+ hardmax: bool = True,
603
+ datums: pc.Expression | None = None,
604
+ ) -> dict[MetricType, list]:
605
+ """
606
+ Performs an evaluation and returns metrics.
607
+
608
+ Parameters
609
+ ----------
610
+ score_thresholds : list[float]
611
+ A list of score thresholds to compute metrics over.
612
+ hardmax : bool
613
+ Toggles whether a hardmax is applied to predictions.
614
+ datums : pyarrow.compute.Expression, optional
615
+ Option to filter datums by an expression.
616
+
617
+ Returns
618
+ -------
619
+ dict[MetricType, list]
620
+ A dictionary mapping MetricType enumerations to lists of computed metrics.
621
+ """
622
+ if not score_thresholds:
623
+ raise ValueError("At least one score threshold must be passed.")
624
+
625
+ n_scores = len(score_thresholds)
626
+ n_datums = self.info.number_of_datums
627
+ n_labels = self.info.number_of_labels
628
+
629
+ # intermediates
630
+ counts = np.zeros((n_scores, n_labels, 4), dtype=np.uint64)
631
+
632
+ for ids, scores, winners, _ in self.iterate_values(datums=datums):
633
+ batch_counts = compute_counts(
634
+ ids=ids,
635
+ scores=scores,
636
+ winners=winners,
637
+ score_thresholds=np.array(score_thresholds),
638
+ hardmax=hardmax,
639
+ n_labels=n_labels,
640
+ )
641
+ counts += batch_counts
642
+
643
+ precision = compute_precision(counts)
644
+ recall = compute_recall(counts)
645
+ f1_score = compute_f1_score(precision, recall)
646
+ accuracy = compute_accuracy(counts, n_datums=n_datums)
647
+
648
+ return unpack_precision_recall(
649
+ counts=counts,
650
+ precision=precision,
651
+ recall=recall,
652
+ accuracy=accuracy,
653
+ f1_score=f1_score,
654
+ score_thresholds=score_thresholds,
655
+ hardmax=hardmax,
656
+ index_to_label=self._index_to_label,
657
+ )
658
+
659
+ def compute_confusion_matrix(
660
+ self,
661
+ score_thresholds: list[float] = [0.0],
662
+ hardmax: bool = True,
663
+ datums: pc.Expression | None = None,
664
+ ) -> list[Metric]:
665
+ """
666
+ Compute a confusion matrix.
667
+
668
+ Parameters
669
+ ----------
670
+ score_thresholds : list[float]
671
+ A list of score thresholds to compute metrics over.
672
+ hardmax : bool
673
+ Toggles whether a hardmax is applied to predictions.
674
+ datums : pyarrow.compute.Expression, optional
675
+ Option to filter datums by an expression.
676
+
677
+ Returns
678
+ -------
679
+ list[Metric]
680
+ A list of confusion matrices.
681
+ """
682
+ if not score_thresholds:
683
+ raise ValueError("At least one score threshold must be passed.")
684
+
685
+ n_scores = len(score_thresholds)
686
+ n_labels = len(self._index_to_label)
687
+ confusion_matrices = np.zeros(
688
+ (n_scores, n_labels, n_labels), dtype=np.uint64
689
+ )
690
+ unmatched_groundtruths = np.zeros(
691
+ (n_scores, n_labels), dtype=np.uint64
692
+ )
693
+ for ids, scores, winners, matches in self.iterate_values(
694
+ datums=datums
695
+ ):
696
+ (
697
+ mask_tp,
698
+ mask_fp_fn_misclf,
699
+ mask_fn_unmatched,
700
+ ) = compute_pair_classifications(
701
+ ids=ids,
702
+ scores=scores,
703
+ winners=winners,
704
+ score_thresholds=np.array(score_thresholds),
705
+ hardmax=hardmax,
706
+ )
707
+
708
+ batch_cm, batch_ugt = compute_confusion_matrix(
709
+ ids=ids,
710
+ mask_tp=mask_tp,
711
+ mask_fp_fn_misclf=mask_fp_fn_misclf,
712
+ mask_fn_unmatched=mask_fn_unmatched,
713
+ score_thresholds=np.array(score_thresholds),
714
+ n_labels=n_labels,
715
+ )
716
+ confusion_matrices += batch_cm
717
+ unmatched_groundtruths += batch_ugt
718
+
719
+ return unpack_confusion_matrix(
720
+ confusion_matrices=confusion_matrices,
721
+ unmatched_groundtruths=unmatched_groundtruths,
722
+ index_to_label=self._index_to_label,
723
+ score_thresholds=score_thresholds,
724
+ hardmax=hardmax,
725
+ )
726
+
727
+ def compute_examples(
728
+ self,
729
+ score_thresholds: list[float] = [0.0],
730
+ hardmax: bool = True,
731
+ datums: pc.Expression | None = None,
732
+ ) -> list[Metric]:
733
+ """
734
+ Compute examples per datum.
735
+
736
+ Note: This function should be used with filtering to reduce response size.
737
+
738
+ Parameters
739
+ ----------
740
+ score_thresholds : list[float]
741
+ A list of score thresholds to compute metrics over.
742
+ hardmax : bool
743
+ Toggles whether a hardmax is applied to predictions.
744
+ datums : pyarrow.compute.Expression, optional
745
+ Option to filter datums by an expression.
746
+
747
+ Returns
748
+ -------
749
+ list[Metric]
750
+ A list of confusion matrices.
751
+ """
752
+ if not score_thresholds:
753
+ raise ValueError("At least one score threshold must be passed.")
754
+
755
+ metrics = []
756
+ for (
757
+ ids,
758
+ scores,
759
+ winners,
760
+ _,
761
+ tbl,
762
+ ) in self.iterate_values_with_tables(datums=datums):
763
+ if ids.size == 0:
764
+ continue
765
+
766
+ # extract external identifiers
767
+ index_to_datum_id = create_mapping(
768
+ tbl, ids, 0, "datum_id", "datum_uid"
769
+ )
770
+
771
+ (
772
+ mask_tp,
773
+ mask_fp_fn_misclf,
774
+ mask_fn_unmatched,
775
+ ) = compute_pair_classifications(
776
+ ids=ids,
777
+ scores=scores,
778
+ winners=winners,
779
+ score_thresholds=np.array(score_thresholds),
780
+ hardmax=hardmax,
781
+ )
782
+
783
+ mask_fn = mask_fp_fn_misclf | mask_fn_unmatched
784
+ mask_fp = mask_fp_fn_misclf
785
+
786
+ batch_examples = unpack_examples(
787
+ ids=ids,
788
+ mask_tp=mask_tp,
789
+ mask_fp=mask_fp,
790
+ mask_fn=mask_fn,
791
+ index_to_datum_id=index_to_datum_id,
792
+ score_thresholds=score_thresholds,
793
+ hardmax=hardmax,
794
+ index_to_label=self._index_to_label,
795
+ )
796
+ metrics.extend(batch_examples)
797
+
798
+ return metrics
799
+
800
+ def compute_confusion_matrix_with_examples(
801
+ self,
802
+ score_thresholds: list[float] = [0.0],
803
+ hardmax: bool = True,
804
+ datums: pc.Expression | None = None,
805
+ ) -> list[Metric]:
806
+ """
807
+ Compute confusion matrix with examples.
808
+
809
+ Note: This function should be used with filtering to reduce response size.
810
+
811
+ Parameters
812
+ ----------
813
+ metrics : dict[int, Metric]
814
+ Mapping of score threshold index to cached metric.
815
+ score_thresholds : list[float]
816
+ A list of score thresholds to compute metrics over.
817
+ hardmax : bool
818
+ Toggles whether a hardmax is applied to predictions.
819
+ datums : pyarrow.compute.Expression, optional
820
+ Option to filter datums by an expression.
821
+
822
+ Returns
823
+ -------
824
+ list[Metric]
825
+ A list of confusion matrices.
826
+ """
827
+ if not score_thresholds:
828
+ raise ValueError("At least one score threshold must be passed.")
829
+
830
+ metrics = {
831
+ score_idx: create_empty_confusion_matrix_with_examples(
832
+ score_threshold=score_thresh,
833
+ hardmax=hardmax,
834
+ index_to_label=self._index_to_label,
835
+ )
836
+ for score_idx, score_thresh in enumerate(score_thresholds)
837
+ }
838
+ for (
839
+ ids,
840
+ scores,
841
+ winners,
842
+ _,
843
+ tbl,
844
+ ) in self.iterate_values_with_tables(datums=datums):
845
+ if ids.size == 0:
846
+ continue
847
+
848
+ # extract external identifiers
849
+ index_to_datum_id = create_mapping(
850
+ tbl, ids, 0, "datum_id", "datum_uid"
851
+ )
852
+
853
+ (
854
+ mask_tp,
855
+ mask_fp_fn_misclf,
856
+ mask_fn_unmatched,
857
+ ) = compute_pair_classifications(
858
+ ids=ids,
859
+ scores=scores,
860
+ winners=winners,
861
+ score_thresholds=np.array(score_thresholds),
862
+ hardmax=hardmax,
863
+ )
864
+
865
+ mask_matched = mask_tp | mask_fp_fn_misclf
866
+ mask_unmatched_fn = mask_fn_unmatched
867
+
868
+ unpack_confusion_matrix_with_examples(
869
+ metrics=metrics,
870
+ ids=ids,
871
+ scores=scores,
872
+ winners=winners,
873
+ mask_matched=mask_matched,
874
+ mask_unmatched_fn=mask_unmatched_fn,
875
+ index_to_datum_id=index_to_datum_id,
876
+ index_to_label=self._index_to_label,
877
+ )
878
+
879
+ return list(metrics.values())