valor-lite 0.36.5__py3-none-any.whl → 0.37.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +211 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +536 -0
- valor_lite/classification/__init__.py +5 -10
- valor_lite/classification/annotation.py +4 -0
- valor_lite/classification/computation.py +233 -251
- valor_lite/classification/evaluator.py +882 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +141 -4
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +221 -118
- valor_lite/exceptions.py +5 -0
- valor_lite/object_detection/__init__.py +5 -4
- valor_lite/object_detection/annotation.py +13 -1
- valor_lite/object_detection/computation.py +367 -304
- valor_lite/object_detection/evaluator.py +804 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +152 -3
- valor_lite/object_detection/shared.py +206 -0
- valor_lite/object_detection/utilities.py +182 -109
- valor_lite/semantic_segmentation/__init__.py +5 -4
- valor_lite/semantic_segmentation/annotation.py +7 -0
- valor_lite/semantic_segmentation/computation.py +20 -110
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +6 -23
- {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
- valor_lite-0.37.5.dist-info/RECORD +49 -0
- {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
- valor_lite/classification/manager.py +0 -545
- valor_lite/object_detection/manager.py +0 -865
- valor_lite/profiling.py +0 -374
- valor_lite/semantic_segmentation/benchmark.py +0 -237
- valor_lite/semantic_segmentation/manager.py +0 -446
- valor_lite-0.36.5.dist-info/RECORD +0 -41
- {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import glob
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from collections.abc import Iterator
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pyarrow as pa
|
|
11
|
+
import pyarrow.compute as pc
|
|
12
|
+
import pyarrow.dataset as ds
|
|
13
|
+
import pyarrow.parquet as pq
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FileCache:
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
path: str | Path,
|
|
20
|
+
schema: pa.Schema,
|
|
21
|
+
batch_size: int,
|
|
22
|
+
rows_per_file: int,
|
|
23
|
+
compression: str,
|
|
24
|
+
):
|
|
25
|
+
self._path = Path(path)
|
|
26
|
+
self._schema = schema
|
|
27
|
+
self._batch_size = batch_size
|
|
28
|
+
self._rows_per_file = rows_per_file
|
|
29
|
+
self._compression = compression
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def path(self) -> Path:
|
|
33
|
+
return self._path
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def schema(self) -> pa.Schema:
|
|
37
|
+
return self._schema
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def batch_size(self) -> int:
|
|
41
|
+
return self._batch_size
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def rows_per_file(self) -> int:
|
|
45
|
+
return self._rows_per_file
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def compression(self) -> str:
|
|
49
|
+
return self._compression
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def _generate_config_path(path: str | Path) -> Path:
|
|
53
|
+
"""Generate cache configuration path."""
|
|
54
|
+
return Path(path) / ".cfg"
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def _encode_schema(schema: pa.Schema) -> str:
|
|
58
|
+
"""Encode schema to b64 string."""
|
|
59
|
+
schema_bytes = schema.serialize()
|
|
60
|
+
return base64.b64encode(schema_bytes).decode("utf-8")
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def _decode_schema(encoded_schema: str) -> pa.Schema:
|
|
64
|
+
"""Decode schema from b64 string."""
|
|
65
|
+
schema_bytes = base64.b64decode(encoded_schema)
|
|
66
|
+
return pa.ipc.read_schema(pa.BufferReader(schema_bytes))
|
|
67
|
+
|
|
68
|
+
def count_rows(self) -> int:
|
|
69
|
+
"""Count the number of rows in the cache."""
|
|
70
|
+
dataset = ds.dataset(
|
|
71
|
+
source=self._path,
|
|
72
|
+
format="parquet",
|
|
73
|
+
)
|
|
74
|
+
return dataset.count_rows()
|
|
75
|
+
|
|
76
|
+
def count_tables(self) -> int:
|
|
77
|
+
"""Count the number of files in the cache."""
|
|
78
|
+
return len(self.get_dataset_files())
|
|
79
|
+
|
|
80
|
+
def get_files(self) -> list[Path]:
|
|
81
|
+
"""
|
|
82
|
+
Retrieve all files.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
list[Path]
|
|
87
|
+
A list of paths to files in the cache.
|
|
88
|
+
"""
|
|
89
|
+
if not self._path.exists():
|
|
90
|
+
return []
|
|
91
|
+
files = []
|
|
92
|
+
for entry in os.listdir(self._path):
|
|
93
|
+
full_path = os.path.join(self._path, entry)
|
|
94
|
+
if os.path.isfile(full_path):
|
|
95
|
+
files.append(Path(full_path))
|
|
96
|
+
return files
|
|
97
|
+
|
|
98
|
+
def get_dataset_files(self) -> list[Path]:
|
|
99
|
+
"""
|
|
100
|
+
Retrieve all dataset files.
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
list[Path]
|
|
105
|
+
A list of paths to dataset files in the cache.
|
|
106
|
+
"""
|
|
107
|
+
if not self._path.exists():
|
|
108
|
+
return []
|
|
109
|
+
return [
|
|
110
|
+
Path(filepath) for filepath in glob.glob(f"{self._path}/*.parquet")
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class FileCacheReader(FileCache):
|
|
115
|
+
@classmethod
|
|
116
|
+
def load(cls, path: str | Path):
|
|
117
|
+
"""
|
|
118
|
+
Load cache from disk.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
path : str | Path
|
|
123
|
+
Where the cache is stored.
|
|
124
|
+
"""
|
|
125
|
+
path = Path(path)
|
|
126
|
+
if not path.exists():
|
|
127
|
+
raise FileNotFoundError(f"Directory does not exist: {path}")
|
|
128
|
+
elif not path.is_dir():
|
|
129
|
+
raise NotADirectoryError(
|
|
130
|
+
f"Path exists but is not a directory: {path}"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def _retrieve(config: dict, key: str):
|
|
134
|
+
if value := config.get(key, None):
|
|
135
|
+
return value
|
|
136
|
+
raise KeyError(
|
|
137
|
+
f"'{key}' is not defined within {cls._generate_config_path(path)}"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# read configuration file
|
|
141
|
+
cfg_path = cls._generate_config_path(path)
|
|
142
|
+
with open(cfg_path, "r") as f:
|
|
143
|
+
cfg = json.load(f)
|
|
144
|
+
batch_size = _retrieve(cfg, "batch_size")
|
|
145
|
+
rows_per_file = _retrieve(cfg, "rows_per_file")
|
|
146
|
+
compression = _retrieve(cfg, "compression")
|
|
147
|
+
schema = cls._decode_schema(_retrieve(cfg, "schema"))
|
|
148
|
+
|
|
149
|
+
return cls(
|
|
150
|
+
schema=schema,
|
|
151
|
+
path=path,
|
|
152
|
+
batch_size=batch_size,
|
|
153
|
+
rows_per_file=rows_per_file,
|
|
154
|
+
compression=compression,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def iterate_ordered_fragments(
|
|
158
|
+
self, filter: pc.Expression | None = None
|
|
159
|
+
) -> Iterator[ds.Fragment]:
|
|
160
|
+
"""
|
|
161
|
+
Iterate through ordered cache fragments.
|
|
162
|
+
"""
|
|
163
|
+
dataset = ds.dataset(
|
|
164
|
+
source=self._path,
|
|
165
|
+
schema=self._schema,
|
|
166
|
+
format="parquet",
|
|
167
|
+
)
|
|
168
|
+
fragments = list(dataset.get_fragments(filter=filter))
|
|
169
|
+
for fragment in sorted(
|
|
170
|
+
fragments, key=lambda x: int(Path(x.path).stem)
|
|
171
|
+
):
|
|
172
|
+
yield fragment
|
|
173
|
+
|
|
174
|
+
def iterate_tables(
|
|
175
|
+
self,
|
|
176
|
+
columns: list[str] | None = None,
|
|
177
|
+
filter: pc.Expression | None = None,
|
|
178
|
+
) -> Iterator[pa.Table]:
|
|
179
|
+
"""
|
|
180
|
+
Iterate over tables within the cache.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
columns : list[str], optional
|
|
185
|
+
Optionally select columns to be returned.
|
|
186
|
+
filter : pyarrow.compute.Expression, optional
|
|
187
|
+
Optionally filter table before returning.
|
|
188
|
+
|
|
189
|
+
Returns
|
|
190
|
+
-------
|
|
191
|
+
Iterator[pa.Table]
|
|
192
|
+
"""
|
|
193
|
+
for fragment in self.iterate_ordered_fragments(filter=filter):
|
|
194
|
+
yield fragment.to_table(columns=columns, filter=filter)
|
|
195
|
+
|
|
196
|
+
def iterate_arrays(
|
|
197
|
+
self,
|
|
198
|
+
numeric_columns: list[str] | None = None,
|
|
199
|
+
filter: pc.Expression | None = None,
|
|
200
|
+
) -> Iterator[np.ndarray]:
|
|
201
|
+
"""
|
|
202
|
+
Iterate over chunks within the cache returning arrays.
|
|
203
|
+
|
|
204
|
+
Parameters
|
|
205
|
+
----------
|
|
206
|
+
numeric_columns : list[str], optional
|
|
207
|
+
Optionally select numeric columns to be returned within an array.
|
|
208
|
+
filter : pyarrow.compute.Expression, optional
|
|
209
|
+
Optionally filter table before returning.
|
|
210
|
+
|
|
211
|
+
Returns
|
|
212
|
+
-------
|
|
213
|
+
Iterator[np.ndarray]
|
|
214
|
+
"""
|
|
215
|
+
for tbl in self.iterate_tables(
|
|
216
|
+
columns=numeric_columns,
|
|
217
|
+
filter=filter,
|
|
218
|
+
):
|
|
219
|
+
yield np.column_stack(
|
|
220
|
+
[tbl.column(i).to_numpy() for i in range(tbl.num_columns)]
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def iterate_tables_with_arrays(
|
|
224
|
+
self,
|
|
225
|
+
columns: list[str] | None = None,
|
|
226
|
+
filter: pc.Expression | None = None,
|
|
227
|
+
numeric_columns: list[str] | None = None,
|
|
228
|
+
) -> Iterator[tuple[pa.Table, np.ndarray]]:
|
|
229
|
+
"""
|
|
230
|
+
Iterate over chunks within the cache returning both tables and arrays.
|
|
231
|
+
|
|
232
|
+
Parameters
|
|
233
|
+
----------
|
|
234
|
+
columns : list[str], optional
|
|
235
|
+
Optionally select columns to be returned.
|
|
236
|
+
filter : pyarrow.compute.Expression, optional
|
|
237
|
+
Optionally filter table before returning.
|
|
238
|
+
numeric_columns : list[str], optional
|
|
239
|
+
Optionally select numeric columns to be returned within an array.
|
|
240
|
+
|
|
241
|
+
Returns
|
|
242
|
+
-------
|
|
243
|
+
Iterator[tuple[pa.Table, np.ndarray]]
|
|
244
|
+
|
|
245
|
+
"""
|
|
246
|
+
_columns = set(columns) if columns else set()
|
|
247
|
+
_numeric_columns = set(numeric_columns) if numeric_columns else set()
|
|
248
|
+
columns = list(_columns.union(_numeric_columns))
|
|
249
|
+
for tbl in self.iterate_tables(
|
|
250
|
+
columns=columns,
|
|
251
|
+
filter=filter,
|
|
252
|
+
):
|
|
253
|
+
table_columns = numeric_columns if numeric_columns else tbl.columns
|
|
254
|
+
yield tbl, np.column_stack(
|
|
255
|
+
[tbl[col].to_numpy() for col in table_columns]
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
def iterate_fragment_batch_iterators(
|
|
259
|
+
self, batch_size: int
|
|
260
|
+
) -> Iterator[Iterator[pa.RecordBatch]]:
|
|
261
|
+
"""
|
|
262
|
+
Iterate over fragment batch iterators within the file-based cache.
|
|
263
|
+
|
|
264
|
+
Parameters
|
|
265
|
+
----------
|
|
266
|
+
batch_size : int
|
|
267
|
+
Maximum number of rows allowed to be read into memory per cache file.
|
|
268
|
+
|
|
269
|
+
Returns
|
|
270
|
+
-------
|
|
271
|
+
Iterator[Iterator[pa.RecordBatch]]
|
|
272
|
+
"""
|
|
273
|
+
for fragment in self.iterate_ordered_fragments():
|
|
274
|
+
yield fragment.to_batches(batch_size=batch_size)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class FileCacheWriter(FileCache):
|
|
278
|
+
def __init__(
|
|
279
|
+
self,
|
|
280
|
+
path: str | Path,
|
|
281
|
+
schema: pa.Schema,
|
|
282
|
+
batch_size: int,
|
|
283
|
+
rows_per_file: int,
|
|
284
|
+
compression: str,
|
|
285
|
+
):
|
|
286
|
+
super().__init__(
|
|
287
|
+
path=path,
|
|
288
|
+
schema=schema,
|
|
289
|
+
batch_size=batch_size,
|
|
290
|
+
rows_per_file=rows_per_file,
|
|
291
|
+
compression=compression,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# internal state
|
|
295
|
+
self._writer = None
|
|
296
|
+
self._buffer = []
|
|
297
|
+
self._count = 0
|
|
298
|
+
|
|
299
|
+
@classmethod
|
|
300
|
+
def create(
|
|
301
|
+
cls,
|
|
302
|
+
path: str | Path,
|
|
303
|
+
schema: pa.Schema,
|
|
304
|
+
batch_size: int,
|
|
305
|
+
rows_per_file: int,
|
|
306
|
+
compression: str = "snappy",
|
|
307
|
+
delete_if_exists: bool = False,
|
|
308
|
+
):
|
|
309
|
+
"""
|
|
310
|
+
Create an on-disk cache.
|
|
311
|
+
|
|
312
|
+
Parameters
|
|
313
|
+
----------
|
|
314
|
+
path : str | Path
|
|
315
|
+
Where to write the cache.
|
|
316
|
+
schema : pa.Schema
|
|
317
|
+
Cache schema.
|
|
318
|
+
batch_size : int
|
|
319
|
+
Target batch size when writing chunks.
|
|
320
|
+
rows_per_file : int
|
|
321
|
+
Target number of rows to store per file.
|
|
322
|
+
compression : str, default="snappy"
|
|
323
|
+
Compression method to use when storing on disk.
|
|
324
|
+
delete_if_exists : bool, default=False
|
|
325
|
+
Delete the cache if it already exists.
|
|
326
|
+
"""
|
|
327
|
+
path = Path(path)
|
|
328
|
+
if delete_if_exists and path.exists():
|
|
329
|
+
cls.delete(path)
|
|
330
|
+
Path(path).mkdir(parents=True, exist_ok=False)
|
|
331
|
+
|
|
332
|
+
# write configuration file
|
|
333
|
+
cfg_path = cls._generate_config_path(path)
|
|
334
|
+
with open(cfg_path, "w") as f:
|
|
335
|
+
cfg = dict(
|
|
336
|
+
batch_size=batch_size,
|
|
337
|
+
rows_per_file=rows_per_file,
|
|
338
|
+
compression=compression,
|
|
339
|
+
schema=cls._encode_schema(schema),
|
|
340
|
+
)
|
|
341
|
+
json.dump(cfg, f, indent=2)
|
|
342
|
+
|
|
343
|
+
return cls(
|
|
344
|
+
schema=schema,
|
|
345
|
+
path=path,
|
|
346
|
+
batch_size=batch_size,
|
|
347
|
+
rows_per_file=rows_per_file,
|
|
348
|
+
compression=compression,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
@classmethod
|
|
352
|
+
def delete(cls, path: str | Path):
|
|
353
|
+
"""
|
|
354
|
+
Delete a cache at path.
|
|
355
|
+
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
358
|
+
path : str | Path
|
|
359
|
+
Where the cache is stored.
|
|
360
|
+
"""
|
|
361
|
+
path = Path(path)
|
|
362
|
+
if not path.exists():
|
|
363
|
+
return
|
|
364
|
+
|
|
365
|
+
# delete dataset files
|
|
366
|
+
reader = FileCacheReader.load(path)
|
|
367
|
+
for file in reader.get_dataset_files():
|
|
368
|
+
if file.exists() and file.is_file() and file.suffix == ".parquet":
|
|
369
|
+
file.unlink()
|
|
370
|
+
|
|
371
|
+
# delete config file
|
|
372
|
+
cfg_path = cls._generate_config_path(path)
|
|
373
|
+
if cfg_path.exists() and cfg_path.is_file():
|
|
374
|
+
cfg_path.unlink()
|
|
375
|
+
|
|
376
|
+
# delete empty cache directory
|
|
377
|
+
path.rmdir()
|
|
378
|
+
|
|
379
|
+
def write_rows(
|
|
380
|
+
self,
|
|
381
|
+
rows: list[dict[str, Any]],
|
|
382
|
+
):
|
|
383
|
+
"""
|
|
384
|
+
Write rows to cache.
|
|
385
|
+
|
|
386
|
+
Parameters
|
|
387
|
+
----------
|
|
388
|
+
rows : list[dict[str, Any]]
|
|
389
|
+
A list of rows represented by dictionaries mapping fields to values.
|
|
390
|
+
"""
|
|
391
|
+
if not rows:
|
|
392
|
+
return
|
|
393
|
+
batch = pa.RecordBatch.from_pylist(rows, schema=self._schema)
|
|
394
|
+
self.write_batch(batch)
|
|
395
|
+
|
|
396
|
+
def write_columns(
|
|
397
|
+
self,
|
|
398
|
+
columns: dict[str, list | np.ndarray | pa.Array],
|
|
399
|
+
):
|
|
400
|
+
"""
|
|
401
|
+
Write columnar data to cache.
|
|
402
|
+
|
|
403
|
+
Parameters
|
|
404
|
+
----------
|
|
405
|
+
columns : dict[str, list | np.ndarray | pa.Array]
|
|
406
|
+
A mapping of columnar field names to list of values.
|
|
407
|
+
"""
|
|
408
|
+
if not columns:
|
|
409
|
+
return
|
|
410
|
+
batch = pa.RecordBatch.from_pydict(columns)
|
|
411
|
+
self.write_batch(batch)
|
|
412
|
+
|
|
413
|
+
def write_batch(
|
|
414
|
+
self,
|
|
415
|
+
batch: pa.RecordBatch,
|
|
416
|
+
):
|
|
417
|
+
"""
|
|
418
|
+
Write a batch to cache.
|
|
419
|
+
|
|
420
|
+
Parameters
|
|
421
|
+
----------
|
|
422
|
+
batch : pa.RecordBatch
|
|
423
|
+
A batch of columnar data.
|
|
424
|
+
"""
|
|
425
|
+
size = batch.num_rows
|
|
426
|
+
if self._buffer:
|
|
427
|
+
size += sum([b.num_rows for b in self._buffer])
|
|
428
|
+
|
|
429
|
+
# check size
|
|
430
|
+
if size < self.batch_size and self._count < self.rows_per_file:
|
|
431
|
+
self._buffer.append(batch)
|
|
432
|
+
return
|
|
433
|
+
|
|
434
|
+
if self._buffer:
|
|
435
|
+
self._buffer.append(batch)
|
|
436
|
+
batch = pa.concat_batches(self._buffer)
|
|
437
|
+
self._buffer = []
|
|
438
|
+
|
|
439
|
+
# write batch
|
|
440
|
+
writer = self._get_or_create_writer()
|
|
441
|
+
writer.write_batch(batch)
|
|
442
|
+
|
|
443
|
+
# check file size
|
|
444
|
+
self._count += size
|
|
445
|
+
if self._count >= self.rows_per_file:
|
|
446
|
+
self.flush()
|
|
447
|
+
|
|
448
|
+
def write_table(
|
|
449
|
+
self,
|
|
450
|
+
table: pa.Table,
|
|
451
|
+
):
|
|
452
|
+
"""
|
|
453
|
+
Write a table directly to cache.
|
|
454
|
+
|
|
455
|
+
Parameters
|
|
456
|
+
----------
|
|
457
|
+
table : pa.Table
|
|
458
|
+
A populated table.
|
|
459
|
+
"""
|
|
460
|
+
self.flush()
|
|
461
|
+
pq.write_table(table, where=self._generate_next_filename())
|
|
462
|
+
|
|
463
|
+
def flush(self):
|
|
464
|
+
"""Flush the cache buffer."""
|
|
465
|
+
if self._buffer:
|
|
466
|
+
combined_arrays = [
|
|
467
|
+
pa.concat_arrays([b.column(name) for b in self._buffer])
|
|
468
|
+
for name in self._schema.names
|
|
469
|
+
]
|
|
470
|
+
batch = pa.RecordBatch.from_arrays(
|
|
471
|
+
combined_arrays, schema=self._schema
|
|
472
|
+
)
|
|
473
|
+
writer = self._get_or_create_writer()
|
|
474
|
+
writer.write_batch(batch)
|
|
475
|
+
self._buffer = []
|
|
476
|
+
self._count = 0
|
|
477
|
+
self._close_writer()
|
|
478
|
+
|
|
479
|
+
def sort_by(
|
|
480
|
+
self,
|
|
481
|
+
sorting: list[tuple[str, str]],
|
|
482
|
+
):
|
|
483
|
+
"""
|
|
484
|
+
Sort cache files locally and in-place.
|
|
485
|
+
|
|
486
|
+
Parameters
|
|
487
|
+
----------
|
|
488
|
+
sorting : list[tuple[str, str]]
|
|
489
|
+
Sorting arguments in PyArrow format (e.g. [('a', 'ascending'), ('b', 'descending')]).
|
|
490
|
+
"""
|
|
491
|
+
self.flush()
|
|
492
|
+
for file in self.get_dataset_files():
|
|
493
|
+
pf = pq.ParquetFile(file)
|
|
494
|
+
tbl = pf.read()
|
|
495
|
+
pf.close()
|
|
496
|
+
sorted_tbl = tbl.sort_by(sorting)
|
|
497
|
+
pq.write_table(sorted_tbl, file)
|
|
498
|
+
|
|
499
|
+
def _generate_next_filename(self) -> Path:
|
|
500
|
+
"""Generates next dataset filepath."""
|
|
501
|
+
files = self.get_dataset_files()
|
|
502
|
+
if not files:
|
|
503
|
+
next_index = 0
|
|
504
|
+
else:
|
|
505
|
+
next_index = max([int(Path(f).stem) for f in files]) + 1
|
|
506
|
+
return self._path / f"{next_index:06d}.parquet"
|
|
507
|
+
|
|
508
|
+
def _get_or_create_writer(self) -> pq.ParquetWriter:
|
|
509
|
+
"""Open a new parquet file for writing."""
|
|
510
|
+
if self._writer is not None:
|
|
511
|
+
return self._writer
|
|
512
|
+
self._writer = pq.ParquetWriter(
|
|
513
|
+
where=self._generate_next_filename(),
|
|
514
|
+
schema=self._schema,
|
|
515
|
+
compression=self._compression,
|
|
516
|
+
)
|
|
517
|
+
return self._writer
|
|
518
|
+
|
|
519
|
+
def _close_writer(self) -> None:
|
|
520
|
+
"""Close the current parquet file."""
|
|
521
|
+
if self._writer is not None:
|
|
522
|
+
self._writer.close()
|
|
523
|
+
self._writer = None
|
|
524
|
+
|
|
525
|
+
def __enter__(self):
|
|
526
|
+
"""Context manager entry."""
|
|
527
|
+
return self
|
|
528
|
+
|
|
529
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
530
|
+
"""Context manager exit - ensures data is flushed."""
|
|
531
|
+
self.flush()
|
|
532
|
+
|
|
533
|
+
def to_reader(self) -> FileCacheReader:
|
|
534
|
+
"""Get cache reader."""
|
|
535
|
+
self.flush()
|
|
536
|
+
return FileCacheReader.load(path=self.path)
|
|
@@ -1,19 +1,14 @@
|
|
|
1
1
|
from .annotation import Classification
|
|
2
|
-
from .
|
|
3
|
-
|
|
4
|
-
compute_precision_recall_rocauc,
|
|
5
|
-
)
|
|
6
|
-
from .manager import DataLoader, Evaluator, Filter, Metadata
|
|
2
|
+
from .evaluator import Evaluator
|
|
3
|
+
from .loader import Loader
|
|
7
4
|
from .metric import Metric, MetricType
|
|
5
|
+
from .shared import EvaluatorInfo
|
|
8
6
|
|
|
9
7
|
__all__ = [
|
|
10
8
|
"Classification",
|
|
11
|
-
"compute_precision_recall_rocauc",
|
|
12
|
-
"compute_confusion_matrix",
|
|
13
9
|
"MetricType",
|
|
14
|
-
"
|
|
10
|
+
"Loader",
|
|
15
11
|
"Evaluator",
|
|
16
12
|
"Metric",
|
|
17
|
-
"
|
|
18
|
-
"Filter",
|
|
13
|
+
"EvaluatorInfo",
|
|
19
14
|
]
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
@dataclass
|
|
@@ -16,6 +17,8 @@ class Classification:
|
|
|
16
17
|
List of predicted labels.
|
|
17
18
|
scores : list of float
|
|
18
19
|
Confidence scores corresponding to each predicted label.
|
|
20
|
+
metadata : dict[str, Any], optional
|
|
21
|
+
A dictionary containing any metadata to be used within filtering operations.
|
|
19
22
|
|
|
20
23
|
Examples
|
|
21
24
|
--------
|
|
@@ -31,6 +34,7 @@ class Classification:
|
|
|
31
34
|
groundtruth: str
|
|
32
35
|
predictions: list[str]
|
|
33
36
|
scores: list[float]
|
|
37
|
+
metadata: dict[str, Any] | None = None
|
|
34
38
|
|
|
35
39
|
def __post_init__(self):
|
|
36
40
|
if not isinstance(self.groundtruth, str):
|