pixeltable 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (139) hide show
  1. pixeltable/__init__.py +34 -6
  2. pixeltable/catalog/__init__.py +13 -0
  3. pixeltable/catalog/catalog.py +159 -0
  4. pixeltable/catalog/column.py +200 -0
  5. pixeltable/catalog/dir.py +32 -0
  6. pixeltable/catalog/globals.py +33 -0
  7. pixeltable/catalog/insertable_table.py +191 -0
  8. pixeltable/catalog/named_function.py +36 -0
  9. pixeltable/catalog/path.py +58 -0
  10. pixeltable/catalog/path_dict.py +139 -0
  11. pixeltable/catalog/schema_object.py +39 -0
  12. pixeltable/catalog/table.py +581 -0
  13. pixeltable/catalog/table_version.py +749 -0
  14. pixeltable/catalog/table_version_path.py +133 -0
  15. pixeltable/catalog/view.py +203 -0
  16. pixeltable/client.py +520 -30
  17. pixeltable/dataframe.py +540 -349
  18. pixeltable/env.py +373 -45
  19. pixeltable/exceptions.py +12 -21
  20. pixeltable/exec/__init__.py +9 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +113 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +95 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +69 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +225 -0
  31. pixeltable/exprs/__init__.py +24 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +105 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +187 -0
  39. pixeltable/exprs/expr.py +586 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +380 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +115 -0
  44. pixeltable/exprs/image_similarity_predicate.py +58 -0
  45. pixeltable/exprs/inline_array.py +107 -0
  46. pixeltable/exprs/inline_dict.py +101 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +54 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +355 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/type_cast.py +53 -0
  56. pixeltable/exprs/variable.py +45 -0
  57. pixeltable/func/__init__.py +9 -0
  58. pixeltable/func/aggregate_function.py +194 -0
  59. pixeltable/func/batched_function.py +53 -0
  60. pixeltable/func/callable_function.py +69 -0
  61. pixeltable/func/expr_template_function.py +82 -0
  62. pixeltable/func/function.py +110 -0
  63. pixeltable/func/function_registry.py +227 -0
  64. pixeltable/func/globals.py +36 -0
  65. pixeltable/func/nos_function.py +202 -0
  66. pixeltable/func/signature.py +166 -0
  67. pixeltable/func/udf.py +163 -0
  68. pixeltable/functions/__init__.py +52 -103
  69. pixeltable/functions/eval.py +216 -0
  70. pixeltable/functions/fireworks.py +61 -0
  71. pixeltable/functions/huggingface.py +120 -0
  72. pixeltable/functions/image.py +16 -0
  73. pixeltable/functions/openai.py +88 -0
  74. pixeltable/functions/pil/image.py +148 -7
  75. pixeltable/functions/string.py +13 -0
  76. pixeltable/functions/together.py +27 -0
  77. pixeltable/functions/util.py +41 -0
  78. pixeltable/functions/video.py +62 -0
  79. pixeltable/iterators/__init__.py +3 -0
  80. pixeltable/iterators/base.py +48 -0
  81. pixeltable/iterators/document.py +311 -0
  82. pixeltable/iterators/video.py +89 -0
  83. pixeltable/metadata/__init__.py +54 -0
  84. pixeltable/metadata/converters/convert_10.py +18 -0
  85. pixeltable/metadata/schema.py +211 -0
  86. pixeltable/plan.py +656 -0
  87. pixeltable/store.py +413 -182
  88. pixeltable/tests/conftest.py +143 -87
  89. pixeltable/tests/test_audio.py +65 -0
  90. pixeltable/tests/test_catalog.py +27 -0
  91. pixeltable/tests/test_client.py +14 -14
  92. pixeltable/tests/test_component_view.py +372 -0
  93. pixeltable/tests/test_dataframe.py +433 -0
  94. pixeltable/tests/test_dirs.py +78 -62
  95. pixeltable/tests/test_document.py +117 -0
  96. pixeltable/tests/test_exprs.py +591 -135
  97. pixeltable/tests/test_function.py +297 -67
  98. pixeltable/tests/test_functions.py +283 -1
  99. pixeltable/tests/test_migration.py +43 -0
  100. pixeltable/tests/test_nos.py +54 -0
  101. pixeltable/tests/test_snapshot.py +208 -0
  102. pixeltable/tests/test_table.py +1085 -262
  103. pixeltable/tests/test_transactional_directory.py +42 -0
  104. pixeltable/tests/test_types.py +5 -11
  105. pixeltable/tests/test_video.py +149 -34
  106. pixeltable/tests/test_view.py +530 -0
  107. pixeltable/tests/utils.py +186 -45
  108. pixeltable/tool/create_test_db_dump.py +149 -0
  109. pixeltable/type_system.py +490 -126
  110. pixeltable/utils/__init__.py +17 -46
  111. pixeltable/utils/clip.py +12 -15
  112. pixeltable/utils/coco.py +136 -0
  113. pixeltable/utils/documents.py +39 -0
  114. pixeltable/utils/filecache.py +195 -0
  115. pixeltable/utils/help.py +11 -0
  116. pixeltable/utils/media_store.py +76 -0
  117. pixeltable/utils/parquet.py +126 -0
  118. pixeltable/utils/pytorch.py +172 -0
  119. pixeltable/utils/s3.py +13 -0
  120. pixeltable/utils/sql.py +17 -0
  121. pixeltable/utils/transactional_directory.py +35 -0
  122. pixeltable-0.2.0.dist-info/LICENSE +18 -0
  123. pixeltable-0.2.0.dist-info/METADATA +117 -0
  124. pixeltable-0.2.0.dist-info/RECORD +125 -0
  125. {pixeltable-0.1.1.dist-info → pixeltable-0.2.0.dist-info}/WHEEL +1 -1
  126. pixeltable/catalog.py +0 -1421
  127. pixeltable/exprs.py +0 -1745
  128. pixeltable/function.py +0 -269
  129. pixeltable/functions/clip.py +0 -10
  130. pixeltable/functions/pil/__init__.py +0 -23
  131. pixeltable/functions/tf.py +0 -21
  132. pixeltable/index.py +0 -57
  133. pixeltable/tests/test_dict.py +0 -24
  134. pixeltable/tests/test_tf.py +0 -69
  135. pixeltable/tf.py +0 -33
  136. pixeltable/utils/tf.py +0 -33
  137. pixeltable/utils/video.py +0 -32
  138. pixeltable-0.1.1.dist-info/METADATA +0 -31
  139. pixeltable-0.1.1.dist-info/RECORD +0 -36
pixeltable/exprs.py DELETED
@@ -1,1745 +0,0 @@
1
- import abc
2
- import copy
3
- import datetime
4
- import enum
5
- import inspect
6
- import typing
7
- from typing import Union, Optional, List, Callable, Any, Dict, Tuple, Set, Generator, Iterator
8
- import operator
9
- import json
10
- import io
11
-
12
- import PIL.Image
13
- import jmespath
14
- import numpy as np
15
- import sqlalchemy as sql
16
-
17
- from pixeltable import catalog
18
- from pixeltable.type_system import \
19
- ColumnType, InvalidType, StringType, IntType, FloatType, BoolType, TimestampType, ImageType, JsonType, ArrayType
20
- from pixeltable.function import Function
21
- from pixeltable import exceptions as exc
22
- from pixeltable.utils import clip
23
-
24
- # Python types corresponding to our literal types
25
- LiteralPythonTypes = Union[str, int, float, bool, datetime.datetime, datetime.date]
26
-
27
-
28
- class ComparisonOperator(enum.Enum):
29
- LT = 0
30
- LE = 1
31
- EQ = 2
32
- NE = 3
33
- GT = 4
34
- GE = 5
35
-
36
- def __str__(self) -> str:
37
- if self == self.LT:
38
- return '<'
39
- if self == self.LE:
40
- return '<='
41
- if self == self.EQ:
42
- return '=='
43
- if self == self.GT:
44
- return '>'
45
- if self == self.GE:
46
- return '>='
47
-
48
-
49
- class LogicalOperator(enum.Enum):
50
- AND = 0
51
- OR = 1
52
- NOT = 2
53
-
54
-
55
- class ArithmeticOperator(enum.Enum):
56
- ADD = 0
57
- SUB = 1
58
- MUL = 2
59
- DIV = 3
60
- MOD = 4
61
-
62
- def __str__(self) -> str:
63
- if self == self.ADD:
64
- return '+'
65
- if self == self.SUB:
66
- return '-'
67
- if self == self.MUL:
68
- return '*'
69
- if self == self.DIV:
70
- return '/'
71
- if self == self.MOD:
72
- return '%'
73
-
74
-
75
- class ExprScope:
76
- """
77
- Representation of the scope in which an Expr needs to be evaluated. Used to determine nesting of scopes.
78
- parent is None: outermost scope
79
- """
80
- def __init__(self, parent: Optional['ExprScope']):
81
- self.parent = parent
82
-
83
- def is_contained_in(self, other: 'ExprScope') -> bool:
84
- if self == other:
85
- return True
86
- if self.parent is None:
87
- return False
88
- return self.parent.is_contained_in(other)
89
-
90
-
91
- _GLOBAL_SCOPE = ExprScope(None)
92
-
93
-
94
- class Expr(abc.ABC):
95
- """
96
- Rules for using state in subclasses:
97
- - all state except for components and data/sql_row_idx is shared between copies of an Expr
98
- - data/sql_row_idx is set during analysis (DataFrame.show())
99
- - during eval(), components can only be accessed via self.components; any Exprs outside of that won't
100
- have data_row_idx set
101
- """
102
- def __init__(self, col_type: ColumnType):
103
- self.col_type = col_type
104
- # index of the expr's value in the data row; set for all materialized exprs; -1: invalid
105
- # not set for subexprs that don't need to be materialized because the parent can be materialized via SQL
106
- self.data_row_idx = -1
107
- # index of the expr's value in the SQL row; only set for exprs that can be materialized in SQL; -1: invalid
108
- self.sql_row_idx = -1
109
- self.components: List[Expr] = [] # the subexprs that are needed to construct this expr
110
-
111
- def dependencies(self) -> List['Expr']:
112
- """
113
- Returns all exprs that need to have been evaluated before eval() can be called on this one.
114
- """
115
- return self.components
116
-
117
- def scope(self) -> ExprScope:
118
- # by default this is the innermost scope of any of our components
119
- result = _GLOBAL_SCOPE
120
- for c in self.components:
121
- c_scope = c.scope()
122
- if c_scope.is_contained_in(result):
123
- result = c_scope
124
- return result
125
-
126
- def bind_rel_paths(self, mapper: Optional['JsonMapper'] = None) -> None:
127
- """
128
- Binds relative JsonPaths to mapper.
129
- This needs to be done in a separate phase after __init__(), because RelativeJsonPath()(-1) cannot be resolved
130
- by the immediately containing JsonMapper during initialization.
131
- """
132
- for c in self.components:
133
- c.bind_rel_paths(mapper)
134
-
135
- def display_name(self) -> str:
136
- """
137
- Displayed column name in DataFrame. '': assigned by DataFrame
138
- """
139
- return ''
140
-
141
- def equals(self, other: 'Expr') -> bool:
142
- """
143
- Subclass-specific comparison. Implemented as a function because __eq__() is needed to construct Comparisons.
144
- """
145
- if type(self) != type(other):
146
- return False
147
- if len(self.components) != len(other.components):
148
- return False
149
- for i in range(len(self.components)):
150
- if not self.components[i].equals(other.components[i]):
151
- return False
152
- return self._equals(other)
153
-
154
- @classmethod
155
- def list_equals(cls, a: List['Expr'], b: List['Expr']) -> bool:
156
- if len(a) != len(b):
157
- return False
158
- for i in range(len(a)):
159
- if not a[i].equals(b[i]):
160
- return False
161
- return True
162
-
163
- def copy(self) -> 'Expr':
164
- """
165
- Creates a copy that can be evaluated separately: it doesn't share any eval context (data/sql_row_idx)
166
- but shares everything else (catalog objects, etc.)
167
- """
168
- cls = self.__class__
169
- result = cls.__new__(cls)
170
- result.__dict__.update(self.__dict__)
171
- result.data_row_idx = -1
172
- result.sql_row_idx = -1
173
- result.components = [c.copy() for c in self.components]
174
- return result
175
-
176
- @classmethod
177
- def copy_list(cls, expr_list: List['Expr']) -> List['Expr']:
178
- return [e.copy() for e in expr_list]
179
-
180
- def __deepcopy__(self, memo={}) -> 'Expr':
181
- # we don't need to create an actual deep copy because all state other than execution state is read-only
182
- result = self.copy()
183
- memo[id(self)] = result
184
- return result
185
-
186
- def subexprs(self) -> Generator['Expr', None, None]:
187
- """
188
- Iterate over all subexprs, including self.
189
- """
190
- for c in self.components:
191
- yield from c.subexprs()
192
- yield self
193
-
194
- @classmethod
195
- def list_subexprs(cls, expr_list: List['Expr']) -> Generator['Expr', None, None]:
196
- """
197
- Produce subexprs for all exprs in list.
198
- """
199
- for e in expr_list:
200
- yield from e.subexprs()
201
-
202
- @classmethod
203
- def from_object(cls, o: object) -> Optional['Expr']:
204
- """
205
- Try to turn a literal object into an Expr.
206
- """
207
- if isinstance(o, Expr):
208
- return o
209
- if isinstance(o, dict):
210
- return InlineDict(o)
211
- elif isinstance(o, list):
212
- return InlineArray(tuple(o))
213
- return None
214
-
215
- @abc.abstractmethod
216
- def _equals(self, other: 'Expr') -> bool:
217
- pass
218
-
219
- @abc.abstractmethod
220
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
221
- """
222
- If this expr can be materialized directly in SQL:
223
- - returns a ClauseElement
224
- - eval() will not be called (exception: Literal)
225
- Otherwise
226
- - returns None
227
- - eval() will be called
228
- """
229
- pass
230
-
231
- @abc.abstractmethod
232
- def eval(self, data_row: List[Any]) -> None:
233
- """
234
- Compute the expr value for data_row and store the result in data_row[data_row_idx].
235
- Not called if sql_expr() != None (exception: Literal).
236
- """
237
- pass
238
-
239
- def serialize(self) -> str:
240
- return json.dumps(self.as_dict())
241
-
242
- def as_dict(self) -> Dict:
243
- """
244
- Turn Expr object into a dict that can be passed to json.dumps().
245
- Subclasses override _as_dict().
246
- """
247
- return {
248
- '_classname': self.__class__.__name__,
249
- **self._as_dict(),
250
- }
251
-
252
- @classmethod
253
- def as_dict_list(self, expr_list: List['Expr']) -> List[Dict]:
254
- return [e.as_dict() for e in expr_list]
255
-
256
- def _as_dict(self) -> Dict:
257
- if len(self.components) > 0:
258
- return {'components': [c.as_dict() for c in self.components]}
259
- return {}
260
-
261
- @classmethod
262
- def deserialize(cls, dict_str: str, t: catalog.Table) -> 'Expr':
263
- return cls.from_dict(json.loads(dict_str), t)
264
-
265
- @classmethod
266
- def from_dict(cls, d: Dict, t: catalog.Table) -> 'Expr':
267
- """
268
- Turn dict that was produced by calling Expr.as_dict() into an instance of the correct Expr subclass.
269
- """
270
- assert '_classname' in d
271
- type_class = globals()[d['_classname']]
272
- components: List[Expr] = []
273
- if 'components' in d:
274
- components = [cls.from_dict(component_dict, t) for component_dict in d['components']]
275
- return type_class._from_dict(d, components, t)
276
-
277
- @classmethod
278
- def from_dict_list(cls, dict_list: List[Dict], t: catalog.Table) -> List['Expr']:
279
- return [cls.from_dict(d, t) for d in dict_list]
280
-
281
- @classmethod
282
- def _from_dict(cls, d: Dict, components: List['Expr'], t: catalog.Table) -> 'Expr':
283
- assert False, 'not implemented'
284
-
285
- def __getitem__(self, index: object) -> 'Expr':
286
- if self.col_type.is_json_type():
287
- return JsonPath(self).__getitem__(index)
288
- if self.col_type.is_array_type():
289
- return ArraySlice(self, index)
290
- raise exc.Error(f'Type {self.col_type} is not subscriptable')
291
-
292
- def __getattr__(self, name: str) -> 'ImageMemberAccess':
293
- """
294
- ex.: <img col>.rotate(60)
295
- """
296
- if not self.col_type.is_image_type():
297
- raise exc.OperationalError(f'Member access not supported on type {self.col_type}: {name}')
298
- return ImageMemberAccess(name, self)
299
-
300
- def __lt__(self, other: object) -> 'Comparison':
301
- return self._make_comparison(ComparisonOperator.LT, other)
302
-
303
- def __le__(self, other: object) -> 'Comparison':
304
- return self._make_comparison(ComparisonOperator.LE, other)
305
-
306
- def __eq__(self, other: object) -> 'Comparison':
307
- return self._make_comparison(ComparisonOperator.EQ, other)
308
-
309
- def __ne__(self, other: object) -> 'Comparison':
310
- return self._make_comparison(ComparisonOperator.NE, other)
311
-
312
- def __gt__(self, other: object) -> 'Comparison':
313
- return self._make_comparison(ComparisonOperator.GT, other)
314
-
315
- def __ge__(self, other: object) -> 'Comparison':
316
- return self._make_comparison(ComparisonOperator.GE, other)
317
-
318
- def _make_comparison(self, op: ComparisonOperator, other: object) -> 'Comparison':
319
- """
320
- other: Union[Expr, LiteralPythonTypes]
321
- """
322
- # TODO: check for compatibility
323
- if isinstance(other, Expr):
324
- return Comparison(op, self, other)
325
- if isinstance(other, typing.get_args(LiteralPythonTypes)):
326
- return Comparison(op, self, Literal(other)) # type: ignore[arg-type]
327
- raise TypeError(f'Other must be Expr or literal: {type(other)}')
328
-
329
- def __add__(self, other: object) -> 'ArithmeticExpr':
330
- return self._make_arithmetic_expr(ArithmeticOperator.ADD, other)
331
-
332
- def __sub__(self, other: object) -> 'ArithmeticExpr':
333
- return self._make_arithmetic_expr(ArithmeticOperator.SUB, other)
334
-
335
- def __mul__(self, other: object) -> 'ArithmeticExpr':
336
- return self._make_arithmetic_expr(ArithmeticOperator.MUL, other)
337
-
338
- def __truediv__(self, other: object) -> 'ArithmeticExpr':
339
- return self._make_arithmetic_expr(ArithmeticOperator.DIV, other)
340
-
341
- def __mod__(self, other: object) -> 'ArithmeticExpr':
342
- return self._make_arithmetic_expr(ArithmeticOperator.MOD, other)
343
-
344
- def _make_arithmetic_expr(self, op: ArithmeticOperator, other: object) -> 'ArithmeticExpr':
345
- """
346
- other: Union[Expr, LiteralPythonTypes]
347
- """
348
- # TODO: check for compatibility
349
- if isinstance(other, Expr):
350
- return ArithmeticExpr(op, self, other)
351
- if isinstance(other, typing.get_args(LiteralPythonTypes)):
352
- return ArithmeticExpr(op, self, Literal(other)) # type: ignore[arg-type]
353
- raise TypeError(f'Other must be Expr or literal: {type(other)}')
354
-
355
-
356
- class ColumnRef(Expr):
357
- def __init__(self, col: catalog.Column):
358
- super().__init__(col.col_type)
359
- self.col = col
360
-
361
- def __getattr__(self, name: str) -> Expr:
362
- if self.col_type.is_json_type():
363
- return JsonPath(self).__getattr__(name)
364
- return super().__getattr__(name)
365
-
366
- def display_name(self) -> str:
367
- return self.col.name
368
-
369
- def _equals(self, other: 'ColumnRef') -> bool:
370
- return self.col == other.col
371
-
372
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
373
- return self.col.sa_col
374
-
375
- def eval(self, data_row: List[Any]) -> None:
376
- # we get called while materializing computed cols
377
- pass
378
-
379
- def _as_dict(self) -> Dict:
380
- return {'col_id': self.col.id}
381
-
382
- @classmethod
383
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> 'Expr':
384
- assert 'col_id' in d
385
- return cls(t.cols_by_id[d['col_id']])
386
-
387
-
388
- class FunctionCall(Expr):
389
- def __init__(self, fn: Function, args: Tuple[Any] = None):
390
- super().__init__(fn.return_type)
391
- self.fn = fn
392
-
393
- if fn.param_types is not None:
394
- # check if arg types match param types and convert values, if necessary
395
- if len(args) != len(fn.param_types):
396
- raise exc.OperationalError(
397
- f"Number of arguments doesn't match parameter list: {args} vs {fn.param_types}")
398
- args = list(args)
399
- for i in range(len(args)):
400
- if not isinstance(args[i], Expr):
401
- # TODO: check non-Expr args
402
- continue
403
- if args[i].col_type == fn.param_types[i]:
404
- # nothing to do
405
- continue
406
- converter = args[i].col_type.conversion_fn(fn.param_types[i])
407
- if converter is None:
408
- raise exc.OperationalError(f'Cannot convert {args[i].col_type} to {fn.param_types[i]}')
409
- if converter == ColumnType.no_conversion:
410
- # nothing to do
411
- continue
412
- convert_fn = Function(fn.param_types[i], [args[i].col_type], eval_fn=converter)
413
- args[i] = FunctionCall(convert_fn, (args[i],))
414
-
415
- self.components = [arg for arg in args if isinstance(arg, Expr)]
416
- self.args = [arg if not isinstance(arg, Expr) else None for arg in args]
417
-
418
- # window function state
419
- self.partition_by_idx = -1 # self.components[self.pb_index:] contains partition_by exprs
420
- self.order_by: List[Expr] = []
421
- # execution state for window functions
422
- self.aggregator: Optional[Any] = self.fn.init_fn() if self.fn.is_aggregate else None
423
- self.current_partition_vals: Optional[List[Any]] = None
424
-
425
- @property
426
- def _eval_fn(self) -> Optional[Callable]:
427
- return self.fn.eval_fn
428
-
429
- def _equals(self, other: 'FunctionCall') -> bool:
430
- if self.fn != other.fn:
431
- return False
432
- if len(self.args) != len(other.args):
433
- return False
434
- for i in range(len(self.args)):
435
- if self.args[i] != other.args[i]:
436
- return False
437
- if self.partition_by_idx != other.partition_by_idx:
438
- return False
439
- if not self.list_equals(self.order_by, other.order_by):
440
- return False
441
- return True
442
-
443
- def window(
444
- self, partition_by: Optional[Union[Expr, List[Expr]]] = None,
445
- order_by: Optional[Union[Expr, List[Expr]]] = None
446
- ) -> 'FunctionCall':
447
- if not self.fn.is_aggregate:
448
- raise exc.Error(f'The window() clause is only allowed for aggregate functions')
449
- if partition_by is None and order_by is None:
450
- raise exc.Error('The window() clause requires at least one parameter not to be None')
451
- if partition_by is not None and not isinstance(partition_by, list):
452
- partition_by = [partition_by]
453
- if order_by is not None:
454
- self.order_by = order_by if isinstance(order_by, list) else [order_by]
455
- # we only need to record the partition_by exprs in self.components, because the order_by values aren't
456
- # used during evaluation (the SQL store will return rows in that order)
457
- if partition_by is not None:
458
- self.partition_by_idx = len(self.components)
459
- self.components.extend(partition_by)
460
- return self
461
-
462
- @property
463
- def partition_by(self) -> List[Expr]:
464
- if self.partition_by_idx == -1:
465
- return []
466
- return self.components[self.partition_by_idx:]
467
-
468
- @property
469
- def is_window_fn_call(self) -> bool:
470
- return self.fn.is_aggregate and (self.partition_by_idx != -1 or len(self.order_by) > 0)
471
-
472
- def get_window_sort_exprs(self) -> List[Expr]:
473
- return [*self.partition_by, *self.order_by]
474
-
475
- @property
476
- def is_agg_fn_call(self) -> bool:
477
- return self.fn.is_aggregate and self.partition_by_idx == -1 and len(self.order_by) == 0
478
-
479
- def get_agg_order_by(self) -> List[Expr]:
480
- assert self.is_agg_fn_call
481
- result: List[Expr] = []
482
- component_idx = 0
483
- for arg_idx in range(len(self.args)):
484
- if arg_idx in self.fn.order_by:
485
- assert self.args[arg_idx] is None # this is an Expr, not something else
486
- result.append(self.components[component_idx])
487
- if self.args[arg_idx] is None:
488
- component_idx += 1
489
- return result
490
-
491
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
492
- # TODO: implement for standard aggregate functions
493
- return None
494
-
495
- def reset_agg(self) -> None:
496
- """
497
- Init agg state
498
- """
499
- assert self.is_agg_fn_call
500
- self.aggregator = self.fn.init_fn()
501
-
502
- def update(self, data_row: List[Any]) -> None:
503
- """
504
- Update agg state
505
- """
506
- assert self.is_agg_fn_call
507
- args = self._make_args(data_row)
508
- self.fn.update_fn(self.aggregator, *args)
509
-
510
- def _make_args(self, data_row: List[Any]) -> List[Any]:
511
- args = copy.copy(self.args)
512
- # fill in missing child values
513
- i = 0
514
- for j in range(len(args)):
515
- if args[j] is None:
516
- args[j] = data_row[self.components[i].data_row_idx]
517
- i += 1
518
- return args
519
-
520
- def eval(self, data_row: List[Any]) -> None:
521
- args = self._make_args(data_row)
522
- if not self.fn.is_aggregate:
523
- data_row[self.data_row_idx] = self.fn.eval_fn(*args)
524
- elif self.is_window_fn_call:
525
- if self.partition_by_idx != -1:
526
- if self.current_partition_vals is None:
527
- self.current_partition_vals = [None] * len(self.partition_by)
528
- partition_vals = [data_row[e.data_row_idx] for e in self.partition_by]
529
- if partition_vals != self.current_partition_vals:
530
- # new partition
531
- self.aggregator = self.fn.init_fn()
532
- self.current_partition_vals = partition_vals
533
- elif self.aggregator is None:
534
- self.aggregator = self.fn.init_fn()
535
- self.fn.update_fn(self.aggregator, *args)
536
- data_row[self.data_row_idx] = self.fn.value_fn(self.aggregator)
537
- else:
538
- assert self.is_agg_fn_call
539
- data_row[self.data_row_idx] = self.fn.value_fn(self.aggregator)
540
-
541
- def _as_dict(self) -> Dict:
542
- result = {'fn': self.fn.as_dict(), 'args': self.args, **super()._as_dict()}
543
- if self.fn.is_aggregate:
544
- result.update({'partition_by_idx': self.partition_by_idx, 'order_by': Expr.as_dict_list(self.order_by)})
545
- return result
546
-
547
- @classmethod
548
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> 'Expr':
549
- assert 'fn' in d
550
- assert 'args' in d
551
- # reassemble args
552
- args = [arg if arg is not None else components[i] for i, arg in enumerate(d['args'])]
553
- fn_call = cls(Function.from_dict(d['fn']), args)
554
- if fn_call.fn.is_aggregate:
555
- fn_call.partition_by_idx = d['partition_by_idx']
556
- fn_call.components.extend(components[fn_call.partition_by_idx:])
557
- fn_call.order_by = Expr.from_dict_list(d['order_by'], t)
558
- return fn_call
559
-
560
-
561
- def _caller_return_type(caller: Expr, *args: object, **kwargs: object) -> ColumnType:
562
- return caller.col_type
563
-
564
- def _convert_return_type(caller: Expr, *args: object, **kwargs: object) -> ColumnType:
565
- mode_str = args[0]
566
- assert isinstance(mode_str, str)
567
- assert isinstance(caller.col_type, ImageType)
568
- return ImageType(
569
- width=caller.col_type.width, height=caller.col_type.height, mode=ImageType.Mode.from_pil(mode_str))
570
-
571
- def _crop_return_type(caller: Expr, *args: object, **kwargs: object) -> ColumnType:
572
- left, upper, right, lower = args[0]
573
- assert isinstance(caller.col_type, ImageType)
574
- return ImageType(width=(right - left), height=(lower - upper), mode=caller.col_type.mode)
575
-
576
- def _resize_return_type(caller: Expr, *args: object, **kwargs: object) -> ColumnType:
577
- w, h = args[0]
578
- assert isinstance(caller.col_type, ImageType)
579
- return ImageType(width=w, height=h, mode=caller.col_type.mode)
580
-
581
- # This only includes methods that return something that can be displayed in pixeltable
582
- # and that make sense to call (counterexample: copy() doesn't make sense to call)
583
- # This is hardcoded here instead of being dynamically extracted from the PIL type stubs because
584
- # doing that is messy and it's unclear whether it has any advantages.
585
- # TODO: how to capture return values like List[Tuple[int, int]]?
586
- # dict from method name to (function to compute value, function to compute return type)
587
- # TODO: JsonTypes() where it should be ArrayType(): need to determine the shape and base type
588
- _PIL_METHOD_INFO: Dict[str, Union[ColumnType, Callable]] = {
589
- 'convert': _convert_return_type,
590
- 'crop': _crop_return_type,
591
- 'effect_spread': _caller_return_type,
592
- 'entropy': FloatType(),
593
- 'filter': _caller_return_type,
594
- 'getbands': ArrayType((None,), ColumnType.Type.STRING),
595
- 'getbbox': ArrayType((4,), ColumnType.Type.INT),
596
- 'getchannel': _caller_return_type,
597
- 'getcolors': JsonType(),
598
- 'getextrema': JsonType(),
599
- 'getpalette': JsonType(),
600
- 'getpixel': JsonType(),
601
- 'getprojection': JsonType(),
602
- 'histogram': JsonType(),
603
- # TODO: what to do with this? it modifies the img in-place
604
- # paste: <ast.Constant object at 0x7f9e9a9be3a0>
605
- 'point': _caller_return_type,
606
- 'quantize': _caller_return_type,
607
- 'reduce': _caller_return_type, # TODO: this is incorrect
608
- 'remap_palette': _caller_return_type,
609
- 'resize': _resize_return_type,
610
- 'rotate': _caller_return_type, # TODO: this is incorrect
611
- # TODO: this returns a Tuple[Image], which we can't express
612
- # split: <ast.Subscript object at 0x7f9e9a9cc9d0>
613
- 'transform': _caller_return_type, # TODO: this is incorrect
614
- 'transpose': _caller_return_type, # TODO: this is incorrect
615
- }
616
-
617
-
618
- # TODO: this doesn't dig up all attrs for actual jpeg images
619
- def _create_pil_attr_info() -> Dict[str, ColumnType]:
620
- # create random Image to inspect for attrs
621
- img = PIL.Image.new('RGB', (100, 100))
622
- # we're only interested in public attrs (including properties)
623
- result: Dict[str, ColumnType] = {}
624
- for name in [name for name in dir(img) if not callable(getattr(img, name)) and not name.startswith('_')]:
625
- if getattr(img, name) is None:
626
- continue
627
- if isinstance(getattr(img, name), str):
628
- result[name] = StringType()
629
- if isinstance(getattr(img, name), int):
630
- result[name] = IntType()
631
- if getattr(img, name) is dict:
632
- result[name] = JsonType()
633
- return result
634
-
635
-
636
- class ImageMemberAccess(Expr):
637
- """
638
- Access of either an attribute or function member of PIL.Image.Image.
639
- Ex.: tbl.img_col_ref.rotate(90), tbl.img_col_ref.width
640
- """
641
- attr_info = _create_pil_attr_info()
642
- special_img_predicates = ['nearest', 'matches']
643
-
644
- def __init__(self, member_name: str, caller: Expr):
645
- if member_name in self.special_img_predicates:
646
- super().__init__(InvalidType()) # requires FunctionCall to return value
647
- elif member_name in _PIL_METHOD_INFO:
648
- super().__init__(InvalidType()) # requires FunctionCall to return value
649
- elif member_name in self.attr_info:
650
- super().__init__(self.attr_info[member_name])
651
- else:
652
- raise exc.OperationalError(f'Unknown Image member: {member_name}')
653
- self.member_name = member_name
654
- self.components = [caller]
655
-
656
- def display_name(self) -> str:
657
- return self.member_name
658
-
659
- @property
660
- def _caller(self) -> Expr:
661
- return self.components[0]
662
-
663
- def _as_dict(self) -> Dict:
664
- return {'member_name': self.member_name, **super()._as_dict()}
665
-
666
- @classmethod
667
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> Expr:
668
- assert 'member_name' in d
669
- assert len(components) == 1
670
- return cls(d['member_name'], components[0])
671
-
672
- # TODO: correct signature?
673
- def __call__(self, *args, **kwargs) -> Union['ImageMethodCall', 'ImageSimilarityPredicate']:
674
- caller = self._caller
675
- call_signature = f'({",".join([type(arg).__name__ for arg in args])})'
676
- if self.member_name == 'nearest':
677
- # - caller must be ColumnRef
678
- # - signature is (PIL.Image.Image)
679
- if not isinstance(caller, ColumnRef):
680
- raise exc.OperationalError(f'nearest(): caller must be an IMAGE column')
681
- if len(args) != 1 or not isinstance(args[0], PIL.Image.Image):
682
- raise exc.OperationalError(
683
- f'nearest(): required signature is (PIL.Image.Image) (passed: {call_signature})')
684
- return ImageSimilarityPredicate(caller, img=args[0])
685
-
686
- if self.member_name == 'matches':
687
- # - caller must be ColumnRef
688
- # - signature is (str)
689
- if not isinstance(caller, ColumnRef):
690
- raise exc.OperationalError(f'matches(): caller must be an IMAGE column')
691
- if len(args) != 1 or not isinstance(args[0], str):
692
- raise exc.OperationalError(f"matches(): required signature is (str) (passed: {call_signature})")
693
- return ImageSimilarityPredicate(caller, text=args[0])
694
-
695
- # TODO: verify signature
696
- return ImageMethodCall(self.member_name, caller, *args, **kwargs)
697
-
698
- def _equals(self, other: 'ImageMemberAccess') -> bool:
699
- return self.member_name == other.member_name
700
-
701
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
702
- return None
703
-
704
- def eval(self, data_row: List[Any]) -> None:
705
- caller_val = data_row[self._caller.data_row_idx]
706
- try:
707
- data_row[self.data_row_idx] = getattr(caller_val, self.member_name)
708
- except AttributeError:
709
- data_row[self.data_row_idx] = None
710
-
711
-
712
- class ImageMethodCall(FunctionCall):
713
- """
714
- Ex.: tbl.img_col_ref.rotate(90)
715
- """
716
- def __init__(self, method_name: str, caller: Expr, *args: object, **kwargs: object):
717
- assert method_name in _PIL_METHOD_INFO
718
- self.method_name = method_name
719
- method_info = _PIL_METHOD_INFO[self.method_name]
720
- if isinstance(method_info, ColumnType):
721
- return_type = method_info
722
- else:
723
- return_type = method_info(caller, *args, **kwargs)
724
- # TODO: register correct parameters
725
- fn = Function(return_type, None, module_name='PIL.Image', eval_symbol=f'Image.{method_name}')
726
- super().__init__(fn, (caller, *args))
727
- # TODO: deal with kwargs
728
-
729
- def display_name(self) -> str:
730
- return self.method_name
731
-
732
- def _as_dict(self) -> Dict:
733
- return {'method_name': self.method_name, 'args': self.args, **super()._as_dict()}
734
-
735
- @classmethod
736
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> 'Expr':
737
- """
738
- We're implementing this, instead of letting FunctionCall handle it, in order to return an
739
- ImageMethodCall instead of a FunctionCall, which is useful for testing that a serialize()/deserialize()
740
- roundtrip ends up with the same Expr.
741
- """
742
- assert 'method_name' in d
743
- assert 'args' in d
744
- # reassemble args, but skip args[0], which is our caller
745
- args = [arg if arg is not None else components[i+1] for i, arg in enumerate(d['args'][1:])]
746
- return cls(d['method_name'], components[0], *args)
747
-
748
-
749
- class JsonPath(Expr):
750
- def __init__(self, anchor: Optional[ColumnRef], path_elements: List[str] = [], scope_idx: int = 0):
751
- """
752
- anchor can be None, in which case this is a relative JsonPath and the anchor is set later via set_anchor().
753
- scope_idx: for relative paths, index of referenced JsonMapper
754
- (0: indicates the immediately preceding JsonMapper, -1: the parent of the immediately preceding mapper, ...)
755
- """
756
- super().__init__(JsonType())
757
- if anchor is not None:
758
- self.components = [anchor]
759
- self.path_elements: List[Union[str, int]] = path_elements
760
- self.compiled_path = jmespath.compile(self._json_path()) if len(path_elements) > 0 else None
761
- self.scope_idx = scope_idx
762
-
763
- def _as_dict(self) -> Dict:
764
- return {'path_elements': self.path_elements, 'scope_idx': self.scope_idx, **super()._as_dict()}
765
-
766
- @classmethod
767
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> Expr:
768
- assert 'path_elements' in d
769
- assert 'scope_idx' in d
770
- assert len(components) <= 1
771
- anchor = components[0] if len(components) == 1 else None
772
- return cls(anchor, d['path_elements'], d['scope_idx'])
773
-
774
- @property
775
- def _anchor(self) -> Optional[Expr]:
776
- return None if len(self.components) == 0 else self.components[0]
777
-
778
- def set_anchor(self, anchor: Expr) -> None:
779
- assert len(self.components) == 0
780
- self.components = [anchor]
781
-
782
- def is_relative_path(self) -> bool:
783
- return self._anchor is None
784
-
785
- def bind_rel_paths(self, mapper: Optional['JsonMapper'] = None) -> None:
786
- if not self.is_relative_path():
787
- return
788
- # TODO: take scope_idx into account
789
- self.set_anchor(mapper.scope_anchor)
790
-
791
- def __call__(self, *args: object, **kwargs: object) -> 'JsonPath':
792
- """
793
- Construct a relative path that references an ancestor of the immediately enclosing JsonMapper.
794
- """
795
- if not self.is_relative_path():
796
- raise exc.OperationalError(f'() for an absolute path is invalid')
797
- if len(args) != 1 or not isinstance(args[0], int) or args[0] >= 0:
798
- raise exc.OperationalError(f'R() requires a negative index')
799
- return JsonPath(None, [], args[0])
800
-
801
- def __getattr__(self, name: str) -> 'JsonPath':
802
- assert isinstance(name, str)
803
- return JsonPath(self._anchor, self.path_elements + [name])
804
-
805
- def __getitem__(self, index: object) -> 'JsonPath':
806
- if isinstance(index, str) and index != '*':
807
- raise exc.OperationalError(f'Invalid json list index: {index}')
808
- return JsonPath(self._anchor, self.path_elements + [index])
809
-
810
- def __rshift__(self, other: object) -> 'JsonMapper':
811
- rhs_expr = Expr.from_object(other)
812
- if rhs_expr is None:
813
- raise exc.OperationalError(f'>> requires an expression on the right-hand side, found {type(other)}')
814
- return JsonMapper(self, rhs_expr)
815
-
816
- def display_name(self) -> str:
817
- anchor_name = self._anchor.display_name() if self._anchor is not None else ''
818
- return f'{anchor_name}.{self._json_path()}'
819
-
820
- def _equals(self, other: 'JsonPath') -> bool:
821
- return self.path_elements == other.path_elements
822
-
823
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
824
- """
825
- Postgres appears to have a bug: jsonb_path_query('{a: [{b: 0}, {b: 1}]}', '$.a.b') returns
826
- *two* rows (each containing col val 0), not a single row with [0, 0].
827
- We need to use a workaround: retrieve the entire dict, then use jmespath to extract the path correctly.
828
- """
829
- #path_str = '$.' + '.'.join(self.path_elements)
830
- #assert isinstance(self._anchor(), ColumnRef)
831
- #return sql.func.jsonb_path_query(self._anchor().col.sa_col, path_str)
832
- return None
833
-
834
- def _json_path(self) -> str:
835
- assert len(self.path_elements) > 0
836
- result: List[str] = []
837
- for element in self.path_elements:
838
- if element == '*':
839
- result.append('[*]')
840
- elif isinstance(element, str):
841
- result.append(f'{"." if len(result) > 0 else ""}{element}')
842
- elif isinstance(element, int):
843
- result.append(f'[{element}]')
844
- return ''.join(result)
845
-
846
- def eval(self, data_row: List[Any]) -> None:
847
- val = data_row[self._anchor.data_row_idx]
848
- if self.compiled_path is not None:
849
- val = self.compiled_path.search(val)
850
- data_row[self.data_row_idx] = val
851
-
852
-
853
- RELATIVE_PATH_ROOT = JsonPath(None)
854
-
855
-
856
- class Literal(Expr):
857
- def __init__(self, val: LiteralPythonTypes):
858
- if isinstance(val, str):
859
- super().__init__(StringType())
860
- if isinstance(val, int):
861
- super().__init__(IntType())
862
- if isinstance(val, float):
863
- super().__init__(FloatType())
864
- if isinstance(val, bool):
865
- super().__init__(BoolType())
866
- if isinstance(val, datetime.datetime) or isinstance(val, datetime.date):
867
- super().__init__(TimestampType())
868
- self.val = val
869
-
870
- def display_name(self) -> str:
871
- return 'Literal'
872
-
873
- def _equals(self, other: 'Literal') -> bool:
874
- return self.val == other.val
875
-
876
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
877
- # we need to return something here so that we can generate a Where clause for predicates
878
- # that involve literals (like Where c > 0)
879
- return sql.sql.expression.literal(self.val)
880
-
881
- def eval(self, data_row: List[Any]) -> None:
882
- # this will be called, even though sql_expr() does not return None
883
- data_row[self.data_row_idx] = self.val
884
-
885
- def _as_dict(self) -> Dict:
886
- return {'val': self.val, **super()._as_dict()}
887
-
888
- @classmethod
889
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> Expr:
890
- assert 'val' in d
891
- return cls(d['val'])
892
-
893
-
894
- class InlineDict(Expr):
895
- """
896
- Dictionary 'literal' which can use Exprs as values.
897
- """
898
- def __init__(self, d: Dict):
899
- super().__init__(JsonType()) # we need to call this in order to populate self.components
900
- # dict_items contains
901
- # - for Expr fields: (key, index into components, None)
902
- # - for non-Expr fields: (key, -1, value)
903
- self.dict_items: List[Tuple[str, int, Any]] = []
904
- for key, val in d.items():
905
- if not isinstance(key, str):
906
- raise exc.OperationalError(f'Dictionary requires string keys, {key} has type {type(key)}')
907
- val = copy.deepcopy(val)
908
- if isinstance(val, dict):
909
- val = InlineDict(val)
910
- if isinstance(val, Expr):
911
- self.dict_items.append((key, len(self.components), None))
912
- self.components.append(val)
913
- else:
914
- self.dict_items.append((key, -1, val))
915
-
916
- self.type_spec: Optional[Dict[str, ColumnType]] = {}
917
- for key, idx, _ in self.dict_items:
918
- if idx == -1:
919
- # TODO: implement type inference for values
920
- self.type_spec = None
921
- break
922
- self.type_spec[key] = self.components[idx].col_type
923
- self.col_type = JsonType(self.type_spec)
924
-
925
-
926
- def _equals(self, other: 'InlineDict') -> bool:
927
- return self.dict_items == other.dict_items
928
-
929
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
930
- return None
931
-
932
- def eval(self, data_row: List[Any]) -> None:
933
- result = {}
934
- for key, idx, val in self.dict_items:
935
- assert isinstance(key, str)
936
- if idx >= 0:
937
- result[key] = data_row[self.components[idx].data_row_idx]
938
- else:
939
- result[key] = copy.deepcopy(val)
940
- data_row[self.data_row_idx] = result
941
-
942
- def _as_dict(self) -> Dict:
943
- return {'dict_items': self.dict_items, **super()._as_dict()}
944
-
945
- @classmethod
946
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> Expr:
947
- assert 'dict_items' in d
948
- arg: Dict[str, Any] = {}
949
- for key, idx, val in d['dict_items']:
950
- if idx >= 0:
951
- arg[key] = components[idx]
952
- else:
953
- arg[key] = val
954
- return cls(arg)
955
-
956
-
957
- class InlineArray(Expr):
958
- """
959
- Array 'literal' which can use Exprs as values.
960
- """
961
- def __init__(self, elements: Tuple):
962
- # we need to call this in order to populate self.components
963
- super().__init__(ArrayType((len(elements),), ColumnType.Type.INT))
964
-
965
- # elements contains
966
- # - for Expr elements: (index into components, None)
967
- # - for non-Expr elements: (-1, value)
968
- self.elements: List[Tuple[int, Any]] = []
969
- for el in elements:
970
- el = copy.deepcopy(el)
971
- if isinstance(el, list):
972
- el = InlineArray(tuple(el))
973
- if isinstance(el, Expr):
974
- self.elements.append((len(self.components), None))
975
- self.components.append(el)
976
- else:
977
- self.elements.append((-1, el))
978
-
979
- element_type = InvalidType()
980
- for idx, val in self.elements:
981
- if idx >= 0:
982
- element_type = ColumnType.supertype(element_type, self.components[idx].col_type)
983
- else:
984
- element_type = ColumnType.supertype(element_type, ColumnType.get_value_type(val))
985
- if element_type is None:
986
- # there is no common element type: this is a json value, not an array
987
- # TODO: make sure this doesn't contain Images
988
- self.col_type = JsonType()
989
- return
990
-
991
- if element_type.is_scalar_type():
992
- self.col_type = ArrayType((len(self.elements),), element_type.type_enum)
993
- elif element_type.is_array_type():
994
- assert isinstance(element_type, ArrayType)
995
- self.col_type = ArrayType((len(self.elements), *element_type.shape), element_type.dtype)
996
- elif element_type.is_json_type():
997
- self.col_type = JsonType()
998
-
999
-
1000
- def _equals(self, other: 'InlineDict') -> bool:
1001
- return self.elements == other.elements
1002
-
1003
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
1004
- return None
1005
-
1006
- def eval(self, data_row: List[Any]) -> None:
1007
- result = [None] * len(self.elements)
1008
- for i, (child_idx, val) in enumerate(self.elements):
1009
- if child_idx >= 0:
1010
- result[i] = data_row[self.components[child_idx].data_row_idx]
1011
- else:
1012
- result[i] = copy.deepcopy(val)
1013
- data_row[self.data_row_idx] = np.array(result)
1014
-
1015
- def _as_dict(self) -> Dict:
1016
- return {'elements': self.elements, **super()._as_dict()}
1017
-
1018
- @classmethod
1019
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> Expr:
1020
- assert 'elements' in d
1021
- arg: List[Any] = []
1022
- for idx, val in d['elements']:
1023
- if idx >= 0:
1024
- arg.append(components[idx])
1025
- else:
1026
- arg.append(val)
1027
- return cls(tuple(arg))
1028
-
1029
-
1030
- class ArraySlice(Expr):
1031
- """
1032
- Slice operation on an array, eg, t.array_col[:, 1:2].
1033
- """
1034
- def __init__(self, arr: Expr, index: Tuple):
1035
- assert arr.col_type.is_array_type()
1036
- # determine result type
1037
- super().__init__(arr.col_type)
1038
- self.components = [arr]
1039
- self.index = index
1040
-
1041
- @property
1042
- def _array(self) -> Expr:
1043
- return self.components[0]
1044
-
1045
- def _equals(self, other: 'ArraySlice') -> bool:
1046
- return self.index == other.index
1047
-
1048
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
1049
- return None
1050
-
1051
- def eval(self, data_row: List[Any]) -> None:
1052
- val = data_row[self._array.data_row_idx]
1053
- data_row[self.data_row_idx] = val[self.index]
1054
-
1055
- def _as_dict(self) -> Dict:
1056
- index = []
1057
- for el in self.index:
1058
- if isinstance(el, slice):
1059
- index.append([el.start, el.stop, el.step])
1060
- else:
1061
- index.append(el)
1062
- return {'index': index, **super()._as_dict()}
1063
-
1064
- @classmethod
1065
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> Expr:
1066
- assert 'index' in d
1067
- index = []
1068
- for el in d['index']:
1069
- if isinstance(el, list):
1070
- index.append(slice(el[0], el[1], el[2]))
1071
- else:
1072
- index.append(el)
1073
- return cls(components[0], tuple(index))
1074
-
1075
-
1076
- class Predicate(Expr):
1077
- def __init__(self) -> None:
1078
- super().__init__(BoolType())
1079
-
1080
- def extract_sql_predicate(self) -> Tuple[Optional[sql.sql.expression.ClauseElement], Optional['Predicate']]:
1081
- """
1082
- Return ClauseElement for what can be evaluated in SQL and a predicate for the remainder that needs to be
1083
- evaluated in Python.
1084
- Needed to for predicate push-down into SQL.
1085
- """
1086
- e = self.sql_expr()
1087
- return (None, self) if e is None else (e, None)
1088
-
1089
- def split_conjuncts(
1090
- self, condition: Callable[['Predicate'], bool]) -> Tuple[List['Predicate'], Optional['Predicate']]:
1091
- """
1092
- Returns clauses of a conjunction that meet condition in the first element.
1093
- The second element contains remaining clauses, rolled into a conjunction.
1094
- """
1095
- if condition(self):
1096
- return [self], None
1097
- else:
1098
- return [], self
1099
-
1100
- def __and__(self, other: object) -> 'CompoundPredicate':
1101
- if not isinstance(other, Expr):
1102
- raise TypeError(f'Other needs to be an expression: {type(other)}')
1103
- if not other.col_type.is_bool_type():
1104
- raise TypeError(f'Other needs to be an expression that returns a boolean: {other.col_type}')
1105
- return CompoundPredicate(LogicalOperator.AND, [self, other])
1106
-
1107
- def __or__(self, other: object) -> 'CompoundPredicate':
1108
- if not isinstance(other, Expr):
1109
- raise TypeError(f'Other needs to be an expression: {type(other)}')
1110
- if not other.col_type.is_bool_type():
1111
- raise TypeError(f'Other needs to be an expression that returns a boolean: {other.col_type}')
1112
- return CompoundPredicate(LogicalOperator.OR, [self, other])
1113
-
1114
- def __invert__(self) -> 'CompoundPredicate':
1115
- return CompoundPredicate(LogicalOperator.NOT, [self])
1116
-
1117
-
1118
- class CompoundPredicate(Predicate):
1119
- def __init__(self, operator: LogicalOperator, operands: List[Predicate]):
1120
- super().__init__()
1121
- self.operator = operator
1122
- # operands are stored in self.components
1123
- if self.operator == LogicalOperator.NOT:
1124
- assert len(operands) == 1
1125
- self.components = operands
1126
- else:
1127
- assert len(operands) > 1
1128
- self.operands: List[Predicate] = []
1129
- for operand in operands:
1130
- self._merge_operand(operand)
1131
-
1132
- @classmethod
1133
- def make_conjunction(cls, operands: List[Predicate]) -> Optional[Predicate]:
1134
- if len(operands) == 0:
1135
- return None
1136
- if len(operands) == 1:
1137
- return operands[0]
1138
- return CompoundPredicate(LogicalOperator.AND, operands)
1139
-
1140
- def _merge_operand(self, operand: Predicate) -> None:
1141
- """
1142
- Merge this operand, if possible, otherwise simply record it.
1143
- """
1144
- if isinstance(operand, CompoundPredicate) and operand.operator == self.operator:
1145
- # this can be merged
1146
- for child_op in operand.components:
1147
- self._merge_operand(child_op)
1148
- else:
1149
- self.components.append(operand)
1150
-
1151
- def _equals(self, other: 'CompoundPredicate') -> bool:
1152
- return self.operator == other.operator
1153
-
1154
- def extract_sql_predicate(self) -> Tuple[Optional[sql.sql.expression.ClauseElement], Optional[Predicate]]:
1155
- if self.operator == LogicalOperator.NOT:
1156
- e = self.components[0].sql_expr()
1157
- return (None, self) if e is None else (e, None)
1158
-
1159
- sql_exprs = [op.sql_expr() for op in self.components]
1160
- if self.operator == LogicalOperator.OR and any(e is None for e in sql_exprs):
1161
- # if any clause of a | can't be evaluated in SQL, we need to evaluate everything in Python
1162
- return None, self
1163
- if not(any(e is None for e in sql_exprs)):
1164
- # we can do everything in SQL
1165
- return self.sql_expr(), None
1166
-
1167
- assert self.operator == LogicalOperator.AND
1168
- if not any(e is not None for e in sql_exprs):
1169
- # there's nothing that can be done in SQL
1170
- return None, self
1171
-
1172
- sql_preds = [e for e in sql_exprs if e is not None]
1173
- other_preds = [self.components[i] for i, e in enumerate(sql_exprs) if e is None]
1174
- assert len(sql_preds) > 0
1175
- combined_sql_pred = sql.and_(*sql_preds)
1176
- combined_other = self.make_conjunction(other_preds)
1177
- return combined_sql_pred, combined_other
1178
-
1179
- def split_conjuncts(
1180
- self, condition: Callable[['Predicate'], bool]) -> Tuple[List['Predicate'], Optional['Predicate']]:
1181
- if self.operator == LogicalOperator.OR or self.operator == LogicalOperator.NOT:
1182
- return super().split_conjuncts(condition)
1183
- matches = [op for op in self.components if condition(op)]
1184
- non_matches = [op for op in self.components if not condition(op)]
1185
- return (matches, self.make_conjunction(non_matches))
1186
-
1187
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
1188
- sql_exprs = [op.sql_expr() for op in self.components]
1189
- if any(e is None for e in sql_exprs):
1190
- return None
1191
- if self.operator == LogicalOperator.NOT:
1192
- assert len(sql_exprs) == 1
1193
- return sql.not_(sql_exprs[0])
1194
- assert len(sql_exprs) > 1
1195
- operator = sql.and_ if self.operator == LogicalOperator.AND else sql.or_
1196
- combined = operator(*sql_exprs)
1197
- return combined
1198
-
1199
- def eval(self, data_row: List[Any]) -> None:
1200
- if self.operator == LogicalOperator.NOT:
1201
- data_row[self.data_row_idx] = not data_row[self.components[0].data_row_idx]
1202
- else:
1203
- val = True if self.operator == LogicalOperator.AND else False
1204
- op_function = operator.and_ if self.operator == LogicalOperator.AND else operator.or_
1205
- for op in self.components:
1206
- val = op_function(val, data_row[op.data_row_idx])
1207
- data_row[self.data_row_idx] = val
1208
-
1209
- def _as_dict(self) -> Dict:
1210
- return {'operator': self.operator.value, **super()._as_dict()}
1211
-
1212
- @classmethod
1213
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> Expr:
1214
- assert 'operator' in d
1215
- return cls(LogicalOperator(d['operator']), components)
1216
-
1217
-
1218
- class Comparison(Predicate):
1219
- def __init__(self, operator: ComparisonOperator, op1: Expr, op2: Expr):
1220
- super().__init__()
1221
- self.operator = operator
1222
- self.components = [op1, op2]
1223
-
1224
- def _equals(self, other: 'Comparison') -> bool:
1225
- return self.operator == other.operator
1226
-
1227
- @property
1228
- def _op1(self) -> Expr:
1229
- return self.components[0]
1230
-
1231
- @property
1232
- def _op2(self) -> Expr:
1233
- return self.components[1]
1234
-
1235
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
1236
- left = self._op1.sql_expr()
1237
- right = self._op2.sql_expr()
1238
- if left is None or right is None:
1239
- return None
1240
- if self.operator == ComparisonOperator.LT:
1241
- return left < right
1242
- if self.operator == ComparisonOperator.LE:
1243
- return left <= right
1244
- if self.operator == ComparisonOperator.EQ:
1245
- return left == right
1246
- if self.operator == ComparisonOperator.NE:
1247
- return left != right
1248
- if self.operator == ComparisonOperator.GT:
1249
- return left > right
1250
- if self.operator == ComparisonOperator.GE:
1251
- return left >= right
1252
-
1253
- def eval(self, data_row: List[Any]) -> None:
1254
- if self.operator == ComparisonOperator.LT:
1255
- data_row[self.data_row_idx] = data_row[self._op1.data_row_idx] < data_row[self._op2.data_row_idx]
1256
- elif self.operator == ComparisonOperator.LE:
1257
- data_row[self.data_row_idx] = data_row[self._op1.data_row_idx] <= data_row[self._op2.data_row_idx]
1258
- elif self.operator == ComparisonOperator.EQ:
1259
- data_row[self.data_row_idx] = data_row[self._op1.data_row_idx] == data_row[self._op2.data_row_idx]
1260
- elif self.operator == ComparisonOperator.NE:
1261
- data_row[self.data_row_idx] = data_row[self._op1.data_row_idx] != data_row[self._op2.data_row_idx]
1262
- elif self.operator == ComparisonOperator.GT:
1263
- data_row[self.data_row_idx] = data_row[self._op1.data_row_idx] > data_row[self._op2.data_row_idx]
1264
- elif self.operator == ComparisonOperator.GE:
1265
- data_row[self.data_row_idx] = data_row[self._op1.data_row_idx] >= data_row[self._op2.data_row_idx]
1266
-
1267
- def _as_dict(self) -> Dict:
1268
- return {'operator': self.operator.value, **super()._as_dict()}
1269
-
1270
- @classmethod
1271
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> Expr:
1272
- assert 'operator' in d
1273
- return cls(ComparisonOperator(d['operator']), components[0], components[1])
1274
-
1275
-
1276
- class ImageSimilarityPredicate(Predicate):
1277
- def __init__(self, img_col: ColumnRef, img: Optional[PIL.Image.Image] = None, text: Optional[str] = None):
1278
- assert (img is None) != (text is None)
1279
- super().__init__()
1280
- self.img_col_ref = img_col
1281
- self.components = [img_col]
1282
- self.img = img
1283
- self.text = text
1284
-
1285
- def embedding(self) -> np.ndarray:
1286
- if self.text is not None:
1287
- return clip.encode_text(self.text)
1288
- else:
1289
- return clip.encode_image(self.img)
1290
-
1291
- def _equals(self, other: 'ImageSimilarityPredicate') -> bool:
1292
- return False
1293
-
1294
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
1295
- return None
1296
-
1297
- def eval(self, data_row: List[Any]) -> None:
1298
- assert False
1299
-
1300
- def _as_dict(self) -> Dict:
1301
- assert False, 'not implemented'
1302
- # TODO: convert self.img into a serializable string
1303
- return {'img': self.img, 'text': self.text, **super()._as_dict()}
1304
-
1305
- @classmethod
1306
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> Expr:
1307
- assert 'img' in d
1308
- assert 'text' in d
1309
- assert len(components) == 1
1310
- return cls(components[0], d['img'], d['text'])
1311
-
1312
-
1313
- class ArithmeticExpr(Expr):
1314
- """
1315
- Allows arithmetic exprs on json paths
1316
- """
1317
- def __init__(self, operator: ArithmeticOperator, op1: Expr, op2: Expr):
1318
- if not op1.col_type.is_numeric_type() and not op1.col_type.is_json_type():
1319
- raise exc.OperationalError(f'{operator} requires numeric type: {op1} has type {op1.col_type}')
1320
- if not op2.col_type.is_numeric_type() and not op2.col_type.is_json_type():
1321
- raise exc.OperationalError(f'{operator} requires numeric type: {op2} has type {op2.col_type}')
1322
- # TODO: determine most specific common supertype
1323
- if op1.col_type.is_json_type() or op2.col_type.is_json_type():
1324
- # we assume it's a float
1325
- super().__init__(FloatType())
1326
- else:
1327
- super().__init__(ColumnType.supertype(op1.col_type, op2.col_type))
1328
- self.operator = operator
1329
- self.components = [op1, op2]
1330
-
1331
- def _equals(self, other: 'ArithmeticExpr') -> bool:
1332
- return self.operator == other.operator
1333
-
1334
- @property
1335
- def _op1(self) -> Expr:
1336
- return self.components[0]
1337
-
1338
- @property
1339
- def _op2(self) -> Expr:
1340
- return self.components[1]
1341
-
1342
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
1343
- left = self._op1.sql_expr()
1344
- right = self._op2.sql_expr()
1345
- if left is None or right is None:
1346
- return None
1347
- if self.operator == ArithmeticOperator.ADD:
1348
- return left + right
1349
- if self.operator == ArithmeticOperator.SUB:
1350
- return left - right
1351
- if self.operator == ArithmeticOperator.MUL:
1352
- return left * right
1353
- if self.operator == ArithmeticOperator.DIV:
1354
- return left / right
1355
- if self.operator == ArithmeticOperator.MOD:
1356
- return left % right
1357
-
1358
- def eval(self, data_row: List[Any]) -> None:
1359
- op1_val = data_row[self._op1.data_row_idx]
1360
- op2_val = data_row[self._op2.data_row_idx]
1361
- # check types if we couldn't do that prior to execution
1362
- if self._op1.col_type.is_json_type() and not isinstance(op1_val, int) and not isinstance(op1_val, float):
1363
- raise exc.OperationalError(f'{self.operator} requires numeric type: {self._op1} has type {type(op1_val)}')
1364
- if self._op2.col_type.is_json_type() and not isinstance(op2_val, int) and not isinstance(op2_val, float):
1365
- raise exc.OperationalError(f'{self.operator} requires numeric type: {self._op2} has type {type(op2_val)}')
1366
-
1367
- if self.operator == ArithmeticOperator.ADD:
1368
- data_row[self.data_row_idx] = op1_val + op2_val
1369
- elif self.operator == ArithmeticOperator.SUB:
1370
- data_row[self.data_row_idx] = op1_val - op2_val
1371
- elif self.operator == ArithmeticOperator.MUL:
1372
- data_row[self.data_row_idx] = op1_val * op2_val
1373
- elif self.operator == ArithmeticOperator.DIV:
1374
- data_row[self.data_row_idx] = op1_val / op2_val
1375
- elif self.operator == ArithmeticOperator.MOD:
1376
- data_row[self.data_row_idx] = op1_val % op2_val
1377
-
1378
- def _as_dict(self) -> Dict:
1379
- return {'operator': self.operator.value, **super()._as_dict()}
1380
-
1381
- @classmethod
1382
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> Expr:
1383
- assert 'operator' in d
1384
- assert len(components) == 2
1385
- return cls(ArithmeticOperator(d['operator']), components[0], components[1])
1386
-
1387
-
1388
- class ObjectRef(Expr):
1389
- """
1390
- Reference to an intermediate result, such as the "scope variable" produced by a JsonMapper.
1391
- The object is generated/materialized elsewhere and establishes a new scope.
1392
- """
1393
- def __init__(self, scope: ExprScope, owner: 'JsonMapper'):
1394
- # TODO: do we need an Unknown type after all?
1395
- super().__init__(JsonType()) # JsonType: this could be anything
1396
- self._scope = scope
1397
- self.owner = owner
1398
-
1399
- def scope(self) -> ExprScope:
1400
- return self._scope
1401
-
1402
- def _equals(self, other: 'ObjectRef') -> bool:
1403
- return self.owner is other.owner
1404
-
1405
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
1406
- return None
1407
-
1408
- def eval(self, data_row: List[Any]) -> None:
1409
- # this will be called, but the value has already been materialized elsewhere
1410
- pass
1411
-
1412
-
1413
- class JsonMapper(Expr):
1414
- """
1415
- JsonMapper transforms the list output of a JsonPath by applying a target expr to every element of the list.
1416
- The target expr would typically contain relative JsonPaths, which are bound to an ObjectRef, which in turn
1417
- is populated by JsonMapper.eval(). The JsonMapper effectively creates a new scope for its target expr.
1418
- """
1419
- def __init__(self, src_expr: Expr, target_expr: Expr):
1420
- # TODO: type spec should be List[target_expr.col_type]
1421
- super().__init__(JsonType())
1422
-
1423
- # we're creating a new scope, but we don't know yet whether this is nested within another JsonMapper;
1424
- # this gets resolved in bind_rel_paths(); for now we assume we're in the global scope
1425
- self.target_expr_scope = ExprScope(_GLOBAL_SCOPE)
1426
-
1427
- scope_anchor = ObjectRef(self.target_expr_scope, self)
1428
- self.components = [src_expr, target_expr, scope_anchor]
1429
- self.parent_mapper: Optional[JsonMapper] = None
1430
- self.evaluator: Optional[ExprEvaluator] = None
1431
-
1432
- def bind_rel_paths(self, mapper: Optional['JsonMapper']) -> None:
1433
- self._src_expr.bind_rel_paths(mapper)
1434
- self._target_expr.bind_rel_paths(self)
1435
- self.parent_mapper = mapper
1436
- parent_scope = _GLOBAL_SCOPE if mapper is None else mapper.target_expr_scope
1437
- self.target_expr_scope.parent = parent_scope
1438
-
1439
- def scope(self) -> ExprScope:
1440
- # need to ignore target_expr
1441
- return self._src_expr.scope()
1442
-
1443
- def dependencies(self) -> List['Expr']:
1444
- result = [self._src_expr]
1445
- result.extend(self._target_dependencies(self._target_expr))
1446
- return result
1447
-
1448
- def _target_dependencies(self, e: Expr) -> List['Expr']:
1449
- """
1450
- Return all subexprs of e of which the scope isn't contained in target_expr_scope.
1451
- Those need to be evaluated before us.
1452
- """
1453
- expr_scope = e.scope()
1454
- if not expr_scope.is_contained_in(self.target_expr_scope):
1455
- return [e]
1456
- result: List[Expr] = []
1457
- for c in e.components:
1458
- result.extend(self._target_dependencies(c))
1459
- return result
1460
-
1461
- def equals(self, other: 'Expr') -> bool:
1462
- """
1463
- We override equals() because we need to avoid comparing our scope anchor.
1464
- """
1465
- if type(self) != type(other):
1466
- return False
1467
- return self._src_expr.equals(other._src_expr) and self._target_expr.equals(other._target_expr)
1468
-
1469
- @property
1470
- def _src_expr(self) -> Expr:
1471
- return self.components[0]
1472
-
1473
- @property
1474
- def _target_expr(self) -> Expr:
1475
- return self.components[1]
1476
-
1477
- @property
1478
- def scope_anchor(self) -> Expr:
1479
- return self.components[2]
1480
-
1481
- def _equals(self, other: 'JsonMapper') -> bool:
1482
- return True
1483
-
1484
- def sql_expr(self) -> Optional[sql.sql.expression.ClauseElement]:
1485
- return None
1486
-
1487
- def eval(self, data_row: List[Any]) -> None:
1488
- # this will be called, but the value has already been materialized elsewhere
1489
- src = data_row[self._src_expr.data_row_idx]
1490
- if not isinstance(src, list):
1491
- # invalid/non-list src path
1492
- data_row[self.data_row_idx] = None
1493
- return
1494
-
1495
- result = [None] * len(src)
1496
- if self.evaluator is None:
1497
- self.evaluator = ExprEvaluator([self._target_expr], None)
1498
- for i, val in enumerate(src):
1499
- data_row[self.scope_anchor.data_row_idx] = val
1500
- # materialize target_expr
1501
- self.evaluator.eval((), data_row)
1502
- result[i] = data_row[self._target_expr.data_row_idx]
1503
- data_row[self.data_row_idx] = result
1504
-
1505
- def _as_dict(self) -> Dict:
1506
- """
1507
- We need to avoid serializing component[2], which is an ObjectRef.
1508
- """
1509
- return {'components': [c.as_dict() for c in self.components[0:2]]}
1510
-
1511
- @classmethod
1512
- def _from_dict(cls, d: Dict, components: List[Expr], t: catalog.Table) -> Expr:
1513
- assert len(components) == 2
1514
- return cls(components[0], components[1])
1515
-
1516
-
1517
- class ExprEvaluator:
1518
- """
1519
- Materializes values for a list of output exprs, subject to passing a filter.
1520
-
1521
- ex.: the select list [<img col 1>.alpha_composite(<img col 2>), <text col 3>]
1522
- - sql row composition: [<file path col 1>, <file path col 2>, <text col 3>]
1523
- - data row composition: [Image, str, Image, Image]
1524
- - copy_exprs: [
1525
- ColumnRef(data_row_idx: 2, sql_row_idx: 0, col: <col 1>)
1526
- ColumnRef(data_row_idx: 3, sql_row_idx: 1, col: <col 2>)
1527
- ColumnRef(data_row_idx: 1, sql_row_idx: 2, col: <col 3>)
1528
- ]
1529
- - eval_exprs: [ImageMethodCall(data_row_idx: 0, sql_row_id: -1)]
1530
- """
1531
-
1532
- def __init__(self, output_exprs: List[Expr], filter: Optional[Predicate], with_sql: bool = True):
1533
- # TODO: add self.literal_exprs so that we don't need to retrieve those from SQL
1534
- # exprs that are materialized directly via SQL query and for which results can be copied from sql row
1535
- # into data row
1536
- self.filter_copy_exprs: List[Expr] = []
1537
- self.output_copy_exprs: List[Expr] = []
1538
- # exprs for which we need to call eval() to compute the value; must be called in the order stored here
1539
- self.filter_eval_exprs: List[Expr] = []
1540
- self.output_eval_exprs: List[Expr] = []
1541
- self.filter = filter
1542
-
1543
- unique_ids: Set[int] = set()
1544
- # analyze filter first, so that it can be evaluated before output_exprs
1545
- if filter is not None:
1546
- self._analyze_expr(filter, filter.scope(), self.filter_copy_exprs, self.filter_eval_exprs, unique_ids)
1547
- for expr in output_exprs:
1548
- self._analyze_expr(expr, expr.scope(), self.output_copy_exprs, self.output_eval_exprs, unique_ids)
1549
-
1550
- def _analyze_expr(
1551
- self, expr: Expr, scope: ExprScope, copy_exprs: List[Expr], eval_exprs: List[Expr], unique_ids: Set[int]
1552
- ) -> None:
1553
- """
1554
- Determine unique dependencies of expr and accumulate those in copy_exprs and eval_exprs.
1555
- Dependencies that are not in 'scope' are assumed to have been materialized already and are ignored.
1556
- """
1557
- if expr.data_row_idx in unique_ids:
1558
- return
1559
- unique_ids.add(expr.data_row_idx)
1560
-
1561
- if expr.sql_row_idx >= 0:
1562
- # this can be copied, no need to look at its dependencies
1563
- copy_exprs.append(expr)
1564
- return
1565
-
1566
- for d in expr.dependencies():
1567
- if d.scope() != scope:
1568
- continue
1569
- self._analyze_expr(d, scope, copy_exprs, eval_exprs, unique_ids)
1570
- # make sure to eval() this after its dependencies
1571
- eval_exprs.append(expr)
1572
-
1573
- def eval(self, sql_row: Tuple[Any], data_row: List[Any]) -> bool:
1574
- """
1575
- If the filter predicate evaluates to True, populates the data_row slots of the output_exprs.
1576
- """
1577
- if self.filter is not None:
1578
- # we need to evaluate the remaining filter predicate first
1579
- self._copy_to_data_row(self.filter_copy_exprs, sql_row, data_row)
1580
- for expr in self.filter_eval_exprs:
1581
- expr.eval(data_row)
1582
- if not data_row[self.filter.data_row_idx]:
1583
- return False
1584
-
1585
- # materialize output_exprs
1586
- self._copy_to_data_row(self.output_copy_exprs, sql_row, data_row)
1587
- for expr in self.output_eval_exprs:
1588
- expr.eval(data_row)
1589
- return True
1590
-
1591
- def _copy_to_data_row(self, exprs: List[Expr], sql_row: Tuple[Any], data_row: List[Any]):
1592
- """
1593
- Copy expr values from sql to data row.
1594
- """
1595
- for expr in exprs:
1596
- assert expr.sql_row_idx != -1
1597
- if expr.col_type.is_image_type():
1598
- # column value is a file path that we need to open
1599
- file_path = sql_row[expr.sql_row_idx]
1600
- try:
1601
- img = PIL.Image.open(file_path)
1602
- #img.thumbnail((128, 128))
1603
- data_row[expr.data_row_idx] = img
1604
- except Exception:
1605
- raise exc.OperationalError(f'Error reading image file: {file_path}')
1606
- elif expr.col_type.is_array_type():
1607
- # column value is a saved numpy array
1608
- array_data = sql_row[expr.sql_row_idx]
1609
- data_row[expr.data_row_idx] = np.load(io.BytesIO(array_data))
1610
- else:
1611
- data_row[expr.data_row_idx] = sql_row[expr.sql_row_idx]
1612
-
1613
-
1614
- class UniqueExprSet:
1615
- """
1616
- We want to avoid duplicate expr evaluation, so we keep track of unique exprs (duplicates share the
1617
- same data_row_idx). However, __eq__() doesn't work for sets, so we use a list here.
1618
- """
1619
- def __init__(self):
1620
- self.unique_exprs: List[Expr] = []
1621
-
1622
- def add(self, expr: Expr) -> bool:
1623
- """
1624
- If expr is not unique, sets expr.data/sql_row_idx to that of the already-recorded duplicate and returns
1625
- False, otherwise returns True.
1626
- """
1627
- try:
1628
- existing = next(e for e in self.unique_exprs if e.equals(expr))
1629
- expr.data_row_idx = existing.data_row_idx
1630
- expr.sql_row_idx = existing.sql_row_idx
1631
- return False
1632
- except StopIteration:
1633
- self.unique_exprs.append(expr)
1634
- return True
1635
-
1636
- def __iter__(self) -> Iterator[Expr]:
1637
- return iter(self.unique_exprs)
1638
-
1639
-
1640
- class ExprEvalCtx:
1641
- """
1642
- Assigns execution state necessary to materialize a list of Exprs into a data row:
1643
- - Expr.sql_/data_row_idx
1644
-
1645
- Data row:
1646
- - List[Any]
1647
- - contains slots for all materialized component exprs (ie, not for predicates that turn into the SQL Where clause):
1648
- a) every DataFrame.select_list expr
1649
- b) the parts of the where clause predicate that cannot be evaluated in SQL
1650
- b) every component expr of a) and b), recursively
1651
- - IMAGE columns are materialized immediately as a PIL.Image.Image
1652
-
1653
- ex.: the select list [<img col 1>.alpha_composite(<img col 2>), <text col 3>]
1654
- - sql row composition: [<file path col 1>, <file path col 2>, <text col 3>]
1655
- - data row composition: [Image, str, Image, Image]
1656
- """
1657
-
1658
- def __init__(self, output_exprs: List[Expr], filter: Optional[Predicate]):
1659
- """
1660
- Init for list of materialized exprs and a possible filter.
1661
- with_sql == True: if an expr e has a e.sql_expr(), its components do not need to be materialized
1662
- (and consequently also don't get data_row_idx assigned) and the expr value is produced via a Select stmt
1663
- """
1664
-
1665
- # objects needed to materialize the SQL result row
1666
- self.sql_exprs: List[sql.sql.expression.ClauseElement] = []
1667
- self.unique_exprs = UniqueExprSet()
1668
- self.next_data_row_idx = 0
1669
-
1670
- if filter is not None:
1671
- self._analyze_expr(filter)
1672
- for expr in output_exprs:
1673
- self._analyze_expr(expr)
1674
-
1675
- @property
1676
- def num_materialized(self) -> int:
1677
- return self.next_data_row_idx
1678
-
1679
- def _analyze_expr(self, expr: Expr) -> None:
1680
- """
1681
- Assign Expr.data_row_idx and Expr.sql_row_idx.
1682
- """
1683
- if not self.unique_exprs.add(expr):
1684
- # nothing left to do
1685
- return
1686
-
1687
- sql_expr = expr.sql_expr()
1688
- # if this can be materialized via SQL we don't need to look at its components;
1689
- # we special-case Literals because we don't want to have to materialize them via SQL
1690
- if sql_expr is not None and not isinstance(expr, Literal):
1691
- assert expr.data_row_idx < 0
1692
- expr.data_row_idx = self.next_data_row_idx
1693
- self.next_data_row_idx += 1
1694
- expr.sql_row_idx = len(self.sql_exprs)
1695
- self.sql_exprs.append(sql_expr)
1696
- return
1697
-
1698
- # expr value needs to be computed via Expr.eval()
1699
- for c in expr.components:
1700
- self._analyze_expr(c)
1701
- assert expr.data_row_idx < 0
1702
- expr.data_row_idx = self.next_data_row_idx
1703
- self.next_data_row_idx += 1
1704
-
1705
-
1706
- class ComputedColEvalCtx:
1707
- """
1708
- EvalCtx for computed cols:
1709
- - referenced inputs are not supplied via SQL
1710
- - a col's ColumnRef and value_expr need to share the same data_row_idx
1711
- """
1712
-
1713
- def __init__(self, computed_col_info: List[Tuple[ColumnRef, Expr]]):
1714
- """
1715
- computed_col_info: list of (ref to col, value_expr of col)
1716
- """
1717
-
1718
- # we want to avoid duplicate expr evaluation, so we keep track of unique exprs (duplicates share the
1719
- # same data_row_idx); however, __eq__() doesn't work for sets, so we use a list here
1720
- self.unique_exprs = UniqueExprSet()
1721
- self.next_data_row_idx = 0
1722
-
1723
- for col_ref, expr in computed_col_info:
1724
- self._analyze_expr(expr)
1725
- # the expr materializes the value of that column
1726
- col_ref.data_row_idx = expr.data_row_idx
1727
- # future references to that column will use the already-assigned data_row_idx
1728
- self.unique_exprs.add(col_ref)
1729
-
1730
- @property
1731
- def num_materialized(self) -> int:
1732
- return self.next_data_row_idx
1733
-
1734
- def _analyze_expr(self, expr: Expr) -> None:
1735
- """
1736
- Assign Expr.data_row_idx.
1737
- """
1738
- if not self.unique_exprs.add(expr):
1739
- # nothing left to do
1740
- return
1741
- for c in expr.components:
1742
- self._analyze_expr(c)
1743
- assert expr.data_row_idx < 0
1744
- expr.data_row_idx = self.next_data_row_idx
1745
- self.next_data_row_idx += 1