pixeltable 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/catalog/column.py +1 -1
- pixeltable/client.py +72 -2
- pixeltable/env.py +36 -52
- pixeltable/functions/__init__.py +1 -1
- pixeltable/functions/fireworks.py +10 -37
- pixeltable/functions/openai.py +192 -24
- pixeltable/functions/together.py +104 -9
- pixeltable/tests/conftest.py +4 -4
- pixeltable/tests/functions/test_fireworks.py +42 -0
- pixeltable/tests/functions/test_functions.py +60 -0
- pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +5 -141
- pixeltable/tests/functions/test_openai.py +152 -0
- pixeltable/tests/functions/test_together.py +111 -0
- pixeltable/tests/test_dataframe.py +4 -4
- pixeltable/tests/test_table.py +105 -2
- pixeltable/tests/utils.py +128 -5
- pixeltable/type_system.py +41 -84
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/hf_datasets.py +157 -0
- pixeltable/utils/parquet.py +68 -27
- pixeltable/utils/pytorch.py +16 -97
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/METADATA +33 -27
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/RECORD +25 -19
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/WHEEL +0 -0
pixeltable/type_system.py
CHANGED
|
@@ -6,9 +6,8 @@ import enum
|
|
|
6
6
|
import json
|
|
7
7
|
import typing
|
|
8
8
|
import urllib.parse
|
|
9
|
-
from copy import copy
|
|
10
9
|
from pathlib import Path
|
|
11
|
-
from typing import Any, Optional, Tuple, Dict, Callable, List, Union
|
|
10
|
+
from typing import Any, Optional, Tuple, Dict, Callable, List, Union, Sequence, Mapping
|
|
12
11
|
|
|
13
12
|
import PIL.Image
|
|
14
13
|
import av
|
|
@@ -240,19 +239,38 @@ class ColumnType:
|
|
|
240
239
|
|
|
241
240
|
@classmethod
|
|
242
241
|
def from_python_type(cls, t: type) -> Optional[ColumnType]:
|
|
243
|
-
if t
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
underlying
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
242
|
+
if typing.get_origin(t) is typing.Union:
|
|
243
|
+
union_args = typing.get_args(t)
|
|
244
|
+
if union_args[1] is type(None):
|
|
245
|
+
# `t` is a type of the form Optional[T] (equivalently, Union[T, None]).
|
|
246
|
+
# We treat it as the underlying type but with nullable=True.
|
|
247
|
+
underlying = cls.from_python_type(union_args[0])
|
|
248
|
+
if underlying is not None:
|
|
249
|
+
underlying.nullable = True
|
|
250
|
+
return underlying
|
|
251
|
+
else:
|
|
252
|
+
# Discard type parameters to ensure that parameterized types such as `list[T]`
|
|
253
|
+
# are correctly mapped to Pixeltable types.
|
|
254
|
+
base = typing.get_origin(t)
|
|
255
|
+
if base is None:
|
|
256
|
+
# No type parameters; the base type is just `t` itself
|
|
257
|
+
base = t
|
|
258
|
+
if base is str:
|
|
259
|
+
return StringType()
|
|
260
|
+
if base is int:
|
|
261
|
+
return IntType()
|
|
262
|
+
if base is float:
|
|
263
|
+
return FloatType()
|
|
264
|
+
if base is bool:
|
|
265
|
+
return BoolType()
|
|
266
|
+
if base is datetime.date or base is datetime.datetime:
|
|
267
|
+
return TimestampType()
|
|
268
|
+
if issubclass(base, Sequence) or issubclass(base, Mapping):
|
|
269
|
+
return JsonType()
|
|
270
|
+
if issubclass(base, PIL.Image.Image):
|
|
271
|
+
return ImageType()
|
|
253
272
|
return None
|
|
254
273
|
|
|
255
|
-
|
|
256
274
|
def validate_literal(self, val: Any) -> None:
|
|
257
275
|
"""Raise TypeError if val is not a valid literal for this type"""
|
|
258
276
|
if val is None:
|
|
@@ -383,10 +401,6 @@ class ColumnType:
|
|
|
383
401
|
return sql.VARBINARY
|
|
384
402
|
assert False
|
|
385
403
|
|
|
386
|
-
@abc.abstractmethod
|
|
387
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
388
|
-
assert False, f'Have not implemented {self.__class__.__name__} to Arrow'
|
|
389
|
-
|
|
390
404
|
@staticmethod
|
|
391
405
|
def no_conversion(v: Any) -> Any:
|
|
392
406
|
"""
|
|
@@ -413,9 +427,6 @@ class InvalidType(ColumnType):
|
|
|
413
427
|
def to_sa_type(self) -> Any:
|
|
414
428
|
assert False
|
|
415
429
|
|
|
416
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
417
|
-
assert False
|
|
418
|
-
|
|
419
430
|
def print_value(self, val: Any) -> str:
|
|
420
431
|
assert False
|
|
421
432
|
|
|
@@ -442,10 +453,6 @@ class StringType(ColumnType):
|
|
|
442
453
|
|
|
443
454
|
def to_sa_type(self) -> str:
|
|
444
455
|
return sql.String
|
|
445
|
-
|
|
446
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
447
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
448
|
-
return pa.string()
|
|
449
456
|
|
|
450
457
|
def print_value(self, val: Any) -> str:
|
|
451
458
|
return f"'{val}'"
|
|
@@ -454,6 +461,13 @@ class StringType(ColumnType):
|
|
|
454
461
|
if not isinstance(val, str):
|
|
455
462
|
raise TypeError(f'Expected string, got {val.__class__.__name__}')
|
|
456
463
|
|
|
464
|
+
def _create_literal(self, val: Any) -> Any:
|
|
465
|
+
# Replace null byte within python string with space to avoid issues with Postgres.
|
|
466
|
+
# Use a space to avoid merging words.
|
|
467
|
+
# TODO(orm): this will also be an issue with JSON inputs, would space still be a good replacement?
|
|
468
|
+
if isinstance(val, str) and '\x00' in val:
|
|
469
|
+
return val.replace('\x00', ' ')
|
|
470
|
+
return val
|
|
457
471
|
|
|
458
472
|
class IntType(ColumnType):
|
|
459
473
|
def __init__(self, nullable: bool = False):
|
|
@@ -464,10 +478,6 @@ class IntType(ColumnType):
|
|
|
464
478
|
|
|
465
479
|
def to_sa_type(self) -> str:
|
|
466
480
|
return sql.BigInteger
|
|
467
|
-
|
|
468
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
469
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
470
|
-
return pa.int64() # to be consistent with bigint above
|
|
471
481
|
|
|
472
482
|
def _validate_literal(self, val: Any) -> None:
|
|
473
483
|
if not isinstance(val, int):
|
|
@@ -483,10 +493,6 @@ class FloatType(ColumnType):
|
|
|
483
493
|
|
|
484
494
|
def to_sa_type(self) -> str:
|
|
485
495
|
return sql.Float
|
|
486
|
-
|
|
487
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
488
|
-
import pyarrow as pa
|
|
489
|
-
return pa.float32()
|
|
490
496
|
|
|
491
497
|
def _validate_literal(self, val: Any) -> None:
|
|
492
498
|
if not isinstance(val, float):
|
|
@@ -506,10 +512,6 @@ class BoolType(ColumnType):
|
|
|
506
512
|
|
|
507
513
|
def to_sa_type(self) -> str:
|
|
508
514
|
return sql.Boolean
|
|
509
|
-
|
|
510
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
511
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
512
|
-
return pa.bool_()
|
|
513
515
|
|
|
514
516
|
def _validate_literal(self, val: Any) -> None:
|
|
515
517
|
if not isinstance(val, bool):
|
|
@@ -529,10 +531,6 @@ class TimestampType(ColumnType):
|
|
|
529
531
|
|
|
530
532
|
def to_sa_type(self) -> str:
|
|
531
533
|
return sql.TIMESTAMP
|
|
532
|
-
|
|
533
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
534
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
535
|
-
return pa.timestamp('us') # postgres timestamp is microseconds
|
|
536
534
|
|
|
537
535
|
def _validate_literal(self, val: Any) -> None:
|
|
538
536
|
if not isinstance(val, datetime.datetime) and not isinstance(val, datetime.date):
|
|
@@ -570,10 +568,6 @@ class JsonType(ColumnType):
|
|
|
570
568
|
|
|
571
569
|
def to_sa_type(self) -> str:
|
|
572
570
|
return sql.dialects.postgresql.JSONB
|
|
573
|
-
|
|
574
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
575
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
576
|
-
return pa.string() # TODO: weight advantage of pa.struct type.
|
|
577
571
|
|
|
578
572
|
def print_value(self, val: Any) -> str:
|
|
579
573
|
val_type = self.infer_literal_type(val)
|
|
@@ -669,7 +663,9 @@ class ArrayType(ColumnType):
|
|
|
669
663
|
|
|
670
664
|
def _create_literal(self, val: Any) -> Any:
|
|
671
665
|
if isinstance(val, (list,tuple)):
|
|
672
|
-
|
|
666
|
+
# map python float to whichever numpy float is
|
|
667
|
+
# declared for this type, rather than assume float64
|
|
668
|
+
return np.array(val, dtype=self.numpy_dtype())
|
|
673
669
|
return val
|
|
674
670
|
|
|
675
671
|
def to_sql(self) -> str:
|
|
@@ -677,12 +673,6 @@ class ArrayType(ColumnType):
|
|
|
677
673
|
|
|
678
674
|
def to_sa_type(self) -> str:
|
|
679
675
|
return sql.LargeBinary
|
|
680
|
-
|
|
681
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
682
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
683
|
-
if any([n is None for n in self.shape]):
|
|
684
|
-
raise TypeError(f'Cannot convert array with unknown shape to Arrow')
|
|
685
|
-
return pa.fixed_shape_tensor(pa.from_numpy_dtype(self.numpy_dtype()), self.shape)
|
|
686
676
|
|
|
687
677
|
def numpy_dtype(self) -> np.dtype:
|
|
688
678
|
if self.dtype == self.Type.INT:
|
|
@@ -788,10 +778,6 @@ class ImageType(ColumnType):
|
|
|
788
778
|
|
|
789
779
|
def to_sa_type(self) -> str:
|
|
790
780
|
return sql.String
|
|
791
|
-
|
|
792
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
793
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
794
|
-
return pa.binary()
|
|
795
781
|
|
|
796
782
|
def _validate_literal(self, val: Any) -> None:
|
|
797
783
|
if isinstance(val, PIL.Image.Image):
|
|
@@ -815,10 +801,6 @@ class VideoType(ColumnType):
|
|
|
815
801
|
|
|
816
802
|
def to_sa_type(self) -> str:
|
|
817
803
|
return sql.String
|
|
818
|
-
|
|
819
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
820
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
821
|
-
return pa.string()
|
|
822
804
|
|
|
823
805
|
def _validate_literal(self, val: Any) -> None:
|
|
824
806
|
self._validate_file_path(val)
|
|
@@ -854,10 +836,6 @@ class AudioType(ColumnType):
|
|
|
854
836
|
def to_sa_type(self) -> str:
|
|
855
837
|
return sql.String
|
|
856
838
|
|
|
857
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
858
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
859
|
-
return pa.string()
|
|
860
|
-
|
|
861
839
|
def _validate_literal(self, val: Any) -> None:
|
|
862
840
|
self._validate_file_path(val)
|
|
863
841
|
|
|
@@ -901,10 +879,6 @@ class DocumentType(ColumnType):
|
|
|
901
879
|
def to_sa_type(self) -> str:
|
|
902
880
|
return sql.String
|
|
903
881
|
|
|
904
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
905
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
906
|
-
return pa.string()
|
|
907
|
-
|
|
908
882
|
def _validate_literal(self, val: Any) -> None:
|
|
909
883
|
self._validate_file_path(val)
|
|
910
884
|
|
|
@@ -919,20 +893,3 @@ class DocumentType(ColumnType):
|
|
|
919
893
|
raise excs.Error(f'Not a recognized document format: {val}')
|
|
920
894
|
except Exception as e:
|
|
921
895
|
raise excs.Error(f'Not a recognized document format: {val}') from None
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
# A dictionary mapping various Python types to their respective ColumnTypes.
|
|
925
|
-
# This can be used to infer Pixeltable ColumnTypes from type hints on Python
|
|
926
|
-
# functions. (Since Python functions do not necessarily have type hints, this
|
|
927
|
-
# should always be an optional/convenience inference.)
|
|
928
|
-
_python_type_to_column_type: dict[type, ColumnType] = {
|
|
929
|
-
str: StringType(),
|
|
930
|
-
int: IntType(),
|
|
931
|
-
float: FloatType(),
|
|
932
|
-
bool: BoolType(),
|
|
933
|
-
datetime.datetime: TimestampType(),
|
|
934
|
-
datetime.date: TimestampType(),
|
|
935
|
-
list: JsonType(),
|
|
936
|
-
dict: JsonType(),
|
|
937
|
-
PIL.Image.Image: ImageType()
|
|
938
|
-
}
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Dict, Iterable, Iterator, Optional
|
|
3
|
+
|
|
4
|
+
import pyarrow as pa
|
|
5
|
+
|
|
6
|
+
import pixeltable.type_system as ts
|
|
7
|
+
|
|
8
|
+
_logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
_pa_to_pt: Dict[pa.DataType, ts.ColumnType] = {
|
|
11
|
+
pa.string(): ts.StringType(nullable=True),
|
|
12
|
+
pa.timestamp('us'): ts.TimestampType(nullable=True),
|
|
13
|
+
pa.bool_(): ts.BoolType(nullable=True),
|
|
14
|
+
pa.uint8(): ts.IntType(nullable=True),
|
|
15
|
+
pa.int8(): ts.IntType(nullable=True),
|
|
16
|
+
pa.uint32(): ts.IntType(nullable=True),
|
|
17
|
+
pa.uint64(): ts.IntType(nullable=True),
|
|
18
|
+
pa.int32(): ts.IntType(nullable=True),
|
|
19
|
+
pa.int64(): ts.IntType(nullable=True),
|
|
20
|
+
pa.float32(): ts.FloatType(nullable=True),
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
_pt_to_pa: Dict[ts.ColumnType, pa.DataType] = {
|
|
24
|
+
ts.StringType: pa.string(),
|
|
25
|
+
ts.TimestampType: pa.timestamp('us'), # postgres timestamp is microseconds
|
|
26
|
+
ts.BoolType: pa.bool_(),
|
|
27
|
+
ts.IntType: pa.int64(),
|
|
28
|
+
ts.FloatType: pa.float32(),
|
|
29
|
+
ts.JsonType: pa.string(), # TODO(orm) pa.struct() is possible
|
|
30
|
+
ts.ImageType: pa.binary(), # inline image
|
|
31
|
+
ts.AudioType: pa.string(), # path
|
|
32
|
+
ts.VideoType: pa.string(), # path
|
|
33
|
+
ts.DocumentType: pa.string(), # path
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def to_pixeltable_type(arrow_type: pa.DataType) -> Optional[ts.ColumnType]:
|
|
38
|
+
"""Convert a pyarrow DataType to a pixeltable ColumnType if one is defined.
|
|
39
|
+
Returns None if no conversion is currently implemented.
|
|
40
|
+
"""
|
|
41
|
+
if arrow_type in _pa_to_pt:
|
|
42
|
+
return _pa_to_pt[arrow_type]
|
|
43
|
+
elif isinstance(arrow_type, pa.FixedShapeTensorType):
|
|
44
|
+
dtype = to_pixeltable_type(arrow_type.value_type)
|
|
45
|
+
if dtype is None:
|
|
46
|
+
return None
|
|
47
|
+
return ts.ArrayType(shape=arrow_type.shape, dtype=dtype)
|
|
48
|
+
else:
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def to_arrow_type(pixeltable_type: ts.ColumnType) -> Optional[pa.DataType]:
|
|
53
|
+
"""Convert a pixeltable DataType to a pyarrow datatype if one is defined.
|
|
54
|
+
Returns None if no conversion is currently implemented.
|
|
55
|
+
"""
|
|
56
|
+
if pixeltable_type.__class__ in _pt_to_pa:
|
|
57
|
+
return _pt_to_pa[pixeltable_type.__class__]
|
|
58
|
+
elif isinstance(pixeltable_type, ts.ArrayType):
|
|
59
|
+
return pa.fixed_shape_tensor(pa.from_numpy_dtype(pixeltable_type.numpy_dtype()), pixeltable_type.shape)
|
|
60
|
+
else:
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def to_pixeltable_schema(arrow_schema: pa.Schema) -> Dict[str, ts.ColumnType]:
|
|
65
|
+
return {field.name: to_pixeltable_type(field.type) for field in arrow_schema}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def to_arrow_schema(pixeltable_schema: Dict[str, Any]) -> pa.Schema:
|
|
69
|
+
return pa.schema((name, to_arrow_type(typ)) for name, typ in pixeltable_schema.items())
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def to_pydict(batch: pa.RecordBatch) -> Dict[str, Iterable[Any]]:
|
|
73
|
+
"""Convert a RecordBatch to a dictionary of lists, unlike pa.lib.RecordBatch.to_pydict,
|
|
74
|
+
this function will not convert numpy arrays to lists, and will preserve the original numpy dtype.
|
|
75
|
+
"""
|
|
76
|
+
out = {}
|
|
77
|
+
for k, name in enumerate(batch.schema.names):
|
|
78
|
+
col = batch.column(k)
|
|
79
|
+
if isinstance(col.type, pa.FixedShapeTensorType):
|
|
80
|
+
# treat array columns as numpy arrays to easily preserve numpy type
|
|
81
|
+
out[name] = col.to_numpy(zero_copy_only=False)
|
|
82
|
+
else:
|
|
83
|
+
# for the rest, use pydict to preserve python types
|
|
84
|
+
out[name] = col.to_pylist()
|
|
85
|
+
|
|
86
|
+
return out
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def iter_tuples(batch: pa.RecordBatch) -> Iterator[Dict[str, Any]]:
|
|
90
|
+
"""Convert a RecordBatch to an iterator of dictionaries. also works with pa.Table and pa.RowGroup"""
|
|
91
|
+
pydict = to_pydict(batch)
|
|
92
|
+
assert len(pydict) > 0, 'empty record batch'
|
|
93
|
+
for _, v in pydict.items():
|
|
94
|
+
batch_size = len(v)
|
|
95
|
+
break
|
|
96
|
+
|
|
97
|
+
for i in range(batch_size):
|
|
98
|
+
yield {col_name: values[i] for col_name, values in pydict.items()}
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import datasets
|
|
2
|
+
from typing import Union, Optional, List, Dict, Any
|
|
3
|
+
import pixeltable.type_system as ts
|
|
4
|
+
from pixeltable import exceptions as excs
|
|
5
|
+
import math
|
|
6
|
+
import logging
|
|
7
|
+
import pixeltable
|
|
8
|
+
import random
|
|
9
|
+
|
|
10
|
+
_logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
# use 100MB as the batch size limit for loading a huggingface dataset into pixeltable.
|
|
13
|
+
# The primary goal is to bound memory use, regardless of dataset size.
|
|
14
|
+
# Second goal is to limit overhead. 100MB is presumed to be reasonable for a lot of storage systems.
|
|
15
|
+
_K_BATCH_SIZE_BYTES = 100_000_000
|
|
16
|
+
|
|
17
|
+
# note, there are many more types. we allow overrides in the schema_override parameter
|
|
18
|
+
# to handle cases where the appropriate type is not yet mapped, or to override this mapping.
|
|
19
|
+
# https://huggingface.co/docs/datasets/v2.17.0/en/package_reference/main_classes#datasets.Value
|
|
20
|
+
_hf_to_pxt: Dict[str, ts.ColumnType] = {
|
|
21
|
+
'int32': ts.IntType(nullable=True), # pixeltable widens to big int
|
|
22
|
+
'int64': ts.IntType(nullable=True),
|
|
23
|
+
'bool': ts.BoolType(nullable=True),
|
|
24
|
+
'float32': ts.FloatType(nullable=True),
|
|
25
|
+
'string': ts.StringType(nullable=True),
|
|
26
|
+
'timestamp[s]': ts.TimestampType(nullable=True),
|
|
27
|
+
'timestamp[ms]': ts.TimestampType(nullable=True), # HF dataset iterator converts timestamps to datetime.datetime
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
def _to_pixeltable_type(
|
|
31
|
+
feature_type: Union[datasets.ClassLabel, datasets.Value, datasets.Sequence],
|
|
32
|
+
) -> Optional[ts.ColumnType]:
|
|
33
|
+
"""Convert a huggingface feature type to a pixeltable ColumnType if one is defined."""
|
|
34
|
+
if isinstance(feature_type, datasets.ClassLabel):
|
|
35
|
+
# enum, example: ClassLabel(names=['neg', 'pos'], id=None)
|
|
36
|
+
return ts.StringType(nullable=True)
|
|
37
|
+
elif isinstance(feature_type, datasets.Value):
|
|
38
|
+
# example: Value(dtype='int64', id=None)
|
|
39
|
+
return _hf_to_pxt.get(feature_type.dtype, None)
|
|
40
|
+
elif isinstance(feature_type, datasets.Sequence):
|
|
41
|
+
# example: cohere wiki. Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None)
|
|
42
|
+
dtype = _to_pixeltable_type(feature_type.feature)
|
|
43
|
+
length = feature_type.length if feature_type.length != -1 else None
|
|
44
|
+
return ts.ArrayType(shape=(length,), dtype=dtype)
|
|
45
|
+
else:
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
def _get_hf_schema(dataset: Union[datasets.Dataset, datasets.DatasetDict]) -> datasets.Features:
|
|
49
|
+
"""Get the schema of a huggingface dataset as a dictionary."""
|
|
50
|
+
first_dataset = dataset if isinstance(dataset, datasets.Dataset) else next(iter(dataset.values()))
|
|
51
|
+
return first_dataset.features
|
|
52
|
+
|
|
53
|
+
def huggingface_schema_to_pixeltable_schema(
|
|
54
|
+
hf_dataset: Union[datasets.Dataset, datasets.DatasetDict],
|
|
55
|
+
) -> Dict[str, Optional[ts.ColumnType]]:
|
|
56
|
+
"""Generate a pixeltable schema from a huggingface dataset schema.
|
|
57
|
+
Columns without a known mapping are mapped to None
|
|
58
|
+
"""
|
|
59
|
+
hf_schema = _get_hf_schema(hf_dataset)
|
|
60
|
+
pixeltable_schema = {
|
|
61
|
+
column_name: _to_pixeltable_type(feature_type) for column_name, feature_type in hf_schema.items()
|
|
62
|
+
}
|
|
63
|
+
return pixeltable_schema
|
|
64
|
+
|
|
65
|
+
def import_huggingface_dataset(
|
|
66
|
+
cl: 'pixeltable.Client',
|
|
67
|
+
table_path: str,
|
|
68
|
+
dataset: Union[datasets.Dataset, datasets.DatasetDict],
|
|
69
|
+
*,
|
|
70
|
+
column_name_for_split: Optional[str],
|
|
71
|
+
schema_override: Optional[Dict[str, Any]],
|
|
72
|
+
**kwargs,
|
|
73
|
+
) -> 'pixeltable.InsertableTable':
|
|
74
|
+
"""See `pixeltable.Client.import_huggingface_dataset` for documentation"""
|
|
75
|
+
if table_path in cl.list_tables():
|
|
76
|
+
raise excs.Error(f'table {table_path} already exists')
|
|
77
|
+
|
|
78
|
+
if not isinstance(dataset, (datasets.Dataset, datasets.DatasetDict)):
|
|
79
|
+
raise excs.Error(f'`type(dataset)` must be `datasets.Dataset` or `datasets.DatasetDict`. Got {type(dataset)=}')
|
|
80
|
+
|
|
81
|
+
if isinstance(dataset, datasets.Dataset):
|
|
82
|
+
# when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
|
|
83
|
+
raw_name = dataset.split._name
|
|
84
|
+
split_name = raw_name.split('[')[0] if raw_name is not None else None
|
|
85
|
+
dataset_dict = {split_name: dataset}
|
|
86
|
+
else:
|
|
87
|
+
dataset_dict = dataset
|
|
88
|
+
|
|
89
|
+
pixeltable_schema = huggingface_schema_to_pixeltable_schema(dataset)
|
|
90
|
+
if schema_override is not None:
|
|
91
|
+
pixeltable_schema.update(schema_override)
|
|
92
|
+
|
|
93
|
+
if column_name_for_split is not None:
|
|
94
|
+
if column_name_for_split in pixeltable_schema:
|
|
95
|
+
raise excs.Error(
|
|
96
|
+
f'Column name `{column_name_for_split}` already exists in dataset schema; provide a different `column_name_for_split`'
|
|
97
|
+
)
|
|
98
|
+
pixeltable_schema[column_name_for_split] = ts.StringType(nullable=True)
|
|
99
|
+
|
|
100
|
+
for field, column_type in pixeltable_schema.items():
|
|
101
|
+
if column_type is None:
|
|
102
|
+
raise excs.Error(f'Could not infer pixeltable type for feature `{field}` in huggingface dataset')
|
|
103
|
+
|
|
104
|
+
if isinstance(dataset, datasets.Dataset):
|
|
105
|
+
# when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
|
|
106
|
+
raw_name = dataset.split._name
|
|
107
|
+
split_name = raw_name.split('[')[0] if raw_name is not None else None
|
|
108
|
+
dataset_dict = {split_name: dataset}
|
|
109
|
+
elif isinstance(dataset, datasets.DatasetDict):
|
|
110
|
+
dataset_dict = dataset
|
|
111
|
+
else:
|
|
112
|
+
raise excs.Error(f'`type(dataset)` must be `datasets.Dataset` or `datasets.DatasetDict`. Got {type(dataset)=}')
|
|
113
|
+
|
|
114
|
+
# extract all class labels from the dataset to translate category ints to strings
|
|
115
|
+
hf_schema = _get_hf_schema(dataset)
|
|
116
|
+
categorical_features = {
|
|
117
|
+
feature_name: feature_type.names
|
|
118
|
+
for (feature_name, feature_type) in hf_schema.items()
|
|
119
|
+
if isinstance(feature_type, datasets.ClassLabel)
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
# random tmp name
|
|
124
|
+
tmp_name = f'{table_path}_tmp_{random.randint(0, 100000000)}'
|
|
125
|
+
tab = cl.create_table(tmp_name, pixeltable_schema, **kwargs)
|
|
126
|
+
|
|
127
|
+
def _translate_row(row: Dict[str, Any], split_name: str) -> Dict[str, Any]:
|
|
128
|
+
output_row = row.copy()
|
|
129
|
+
# map all class labels to strings
|
|
130
|
+
for field, values in categorical_features.items():
|
|
131
|
+
output_row[field] = values[row[field]]
|
|
132
|
+
# add split name to row
|
|
133
|
+
if column_name_for_split is not None:
|
|
134
|
+
output_row[column_name_for_split] = split_name
|
|
135
|
+
return output_row
|
|
136
|
+
|
|
137
|
+
for split_name, split_dataset in dataset_dict.items():
|
|
138
|
+
num_batches = split_dataset.size_in_bytes / _K_BATCH_SIZE_BYTES
|
|
139
|
+
tuples_per_batch = math.ceil(split_dataset.num_rows / num_batches)
|
|
140
|
+
assert tuples_per_batch > 0
|
|
141
|
+
|
|
142
|
+
batch = []
|
|
143
|
+
for row in split_dataset:
|
|
144
|
+
batch.append(_translate_row(row, split_name))
|
|
145
|
+
if len(batch) >= tuples_per_batch:
|
|
146
|
+
tab.insert(batch)
|
|
147
|
+
batch = []
|
|
148
|
+
# last batch
|
|
149
|
+
if len(batch) > 0:
|
|
150
|
+
tab.insert(batch)
|
|
151
|
+
|
|
152
|
+
except Exception as e:
|
|
153
|
+
_logger.error(f'Error while inserting dataset into table: {tmp_name}')
|
|
154
|
+
raise e
|
|
155
|
+
|
|
156
|
+
cl.move(tmp_name, table_path)
|
|
157
|
+
return cl.get_table(table_path)
|
pixeltable/utils/parquet.py
CHANGED
|
@@ -1,14 +1,24 @@
|
|
|
1
|
+
import io
|
|
1
2
|
import json
|
|
3
|
+
import logging
|
|
4
|
+
from collections import deque
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
2
9
|
import PIL.Image
|
|
3
|
-
import io
|
|
4
|
-
import pyarrow.parquet as pq
|
|
5
10
|
import pyarrow as pa
|
|
6
|
-
import
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
from collections import deque
|
|
11
|
+
import pyarrow.parquet
|
|
9
12
|
|
|
10
|
-
|
|
13
|
+
import pixeltable.type_system as ts
|
|
14
|
+
from pixeltable.utils.arrow import iter_tuples, to_arrow_schema, to_pixeltable_schema
|
|
11
15
|
from pixeltable.utils.transactional_directory import transactional_directory
|
|
16
|
+
import pixeltable.exceptions as exc
|
|
17
|
+
|
|
18
|
+
import random
|
|
19
|
+
|
|
20
|
+
_logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
12
22
|
|
|
13
23
|
def _write_batch(value_batch : Dict[str, deque], schema : pa.Schema, output_path : Path) -> None:
|
|
14
24
|
pydict = {}
|
|
@@ -20,18 +30,18 @@ def _write_batch(value_batch : Dict[str, deque], schema : pa.Schema, output_path
|
|
|
20
30
|
pydict[field.name] = value_batch[field.name]
|
|
21
31
|
|
|
22
32
|
tab = pa.Table.from_pydict(pydict, schema=schema)
|
|
23
|
-
|
|
33
|
+
pa.parquet.write_table(tab, output_path)
|
|
24
34
|
|
|
25
35
|
def save_parquet(df: 'pixeltable.DataFrame', dest_path: Path, partition_size_bytes: int = 100_000_000) -> None:
|
|
26
36
|
"""
|
|
27
37
|
Internal method to stream dataframe data to parquet format.
|
|
28
38
|
Does not materialize the dataset to memory.
|
|
29
39
|
|
|
30
|
-
It preserves pixeltable type metadata in a json file, which would otherwise
|
|
40
|
+
It preserves pixeltable type metadata in a json file, which would otherwise
|
|
31
41
|
not be available in the parquet format.
|
|
32
42
|
|
|
33
43
|
Images are stored inline in a compressed format in their parquet file.
|
|
34
|
-
|
|
44
|
+
|
|
35
45
|
Args:
|
|
36
46
|
df : dataframe to save.
|
|
37
47
|
dest_path : path to directory to save the parquet files to.
|
|
@@ -39,10 +49,9 @@ def save_parquet(df: 'pixeltable.DataFrame', dest_path: Path, partition_size_byt
|
|
|
39
49
|
"""
|
|
40
50
|
column_names = df.get_column_names()
|
|
41
51
|
column_types = df.get_column_types()
|
|
42
|
-
|
|
43
52
|
type_dict = {k: v.as_dict() for k, v in zip(column_names, column_types)}
|
|
53
|
+
arrow_schema = to_arrow_schema(dict(zip(column_names, column_types)))
|
|
44
54
|
|
|
45
|
-
arrow_schema = pa.schema([pa.field(name, column_types[i].to_arrow_type()) for i, name in enumerate(column_names)])
|
|
46
55
|
# store the changes atomically
|
|
47
56
|
with transactional_directory(dest_path) as temp_path:
|
|
48
57
|
# dump metadata json file so we can inspect what was the source of the parquet file later on.
|
|
@@ -65,8 +74,9 @@ def save_parquet(df: 'pixeltable.DataFrame', dest_path: Path, partition_size_byt
|
|
|
65
74
|
# images get inlined into the parquet file
|
|
66
75
|
if data_row.file_paths is not None and data_row.file_paths[e.slot_idx] is not None:
|
|
67
76
|
# if there is a file, read directly to preserve information
|
|
68
|
-
|
|
69
|
-
|
|
77
|
+
with open(data_row.file_paths[e.slot_idx], 'rb') as f:
|
|
78
|
+
val = f.read()
|
|
79
|
+
elif isinstance(val, PIL.Image.Image):
|
|
70
80
|
# if no file available, eg. bc it is computed, convert to png
|
|
71
81
|
buf = io.BytesIO()
|
|
72
82
|
val.save(buf, format='PNG')
|
|
@@ -109,18 +119,49 @@ def save_parquet(df: 'pixeltable.DataFrame', dest_path: Path, partition_size_byt
|
|
|
109
119
|
|
|
110
120
|
_write_batch(current_value_batch, arrow_schema, temp_path / f'part-{batch_num:05d}.parquet')
|
|
111
121
|
|
|
112
|
-
def get_part_metadata(path : Path) -> List[Tuple[str, int]]:
|
|
113
|
-
"""
|
|
114
|
-
Args:
|
|
115
|
-
path: path to directory containing parquet files.
|
|
116
|
-
Returns:
|
|
117
|
-
A list of (file_name, num_rows) tuples for the parquet files in file name order.
|
|
118
|
-
"""
|
|
119
|
-
parts = sorted([f for f in path.iterdir() if f.suffix == '.parquet'])
|
|
120
|
-
rows_per_file = {}
|
|
121
|
-
|
|
122
|
-
for part in parts:
|
|
123
|
-
parquet_file = pq.ParquetFile(str(part))
|
|
124
|
-
rows_per_file[part] = parquet_file.metadata.num_rows
|
|
125
122
|
|
|
126
|
-
|
|
123
|
+
def parquet_schema_to_pixeltable_schema(parquet_path: str) -> Dict[str, Optional[ts.ColumnType]]:
|
|
124
|
+
"""Generate a default pixeltable schema for the given parquet file. Returns None for unknown types."""
|
|
125
|
+
|
|
126
|
+
input_path = Path(parquet_path).expanduser()
|
|
127
|
+
parquet_dataset = pa.parquet.ParquetDataset(input_path)
|
|
128
|
+
return to_pixeltable_schema(parquet_dataset.schema)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def import_parquet(
|
|
132
|
+
cl: 'pixeltable.Client',
|
|
133
|
+
table_path: str,
|
|
134
|
+
*,
|
|
135
|
+
parquet_path: str,
|
|
136
|
+
schema_override: Optional[Dict[str, ts.ColumnType]],
|
|
137
|
+
**kwargs,
|
|
138
|
+
) -> 'catalog.InsertableTable':
|
|
139
|
+
"""See `pixeltable.Client.import_parquet` for documentation"""
|
|
140
|
+
input_path = Path(parquet_path).expanduser()
|
|
141
|
+
parquet_dataset = pa.parquet.ParquetDataset(input_path)
|
|
142
|
+
|
|
143
|
+
schema = parquet_schema_to_pixeltable_schema(parquet_path)
|
|
144
|
+
if schema_override is None:
|
|
145
|
+
schema_override = {}
|
|
146
|
+
|
|
147
|
+
schema.update(schema_override)
|
|
148
|
+
for k, v in schema.items():
|
|
149
|
+
if v is None:
|
|
150
|
+
raise exc.Error(f'Could not infer pixeltable type for column {k} from parquet file')
|
|
151
|
+
|
|
152
|
+
if table_path in cl.list_tables():
|
|
153
|
+
raise exc.Error(f'Table {table_path} already exists')
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
tmp_name = f'{table_path}_tmp_{random.randint(0, 100000000)}'
|
|
157
|
+
tab = cl.create_table(tmp_name, schema, **kwargs)
|
|
158
|
+
for fragment in parquet_dataset.fragments:
|
|
159
|
+
for batch in fragment.to_batches():
|
|
160
|
+
dict_batch = list(iter_tuples(batch))
|
|
161
|
+
tab.insert(dict_batch)
|
|
162
|
+
except Exception as e:
|
|
163
|
+
_logger.error(f'Error while inserting Parquet file into table: {e}')
|
|
164
|
+
raise e
|
|
165
|
+
|
|
166
|
+
cl.move(tmp_name, table_path)
|
|
167
|
+
return cl.get_table(table_path)
|