pixeltable 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

pixeltable/type_system.py CHANGED
@@ -6,9 +6,8 @@ import enum
6
6
  import json
7
7
  import typing
8
8
  import urllib.parse
9
- from copy import copy
10
9
  from pathlib import Path
11
- from typing import Any, Optional, Tuple, Dict, Callable, List, Union
10
+ from typing import Any, Optional, Tuple, Dict, Callable, List, Union, Sequence, Mapping
12
11
 
13
12
  import PIL.Image
14
13
  import av
@@ -240,19 +239,38 @@ class ColumnType:
240
239
 
241
240
  @classmethod
242
241
  def from_python_type(cls, t: type) -> Optional[ColumnType]:
243
- if t in _python_type_to_column_type:
244
- return _python_type_to_column_type[t]
245
- elif isinstance(t, typing._UnionGenericAlias) and t.__args__[1] is type(None):
246
- # `t` is a type of the form Optional[T] (equivalently, Union[T, None]).
247
- # We treat it as the underlying type but with nullable=True.
248
- if t.__args__[0] in _python_type_to_column_type:
249
- underlying = copy(_python_type_to_column_type[t.__args__[0]])
250
- underlying.nullable = True
251
- return underlying
252
-
242
+ if typing.get_origin(t) is typing.Union:
243
+ union_args = typing.get_args(t)
244
+ if union_args[1] is type(None):
245
+ # `t` is a type of the form Optional[T] (equivalently, Union[T, None]).
246
+ # We treat it as the underlying type but with nullable=True.
247
+ underlying = cls.from_python_type(union_args[0])
248
+ if underlying is not None:
249
+ underlying.nullable = True
250
+ return underlying
251
+ else:
252
+ # Discard type parameters to ensure that parameterized types such as `list[T]`
253
+ # are correctly mapped to Pixeltable types.
254
+ base = typing.get_origin(t)
255
+ if base is None:
256
+ # No type parameters; the base type is just `t` itself
257
+ base = t
258
+ if base is str:
259
+ return StringType()
260
+ if base is int:
261
+ return IntType()
262
+ if base is float:
263
+ return FloatType()
264
+ if base is bool:
265
+ return BoolType()
266
+ if base is datetime.date or base is datetime.datetime:
267
+ return TimestampType()
268
+ if issubclass(base, Sequence) or issubclass(base, Mapping):
269
+ return JsonType()
270
+ if issubclass(base, PIL.Image.Image):
271
+ return ImageType()
253
272
  return None
254
273
 
255
-
256
274
  def validate_literal(self, val: Any) -> None:
257
275
  """Raise TypeError if val is not a valid literal for this type"""
258
276
  if val is None:
@@ -383,10 +401,6 @@ class ColumnType:
383
401
  return sql.VARBINARY
384
402
  assert False
385
403
 
386
- @abc.abstractmethod
387
- def to_arrow_type(self) -> 'pyarrow.DataType':
388
- assert False, f'Have not implemented {self.__class__.__name__} to Arrow'
389
-
390
404
  @staticmethod
391
405
  def no_conversion(v: Any) -> Any:
392
406
  """
@@ -413,9 +427,6 @@ class InvalidType(ColumnType):
413
427
  def to_sa_type(self) -> Any:
414
428
  assert False
415
429
 
416
- def to_arrow_type(self) -> 'pyarrow.DataType':
417
- assert False
418
-
419
430
  def print_value(self, val: Any) -> str:
420
431
  assert False
421
432
 
@@ -442,10 +453,6 @@ class StringType(ColumnType):
442
453
 
443
454
  def to_sa_type(self) -> str:
444
455
  return sql.String
445
-
446
- def to_arrow_type(self) -> 'pyarrow.DataType':
447
- import pyarrow as pa # pylint: disable=import-outside-toplevel
448
- return pa.string()
449
456
 
450
457
  def print_value(self, val: Any) -> str:
451
458
  return f"'{val}'"
@@ -454,6 +461,13 @@ class StringType(ColumnType):
454
461
  if not isinstance(val, str):
455
462
  raise TypeError(f'Expected string, got {val.__class__.__name__}')
456
463
 
464
+ def _create_literal(self, val: Any) -> Any:
465
+ # Replace null byte within python string with space to avoid issues with Postgres.
466
+ # Use a space to avoid merging words.
467
+ # TODO(orm): this will also be an issue with JSON inputs, would space still be a good replacement?
468
+ if isinstance(val, str) and '\x00' in val:
469
+ return val.replace('\x00', ' ')
470
+ return val
457
471
 
458
472
  class IntType(ColumnType):
459
473
  def __init__(self, nullable: bool = False):
@@ -464,10 +478,6 @@ class IntType(ColumnType):
464
478
 
465
479
  def to_sa_type(self) -> str:
466
480
  return sql.BigInteger
467
-
468
- def to_arrow_type(self) -> 'pyarrow.DataType':
469
- import pyarrow as pa # pylint: disable=import-outside-toplevel
470
- return pa.int64() # to be consistent with bigint above
471
481
 
472
482
  def _validate_literal(self, val: Any) -> None:
473
483
  if not isinstance(val, int):
@@ -483,10 +493,6 @@ class FloatType(ColumnType):
483
493
 
484
494
  def to_sa_type(self) -> str:
485
495
  return sql.Float
486
-
487
- def to_arrow_type(self) -> 'pyarrow.DataType':
488
- import pyarrow as pa
489
- return pa.float32()
490
496
 
491
497
  def _validate_literal(self, val: Any) -> None:
492
498
  if not isinstance(val, float):
@@ -506,10 +512,6 @@ class BoolType(ColumnType):
506
512
 
507
513
  def to_sa_type(self) -> str:
508
514
  return sql.Boolean
509
-
510
- def to_arrow_type(self) -> 'pyarrow.DataType':
511
- import pyarrow as pa # pylint: disable=import-outside-toplevel
512
- return pa.bool_()
513
515
 
514
516
  def _validate_literal(self, val: Any) -> None:
515
517
  if not isinstance(val, bool):
@@ -529,10 +531,6 @@ class TimestampType(ColumnType):
529
531
 
530
532
  def to_sa_type(self) -> str:
531
533
  return sql.TIMESTAMP
532
-
533
- def to_arrow_type(self) -> 'pyarrow.DataType':
534
- import pyarrow as pa # pylint: disable=import-outside-toplevel
535
- return pa.timestamp('us') # postgres timestamp is microseconds
536
534
 
537
535
  def _validate_literal(self, val: Any) -> None:
538
536
  if not isinstance(val, datetime.datetime) and not isinstance(val, datetime.date):
@@ -570,10 +568,6 @@ class JsonType(ColumnType):
570
568
 
571
569
  def to_sa_type(self) -> str:
572
570
  return sql.dialects.postgresql.JSONB
573
-
574
- def to_arrow_type(self) -> 'pyarrow.DataType':
575
- import pyarrow as pa # pylint: disable=import-outside-toplevel
576
- return pa.string() # TODO: weight advantage of pa.struct type.
577
571
 
578
572
  def print_value(self, val: Any) -> str:
579
573
  val_type = self.infer_literal_type(val)
@@ -669,7 +663,9 @@ class ArrayType(ColumnType):
669
663
 
670
664
  def _create_literal(self, val: Any) -> Any:
671
665
  if isinstance(val, (list,tuple)):
672
- return np.array(val)
666
+ # map python float to whichever numpy float is
667
+ # declared for this type, rather than assume float64
668
+ return np.array(val, dtype=self.numpy_dtype())
673
669
  return val
674
670
 
675
671
  def to_sql(self) -> str:
@@ -677,12 +673,6 @@ class ArrayType(ColumnType):
677
673
 
678
674
  def to_sa_type(self) -> str:
679
675
  return sql.LargeBinary
680
-
681
- def to_arrow_type(self) -> 'pyarrow.DataType':
682
- import pyarrow as pa # pylint: disable=import-outside-toplevel
683
- if any([n is None for n in self.shape]):
684
- raise TypeError(f'Cannot convert array with unknown shape to Arrow')
685
- return pa.fixed_shape_tensor(pa.from_numpy_dtype(self.numpy_dtype()), self.shape)
686
676
 
687
677
  def numpy_dtype(self) -> np.dtype:
688
678
  if self.dtype == self.Type.INT:
@@ -788,10 +778,6 @@ class ImageType(ColumnType):
788
778
 
789
779
  def to_sa_type(self) -> str:
790
780
  return sql.String
791
-
792
- def to_arrow_type(self) -> 'pyarrow.DataType':
793
- import pyarrow as pa # pylint: disable=import-outside-toplevel
794
- return pa.binary()
795
781
 
796
782
  def _validate_literal(self, val: Any) -> None:
797
783
  if isinstance(val, PIL.Image.Image):
@@ -815,10 +801,6 @@ class VideoType(ColumnType):
815
801
 
816
802
  def to_sa_type(self) -> str:
817
803
  return sql.String
818
-
819
- def to_arrow_type(self) -> 'pyarrow.DataType':
820
- import pyarrow as pa # pylint: disable=import-outside-toplevel
821
- return pa.string()
822
804
 
823
805
  def _validate_literal(self, val: Any) -> None:
824
806
  self._validate_file_path(val)
@@ -854,10 +836,6 @@ class AudioType(ColumnType):
854
836
  def to_sa_type(self) -> str:
855
837
  return sql.String
856
838
 
857
- def to_arrow_type(self) -> 'pyarrow.DataType':
858
- import pyarrow as pa # pylint: disable=import-outside-toplevel
859
- return pa.string()
860
-
861
839
  def _validate_literal(self, val: Any) -> None:
862
840
  self._validate_file_path(val)
863
841
 
@@ -901,10 +879,6 @@ class DocumentType(ColumnType):
901
879
  def to_sa_type(self) -> str:
902
880
  return sql.String
903
881
 
904
- def to_arrow_type(self) -> 'pyarrow.DataType':
905
- import pyarrow as pa # pylint: disable=import-outside-toplevel
906
- return pa.string()
907
-
908
882
  def _validate_literal(self, val: Any) -> None:
909
883
  self._validate_file_path(val)
910
884
 
@@ -919,20 +893,3 @@ class DocumentType(ColumnType):
919
893
  raise excs.Error(f'Not a recognized document format: {val}')
920
894
  except Exception as e:
921
895
  raise excs.Error(f'Not a recognized document format: {val}') from None
922
-
923
-
924
- # A dictionary mapping various Python types to their respective ColumnTypes.
925
- # This can be used to infer Pixeltable ColumnTypes from type hints on Python
926
- # functions. (Since Python functions do not necessarily have type hints, this
927
- # should always be an optional/convenience inference.)
928
- _python_type_to_column_type: dict[type, ColumnType] = {
929
- str: StringType(),
930
- int: IntType(),
931
- float: FloatType(),
932
- bool: BoolType(),
933
- datetime.datetime: TimestampType(),
934
- datetime.date: TimestampType(),
935
- list: JsonType(),
936
- dict: JsonType(),
937
- PIL.Image.Image: ImageType()
938
- }
@@ -0,0 +1,98 @@
1
+ import logging
2
+ from typing import Any, Dict, Iterable, Iterator, Optional
3
+
4
+ import pyarrow as pa
5
+
6
+ import pixeltable.type_system as ts
7
+
8
+ _logger = logging.getLogger(__name__)
9
+
10
+ _pa_to_pt: Dict[pa.DataType, ts.ColumnType] = {
11
+ pa.string(): ts.StringType(nullable=True),
12
+ pa.timestamp('us'): ts.TimestampType(nullable=True),
13
+ pa.bool_(): ts.BoolType(nullable=True),
14
+ pa.uint8(): ts.IntType(nullable=True),
15
+ pa.int8(): ts.IntType(nullable=True),
16
+ pa.uint32(): ts.IntType(nullable=True),
17
+ pa.uint64(): ts.IntType(nullable=True),
18
+ pa.int32(): ts.IntType(nullable=True),
19
+ pa.int64(): ts.IntType(nullable=True),
20
+ pa.float32(): ts.FloatType(nullable=True),
21
+ }
22
+
23
+ _pt_to_pa: Dict[ts.ColumnType, pa.DataType] = {
24
+ ts.StringType: pa.string(),
25
+ ts.TimestampType: pa.timestamp('us'), # postgres timestamp is microseconds
26
+ ts.BoolType: pa.bool_(),
27
+ ts.IntType: pa.int64(),
28
+ ts.FloatType: pa.float32(),
29
+ ts.JsonType: pa.string(), # TODO(orm) pa.struct() is possible
30
+ ts.ImageType: pa.binary(), # inline image
31
+ ts.AudioType: pa.string(), # path
32
+ ts.VideoType: pa.string(), # path
33
+ ts.DocumentType: pa.string(), # path
34
+ }
35
+
36
+
37
+ def to_pixeltable_type(arrow_type: pa.DataType) -> Optional[ts.ColumnType]:
38
+ """Convert a pyarrow DataType to a pixeltable ColumnType if one is defined.
39
+ Returns None if no conversion is currently implemented.
40
+ """
41
+ if arrow_type in _pa_to_pt:
42
+ return _pa_to_pt[arrow_type]
43
+ elif isinstance(arrow_type, pa.FixedShapeTensorType):
44
+ dtype = to_pixeltable_type(arrow_type.value_type)
45
+ if dtype is None:
46
+ return None
47
+ return ts.ArrayType(shape=arrow_type.shape, dtype=dtype)
48
+ else:
49
+ return None
50
+
51
+
52
+ def to_arrow_type(pixeltable_type: ts.ColumnType) -> Optional[pa.DataType]:
53
+ """Convert a pixeltable DataType to a pyarrow datatype if one is defined.
54
+ Returns None if no conversion is currently implemented.
55
+ """
56
+ if pixeltable_type.__class__ in _pt_to_pa:
57
+ return _pt_to_pa[pixeltable_type.__class__]
58
+ elif isinstance(pixeltable_type, ts.ArrayType):
59
+ return pa.fixed_shape_tensor(pa.from_numpy_dtype(pixeltable_type.numpy_dtype()), pixeltable_type.shape)
60
+ else:
61
+ return None
62
+
63
+
64
+ def to_pixeltable_schema(arrow_schema: pa.Schema) -> Dict[str, ts.ColumnType]:
65
+ return {field.name: to_pixeltable_type(field.type) for field in arrow_schema}
66
+
67
+
68
+ def to_arrow_schema(pixeltable_schema: Dict[str, Any]) -> pa.Schema:
69
+ return pa.schema((name, to_arrow_type(typ)) for name, typ in pixeltable_schema.items())
70
+
71
+
72
+ def to_pydict(batch: pa.RecordBatch) -> Dict[str, Iterable[Any]]:
73
+ """Convert a RecordBatch to a dictionary of lists, unlike pa.lib.RecordBatch.to_pydict,
74
+ this function will not convert numpy arrays to lists, and will preserve the original numpy dtype.
75
+ """
76
+ out = {}
77
+ for k, name in enumerate(batch.schema.names):
78
+ col = batch.column(k)
79
+ if isinstance(col.type, pa.FixedShapeTensorType):
80
+ # treat array columns as numpy arrays to easily preserve numpy type
81
+ out[name] = col.to_numpy(zero_copy_only=False)
82
+ else:
83
+ # for the rest, use pydict to preserve python types
84
+ out[name] = col.to_pylist()
85
+
86
+ return out
87
+
88
+
89
+ def iter_tuples(batch: pa.RecordBatch) -> Iterator[Dict[str, Any]]:
90
+ """Convert a RecordBatch to an iterator of dictionaries. also works with pa.Table and pa.RowGroup"""
91
+ pydict = to_pydict(batch)
92
+ assert len(pydict) > 0, 'empty record batch'
93
+ for _, v in pydict.items():
94
+ batch_size = len(v)
95
+ break
96
+
97
+ for i in range(batch_size):
98
+ yield {col_name: values[i] for col_name, values in pydict.items()}
@@ -0,0 +1,157 @@
1
+ import datasets
2
+ from typing import Union, Optional, List, Dict, Any
3
+ import pixeltable.type_system as ts
4
+ from pixeltable import exceptions as excs
5
+ import math
6
+ import logging
7
+ import pixeltable
8
+ import random
9
+
10
+ _logger = logging.getLogger(__name__)
11
+
12
+ # use 100MB as the batch size limit for loading a huggingface dataset into pixeltable.
13
+ # The primary goal is to bound memory use, regardless of dataset size.
14
+ # Second goal is to limit overhead. 100MB is presumed to be reasonable for a lot of storage systems.
15
+ _K_BATCH_SIZE_BYTES = 100_000_000
16
+
17
+ # note, there are many more types. we allow overrides in the schema_override parameter
18
+ # to handle cases where the appropriate type is not yet mapped, or to override this mapping.
19
+ # https://huggingface.co/docs/datasets/v2.17.0/en/package_reference/main_classes#datasets.Value
20
+ _hf_to_pxt: Dict[str, ts.ColumnType] = {
21
+ 'int32': ts.IntType(nullable=True), # pixeltable widens to big int
22
+ 'int64': ts.IntType(nullable=True),
23
+ 'bool': ts.BoolType(nullable=True),
24
+ 'float32': ts.FloatType(nullable=True),
25
+ 'string': ts.StringType(nullable=True),
26
+ 'timestamp[s]': ts.TimestampType(nullable=True),
27
+ 'timestamp[ms]': ts.TimestampType(nullable=True), # HF dataset iterator converts timestamps to datetime.datetime
28
+ }
29
+
30
+ def _to_pixeltable_type(
31
+ feature_type: Union[datasets.ClassLabel, datasets.Value, datasets.Sequence],
32
+ ) -> Optional[ts.ColumnType]:
33
+ """Convert a huggingface feature type to a pixeltable ColumnType if one is defined."""
34
+ if isinstance(feature_type, datasets.ClassLabel):
35
+ # enum, example: ClassLabel(names=['neg', 'pos'], id=None)
36
+ return ts.StringType(nullable=True)
37
+ elif isinstance(feature_type, datasets.Value):
38
+ # example: Value(dtype='int64', id=None)
39
+ return _hf_to_pxt.get(feature_type.dtype, None)
40
+ elif isinstance(feature_type, datasets.Sequence):
41
+ # example: cohere wiki. Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None)
42
+ dtype = _to_pixeltable_type(feature_type.feature)
43
+ length = feature_type.length if feature_type.length != -1 else None
44
+ return ts.ArrayType(shape=(length,), dtype=dtype)
45
+ else:
46
+ return None
47
+
48
+ def _get_hf_schema(dataset: Union[datasets.Dataset, datasets.DatasetDict]) -> datasets.Features:
49
+ """Get the schema of a huggingface dataset as a dictionary."""
50
+ first_dataset = dataset if isinstance(dataset, datasets.Dataset) else next(iter(dataset.values()))
51
+ return first_dataset.features
52
+
53
+ def huggingface_schema_to_pixeltable_schema(
54
+ hf_dataset: Union[datasets.Dataset, datasets.DatasetDict],
55
+ ) -> Dict[str, Optional[ts.ColumnType]]:
56
+ """Generate a pixeltable schema from a huggingface dataset schema.
57
+ Columns without a known mapping are mapped to None
58
+ """
59
+ hf_schema = _get_hf_schema(hf_dataset)
60
+ pixeltable_schema = {
61
+ column_name: _to_pixeltable_type(feature_type) for column_name, feature_type in hf_schema.items()
62
+ }
63
+ return pixeltable_schema
64
+
65
+ def import_huggingface_dataset(
66
+ cl: 'pixeltable.Client',
67
+ table_path: str,
68
+ dataset: Union[datasets.Dataset, datasets.DatasetDict],
69
+ *,
70
+ column_name_for_split: Optional[str],
71
+ schema_override: Optional[Dict[str, Any]],
72
+ **kwargs,
73
+ ) -> 'pixeltable.InsertableTable':
74
+ """See `pixeltable.Client.import_huggingface_dataset` for documentation"""
75
+ if table_path in cl.list_tables():
76
+ raise excs.Error(f'table {table_path} already exists')
77
+
78
+ if not isinstance(dataset, (datasets.Dataset, datasets.DatasetDict)):
79
+ raise excs.Error(f'`type(dataset)` must be `datasets.Dataset` or `datasets.DatasetDict`. Got {type(dataset)=}')
80
+
81
+ if isinstance(dataset, datasets.Dataset):
82
+ # when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
83
+ raw_name = dataset.split._name
84
+ split_name = raw_name.split('[')[0] if raw_name is not None else None
85
+ dataset_dict = {split_name: dataset}
86
+ else:
87
+ dataset_dict = dataset
88
+
89
+ pixeltable_schema = huggingface_schema_to_pixeltable_schema(dataset)
90
+ if schema_override is not None:
91
+ pixeltable_schema.update(schema_override)
92
+
93
+ if column_name_for_split is not None:
94
+ if column_name_for_split in pixeltable_schema:
95
+ raise excs.Error(
96
+ f'Column name `{column_name_for_split}` already exists in dataset schema; provide a different `column_name_for_split`'
97
+ )
98
+ pixeltable_schema[column_name_for_split] = ts.StringType(nullable=True)
99
+
100
+ for field, column_type in pixeltable_schema.items():
101
+ if column_type is None:
102
+ raise excs.Error(f'Could not infer pixeltable type for feature `{field}` in huggingface dataset')
103
+
104
+ if isinstance(dataset, datasets.Dataset):
105
+ # when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
106
+ raw_name = dataset.split._name
107
+ split_name = raw_name.split('[')[0] if raw_name is not None else None
108
+ dataset_dict = {split_name: dataset}
109
+ elif isinstance(dataset, datasets.DatasetDict):
110
+ dataset_dict = dataset
111
+ else:
112
+ raise excs.Error(f'`type(dataset)` must be `datasets.Dataset` or `datasets.DatasetDict`. Got {type(dataset)=}')
113
+
114
+ # extract all class labels from the dataset to translate category ints to strings
115
+ hf_schema = _get_hf_schema(dataset)
116
+ categorical_features = {
117
+ feature_name: feature_type.names
118
+ for (feature_name, feature_type) in hf_schema.items()
119
+ if isinstance(feature_type, datasets.ClassLabel)
120
+ }
121
+
122
+ try:
123
+ # random tmp name
124
+ tmp_name = f'{table_path}_tmp_{random.randint(0, 100000000)}'
125
+ tab = cl.create_table(tmp_name, pixeltable_schema, **kwargs)
126
+
127
+ def _translate_row(row: Dict[str, Any], split_name: str) -> Dict[str, Any]:
128
+ output_row = row.copy()
129
+ # map all class labels to strings
130
+ for field, values in categorical_features.items():
131
+ output_row[field] = values[row[field]]
132
+ # add split name to row
133
+ if column_name_for_split is not None:
134
+ output_row[column_name_for_split] = split_name
135
+ return output_row
136
+
137
+ for split_name, split_dataset in dataset_dict.items():
138
+ num_batches = split_dataset.size_in_bytes / _K_BATCH_SIZE_BYTES
139
+ tuples_per_batch = math.ceil(split_dataset.num_rows / num_batches)
140
+ assert tuples_per_batch > 0
141
+
142
+ batch = []
143
+ for row in split_dataset:
144
+ batch.append(_translate_row(row, split_name))
145
+ if len(batch) >= tuples_per_batch:
146
+ tab.insert(batch)
147
+ batch = []
148
+ # last batch
149
+ if len(batch) > 0:
150
+ tab.insert(batch)
151
+
152
+ except Exception as e:
153
+ _logger.error(f'Error while inserting dataset into table: {tmp_name}')
154
+ raise e
155
+
156
+ cl.move(tmp_name, table_path)
157
+ return cl.get_table(table_path)
@@ -1,14 +1,24 @@
1
+ import io
1
2
  import json
3
+ import logging
4
+ from collections import deque
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ import numpy as np
2
9
  import PIL.Image
3
- import io
4
- import pyarrow.parquet as pq
5
10
  import pyarrow as pa
6
- import numpy as np
7
- from pathlib import Path
8
- from collections import deque
11
+ import pyarrow.parquet
9
12
 
10
- from typing import List, Tuple, Any, Dict
13
+ import pixeltable.type_system as ts
14
+ from pixeltable.utils.arrow import iter_tuples, to_arrow_schema, to_pixeltable_schema
11
15
  from pixeltable.utils.transactional_directory import transactional_directory
16
+ import pixeltable.exceptions as exc
17
+
18
+ import random
19
+
20
+ _logger = logging.getLogger(__name__)
21
+
12
22
 
13
23
  def _write_batch(value_batch : Dict[str, deque], schema : pa.Schema, output_path : Path) -> None:
14
24
  pydict = {}
@@ -20,18 +30,18 @@ def _write_batch(value_batch : Dict[str, deque], schema : pa.Schema, output_path
20
30
  pydict[field.name] = value_batch[field.name]
21
31
 
22
32
  tab = pa.Table.from_pydict(pydict, schema=schema)
23
- pq.write_table(tab, output_path)
33
+ pa.parquet.write_table(tab, output_path)
24
34
 
25
35
  def save_parquet(df: 'pixeltable.DataFrame', dest_path: Path, partition_size_bytes: int = 100_000_000) -> None:
26
36
  """
27
37
  Internal method to stream dataframe data to parquet format.
28
38
  Does not materialize the dataset to memory.
29
39
 
30
- It preserves pixeltable type metadata in a json file, which would otherwise
40
+ It preserves pixeltable type metadata in a json file, which would otherwise
31
41
  not be available in the parquet format.
32
42
 
33
43
  Images are stored inline in a compressed format in their parquet file.
34
-
44
+
35
45
  Args:
36
46
  df : dataframe to save.
37
47
  dest_path : path to directory to save the parquet files to.
@@ -39,10 +49,9 @@ def save_parquet(df: 'pixeltable.DataFrame', dest_path: Path, partition_size_byt
39
49
  """
40
50
  column_names = df.get_column_names()
41
51
  column_types = df.get_column_types()
42
-
43
52
  type_dict = {k: v.as_dict() for k, v in zip(column_names, column_types)}
53
+ arrow_schema = to_arrow_schema(dict(zip(column_names, column_types)))
44
54
 
45
- arrow_schema = pa.schema([pa.field(name, column_types[i].to_arrow_type()) for i, name in enumerate(column_names)])
46
55
  # store the changes atomically
47
56
  with transactional_directory(dest_path) as temp_path:
48
57
  # dump metadata json file so we can inspect what was the source of the parquet file later on.
@@ -65,8 +74,9 @@ def save_parquet(df: 'pixeltable.DataFrame', dest_path: Path, partition_size_byt
65
74
  # images get inlined into the parquet file
66
75
  if data_row.file_paths is not None and data_row.file_paths[e.slot_idx] is not None:
67
76
  # if there is a file, read directly to preserve information
68
- val = open(data_row.file_paths[e.slot_idx], 'rb').read()
69
- elif isinstance(val, PIL.Image.Image):
77
+ with open(data_row.file_paths[e.slot_idx], 'rb') as f:
78
+ val = f.read()
79
+ elif isinstance(val, PIL.Image.Image):
70
80
  # if no file available, eg. bc it is computed, convert to png
71
81
  buf = io.BytesIO()
72
82
  val.save(buf, format='PNG')
@@ -109,18 +119,49 @@ def save_parquet(df: 'pixeltable.DataFrame', dest_path: Path, partition_size_byt
109
119
 
110
120
  _write_batch(current_value_batch, arrow_schema, temp_path / f'part-{batch_num:05d}.parquet')
111
121
 
112
- def get_part_metadata(path : Path) -> List[Tuple[str, int]]:
113
- """
114
- Args:
115
- path: path to directory containing parquet files.
116
- Returns:
117
- A list of (file_name, num_rows) tuples for the parquet files in file name order.
118
- """
119
- parts = sorted([f for f in path.iterdir() if f.suffix == '.parquet'])
120
- rows_per_file = {}
121
-
122
- for part in parts:
123
- parquet_file = pq.ParquetFile(str(part))
124
- rows_per_file[part] = parquet_file.metadata.num_rows
125
122
 
126
- return [(file, num_rows) for file, num_rows in rows_per_file.items()]
123
+ def parquet_schema_to_pixeltable_schema(parquet_path: str) -> Dict[str, Optional[ts.ColumnType]]:
124
+ """Generate a default pixeltable schema for the given parquet file. Returns None for unknown types."""
125
+
126
+ input_path = Path(parquet_path).expanduser()
127
+ parquet_dataset = pa.parquet.ParquetDataset(input_path)
128
+ return to_pixeltable_schema(parquet_dataset.schema)
129
+
130
+
131
+ def import_parquet(
132
+ cl: 'pixeltable.Client',
133
+ table_path: str,
134
+ *,
135
+ parquet_path: str,
136
+ schema_override: Optional[Dict[str, ts.ColumnType]],
137
+ **kwargs,
138
+ ) -> 'catalog.InsertableTable':
139
+ """See `pixeltable.Client.import_parquet` for documentation"""
140
+ input_path = Path(parquet_path).expanduser()
141
+ parquet_dataset = pa.parquet.ParquetDataset(input_path)
142
+
143
+ schema = parquet_schema_to_pixeltable_schema(parquet_path)
144
+ if schema_override is None:
145
+ schema_override = {}
146
+
147
+ schema.update(schema_override)
148
+ for k, v in schema.items():
149
+ if v is None:
150
+ raise exc.Error(f'Could not infer pixeltable type for column {k} from parquet file')
151
+
152
+ if table_path in cl.list_tables():
153
+ raise exc.Error(f'Table {table_path} already exists')
154
+
155
+ try:
156
+ tmp_name = f'{table_path}_tmp_{random.randint(0, 100000000)}'
157
+ tab = cl.create_table(tmp_name, schema, **kwargs)
158
+ for fragment in parquet_dataset.fragments:
159
+ for batch in fragment.to_batches():
160
+ dict_batch = list(iter_tuples(batch))
161
+ tab.insert(dict_batch)
162
+ except Exception as e:
163
+ _logger.error(f'Error while inserting Parquet file into table: {e}')
164
+ raise e
165
+
166
+ cl.move(tmp_name, table_path)
167
+ return cl.get_table(table_path)