pixeltable 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (99) hide show
  1. pixeltable/__init__.py +18 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +31 -50
  4. pixeltable/catalog/insertable_table.py +7 -6
  5. pixeltable/catalog/table.py +171 -57
  6. pixeltable/catalog/table_version.py +417 -140
  7. pixeltable/catalog/table_version_path.py +2 -2
  8. pixeltable/dataframe.py +239 -121
  9. pixeltable/env.py +82 -16
  10. pixeltable/exec/__init__.py +2 -1
  11. pixeltable/exec/cache_prefetch_node.py +1 -1
  12. pixeltable/exec/data_row_batch.py +6 -7
  13. pixeltable/exec/expr_eval_node.py +28 -28
  14. pixeltable/exec/in_memory_data_node.py +11 -7
  15. pixeltable/exec/sql_scan_node.py +7 -6
  16. pixeltable/exprs/__init__.py +4 -3
  17. pixeltable/exprs/column_ref.py +9 -0
  18. pixeltable/exprs/comparison.py +3 -3
  19. pixeltable/exprs/data_row.py +5 -1
  20. pixeltable/exprs/expr.py +15 -7
  21. pixeltable/exprs/function_call.py +17 -15
  22. pixeltable/exprs/image_member_access.py +9 -28
  23. pixeltable/exprs/in_predicate.py +96 -0
  24. pixeltable/exprs/inline_array.py +13 -11
  25. pixeltable/exprs/inline_dict.py +15 -13
  26. pixeltable/exprs/literal.py +16 -4
  27. pixeltable/exprs/row_builder.py +15 -41
  28. pixeltable/exprs/similarity_expr.py +65 -0
  29. pixeltable/ext/__init__.py +5 -0
  30. pixeltable/ext/functions/yolox.py +92 -0
  31. pixeltable/func/__init__.py +0 -2
  32. pixeltable/func/aggregate_function.py +18 -15
  33. pixeltable/func/callable_function.py +57 -13
  34. pixeltable/func/expr_template_function.py +20 -3
  35. pixeltable/func/function.py +35 -4
  36. pixeltable/func/globals.py +24 -14
  37. pixeltable/func/signature.py +23 -27
  38. pixeltable/func/udf.py +13 -12
  39. pixeltable/functions/__init__.py +8 -8
  40. pixeltable/functions/eval.py +7 -8
  41. pixeltable/functions/huggingface.py +64 -17
  42. pixeltable/functions/openai.py +36 -3
  43. pixeltable/functions/pil/image.py +61 -64
  44. pixeltable/functions/together.py +21 -0
  45. pixeltable/functions/util.py +11 -0
  46. pixeltable/globals.py +425 -0
  47. pixeltable/index/__init__.py +2 -0
  48. pixeltable/index/base.py +51 -0
  49. pixeltable/index/embedding_index.py +168 -0
  50. pixeltable/io/__init__.py +3 -0
  51. pixeltable/{utils → io}/hf_datasets.py +48 -17
  52. pixeltable/io/pandas.py +148 -0
  53. pixeltable/{utils → io}/parquet.py +58 -33
  54. pixeltable/iterators/__init__.py +1 -1
  55. pixeltable/iterators/base.py +4 -0
  56. pixeltable/iterators/document.py +218 -97
  57. pixeltable/iterators/video.py +8 -9
  58. pixeltable/metadata/__init__.py +7 -3
  59. pixeltable/metadata/converters/convert_12.py +3 -0
  60. pixeltable/metadata/converters/convert_13.py +41 -0
  61. pixeltable/metadata/schema.py +45 -22
  62. pixeltable/plan.py +15 -51
  63. pixeltable/store.py +38 -41
  64. pixeltable/tool/create_test_db_dump.py +39 -4
  65. pixeltable/type_system.py +47 -96
  66. pixeltable/utils/documents.py +42 -12
  67. pixeltable/utils/http_server.py +70 -0
  68. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/METADATA +14 -10
  69. pixeltable-0.2.6.dist-info/RECORD +119 -0
  70. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/WHEEL +1 -1
  71. pixeltable/client.py +0 -604
  72. pixeltable/exprs/image_similarity_predicate.py +0 -58
  73. pixeltable/func/batched_function.py +0 -53
  74. pixeltable/tests/conftest.py +0 -177
  75. pixeltable/tests/functions/test_fireworks.py +0 -42
  76. pixeltable/tests/functions/test_functions.py +0 -60
  77. pixeltable/tests/functions/test_huggingface.py +0 -158
  78. pixeltable/tests/functions/test_openai.py +0 -152
  79. pixeltable/tests/functions/test_together.py +0 -111
  80. pixeltable/tests/test_audio.py +0 -65
  81. pixeltable/tests/test_catalog.py +0 -27
  82. pixeltable/tests/test_client.py +0 -21
  83. pixeltable/tests/test_component_view.py +0 -370
  84. pixeltable/tests/test_dataframe.py +0 -439
  85. pixeltable/tests/test_dirs.py +0 -107
  86. pixeltable/tests/test_document.py +0 -120
  87. pixeltable/tests/test_exprs.py +0 -805
  88. pixeltable/tests/test_function.py +0 -324
  89. pixeltable/tests/test_migration.py +0 -43
  90. pixeltable/tests/test_nos.py +0 -54
  91. pixeltable/tests/test_snapshot.py +0 -208
  92. pixeltable/tests/test_table.py +0 -1267
  93. pixeltable/tests/test_transactional_directory.py +0 -42
  94. pixeltable/tests/test_types.py +0 -22
  95. pixeltable/tests/test_video.py +0 -159
  96. pixeltable/tests/test_view.py +0 -530
  97. pixeltable/tests/utils.py +0 -408
  98. pixeltable-0.2.4.dist-info/RECORD +0 -132
  99. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/LICENSE +0 -0
@@ -0,0 +1,92 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Iterable, Iterator
4
+ from urllib.request import urlretrieve
5
+
6
+ import PIL.Image
7
+ import numpy as np
8
+ import torch
9
+ from yolox.data import ValTransform
10
+ from yolox.exp import get_exp, Exp
11
+ from yolox.models import YOLOX
12
+ from yolox.utils import postprocess
13
+
14
+ import pixeltable as pxt
15
+ from pixeltable import env
16
+ from pixeltable.func import Batch
17
+ from pixeltable.functions.util import resolve_torch_device
18
+
19
+ _logger = logging.getLogger('pixeltable')
20
+
21
+
22
+ @pxt.udf(batch_size=4)
23
+ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
24
+ """
25
+ Runs the specified YOLOX object detection model on an image.
26
+
27
+ YOLOX support is part of the `pixeltable.ext` package: long-term support is not guaranteed, and it is not
28
+ intended for use in production applications.
29
+
30
+ Parameters:
31
+ - `model_id` - one of: `yolox_nano, `yolox_tiny`, `yolox_s`, `yolox_m`, `yolox_l`, `yolox_x`
32
+ - `threshold` - the threshold for object detection
33
+ """
34
+ model, exp = _lookup_model(model_id, 'cpu')
35
+ image_tensors = list(_images_to_tensors(images, exp))
36
+ batch_tensor = torch.stack(image_tensors)
37
+
38
+ with torch.no_grad():
39
+ output_tensor = model(batch_tensor)
40
+
41
+ outputs = postprocess(
42
+ output_tensor, 80, threshold, exp.nmsthre, class_agnostic=False
43
+ )
44
+
45
+ results: list[dict] = []
46
+ for image in images:
47
+ ratio = min(exp.test_size[0] / image.height, exp.test_size[1] / image.width)
48
+ if outputs[0] is None:
49
+ results.append({'bboxes': [], 'scores': [], 'labels': []})
50
+ else:
51
+ results.append({
52
+ 'bboxes': [(output[:4] / ratio).tolist() for output in outputs[0]],
53
+ 'scores': [output[4].item() * output[5].item() for output in outputs[0]],
54
+ 'labels': [int(output[6]) for output in outputs[0]]
55
+ })
56
+ return results
57
+
58
+
59
+ def _images_to_tensors(images: Iterable[PIL.Image.Image], exp: Exp) -> Iterator[torch.Tensor]:
60
+ for image in images:
61
+ image_transform, _ = _val_transform(np.array(image), None, exp.test_size)
62
+ yield torch.from_numpy(image_transform)
63
+
64
+
65
+ def _lookup_model(model_id: str, device: str) -> (YOLOX, Exp):
66
+ key = (model_id, device)
67
+ if key in _model_cache:
68
+ return _model_cache[key]
69
+
70
+ weights_url = f'https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/{model_id}.pth'
71
+ weights_file = Path(f'{env.Env.get().tmp_dir}/{model_id}.pth')
72
+ if not weights_file.exists():
73
+ _logger.info(f'Downloading weights for YOLOX model {model_id}: from {weights_url} -> {weights_file}')
74
+ urlretrieve(weights_url, weights_file)
75
+
76
+ exp = get_exp(exp_name=model_id)
77
+ model = exp.get_model().to(device)
78
+
79
+ model.eval()
80
+ model.head.training = False
81
+ model.training = False
82
+
83
+ # Load in the weights from training
84
+ weights = torch.load(weights_file, map_location=torch.device(device))
85
+ model.load_state_dict(weights['model'])
86
+
87
+ _model_cache[key] = (model, exp)
88
+ return model, exp
89
+
90
+
91
+ _model_cache = {}
92
+ _val_transform = ValTransform(legacy=False)
@@ -1,9 +1,7 @@
1
1
  from .aggregate_function import Aggregator, AggregateFunction, uda
2
- from .batched_function import BatchedFunction, ExplicitBatchedFunction
3
2
  from .callable_function import CallableFunction
4
3
  from .expr_template_function import ExprTemplateFunction
5
4
  from .function import Function
6
5
  from .function_registry import FunctionRegistry
7
- from .nos_function import NOSFunction
8
6
  from .signature import Signature, Parameter, Batch
9
7
  from .udf import udf, make_function, expr_udf
@@ -3,13 +3,14 @@ from __future__ import annotations
3
3
  import abc
4
4
  import importlib
5
5
  import inspect
6
- from typing import Optional, Any, Type, List, Dict
6
+ from typing import Optional, Any, Type, List, Dict, Callable
7
7
  import itertools
8
8
 
9
9
  import pixeltable.exceptions as excs
10
10
  import pixeltable.type_system as ts
11
11
  from .function import Function
12
12
  from .signature import Signature, Parameter
13
+ from .globals import validate_symbol_path
13
14
 
14
15
 
15
16
  class Aggregator(abc.ABC):
@@ -71,6 +72,9 @@ class AggregateFunction(Function):
71
72
  if param.lower() in self.RESERVED_PARAMS:
72
73
  raise excs.Error(f'{self.name}(): parameter name {param} is reserved')
73
74
 
75
+ def exec(self, *args: Any, **kwargs: Any) -> Any:
76
+ raise NotImplementedError
77
+
74
78
  def help_str(self) -> str:
75
79
  res = super().help_str()
76
80
  res += '\n\n' + inspect.getdoc(self.agg_cls.update)
@@ -136,8 +140,7 @@ def uda(
136
140
  update_types: List[ts.ColumnType],
137
141
  init_types: Optional[List[ts.ColumnType]] = None,
138
142
  requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
139
- name: Optional[str] = None
140
- ) -> Type[Aggregator]:
143
+ ) -> Callable:
141
144
  """Decorator for user-defined aggregate functions.
142
145
 
143
146
  The decorated class must inherit from Aggregator and implement the following methods:
@@ -155,14 +158,11 @@ def uda(
155
158
  - requires_order_by: if True, the first parameter to the function is the order-by expression
156
159
  - allows_std_agg: if True, the function can be used as a standard aggregate function w/o a window
157
160
  - allows_window: if True, the function can be used with a window
158
- - name: name of the AggregateFunction instance; if None, the class name is used
159
161
  """
160
- if name is not None and not name.isidentifier():
161
- raise excs.Error(f'Invalid name: {name}')
162
162
  if init_types is None:
163
163
  init_types = []
164
164
 
165
- def decorator(cls: Type[Aggregator]) -> Type[Aggregator]:
165
+ def decorator(cls: Type[Aggregator]) -> Type[Function]:
166
166
  # validate type parameters
167
167
  num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
168
168
  if num_init_params > 0:
@@ -178,17 +178,20 @@ def uda(
178
178
  assert value_type is not None
179
179
 
180
180
  # the AggregateFunction instance resides in the same module as cls
181
- module_path = cls.__module__
182
- nonlocal name
183
- name = name or cls.__name__
184
- instance_path = f'{module_path}.{name}'
181
+ class_path = f'{cls.__module__}.{cls.__qualname__}'
182
+ # nonlocal name
183
+ # name = name or cls.__name__
184
+ # instance_path_elements = class_path.split('.')[:-1] + [name]
185
+ # instance_path = '.'.join(instance_path_elements)
185
186
 
186
187
  # create the corresponding AggregateFunction instance
187
188
  instance = AggregateFunction(
188
- cls, instance_path, init_types, update_types, value_type, requires_order_by, allows_std_agg, allows_window)
189
- module = importlib.import_module(module_path)
190
- setattr(module, name, instance)
189
+ cls, class_path, init_types, update_types, value_type, requires_order_by, allows_std_agg, allows_window)
190
+ # do the path validation at the very end, in order to be able to write tests for the other failure cases
191
+ validate_symbol_path(class_path)
192
+ #module = importlib.import_module(cls.__module__)
193
+ #setattr(module, name, instance)
191
194
 
192
- return cls
195
+ return instance
193
196
 
194
197
  return decorator
@@ -1,16 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- import sys
5
- from typing import Optional, Dict, Callable, List, Tuple
4
+ from typing import Optional, Callable, Tuple, Any
6
5
  from uuid import UUID
6
+
7
7
  import cloudpickle
8
8
 
9
- import pixeltable.type_system as ts
10
- import pixeltable.exceptions as excs
11
9
  from .function import Function
12
- from .function_registry import FunctionRegistry
13
- from .globals import get_caller_module_path
14
10
  from .signature import Signature
15
11
 
16
12
 
@@ -24,13 +20,48 @@ class CallableFunction(Function):
24
20
 
25
21
  def __init__(
26
22
  self, signature: Signature, py_fn: Callable, self_path: Optional[str] = None,
27
- self_name: Optional[str] = None):
23
+ self_name: Optional[str] = None, batch_size: Optional[int] = None):
28
24
  assert py_fn is not None
29
25
  self.py_fn = py_fn
30
26
  self.self_name = self_name
27
+ self.batch_size = batch_size
31
28
  py_signature = inspect.signature(self.py_fn)
32
29
  super().__init__(signature, py_signature, self_path=self_path)
33
30
 
31
+ @property
32
+ def is_batched(self) -> bool:
33
+ return self.batch_size is not None
34
+
35
+ def exec(self, *args: Any, **kwargs: Any) -> Any:
36
+ if self.is_batched:
37
+ # Pack the batched parameters into singleton lists
38
+ constant_param_names = [p.name for p in self.signature.constant_parameters]
39
+ batched_args = [[arg] for arg in args]
40
+ constant_kwargs = {k: v for k, v in kwargs.items() if k in constant_param_names}
41
+ batched_kwargs = {k: [v] for k, v in kwargs.items() if k not in constant_param_names}
42
+ result = self.py_fn(*batched_args, **constant_kwargs, **batched_kwargs)
43
+ assert len(result) == 1
44
+ return result[0]
45
+ else:
46
+ return self.py_fn(*args, **kwargs)
47
+
48
+ def exec_batch(self, *args: Any, **kwargs: Any) -> list:
49
+ """Execute the function with the given arguments and return the result.
50
+ The arguments are expected to be batched: if the corresponding parameter has type T,
51
+ then the argument should have type T if it's a constant parameter, or list[T] if it's
52
+ a batched parameter.
53
+ """
54
+ assert self.is_batched
55
+ # Unpack the constant parameters
56
+ constant_param_names = [p.name for p in self.signature.constant_parameters]
57
+ constant_kwargs = {k: v[0] for k, v in kwargs.items() if k in constant_param_names}
58
+ batched_kwargs = {k: v for k, v in kwargs.items() if k not in constant_param_names}
59
+ return self.py_fn(*args, **constant_kwargs, **batched_kwargs)
60
+
61
+ # TODO(aaron-siegel): Implement conditional batch sizing
62
+ def get_batch_size(self, *args: Any, **kwargs: Any) -> Optional[int]:
63
+ return self.batch_size
64
+
34
65
  @property
35
66
  def display_name(self) -> str:
36
67
  return self.self_name
@@ -44,7 +75,7 @@ class CallableFunction(Function):
44
75
  res += '\n\n' + inspect.getdoc(self.py_fn)
45
76
  return res
46
77
 
47
- def _as_dict(self) -> Dict:
78
+ def _as_dict(self) -> dict:
48
79
  if self.self_path is None:
49
80
  # this is not a module function
50
81
  from .function_registry import FunctionRegistry
@@ -53,17 +84,30 @@ class CallableFunction(Function):
53
84
  return super()._as_dict()
54
85
 
55
86
  @classmethod
56
- def _from_dict(cls, d: Dict) -> Function:
87
+ def _from_dict(cls, d: dict) -> Function:
57
88
  if 'id' in d:
58
89
  from .function_registry import FunctionRegistry
59
90
  return FunctionRegistry.get().get_stored_function(UUID(hex=d['id']))
60
91
  return super()._from_dict(d)
61
92
 
62
- def to_store(self) -> Tuple[Dict, bytes]:
63
- return (self.signature.as_dict(), cloudpickle.dumps(self.py_fn))
93
+ def to_store(self) -> tuple[dict, bytes]:
94
+ md = self.signature.as_dict()
95
+ if self.batch_size is not None:
96
+ md['batch_size'] = self.batch_size
97
+ return md, cloudpickle.dumps(self.py_fn)
64
98
 
65
99
  @classmethod
66
- def from_store(cls, name: Optional[str], md: Dict, binary_obj: bytes) -> Function:
100
+ def from_store(cls, name: Optional[str], md: dict, binary_obj: bytes) -> Function:
67
101
  py_fn = cloudpickle.loads(binary_obj)
68
102
  assert isinstance(py_fn, Callable)
69
- return CallableFunction(Signature.from_dict(md), py_fn, self_name=name)
103
+ return CallableFunction(Signature.from_dict(md), py_fn, self_name=name, batch_size=md.get('batch_size'))
104
+
105
+ def validate_call(self, bound_args: dict[str, Any]) -> None:
106
+ import pixeltable.exprs as exprs
107
+ if self.is_batched:
108
+ for param in self.signature.constant_parameters:
109
+ if param.name in bound_args and isinstance(bound_args[param.name], exprs.Expr):
110
+ raise ValueError(
111
+ f'{self.display_name}(): '
112
+ f'parameter {param.name} must be a constant value, not a Pixeltable expression'
113
+ )
@@ -1,9 +1,8 @@
1
1
  import inspect
2
- from typing import Dict, Optional, Callable, List
2
+ from typing import Dict, Optional, Any
3
3
 
4
4
  import pixeltable
5
5
  import pixeltable.exceptions as excs
6
- import pixeltable.type_system as ts
7
6
  from .function import Function
8
7
  from .signature import Signature, Parameter
9
8
 
@@ -50,13 +49,31 @@ class ExprTemplateFunction(Function):
50
49
  bound_args.update(
51
50
  {param_name: default for param_name, default in self.defaults.items() if param_name not in bound_args})
52
51
  result = self.expr.copy()
52
+ import pixeltable.exprs as exprs
53
53
  for param_name, arg in bound_args.items():
54
54
  param_expr = self.param_exprs_by_name[param_name]
55
- result = result.substitute(param_expr, arg)
55
+ if not isinstance(arg, exprs.Expr):
56
+ # TODO: use the available param_expr.col_type
57
+ arg_expr = exprs.Expr.from_object(arg)
58
+ if arg_expr is None:
59
+ raise excs.Error(f'{self.self_name}(): cannot convert argument {arg} to a Pixeltable expression')
60
+ else:
61
+ arg_expr = arg
62
+ result = result.substitute(param_expr, arg_expr)
56
63
  import pixeltable.exprs as exprs
57
64
  assert not result.contains(exprs.Variable)
58
65
  return result
59
66
 
67
+ def exec(self, *args: Any, **kwargs: Any) -> Any:
68
+ expr = self.instantiate(*args, **kwargs)
69
+ import pixeltable.exprs as exprs
70
+ row_builder = exprs.RowBuilder(output_exprs=[expr], columns=[], input_exprs=[])
71
+ import pixeltable.exec as exec
72
+ row_batch = exec.DataRowBatch(tbl=None, row_builder=row_builder, len=1)
73
+ row = row_batch[0]
74
+ row_builder.eval(row, ctx=row_builder.default_eval_ctx)
75
+ return row[row_builder.get_output_exprs()[0].slot_idx]
76
+
60
77
  @property
61
78
  def display_name(self) -> str:
62
79
  return self.self_name
@@ -3,9 +3,10 @@ from __future__ import annotations
3
3
  import abc
4
4
  import importlib
5
5
  import inspect
6
- import pixeltable
7
- from typing import Optional, Dict, Any, Tuple
6
+ from typing import Optional, Dict, Any, Tuple, Callable
8
7
 
8
+ import pixeltable
9
+ import pixeltable.type_system as ts
9
10
  from .globals import resolve_symbol
10
11
  from .signature import Signature
11
12
 
@@ -18,10 +19,13 @@ class Function(abc.ABC):
18
19
  via the member self_path.
19
20
  """
20
21
 
21
- def __init__(self, signature: Signature, py_signature: inspect.Signature, self_path: Optional[str] = None):
22
+ def __init__(
23
+ self, signature: Signature, py_signature: inspect.Signature, self_path: Optional[str] = None
24
+ ):
22
25
  self.signature = signature
23
26
  self.py_signature = py_signature
24
27
  self.self_path = self_path # fully-qualified path to self
28
+ self._conditional_return_type: Optional[Callable[..., ts.ColumnType]] = None
25
29
 
26
30
  @property
27
31
  def name(self) -> str:
@@ -40,7 +44,7 @@ class Function(abc.ABC):
40
44
  def help_str(self) -> str:
41
45
  return self.display_name + str(self.signature)
42
46
 
43
- def __call__(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.Expr':
47
+ def __call__(self, *args: Any, **kwargs: Any) -> 'pixeltable.exprs.Expr':
44
48
  from pixeltable import exprs
45
49
  bound_args = self.py_signature.bind(*args, **kwargs)
46
50
  self.validate_call(bound_args.arguments)
@@ -50,6 +54,33 @@ class Function(abc.ABC):
50
54
  """Override this to do custom validation of the arguments"""
51
55
  pass
52
56
 
57
+ def call_return_type(self, kwargs: dict[str, Any]) -> ts.ColumnType:
58
+ """Return the type of the value returned by calling this function with the given arguments"""
59
+ if self._conditional_return_type is None:
60
+ return self.signature.return_type
61
+ bound_args = self.py_signature.bind(**kwargs)
62
+ kw_args: dict[str, Any] = {}
63
+ sig = inspect.signature(self._conditional_return_type)
64
+ for param in sig.parameters.values():
65
+ if param.name in bound_args.arguments:
66
+ kw_args[param.name] = bound_args.arguments[param.name]
67
+ return self._conditional_return_type(**kw_args)
68
+
69
+ def conditional_return_type(self, fn: Callable[..., ts.ColumnType]) -> Callable[..., ts.ColumnType]:
70
+ """Instance decorator for specifying a conditional return type for this function"""
71
+ # verify that call_return_type only has parameters that are also present in the signature
72
+ sig = inspect.signature(fn)
73
+ for param in sig.parameters.values():
74
+ if param.name not in self.signature.parameters:
75
+ raise ValueError(f'`conditional_return_type` has parameter `{param.name}` that is not in the signature')
76
+ self._conditional_return_type = fn
77
+ return fn
78
+
79
+ @abc.abstractmethod
80
+ def exec(self, *args: Any, **kwargs: Any) -> Any:
81
+ """Execute the function with the given arguments and return the result."""
82
+ pass
83
+
53
84
  def __eq__(self, other: object) -> bool:
54
85
  if not isinstance(other, self.__class__):
55
86
  return False
@@ -1,29 +1,39 @@
1
- from typing import Optional
2
- from types import ModuleType
3
1
  import importlib
4
2
  import inspect
3
+ from types import ModuleType
4
+ from typing import Optional
5
5
 
6
+ import pixeltable.exceptions as excs
6
7
 
7
- def resolve_symbol(symbol_path: str) -> object:
8
+
9
+ def resolve_symbol(symbol_path: str) -> Optional[object]:
8
10
  path_elems = symbol_path.split('.')
9
11
  module: Optional[ModuleType] = None
10
- if path_elems[0:2] == ['pixeltable', 'functions'] and len(path_elems) > 2:
11
- # if this is a pixeltable.functions submodule, it cannot be resolved via pixeltable.functions;
12
- # try to import the submodule directly
13
- submodule_path = '.'.join(path_elems[0:3])
12
+ i = len(path_elems) - 1
13
+ while i > 0 and module is None:
14
14
  try:
15
- module = importlib.import_module(submodule_path)
16
- path_elems = path_elems[3:]
15
+ module = importlib.import_module('.'.join(path_elems[:i]))
17
16
  except ModuleNotFoundError:
18
- pass
19
- if module is None:
20
- module = importlib.import_module(path_elems[0])
21
- path_elems = path_elems[1:]
17
+ i -= 1
18
+ if i == 0:
19
+ return None # Not resolvable
22
20
  obj = module
23
- for el in path_elems:
21
+ for el in path_elems[i:]:
24
22
  obj = getattr(obj, el)
25
23
  return obj
26
24
 
25
+
26
+ def validate_symbol_path(fn_path: str) -> None:
27
+ path_elems = fn_path.split('.')
28
+ fn_name = path_elems[-1]
29
+ if any(el == '<locals>' for el in path_elems):
30
+ raise excs.Error(
31
+ f'{fn_name}(): nested functions are not supported. Move the function to the module level or into a class.')
32
+ if any(not el.isidentifier() for el in path_elems):
33
+ raise excs.Error(
34
+ f'{fn_name}(): cannot resolve symbol path {fn_path}. Move the function to the module level or into a class.')
35
+
36
+
27
37
  def get_caller_module_path() -> str:
28
38
  """Return the module path of our caller's caller"""
29
39
  stack = inspect.stack()
@@ -29,21 +29,12 @@ class Signature:
29
29
  """
30
30
  Represents the signature of a Pixeltable function.
31
31
 
32
- Regarding return type:
33
- - most functions will have a fixed return type, which is specified directly
34
- - some functions will have a return type that depends on the argument values;
35
- ex.: PIL.Image.Image.resize() returns an image with dimensions specified as a parameter
36
- - in the latter case, the 'return_type' field is a function that takes the bound arguments and returns the
37
- return type; if no bound arguments are specified, a generic return type is returned (eg, ImageType() without a
38
- size)
39
32
  - self.is_batched: return type is a Batch[...] type
40
33
  """
41
34
  SPECIAL_PARAM_NAMES = ['group_by', 'order_by']
42
35
 
43
- def __init__(
44
- self,
45
- return_type: Union[ts.ColumnType, Callable[[Dict[str, Any]], ts.ColumnType]],
46
- parameters: List[Parameter], is_batched: bool = False):
36
+ def __init__(self, return_type: ts.ColumnType, parameters: List[Parameter], is_batched: bool = False):
37
+ assert isinstance(return_type, ts.ColumnType)
47
38
  self.return_type = return_type
48
39
  self.is_batched = is_batched
49
40
  # we rely on the ordering guarantee of dicts in Python >=3.7
@@ -52,10 +43,9 @@ class Signature:
52
43
  self.constant_parameters = [p for p in parameters if not p.is_batched]
53
44
  self.batched_parameters = [p for p in parameters if p.is_batched]
54
45
 
55
- def get_return_type(self, bound_args: Optional[Dict[str, Any]] = None) -> ts.ColumnType:
56
- if isinstance(self.return_type, ts.ColumnType):
57
- return self.return_type
58
- return self.return_type(bound_args)
46
+ def get_return_type(self) -> ts.ColumnType:
47
+ assert isinstance(self.return_type, ts.ColumnType)
48
+ return self.return_type
59
49
 
60
50
  def as_dict(self) -> Dict[str, Any]:
61
51
  result = {
@@ -114,20 +104,12 @@ class Signature:
114
104
  return (col_type, is_batched)
115
105
 
116
106
  @classmethod
117
- def create(
118
- cls, c: Callable,
119
- param_types: Optional[List[ts.ColumnType]] = None,
120
- return_type: Optional[Union[ts.ColumnType, Callable]] = None
121
- ) -> Signature:
122
- """Create a signature for the given Callable.
123
- Infer the parameter and return types, if none are specified.
124
- Raises an exception if the types cannot be inferred.
125
- """
107
+ def create_parameters(
108
+ cls, c: Callable, param_types: Optional[List[ts.ColumnType]] = None) -> List[Parameter]:
126
109
  sig = inspect.signature(c)
127
110
  py_parameters = list(sig.parameters.values())
128
-
129
- # check non-var parameters for name collisions and default value compatibility
130
111
  parameters: List[Parameter] = []
112
+
131
113
  for idx, param in enumerate(py_parameters):
132
114
  if param.name in cls.SPECIAL_PARAM_NAMES:
133
115
  raise excs.Error(f"'{param.name}' is a reserved parameter name")
@@ -135,6 +117,7 @@ class Signature:
135
117
  parameters.append(Parameter(param.name, None, param.kind, False))
136
118
  continue
137
119
 
120
+ # check non-var parameters for name collisions and default value compatibility
138
121
  if param_types is not None:
139
122
  if idx >= len(param_types):
140
123
  raise excs.Error(f'Missing type for parameter {param.name}')
@@ -155,7 +138,20 @@ class Signature:
155
138
 
156
139
  parameters.append(Parameter(param.name, param_type, param.kind, is_batched))
157
140
 
158
- return_is_batched = False
141
+ return parameters
142
+
143
+ @classmethod
144
+ def create(
145
+ cls, c: Callable,
146
+ param_types: Optional[List[ts.ColumnType]] = None,
147
+ return_type: Optional[Union[ts.ColumnType, Callable]] = None
148
+ ) -> Signature:
149
+ """Create a signature for the given Callable.
150
+ Infer the parameter and return types, if none are specified.
151
+ Raises an exception if the types cannot be inferred.
152
+ """
153
+ parameters = cls.create_parameters(c, param_types)
154
+ sig = inspect.signature(c)
159
155
  if return_type is None:
160
156
  return_type, return_is_batched = cls._infer_type(sig.return_annotation)
161
157
  if return_type is None:
pixeltable/func/udf.py CHANGED
@@ -6,11 +6,11 @@ from typing import List, Callable, Optional, overload, Any
6
6
  import pixeltable as pxt
7
7
  import pixeltable.exceptions as excs
8
8
  import pixeltable.type_system as ts
9
- from .batched_function import ExplicitBatchedFunction
10
9
  from .callable_function import CallableFunction
11
10
  from .expr_template_function import ExprTemplateFunction
12
11
  from .function import Function
13
12
  from .function_registry import FunctionRegistry
13
+ from .globals import validate_symbol_path
14
14
  from .signature import Signature
15
15
 
16
16
 
@@ -61,8 +61,8 @@ def udf(*args, **kwargs):
61
61
 
62
62
  def decorator(decorated_fn: Callable):
63
63
  return make_function(
64
- decorated_fn, return_type, param_types, batch_size, substitute_fn=substitute_fn,
65
- force_stored=force_stored)
64
+ decorated_fn, return_type, param_types, batch_size,
65
+ substitute_fn=substitute_fn, force_stored=force_stored)
66
66
 
67
67
  return decorator
68
68
 
@@ -77,8 +77,8 @@ def make_function(
77
77
  force_stored: bool = False
78
78
  ) -> Function:
79
79
  """
80
- Constructs a `CallableFunction` or `BatchedFunction`, depending on the
81
- supplied parameters. If `substitute_fn` is specified, then `decorated_fn`
80
+ Constructs a `CallableFunction` from the specified parameters.
81
+ If `substitute_fn` is specified, then `decorated_fn`
82
82
  will be used only for its signature, with execution delegated to
83
83
  `substitute_fn`.
84
84
  """
@@ -116,14 +116,13 @@ def make_function(
116
116
  raise excs.Error(f'{errmsg_name}(): @udf decorator with a `substitute_fn` can only be used in a module')
117
117
  py_fn = substitute_fn
118
118
 
119
- if batch_size is None:
120
- result = CallableFunction(signature=sig, py_fn=py_fn, self_path=function_path, self_name=function_name)
121
- else:
122
- result = ExplicitBatchedFunction(
123
- signature=sig, batch_size=batch_size, invoker_fn=py_fn, self_path=function_path)
119
+ result = CallableFunction(
120
+ signature=sig, py_fn=py_fn, self_path=function_path, self_name=function_name, batch_size=batch_size)
124
121
 
125
122
  # If this function is part of a module, register it
126
123
  if function_path is not None:
124
+ # do the validation at the very end, so it's easier to write tests for other failure scenarios
125
+ validate_symbol_path(function_path)
127
126
  FunctionRegistry.get().register_function(function_path, result)
128
127
 
129
128
  return result
@@ -142,17 +141,19 @@ def expr_udf(*args: Any, **kwargs: Any) -> Any:
142
141
  else:
143
142
  function_path = None
144
143
 
145
- sig = Signature.create(py_fn, param_types=param_types, return_type=None)
146
144
  # TODO: verify that the inferred return type matches that of the template
147
145
  # TODO: verify that the signature doesn't contain batched parameters
148
146
 
149
147
  # construct Parameters from the function signature
148
+ params = Signature.create_parameters(py_fn, param_types=param_types)
150
149
  import pixeltable.exprs as exprs
151
- var_exprs = [exprs.Variable(param.name, param.col_type) for param in sig.parameters.values()]
150
+ var_exprs = [exprs.Variable(param.name, param.col_type) for param in params]
152
151
  # call the function with the parameter expressions to construct an Expr with parameters
153
152
  template = py_fn(*var_exprs)
154
153
  assert isinstance(template, exprs.Expr)
155
154
  py_sig = inspect.signature(py_fn)
155
+ if function_path is not None:
156
+ validate_symbol_path(function_path)
156
157
  return ExprTemplateFunction(template, py_signature=py_sig, self_path=function_path, name=py_fn.__name__)
157
158
 
158
159
  if len(args) == 1: