pixeltable 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (47) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/table_version.py +2 -1
  3. pixeltable/dataframe.py +52 -27
  4. pixeltable/env.py +92 -4
  5. pixeltable/exec/__init__.py +1 -1
  6. pixeltable/exec/aggregation_node.py +3 -3
  7. pixeltable/exec/cache_prefetch_node.py +13 -7
  8. pixeltable/exec/component_iteration_node.py +3 -9
  9. pixeltable/exec/data_row_batch.py +17 -5
  10. pixeltable/exec/exec_node.py +32 -12
  11. pixeltable/exec/expr_eval/__init__.py +1 -0
  12. pixeltable/exec/expr_eval/evaluators.py +245 -0
  13. pixeltable/exec/expr_eval/expr_eval_node.py +404 -0
  14. pixeltable/exec/expr_eval/globals.py +114 -0
  15. pixeltable/exec/expr_eval/row_buffer.py +76 -0
  16. pixeltable/exec/expr_eval/schedulers.py +232 -0
  17. pixeltable/exec/in_memory_data_node.py +2 -2
  18. pixeltable/exec/row_update_node.py +14 -14
  19. pixeltable/exec/sql_node.py +2 -2
  20. pixeltable/exprs/column_ref.py +5 -1
  21. pixeltable/exprs/data_row.py +50 -40
  22. pixeltable/exprs/expr.py +57 -12
  23. pixeltable/exprs/function_call.py +54 -19
  24. pixeltable/exprs/inline_expr.py +12 -21
  25. pixeltable/exprs/literal.py +25 -8
  26. pixeltable/exprs/row_builder.py +23 -0
  27. pixeltable/func/aggregate_function.py +4 -0
  28. pixeltable/func/callable_function.py +54 -4
  29. pixeltable/func/expr_template_function.py +5 -1
  30. pixeltable/func/function.py +48 -7
  31. pixeltable/func/query_template_function.py +16 -7
  32. pixeltable/func/udf.py +7 -1
  33. pixeltable/functions/__init__.py +1 -1
  34. pixeltable/functions/anthropic.py +95 -21
  35. pixeltable/functions/gemini.py +2 -6
  36. pixeltable/functions/openai.py +207 -28
  37. pixeltable/globals.py +1 -1
  38. pixeltable/plan.py +24 -9
  39. pixeltable/store.py +6 -0
  40. pixeltable/type_system.py +3 -3
  41. pixeltable/utils/arrow.py +3 -3
  42. {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/METADATA +3 -1
  43. {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/RECORD +46 -41
  44. pixeltable/exec/expr_eval_node.py +0 -232
  45. {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/LICENSE +0 -0
  46. {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/WHEEL +0 -0
  47. {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/entry_points.txt +0 -0
pixeltable/exprs/expr.py CHANGED
@@ -341,26 +341,67 @@ class Expr(abc.ABC):
341
341
  result.extend(cls.get_refd_columns(component_dict))
342
342
  return result
343
343
 
344
+ def is_constant(self) -> bool:
345
+ """Returns True if this expr is a constant."""
346
+ return all(comp.is_constant() for comp in self.components)
347
+
348
+ def _as_constant(self) -> Any:
349
+ return None
350
+
351
+ def as_constant(self) -> Any:
352
+ """
353
+ If expression is a constant then return the associated value which will be converted to a Literal.
354
+ """
355
+ if self.is_constant():
356
+ return self._as_constant()
357
+ return None
358
+
359
+ @classmethod
360
+ def from_array(cls, elements: Iterable) -> Optional[Expr]:
361
+ from .inline_expr import InlineArray
362
+ inline_array = InlineArray(elements)
363
+ constant_array = inline_array.as_constant()
364
+ if constant_array is not None:
365
+ from .literal import Literal
366
+ return Literal(constant_array, inline_array.col_type)
367
+ else:
368
+ return inline_array
369
+
344
370
  @classmethod
345
371
  def from_object(cls, o: object) -> Optional[Expr]:
346
372
  """
347
373
  Try to turn a literal object into an Expr.
348
374
  """
349
- if isinstance(o, Expr):
350
- return o
351
375
  # Try to create a literal. We need to check for InlineList/InlineDict
352
376
  # first, to prevent them from inappropriately being interpreted as JsonType
353
377
  # literals.
354
- if isinstance(o, list):
355
- from .inline_expr import InlineList
356
- return InlineList(o)
357
- if isinstance(o, dict):
358
- from .inline_expr import InlineDict
359
- return InlineDict(o)
360
- obj_type = ts.ColumnType.infer_literal_type(o)
361
- if obj_type is not None:
362
- from .literal import Literal
363
- return Literal(o, col_type=obj_type)
378
+ if isinstance(o, (list, tuple, dict, Expr)):
379
+ expr: Optional[Expr] = None
380
+ if isinstance(o, (list, tuple)):
381
+ from .inline_expr import InlineList
382
+ expr = InlineList(o)
383
+ elif isinstance(o, dict):
384
+ from .inline_expr import InlineDict
385
+ expr = InlineDict(o)
386
+ elif isinstance(o, Expr):
387
+ expr = o
388
+ from .literal import Literal
389
+ if isinstance(expr, Literal):
390
+ return expr
391
+ # Check if the expression is constant
392
+ if expr is not None:
393
+ expr_value = expr.as_constant()
394
+ if expr_value is not None:
395
+ from .literal import Literal
396
+ return Literal(expr_value)
397
+ else:
398
+ return expr
399
+ else:
400
+ # convert scalar to a literal
401
+ obj_type = ts.ColumnType.infer_literal_type(o)
402
+ if obj_type is not None:
403
+ from .literal import Literal
404
+ return Literal(o, col_type=obj_type)
364
405
  return None
365
406
 
366
407
  @abc.abstractmethod
@@ -508,6 +549,10 @@ class Expr(abc.ABC):
508
549
  # Return the `MethodRef` object itself; it requires arguments to become a `FunctionCall`
509
550
  return method_ref
510
551
 
552
+ def __rshift__(self, other: object) -> 'exprs.Expr':
553
+ # Implemented here for type-checking purposes
554
+ raise excs.Error('The `>>` operator can only be applied to Json expressions')
555
+
511
556
  def __bool__(self) -> bool:
512
557
  raise TypeError(
513
558
  'Pixeltable expressions cannot be used in conjunction with Python boolean operators (and/or/not)')
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import inspect
4
4
  import json
5
5
  import sys
6
- from typing import Any, Optional
6
+ from typing import Any, Optional, Sequence
7
7
 
8
8
  import sqlalchemy as sql
9
9
 
@@ -26,6 +26,7 @@ class FunctionCall(Expr):
26
26
  fn: func.Function
27
27
  is_method_call: bool
28
28
  agg_init_args: dict[str, Any]
29
+ resource_pool: Optional[str]
29
30
 
30
31
  # tuple[Optional[int], Optional[Any]]:
31
32
  # - for Exprs: (index into components, None)
@@ -33,6 +34,12 @@ class FunctionCall(Expr):
33
34
  args: list[tuple[Optional[int], Optional[Any]]]
34
35
  kwargs: dict[str, tuple[Optional[int], Optional[Any]]]
35
36
 
37
+ # maps each parameter name to tuple representing the value it has in the call:
38
+ # - argument's index in components, if an argument is given in the call
39
+ # - default value, if no argument given in the call
40
+ # (in essence, this combines init()'s bound_args and default values)
41
+ _param_values: dict[str, tuple[Optional[int], Optional[Any]]]
42
+
36
43
  arg_types: list[ts.ColumnType]
37
44
  kwarg_types: dict[str, ts.ColumnType]
38
45
  return_type: ts.ColumnType
@@ -62,7 +69,8 @@ class FunctionCall(Expr):
62
69
 
63
70
  self.fn = fn
64
71
  self.is_method_call = is_method_call
65
-
72
+ #self.normalize_args(fn.name, signature, bound_args)
73
+ self.resource_pool = fn.call_resource_pool(bound_args)
66
74
  signature = fn.signature
67
75
 
68
76
  # If `return_type` is non-nullable, but the function call has a nullable input to any of its non-nullable
@@ -93,6 +101,7 @@ class FunctionCall(Expr):
93
101
  # construct components, args, kwargs
94
102
  self.args = []
95
103
  self.kwargs = {}
104
+ self._param_values = {}
96
105
 
97
106
  # we record the types of non-variable parameters for runtime type checks
98
107
  self.arg_types = []
@@ -106,9 +115,11 @@ class FunctionCall(Expr):
106
115
  arg = bound_args[py_param.name]
107
116
  if isinstance(arg, Expr):
108
117
  self.args.append((len(self.components), None))
118
+ self._param_values[py_param.name] = (len(self.components), None)
109
119
  self.components.append(arg.copy())
110
120
  else:
111
121
  self.args.append((None, arg))
122
+ self._param_values[py_param.name] = (None, arg)
112
123
  if py_param.kind != inspect.Parameter.VAR_POSITIONAL and py_param.kind != inspect.Parameter.VAR_KEYWORD:
113
124
  self.arg_types.append(signature.parameters[py_param.name].col_type)
114
125
  processed_args.add(py_param.name)
@@ -119,12 +130,21 @@ class FunctionCall(Expr):
119
130
  arg = bound_args[param_name]
120
131
  if isinstance(arg, Expr):
121
132
  self.kwargs[param_name] = (len(self.components), None)
133
+ self._param_values[param_name] = (len(self.components), None)
122
134
  self.components.append(arg.copy())
123
135
  else:
124
136
  self.kwargs[param_name] = (None, arg)
137
+ self._param_values[param_name] = (None, arg)
125
138
  if signature.py_signature.parameters[param_name].kind != inspect.Parameter.VAR_KEYWORD:
126
139
  self.kwarg_types[param_name] = signature.parameters[param_name].col_type
127
140
 
141
+ # fill in default values for parameters that don't have explicit arguments
142
+ for param in fn.signature.parameters.values():
143
+ if param.name not in self._param_values:
144
+ self._param_values[param.name] = (
145
+ (None, None) if param.default is inspect.Parameter.empty else (None, param.default)
146
+ )
147
+
128
148
  # window function state:
129
149
  # self.components[self.group_by_start_idx:self.group_by_stop_idx] contains group_by exprs
130
150
  self.group_by_start_idx, self.group_by_stop_idx = 0, 0
@@ -374,11 +394,11 @@ class FunctionCall(Expr):
374
394
  Update agg state
375
395
  """
376
396
  assert self.is_agg_fn_call
377
- args, kwargs = self._make_args(data_row)
397
+ args, kwargs = self.make_args(data_row)
378
398
  self.aggregator.update(*args, **kwargs)
379
399
 
380
- def _make_args(self, data_row: DataRow) -> tuple[list[Any], dict[str, Any]]:
381
- """Return args and kwargs, constructed for data_row"""
400
+ def make_args(self, data_row: DataRow) -> Optional[tuple[list[Any], dict[str, Any]]]:
401
+ """Return args and kwargs, constructed for data_row; returns None if any non-nullable arg is None."""
382
402
  kwargs: dict[str, Any] = {}
383
403
  for param_name, (component_idx, arg) in self.kwargs.items():
384
404
  val = arg if component_idx is None else data_row[self.components[component_idx].slot_idx]
@@ -388,6 +408,8 @@ class FunctionCall(Expr):
388
408
  kwargs.update(val)
389
409
  else:
390
410
  assert param.kind != inspect.Parameter.VAR_POSITIONAL
411
+ if not param.col_type.nullable and val is None:
412
+ return None
391
413
  kwargs[param_name] = val
392
414
 
393
415
  args: list[Any] = []
@@ -403,9 +425,30 @@ class FunctionCall(Expr):
403
425
  assert isinstance(val, dict)
404
426
  kwargs.update(val)
405
427
  else:
428
+ if not param.col_type.nullable and val is None:
429
+ return None
406
430
  args.append(val)
407
431
  return args, kwargs
408
432
 
433
+ def get_param_values(self, param_names: Sequence[str], data_rows: list[DataRow]) -> list[dict[str, Any]]:
434
+ """
435
+ Returns a list of dicts mapping each param name to its value when this FunctionCall is evaluated against
436
+ data_rows
437
+ """
438
+ assert all(name in self._param_values for name in param_names)
439
+ result: list[dict[str, Any]] = []
440
+ for row in data_rows:
441
+ d: dict[str, Any] = {}
442
+ for param_name in param_names:
443
+ component_idx, default_val = self._param_values[param_name]
444
+ if component_idx is None:
445
+ d[param_name] = default_val
446
+ else:
447
+ slot_idx = self.components[component_idx].slot_idx
448
+ d[param_name] = row[slot_idx]
449
+ result.append(d)
450
+ return result
451
+
409
452
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
410
453
  if isinstance(self.fn, func.ExprTemplateFunction):
411
454
  # we need to evaluate the template
@@ -419,20 +462,12 @@ class FunctionCall(Expr):
419
462
  data_row[self.slot_idx] = self.aggregator.value()
420
463
  return
421
464
 
422
- args, kwargs = self._make_args(data_row)
423
- signature = self.fn.signature
424
- if signature.parameters is not None:
425
- # check for nulls
426
- for i in range(len(self.arg_types)):
427
- if args[i] is None and not self.arg_types[i].nullable:
428
- # we can't evaluate this function
429
- data_row[self.slot_idx] = None
430
- return
431
- for param_name, param_type in self.kwarg_types.items():
432
- if kwargs[param_name] is None and not param_type.nullable:
433
- # we can't evaluate this function
434
- data_row[self.slot_idx] = None
435
- return
465
+ args_kwargs = self.make_args(data_row)
466
+ if args_kwargs is None:
467
+ # we can't evaluate this function
468
+ data_row[self.slot_idx] = None
469
+ return
470
+ args, kwargs = args_kwargs
436
471
 
437
472
  if isinstance(self.fn, func.CallableFunction) and not self.fn.is_batched:
438
473
  # optimization: avoid additional level of indirection we'd get from calling Function.exec()
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import copy
4
3
  from typing import Any, Iterable, Optional
5
4
 
6
5
  import numpy as np
@@ -26,8 +25,8 @@ class InlineArray(Expr):
26
25
  for el in elements:
27
26
  if isinstance(el, Expr):
28
27
  exprs.append(el)
29
- elif isinstance(el, list) or isinstance(el, tuple):
30
- exprs.append(InlineArray(el))
28
+ elif isinstance(el, (list, tuple)):
29
+ exprs.append(Expr.from_array(el))
31
30
  else:
32
31
  exprs.append(Literal(el))
33
32
 
@@ -83,6 +82,9 @@ class InlineArray(Expr):
83
82
  # loaded and their types are known.
84
83
  return InlineList(components) # type: ignore[return-value]
85
84
 
85
+ def _as_constant(self) -> Optional[np.ndarray]:
86
+ assert isinstance(self.col_type, ts.ArrayType)
87
+ return np.array([c.as_constant() for c in self.components], dtype=self.col_type.numpy_dtype())
86
88
 
87
89
  class InlineList(Expr):
88
90
  """
@@ -90,16 +92,7 @@ class InlineList(Expr):
90
92
  """
91
93
 
92
94
  def __init__(self, elements: Iterable):
93
- exprs = []
94
- for el in elements:
95
- if isinstance(el, Expr):
96
- exprs.append(el)
97
- elif isinstance(el, list) or isinstance(el, tuple):
98
- exprs.append(InlineList(el))
99
- elif isinstance(el, dict):
100
- exprs.append(InlineDict(el))
101
- else:
102
- exprs.append(Literal(el))
95
+ exprs = [Expr.from_object(el) for el in elements]
103
96
 
104
97
  json_schema = {
105
98
  'type': 'array',
@@ -131,6 +124,8 @@ class InlineList(Expr):
131
124
  def _from_dict(cls, _: dict, components: list[Expr]) -> InlineList:
132
125
  return cls(components)
133
126
 
127
+ def _as_constant(self) -> Optional[list[Any]]:
128
+ return list(c.as_constant() for c in self.components)
134
129
 
135
130
  class InlineDict(Expr):
136
131
  """
@@ -146,14 +141,7 @@ class InlineDict(Expr):
146
141
  if not isinstance(key, str):
147
142
  raise excs.Error(f'Dictionary requires string keys; {key} has type {type(key)}')
148
143
  self.keys.append(key)
149
- if isinstance(val, Expr):
150
- exprs.append(val)
151
- elif isinstance(val, dict):
152
- exprs.append(InlineDict(val))
153
- elif isinstance(val, list) or isinstance(val, tuple):
154
- exprs.append(InlineList(val))
155
- else:
156
- exprs.append(Literal(val))
144
+ exprs.append(Expr.from_object(val))
157
145
 
158
146
  json_schema: Optional[dict[str, Any]]
159
147
  try:
@@ -218,3 +206,6 @@ class InlineDict(Expr):
218
206
  assert len(d['keys']) == len(components)
219
207
  arg = dict(zip(d['keys'], components))
220
208
  return InlineDict(arg)
209
+
210
+ def _as_constant(self) -> Optional[dict[str, Any]]:
211
+ return dict(zip(self.keys, (c.as_constant() for c in self.components)))
@@ -4,6 +4,7 @@ import datetime
4
4
  from typing import Any, Optional
5
5
 
6
6
  import sqlalchemy as sql
7
+ import numpy as np
7
8
 
8
9
  import pixeltable.type_system as ts
9
10
  from pixeltable.env import Env
@@ -33,6 +34,9 @@ class Literal(Expr):
33
34
  val = val.replace(tzinfo=default_tz)
34
35
  # Now convert to UTC
35
36
  val = val.astimezone(datetime.timezone.utc)
37
+ if isinstance(val, tuple):
38
+ # Tuples are stored as a list
39
+ val = list(val)
36
40
  self.val = val
37
41
  self.id = self._create_id()
38
42
 
@@ -46,6 +50,9 @@ class Literal(Expr):
46
50
  assert isinstance(self.val, datetime.datetime)
47
51
  default_tz = Env.get().default_time_zone
48
52
  return f"'{self.val.astimezone(default_tz).isoformat()}'"
53
+ if self.col_type.is_array_type():
54
+ assert isinstance(self.val, np.ndarray)
55
+ return str(self.val.tolist())
49
56
  return str(self.val)
50
57
 
51
58
  def __repr__(self) -> str:
@@ -67,7 +74,7 @@ class Literal(Expr):
67
74
  data_row[self.slot_idx] = self.val
68
75
 
69
76
  def _as_dict(self) -> dict:
70
- # For some types, we need to explictly record their type, because JSON does not know
77
+ # For some types, we need to explicitly record their type, because JSON does not know
71
78
  # how to interpret them unambiguously
72
79
  if self.col_type.is_timestamp_type():
73
80
  assert isinstance(self.val, datetime.datetime)
@@ -76,18 +83,28 @@ class Literal(Expr):
76
83
  # stored as UTC in the database)
77
84
  encoded_val = self.val.isoformat()
78
85
  return {'val': encoded_val, 'val_t': self.col_type._type.name, **super()._as_dict()}
86
+ elif self.col_type.is_array_type():
87
+ assert isinstance(self.val, np.ndarray)
88
+ return {'val': self.val.tolist(), 'val_t': self.col_type._type.name, **super()._as_dict()}
79
89
  else:
80
90
  return {'val': self.val, **super()._as_dict()}
81
91
 
92
+ def _as_constant(self) -> Any:
93
+ return self.val
94
+
95
+ def is_constant(self) -> bool:
96
+ return True
97
+
82
98
  @classmethod
83
99
  def _from_dict(cls, d: dict, components: list[Expr]) -> Literal:
84
100
  assert 'val' in d
85
101
  if 'val_t' in d:
86
102
  val_t = d['val_t']
87
- # Currently the only special-cased literal type is TIMESTAMP
88
- assert val_t == ts.ColumnType.Type.TIMESTAMP.name
89
- dt = datetime.datetime.fromisoformat(d['val'])
90
- assert dt.tzinfo == datetime.timezone.utc # Must be UTC in the database
91
- return cls(dt)
92
- else:
93
- return cls(d['val'])
103
+ if val_t == ts.ColumnType.Type.TIMESTAMP.name:
104
+ dt = datetime.datetime.fromisoformat(d['val'])
105
+ assert dt.tzinfo == datetime.timezone.utc # Must be UTC in the database
106
+ return cls(dt)
107
+ elif val_t == ts.ColumnType.Type.ARRAY.name:
108
+ arrays = np.array(d['val'])
109
+ return cls(arrays)
110
+ return cls(d['val'])
@@ -6,12 +6,14 @@ from dataclasses import dataclass
6
6
  from typing import Any, Iterable, Optional, Sequence
7
7
  from uuid import UUID
8
8
 
9
+ import numpy as np
9
10
  import sqlalchemy as sql
10
11
 
11
12
  import pixeltable.catalog as catalog
12
13
  import pixeltable.exceptions as excs
13
14
  import pixeltable.func as func
14
15
  import pixeltable.utils as utils
16
+ from pixeltable.utils.media_store import MediaStore
15
17
  from .data_row import DataRow
16
18
  from .expr import Expr
17
19
  from .expr_set import ExprSet
@@ -68,6 +70,12 @@ class RowBuilder:
68
70
  # (list of set of slot_idxs, indexed by slot_idx)
69
71
  _exc_dependents: list[set[int]]
70
72
 
73
+ # dependents[i] = direct dependents of expr with slot idx i; dependents[i, j] == True: expr j depends on expr i
74
+ dependents: np.ndarray # of bool
75
+ transitive_dependents: np.ndarray # of bool
76
+ # dependencies[i] = direct dependencies of expr with slot idx i; transpose of dependents
77
+ dependencies: np.ndarray # of bool
78
+
71
79
  # records the output_expr that a subexpr belongs to
72
80
  # (a subexpr can be shared across multiple output exprs)
73
81
  output_expr_ids: list[set[int]]
@@ -176,6 +184,8 @@ class RowBuilder:
176
184
 
177
185
  # determine transitive dependencies for the purpose of exception propagation
178
186
  # (list of set of slot_idxs, indexed by slot_idx)
187
+ #self.dependents = np.zeros((self.num_materialized, self.num_materialized), dtype=bool)
188
+ self.dependencies = np.zeros((self.num_materialized, self.num_materialized), dtype=bool)
179
189
  exc_dependencies: list[set[int]] = [set() for _ in range(self.num_materialized)]
180
190
  from .column_property_ref import ColumnPropertyRef
181
191
  for expr in self.unique_exprs:
@@ -185,10 +195,19 @@ class RowBuilder:
185
195
  # error properties don't have exceptions themselves
186
196
  if isinstance(expr, ColumnPropertyRef) and expr.is_error_prop():
187
197
  continue
198
+ dependency_idxs = [d.slot_idx for d in expr.dependencies()]
199
+ self.dependencies[expr.slot_idx, dependency_idxs] = True
188
200
  for d in expr.dependencies():
189
201
  exc_dependencies[expr.slot_idx].add(d.slot_idx)
190
202
  exc_dependencies[expr.slot_idx].update(exc_dependencies[d.slot_idx])
191
203
 
204
+ self.dependents = self.dependencies.T
205
+ self.transitive_dependents = np.zeros((self.num_materialized, self.num_materialized), dtype=bool)
206
+ for i in reversed(range(self.num_materialized)):
207
+ self.transitive_dependents[i] = (
208
+ self.dependents[i] | np.any(self.transitive_dependents[self.dependents[i]], axis=0)
209
+ )
210
+
192
211
  self._exc_dependents = [set() for _ in range(self.num_materialized)]
193
212
  for expr in self.unique_exprs:
194
213
  assert expr.slot_idx is not None
@@ -389,6 +408,10 @@ class RowBuilder:
389
408
  table_row[col.errortype_store_name()] = type(exc).__name__
390
409
  table_row[col.errormsg_store_name()] = str(exc)
391
410
  else:
411
+ if col.col_type.is_image_type() and data_row.file_urls[slot_idx] is None:
412
+ # we have yet to store this image
413
+ filepath = str(MediaStore.prepare_media_path(col.tbl.id, col.id, col.tbl.version))
414
+ data_row.flush_img(slot_idx, filepath)
392
415
  val = data_row.get_stored_val(slot_idx, col.sa_col.type)
393
416
  table_row[col.store_name()] = val
394
417
  # we unfortunately need to set these, even if there are no errors
@@ -123,6 +123,10 @@ class AggregateFunction(Function):
123
123
  assert not self.is_polymorphic
124
124
  return self.agg_classes[0]
125
125
 
126
+ @property
127
+ def is_async(self) -> bool:
128
+ return False
129
+
126
130
  def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
127
131
  raise NotImplementedError
128
132
 
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import inspect
4
4
  from typing import Any, Callable, Optional, Sequence
5
5
  from uuid import UUID
6
+ import asyncio
6
7
 
7
8
  import cloudpickle # type: ignore[import-untyped]
8
9
 
@@ -19,6 +20,9 @@ class CallableFunction(Function):
19
20
  - references to lambdas and functions defined in notebooks, which are pickled and serialized to the store
20
21
  - functions that are defined in modules are serialized via the default mechanism
21
22
  """
23
+ py_fns: list[Callable]
24
+ self_name: Optional[str]
25
+ batch_size: Optional[int]
22
26
 
23
27
  def __init__(
24
28
  self,
@@ -48,6 +52,10 @@ class CallableFunction(Function):
48
52
  def is_batched(self) -> bool:
49
53
  return self.batch_size is not None
50
54
 
55
+ @property
56
+ def is_async(self) -> bool:
57
+ return inspect.iscoroutinefunction(self.py_fn)
58
+
51
59
  def _docstring(self) -> Optional[str]:
52
60
  return inspect.getdoc(self.py_fns[0])
53
61
 
@@ -56,6 +64,21 @@ class CallableFunction(Function):
56
64
  assert not self.is_polymorphic
57
65
  return self.py_fns[0]
58
66
 
67
+ async def aexec(self, *args: Any, **kwargs: Any) -> Any:
68
+ assert not self.is_polymorphic
69
+ assert self.is_async
70
+ if self.is_batched:
71
+ # Pack the batched parameters into singleton lists
72
+ constant_param_names = [p.name for p in self.signature.constant_parameters]
73
+ batched_args = [[arg] for arg in args]
74
+ constant_kwargs = {k: v for k, v in kwargs.items() if k in constant_param_names}
75
+ batched_kwargs = {k: [v] for k, v in kwargs.items() if k not in constant_param_names}
76
+ result = await self.py_fn(*batched_args, **constant_kwargs, **batched_kwargs)
77
+ assert len(result) == 1
78
+ return result[0]
79
+ else:
80
+ return await self.py_fn(*args, **kwargs)
81
+
59
82
  def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
60
83
  assert not self.is_polymorphic
61
84
  if self.is_batched:
@@ -64,11 +87,31 @@ class CallableFunction(Function):
64
87
  batched_args = [[arg] for arg in args]
65
88
  constant_kwargs = {k: v for k, v in kwargs.items() if k in constant_param_names}
66
89
  batched_kwargs = {k: [v] for k, v in kwargs.items() if k not in constant_param_names}
67
- result = self.py_fn(*batched_args, **constant_kwargs, **batched_kwargs)
90
+ result: list[Any]
91
+ if inspect.iscoroutinefunction(self.py_fn):
92
+ result = asyncio.run(self.py_fn(*batched_args, **constant_kwargs, **batched_kwargs))
93
+ else:
94
+ result = self.py_fn(*batched_args, **constant_kwargs, **batched_kwargs)
68
95
  assert len(result) == 1
69
96
  return result[0]
70
97
  else:
71
- return self.py_fn(*args, **kwargs)
98
+ if inspect.iscoroutinefunction(self.py_fn):
99
+ return asyncio.run(self.py_fn(*args, **kwargs))
100
+ else:
101
+ return self.py_fn(*args, **kwargs)
102
+
103
+ async def aexec_batch(self, *args: Any, **kwargs: Any) -> list:
104
+ """Execute the function with the given arguments and return the result.
105
+ The arguments are expected to be batched: if the corresponding parameter has type T,
106
+ then the argument should have type T if it's a constant parameter, or list[T] if it's
107
+ a batched parameter.
108
+ """
109
+ assert self.is_batched
110
+ assert self.is_async
111
+ assert not self.is_polymorphic
112
+ # Unpack the constant parameters
113
+ constant_kwargs, batched_kwargs = self.create_batch_kwargs(kwargs)
114
+ return await self.py_fn(*args, **constant_kwargs, **batched_kwargs)
72
115
 
73
116
  def exec_batch(self, args: list[Any], kwargs: dict[str, Any]) -> list:
74
117
  """Execute the function with the given arguments and return the result.
@@ -79,12 +122,19 @@ class CallableFunction(Function):
79
122
  assert self.is_batched
80
123
  assert not self.is_polymorphic
81
124
  # Unpack the constant parameters
125
+ constant_kwargs, batched_kwargs = self.create_batch_kwargs(kwargs)
126
+ if inspect.iscoroutinefunction(self.py_fn):
127
+ return asyncio.run(self.py_fn(*args, **constant_kwargs, **batched_kwargs))
128
+ else:
129
+ return self.py_fn(*args, **constant_kwargs, **batched_kwargs)
130
+
131
+ def create_batch_kwargs(self, kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, list[Any]]]:
132
+ """Converts kwargs containing lists into constant and batched kwargs in the format expected by a batched udf."""
82
133
  constant_param_names = [p.name for p in self.signature.constant_parameters]
83
134
  constant_kwargs = {k: v[0] for k, v in kwargs.items() if k in constant_param_names}
84
135
  batched_kwargs = {k: v for k, v in kwargs.items() if k not in constant_param_names}
85
- return self.py_fn(*args, **constant_kwargs, **batched_kwargs)
136
+ return constant_kwargs, batched_kwargs
86
137
 
87
- # TODO(aaron-siegel): Implement conditional batch sizing
88
138
  def get_batch_size(self, *args: Any, **kwargs: Any) -> Optional[int]:
89
139
  return self.batch_size
90
140
 
@@ -100,7 +100,7 @@ class ExprTemplateFunction(Function):
100
100
  assert not self.is_polymorphic
101
101
  expr = self.instantiate(args, kwargs)
102
102
  row_builder = exprs.RowBuilder(output_exprs=[expr], columns=[], input_exprs=[])
103
- row_batch = exec.DataRowBatch(tbl=None, row_builder=row_builder, len=1)
103
+ row_batch = exec.DataRowBatch(tbl=None, row_builder=row_builder, num_rows=1)
104
104
  row = row_batch[0]
105
105
  row_builder.eval(row, ctx=row_builder.default_eval_ctx)
106
106
  return row[row_builder.get_output_exprs()[0].slot_idx]
@@ -113,6 +113,10 @@ class ExprTemplateFunction(Function):
113
113
  def name(self) -> str:
114
114
  return self.self_name
115
115
 
116
+ @property
117
+ def is_async(self) -> bool:
118
+ return False
119
+
116
120
  def __str__(self) -> str:
117
121
  return str(self.templates[0].expr)
118
122