pixeltable 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (110) hide show
  1. pixeltable/__init__.py +20 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +23 -7
  4. pixeltable/catalog/insertable_table.py +32 -19
  5. pixeltable/catalog/table.py +210 -20
  6. pixeltable/catalog/table_version.py +272 -111
  7. pixeltable/catalog/table_version_path.py +6 -1
  8. pixeltable/dataframe.py +184 -110
  9. pixeltable/datatransfer/__init__.py +1 -0
  10. pixeltable/datatransfer/label_studio.py +526 -0
  11. pixeltable/datatransfer/remote.py +113 -0
  12. pixeltable/env.py +213 -79
  13. pixeltable/exec/__init__.py +2 -1
  14. pixeltable/exec/data_row_batch.py +6 -7
  15. pixeltable/exec/expr_eval_node.py +28 -28
  16. pixeltable/exec/sql_scan_node.py +7 -6
  17. pixeltable/exprs/__init__.py +4 -3
  18. pixeltable/exprs/column_ref.py +11 -2
  19. pixeltable/exprs/comparison.py +39 -1
  20. pixeltable/exprs/data_row.py +7 -0
  21. pixeltable/exprs/expr.py +26 -19
  22. pixeltable/exprs/function_call.py +17 -18
  23. pixeltable/exprs/globals.py +14 -2
  24. pixeltable/exprs/image_member_access.py +9 -28
  25. pixeltable/exprs/in_predicate.py +96 -0
  26. pixeltable/exprs/inline_array.py +13 -11
  27. pixeltable/exprs/inline_dict.py +15 -13
  28. pixeltable/exprs/row_builder.py +7 -1
  29. pixeltable/exprs/similarity_expr.py +67 -0
  30. pixeltable/ext/functions/whisperx.py +30 -0
  31. pixeltable/ext/functions/yolox.py +16 -0
  32. pixeltable/func/__init__.py +0 -2
  33. pixeltable/func/aggregate_function.py +5 -2
  34. pixeltable/func/callable_function.py +57 -13
  35. pixeltable/func/expr_template_function.py +14 -3
  36. pixeltable/func/function.py +35 -4
  37. pixeltable/func/signature.py +5 -15
  38. pixeltable/func/udf.py +8 -12
  39. pixeltable/functions/fireworks.py +9 -4
  40. pixeltable/functions/huggingface.py +48 -5
  41. pixeltable/functions/openai.py +49 -11
  42. pixeltable/functions/pil/image.py +61 -64
  43. pixeltable/functions/together.py +32 -6
  44. pixeltable/functions/util.py +0 -43
  45. pixeltable/functions/video.py +46 -8
  46. pixeltable/globals.py +443 -0
  47. pixeltable/index/__init__.py +1 -0
  48. pixeltable/index/base.py +9 -2
  49. pixeltable/index/btree.py +54 -0
  50. pixeltable/index/embedding_index.py +91 -15
  51. pixeltable/io/__init__.py +4 -0
  52. pixeltable/io/globals.py +59 -0
  53. pixeltable/{utils → io}/hf_datasets.py +48 -17
  54. pixeltable/io/pandas.py +148 -0
  55. pixeltable/{utils → io}/parquet.py +58 -33
  56. pixeltable/iterators/__init__.py +1 -1
  57. pixeltable/iterators/base.py +8 -4
  58. pixeltable/iterators/document.py +225 -93
  59. pixeltable/iterators/video.py +16 -9
  60. pixeltable/metadata/__init__.py +8 -4
  61. pixeltable/metadata/converters/convert_12.py +3 -0
  62. pixeltable/metadata/converters/convert_13.py +41 -0
  63. pixeltable/metadata/converters/convert_14.py +13 -0
  64. pixeltable/metadata/converters/convert_15.py +29 -0
  65. pixeltable/metadata/converters/util.py +63 -0
  66. pixeltable/metadata/schema.py +12 -6
  67. pixeltable/plan.py +11 -24
  68. pixeltable/store.py +16 -23
  69. pixeltable/tool/create_test_db_dump.py +49 -14
  70. pixeltable/type_system.py +27 -58
  71. pixeltable/utils/coco.py +94 -0
  72. pixeltable/utils/documents.py +42 -12
  73. pixeltable/utils/http_server.py +70 -0
  74. pixeltable-0.2.7.dist-info/METADATA +137 -0
  75. pixeltable-0.2.7.dist-info/RECORD +126 -0
  76. {pixeltable-0.2.5.dist-info → pixeltable-0.2.7.dist-info}/WHEEL +1 -1
  77. pixeltable/client.py +0 -600
  78. pixeltable/exprs/image_similarity_predicate.py +0 -58
  79. pixeltable/func/batched_function.py +0 -53
  80. pixeltable/func/nos_function.py +0 -202
  81. pixeltable/tests/conftest.py +0 -171
  82. pixeltable/tests/ext/test_yolox.py +0 -21
  83. pixeltable/tests/functions/test_fireworks.py +0 -43
  84. pixeltable/tests/functions/test_functions.py +0 -60
  85. pixeltable/tests/functions/test_huggingface.py +0 -158
  86. pixeltable/tests/functions/test_openai.py +0 -162
  87. pixeltable/tests/functions/test_together.py +0 -112
  88. pixeltable/tests/test_audio.py +0 -65
  89. pixeltable/tests/test_catalog.py +0 -27
  90. pixeltable/tests/test_client.py +0 -21
  91. pixeltable/tests/test_component_view.py +0 -379
  92. pixeltable/tests/test_dataframe.py +0 -440
  93. pixeltable/tests/test_dirs.py +0 -107
  94. pixeltable/tests/test_document.py +0 -120
  95. pixeltable/tests/test_exprs.py +0 -802
  96. pixeltable/tests/test_function.py +0 -332
  97. pixeltable/tests/test_index.py +0 -138
  98. pixeltable/tests/test_migration.py +0 -44
  99. pixeltable/tests/test_nos.py +0 -54
  100. pixeltable/tests/test_snapshot.py +0 -231
  101. pixeltable/tests/test_table.py +0 -1343
  102. pixeltable/tests/test_transactional_directory.py +0 -42
  103. pixeltable/tests/test_types.py +0 -52
  104. pixeltable/tests/test_video.py +0 -159
  105. pixeltable/tests/test_view.py +0 -535
  106. pixeltable/tests/utils.py +0 -442
  107. pixeltable/utils/clip.py +0 -18
  108. pixeltable-0.2.5.dist-info/METADATA +0 -128
  109. pixeltable-0.2.5.dist-info/RECORD +0 -139
  110. {pixeltable-0.2.5.dist-info → pixeltable-0.2.7.dist-info}/LICENSE +0 -0
@@ -1,15 +1,15 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, List, Any, Dict, Tuple
2
+
3
3
  import copy
4
+ from typing import Optional, List, Any, Dict, Tuple
4
5
 
5
6
  import sqlalchemy as sql
6
7
 
7
- from .expr import Expr
8
- from .data_row import DataRow
9
- from .row_builder import RowBuilder
10
8
  import pixeltable.exceptions as excs
11
- import pixeltable.catalog as catalog
12
9
  import pixeltable.type_system as ts
10
+ from .data_row import DataRow
11
+ from .expr import Expr
12
+ from .row_builder import RowBuilder
13
13
 
14
14
 
15
15
  class InlineDict(Expr):
@@ -21,8 +21,8 @@ class InlineDict(Expr):
21
21
  super().__init__(ts.JsonType()) # we need to call this in order to populate self.components
22
22
  # dict_items contains
23
23
  # - for Expr fields: (key, index into components, None)
24
- # - for non-Expr fields: (key, -1, value)
25
- self.dict_items: List[Tuple[str, int, Any]] = []
24
+ # - for non-Expr fields: (key, None, value)
25
+ self.dict_items: List[Tuple[str, Optional[int], Any]] = []
26
26
  for key, val in d.items():
27
27
  if not isinstance(key, str):
28
28
  raise excs.Error(f'Dictionary requires string keys, {key} has type {type(key)}')
@@ -35,11 +35,11 @@ class InlineDict(Expr):
35
35
  self.dict_items.append((key, len(self.components), None))
36
36
  self.components.append(val)
37
37
  else:
38
- self.dict_items.append((key, -1, val))
38
+ self.dict_items.append((key, None, val))
39
39
 
40
40
  self.type_spec: Optional[Dict[str, ts.ColumnType]] = {}
41
41
  for key, idx, _ in self.dict_items:
42
- if idx == -1:
42
+ if idx is None:
43
43
  # TODO: implement type inference for values
44
44
  self.type_spec = None
45
45
  break
@@ -56,7 +56,7 @@ class InlineDict(Expr):
56
56
  return f"'{val}'"
57
57
  return str(val)
58
58
  for key, idx, val in self.dict_items:
59
- if idx != -1:
59
+ if idx is not None:
60
60
  item_strs.append(f"'{key}': {str(self.components[i])}")
61
61
  i += 1
62
62
  else:
@@ -71,7 +71,7 @@ class InlineDict(Expr):
71
71
 
72
72
  def to_dict(self) -> Dict[str, Any]:
73
73
  """Return the original dict used to construct this"""
74
- return {key: val if idx == -1 else self.components[idx] for key, idx, val in self.dict_items}
74
+ return {key: val if idx is None else self.components[idx] for key, idx, val in self.dict_items}
75
75
 
76
76
  def sql_expr(self) -> Optional[sql.ClauseElement]:
77
77
  return None
@@ -80,7 +80,7 @@ class InlineDict(Expr):
80
80
  result = {}
81
81
  for key, idx, val in self.dict_items:
82
82
  assert isinstance(key, str)
83
- if idx >= 0:
83
+ if idx is not None:
84
84
  result[key] = data_row[self.components[idx].slot_idx]
85
85
  else:
86
86
  result[key] = copy.deepcopy(val)
@@ -94,7 +94,9 @@ class InlineDict(Expr):
94
94
  assert 'dict_items' in d
95
95
  arg: Dict[str, Any] = {}
96
96
  for key, idx, val in d['dict_items']:
97
- if idx >= 0:
97
+ # TODO Normalize idx -1 to None via schema migrations.
98
+ # Long-term we should not be allowing idx == -1.
99
+ if idx is not None and idx >= 0: # Older schemas might have -1 instead of None
98
100
  arg[key] = components[idx]
99
101
  else:
100
102
  arg[key] = val
@@ -60,6 +60,8 @@ class RowBuilder:
60
60
  Args:
61
61
  output_exprs: list of Exprs to be evaluated
62
62
  columns: list of columns to be materialized
63
+ input_exprs: list of Exprs that are excluded from evaluation (because they're already materialized)
64
+ TODO: enforce that output_exprs doesn't overlap with input_exprs?
63
65
  """
64
66
  self.unique_exprs = ExprSet() # dependencies precede their dependents
65
67
  self.next_slot_idx = 0
@@ -179,12 +181,16 @@ class RowBuilder:
179
181
  for i, c in enumerate(expr.components):
180
182
  # make sure we only refer to components that have themselves been recorded
181
183
  expr.components[i] = self._record_unique_expr(c, True)
182
- assert expr.slot_idx < 0
184
+ assert expr.slot_idx is None
183
185
  expr.slot_idx = self._next_slot_idx()
184
186
  self.unique_exprs.append(expr)
185
187
  return expr
186
188
 
187
189
  def _record_output_expr_id(self, e: Expr, output_expr_id: int) -> None:
190
+ assert e.slot_idx is not None
191
+ assert output_expr_id is not None
192
+ if e.slot_idx in self.input_expr_slot_idxs:
193
+ return
188
194
  self.output_expr_ids[e.slot_idx].add(output_expr_id)
189
195
  for d in e.dependencies():
190
196
  self._record_output_expr_id(d, output_expr_id)
@@ -0,0 +1,67 @@
1
+ from typing import Optional, List
2
+
3
+ import sqlalchemy as sql
4
+ import PIL.Image
5
+
6
+ import pixeltable.exceptions as excs
7
+ import pixeltable.type_system as ts
8
+ from .column_ref import ColumnRef
9
+ from .data_row import DataRow
10
+ from .expr import Expr
11
+ from .literal import Literal
12
+ from .row_builder import RowBuilder
13
+
14
+
15
+ class SimilarityExpr(Expr):
16
+
17
+ def __init__(self, col_ref: ColumnRef, item: Expr):
18
+ super().__init__(ts.FloatType())
19
+ self.components = [col_ref, item]
20
+ self.id = self._create_id()
21
+ assert isinstance(item, Literal)
22
+ assert item.col_type.is_string_type() or item.col_type.is_image_type()
23
+
24
+ # determine index to use
25
+ idx_info = col_ref.col.get_idx_info()
26
+ import pixeltable.index as index
27
+ embedding_idx_info = [info for info in idx_info.values() if isinstance(info.idx, index.EmbeddingIndex)]
28
+ if len(embedding_idx_info) == 0:
29
+ raise excs.Error(f'No index found for column {col_ref.col}')
30
+ if len(embedding_idx_info) > 1:
31
+ raise excs.Error(
32
+ f'Column {col_ref.col.name} has multiple indices; use the index name to disambiguate, '
33
+ f'e.g., `{col_ref.col.name}.<index-name>.similarity(...)`')
34
+ self.idx_info = embedding_idx_info[0]
35
+ idx = self.idx_info.idx
36
+
37
+ if item.col_type.is_string_type() and idx.txt_embed is None:
38
+ raise excs.Error(
39
+ f'Embedding index {self.idx_info.name} on column {self.idx_info.col.name} was created without the '
40
+ f'text_embed parameter and does not support text queries')
41
+ if item.col_type.is_image_type() and idx.img_embed is None:
42
+ raise excs.Error(
43
+ f'Embedding index {self.idx_info.name} on column {self.idx_info.col.name} was created without the '
44
+ f'img_embed parameter and does not support image queries')
45
+
46
+ def __str__(self) -> str:
47
+ return f'{self.components[0]}.similarity({self.components[1]})'
48
+
49
+ def sql_expr(self) -> Optional[sql.ClauseElement]:
50
+ assert isinstance(self.components[1], Literal)
51
+ item = self.components[1].val
52
+ return self.idx_info.idx.similarity_clause(self.idx_info.val_col, item)
53
+
54
+ def as_order_by_clause(self, is_asc: bool) -> Optional[sql.ClauseElement]:
55
+ assert isinstance(self.components[1], Literal)
56
+ item = self.components[1].val
57
+ return self.idx_info.idx.order_by_clause(self.idx_info.val_col, item, is_asc)
58
+
59
+ def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
60
+ # this should never get called
61
+ assert False
62
+
63
+ @classmethod
64
+ def _from_dict(cls, d: dict, components: List[Expr]) -> Expr:
65
+ assert len(components) == 2
66
+ assert isinstance(components[0], ColumnRef)
67
+ return cls(components[0], components[1])
@@ -0,0 +1,30 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import whisperx
5
+ from whisperx.asr import FasterWhisperPipeline
6
+
7
+ import pixeltable as pxt
8
+
9
+
10
+ @pxt.udf(param_types=[pxt.AudioType(), pxt.StringType(), pxt.StringType(), pxt.StringType(), pxt.IntType()])
11
+ def transcribe(
12
+ audio: str, *, model: str, compute_type: Optional[str] = None, language: Optional[str] = None, chunk_size: int = 30
13
+ ) -> dict:
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ compute_type = compute_type or ('float16' if device == 'cuda' else 'int8')
16
+ model = _lookup_model(model, device, compute_type)
17
+ audio_array = whisperx.load_audio(audio)
18
+ result = model.transcribe(audio_array, batch_size=16, language=language, chunk_size=chunk_size)
19
+ return result
20
+
21
+
22
+ def _lookup_model(model_id: str, device: str, compute_type: str) -> FasterWhisperPipeline:
23
+ key = (model_id, device, compute_type)
24
+ if key not in _model_cache:
25
+ model = whisperx.load_model(model_id, device, compute_type=compute_type)
26
+ _model_cache[key] = model
27
+ return _model_cache[key]
28
+
29
+
30
+ _model_cache = {}
@@ -56,6 +56,22 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
56
56
  return results
57
57
 
58
58
 
59
+ @pxt.udf
60
+ def yolo_to_coco(detections: dict) -> list:
61
+ bboxes, labels = detections['bboxes'], detections['labels']
62
+ num_annotations = len(detections['bboxes'])
63
+ assert num_annotations == len(detections['labels'])
64
+ result = []
65
+ for i in range(num_annotations):
66
+ bbox = bboxes[i]
67
+ ann = {
68
+ 'bbox': [round(bbox[0]), round(bbox[1]), round(bbox[2] - bbox[0]), round(bbox[3] - bbox[1])],
69
+ 'category': labels[i],
70
+ }
71
+ result.append(ann)
72
+ return result
73
+
74
+
59
75
  def _images_to_tensors(images: Iterable[PIL.Image.Image], exp: Exp) -> Iterator[torch.Tensor]:
60
76
  for image in images:
61
77
  image_transform, _ = _val_transform(np.array(image), None, exp.test_size)
@@ -1,9 +1,7 @@
1
1
  from .aggregate_function import Aggregator, AggregateFunction, uda
2
- from .batched_function import BatchedFunction, ExplicitBatchedFunction
3
2
  from .callable_function import CallableFunction
4
3
  from .expr_template_function import ExprTemplateFunction
5
4
  from .function import Function
6
5
  from .function_registry import FunctionRegistry
7
- from .nos_function import NOSFunction
8
6
  from .signature import Signature, Parameter, Batch
9
7
  from .udf import udf, make_function, expr_udf
@@ -72,6 +72,9 @@ class AggregateFunction(Function):
72
72
  if param.lower() in self.RESERVED_PARAMS:
73
73
  raise excs.Error(f'{self.name}(): parameter name {param} is reserved')
74
74
 
75
+ def exec(self, *args: Any, **kwargs: Any) -> Any:
76
+ raise NotImplementedError
77
+
75
78
  def help_str(self) -> str:
76
79
  res = super().help_str()
77
80
  res += '\n\n' + inspect.getdoc(self.agg_cls.update)
@@ -137,7 +140,7 @@ def uda(
137
140
  update_types: List[ts.ColumnType],
138
141
  init_types: Optional[List[ts.ColumnType]] = None,
139
142
  requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
140
- ) -> Callable:
143
+ ) -> Callable[[Type[Aggregator]], AggregateFunction]:
141
144
  """Decorator for user-defined aggregate functions.
142
145
 
143
146
  The decorated class must inherit from Aggregator and implement the following methods:
@@ -159,7 +162,7 @@ def uda(
159
162
  if init_types is None:
160
163
  init_types = []
161
164
 
162
- def decorator(cls: Type[Aggregator]) -> Type[Function]:
165
+ def decorator(cls: Type[Aggregator]) -> AggregateFunction:
163
166
  # validate type parameters
164
167
  num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
165
168
  if num_init_params > 0:
@@ -1,16 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- import sys
5
- from typing import Optional, Dict, Callable, List, Tuple
4
+ from typing import Optional, Callable, Tuple, Any
6
5
  from uuid import UUID
6
+
7
7
  import cloudpickle
8
8
 
9
- import pixeltable.type_system as ts
10
- import pixeltable.exceptions as excs
11
9
  from .function import Function
12
- from .function_registry import FunctionRegistry
13
- from .globals import get_caller_module_path
14
10
  from .signature import Signature
15
11
 
16
12
 
@@ -24,13 +20,48 @@ class CallableFunction(Function):
24
20
 
25
21
  def __init__(
26
22
  self, signature: Signature, py_fn: Callable, self_path: Optional[str] = None,
27
- self_name: Optional[str] = None):
23
+ self_name: Optional[str] = None, batch_size: Optional[int] = None):
28
24
  assert py_fn is not None
29
25
  self.py_fn = py_fn
30
26
  self.self_name = self_name
27
+ self.batch_size = batch_size
31
28
  py_signature = inspect.signature(self.py_fn)
32
29
  super().__init__(signature, py_signature, self_path=self_path)
33
30
 
31
+ @property
32
+ def is_batched(self) -> bool:
33
+ return self.batch_size is not None
34
+
35
+ def exec(self, *args: Any, **kwargs: Any) -> Any:
36
+ if self.is_batched:
37
+ # Pack the batched parameters into singleton lists
38
+ constant_param_names = [p.name for p in self.signature.constant_parameters]
39
+ batched_args = [[arg] for arg in args]
40
+ constant_kwargs = {k: v for k, v in kwargs.items() if k in constant_param_names}
41
+ batched_kwargs = {k: [v] for k, v in kwargs.items() if k not in constant_param_names}
42
+ result = self.py_fn(*batched_args, **constant_kwargs, **batched_kwargs)
43
+ assert len(result) == 1
44
+ return result[0]
45
+ else:
46
+ return self.py_fn(*args, **kwargs)
47
+
48
+ def exec_batch(self, *args: Any, **kwargs: Any) -> list:
49
+ """Execute the function with the given arguments and return the result.
50
+ The arguments are expected to be batched: if the corresponding parameter has type T,
51
+ then the argument should have type T if it's a constant parameter, or list[T] if it's
52
+ a batched parameter.
53
+ """
54
+ assert self.is_batched
55
+ # Unpack the constant parameters
56
+ constant_param_names = [p.name for p in self.signature.constant_parameters]
57
+ constant_kwargs = {k: v[0] for k, v in kwargs.items() if k in constant_param_names}
58
+ batched_kwargs = {k: v for k, v in kwargs.items() if k not in constant_param_names}
59
+ return self.py_fn(*args, **constant_kwargs, **batched_kwargs)
60
+
61
+ # TODO(aaron-siegel): Implement conditional batch sizing
62
+ def get_batch_size(self, *args: Any, **kwargs: Any) -> Optional[int]:
63
+ return self.batch_size
64
+
34
65
  @property
35
66
  def display_name(self) -> str:
36
67
  return self.self_name
@@ -44,7 +75,7 @@ class CallableFunction(Function):
44
75
  res += '\n\n' + inspect.getdoc(self.py_fn)
45
76
  return res
46
77
 
47
- def _as_dict(self) -> Dict:
78
+ def _as_dict(self) -> dict:
48
79
  if self.self_path is None:
49
80
  # this is not a module function
50
81
  from .function_registry import FunctionRegistry
@@ -53,17 +84,30 @@ class CallableFunction(Function):
53
84
  return super()._as_dict()
54
85
 
55
86
  @classmethod
56
- def _from_dict(cls, d: Dict) -> Function:
87
+ def _from_dict(cls, d: dict) -> Function:
57
88
  if 'id' in d:
58
89
  from .function_registry import FunctionRegistry
59
90
  return FunctionRegistry.get().get_stored_function(UUID(hex=d['id']))
60
91
  return super()._from_dict(d)
61
92
 
62
- def to_store(self) -> Tuple[Dict, bytes]:
63
- return (self.signature.as_dict(), cloudpickle.dumps(self.py_fn))
93
+ def to_store(self) -> tuple[dict, bytes]:
94
+ md = self.signature.as_dict()
95
+ if self.batch_size is not None:
96
+ md['batch_size'] = self.batch_size
97
+ return md, cloudpickle.dumps(self.py_fn)
64
98
 
65
99
  @classmethod
66
- def from_store(cls, name: Optional[str], md: Dict, binary_obj: bytes) -> Function:
100
+ def from_store(cls, name: Optional[str], md: dict, binary_obj: bytes) -> Function:
67
101
  py_fn = cloudpickle.loads(binary_obj)
68
102
  assert isinstance(py_fn, Callable)
69
- return CallableFunction(Signature.from_dict(md), py_fn, self_name=name)
103
+ return CallableFunction(Signature.from_dict(md), py_fn, self_name=name, batch_size=md.get('batch_size'))
104
+
105
+ def validate_call(self, bound_args: dict[str, Any]) -> None:
106
+ import pixeltable.exprs as exprs
107
+ if self.is_batched:
108
+ for param in self.signature.constant_parameters:
109
+ if param.name in bound_args and isinstance(bound_args[param.name], exprs.Expr):
110
+ raise ValueError(
111
+ f'{self.display_name}(): '
112
+ f'parameter {param.name} must be a constant value, not a Pixeltable expression'
113
+ )
@@ -1,9 +1,8 @@
1
1
  import inspect
2
- from typing import Dict, Optional, Callable, List
2
+ from typing import Dict, Optional, Any
3
3
 
4
4
  import pixeltable
5
5
  import pixeltable.exceptions as excs
6
- import pixeltable.type_system as ts
7
6
  from .function import Function
8
7
  from .signature import Signature, Parameter
9
8
 
@@ -51,6 +50,7 @@ class ExprTemplateFunction(Function):
51
50
  {param_name: default for param_name, default in self.defaults.items() if param_name not in bound_args})
52
51
  result = self.expr.copy()
53
52
  import pixeltable.exprs as exprs
53
+ arg_exprs: dict[exprs.Expr, exprs.Expr] = {}
54
54
  for param_name, arg in bound_args.items():
55
55
  param_expr = self.param_exprs_by_name[param_name]
56
56
  if not isinstance(arg, exprs.Expr):
@@ -60,11 +60,22 @@ class ExprTemplateFunction(Function):
60
60
  raise excs.Error(f'{self.self_name}(): cannot convert argument {arg} to a Pixeltable expression')
61
61
  else:
62
62
  arg_expr = arg
63
- result = result.substitute(param_expr, arg_expr)
63
+ arg_exprs[param_expr] = arg_expr
64
+ result = result.substitute(arg_exprs)
64
65
  import pixeltable.exprs as exprs
65
66
  assert not result.contains(exprs.Variable)
66
67
  return result
67
68
 
69
+ def exec(self, *args: Any, **kwargs: Any) -> Any:
70
+ expr = self.instantiate(*args, **kwargs)
71
+ import pixeltable.exprs as exprs
72
+ row_builder = exprs.RowBuilder(output_exprs=[expr], columns=[], input_exprs=[])
73
+ import pixeltable.exec as exec
74
+ row_batch = exec.DataRowBatch(tbl=None, row_builder=row_builder, len=1)
75
+ row = row_batch[0]
76
+ row_builder.eval(row, ctx=row_builder.default_eval_ctx)
77
+ return row[row_builder.get_output_exprs()[0].slot_idx]
78
+
68
79
  @property
69
80
  def display_name(self) -> str:
70
81
  return self.self_name
@@ -3,9 +3,10 @@ from __future__ import annotations
3
3
  import abc
4
4
  import importlib
5
5
  import inspect
6
- import pixeltable
7
- from typing import Optional, Dict, Any, Tuple
6
+ from typing import Optional, Dict, Any, Tuple, Callable
8
7
 
8
+ import pixeltable
9
+ import pixeltable.type_system as ts
9
10
  from .globals import resolve_symbol
10
11
  from .signature import Signature
11
12
 
@@ -18,10 +19,13 @@ class Function(abc.ABC):
18
19
  via the member self_path.
19
20
  """
20
21
 
21
- def __init__(self, signature: Signature, py_signature: inspect.Signature, self_path: Optional[str] = None):
22
+ def __init__(
23
+ self, signature: Signature, py_signature: inspect.Signature, self_path: Optional[str] = None
24
+ ):
22
25
  self.signature = signature
23
26
  self.py_signature = py_signature
24
27
  self.self_path = self_path # fully-qualified path to self
28
+ self._conditional_return_type: Optional[Callable[..., ts.ColumnType]] = None
25
29
 
26
30
  @property
27
31
  def name(self) -> str:
@@ -40,7 +44,7 @@ class Function(abc.ABC):
40
44
  def help_str(self) -> str:
41
45
  return self.display_name + str(self.signature)
42
46
 
43
- def __call__(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.Expr':
47
+ def __call__(self, *args: Any, **kwargs: Any) -> 'pixeltable.exprs.Expr':
44
48
  from pixeltable import exprs
45
49
  bound_args = self.py_signature.bind(*args, **kwargs)
46
50
  self.validate_call(bound_args.arguments)
@@ -50,6 +54,33 @@ class Function(abc.ABC):
50
54
  """Override this to do custom validation of the arguments"""
51
55
  pass
52
56
 
57
+ def call_return_type(self, kwargs: dict[str, Any]) -> ts.ColumnType:
58
+ """Return the type of the value returned by calling this function with the given arguments"""
59
+ if self._conditional_return_type is None:
60
+ return self.signature.return_type
61
+ bound_args = self.py_signature.bind(**kwargs)
62
+ kw_args: dict[str, Any] = {}
63
+ sig = inspect.signature(self._conditional_return_type)
64
+ for param in sig.parameters.values():
65
+ if param.name in bound_args.arguments:
66
+ kw_args[param.name] = bound_args.arguments[param.name]
67
+ return self._conditional_return_type(**kw_args)
68
+
69
+ def conditional_return_type(self, fn: Callable[..., ts.ColumnType]) -> Callable[..., ts.ColumnType]:
70
+ """Instance decorator for specifying a conditional return type for this function"""
71
+ # verify that call_return_type only has parameters that are also present in the signature
72
+ sig = inspect.signature(fn)
73
+ for param in sig.parameters.values():
74
+ if param.name not in self.signature.parameters:
75
+ raise ValueError(f'`conditional_return_type` has parameter `{param.name}` that is not in the signature')
76
+ self._conditional_return_type = fn
77
+ return fn
78
+
79
+ @abc.abstractmethod
80
+ def exec(self, *args: Any, **kwargs: Any) -> Any:
81
+ """Execute the function with the given arguments and return the result."""
82
+ pass
83
+
53
84
  def __eq__(self, other: object) -> bool:
54
85
  if not isinstance(other, self.__class__):
55
86
  return False
@@ -29,21 +29,12 @@ class Signature:
29
29
  """
30
30
  Represents the signature of a Pixeltable function.
31
31
 
32
- Regarding return type:
33
- - most functions will have a fixed return type, which is specified directly
34
- - some functions will have a return type that depends on the argument values;
35
- ex.: PIL.Image.Image.resize() returns an image with dimensions specified as a parameter
36
- - in the latter case, the 'return_type' field is a function that takes the bound arguments and returns the
37
- return type; if no bound arguments are specified, a generic return type is returned (eg, ImageType() without a
38
- size)
39
32
  - self.is_batched: return type is a Batch[...] type
40
33
  """
41
34
  SPECIAL_PARAM_NAMES = ['group_by', 'order_by']
42
35
 
43
- def __init__(
44
- self,
45
- return_type: Union[ts.ColumnType, Callable[[Dict[str, Any]], ts.ColumnType]],
46
- parameters: List[Parameter], is_batched: bool = False):
36
+ def __init__(self, return_type: ts.ColumnType, parameters: List[Parameter], is_batched: bool = False):
37
+ assert isinstance(return_type, ts.ColumnType)
47
38
  self.return_type = return_type
48
39
  self.is_batched = is_batched
49
40
  # we rely on the ordering guarantee of dicts in Python >=3.7
@@ -52,10 +43,9 @@ class Signature:
52
43
  self.constant_parameters = [p for p in parameters if not p.is_batched]
53
44
  self.batched_parameters = [p for p in parameters if p.is_batched]
54
45
 
55
- def get_return_type(self, bound_args: Optional[Dict[str, Any]] = None) -> ts.ColumnType:
56
- if isinstance(self.return_type, ts.ColumnType):
57
- return self.return_type
58
- return self.return_type(bound_args)
46
+ def get_return_type(self) -> ts.ColumnType:
47
+ assert isinstance(self.return_type, ts.ColumnType)
48
+ return self.return_type
59
49
 
60
50
  def as_dict(self) -> Dict[str, Any]:
61
51
  result = {
pixeltable/func/udf.py CHANGED
@@ -6,7 +6,6 @@ from typing import List, Callable, Optional, overload, Any
6
6
  import pixeltable as pxt
7
7
  import pixeltable.exceptions as excs
8
8
  import pixeltable.type_system as ts
9
- from .batched_function import ExplicitBatchedFunction
10
9
  from .callable_function import CallableFunction
11
10
  from .expr_template_function import ExprTemplateFunction
12
11
  from .function import Function
@@ -29,7 +28,7 @@ def udf(
29
28
  batch_size: Optional[int] = None,
30
29
  substitute_fn: Optional[Callable] = None,
31
30
  _force_stored: bool = False
32
- ) -> Callable: ...
31
+ ) -> Callable[[Callable], Function]: ...
33
32
 
34
33
 
35
34
  def udf(*args, **kwargs):
@@ -62,8 +61,8 @@ def udf(*args, **kwargs):
62
61
 
63
62
  def decorator(decorated_fn: Callable):
64
63
  return make_function(
65
- decorated_fn, return_type, param_types, batch_size, substitute_fn=substitute_fn,
66
- force_stored=force_stored)
64
+ decorated_fn, return_type, param_types, batch_size,
65
+ substitute_fn=substitute_fn, force_stored=force_stored)
67
66
 
68
67
  return decorator
69
68
 
@@ -78,8 +77,8 @@ def make_function(
78
77
  force_stored: bool = False
79
78
  ) -> Function:
80
79
  """
81
- Constructs a `CallableFunction` or `BatchedFunction`, depending on the
82
- supplied parameters. If `substitute_fn` is specified, then `decorated_fn`
80
+ Constructs a `CallableFunction` from the specified parameters.
81
+ If `substitute_fn` is specified, then `decorated_fn`
83
82
  will be used only for its signature, with execution delegated to
84
83
  `substitute_fn`.
85
84
  """
@@ -117,11 +116,8 @@ def make_function(
117
116
  raise excs.Error(f'{errmsg_name}(): @udf decorator with a `substitute_fn` can only be used in a module')
118
117
  py_fn = substitute_fn
119
118
 
120
- if batch_size is None:
121
- result = CallableFunction(signature=sig, py_fn=py_fn, self_path=function_path, self_name=function_name)
122
- else:
123
- result = ExplicitBatchedFunction(
124
- signature=sig, batch_size=batch_size, invoker_fn=py_fn, self_path=function_path)
119
+ result = CallableFunction(
120
+ signature=sig, py_fn=py_fn, self_path=function_path, self_name=function_name, batch_size=batch_size)
125
121
 
126
122
  # If this function is part of a module, register it
127
123
  if function_path is not None:
@@ -135,7 +131,7 @@ def make_function(
135
131
  def expr_udf(py_fn: Callable) -> ExprTemplateFunction: ...
136
132
 
137
133
  @overload
138
- def expr_udf(*, param_types: Optional[List[ts.ColumnType]] = None) -> Callable: ...
134
+ def expr_udf(*, param_types: Optional[List[ts.ColumnType]] = None) -> Callable[[Callable], ExprTemplateFunction]: ...
139
135
 
140
136
  def expr_udf(*args: Any, **kwargs: Any) -> Any:
141
137
  def decorator(py_fn: Callable, param_types: Optional[List[ts.ColumnType]]) -> ExprTemplateFunction:
@@ -6,8 +6,13 @@ import pixeltable as pxt
6
6
  from pixeltable import env
7
7
 
8
8
 
9
- def fireworks_client() -> fireworks.client.Fireworks:
10
- return env.Env.get().get_client('fireworks', lambda api_key: fireworks.client.Fireworks(api_key=api_key))
9
+ @env.register_client('fireworks')
10
+ def _(api_key: str) -> fireworks.client.Fireworks:
11
+ return fireworks.client.Fireworks(api_key=api_key)
12
+
13
+
14
+ def _fireworks_client() -> fireworks.client.Fireworks:
15
+ return env.Env.get().get_client('fireworks')
11
16
 
12
17
 
13
18
  @pxt.udf
@@ -26,8 +31,8 @@ def chat_completions(
26
31
  'top_p': top_p,
27
32
  'temperature': temperature
28
33
  }
29
- kwargs_not_none = dict(filter(lambda x: x[1] is not None, kwargs.items()))
30
- return fireworks_client().chat.completions.create(
34
+ kwargs_not_none = {k: v for k, v in kwargs.items() if v is not None}
35
+ return _fireworks_client().chat.completions.create(
31
36
  model=model,
32
37
  messages=messages,
33
38
  **kwargs_not_none