pixeltable 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (63) hide show
  1. pixeltable/catalog/column.py +26 -49
  2. pixeltable/catalog/insertable_table.py +7 -4
  3. pixeltable/catalog/table.py +163 -57
  4. pixeltable/catalog/table_version.py +416 -140
  5. pixeltable/catalog/table_version_path.py +2 -2
  6. pixeltable/client.py +72 -6
  7. pixeltable/dataframe.py +65 -21
  8. pixeltable/env.py +52 -53
  9. pixeltable/exec/cache_prefetch_node.py +1 -1
  10. pixeltable/exec/in_memory_data_node.py +11 -7
  11. pixeltable/exprs/comparison.py +3 -3
  12. pixeltable/exprs/data_row.py +5 -1
  13. pixeltable/exprs/literal.py +16 -4
  14. pixeltable/exprs/row_builder.py +8 -40
  15. pixeltable/ext/__init__.py +5 -0
  16. pixeltable/ext/functions/yolox.py +92 -0
  17. pixeltable/func/aggregate_function.py +15 -15
  18. pixeltable/func/expr_template_function.py +9 -1
  19. pixeltable/func/globals.py +24 -14
  20. pixeltable/func/signature.py +18 -12
  21. pixeltable/func/udf.py +7 -2
  22. pixeltable/functions/__init__.py +9 -9
  23. pixeltable/functions/eval.py +7 -8
  24. pixeltable/functions/fireworks.py +10 -37
  25. pixeltable/functions/huggingface.py +47 -19
  26. pixeltable/functions/openai.py +192 -24
  27. pixeltable/functions/together.py +104 -9
  28. pixeltable/functions/util.py +11 -0
  29. pixeltable/index/__init__.py +2 -0
  30. pixeltable/index/base.py +49 -0
  31. pixeltable/index/embedding_index.py +95 -0
  32. pixeltable/metadata/schema.py +45 -22
  33. pixeltable/plan.py +15 -34
  34. pixeltable/store.py +38 -41
  35. pixeltable/tests/conftest.py +8 -14
  36. pixeltable/tests/ext/test_yolox.py +21 -0
  37. pixeltable/tests/functions/test_fireworks.py +43 -0
  38. pixeltable/tests/functions/test_functions.py +60 -0
  39. pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +7 -143
  40. pixeltable/tests/functions/test_openai.py +162 -0
  41. pixeltable/tests/functions/test_together.py +112 -0
  42. pixeltable/tests/test_component_view.py +14 -5
  43. pixeltable/tests/test_dataframe.py +23 -22
  44. pixeltable/tests/test_exprs.py +99 -102
  45. pixeltable/tests/test_function.py +51 -43
  46. pixeltable/tests/test_index.py +138 -0
  47. pixeltable/tests/test_migration.py +2 -1
  48. pixeltable/tests/test_snapshot.py +24 -1
  49. pixeltable/tests/test_table.py +205 -26
  50. pixeltable/tests/test_types.py +30 -0
  51. pixeltable/tests/test_video.py +16 -16
  52. pixeltable/tests/test_view.py +5 -0
  53. pixeltable/tests/utils.py +171 -14
  54. pixeltable/tool/create_test_db_dump.py +16 -0
  55. pixeltable/type_system.py +77 -128
  56. pixeltable/utils/arrow.py +98 -0
  57. pixeltable/utils/hf_datasets.py +157 -0
  58. pixeltable/utils/parquet.py +68 -27
  59. pixeltable/utils/pytorch.py +16 -97
  60. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/METADATA +35 -28
  61. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/RECORD +63 -50
  62. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/LICENSE +0 -0
  63. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/WHEEL +0 -0
pixeltable/type_system.py CHANGED
@@ -6,9 +6,10 @@ import enum
6
6
  import json
7
7
  import typing
8
8
  import urllib.parse
9
+ import urllib.request
9
10
  from copy import copy
10
11
  from pathlib import Path
11
- from typing import Any, Optional, Tuple, Dict, Callable, List, Union
12
+ from typing import Any, Optional, Tuple, Dict, Callable, List, Union, Sequence, Mapping
12
13
 
13
14
  import PIL.Image
14
15
  import av
@@ -240,19 +241,38 @@ class ColumnType:
240
241
 
241
242
  @classmethod
242
243
  def from_python_type(cls, t: type) -> Optional[ColumnType]:
243
- if t in _python_type_to_column_type:
244
- return _python_type_to_column_type[t]
245
- elif isinstance(t, typing._UnionGenericAlias) and t.__args__[1] is type(None):
246
- # `t` is a type of the form Optional[T] (equivalently, Union[T, None]).
247
- # We treat it as the underlying type but with nullable=True.
248
- if t.__args__[0] in _python_type_to_column_type:
249
- underlying = copy(_python_type_to_column_type[t.__args__[0]])
250
- underlying.nullable = True
251
- return underlying
252
-
244
+ if typing.get_origin(t) is typing.Union:
245
+ union_args = typing.get_args(t)
246
+ if union_args[1] is type(None):
247
+ # `t` is a type of the form Optional[T] (equivalently, Union[T, None]).
248
+ # We treat it as the underlying type but with nullable=True.
249
+ underlying = cls.from_python_type(union_args[0])
250
+ if underlying is not None:
251
+ underlying.nullable = True
252
+ return underlying
253
+ else:
254
+ # Discard type parameters to ensure that parameterized types such as `list[T]`
255
+ # are correctly mapped to Pixeltable types.
256
+ base = typing.get_origin(t)
257
+ if base is None:
258
+ # No type parameters; the base type is just `t` itself
259
+ base = t
260
+ if base is str:
261
+ return StringType()
262
+ if base is int:
263
+ return IntType()
264
+ if base is float:
265
+ return FloatType()
266
+ if base is bool:
267
+ return BoolType()
268
+ if base is datetime.date or base is datetime.datetime:
269
+ return TimestampType()
270
+ if issubclass(base, Sequence) or issubclass(base, Mapping):
271
+ return JsonType()
272
+ if issubclass(base, PIL.Image.Image):
273
+ return ImageType()
253
274
  return None
254
275
 
255
-
256
276
  def validate_literal(self, val: Any) -> None:
257
277
  """Raise TypeError if val is not a valid literal for this type"""
258
278
  if val is None:
@@ -275,7 +295,7 @@ class ColumnType:
275
295
  parsed = urllib.parse.urlparse(val)
276
296
  if parsed.scheme != '' and parsed.scheme != 'file':
277
297
  return
278
- path = Path(urllib.parse.unquote(parsed.path))
298
+ path = Path(urllib.parse.unquote(urllib.request.url2pathname(parsed.path)))
279
299
  if not path.is_file():
280
300
  raise TypeError(f'File not found: {str(path)}')
281
301
  else:
@@ -358,35 +378,12 @@ class ColumnType:
358
378
  pass
359
379
 
360
380
  @abc.abstractmethod
361
- def to_sa_type(self) -> Any:
381
+ def to_sa_type(self) -> sql.types.TypeEngine:
362
382
  """
363
383
  Return corresponding SQLAlchemy type.
364
- return type Any: there doesn't appear to be a superclass for the sqlalchemy types
365
384
  """
366
- assert self._type != self.Type.INVALID
367
- if self._type == self.Type.STRING:
368
- return sql.String
369
- if self._type == self.Type.INT:
370
- return sql.Integer
371
- if self._type == self.Type.FLOAT:
372
- return sql.Float
373
- if self._type == self.Type.BOOL:
374
- return sql.Boolean
375
- if self._type == self.Type.TIMESTAMP:
376
- return sql.TIMESTAMP
377
- if self._type == self.Type.IMAGE:
378
- # the URL
379
- return sql.String
380
- if self._type == self.Type.JSON:
381
- return sql.dialects.postgresql.JSONB
382
- if self._type == self.Type.ARRAY:
383
- return sql.VARBINARY
384
- assert False
385
+ pass
385
386
 
386
- @abc.abstractmethod
387
- def to_arrow_type(self) -> 'pyarrow.DataType':
388
- assert False, f'Have not implemented {self.__class__.__name__} to Arrow'
389
-
390
387
  @staticmethod
391
388
  def no_conversion(v: Any) -> Any:
392
389
  """
@@ -410,10 +407,7 @@ class InvalidType(ColumnType):
410
407
  def to_sql(self) -> str:
411
408
  assert False
412
409
 
413
- def to_sa_type(self) -> Any:
414
- assert False
415
-
416
- def to_arrow_type(self) -> 'pyarrow.DataType':
410
+ def to_sa_type(self) -> sql.types.TypeEngine:
417
411
  assert False
418
412
 
419
413
  def print_value(self, val: Any) -> str:
@@ -422,6 +416,7 @@ class InvalidType(ColumnType):
422
416
  def _validate_literal(self, val: Any) -> None:
423
417
  assert False
424
418
 
419
+
425
420
  class StringType(ColumnType):
426
421
  def __init__(self, nullable: bool = False):
427
422
  super().__init__(self.Type.STRING, nullable=nullable)
@@ -440,12 +435,8 @@ class StringType(ColumnType):
440
435
  def to_sql(self) -> str:
441
436
  return 'VARCHAR'
442
437
 
443
- def to_sa_type(self) -> str:
444
- return sql.String
445
-
446
- def to_arrow_type(self) -> 'pyarrow.DataType':
447
- import pyarrow as pa # pylint: disable=import-outside-toplevel
448
- return pa.string()
438
+ def to_sa_type(self) -> sql.types.TypeEngine:
439
+ return sql.String()
449
440
 
450
441
  def print_value(self, val: Any) -> str:
451
442
  return f"'{val}'"
@@ -454,6 +445,14 @@ class StringType(ColumnType):
454
445
  if not isinstance(val, str):
455
446
  raise TypeError(f'Expected string, got {val.__class__.__name__}')
456
447
 
448
+ def _create_literal(self, val: Any) -> Any:
449
+ # Replace null byte within python string with space to avoid issues with Postgres.
450
+ # Use a space to avoid merging words.
451
+ # TODO(orm): this will also be an issue with JSON inputs, would space still be a good replacement?
452
+ if isinstance(val, str) and '\x00' in val:
453
+ return val.replace('\x00', ' ')
454
+ return val
455
+
457
456
 
458
457
  class IntType(ColumnType):
459
458
  def __init__(self, nullable: bool = False):
@@ -462,12 +461,8 @@ class IntType(ColumnType):
462
461
  def to_sql(self) -> str:
463
462
  return 'BIGINT'
464
463
 
465
- def to_sa_type(self) -> str:
466
- return sql.BigInteger
467
-
468
- def to_arrow_type(self) -> 'pyarrow.DataType':
469
- import pyarrow as pa # pylint: disable=import-outside-toplevel
470
- return pa.int64() # to be consistent with bigint above
464
+ def to_sa_type(self) -> sql.types.TypeEngine:
465
+ return sql.BigInteger()
471
466
 
472
467
  def _validate_literal(self, val: Any) -> None:
473
468
  if not isinstance(val, int):
@@ -481,12 +476,8 @@ class FloatType(ColumnType):
481
476
  def to_sql(self) -> str:
482
477
  return 'FLOAT'
483
478
 
484
- def to_sa_type(self) -> str:
485
- return sql.Float
486
-
487
- def to_arrow_type(self) -> 'pyarrow.DataType':
488
- import pyarrow as pa
489
- return pa.float32()
479
+ def to_sa_type(self) -> sql.types.TypeEngine:
480
+ return sql.Float()
490
481
 
491
482
  def _validate_literal(self, val: Any) -> None:
492
483
  if not isinstance(val, float):
@@ -497,6 +488,7 @@ class FloatType(ColumnType):
497
488
  return float(val)
498
489
  return val
499
490
 
491
+
500
492
  class BoolType(ColumnType):
501
493
  def __init__(self, nullable: bool = False):
502
494
  super().__init__(self.Type.BOOL, nullable=nullable)
@@ -504,12 +496,8 @@ class BoolType(ColumnType):
504
496
  def to_sql(self) -> str:
505
497
  return 'BOOLEAN'
506
498
 
507
- def to_sa_type(self) -> str:
508
- return sql.Boolean
509
-
510
- def to_arrow_type(self) -> 'pyarrow.DataType':
511
- import pyarrow as pa # pylint: disable=import-outside-toplevel
512
- return pa.bool_()
499
+ def to_sa_type(self) -> sql.types.TypeEngine:
500
+ return sql.Boolean()
513
501
 
514
502
  def _validate_literal(self, val: Any) -> None:
515
503
  if not isinstance(val, bool):
@@ -520,6 +508,7 @@ class BoolType(ColumnType):
520
508
  return bool(val)
521
509
  return val
522
510
 
511
+
523
512
  class TimestampType(ColumnType):
524
513
  def __init__(self, nullable: bool = False):
525
514
  super().__init__(self.Type.TIMESTAMP, nullable=nullable)
@@ -527,12 +516,8 @@ class TimestampType(ColumnType):
527
516
  def to_sql(self) -> str:
528
517
  return 'INTEGER'
529
518
 
530
- def to_sa_type(self) -> str:
531
- return sql.TIMESTAMP
532
-
533
- def to_arrow_type(self) -> 'pyarrow.DataType':
534
- import pyarrow as pa # pylint: disable=import-outside-toplevel
535
- return pa.timestamp('us') # postgres timestamp is microseconds
519
+ def to_sa_type(self) -> sql.types.TypeEngine:
520
+ return sql.TIMESTAMP()
536
521
 
537
522
  def _validate_literal(self, val: Any) -> None:
538
523
  if not isinstance(val, datetime.datetime) and not isinstance(val, datetime.date):
@@ -543,6 +528,7 @@ class TimestampType(ColumnType):
543
528
  return datetime.datetime.fromisoformat(val)
544
529
  return val
545
530
 
531
+
546
532
  class JsonType(ColumnType):
547
533
  # TODO: type_spec also needs to be able to express lists
548
534
  def __init__(self, type_spec: Optional[Dict[str, ColumnType]] = None, nullable: bool = False):
@@ -568,12 +554,8 @@ class JsonType(ColumnType):
568
554
  def to_sql(self) -> str:
569
555
  return 'JSONB'
570
556
 
571
- def to_sa_type(self) -> str:
572
- return sql.dialects.postgresql.JSONB
573
-
574
- def to_arrow_type(self) -> 'pyarrow.DataType':
575
- import pyarrow as pa # pylint: disable=import-outside-toplevel
576
- return pa.string() # TODO: weight advantage of pa.struct type.
557
+ def to_sa_type(self) -> sql.types.TypeEngine:
558
+ return sql.dialects.postgresql.JSONB()
577
559
 
578
560
  def print_value(self, val: Any) -> str:
579
561
  val_type = self.infer_literal_type(val)
@@ -594,6 +576,7 @@ class JsonType(ColumnType):
594
576
  val = list(val)
595
577
  return val
596
578
 
579
+
597
580
  class ArrayType(ColumnType):
598
581
  def __init__(
599
582
  self, shape: Tuple[Union[int, None], ...], dtype: ColumnType, nullable: bool = False):
@@ -669,20 +652,16 @@ class ArrayType(ColumnType):
669
652
 
670
653
  def _create_literal(self, val: Any) -> Any:
671
654
  if isinstance(val, (list,tuple)):
672
- return np.array(val)
655
+ # map python float to whichever numpy float is
656
+ # declared for this type, rather than assume float64
657
+ return np.array(val, dtype=self.numpy_dtype())
673
658
  return val
674
659
 
675
660
  def to_sql(self) -> str:
676
661
  return 'BYTEA'
677
662
 
678
- def to_sa_type(self) -> str:
679
- return sql.LargeBinary
680
-
681
- def to_arrow_type(self) -> 'pyarrow.DataType':
682
- import pyarrow as pa # pylint: disable=import-outside-toplevel
683
- if any([n is None for n in self.shape]):
684
- raise TypeError(f'Cannot convert array with unknown shape to Arrow')
685
- return pa.fixed_shape_tensor(pa.from_numpy_dtype(self.numpy_dtype()), self.shape)
663
+ def to_sa_type(self) -> sql.types.TypeEngine:
664
+ return sql.LargeBinary()
686
665
 
687
666
  def numpy_dtype(self) -> np.dtype:
688
667
  if self.dtype == self.Type.INT:
@@ -786,12 +765,8 @@ class ImageType(ColumnType):
786
765
  def to_sql(self) -> str:
787
766
  return 'VARCHAR'
788
767
 
789
- def to_sa_type(self) -> str:
790
- return sql.String
791
-
792
- def to_arrow_type(self) -> 'pyarrow.DataType':
793
- import pyarrow as pa # pylint: disable=import-outside-toplevel
794
- return pa.binary()
768
+ def to_sa_type(self) -> sql.types.TypeEngine:
769
+ return sql.String()
795
770
 
796
771
  def _validate_literal(self, val: Any) -> None:
797
772
  if isinstance(val, PIL.Image.Image):
@@ -805,6 +780,7 @@ class ImageType(ColumnType):
805
780
  except PIL.UnidentifiedImageError:
806
781
  raise excs.Error(f'Not a valid image: {val}') from None
807
782
 
783
+
808
784
  class VideoType(ColumnType):
809
785
  def __init__(self, nullable: bool = False):
810
786
  super().__init__(self.Type.VIDEO, nullable=nullable)
@@ -813,12 +789,8 @@ class VideoType(ColumnType):
813
789
  # stored as a file path
814
790
  return 'VARCHAR'
815
791
 
816
- def to_sa_type(self) -> str:
817
- return sql.String
818
-
819
- def to_arrow_type(self) -> 'pyarrow.DataType':
820
- import pyarrow as pa # pylint: disable=import-outside-toplevel
821
- return pa.string()
792
+ def to_sa_type(self) -> sql.types.TypeEngine:
793
+ return sql.String()
822
794
 
823
795
  def _validate_literal(self, val: Any) -> None:
824
796
  self._validate_file_path(val)
@@ -843,6 +815,7 @@ class VideoType(ColumnType):
843
815
  except av.AVError:
844
816
  raise excs.Error(f'Not a valid video: {val}') from None
845
817
 
818
+
846
819
  class AudioType(ColumnType):
847
820
  def __init__(self, nullable: bool = False):
848
821
  super().__init__(self.Type.AUDIO, nullable=nullable)
@@ -851,12 +824,8 @@ class AudioType(ColumnType):
851
824
  # stored as a file path
852
825
  return 'VARCHAR'
853
826
 
854
- def to_sa_type(self) -> str:
855
- return sql.String
856
-
857
- def to_arrow_type(self) -> 'pyarrow.DataType':
858
- import pyarrow as pa # pylint: disable=import-outside-toplevel
859
- return pa.string()
827
+ def to_sa_type(self) -> sql.types.TypeEngine:
828
+ return sql.String()
860
829
 
861
830
  def _validate_literal(self, val: Any) -> None:
862
831
  self._validate_file_path(val)
@@ -876,6 +845,7 @@ class AudioType(ColumnType):
876
845
  except av.AVError as e:
877
846
  raise excs.Error(f'Not a valid audio file: {val}\n{e}') from None
878
847
 
848
+
879
849
  class DocumentType(ColumnType):
880
850
  @enum.unique
881
851
  class DocumentFormat(enum.Enum):
@@ -898,12 +868,8 @@ class DocumentType(ColumnType):
898
868
  # stored as a file path
899
869
  return 'VARCHAR'
900
870
 
901
- def to_sa_type(self) -> str:
902
- return sql.String
903
-
904
- def to_arrow_type(self) -> 'pyarrow.DataType':
905
- import pyarrow as pa # pylint: disable=import-outside-toplevel
906
- return pa.string()
871
+ def to_sa_type(self) -> sql.types.TypeEngine:
872
+ return sql.String()
907
873
 
908
874
  def _validate_literal(self, val: Any) -> None:
909
875
  self._validate_file_path(val)
@@ -919,20 +885,3 @@ class DocumentType(ColumnType):
919
885
  raise excs.Error(f'Not a recognized document format: {val}')
920
886
  except Exception as e:
921
887
  raise excs.Error(f'Not a recognized document format: {val}') from None
922
-
923
-
924
- # A dictionary mapping various Python types to their respective ColumnTypes.
925
- # This can be used to infer Pixeltable ColumnTypes from type hints on Python
926
- # functions. (Since Python functions do not necessarily have type hints, this
927
- # should always be an optional/convenience inference.)
928
- _python_type_to_column_type: dict[type, ColumnType] = {
929
- str: StringType(),
930
- int: IntType(),
931
- float: FloatType(),
932
- bool: BoolType(),
933
- datetime.datetime: TimestampType(),
934
- datetime.date: TimestampType(),
935
- list: JsonType(),
936
- dict: JsonType(),
937
- PIL.Image.Image: ImageType()
938
- }
@@ -0,0 +1,98 @@
1
+ import logging
2
+ from typing import Any, Dict, Iterable, Iterator, Optional
3
+
4
+ import pyarrow as pa
5
+
6
+ import pixeltable.type_system as ts
7
+
8
+ _logger = logging.getLogger(__name__)
9
+
10
+ _pa_to_pt: Dict[pa.DataType, ts.ColumnType] = {
11
+ pa.string(): ts.StringType(nullable=True),
12
+ pa.timestamp('us'): ts.TimestampType(nullable=True),
13
+ pa.bool_(): ts.BoolType(nullable=True),
14
+ pa.uint8(): ts.IntType(nullable=True),
15
+ pa.int8(): ts.IntType(nullable=True),
16
+ pa.uint32(): ts.IntType(nullable=True),
17
+ pa.uint64(): ts.IntType(nullable=True),
18
+ pa.int32(): ts.IntType(nullable=True),
19
+ pa.int64(): ts.IntType(nullable=True),
20
+ pa.float32(): ts.FloatType(nullable=True),
21
+ }
22
+
23
+ _pt_to_pa: Dict[ts.ColumnType, pa.DataType] = {
24
+ ts.StringType: pa.string(),
25
+ ts.TimestampType: pa.timestamp('us'), # postgres timestamp is microseconds
26
+ ts.BoolType: pa.bool_(),
27
+ ts.IntType: pa.int64(),
28
+ ts.FloatType: pa.float32(),
29
+ ts.JsonType: pa.string(), # TODO(orm) pa.struct() is possible
30
+ ts.ImageType: pa.binary(), # inline image
31
+ ts.AudioType: pa.string(), # path
32
+ ts.VideoType: pa.string(), # path
33
+ ts.DocumentType: pa.string(), # path
34
+ }
35
+
36
+
37
+ def to_pixeltable_type(arrow_type: pa.DataType) -> Optional[ts.ColumnType]:
38
+ """Convert a pyarrow DataType to a pixeltable ColumnType if one is defined.
39
+ Returns None if no conversion is currently implemented.
40
+ """
41
+ if arrow_type in _pa_to_pt:
42
+ return _pa_to_pt[arrow_type]
43
+ elif isinstance(arrow_type, pa.FixedShapeTensorType):
44
+ dtype = to_pixeltable_type(arrow_type.value_type)
45
+ if dtype is None:
46
+ return None
47
+ return ts.ArrayType(shape=arrow_type.shape, dtype=dtype)
48
+ else:
49
+ return None
50
+
51
+
52
+ def to_arrow_type(pixeltable_type: ts.ColumnType) -> Optional[pa.DataType]:
53
+ """Convert a pixeltable DataType to a pyarrow datatype if one is defined.
54
+ Returns None if no conversion is currently implemented.
55
+ """
56
+ if pixeltable_type.__class__ in _pt_to_pa:
57
+ return _pt_to_pa[pixeltable_type.__class__]
58
+ elif isinstance(pixeltable_type, ts.ArrayType):
59
+ return pa.fixed_shape_tensor(pa.from_numpy_dtype(pixeltable_type.numpy_dtype()), pixeltable_type.shape)
60
+ else:
61
+ return None
62
+
63
+
64
+ def to_pixeltable_schema(arrow_schema: pa.Schema) -> Dict[str, ts.ColumnType]:
65
+ return {field.name: to_pixeltable_type(field.type) for field in arrow_schema}
66
+
67
+
68
+ def to_arrow_schema(pixeltable_schema: Dict[str, Any]) -> pa.Schema:
69
+ return pa.schema((name, to_arrow_type(typ)) for name, typ in pixeltable_schema.items())
70
+
71
+
72
+ def to_pydict(batch: pa.RecordBatch) -> Dict[str, Iterable[Any]]:
73
+ """Convert a RecordBatch to a dictionary of lists, unlike pa.lib.RecordBatch.to_pydict,
74
+ this function will not convert numpy arrays to lists, and will preserve the original numpy dtype.
75
+ """
76
+ out = {}
77
+ for k, name in enumerate(batch.schema.names):
78
+ col = batch.column(k)
79
+ if isinstance(col.type, pa.FixedShapeTensorType):
80
+ # treat array columns as numpy arrays to easily preserve numpy type
81
+ out[name] = col.to_numpy(zero_copy_only=False)
82
+ else:
83
+ # for the rest, use pydict to preserve python types
84
+ out[name] = col.to_pylist()
85
+
86
+ return out
87
+
88
+
89
+ def iter_tuples(batch: pa.RecordBatch) -> Iterator[Dict[str, Any]]:
90
+ """Convert a RecordBatch to an iterator of dictionaries. also works with pa.Table and pa.RowGroup"""
91
+ pydict = to_pydict(batch)
92
+ assert len(pydict) > 0, 'empty record batch'
93
+ for _, v in pydict.items():
94
+ batch_size = len(v)
95
+ break
96
+
97
+ for i in range(batch_size):
98
+ yield {col_name: values[i] for col_name, values in pydict.items()}
@@ -0,0 +1,157 @@
1
+ import datasets
2
+ from typing import Union, Optional, List, Dict, Any
3
+ import pixeltable.type_system as ts
4
+ from pixeltable import exceptions as excs
5
+ import math
6
+ import logging
7
+ import pixeltable
8
+ import random
9
+
10
+ _logger = logging.getLogger(__name__)
11
+
12
+ # use 100MB as the batch size limit for loading a huggingface dataset into pixeltable.
13
+ # The primary goal is to bound memory use, regardless of dataset size.
14
+ # Second goal is to limit overhead. 100MB is presumed to be reasonable for a lot of storage systems.
15
+ _K_BATCH_SIZE_BYTES = 100_000_000
16
+
17
+ # note, there are many more types. we allow overrides in the schema_override parameter
18
+ # to handle cases where the appropriate type is not yet mapped, or to override this mapping.
19
+ # https://huggingface.co/docs/datasets/v2.17.0/en/package_reference/main_classes#datasets.Value
20
+ _hf_to_pxt: Dict[str, ts.ColumnType] = {
21
+ 'int32': ts.IntType(nullable=True), # pixeltable widens to big int
22
+ 'int64': ts.IntType(nullable=True),
23
+ 'bool': ts.BoolType(nullable=True),
24
+ 'float32': ts.FloatType(nullable=True),
25
+ 'string': ts.StringType(nullable=True),
26
+ 'timestamp[s]': ts.TimestampType(nullable=True),
27
+ 'timestamp[ms]': ts.TimestampType(nullable=True), # HF dataset iterator converts timestamps to datetime.datetime
28
+ }
29
+
30
+ def _to_pixeltable_type(
31
+ feature_type: Union[datasets.ClassLabel, datasets.Value, datasets.Sequence],
32
+ ) -> Optional[ts.ColumnType]:
33
+ """Convert a huggingface feature type to a pixeltable ColumnType if one is defined."""
34
+ if isinstance(feature_type, datasets.ClassLabel):
35
+ # enum, example: ClassLabel(names=['neg', 'pos'], id=None)
36
+ return ts.StringType(nullable=True)
37
+ elif isinstance(feature_type, datasets.Value):
38
+ # example: Value(dtype='int64', id=None)
39
+ return _hf_to_pxt.get(feature_type.dtype, None)
40
+ elif isinstance(feature_type, datasets.Sequence):
41
+ # example: cohere wiki. Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None)
42
+ dtype = _to_pixeltable_type(feature_type.feature)
43
+ length = feature_type.length if feature_type.length != -1 else None
44
+ return ts.ArrayType(shape=(length,), dtype=dtype)
45
+ else:
46
+ return None
47
+
48
+ def _get_hf_schema(dataset: Union[datasets.Dataset, datasets.DatasetDict]) -> datasets.Features:
49
+ """Get the schema of a huggingface dataset as a dictionary."""
50
+ first_dataset = dataset if isinstance(dataset, datasets.Dataset) else next(iter(dataset.values()))
51
+ return first_dataset.features
52
+
53
+ def huggingface_schema_to_pixeltable_schema(
54
+ hf_dataset: Union[datasets.Dataset, datasets.DatasetDict],
55
+ ) -> Dict[str, Optional[ts.ColumnType]]:
56
+ """Generate a pixeltable schema from a huggingface dataset schema.
57
+ Columns without a known mapping are mapped to None
58
+ """
59
+ hf_schema = _get_hf_schema(hf_dataset)
60
+ pixeltable_schema = {
61
+ column_name: _to_pixeltable_type(feature_type) for column_name, feature_type in hf_schema.items()
62
+ }
63
+ return pixeltable_schema
64
+
65
+ def import_huggingface_dataset(
66
+ cl: 'pixeltable.Client',
67
+ table_path: str,
68
+ dataset: Union[datasets.Dataset, datasets.DatasetDict],
69
+ *,
70
+ column_name_for_split: Optional[str],
71
+ schema_override: Optional[Dict[str, Any]],
72
+ **kwargs,
73
+ ) -> 'pixeltable.InsertableTable':
74
+ """See `pixeltable.Client.import_huggingface_dataset` for documentation"""
75
+ if table_path in cl.list_tables():
76
+ raise excs.Error(f'table {table_path} already exists')
77
+
78
+ if not isinstance(dataset, (datasets.Dataset, datasets.DatasetDict)):
79
+ raise excs.Error(f'`type(dataset)` must be `datasets.Dataset` or `datasets.DatasetDict`. Got {type(dataset)=}')
80
+
81
+ if isinstance(dataset, datasets.Dataset):
82
+ # when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
83
+ raw_name = dataset.split._name
84
+ split_name = raw_name.split('[')[0] if raw_name is not None else None
85
+ dataset_dict = {split_name: dataset}
86
+ else:
87
+ dataset_dict = dataset
88
+
89
+ pixeltable_schema = huggingface_schema_to_pixeltable_schema(dataset)
90
+ if schema_override is not None:
91
+ pixeltable_schema.update(schema_override)
92
+
93
+ if column_name_for_split is not None:
94
+ if column_name_for_split in pixeltable_schema:
95
+ raise excs.Error(
96
+ f'Column name `{column_name_for_split}` already exists in dataset schema; provide a different `column_name_for_split`'
97
+ )
98
+ pixeltable_schema[column_name_for_split] = ts.StringType(nullable=True)
99
+
100
+ for field, column_type in pixeltable_schema.items():
101
+ if column_type is None:
102
+ raise excs.Error(f'Could not infer pixeltable type for feature `{field}` in huggingface dataset')
103
+
104
+ if isinstance(dataset, datasets.Dataset):
105
+ # when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
106
+ raw_name = dataset.split._name
107
+ split_name = raw_name.split('[')[0] if raw_name is not None else None
108
+ dataset_dict = {split_name: dataset}
109
+ elif isinstance(dataset, datasets.DatasetDict):
110
+ dataset_dict = dataset
111
+ else:
112
+ raise excs.Error(f'`type(dataset)` must be `datasets.Dataset` or `datasets.DatasetDict`. Got {type(dataset)=}')
113
+
114
+ # extract all class labels from the dataset to translate category ints to strings
115
+ hf_schema = _get_hf_schema(dataset)
116
+ categorical_features = {
117
+ feature_name: feature_type.names
118
+ for (feature_name, feature_type) in hf_schema.items()
119
+ if isinstance(feature_type, datasets.ClassLabel)
120
+ }
121
+
122
+ try:
123
+ # random tmp name
124
+ tmp_name = f'{table_path}_tmp_{random.randint(0, 100000000)}'
125
+ tab = cl.create_table(tmp_name, pixeltable_schema, **kwargs)
126
+
127
+ def _translate_row(row: Dict[str, Any], split_name: str) -> Dict[str, Any]:
128
+ output_row = row.copy()
129
+ # map all class labels to strings
130
+ for field, values in categorical_features.items():
131
+ output_row[field] = values[row[field]]
132
+ # add split name to row
133
+ if column_name_for_split is not None:
134
+ output_row[column_name_for_split] = split_name
135
+ return output_row
136
+
137
+ for split_name, split_dataset in dataset_dict.items():
138
+ num_batches = split_dataset.size_in_bytes / _K_BATCH_SIZE_BYTES
139
+ tuples_per_batch = math.ceil(split_dataset.num_rows / num_batches)
140
+ assert tuples_per_batch > 0
141
+
142
+ batch = []
143
+ for row in split_dataset:
144
+ batch.append(_translate_row(row, split_name))
145
+ if len(batch) >= tuples_per_batch:
146
+ tab.insert(batch)
147
+ batch = []
148
+ # last batch
149
+ if len(batch) > 0:
150
+ tab.insert(batch)
151
+
152
+ except Exception as e:
153
+ _logger.error(f'Error while inserting dataset into table: {tmp_name}')
154
+ raise e
155
+
156
+ cl.move(tmp_name, table_path)
157
+ return cl.get_table(table_path)