pixeltable 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (119) hide show
  1. pixeltable/__init__.py +53 -0
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/__init__.py +13 -0
  4. pixeltable/catalog/catalog.py +159 -0
  5. pixeltable/catalog/column.py +181 -0
  6. pixeltable/catalog/dir.py +32 -0
  7. pixeltable/catalog/globals.py +33 -0
  8. pixeltable/catalog/insertable_table.py +192 -0
  9. pixeltable/catalog/named_function.py +36 -0
  10. pixeltable/catalog/path.py +58 -0
  11. pixeltable/catalog/path_dict.py +139 -0
  12. pixeltable/catalog/schema_object.py +39 -0
  13. pixeltable/catalog/table.py +695 -0
  14. pixeltable/catalog/table_version.py +1026 -0
  15. pixeltable/catalog/table_version_path.py +133 -0
  16. pixeltable/catalog/view.py +203 -0
  17. pixeltable/dataframe.py +749 -0
  18. pixeltable/env.py +466 -0
  19. pixeltable/exceptions.py +17 -0
  20. pixeltable/exec/__init__.py +10 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +116 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +94 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +73 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +226 -0
  31. pixeltable/exprs/__init__.py +25 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +114 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +199 -0
  39. pixeltable/exprs/expr.py +594 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +382 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +96 -0
  44. pixeltable/exprs/in_predicate.py +96 -0
  45. pixeltable/exprs/inline_array.py +109 -0
  46. pixeltable/exprs/inline_dict.py +103 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +66 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +329 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/similarity_expr.py +65 -0
  56. pixeltable/exprs/type_cast.py +53 -0
  57. pixeltable/exprs/variable.py +45 -0
  58. pixeltable/ext/__init__.py +5 -0
  59. pixeltable/ext/functions/yolox.py +92 -0
  60. pixeltable/func/__init__.py +7 -0
  61. pixeltable/func/aggregate_function.py +197 -0
  62. pixeltable/func/callable_function.py +113 -0
  63. pixeltable/func/expr_template_function.py +99 -0
  64. pixeltable/func/function.py +141 -0
  65. pixeltable/func/function_registry.py +227 -0
  66. pixeltable/func/globals.py +46 -0
  67. pixeltable/func/nos_function.py +202 -0
  68. pixeltable/func/signature.py +162 -0
  69. pixeltable/func/udf.py +164 -0
  70. pixeltable/functions/__init__.py +95 -0
  71. pixeltable/functions/eval.py +215 -0
  72. pixeltable/functions/fireworks.py +34 -0
  73. pixeltable/functions/huggingface.py +167 -0
  74. pixeltable/functions/image.py +16 -0
  75. pixeltable/functions/openai.py +289 -0
  76. pixeltable/functions/pil/image.py +147 -0
  77. pixeltable/functions/string.py +13 -0
  78. pixeltable/functions/together.py +143 -0
  79. pixeltable/functions/util.py +52 -0
  80. pixeltable/functions/video.py +62 -0
  81. pixeltable/globals.py +425 -0
  82. pixeltable/index/__init__.py +2 -0
  83. pixeltable/index/base.py +51 -0
  84. pixeltable/index/embedding_index.py +168 -0
  85. pixeltable/io/__init__.py +3 -0
  86. pixeltable/io/hf_datasets.py +188 -0
  87. pixeltable/io/pandas.py +148 -0
  88. pixeltable/io/parquet.py +192 -0
  89. pixeltable/iterators/__init__.py +3 -0
  90. pixeltable/iterators/base.py +52 -0
  91. pixeltable/iterators/document.py +432 -0
  92. pixeltable/iterators/video.py +88 -0
  93. pixeltable/metadata/__init__.py +58 -0
  94. pixeltable/metadata/converters/convert_10.py +18 -0
  95. pixeltable/metadata/converters/convert_12.py +3 -0
  96. pixeltable/metadata/converters/convert_13.py +41 -0
  97. pixeltable/metadata/schema.py +234 -0
  98. pixeltable/plan.py +620 -0
  99. pixeltable/store.py +424 -0
  100. pixeltable/tool/create_test_db_dump.py +184 -0
  101. pixeltable/tool/create_test_video.py +81 -0
  102. pixeltable/type_system.py +846 -0
  103. pixeltable/utils/__init__.py +17 -0
  104. pixeltable/utils/arrow.py +98 -0
  105. pixeltable/utils/clip.py +18 -0
  106. pixeltable/utils/coco.py +136 -0
  107. pixeltable/utils/documents.py +69 -0
  108. pixeltable/utils/filecache.py +195 -0
  109. pixeltable/utils/help.py +11 -0
  110. pixeltable/utils/http_server.py +70 -0
  111. pixeltable/utils/media_store.py +76 -0
  112. pixeltable/utils/pytorch.py +91 -0
  113. pixeltable/utils/s3.py +13 -0
  114. pixeltable/utils/sql.py +17 -0
  115. pixeltable/utils/transactional_directory.py +35 -0
  116. pixeltable-0.0.0.dist-info/LICENSE +18 -0
  117. pixeltable-0.0.0.dist-info/METADATA +131 -0
  118. pixeltable-0.0.0.dist-info/RECORD +119 -0
  119. pixeltable-0.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,846 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import datetime
5
+ import enum
6
+ import json
7
+ import typing
8
+ import urllib.parse
9
+ import urllib.request
10
+ from copy import copy
11
+ from pathlib import Path
12
+ from typing import Any, Optional, Tuple, Dict, Callable, List, Union, Sequence, Mapping
13
+
14
+ import PIL.Image
15
+ import av
16
+ import numpy as np
17
+ import sqlalchemy as sql
18
+
19
+ from pixeltable import exceptions as excs
20
+
21
+
22
+ class ColumnType:
23
+ @enum.unique
24
+ class Type(enum.Enum):
25
+ STRING = 0
26
+ INT = 1
27
+ FLOAT = 2
28
+ BOOL = 3
29
+ TIMESTAMP = 4
30
+ JSON = 5
31
+ ARRAY = 6
32
+ IMAGE = 7
33
+ VIDEO = 8
34
+ AUDIO = 9
35
+ DOCUMENT = 10
36
+
37
+ # exprs that don't evaluate to a computable value in Pixeltable, such as an Image member function
38
+ INVALID = 255
39
+
40
+ @classmethod
41
+ def supertype(
42
+ cls, type1: 'Type', type2: 'Type',
43
+ # we need to pass this in because we can't easily append it as a class member
44
+ common_supertypes: Dict[Tuple['Type', 'Type'], 'Type']
45
+ ) -> Optional['Type']:
46
+ if type1 == type2:
47
+ return type1
48
+ t = common_supertypes.get((type1, type2))
49
+ if t is not None:
50
+ return t
51
+ t = common_supertypes.get((type2, type1))
52
+ if t is not None:
53
+ return t
54
+ return None
55
+
56
+
57
+ @enum.unique
58
+ class DType(enum.Enum):
59
+ """
60
+ Base type used in images and arrays
61
+ """
62
+ BOOL = 0,
63
+ INT8 = 1,
64
+ INT16 = 2,
65
+ INT32 = 3,
66
+ INT64 = 4,
67
+ UINT8 = 5,
68
+ UINT16 = 6,
69
+ UINT32 = 7,
70
+ UINT64 = 8,
71
+ FLOAT16 = 9,
72
+ FLOAT32 = 10,
73
+ FLOAT64 = 11
74
+
75
+ scalar_types = {Type.STRING, Type.INT, Type.FLOAT, Type.BOOL, Type.TIMESTAMP}
76
+ numeric_types = {Type.INT, Type.FLOAT}
77
+ common_supertypes: Dict[Tuple[Type, Type], Type] = {
78
+ (Type.BOOL, Type.INT): Type.INT,
79
+ (Type.BOOL, Type.FLOAT): Type.FLOAT,
80
+ (Type.INT, Type.FLOAT): Type.FLOAT,
81
+ }
82
+
83
+ def __init__(self, t: Type, nullable: bool = False):
84
+ self._type = t
85
+ self.nullable = nullable
86
+
87
+ @property
88
+ def type_enum(self) -> Type:
89
+ return self._type
90
+
91
+ def serialize(self) -> str:
92
+ return json.dumps(self.as_dict())
93
+
94
+ @classmethod
95
+ def serialize_list(cls, type_list: List[ColumnType]) -> str:
96
+ return json.dumps([t.as_dict() for t in type_list])
97
+
98
+ def as_dict(self) -> Dict:
99
+ return {
100
+ '_classname': self.__class__.__name__,
101
+ **self._as_dict(),
102
+ }
103
+
104
+ def _as_dict(self) -> Dict:
105
+ return {'nullable': self.nullable}
106
+
107
+ @classmethod
108
+ def deserialize(cls, type_str: str) -> ColumnType:
109
+ type_dict = json.loads(type_str)
110
+ return cls.from_dict(type_dict)
111
+
112
+ @classmethod
113
+ def deserialize_list(cls, type_list_str: str) -> List[ColumnType]:
114
+ type_dict_list = json.loads(type_list_str)
115
+ return [cls.from_dict(type_dict) for type_dict in type_dict_list]
116
+
117
+ @classmethod
118
+ def from_dict(cls, type_dict: Dict) -> ColumnType:
119
+ assert '_classname' in type_dict
120
+ type_class = globals()[type_dict['_classname']]
121
+ return type_class._from_dict(type_dict)
122
+
123
+ @classmethod
124
+ def _from_dict(cls, d: Dict) -> ColumnType:
125
+ """
126
+ Default implementation: simply invoke c'tor
127
+ """
128
+ assert 'nullable' in d
129
+ return cls(nullable=d['nullable'])
130
+
131
+ @classmethod
132
+ def make_type(cls, t: Type) -> ColumnType:
133
+ assert t != cls.Type.INVALID and t != cls.Type.ARRAY
134
+ if t == cls.Type.STRING:
135
+ return StringType()
136
+ if t == cls.Type.INT:
137
+ return IntType()
138
+ if t == cls.Type.FLOAT:
139
+ return FloatType()
140
+ if t == cls.Type.BOOL:
141
+ return BoolType()
142
+ if t == cls.Type.TIMESTAMP:
143
+ return TimestampType()
144
+ if t == cls.Type.JSON:
145
+ return JsonType()
146
+ if t == cls.Type.IMAGE:
147
+ return ImageType()
148
+ if t == cls.Type.VIDEO:
149
+ return VideoType()
150
+ if t == cls.Type.AUDIO:
151
+ return AudioType()
152
+ if t == cls.Type.DOCUMENT:
153
+ return AudioType()
154
+
155
+ def __str__(self) -> str:
156
+ return self._type.name.lower()
157
+
158
+ def __eq__(self, other: object) -> bool:
159
+ return self.matches(other) and self.nullable == other.nullable
160
+
161
+ def is_supertype_of(self, other: ColumnType) -> bool:
162
+ if type(self) != type(other):
163
+ return False
164
+ if self.matches(other):
165
+ return True
166
+ return self._is_supertype_of(other)
167
+
168
+ @abc.abstractmethod
169
+ def _is_supertype_of(self, other: ColumnType) -> bool:
170
+ return False
171
+
172
+ def matches(self, other: object) -> bool:
173
+ """Two types match if they're equal, aside from nullability"""
174
+ if not isinstance(other, ColumnType):
175
+ pass
176
+ assert isinstance(other, ColumnType)
177
+ if type(self) != type(other):
178
+ return False
179
+ for member_var in vars(self).keys():
180
+ if member_var == 'nullable':
181
+ continue
182
+ if getattr(self, member_var) != getattr(other, member_var):
183
+ return False
184
+ return True
185
+
186
+ @classmethod
187
+ def supertype(cls, type1: ColumnType, type2: ColumnType) -> Optional[ColumnType]:
188
+ if type1 == type2:
189
+ return type1
190
+
191
+ if type1.is_invalid_type():
192
+ return type2
193
+ if type2.is_invalid_type():
194
+ return type1
195
+
196
+ if type1.is_scalar_type() and type2.is_scalar_type():
197
+ t = cls.Type.supertype(type1._type, type2._type, cls.common_supertypes)
198
+ if t is not None:
199
+ return cls.make_type(t)
200
+ return None
201
+
202
+ if type1._type == type2._type:
203
+ return cls._supertype(type1, type2)
204
+
205
+ return None
206
+
207
+ @classmethod
208
+ @abc.abstractmethod
209
+ def _supertype(cls, type1: ColumnType, type2: ColumnType) -> Optional[ColumnType]:
210
+ """
211
+ Class-specific implementation of determining the supertype. type1 and type2 are from the same subclass of
212
+ ColumnType.
213
+ """
214
+ pass
215
+
216
+ @classmethod
217
+ def infer_literal_type(cls, val: Any) -> Optional[ColumnType]:
218
+ if isinstance(val, str):
219
+ return StringType()
220
+ if isinstance(val, int):
221
+ return IntType()
222
+ if isinstance(val, float):
223
+ return FloatType()
224
+ if isinstance(val, bool):
225
+ return BoolType()
226
+ if isinstance(val, datetime.datetime) or isinstance(val, datetime.date):
227
+ return TimestampType()
228
+ if isinstance(val, PIL.Image.Image):
229
+ return ImageType(width=val.width, height=val.height)
230
+ if isinstance(val, np.ndarray):
231
+ col_type = ArrayType.from_literal(val)
232
+ if col_type is not None:
233
+ return col_type
234
+ # this could still be json-serializable
235
+ if isinstance(val, dict) or isinstance(val, np.ndarray):
236
+ try:
237
+ JsonType().validate_literal(val)
238
+ return JsonType()
239
+ except TypeError:
240
+ return None
241
+ return None
242
+
243
+
244
+ @classmethod
245
+ def from_python_type(cls, t: type) -> Optional[ColumnType]:
246
+ if typing.get_origin(t) is typing.Union:
247
+ union_args = typing.get_args(t)
248
+ if union_args[1] is type(None):
249
+ # `t` is a type of the form Optional[T] (equivalently, Union[T, None]).
250
+ # We treat it as the underlying type but with nullable=True.
251
+ underlying = cls.from_python_type(union_args[0])
252
+ if underlying is not None:
253
+ underlying.nullable = True
254
+ return underlying
255
+ else:
256
+ # Discard type parameters to ensure that parameterized types such as `list[T]`
257
+ # are correctly mapped to Pixeltable types.
258
+ base = typing.get_origin(t)
259
+ if base is None:
260
+ # No type parameters; the base type is just `t` itself
261
+ base = t
262
+ if base is str:
263
+ return StringType()
264
+ if base is int:
265
+ return IntType()
266
+ if base is float:
267
+ return FloatType()
268
+ if base is bool:
269
+ return BoolType()
270
+ if base is datetime.date or base is datetime.datetime:
271
+ return TimestampType()
272
+ if issubclass(base, Sequence) or issubclass(base, Mapping):
273
+ return JsonType()
274
+ if issubclass(base, PIL.Image.Image):
275
+ return ImageType()
276
+ return None
277
+
278
+ def validate_literal(self, val: Any) -> None:
279
+ """Raise TypeError if val is not a valid literal for this type"""
280
+ if val is None:
281
+ if not self.nullable:
282
+ raise TypeError('Expected non-None value')
283
+ else:
284
+ return
285
+ self._validate_literal(val)
286
+
287
+ def validate_media(self, val: Any) -> None:
288
+ """
289
+ Raise TypeError if val is not a path to a valid media file (or a valid in-memory byte sequence) for this type
290
+ """
291
+ if self.is_media_type():
292
+ raise NotImplementedError(f'validate_media() not implemented for {self.__class__.__name__}')
293
+
294
+ def _validate_file_path(self, val: Any) -> None:
295
+ """Raises TypeError if not a valid local file path or not a path/byte sequence"""
296
+ if isinstance(val, str):
297
+ parsed = urllib.parse.urlparse(val)
298
+ if parsed.scheme != '' and parsed.scheme != 'file':
299
+ return
300
+ path = Path(urllib.parse.unquote(urllib.request.url2pathname(parsed.path)))
301
+ if not path.is_file():
302
+ raise TypeError(f'File not found: {str(path)}')
303
+ else:
304
+ if not isinstance(val, bytes):
305
+ raise TypeError(f'expected file path or bytes, got {type(val)}')
306
+
307
+ @abc.abstractmethod
308
+ def _validate_literal(self, val: Any) -> None:
309
+ """Raise TypeError if val is not a valid literal for this type"""
310
+ pass
311
+
312
+ @abc.abstractmethod
313
+ def _create_literal(self, val : Any) -> Any:
314
+ """Create a literal of this type from val, including any needed conversions.
315
+ val is guaranteed to be non-None"""
316
+ return val
317
+
318
+ def create_literal(self, val: Any) -> Any:
319
+ """Create a literal of this type from val or raise TypeError if not possible"""
320
+ if val is not None:
321
+ val = self._create_literal(val)
322
+
323
+ self.validate_literal(val)
324
+ return val
325
+
326
+ def print_value(self, val: Any) -> str:
327
+ return str(val)
328
+
329
+ def is_scalar_type(self) -> bool:
330
+ return self._type in self.scalar_types
331
+
332
+ def is_numeric_type(self) -> bool:
333
+ return self._type in self.numeric_types
334
+
335
+ def is_invalid_type(self) -> bool:
336
+ return self._type == self.Type.INVALID
337
+
338
+ def is_string_type(self) -> bool:
339
+ return self._type == self.Type.STRING
340
+
341
+ def is_int_type(self) -> bool:
342
+ return self._type == self.Type.INT
343
+
344
+ def is_float_type(self) -> bool:
345
+ return self._type == self.Type.FLOAT
346
+
347
+ def is_bool_type(self) -> bool:
348
+ return self._type == self.Type.BOOL
349
+
350
+ def is_timestamp_type(self) -> bool:
351
+ return self._type == self.Type.TIMESTAMP
352
+
353
+ def is_json_type(self) -> bool:
354
+ return self._type == self.Type.JSON
355
+
356
+ def is_array_type(self) -> bool:
357
+ return self._type == self.Type.ARRAY
358
+
359
+ def is_image_type(self) -> bool:
360
+ return self._type == self.Type.IMAGE
361
+
362
+ def is_video_type(self) -> bool:
363
+ return self._type == self.Type.VIDEO
364
+
365
+ def is_audio_type(self) -> bool:
366
+ return self._type == self.Type.AUDIO
367
+
368
+ def is_document_type(self) -> bool:
369
+ return self._type == self.Type.DOCUMENT
370
+
371
+ def is_media_type(self) -> bool:
372
+ # types that refer to external media files
373
+ return self.is_image_type() or self.is_video_type() or self.is_audio_type() or self.is_document_type()
374
+
375
+ @abc.abstractmethod
376
+ def to_sa_type(self) -> sql.types.TypeEngine:
377
+ """
378
+ Return corresponding SQLAlchemy type.
379
+ """
380
+ pass
381
+
382
+ @staticmethod
383
+ def no_conversion(v: Any) -> Any:
384
+ """
385
+ Special return value of conversion_fn() that indicates that no conversion is necessary.
386
+ Should not be called
387
+ """
388
+ assert False
389
+
390
+ def conversion_fn(self, target: ColumnType) -> Optional[Callable[[Any], Any]]:
391
+ """
392
+ Return Callable that converts a column value of type self to a value of type 'target'.
393
+ Returns None if conversion isn't possible.
394
+ """
395
+ return None
396
+
397
+
398
+ class InvalidType(ColumnType):
399
+ def __init__(self, nullable: bool = False):
400
+ super().__init__(self.Type.INVALID, nullable=nullable)
401
+
402
+ def to_sa_type(self) -> sql.types.TypeEngine:
403
+ assert False
404
+
405
+ def print_value(self, val: Any) -> str:
406
+ assert False
407
+
408
+ def _validate_literal(self, val: Any) -> None:
409
+ assert False
410
+
411
+
412
+ class StringType(ColumnType):
413
+ def __init__(self, nullable: bool = False):
414
+ super().__init__(self.Type.STRING, nullable=nullable)
415
+
416
+ def conversion_fn(self, target: ColumnType) -> Optional[Callable[[Any], Any]]:
417
+ if not target.is_timestamp_type():
418
+ return None
419
+ def convert(val: str) -> Optional[datetime]:
420
+ try:
421
+ dt = datetime.datetime.fromisoformat(val)
422
+ return dt
423
+ except ValueError:
424
+ return None
425
+ return convert
426
+
427
+ def to_sa_type(self) -> sql.types.TypeEngine:
428
+ return sql.String()
429
+
430
+ def print_value(self, val: Any) -> str:
431
+ return f"'{val}'"
432
+
433
+ def _validate_literal(self, val: Any) -> None:
434
+ if not isinstance(val, str):
435
+ raise TypeError(f'Expected string, got {val.__class__.__name__}')
436
+
437
+ def _create_literal(self, val: Any) -> Any:
438
+ # Replace null byte within python string with space to avoid issues with Postgres.
439
+ # Use a space to avoid merging words.
440
+ # TODO(orm): this will also be an issue with JSON inputs, would space still be a good replacement?
441
+ if isinstance(val, str) and '\x00' in val:
442
+ return val.replace('\x00', ' ')
443
+ return val
444
+
445
+
446
+ class IntType(ColumnType):
447
+ def __init__(self, nullable: bool = False):
448
+ super().__init__(self.Type.INT, nullable=nullable)
449
+
450
+ def to_sa_type(self) -> sql.types.TypeEngine:
451
+ return sql.BigInteger()
452
+
453
+ def _validate_literal(self, val: Any) -> None:
454
+ if not isinstance(val, int):
455
+ raise TypeError(f'Expected int, got {val.__class__.__name__}')
456
+
457
+
458
+ class FloatType(ColumnType):
459
+ def __init__(self, nullable: bool = False):
460
+ super().__init__(self.Type.FLOAT, nullable=nullable)
461
+
462
+ def to_sa_type(self) -> sql.types.TypeEngine:
463
+ return sql.Float()
464
+
465
+ def _validate_literal(self, val: Any) -> None:
466
+ if not isinstance(val, float):
467
+ raise TypeError(f'Expected float, got {val.__class__.__name__}')
468
+
469
+ def _create_literal(self, val: Any) -> Any:
470
+ if isinstance(val, int):
471
+ return float(val)
472
+ return val
473
+
474
+
475
+ class BoolType(ColumnType):
476
+ def __init__(self, nullable: bool = False):
477
+ super().__init__(self.Type.BOOL, nullable=nullable)
478
+
479
+ def to_sa_type(self) -> sql.types.TypeEngine:
480
+ return sql.Boolean()
481
+
482
+ def _validate_literal(self, val: Any) -> None:
483
+ if not isinstance(val, bool):
484
+ raise TypeError(f'Expected bool, got {val.__class__.__name__}')
485
+
486
+ def _create_literal(self, val: Any) -> Any:
487
+ if isinstance(val, int):
488
+ return bool(val)
489
+ return val
490
+
491
+
492
+ class TimestampType(ColumnType):
493
+ def __init__(self, nullable: bool = False):
494
+ super().__init__(self.Type.TIMESTAMP, nullable=nullable)
495
+
496
+ def to_sa_type(self) -> sql.types.TypeEngine:
497
+ return sql.TIMESTAMP()
498
+
499
+ def _validate_literal(self, val: Any) -> None:
500
+ if not isinstance(val, datetime.datetime) and not isinstance(val, datetime.date):
501
+ raise TypeError(f'Expected datetime.datetime or datetime.date, got {val.__class__.__name__}')
502
+
503
+ def _create_literal(self, val: Any) -> Any:
504
+ if isinstance(val, str):
505
+ return datetime.datetime.fromisoformat(val)
506
+ return val
507
+
508
+
509
+ class JsonType(ColumnType):
510
+ # TODO: type_spec also needs to be able to express lists
511
+ def __init__(self, type_spec: Optional[Dict[str, ColumnType]] = None, nullable: bool = False):
512
+ super().__init__(self.Type.JSON, nullable=nullable)
513
+ self.type_spec = type_spec
514
+
515
+ def _as_dict(self) -> Dict:
516
+ result = super()._as_dict()
517
+ if self.type_spec is not None:
518
+ type_spec_dict = {field_name: field_type.serialize() for field_name, field_type in self.type_spec.items()}
519
+ result.update({'type_spec': type_spec_dict})
520
+ return result
521
+
522
+ @classmethod
523
+ def _from_dict(cls, d: Dict) -> ColumnType:
524
+ type_spec = None
525
+ if 'type_spec' in d:
526
+ type_spec = {
527
+ field_name: cls.deserialize(field_type_dict) for field_name, field_type_dict in d['type_spec'].items()
528
+ }
529
+ return cls(type_spec, nullable=d['nullable'])
530
+
531
+ def to_sa_type(self) -> sql.types.TypeEngine:
532
+ return sql.dialects.postgresql.JSONB()
533
+
534
+ def print_value(self, val: Any) -> str:
535
+ val_type = self.infer_literal_type(val)
536
+ if val_type is None:
537
+ return super().print_value(val)
538
+ if val_type == self:
539
+ return str(val)
540
+ return val_type.print_value(val)
541
+
542
+ def _validate_literal(self, val: Any) -> None:
543
+ if not isinstance(val, dict) and not isinstance(val, list):
544
+ raise TypeError(f'Expected dict or list, got {val.__class__.__name__}')
545
+ try:
546
+ _ = json.dumps(val)
547
+ except TypeError as e:
548
+ raise TypeError(f'Expected JSON-serializable object, got {val}')
549
+
550
+ def _create_literal(self, val: Any) -> Any:
551
+ if isinstance(val, tuple):
552
+ val = list(val)
553
+ return val
554
+
555
+
556
+ class ArrayType(ColumnType):
557
+ def __init__(
558
+ self, shape: Tuple[Union[int, None], ...], dtype: ColumnType, nullable: bool = False):
559
+ super().__init__(self.Type.ARRAY, nullable=nullable)
560
+ self.shape = shape
561
+ assert dtype.is_int_type() or dtype.is_float_type() or dtype.is_bool_type() or dtype.is_string_type()
562
+ self.dtype = dtype._type
563
+
564
+ def _supertype(cls, type1: ArrayType, type2: ArrayType) -> Optional[ArrayType]:
565
+ if len(type1.shape) != len(type2.shape):
566
+ return None
567
+ base_type = ColumnType.supertype(type1.dtype, type2.dtype)
568
+ if base_type is None:
569
+ return None
570
+ shape = [n1 if n1 == n2 else None for n1, n2 in zip(type1.shape, type2.shape)]
571
+ return ArrayType(tuple(shape), base_type)
572
+
573
+ def _as_dict(self) -> Dict:
574
+ result = super()._as_dict()
575
+ result.update(shape=list(self.shape), dtype=self.dtype.value)
576
+ return result
577
+
578
+ def __str__(self) -> str:
579
+ return f'{self._type.name.lower()}({self.shape}, dtype={self.dtype.name})'
580
+
581
+ @classmethod
582
+ def _from_dict(cls, d: Dict) -> ColumnType:
583
+ assert 'shape' in d
584
+ assert 'dtype' in d
585
+ shape = tuple(d['shape'])
586
+ dtype = cls.make_type(cls.Type(d['dtype']))
587
+ return cls(shape, dtype, nullable=d['nullable'])
588
+
589
+ @classmethod
590
+ def from_literal(cls, val: np.ndarray) -> Optional[ArrayType]:
591
+ # determine our dtype
592
+ assert isinstance(val, np.ndarray)
593
+ if np.issubdtype(val.dtype, np.integer):
594
+ dtype = IntType()
595
+ elif np.issubdtype(val.dtype, np.floating):
596
+ dtype = FloatType()
597
+ elif val.dtype == np.bool_:
598
+ dtype = BoolType()
599
+ elif val.dtype == np.str_:
600
+ dtype = StringType()
601
+ else:
602
+ return None
603
+ return cls(val.shape, dtype=dtype, nullable=True)
604
+
605
+ def is_valid_literal(self, val: np.ndarray) -> bool:
606
+ if not isinstance(val, np.ndarray):
607
+ return False
608
+ if len(val.shape) != len(self.shape):
609
+ return False
610
+ # check that the shapes are compatible
611
+ for n1, n2 in zip(val.shape, self.shape):
612
+ if n1 is None:
613
+ return False
614
+ if n2 is None:
615
+ # wildcard
616
+ continue
617
+ if n1 != n2:
618
+ return False
619
+ return val.dtype == self.numpy_dtype()
620
+
621
+ def _validate_literal(self, val: Any) -> None:
622
+ if not isinstance(val, np.ndarray):
623
+ raise TypeError(f'Expected numpy.ndarray, got {val.__class__.__name__}')
624
+ if not self.is_valid_literal(val):
625
+ raise TypeError((
626
+ f'Expected ndarray({self.shape}, dtype={self.numpy_dtype()}), '
627
+ f'got ndarray({val.shape}, dtype={val.dtype})'))
628
+
629
+ def _create_literal(self, val: Any) -> Any:
630
+ if isinstance(val, (list,tuple)):
631
+ # map python float to whichever numpy float is
632
+ # declared for this type, rather than assume float64
633
+ return np.array(val, dtype=self.numpy_dtype())
634
+ return val
635
+
636
+ def to_sa_type(self) -> sql.types.TypeEngine:
637
+ return sql.LargeBinary()
638
+
639
+ def numpy_dtype(self) -> np.dtype:
640
+ if self.dtype == self.Type.INT:
641
+ return np.dtype(np.int64)
642
+ if self.dtype == self.Type.FLOAT:
643
+ return np.dtype(np.float32)
644
+ if self.dtype == self.Type.BOOL:
645
+ return np.dtype(np.bool_)
646
+ if self.dtype == self.Type.STRING:
647
+ return np.dtype(np.str_)
648
+ assert False
649
+
650
+
651
+ class ImageType(ColumnType):
652
+ def __init__(
653
+ self, width: Optional[int] = None, height: Optional[int] = None, size: Optional[Tuple[int, int]] = None,
654
+ mode: Optional[str] = None, nullable: bool = False
655
+ ):
656
+ """
657
+ TODO: does it make sense to specify only width or height?
658
+ """
659
+ super().__init__(self.Type.IMAGE, nullable=nullable)
660
+ assert not(width is not None and size is not None)
661
+ assert not(height is not None and size is not None)
662
+ if size is not None:
663
+ self.width = size[0]
664
+ self.height = size[1]
665
+ else:
666
+ self.width = width
667
+ self.height = height
668
+ self.mode = mode
669
+
670
+ def __str__(self) -> str:
671
+ if self.width is not None or self.height is not None or self.mode is not None:
672
+ params_str = ''
673
+ if self.width is not None:
674
+ params_str = f'width={self.width}'
675
+ if self.height is not None:
676
+ if len(params_str) > 0:
677
+ params_str += ', '
678
+ params_str += f'height={self.height}'
679
+ if self.mode is not None:
680
+ if len(params_str) > 0:
681
+ params_str += ', '
682
+ params_str += f'mode={self.mode}'
683
+ params_str = f'({params_str})'
684
+ else:
685
+ params_str = ''
686
+ return f'{self._type.name.lower()}{params_str}'
687
+
688
+ def _is_supertype_of(self, other: ImageType) -> bool:
689
+ if self.mode != other.mode:
690
+ return False
691
+ if self.width is None and self.height is None:
692
+ return True
693
+ if self.width != other.width and self.height != other.height:
694
+ return False
695
+
696
+ @property
697
+ def size(self) -> Optional[Tuple[int, int]]:
698
+ if self.width is None or self.height is None:
699
+ return None
700
+ return (self.width, self.height)
701
+
702
+ @property
703
+ def num_channels(self) -> Optional[int]:
704
+ return None if self.mode is None else self.mode.num_channels()
705
+
706
+ def _as_dict(self) -> Dict:
707
+ result = super()._as_dict()
708
+ result.update(width=self.width, height=self.height, mode=self.mode)
709
+ return result
710
+
711
+ @classmethod
712
+ def _from_dict(cls, d: Dict) -> ColumnType:
713
+ assert 'width' in d
714
+ assert 'height' in d
715
+ assert 'mode' in d
716
+ return cls(width=d['width'], height=d['height'], mode=d['mode'], nullable=d['nullable'])
717
+
718
+ def conversion_fn(self, target: ColumnType) -> Optional[Callable[[Any], Any]]:
719
+ if not target.is_image_type():
720
+ return None
721
+ assert isinstance(target, ImageType)
722
+ if (target.width is None) != (target.height is None):
723
+ # we can't resize only one dimension
724
+ return None
725
+ if (target.width == self.width or target.width is None) \
726
+ and (target.height == self.height or target.height is None) \
727
+ and (target.mode == self.mode or target.mode is None):
728
+ # nothing to do
729
+ return self.no_conversion
730
+ def convert(img: PIL.Image.Image) -> PIL.Image.Image:
731
+ if self.width != target.width or self.height != target.height:
732
+ img = img.resize((target.width, target.height))
733
+ if self.mode != target.mode:
734
+ img = img.convert(target.mode.to_pil())
735
+ return img
736
+ return convert
737
+
738
+ def to_sa_type(self) -> sql.types.TypeEngine:
739
+ return sql.String()
740
+
741
+ def _validate_literal(self, val: Any) -> None:
742
+ if isinstance(val, PIL.Image.Image):
743
+ return
744
+ self._validate_file_path(val)
745
+
746
+ def validate_media(self, val: Any) -> None:
747
+ assert isinstance(val, str)
748
+ try:
749
+ _ = PIL.Image.open(val)
750
+ except PIL.UnidentifiedImageError:
751
+ raise excs.Error(f'Not a valid image: {val}') from None
752
+
753
+
754
+ class VideoType(ColumnType):
755
+ def __init__(self, nullable: bool = False):
756
+ super().__init__(self.Type.VIDEO, nullable=nullable)
757
+
758
+ def to_sa_type(self) -> sql.types.TypeEngine:
759
+ # stored as a file path
760
+ return sql.String()
761
+
762
+ def _validate_literal(self, val: Any) -> None:
763
+ self._validate_file_path(val)
764
+
765
+ def validate_media(self, val: Any) -> None:
766
+ assert isinstance(val, str)
767
+ try:
768
+ with av.open(val, 'r') as fh:
769
+ if len(fh.streams.video) == 0:
770
+ raise excs.Error(f'Not a valid video: {val}')
771
+ # decode a few frames to make sure it's playable
772
+ # TODO: decode all frames? but that's very slow
773
+ num_decoded = 0
774
+ for frame in fh.decode(video=0):
775
+ _ = frame.to_image()
776
+ num_decoded += 1
777
+ if num_decoded == 10:
778
+ break
779
+ if num_decoded < 2:
780
+ # this is most likely an image file
781
+ raise excs.Error(f'Not a valid video: {val}')
782
+ except av.AVError:
783
+ raise excs.Error(f'Not a valid video: {val}') from None
784
+
785
+
786
+ class AudioType(ColumnType):
787
+ def __init__(self, nullable: bool = False):
788
+ super().__init__(self.Type.AUDIO, nullable=nullable)
789
+
790
+ def to_sa_type(self) -> sql.types.TypeEngine:
791
+ # stored as a file path
792
+ return sql.String()
793
+
794
+ def _validate_literal(self, val: Any) -> None:
795
+ self._validate_file_path(val)
796
+
797
+ def validate_media(self, val: Any) -> None:
798
+ try:
799
+ with av.open(val) as container:
800
+ if len(container.streams.audio) == 0:
801
+ raise excs.Error(f'No audio stream in file: {val}')
802
+ audio_stream = container.streams.audio[0]
803
+
804
+ # decode everything to make sure it's playable
805
+ # TODO: is there some way to verify it's a playable audio file other than decoding all of it?
806
+ for packet in container.demux(audio_stream):
807
+ for _ in packet.decode():
808
+ pass
809
+ except av.AVError as e:
810
+ raise excs.Error(f'Not a valid audio file: {val}\n{e}') from None
811
+
812
+
813
+ class DocumentType(ColumnType):
814
+ @enum.unique
815
+ class DocumentFormat(enum.Enum):
816
+ HTML = 0
817
+ MD = 1
818
+ PDF = 2
819
+
820
+ def __init__(self, nullable: bool = False, doc_formats: Optional[str] = None):
821
+ super().__init__(self.Type.DOCUMENT, nullable=nullable)
822
+ if doc_formats is not None:
823
+ type_strs = doc_formats.split(',')
824
+ for type_str in type_strs:
825
+ if not hasattr(self.DocumentFormat, type_str):
826
+ raise ValueError(f'Invalid document type: {type_str}')
827
+ self._doc_formats = [self.DocumentFormat[type_str.upper()] for type_str in type_strs]
828
+ else:
829
+ self._doc_formats = [t for t in self.DocumentFormat]
830
+
831
+ def to_sa_type(self) -> sql.types.TypeEngine:
832
+ # stored as a file path
833
+ return sql.String()
834
+
835
+ def _validate_literal(self, val: Any) -> None:
836
+ self._validate_file_path(val)
837
+
838
+ def validate_media(self, val: Any) -> None:
839
+ assert isinstance(val, str)
840
+ from pixeltable.utils.documents import get_document_handle
841
+ try:
842
+ dh = get_document_handle(val)
843
+ if dh is None:
844
+ raise excs.Error(f'Not a recognized document format: {val}')
845
+ except Exception as e:
846
+ raise excs.Error(f'Not a recognized document format: {val}') from None