pixeltable 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/catalog/column.py +26 -49
- pixeltable/catalog/insertable_table.py +7 -4
- pixeltable/catalog/table.py +163 -57
- pixeltable/catalog/table_version.py +416 -140
- pixeltable/catalog/table_version_path.py +2 -2
- pixeltable/client.py +72 -6
- pixeltable/dataframe.py +65 -21
- pixeltable/env.py +52 -53
- pixeltable/exec/cache_prefetch_node.py +1 -1
- pixeltable/exec/in_memory_data_node.py +11 -7
- pixeltable/exprs/comparison.py +3 -3
- pixeltable/exprs/data_row.py +5 -1
- pixeltable/exprs/literal.py +16 -4
- pixeltable/exprs/row_builder.py +8 -40
- pixeltable/ext/__init__.py +5 -0
- pixeltable/ext/functions/yolox.py +92 -0
- pixeltable/func/aggregate_function.py +15 -15
- pixeltable/func/expr_template_function.py +9 -1
- pixeltable/func/globals.py +24 -14
- pixeltable/func/signature.py +18 -12
- pixeltable/func/udf.py +7 -2
- pixeltable/functions/__init__.py +9 -9
- pixeltable/functions/eval.py +7 -8
- pixeltable/functions/fireworks.py +10 -37
- pixeltable/functions/huggingface.py +47 -19
- pixeltable/functions/openai.py +192 -24
- pixeltable/functions/together.py +104 -9
- pixeltable/functions/util.py +11 -0
- pixeltable/index/__init__.py +2 -0
- pixeltable/index/base.py +49 -0
- pixeltable/index/embedding_index.py +95 -0
- pixeltable/metadata/schema.py +45 -22
- pixeltable/plan.py +15 -34
- pixeltable/store.py +38 -41
- pixeltable/tests/conftest.py +8 -14
- pixeltable/tests/ext/test_yolox.py +21 -0
- pixeltable/tests/functions/test_fireworks.py +43 -0
- pixeltable/tests/functions/test_functions.py +60 -0
- pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +7 -143
- pixeltable/tests/functions/test_openai.py +162 -0
- pixeltable/tests/functions/test_together.py +112 -0
- pixeltable/tests/test_component_view.py +14 -5
- pixeltable/tests/test_dataframe.py +23 -22
- pixeltable/tests/test_exprs.py +99 -102
- pixeltable/tests/test_function.py +51 -43
- pixeltable/tests/test_index.py +138 -0
- pixeltable/tests/test_migration.py +2 -1
- pixeltable/tests/test_snapshot.py +24 -1
- pixeltable/tests/test_table.py +205 -26
- pixeltable/tests/test_types.py +30 -0
- pixeltable/tests/test_video.py +16 -16
- pixeltable/tests/test_view.py +5 -0
- pixeltable/tests/utils.py +171 -14
- pixeltable/tool/create_test_db_dump.py +16 -0
- pixeltable/type_system.py +77 -128
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/hf_datasets.py +157 -0
- pixeltable/utils/parquet.py +68 -27
- pixeltable/utils/pytorch.py +16 -97
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/METADATA +35 -28
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/RECORD +63 -50
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/WHEEL +0 -0
pixeltable/type_system.py
CHANGED
|
@@ -6,9 +6,10 @@ import enum
|
|
|
6
6
|
import json
|
|
7
7
|
import typing
|
|
8
8
|
import urllib.parse
|
|
9
|
+
import urllib.request
|
|
9
10
|
from copy import copy
|
|
10
11
|
from pathlib import Path
|
|
11
|
-
from typing import Any, Optional, Tuple, Dict, Callable, List, Union
|
|
12
|
+
from typing import Any, Optional, Tuple, Dict, Callable, List, Union, Sequence, Mapping
|
|
12
13
|
|
|
13
14
|
import PIL.Image
|
|
14
15
|
import av
|
|
@@ -240,19 +241,38 @@ class ColumnType:
|
|
|
240
241
|
|
|
241
242
|
@classmethod
|
|
242
243
|
def from_python_type(cls, t: type) -> Optional[ColumnType]:
|
|
243
|
-
if t
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
underlying
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
244
|
+
if typing.get_origin(t) is typing.Union:
|
|
245
|
+
union_args = typing.get_args(t)
|
|
246
|
+
if union_args[1] is type(None):
|
|
247
|
+
# `t` is a type of the form Optional[T] (equivalently, Union[T, None]).
|
|
248
|
+
# We treat it as the underlying type but with nullable=True.
|
|
249
|
+
underlying = cls.from_python_type(union_args[0])
|
|
250
|
+
if underlying is not None:
|
|
251
|
+
underlying.nullable = True
|
|
252
|
+
return underlying
|
|
253
|
+
else:
|
|
254
|
+
# Discard type parameters to ensure that parameterized types such as `list[T]`
|
|
255
|
+
# are correctly mapped to Pixeltable types.
|
|
256
|
+
base = typing.get_origin(t)
|
|
257
|
+
if base is None:
|
|
258
|
+
# No type parameters; the base type is just `t` itself
|
|
259
|
+
base = t
|
|
260
|
+
if base is str:
|
|
261
|
+
return StringType()
|
|
262
|
+
if base is int:
|
|
263
|
+
return IntType()
|
|
264
|
+
if base is float:
|
|
265
|
+
return FloatType()
|
|
266
|
+
if base is bool:
|
|
267
|
+
return BoolType()
|
|
268
|
+
if base is datetime.date or base is datetime.datetime:
|
|
269
|
+
return TimestampType()
|
|
270
|
+
if issubclass(base, Sequence) or issubclass(base, Mapping):
|
|
271
|
+
return JsonType()
|
|
272
|
+
if issubclass(base, PIL.Image.Image):
|
|
273
|
+
return ImageType()
|
|
253
274
|
return None
|
|
254
275
|
|
|
255
|
-
|
|
256
276
|
def validate_literal(self, val: Any) -> None:
|
|
257
277
|
"""Raise TypeError if val is not a valid literal for this type"""
|
|
258
278
|
if val is None:
|
|
@@ -275,7 +295,7 @@ class ColumnType:
|
|
|
275
295
|
parsed = urllib.parse.urlparse(val)
|
|
276
296
|
if parsed.scheme != '' and parsed.scheme != 'file':
|
|
277
297
|
return
|
|
278
|
-
path = Path(urllib.parse.unquote(parsed.path))
|
|
298
|
+
path = Path(urllib.parse.unquote(urllib.request.url2pathname(parsed.path)))
|
|
279
299
|
if not path.is_file():
|
|
280
300
|
raise TypeError(f'File not found: {str(path)}')
|
|
281
301
|
else:
|
|
@@ -358,35 +378,12 @@ class ColumnType:
|
|
|
358
378
|
pass
|
|
359
379
|
|
|
360
380
|
@abc.abstractmethod
|
|
361
|
-
def to_sa_type(self) ->
|
|
381
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
362
382
|
"""
|
|
363
383
|
Return corresponding SQLAlchemy type.
|
|
364
|
-
return type Any: there doesn't appear to be a superclass for the sqlalchemy types
|
|
365
384
|
"""
|
|
366
|
-
|
|
367
|
-
if self._type == self.Type.STRING:
|
|
368
|
-
return sql.String
|
|
369
|
-
if self._type == self.Type.INT:
|
|
370
|
-
return sql.Integer
|
|
371
|
-
if self._type == self.Type.FLOAT:
|
|
372
|
-
return sql.Float
|
|
373
|
-
if self._type == self.Type.BOOL:
|
|
374
|
-
return sql.Boolean
|
|
375
|
-
if self._type == self.Type.TIMESTAMP:
|
|
376
|
-
return sql.TIMESTAMP
|
|
377
|
-
if self._type == self.Type.IMAGE:
|
|
378
|
-
# the URL
|
|
379
|
-
return sql.String
|
|
380
|
-
if self._type == self.Type.JSON:
|
|
381
|
-
return sql.dialects.postgresql.JSONB
|
|
382
|
-
if self._type == self.Type.ARRAY:
|
|
383
|
-
return sql.VARBINARY
|
|
384
|
-
assert False
|
|
385
|
+
pass
|
|
385
386
|
|
|
386
|
-
@abc.abstractmethod
|
|
387
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
388
|
-
assert False, f'Have not implemented {self.__class__.__name__} to Arrow'
|
|
389
|
-
|
|
390
387
|
@staticmethod
|
|
391
388
|
def no_conversion(v: Any) -> Any:
|
|
392
389
|
"""
|
|
@@ -410,10 +407,7 @@ class InvalidType(ColumnType):
|
|
|
410
407
|
def to_sql(self) -> str:
|
|
411
408
|
assert False
|
|
412
409
|
|
|
413
|
-
def to_sa_type(self) ->
|
|
414
|
-
assert False
|
|
415
|
-
|
|
416
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
410
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
417
411
|
assert False
|
|
418
412
|
|
|
419
413
|
def print_value(self, val: Any) -> str:
|
|
@@ -422,6 +416,7 @@ class InvalidType(ColumnType):
|
|
|
422
416
|
def _validate_literal(self, val: Any) -> None:
|
|
423
417
|
assert False
|
|
424
418
|
|
|
419
|
+
|
|
425
420
|
class StringType(ColumnType):
|
|
426
421
|
def __init__(self, nullable: bool = False):
|
|
427
422
|
super().__init__(self.Type.STRING, nullable=nullable)
|
|
@@ -440,12 +435,8 @@ class StringType(ColumnType):
|
|
|
440
435
|
def to_sql(self) -> str:
|
|
441
436
|
return 'VARCHAR'
|
|
442
437
|
|
|
443
|
-
def to_sa_type(self) ->
|
|
444
|
-
return sql.String
|
|
445
|
-
|
|
446
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
447
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
448
|
-
return pa.string()
|
|
438
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
439
|
+
return sql.String()
|
|
449
440
|
|
|
450
441
|
def print_value(self, val: Any) -> str:
|
|
451
442
|
return f"'{val}'"
|
|
@@ -454,6 +445,14 @@ class StringType(ColumnType):
|
|
|
454
445
|
if not isinstance(val, str):
|
|
455
446
|
raise TypeError(f'Expected string, got {val.__class__.__name__}')
|
|
456
447
|
|
|
448
|
+
def _create_literal(self, val: Any) -> Any:
|
|
449
|
+
# Replace null byte within python string with space to avoid issues with Postgres.
|
|
450
|
+
# Use a space to avoid merging words.
|
|
451
|
+
# TODO(orm): this will also be an issue with JSON inputs, would space still be a good replacement?
|
|
452
|
+
if isinstance(val, str) and '\x00' in val:
|
|
453
|
+
return val.replace('\x00', ' ')
|
|
454
|
+
return val
|
|
455
|
+
|
|
457
456
|
|
|
458
457
|
class IntType(ColumnType):
|
|
459
458
|
def __init__(self, nullable: bool = False):
|
|
@@ -462,12 +461,8 @@ class IntType(ColumnType):
|
|
|
462
461
|
def to_sql(self) -> str:
|
|
463
462
|
return 'BIGINT'
|
|
464
463
|
|
|
465
|
-
def to_sa_type(self) ->
|
|
466
|
-
return sql.BigInteger
|
|
467
|
-
|
|
468
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
469
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
470
|
-
return pa.int64() # to be consistent with bigint above
|
|
464
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
465
|
+
return sql.BigInteger()
|
|
471
466
|
|
|
472
467
|
def _validate_literal(self, val: Any) -> None:
|
|
473
468
|
if not isinstance(val, int):
|
|
@@ -481,12 +476,8 @@ class FloatType(ColumnType):
|
|
|
481
476
|
def to_sql(self) -> str:
|
|
482
477
|
return 'FLOAT'
|
|
483
478
|
|
|
484
|
-
def to_sa_type(self) ->
|
|
485
|
-
return sql.Float
|
|
486
|
-
|
|
487
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
488
|
-
import pyarrow as pa
|
|
489
|
-
return pa.float32()
|
|
479
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
480
|
+
return sql.Float()
|
|
490
481
|
|
|
491
482
|
def _validate_literal(self, val: Any) -> None:
|
|
492
483
|
if not isinstance(val, float):
|
|
@@ -497,6 +488,7 @@ class FloatType(ColumnType):
|
|
|
497
488
|
return float(val)
|
|
498
489
|
return val
|
|
499
490
|
|
|
491
|
+
|
|
500
492
|
class BoolType(ColumnType):
|
|
501
493
|
def __init__(self, nullable: bool = False):
|
|
502
494
|
super().__init__(self.Type.BOOL, nullable=nullable)
|
|
@@ -504,12 +496,8 @@ class BoolType(ColumnType):
|
|
|
504
496
|
def to_sql(self) -> str:
|
|
505
497
|
return 'BOOLEAN'
|
|
506
498
|
|
|
507
|
-
def to_sa_type(self) ->
|
|
508
|
-
return sql.Boolean
|
|
509
|
-
|
|
510
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
511
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
512
|
-
return pa.bool_()
|
|
499
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
500
|
+
return sql.Boolean()
|
|
513
501
|
|
|
514
502
|
def _validate_literal(self, val: Any) -> None:
|
|
515
503
|
if not isinstance(val, bool):
|
|
@@ -520,6 +508,7 @@ class BoolType(ColumnType):
|
|
|
520
508
|
return bool(val)
|
|
521
509
|
return val
|
|
522
510
|
|
|
511
|
+
|
|
523
512
|
class TimestampType(ColumnType):
|
|
524
513
|
def __init__(self, nullable: bool = False):
|
|
525
514
|
super().__init__(self.Type.TIMESTAMP, nullable=nullable)
|
|
@@ -527,12 +516,8 @@ class TimestampType(ColumnType):
|
|
|
527
516
|
def to_sql(self) -> str:
|
|
528
517
|
return 'INTEGER'
|
|
529
518
|
|
|
530
|
-
def to_sa_type(self) ->
|
|
531
|
-
return sql.TIMESTAMP
|
|
532
|
-
|
|
533
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
534
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
535
|
-
return pa.timestamp('us') # postgres timestamp is microseconds
|
|
519
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
520
|
+
return sql.TIMESTAMP()
|
|
536
521
|
|
|
537
522
|
def _validate_literal(self, val: Any) -> None:
|
|
538
523
|
if not isinstance(val, datetime.datetime) and not isinstance(val, datetime.date):
|
|
@@ -543,6 +528,7 @@ class TimestampType(ColumnType):
|
|
|
543
528
|
return datetime.datetime.fromisoformat(val)
|
|
544
529
|
return val
|
|
545
530
|
|
|
531
|
+
|
|
546
532
|
class JsonType(ColumnType):
|
|
547
533
|
# TODO: type_spec also needs to be able to express lists
|
|
548
534
|
def __init__(self, type_spec: Optional[Dict[str, ColumnType]] = None, nullable: bool = False):
|
|
@@ -568,12 +554,8 @@ class JsonType(ColumnType):
|
|
|
568
554
|
def to_sql(self) -> str:
|
|
569
555
|
return 'JSONB'
|
|
570
556
|
|
|
571
|
-
def to_sa_type(self) ->
|
|
572
|
-
return sql.dialects.postgresql.JSONB
|
|
573
|
-
|
|
574
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
575
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
576
|
-
return pa.string() # TODO: weight advantage of pa.struct type.
|
|
557
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
558
|
+
return sql.dialects.postgresql.JSONB()
|
|
577
559
|
|
|
578
560
|
def print_value(self, val: Any) -> str:
|
|
579
561
|
val_type = self.infer_literal_type(val)
|
|
@@ -594,6 +576,7 @@ class JsonType(ColumnType):
|
|
|
594
576
|
val = list(val)
|
|
595
577
|
return val
|
|
596
578
|
|
|
579
|
+
|
|
597
580
|
class ArrayType(ColumnType):
|
|
598
581
|
def __init__(
|
|
599
582
|
self, shape: Tuple[Union[int, None], ...], dtype: ColumnType, nullable: bool = False):
|
|
@@ -669,20 +652,16 @@ class ArrayType(ColumnType):
|
|
|
669
652
|
|
|
670
653
|
def _create_literal(self, val: Any) -> Any:
|
|
671
654
|
if isinstance(val, (list,tuple)):
|
|
672
|
-
|
|
655
|
+
# map python float to whichever numpy float is
|
|
656
|
+
# declared for this type, rather than assume float64
|
|
657
|
+
return np.array(val, dtype=self.numpy_dtype())
|
|
673
658
|
return val
|
|
674
659
|
|
|
675
660
|
def to_sql(self) -> str:
|
|
676
661
|
return 'BYTEA'
|
|
677
662
|
|
|
678
|
-
def to_sa_type(self) ->
|
|
679
|
-
return sql.LargeBinary
|
|
680
|
-
|
|
681
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
682
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
683
|
-
if any([n is None for n in self.shape]):
|
|
684
|
-
raise TypeError(f'Cannot convert array with unknown shape to Arrow')
|
|
685
|
-
return pa.fixed_shape_tensor(pa.from_numpy_dtype(self.numpy_dtype()), self.shape)
|
|
663
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
664
|
+
return sql.LargeBinary()
|
|
686
665
|
|
|
687
666
|
def numpy_dtype(self) -> np.dtype:
|
|
688
667
|
if self.dtype == self.Type.INT:
|
|
@@ -786,12 +765,8 @@ class ImageType(ColumnType):
|
|
|
786
765
|
def to_sql(self) -> str:
|
|
787
766
|
return 'VARCHAR'
|
|
788
767
|
|
|
789
|
-
def to_sa_type(self) ->
|
|
790
|
-
return sql.String
|
|
791
|
-
|
|
792
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
793
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
794
|
-
return pa.binary()
|
|
768
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
769
|
+
return sql.String()
|
|
795
770
|
|
|
796
771
|
def _validate_literal(self, val: Any) -> None:
|
|
797
772
|
if isinstance(val, PIL.Image.Image):
|
|
@@ -805,6 +780,7 @@ class ImageType(ColumnType):
|
|
|
805
780
|
except PIL.UnidentifiedImageError:
|
|
806
781
|
raise excs.Error(f'Not a valid image: {val}') from None
|
|
807
782
|
|
|
783
|
+
|
|
808
784
|
class VideoType(ColumnType):
|
|
809
785
|
def __init__(self, nullable: bool = False):
|
|
810
786
|
super().__init__(self.Type.VIDEO, nullable=nullable)
|
|
@@ -813,12 +789,8 @@ class VideoType(ColumnType):
|
|
|
813
789
|
# stored as a file path
|
|
814
790
|
return 'VARCHAR'
|
|
815
791
|
|
|
816
|
-
def to_sa_type(self) ->
|
|
817
|
-
return sql.String
|
|
818
|
-
|
|
819
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
820
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
821
|
-
return pa.string()
|
|
792
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
793
|
+
return sql.String()
|
|
822
794
|
|
|
823
795
|
def _validate_literal(self, val: Any) -> None:
|
|
824
796
|
self._validate_file_path(val)
|
|
@@ -843,6 +815,7 @@ class VideoType(ColumnType):
|
|
|
843
815
|
except av.AVError:
|
|
844
816
|
raise excs.Error(f'Not a valid video: {val}') from None
|
|
845
817
|
|
|
818
|
+
|
|
846
819
|
class AudioType(ColumnType):
|
|
847
820
|
def __init__(self, nullable: bool = False):
|
|
848
821
|
super().__init__(self.Type.AUDIO, nullable=nullable)
|
|
@@ -851,12 +824,8 @@ class AudioType(ColumnType):
|
|
|
851
824
|
# stored as a file path
|
|
852
825
|
return 'VARCHAR'
|
|
853
826
|
|
|
854
|
-
def to_sa_type(self) ->
|
|
855
|
-
return sql.String
|
|
856
|
-
|
|
857
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
858
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
859
|
-
return pa.string()
|
|
827
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
828
|
+
return sql.String()
|
|
860
829
|
|
|
861
830
|
def _validate_literal(self, val: Any) -> None:
|
|
862
831
|
self._validate_file_path(val)
|
|
@@ -876,6 +845,7 @@ class AudioType(ColumnType):
|
|
|
876
845
|
except av.AVError as e:
|
|
877
846
|
raise excs.Error(f'Not a valid audio file: {val}\n{e}') from None
|
|
878
847
|
|
|
848
|
+
|
|
879
849
|
class DocumentType(ColumnType):
|
|
880
850
|
@enum.unique
|
|
881
851
|
class DocumentFormat(enum.Enum):
|
|
@@ -898,12 +868,8 @@ class DocumentType(ColumnType):
|
|
|
898
868
|
# stored as a file path
|
|
899
869
|
return 'VARCHAR'
|
|
900
870
|
|
|
901
|
-
def to_sa_type(self) ->
|
|
902
|
-
return sql.String
|
|
903
|
-
|
|
904
|
-
def to_arrow_type(self) -> 'pyarrow.DataType':
|
|
905
|
-
import pyarrow as pa # pylint: disable=import-outside-toplevel
|
|
906
|
-
return pa.string()
|
|
871
|
+
def to_sa_type(self) -> sql.types.TypeEngine:
|
|
872
|
+
return sql.String()
|
|
907
873
|
|
|
908
874
|
def _validate_literal(self, val: Any) -> None:
|
|
909
875
|
self._validate_file_path(val)
|
|
@@ -919,20 +885,3 @@ class DocumentType(ColumnType):
|
|
|
919
885
|
raise excs.Error(f'Not a recognized document format: {val}')
|
|
920
886
|
except Exception as e:
|
|
921
887
|
raise excs.Error(f'Not a recognized document format: {val}') from None
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
# A dictionary mapping various Python types to their respective ColumnTypes.
|
|
925
|
-
# This can be used to infer Pixeltable ColumnTypes from type hints on Python
|
|
926
|
-
# functions. (Since Python functions do not necessarily have type hints, this
|
|
927
|
-
# should always be an optional/convenience inference.)
|
|
928
|
-
_python_type_to_column_type: dict[type, ColumnType] = {
|
|
929
|
-
str: StringType(),
|
|
930
|
-
int: IntType(),
|
|
931
|
-
float: FloatType(),
|
|
932
|
-
bool: BoolType(),
|
|
933
|
-
datetime.datetime: TimestampType(),
|
|
934
|
-
datetime.date: TimestampType(),
|
|
935
|
-
list: JsonType(),
|
|
936
|
-
dict: JsonType(),
|
|
937
|
-
PIL.Image.Image: ImageType()
|
|
938
|
-
}
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Dict, Iterable, Iterator, Optional
|
|
3
|
+
|
|
4
|
+
import pyarrow as pa
|
|
5
|
+
|
|
6
|
+
import pixeltable.type_system as ts
|
|
7
|
+
|
|
8
|
+
_logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
_pa_to_pt: Dict[pa.DataType, ts.ColumnType] = {
|
|
11
|
+
pa.string(): ts.StringType(nullable=True),
|
|
12
|
+
pa.timestamp('us'): ts.TimestampType(nullable=True),
|
|
13
|
+
pa.bool_(): ts.BoolType(nullable=True),
|
|
14
|
+
pa.uint8(): ts.IntType(nullable=True),
|
|
15
|
+
pa.int8(): ts.IntType(nullable=True),
|
|
16
|
+
pa.uint32(): ts.IntType(nullable=True),
|
|
17
|
+
pa.uint64(): ts.IntType(nullable=True),
|
|
18
|
+
pa.int32(): ts.IntType(nullable=True),
|
|
19
|
+
pa.int64(): ts.IntType(nullable=True),
|
|
20
|
+
pa.float32(): ts.FloatType(nullable=True),
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
_pt_to_pa: Dict[ts.ColumnType, pa.DataType] = {
|
|
24
|
+
ts.StringType: pa.string(),
|
|
25
|
+
ts.TimestampType: pa.timestamp('us'), # postgres timestamp is microseconds
|
|
26
|
+
ts.BoolType: pa.bool_(),
|
|
27
|
+
ts.IntType: pa.int64(),
|
|
28
|
+
ts.FloatType: pa.float32(),
|
|
29
|
+
ts.JsonType: pa.string(), # TODO(orm) pa.struct() is possible
|
|
30
|
+
ts.ImageType: pa.binary(), # inline image
|
|
31
|
+
ts.AudioType: pa.string(), # path
|
|
32
|
+
ts.VideoType: pa.string(), # path
|
|
33
|
+
ts.DocumentType: pa.string(), # path
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def to_pixeltable_type(arrow_type: pa.DataType) -> Optional[ts.ColumnType]:
|
|
38
|
+
"""Convert a pyarrow DataType to a pixeltable ColumnType if one is defined.
|
|
39
|
+
Returns None if no conversion is currently implemented.
|
|
40
|
+
"""
|
|
41
|
+
if arrow_type in _pa_to_pt:
|
|
42
|
+
return _pa_to_pt[arrow_type]
|
|
43
|
+
elif isinstance(arrow_type, pa.FixedShapeTensorType):
|
|
44
|
+
dtype = to_pixeltable_type(arrow_type.value_type)
|
|
45
|
+
if dtype is None:
|
|
46
|
+
return None
|
|
47
|
+
return ts.ArrayType(shape=arrow_type.shape, dtype=dtype)
|
|
48
|
+
else:
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def to_arrow_type(pixeltable_type: ts.ColumnType) -> Optional[pa.DataType]:
|
|
53
|
+
"""Convert a pixeltable DataType to a pyarrow datatype if one is defined.
|
|
54
|
+
Returns None if no conversion is currently implemented.
|
|
55
|
+
"""
|
|
56
|
+
if pixeltable_type.__class__ in _pt_to_pa:
|
|
57
|
+
return _pt_to_pa[pixeltable_type.__class__]
|
|
58
|
+
elif isinstance(pixeltable_type, ts.ArrayType):
|
|
59
|
+
return pa.fixed_shape_tensor(pa.from_numpy_dtype(pixeltable_type.numpy_dtype()), pixeltable_type.shape)
|
|
60
|
+
else:
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def to_pixeltable_schema(arrow_schema: pa.Schema) -> Dict[str, ts.ColumnType]:
|
|
65
|
+
return {field.name: to_pixeltable_type(field.type) for field in arrow_schema}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def to_arrow_schema(pixeltable_schema: Dict[str, Any]) -> pa.Schema:
|
|
69
|
+
return pa.schema((name, to_arrow_type(typ)) for name, typ in pixeltable_schema.items())
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def to_pydict(batch: pa.RecordBatch) -> Dict[str, Iterable[Any]]:
|
|
73
|
+
"""Convert a RecordBatch to a dictionary of lists, unlike pa.lib.RecordBatch.to_pydict,
|
|
74
|
+
this function will not convert numpy arrays to lists, and will preserve the original numpy dtype.
|
|
75
|
+
"""
|
|
76
|
+
out = {}
|
|
77
|
+
for k, name in enumerate(batch.schema.names):
|
|
78
|
+
col = batch.column(k)
|
|
79
|
+
if isinstance(col.type, pa.FixedShapeTensorType):
|
|
80
|
+
# treat array columns as numpy arrays to easily preserve numpy type
|
|
81
|
+
out[name] = col.to_numpy(zero_copy_only=False)
|
|
82
|
+
else:
|
|
83
|
+
# for the rest, use pydict to preserve python types
|
|
84
|
+
out[name] = col.to_pylist()
|
|
85
|
+
|
|
86
|
+
return out
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def iter_tuples(batch: pa.RecordBatch) -> Iterator[Dict[str, Any]]:
|
|
90
|
+
"""Convert a RecordBatch to an iterator of dictionaries. also works with pa.Table and pa.RowGroup"""
|
|
91
|
+
pydict = to_pydict(batch)
|
|
92
|
+
assert len(pydict) > 0, 'empty record batch'
|
|
93
|
+
for _, v in pydict.items():
|
|
94
|
+
batch_size = len(v)
|
|
95
|
+
break
|
|
96
|
+
|
|
97
|
+
for i in range(batch_size):
|
|
98
|
+
yield {col_name: values[i] for col_name, values in pydict.items()}
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import datasets
|
|
2
|
+
from typing import Union, Optional, List, Dict, Any
|
|
3
|
+
import pixeltable.type_system as ts
|
|
4
|
+
from pixeltable import exceptions as excs
|
|
5
|
+
import math
|
|
6
|
+
import logging
|
|
7
|
+
import pixeltable
|
|
8
|
+
import random
|
|
9
|
+
|
|
10
|
+
_logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
# use 100MB as the batch size limit for loading a huggingface dataset into pixeltable.
|
|
13
|
+
# The primary goal is to bound memory use, regardless of dataset size.
|
|
14
|
+
# Second goal is to limit overhead. 100MB is presumed to be reasonable for a lot of storage systems.
|
|
15
|
+
_K_BATCH_SIZE_BYTES = 100_000_000
|
|
16
|
+
|
|
17
|
+
# note, there are many more types. we allow overrides in the schema_override parameter
|
|
18
|
+
# to handle cases where the appropriate type is not yet mapped, or to override this mapping.
|
|
19
|
+
# https://huggingface.co/docs/datasets/v2.17.0/en/package_reference/main_classes#datasets.Value
|
|
20
|
+
_hf_to_pxt: Dict[str, ts.ColumnType] = {
|
|
21
|
+
'int32': ts.IntType(nullable=True), # pixeltable widens to big int
|
|
22
|
+
'int64': ts.IntType(nullable=True),
|
|
23
|
+
'bool': ts.BoolType(nullable=True),
|
|
24
|
+
'float32': ts.FloatType(nullable=True),
|
|
25
|
+
'string': ts.StringType(nullable=True),
|
|
26
|
+
'timestamp[s]': ts.TimestampType(nullable=True),
|
|
27
|
+
'timestamp[ms]': ts.TimestampType(nullable=True), # HF dataset iterator converts timestamps to datetime.datetime
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
def _to_pixeltable_type(
|
|
31
|
+
feature_type: Union[datasets.ClassLabel, datasets.Value, datasets.Sequence],
|
|
32
|
+
) -> Optional[ts.ColumnType]:
|
|
33
|
+
"""Convert a huggingface feature type to a pixeltable ColumnType if one is defined."""
|
|
34
|
+
if isinstance(feature_type, datasets.ClassLabel):
|
|
35
|
+
# enum, example: ClassLabel(names=['neg', 'pos'], id=None)
|
|
36
|
+
return ts.StringType(nullable=True)
|
|
37
|
+
elif isinstance(feature_type, datasets.Value):
|
|
38
|
+
# example: Value(dtype='int64', id=None)
|
|
39
|
+
return _hf_to_pxt.get(feature_type.dtype, None)
|
|
40
|
+
elif isinstance(feature_type, datasets.Sequence):
|
|
41
|
+
# example: cohere wiki. Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None)
|
|
42
|
+
dtype = _to_pixeltable_type(feature_type.feature)
|
|
43
|
+
length = feature_type.length if feature_type.length != -1 else None
|
|
44
|
+
return ts.ArrayType(shape=(length,), dtype=dtype)
|
|
45
|
+
else:
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
def _get_hf_schema(dataset: Union[datasets.Dataset, datasets.DatasetDict]) -> datasets.Features:
|
|
49
|
+
"""Get the schema of a huggingface dataset as a dictionary."""
|
|
50
|
+
first_dataset = dataset if isinstance(dataset, datasets.Dataset) else next(iter(dataset.values()))
|
|
51
|
+
return first_dataset.features
|
|
52
|
+
|
|
53
|
+
def huggingface_schema_to_pixeltable_schema(
|
|
54
|
+
hf_dataset: Union[datasets.Dataset, datasets.DatasetDict],
|
|
55
|
+
) -> Dict[str, Optional[ts.ColumnType]]:
|
|
56
|
+
"""Generate a pixeltable schema from a huggingface dataset schema.
|
|
57
|
+
Columns without a known mapping are mapped to None
|
|
58
|
+
"""
|
|
59
|
+
hf_schema = _get_hf_schema(hf_dataset)
|
|
60
|
+
pixeltable_schema = {
|
|
61
|
+
column_name: _to_pixeltable_type(feature_type) for column_name, feature_type in hf_schema.items()
|
|
62
|
+
}
|
|
63
|
+
return pixeltable_schema
|
|
64
|
+
|
|
65
|
+
def import_huggingface_dataset(
|
|
66
|
+
cl: 'pixeltable.Client',
|
|
67
|
+
table_path: str,
|
|
68
|
+
dataset: Union[datasets.Dataset, datasets.DatasetDict],
|
|
69
|
+
*,
|
|
70
|
+
column_name_for_split: Optional[str],
|
|
71
|
+
schema_override: Optional[Dict[str, Any]],
|
|
72
|
+
**kwargs,
|
|
73
|
+
) -> 'pixeltable.InsertableTable':
|
|
74
|
+
"""See `pixeltable.Client.import_huggingface_dataset` for documentation"""
|
|
75
|
+
if table_path in cl.list_tables():
|
|
76
|
+
raise excs.Error(f'table {table_path} already exists')
|
|
77
|
+
|
|
78
|
+
if not isinstance(dataset, (datasets.Dataset, datasets.DatasetDict)):
|
|
79
|
+
raise excs.Error(f'`type(dataset)` must be `datasets.Dataset` or `datasets.DatasetDict`. Got {type(dataset)=}')
|
|
80
|
+
|
|
81
|
+
if isinstance(dataset, datasets.Dataset):
|
|
82
|
+
# when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
|
|
83
|
+
raw_name = dataset.split._name
|
|
84
|
+
split_name = raw_name.split('[')[0] if raw_name is not None else None
|
|
85
|
+
dataset_dict = {split_name: dataset}
|
|
86
|
+
else:
|
|
87
|
+
dataset_dict = dataset
|
|
88
|
+
|
|
89
|
+
pixeltable_schema = huggingface_schema_to_pixeltable_schema(dataset)
|
|
90
|
+
if schema_override is not None:
|
|
91
|
+
pixeltable_schema.update(schema_override)
|
|
92
|
+
|
|
93
|
+
if column_name_for_split is not None:
|
|
94
|
+
if column_name_for_split in pixeltable_schema:
|
|
95
|
+
raise excs.Error(
|
|
96
|
+
f'Column name `{column_name_for_split}` already exists in dataset schema; provide a different `column_name_for_split`'
|
|
97
|
+
)
|
|
98
|
+
pixeltable_schema[column_name_for_split] = ts.StringType(nullable=True)
|
|
99
|
+
|
|
100
|
+
for field, column_type in pixeltable_schema.items():
|
|
101
|
+
if column_type is None:
|
|
102
|
+
raise excs.Error(f'Could not infer pixeltable type for feature `{field}` in huggingface dataset')
|
|
103
|
+
|
|
104
|
+
if isinstance(dataset, datasets.Dataset):
|
|
105
|
+
# when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
|
|
106
|
+
raw_name = dataset.split._name
|
|
107
|
+
split_name = raw_name.split('[')[0] if raw_name is not None else None
|
|
108
|
+
dataset_dict = {split_name: dataset}
|
|
109
|
+
elif isinstance(dataset, datasets.DatasetDict):
|
|
110
|
+
dataset_dict = dataset
|
|
111
|
+
else:
|
|
112
|
+
raise excs.Error(f'`type(dataset)` must be `datasets.Dataset` or `datasets.DatasetDict`. Got {type(dataset)=}')
|
|
113
|
+
|
|
114
|
+
# extract all class labels from the dataset to translate category ints to strings
|
|
115
|
+
hf_schema = _get_hf_schema(dataset)
|
|
116
|
+
categorical_features = {
|
|
117
|
+
feature_name: feature_type.names
|
|
118
|
+
for (feature_name, feature_type) in hf_schema.items()
|
|
119
|
+
if isinstance(feature_type, datasets.ClassLabel)
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
# random tmp name
|
|
124
|
+
tmp_name = f'{table_path}_tmp_{random.randint(0, 100000000)}'
|
|
125
|
+
tab = cl.create_table(tmp_name, pixeltable_schema, **kwargs)
|
|
126
|
+
|
|
127
|
+
def _translate_row(row: Dict[str, Any], split_name: str) -> Dict[str, Any]:
|
|
128
|
+
output_row = row.copy()
|
|
129
|
+
# map all class labels to strings
|
|
130
|
+
for field, values in categorical_features.items():
|
|
131
|
+
output_row[field] = values[row[field]]
|
|
132
|
+
# add split name to row
|
|
133
|
+
if column_name_for_split is not None:
|
|
134
|
+
output_row[column_name_for_split] = split_name
|
|
135
|
+
return output_row
|
|
136
|
+
|
|
137
|
+
for split_name, split_dataset in dataset_dict.items():
|
|
138
|
+
num_batches = split_dataset.size_in_bytes / _K_BATCH_SIZE_BYTES
|
|
139
|
+
tuples_per_batch = math.ceil(split_dataset.num_rows / num_batches)
|
|
140
|
+
assert tuples_per_batch > 0
|
|
141
|
+
|
|
142
|
+
batch = []
|
|
143
|
+
for row in split_dataset:
|
|
144
|
+
batch.append(_translate_row(row, split_name))
|
|
145
|
+
if len(batch) >= tuples_per_batch:
|
|
146
|
+
tab.insert(batch)
|
|
147
|
+
batch = []
|
|
148
|
+
# last batch
|
|
149
|
+
if len(batch) > 0:
|
|
150
|
+
tab.insert(batch)
|
|
151
|
+
|
|
152
|
+
except Exception as e:
|
|
153
|
+
_logger.error(f'Error while inserting dataset into table: {tmp_name}')
|
|
154
|
+
raise e
|
|
155
|
+
|
|
156
|
+
cl.move(tmp_name, table_path)
|
|
157
|
+
return cl.get_table(table_path)
|