valor-lite 0.36.6__py3-none-any.whl → 0.37.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +211 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +536 -0
- valor_lite/classification/__init__.py +5 -10
- valor_lite/classification/annotation.py +4 -0
- valor_lite/classification/computation.py +233 -251
- valor_lite/classification/evaluator.py +882 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +141 -4
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +221 -118
- valor_lite/exceptions.py +5 -0
- valor_lite/object_detection/__init__.py +5 -4
- valor_lite/object_detection/annotation.py +13 -1
- valor_lite/object_detection/computation.py +368 -299
- valor_lite/object_detection/evaluator.py +804 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +152 -3
- valor_lite/object_detection/shared.py +206 -0
- valor_lite/object_detection/utilities.py +182 -100
- valor_lite/semantic_segmentation/__init__.py +5 -4
- valor_lite/semantic_segmentation/annotation.py +7 -0
- valor_lite/semantic_segmentation/computation.py +20 -110
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +6 -23
- {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
- valor_lite-0.37.5.dist-info/RECORD +49 -0
- {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
- valor_lite/classification/manager.py +0 -545
- valor_lite/object_detection/manager.py +0 -864
- valor_lite/profiling.py +0 -374
- valor_lite/semantic_segmentation/benchmark.py +0 -237
- valor_lite/semantic_segmentation/manager.py +0 -446
- valor_lite-0.36.6.dist-info/RECORD +0 -41
- {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,414 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
import pyarrow.compute as pc
|
|
9
|
+
from numpy.typing import NDArray
|
|
10
|
+
|
|
11
|
+
from valor_lite.cache import (
|
|
12
|
+
FileCacheReader,
|
|
13
|
+
FileCacheWriter,
|
|
14
|
+
MemoryCacheReader,
|
|
15
|
+
MemoryCacheWriter,
|
|
16
|
+
)
|
|
17
|
+
from valor_lite.exceptions import EmptyCacheError
|
|
18
|
+
from valor_lite.semantic_segmentation.computation import compute_metrics
|
|
19
|
+
from valor_lite.semantic_segmentation.metric import MetricType
|
|
20
|
+
from valor_lite.semantic_segmentation.shared import (
|
|
21
|
+
EvaluatorInfo,
|
|
22
|
+
decode_metadata_fields,
|
|
23
|
+
encode_metadata_fields,
|
|
24
|
+
extract_counts,
|
|
25
|
+
extract_labels,
|
|
26
|
+
generate_cache_path,
|
|
27
|
+
generate_metadata_path,
|
|
28
|
+
generate_schema,
|
|
29
|
+
)
|
|
30
|
+
from valor_lite.semantic_segmentation.utilities import (
|
|
31
|
+
unpack_precision_recall_iou_into_metric_lists,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Builder:
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
writer: MemoryCacheWriter | FileCacheWriter,
|
|
39
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
40
|
+
):
|
|
41
|
+
self._writer = writer
|
|
42
|
+
self._metadata_fields = metadata_fields
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def in_memory(
|
|
46
|
+
cls,
|
|
47
|
+
batch_size: int = 10_000,
|
|
48
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Create an in-memory evaluator cache.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
batch_size : int, default=10_000
|
|
56
|
+
The target number of rows to buffer before writing to the cache. Defaults to 10_000.
|
|
57
|
+
metadata_fields : list[tuple[str, str | pa.DataType]], optional
|
|
58
|
+
Optional metadata field definitions.
|
|
59
|
+
"""
|
|
60
|
+
# create cache
|
|
61
|
+
writer = MemoryCacheWriter.create(
|
|
62
|
+
schema=generate_schema(metadata_fields),
|
|
63
|
+
batch_size=batch_size,
|
|
64
|
+
)
|
|
65
|
+
return cls(
|
|
66
|
+
writer=writer,
|
|
67
|
+
metadata_fields=metadata_fields,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def persistent(
|
|
72
|
+
cls,
|
|
73
|
+
path: str | Path,
|
|
74
|
+
batch_size: int = 10_000,
|
|
75
|
+
rows_per_file: int = 100_000,
|
|
76
|
+
compression: str = "snappy",
|
|
77
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
78
|
+
):
|
|
79
|
+
"""
|
|
80
|
+
Create a persistent file-based evaluator cache.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
path : str | Path
|
|
85
|
+
Where to store the file-based cache.
|
|
86
|
+
batch_size : int, default=10_000
|
|
87
|
+
The target number of rows to buffer before writing to the cache. Defaults to 10_000.
|
|
88
|
+
rows_per_file : int, default=100_000
|
|
89
|
+
The target number of rows to store per cache file. Defaults to 100_000.
|
|
90
|
+
compression : str, default="snappy"
|
|
91
|
+
The compression methods used when writing cache files.
|
|
92
|
+
metadata_fields : list[tuple[str, str | pa.DataType]], optional
|
|
93
|
+
Optional metadata field definitions.
|
|
94
|
+
"""
|
|
95
|
+
path = Path(path)
|
|
96
|
+
|
|
97
|
+
# create cache
|
|
98
|
+
writer = FileCacheWriter.create(
|
|
99
|
+
path=generate_cache_path(path),
|
|
100
|
+
schema=generate_schema(metadata_fields),
|
|
101
|
+
batch_size=batch_size,
|
|
102
|
+
rows_per_file=rows_per_file,
|
|
103
|
+
compression=compression,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# write metadata
|
|
107
|
+
metadata_path = generate_metadata_path(path)
|
|
108
|
+
with open(metadata_path, "w") as f:
|
|
109
|
+
encoded_types = encode_metadata_fields(metadata_fields)
|
|
110
|
+
json.dump(encoded_types, f, indent=2)
|
|
111
|
+
|
|
112
|
+
return cls(
|
|
113
|
+
writer=writer,
|
|
114
|
+
metadata_fields=metadata_fields,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def finalize(
|
|
118
|
+
self,
|
|
119
|
+
index_to_label_override: dict[int, str] | None = None,
|
|
120
|
+
):
|
|
121
|
+
"""
|
|
122
|
+
Performs data finalization and some preprocessing steps.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
index_to_label_override : dict[int, str], optional
|
|
127
|
+
Pre-configures label mapping. Used when operating over filtered subsets.
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
Evaluator
|
|
132
|
+
A ready-to-use evaluator object.
|
|
133
|
+
"""
|
|
134
|
+
self._writer.flush()
|
|
135
|
+
if self._writer.count_rows() == 0:
|
|
136
|
+
raise EmptyCacheError()
|
|
137
|
+
|
|
138
|
+
reader = self._writer.to_reader()
|
|
139
|
+
|
|
140
|
+
# extract labels
|
|
141
|
+
index_to_label = extract_labels(
|
|
142
|
+
reader=reader,
|
|
143
|
+
index_to_label_override=index_to_label_override,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return Evaluator(
|
|
147
|
+
reader=reader,
|
|
148
|
+
index_to_label=index_to_label,
|
|
149
|
+
metadata_fields=self._metadata_fields,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class Evaluator:
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
157
|
+
index_to_label: dict[int, str],
|
|
158
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
159
|
+
):
|
|
160
|
+
self._reader = reader
|
|
161
|
+
self._index_to_label = index_to_label
|
|
162
|
+
self._metadata_fields = metadata_fields
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def info(self) -> EvaluatorInfo:
|
|
166
|
+
return self.get_info()
|
|
167
|
+
|
|
168
|
+
def get_info(
|
|
169
|
+
self,
|
|
170
|
+
datums: pc.Expression | None = None,
|
|
171
|
+
groundtruths: pc.Expression | None = None,
|
|
172
|
+
predictions: pc.Expression | None = None,
|
|
173
|
+
) -> EvaluatorInfo:
|
|
174
|
+
info = EvaluatorInfo()
|
|
175
|
+
info.number_of_rows = self._reader.count_rows()
|
|
176
|
+
info.number_of_labels = len(self._index_to_label)
|
|
177
|
+
info.metadata_fields = self._metadata_fields
|
|
178
|
+
(
|
|
179
|
+
info.number_of_datums,
|
|
180
|
+
info.number_of_pixels,
|
|
181
|
+
info.number_of_groundtruth_pixels,
|
|
182
|
+
info.number_of_prediction_pixels,
|
|
183
|
+
) = extract_counts(
|
|
184
|
+
reader=self._reader,
|
|
185
|
+
datums=datums,
|
|
186
|
+
groundtruths=groundtruths,
|
|
187
|
+
predictions=predictions,
|
|
188
|
+
)
|
|
189
|
+
return info
|
|
190
|
+
|
|
191
|
+
@classmethod
|
|
192
|
+
def load(
|
|
193
|
+
cls,
|
|
194
|
+
path: str | Path,
|
|
195
|
+
index_to_label_override: dict[int, str] | None = None,
|
|
196
|
+
):
|
|
197
|
+
"""
|
|
198
|
+
Load from an existing semantic segmentation cache.
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
path : str | Path
|
|
203
|
+
Path to the existing cache.
|
|
204
|
+
index_to_label_override : dict[int, str], optional
|
|
205
|
+
Option to preset index to label dictionary. Used when loading from filtered caches.
|
|
206
|
+
"""
|
|
207
|
+
# validate path
|
|
208
|
+
path = Path(path)
|
|
209
|
+
if not path.exists():
|
|
210
|
+
raise FileNotFoundError(f"Directory does not exist: {path}")
|
|
211
|
+
elif not path.is_dir():
|
|
212
|
+
raise NotADirectoryError(
|
|
213
|
+
f"Path exists but is not a directory: {path}"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# load cache
|
|
217
|
+
reader = FileCacheReader.load(generate_cache_path(path))
|
|
218
|
+
|
|
219
|
+
# extract labels
|
|
220
|
+
index_to_label = extract_labels(
|
|
221
|
+
reader=reader,
|
|
222
|
+
index_to_label_override=index_to_label_override,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# read config
|
|
226
|
+
metadata_path = generate_metadata_path(path)
|
|
227
|
+
metadata_fields = None
|
|
228
|
+
with open(metadata_path, "r") as f:
|
|
229
|
+
metadata_types = json.load(f)
|
|
230
|
+
metadata_fields = decode_metadata_fields(metadata_types)
|
|
231
|
+
|
|
232
|
+
return cls(
|
|
233
|
+
reader=reader,
|
|
234
|
+
index_to_label=index_to_label,
|
|
235
|
+
metadata_fields=metadata_fields,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def filter(
|
|
239
|
+
self,
|
|
240
|
+
datums: pc.Expression | None = None,
|
|
241
|
+
groundtruths: pc.Expression | None = None,
|
|
242
|
+
predictions: pc.Expression | None = None,
|
|
243
|
+
path: str | Path | None = None,
|
|
244
|
+
) -> Evaluator:
|
|
245
|
+
"""
|
|
246
|
+
Filter evaluator cache.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
datums : pc.Expression | None = None
|
|
251
|
+
A filter expression used to filter datums.
|
|
252
|
+
groundtruths : pc.Expression | None = None
|
|
253
|
+
A filter expression used to filter ground truth annotations.
|
|
254
|
+
predictions : pc.Expression | None = None
|
|
255
|
+
A filter expression used to filter predictions.
|
|
256
|
+
path : str | Path, optional
|
|
257
|
+
Where to store the filtered cache if storing on disk.
|
|
258
|
+
|
|
259
|
+
Returns
|
|
260
|
+
-------
|
|
261
|
+
Evaluator
|
|
262
|
+
A new evaluator object containing the filtered cache.
|
|
263
|
+
"""
|
|
264
|
+
if isinstance(self._reader, FileCacheReader):
|
|
265
|
+
if not path:
|
|
266
|
+
raise ValueError(
|
|
267
|
+
"expected path to be defined for file-based cache"
|
|
268
|
+
)
|
|
269
|
+
builder = Builder.persistent(
|
|
270
|
+
path=path,
|
|
271
|
+
batch_size=self._reader.batch_size,
|
|
272
|
+
rows_per_file=self._reader.rows_per_file,
|
|
273
|
+
compression=self._reader.compression,
|
|
274
|
+
metadata_fields=self.info.metadata_fields,
|
|
275
|
+
)
|
|
276
|
+
else:
|
|
277
|
+
builder = Builder.in_memory(
|
|
278
|
+
batch_size=self._reader.batch_size,
|
|
279
|
+
metadata_fields=self.info.metadata_fields,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
for tbl in self._reader.iterate_tables(filter=datums):
|
|
283
|
+
columns = (
|
|
284
|
+
"datum_id",
|
|
285
|
+
"gt_label_id",
|
|
286
|
+
"pd_label_id",
|
|
287
|
+
)
|
|
288
|
+
pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
|
|
289
|
+
|
|
290
|
+
n_pairs = pairs.shape[0]
|
|
291
|
+
gt_ids = pairs[:, (0, 1)].astype(np.int64)
|
|
292
|
+
pd_ids = pairs[:, (0, 2)].astype(np.int64)
|
|
293
|
+
|
|
294
|
+
if groundtruths is not None:
|
|
295
|
+
mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
|
|
296
|
+
gt_tbl = tbl.filter(groundtruths)
|
|
297
|
+
gt_pairs = np.column_stack(
|
|
298
|
+
[
|
|
299
|
+
gt_tbl[col].to_numpy()
|
|
300
|
+
for col in ("datum_id", "gt_label_id")
|
|
301
|
+
]
|
|
302
|
+
).astype(np.int64)
|
|
303
|
+
for gt in np.unique(gt_pairs, axis=0):
|
|
304
|
+
mask_valid_gt |= (gt_ids == gt).all(axis=1)
|
|
305
|
+
else:
|
|
306
|
+
mask_valid_gt = np.ones(n_pairs, dtype=np.bool_)
|
|
307
|
+
|
|
308
|
+
if predictions is not None:
|
|
309
|
+
mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
|
|
310
|
+
pd_tbl = tbl.filter(predictions)
|
|
311
|
+
pd_pairs = np.column_stack(
|
|
312
|
+
[
|
|
313
|
+
pd_tbl[col].to_numpy()
|
|
314
|
+
for col in ("datum_id", "pd_label_id")
|
|
315
|
+
]
|
|
316
|
+
).astype(np.int64)
|
|
317
|
+
for pd in np.unique(pd_pairs, axis=0):
|
|
318
|
+
mask_valid_pd |= (pd_ids == pd).all(axis=1)
|
|
319
|
+
else:
|
|
320
|
+
mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
|
|
321
|
+
|
|
322
|
+
mask_valid = mask_valid_gt | mask_valid_pd
|
|
323
|
+
mask_valid_gt &= mask_valid
|
|
324
|
+
mask_valid_pd &= mask_valid
|
|
325
|
+
|
|
326
|
+
pairs[~mask_valid_gt, 1] = -1
|
|
327
|
+
pairs[~mask_valid_pd, 2] = -1
|
|
328
|
+
|
|
329
|
+
for idx, col in enumerate(columns):
|
|
330
|
+
tbl = tbl.set_column(
|
|
331
|
+
tbl.schema.names.index(col), col, pa.array(pairs[:, idx])
|
|
332
|
+
)
|
|
333
|
+
builder._writer.write_table(tbl)
|
|
334
|
+
|
|
335
|
+
return builder.finalize(index_to_label_override=self._index_to_label)
|
|
336
|
+
|
|
337
|
+
def _compute_confusion_matrix_intermediate(
|
|
338
|
+
self, datums: pc.Expression | None = None
|
|
339
|
+
) -> NDArray[np.uint64]:
|
|
340
|
+
"""
|
|
341
|
+
Performs an evaluation and returns metrics.
|
|
342
|
+
|
|
343
|
+
Parameters
|
|
344
|
+
----------
|
|
345
|
+
datums : pyarrow.compute.Expression, optional
|
|
346
|
+
Option to filter datums by an expression.
|
|
347
|
+
|
|
348
|
+
Returns
|
|
349
|
+
-------
|
|
350
|
+
dict[MetricType, list]
|
|
351
|
+
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
352
|
+
"""
|
|
353
|
+
n_labels = len(self._index_to_label)
|
|
354
|
+
confusion_matrix = np.zeros(
|
|
355
|
+
(n_labels + 1, n_labels + 1), dtype=np.uint64
|
|
356
|
+
)
|
|
357
|
+
for tbl in self._reader.iterate_tables(filter=datums):
|
|
358
|
+
columns = (
|
|
359
|
+
"datum_id",
|
|
360
|
+
"gt_label_id",
|
|
361
|
+
"pd_label_id",
|
|
362
|
+
)
|
|
363
|
+
ids = np.column_stack(
|
|
364
|
+
[tbl[col].to_numpy() for col in columns]
|
|
365
|
+
).astype(np.int64)
|
|
366
|
+
counts = tbl["count"].to_numpy()
|
|
367
|
+
|
|
368
|
+
mask_null_gts = ids[:, 1] == -1
|
|
369
|
+
mask_null_pds = ids[:, 2] == -1
|
|
370
|
+
confusion_matrix[0, 0] += counts[
|
|
371
|
+
mask_null_gts & mask_null_pds
|
|
372
|
+
].sum()
|
|
373
|
+
for idx in range(n_labels):
|
|
374
|
+
mask_gts = ids[:, 1] == idx
|
|
375
|
+
for pidx in range(n_labels):
|
|
376
|
+
mask_pds = ids[:, 2] == pidx
|
|
377
|
+
confusion_matrix[idx + 1, pidx + 1] += counts[
|
|
378
|
+
mask_gts & mask_pds
|
|
379
|
+
].sum()
|
|
380
|
+
|
|
381
|
+
mask_unmatched_gts = mask_gts & mask_null_pds
|
|
382
|
+
confusion_matrix[idx + 1, 0] += counts[
|
|
383
|
+
mask_unmatched_gts
|
|
384
|
+
].sum()
|
|
385
|
+
mask_unmatched_pds = mask_null_gts & (ids[:, 2] == idx)
|
|
386
|
+
confusion_matrix[0, idx + 1] += counts[
|
|
387
|
+
mask_unmatched_pds
|
|
388
|
+
].sum()
|
|
389
|
+
return confusion_matrix
|
|
390
|
+
|
|
391
|
+
def compute_precision_recall_iou(
|
|
392
|
+
self, datums: pc.Expression | None = None
|
|
393
|
+
) -> dict[MetricType, list]:
|
|
394
|
+
"""
|
|
395
|
+
Performs an evaluation and returns metrics.
|
|
396
|
+
|
|
397
|
+
Parameters
|
|
398
|
+
----------
|
|
399
|
+
datums : pyarrow.compute.Expression, optional
|
|
400
|
+
Option to filter datums by an expression.
|
|
401
|
+
|
|
402
|
+
Returns
|
|
403
|
+
-------
|
|
404
|
+
dict[MetricType, list]
|
|
405
|
+
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
406
|
+
"""
|
|
407
|
+
confusion_matrix = self._compute_confusion_matrix_intermediate(
|
|
408
|
+
datums=datums
|
|
409
|
+
)
|
|
410
|
+
results = compute_metrics(confusion_matrix=confusion_matrix)
|
|
411
|
+
return unpack_precision_recall_iou_into_metric_lists(
|
|
412
|
+
results=results,
|
|
413
|
+
index_to_label=self._index_to_label,
|
|
414
|
+
)
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pyarrow as pa
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
|
|
5
|
+
from valor_lite.cache import FileCacheWriter, MemoryCacheWriter
|
|
6
|
+
from valor_lite.semantic_segmentation.annotation import Segmentation
|
|
7
|
+
from valor_lite.semantic_segmentation.computation import compute_intermediates
|
|
8
|
+
from valor_lite.semantic_segmentation.evaluator import Builder
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Loader(Builder):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
writer: MemoryCacheWriter | FileCacheWriter,
|
|
15
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
16
|
+
):
|
|
17
|
+
super().__init__(
|
|
18
|
+
writer=writer,
|
|
19
|
+
metadata_fields=metadata_fields,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# internal state
|
|
23
|
+
self._labels: dict[str, int] = {}
|
|
24
|
+
self._index_to_label: dict[int, str] = {}
|
|
25
|
+
self._datum_count = 0
|
|
26
|
+
|
|
27
|
+
def _add_label(self, value: str) -> int:
|
|
28
|
+
idx = self._labels.get(value, None)
|
|
29
|
+
if idx is None:
|
|
30
|
+
idx = len(self._labels)
|
|
31
|
+
self._labels[value] = idx
|
|
32
|
+
self._index_to_label[idx] = value
|
|
33
|
+
return idx
|
|
34
|
+
|
|
35
|
+
def add_data(
|
|
36
|
+
self,
|
|
37
|
+
segmentations: list[Segmentation],
|
|
38
|
+
show_progress: bool = False,
|
|
39
|
+
):
|
|
40
|
+
"""
|
|
41
|
+
Adds segmentations to the cache.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
segmentations : list[Segmentation]
|
|
46
|
+
A list of Segmentation objects.
|
|
47
|
+
show_progress : bool, default=False
|
|
48
|
+
Toggle for tqdm progress bar.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
disable_tqdm = not show_progress
|
|
52
|
+
for segmentation in tqdm(segmentations, disable=disable_tqdm):
|
|
53
|
+
|
|
54
|
+
groundtruth_labels = -1 * np.ones(
|
|
55
|
+
len(segmentation.groundtruths), dtype=np.int64
|
|
56
|
+
)
|
|
57
|
+
for idx, groundtruth in enumerate(segmentation.groundtruths):
|
|
58
|
+
label_idx = self._add_label(groundtruth.label)
|
|
59
|
+
groundtruth_labels[idx] = label_idx
|
|
60
|
+
|
|
61
|
+
prediction_labels = -1 * np.ones(
|
|
62
|
+
len(segmentation.predictions), dtype=np.int64
|
|
63
|
+
)
|
|
64
|
+
for idx, prediction in enumerate(segmentation.predictions):
|
|
65
|
+
label_idx = self._add_label(prediction.label)
|
|
66
|
+
prediction_labels[idx] = label_idx
|
|
67
|
+
|
|
68
|
+
if segmentation.groundtruths:
|
|
69
|
+
combined_groundtruths = np.stack(
|
|
70
|
+
[
|
|
71
|
+
groundtruth.mask.flatten()
|
|
72
|
+
for groundtruth in segmentation.groundtruths
|
|
73
|
+
],
|
|
74
|
+
axis=0,
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
combined_groundtruths = np.zeros(
|
|
78
|
+
(1, segmentation.shape[0] * segmentation.shape[1]),
|
|
79
|
+
dtype=np.bool_,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if segmentation.predictions:
|
|
83
|
+
combined_predictions = np.stack(
|
|
84
|
+
[
|
|
85
|
+
prediction.mask.flatten()
|
|
86
|
+
for prediction in segmentation.predictions
|
|
87
|
+
],
|
|
88
|
+
axis=0,
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
combined_predictions = np.zeros(
|
|
92
|
+
(1, segmentation.shape[0] * segmentation.shape[1]),
|
|
93
|
+
dtype=np.bool_,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
n_labels = len(self._labels)
|
|
97
|
+
counts = compute_intermediates(
|
|
98
|
+
groundtruths=combined_groundtruths,
|
|
99
|
+
predictions=combined_predictions,
|
|
100
|
+
groundtruth_labels=groundtruth_labels,
|
|
101
|
+
prediction_labels=prediction_labels,
|
|
102
|
+
n_labels=n_labels,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# prepare metadata
|
|
106
|
+
datum_metadata = (
|
|
107
|
+
segmentation.metadata if segmentation.metadata else {}
|
|
108
|
+
)
|
|
109
|
+
gt_metadata = {
|
|
110
|
+
self._labels[gt.label]: gt.metadata
|
|
111
|
+
for gt in segmentation.groundtruths
|
|
112
|
+
if gt.metadata
|
|
113
|
+
}
|
|
114
|
+
pd_metadata = {
|
|
115
|
+
self._labels[pd.label]: pd.metadata
|
|
116
|
+
for pd in segmentation.predictions
|
|
117
|
+
if pd.metadata
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
# cache formatting
|
|
121
|
+
rows = []
|
|
122
|
+
for idx in range(n_labels):
|
|
123
|
+
label = self._index_to_label[idx]
|
|
124
|
+
for pidx in range(n_labels):
|
|
125
|
+
# write non-zero intersections to cache
|
|
126
|
+
if counts[idx + 1, pidx + 1] > 0:
|
|
127
|
+
plabel = self._index_to_label[pidx]
|
|
128
|
+
rows.append(
|
|
129
|
+
{
|
|
130
|
+
# metadata
|
|
131
|
+
**datum_metadata,
|
|
132
|
+
**gt_metadata.get(idx, {}),
|
|
133
|
+
**pd_metadata.get(pidx, {}),
|
|
134
|
+
# datum
|
|
135
|
+
"datum_uid": segmentation.uid,
|
|
136
|
+
"datum_id": self._datum_count,
|
|
137
|
+
# groundtruth
|
|
138
|
+
"gt_label": label,
|
|
139
|
+
"gt_label_id": idx,
|
|
140
|
+
# prediction
|
|
141
|
+
"pd_label": plabel,
|
|
142
|
+
"pd_label_id": pidx,
|
|
143
|
+
# pair
|
|
144
|
+
"count": counts[idx + 1, pidx + 1],
|
|
145
|
+
}
|
|
146
|
+
)
|
|
147
|
+
# write all unmatched to preserve labels
|
|
148
|
+
rows.extend(
|
|
149
|
+
[
|
|
150
|
+
{
|
|
151
|
+
# metadata
|
|
152
|
+
**datum_metadata,
|
|
153
|
+
**gt_metadata.get(idx, {}),
|
|
154
|
+
# datum
|
|
155
|
+
"datum_uid": segmentation.uid,
|
|
156
|
+
"datum_id": self._datum_count,
|
|
157
|
+
# groundtruth
|
|
158
|
+
"gt_label": label,
|
|
159
|
+
"gt_label_id": idx,
|
|
160
|
+
# prediction
|
|
161
|
+
"pd_label": None,
|
|
162
|
+
"pd_label_id": -1,
|
|
163
|
+
# pair
|
|
164
|
+
"count": counts[idx + 1, 0],
|
|
165
|
+
},
|
|
166
|
+
{
|
|
167
|
+
# metadata
|
|
168
|
+
**datum_metadata,
|
|
169
|
+
**gt_metadata.get(idx, {}),
|
|
170
|
+
**pd_metadata.get(idx, {}),
|
|
171
|
+
# datum
|
|
172
|
+
"datum_uid": segmentation.uid,
|
|
173
|
+
"datum_id": self._datum_count,
|
|
174
|
+
# groundtruth
|
|
175
|
+
"gt_label": None,
|
|
176
|
+
"gt_label_id": -1,
|
|
177
|
+
# prediction
|
|
178
|
+
"pd_label": label,
|
|
179
|
+
"pd_label_id": idx,
|
|
180
|
+
# pair
|
|
181
|
+
"count": counts[0, idx + 1],
|
|
182
|
+
},
|
|
183
|
+
]
|
|
184
|
+
)
|
|
185
|
+
rows.append(
|
|
186
|
+
{
|
|
187
|
+
# metadata
|
|
188
|
+
**datum_metadata,
|
|
189
|
+
# datum
|
|
190
|
+
"datum_uid": segmentation.uid,
|
|
191
|
+
"datum_id": self._datum_count,
|
|
192
|
+
# groundtruth
|
|
193
|
+
"gt_label": None,
|
|
194
|
+
"gt_label_id": -1,
|
|
195
|
+
# prediction
|
|
196
|
+
"pd_label": None,
|
|
197
|
+
"pd_label_id": -1,
|
|
198
|
+
# pair
|
|
199
|
+
"count": counts[0, 0],
|
|
200
|
+
}
|
|
201
|
+
)
|
|
202
|
+
self._writer.write_rows(rows)
|
|
203
|
+
|
|
204
|
+
# update datum count
|
|
205
|
+
self._datum_count += 1
|