pixeltable 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (88) hide show
  1. pixeltable/__init__.py +7 -19
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +7 -7
  4. pixeltable/catalog/globals.py +3 -0
  5. pixeltable/catalog/insertable_table.py +9 -7
  6. pixeltable/catalog/table.py +220 -143
  7. pixeltable/catalog/table_version.py +36 -18
  8. pixeltable/catalog/table_version_path.py +0 -8
  9. pixeltable/catalog/view.py +3 -3
  10. pixeltable/dataframe.py +9 -24
  11. pixeltable/env.py +107 -36
  12. pixeltable/exceptions.py +7 -4
  13. pixeltable/exec/__init__.py +1 -1
  14. pixeltable/exec/aggregation_node.py +22 -15
  15. pixeltable/exec/component_iteration_node.py +62 -41
  16. pixeltable/exec/data_row_batch.py +7 -7
  17. pixeltable/exec/exec_node.py +35 -7
  18. pixeltable/exec/expr_eval_node.py +2 -1
  19. pixeltable/exec/in_memory_data_node.py +9 -9
  20. pixeltable/exec/sql_node.py +265 -136
  21. pixeltable/exprs/__init__.py +1 -0
  22. pixeltable/exprs/data_row.py +30 -19
  23. pixeltable/exprs/expr.py +15 -14
  24. pixeltable/exprs/expr_dict.py +55 -0
  25. pixeltable/exprs/expr_set.py +21 -15
  26. pixeltable/exprs/function_call.py +21 -8
  27. pixeltable/exprs/json_path.py +3 -6
  28. pixeltable/exprs/rowid_ref.py +2 -2
  29. pixeltable/exprs/sql_element_cache.py +5 -1
  30. pixeltable/ext/functions/whisperx.py +7 -2
  31. pixeltable/func/callable_function.py +2 -2
  32. pixeltable/func/function_registry.py +6 -7
  33. pixeltable/func/query_template_function.py +11 -12
  34. pixeltable/func/signature.py +17 -15
  35. pixeltable/func/udf.py +0 -4
  36. pixeltable/functions/__init__.py +1 -1
  37. pixeltable/functions/audio.py +4 -6
  38. pixeltable/functions/globals.py +86 -42
  39. pixeltable/functions/huggingface.py +12 -14
  40. pixeltable/functions/image.py +59 -45
  41. pixeltable/functions/json.py +0 -1
  42. pixeltable/functions/mistralai.py +2 -2
  43. pixeltable/functions/openai.py +22 -25
  44. pixeltable/functions/string.py +50 -50
  45. pixeltable/functions/timestamp.py +20 -20
  46. pixeltable/functions/together.py +26 -12
  47. pixeltable/functions/video.py +11 -20
  48. pixeltable/functions/whisper.py +2 -20
  49. pixeltable/globals.py +57 -56
  50. pixeltable/index/base.py +2 -2
  51. pixeltable/index/btree.py +7 -7
  52. pixeltable/index/embedding_index.py +8 -10
  53. pixeltable/io/external_store.py +11 -5
  54. pixeltable/io/globals.py +3 -1
  55. pixeltable/io/hf_datasets.py +4 -4
  56. pixeltable/io/label_studio.py +6 -6
  57. pixeltable/io/parquet.py +14 -13
  58. pixeltable/iterators/document.py +10 -8
  59. pixeltable/iterators/video.py +10 -1
  60. pixeltable/metadata/__init__.py +3 -2
  61. pixeltable/metadata/converters/convert_14.py +4 -2
  62. pixeltable/metadata/converters/convert_15.py +1 -1
  63. pixeltable/metadata/converters/convert_19.py +1 -0
  64. pixeltable/metadata/converters/convert_20.py +1 -1
  65. pixeltable/metadata/converters/util.py +9 -8
  66. pixeltable/metadata/schema.py +32 -21
  67. pixeltable/plan.py +136 -154
  68. pixeltable/store.py +51 -36
  69. pixeltable/tool/create_test_db_dump.py +7 -7
  70. pixeltable/tool/doc_plugins/griffe.py +3 -34
  71. pixeltable/tool/mypy_plugin.py +32 -0
  72. pixeltable/type_system.py +243 -60
  73. pixeltable/utils/arrow.py +10 -9
  74. pixeltable/utils/coco.py +4 -4
  75. pixeltable/utils/documents.py +1 -1
  76. pixeltable/utils/filecache.py +131 -84
  77. pixeltable/utils/formatter.py +1 -1
  78. pixeltable/utils/http_server.py +2 -5
  79. pixeltable/utils/media_store.py +6 -6
  80. pixeltable/utils/pytorch.py +10 -11
  81. pixeltable/utils/sql.py +2 -1
  82. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/METADATA +16 -7
  83. pixeltable-0.2.21.dist-info/RECORD +148 -0
  84. pixeltable/utils/help.py +0 -11
  85. pixeltable-0.2.19.dist-info/RECORD +0 -147
  86. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/LICENSE +0 -0
  87. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/WHEEL +0 -0
  88. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,55 @@
1
+ from typing import Generic, TypeVar, Optional, Iterator, Iterable
2
+
3
+ T = TypeVar('T')
4
+
5
+ from .expr import Expr
6
+
7
+ class ExprDict(Generic[T]):
8
+ """
9
+ A dictionary that maps Expr instances to values of type T.
10
+
11
+ We cannot use dict[Expr, T] because Expr.__eq__() serves a different purpose than the default __eq__.
12
+ """
13
+
14
+ _data: dict[int, tuple[Expr, T]]
15
+
16
+ def __init__(self, iterable: Optional[Iterable[tuple[Expr, T]]] = None):
17
+ self._data = {}
18
+
19
+ if iterable is not None:
20
+ for key, value in iterable:
21
+ self[key] = value
22
+
23
+ def __setitem__(self, key: Expr, value: T) -> None:
24
+ self._data[key.id] = (key, value)
25
+
26
+ def __getitem__(self, key: Expr) -> T:
27
+ return self._data[key.id][1]
28
+
29
+ def __delitem__(self, key: Expr) -> None:
30
+ del self._data[key.id]
31
+
32
+ def __len__(self) -> int:
33
+ return len(self._data)
34
+
35
+ def __iter__(self) -> Iterator[Expr]:
36
+ return (expr for expr, _ in self._data.values())
37
+
38
+ def __contains__(self, key: Expr) -> bool:
39
+ return key.id in self._data
40
+
41
+ def get(self, key: Expr, default: Optional[T] = None) -> Optional[T]:
42
+ item = self._data.get(key.id)
43
+ return item[1] if item is not None else default
44
+
45
+ def clear(self) -> None:
46
+ self._data.clear()
47
+
48
+ def keys(self) -> Iterator[Expr]:
49
+ return self.__iter__()
50
+
51
+ def values(self) -> Iterator[T]:
52
+ return (value for _, value in self._data.values())
53
+
54
+ def items(self) -> Iterator[tuple[Expr, T]]:
55
+ return ((expr, value) for expr, value in self._data.values())
@@ -1,25 +1,26 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Optional, Iterable, Iterator
3
+ from typing import Optional, Iterable, Iterator, TypeVar, Generic
4
4
 
5
5
  from .expr import Expr
6
6
 
7
+ T = TypeVar('T', bound='Expr')
7
8
 
8
- class ExprSet:
9
+ class ExprSet(Generic[T]):
9
10
  """
10
11
  A set that also supports indexed lookup (by slot_idx and Expr.id). Exprs are uniquely identified by Expr.id.
11
12
  """
12
- exprs: dict[int, Expr] # key: Expr.id
13
- exprs_by_idx: dict[int, Expr] # key: slot_idx
13
+ exprs: dict[int, T] # key: Expr.id
14
+ exprs_by_idx: dict[int, T] # key: slot_idx
14
15
 
15
- def __init__(self, elements: Optional[Iterable[Expr]] = None):
16
+ def __init__(self, elements: Optional[Iterable[T]] = None):
16
17
  self.exprs = {}
17
18
  self.exprs_by_idx = {}
18
19
  if elements is not None:
19
20
  for e in elements:
20
21
  self.add(e)
21
22
 
22
- def add(self, expr: Expr) -> None:
23
+ def add(self, expr: T) -> None:
23
24
  if expr.id in self.exprs:
24
25
  return
25
26
  self.exprs[expr.id] = expr
@@ -27,24 +28,22 @@ class ExprSet:
27
28
  return
28
29
  self.exprs_by_idx[expr.slot_idx] = expr
29
30
 
30
- def update(self, *others: Iterable[Expr]) -> None:
31
+ def update(self, *others: Iterable[T]) -> None:
31
32
  for other in others:
32
33
  for e in other:
33
34
  self.add(e)
34
35
 
35
- def __contains__(self, item: Expr) -> bool:
36
+ def __contains__(self, item: T) -> bool:
36
37
  return item.id in self.exprs
37
38
 
38
39
  def __len__(self) -> int:
39
40
  return len(self.exprs)
40
41
 
41
- def __iter__(self) -> Iterator[Expr]:
42
+ def __iter__(self) -> Iterator[T]:
42
43
  return iter(self.exprs.values())
43
44
 
44
- def __getitem__(self, index: object) -> Optional[Expr]:
45
+ def __getitem__(self, index: object) -> Optional[T]:
45
46
  """Indexed lookup by slot_idx or Expr.id."""
46
- if not isinstance(index, int) and not isinstance(index, Expr):
47
- pass
48
47
  assert isinstance(index, int) or isinstance(index, Expr)
49
48
  if isinstance(index, int):
50
49
  # return expr with matching slot_idx
@@ -52,11 +51,18 @@ class ExprSet:
52
51
  else:
53
52
  return self.exprs.get(index.id)
54
53
 
55
- def issuperset(self, other: ExprSet) -> bool:
54
+ def issuperset(self, other: ExprSet[T]) -> bool:
56
55
  return self.exprs.keys() >= other.exprs.keys()
57
56
 
58
- def __ge__(self, other: ExprSet) -> bool:
57
+ def __ge__(self, other: ExprSet[T]) -> bool:
59
58
  return self.issuperset(other)
60
59
 
61
- def __le__(self, other: ExprSet) -> bool:
60
+ def __le__(self, other: ExprSet[T]) -> bool:
62
61
  return other.issuperset(self)
62
+
63
+ def difference(self, *others: Iterable[T]) -> ExprSet[T]:
64
+ id_diff = set(self.exprs.keys()).difference(e.id for other_set in others for e in other_set)
65
+ return ExprSet(self.exprs[id] for id in id_diff)
66
+
67
+ def __sub__(self, other: ExprSet[T]) -> ExprSet[T]:
68
+ return self.difference(other)
@@ -294,6 +294,10 @@ class FunctionCall(Expr):
294
294
  def get_window_sort_exprs(self) -> tuple[list[Expr], list[Expr]]:
295
295
  return self.group_by, self.order_by
296
296
 
297
+ def get_window_ordering(self) -> list[tuple[Expr, bool]]:
298
+ # ordering is implicitly ascending
299
+ return [(e, None) for e in self.group_by] + [(e, True) for e in self.order_by]
300
+
297
301
  @property
298
302
  def is_agg_fn_call(self) -> bool:
299
303
  return isinstance(self.fn, func.AggregateFunction)
@@ -303,6 +307,10 @@ class FunctionCall(Expr):
303
307
  return self.order_by
304
308
 
305
309
  def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
310
+ # we currently can't translate aggregate functions with grouping and/or ordering to SQL
311
+ if self.has_group_by() or len(self.order_by) > 0:
312
+ return None
313
+
306
314
  # try to construct args and kwargs to call self.fn._to_sql()
307
315
  kwargs: dict[str, sql.ColumnElement] = {}
308
316
  for param_name, (component_idx, arg) in self.kwargs.items():
@@ -374,6 +382,18 @@ class FunctionCall(Expr):
374
382
  return args, kwargs
375
383
 
376
384
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
385
+ if isinstance(self.fn, func.ExprTemplateFunction):
386
+ # we need to evaluate the template
387
+ # TODO: can we get rid of this extra copy?
388
+ fn_expr = self.components[self.fn_expr_idx]
389
+ data_row[self.slot_idx] = data_row[fn_expr.slot_idx]
390
+ return
391
+ elif self.is_agg_fn_call and not self.is_window_fn_call:
392
+ if self.aggregator is None:
393
+ pass
394
+ data_row[self.slot_idx] = self.aggregator.value()
395
+ return
396
+
377
397
  args, kwargs = self._make_args(data_row)
378
398
  signature = self.fn.signature
379
399
  if signature.parameters is not None:
@@ -389,12 +409,7 @@ class FunctionCall(Expr):
389
409
  data_row[self.slot_idx] = None
390
410
  return
391
411
 
392
- if isinstance(self.fn, func.ExprTemplateFunction):
393
- # we need to evaluate the template
394
- # TODO: can we get rid of this extra copy?
395
- fn_expr = self.components[self.fn_expr_idx]
396
- data_row[self.slot_idx] = data_row[fn_expr.slot_idx]
397
- elif isinstance(self.fn, func.CallableFunction) and not self.fn.is_batched:
412
+ if isinstance(self.fn, func.CallableFunction) and not self.fn.is_batched:
398
413
  # optimization: avoid additional level of indirection we'd get from calling Function.exec()
399
414
  data_row[self.slot_idx] = self.fn.py_fn(*args, **kwargs)
400
415
  elif self.is_window_fn_call:
@@ -410,8 +425,6 @@ class FunctionCall(Expr):
410
425
  self.aggregator = self.fn.agg_cls(**self.agg_init_args)
411
426
  self.aggregator.update(*args)
412
427
  data_row[self.slot_idx] = self.aggregator.value()
413
- elif self.is_agg_fn_call:
414
- data_row[self.slot_idx] = self.aggregator.value()
415
428
  else:
416
429
  data_row[self.slot_idx] = self.fn.exec(*args, **kwargs)
417
430
 
@@ -105,12 +105,9 @@ class JsonPath(Expr):
105
105
  return JsonPath(self._anchor, self.path_elements + [name])
106
106
 
107
107
  def __getitem__(self, index: object) -> 'JsonPath':
108
- if isinstance(index, str):
109
- if index != '*':
110
- raise excs.Error(f'Invalid json list index: {index}')
111
- elif not isinstance(index, (int, slice)):
112
- raise excs.Error(f'Invalid json list index: {index}')
113
- return JsonPath(self._anchor, self.path_elements + [index])
108
+ if isinstance(index, (int, slice, str)):
109
+ return JsonPath(self._anchor, self.path_elements + [index])
110
+ raise excs.Error(f'Invalid json list index: {index}')
114
111
 
115
112
  def __rshift__(self, other: object) -> 'JsonMapper':
116
113
  rhs_expr = Expr.from_object(other)
@@ -68,8 +68,8 @@ class RowidRef(Expr):
68
68
  """
69
69
  if self.tbl_id == tbl.tbl_version.id:
70
70
  return
71
- tbl_version_ids = [tbl_version.id for tbl_version in tbl.get_tbl_versions()]
72
- assert self.tbl_id in tbl_version_ids
71
+ base_ids = [tbl_version.id for tbl_version in tbl.get_tbl_versions()]
72
+ assert self.tbl_id in base_ids # our current TableVersion is a base of the new TableVersion
73
73
  self.tbl = tbl.tbl_version
74
74
  self.tbl_id = self.tbl.id
75
75
 
@@ -3,6 +3,7 @@ from typing import Iterable, Union, Optional
3
3
  import sqlalchemy as sql
4
4
 
5
5
  from .expr import Expr
6
+ from .expr_dict import ExprDict
6
7
 
7
8
 
8
9
  class SqlElementCache:
@@ -10,8 +11,11 @@ class SqlElementCache:
10
11
 
11
12
  cache: dict[int, Optional[sql.ColumnElement]] # key: Expr.id
12
13
 
13
- def __init__(self):
14
+ def __init__(self, elements: Optional[ExprDict[sql.ColumnElement]] = None):
14
15
  self.cache = {}
16
+ if elements is not None:
17
+ for e, el in elements.items():
18
+ self.cache[e.id] = el
15
19
 
16
20
  def get(self, e: Expr) -> Optional[sql.ColumnElement]:
17
21
  """Returns the sql.ColumnElement for the given Expr, or None if Expr.to_sql() returns None."""
@@ -8,9 +8,14 @@ if TYPE_CHECKING:
8
8
  import pixeltable as pxt
9
9
 
10
10
 
11
- @pxt.udf(param_types=[pxt.AudioType(), pxt.StringType(), pxt.StringType(), pxt.StringType(), pxt.IntType()])
11
+ @pxt.udf
12
12
  def transcribe(
13
- audio: str, *, model: str, compute_type: Optional[str] = None, language: Optional[str] = None, chunk_size: int = 30
13
+ audio: pxt.Audio,
14
+ *,
15
+ model: str,
16
+ compute_type: Optional[str] = None,
17
+ language: Optional[str] = None,
18
+ chunk_size: int = 30
14
19
  ) -> dict:
15
20
  """
16
21
  Transcribe an audio file using WhisperX.
@@ -4,7 +4,7 @@ import inspect
4
4
  from typing import Any, Callable, Optional
5
5
  from uuid import UUID
6
6
 
7
- import cloudpickle
7
+ import cloudpickle # type: ignore[import-untyped]
8
8
 
9
9
  from .function import Function
10
10
  from .signature import Signature
@@ -108,7 +108,7 @@ class CallableFunction(Function):
108
108
  @classmethod
109
109
  def from_store(cls, name: Optional[str], md: dict, binary_obj: bytes) -> Function:
110
110
  py_fn = cloudpickle.loads(binary_obj)
111
- assert isinstance(py_fn, Callable)
111
+ assert callable(py_fn)
112
112
  sig = Signature.from_dict(md['signature'])
113
113
  batch_size = md['batch_size']
114
114
  return CallableFunction(sig, py_fn, self_name=name, batch_size=batch_size)
@@ -4,7 +4,7 @@ import dataclasses
4
4
  import importlib
5
5
  import logging
6
6
  import sys
7
- from typing import Optional, Dict, List
7
+ from typing import Optional
8
8
  from uuid import UUID
9
9
 
10
10
  import sqlalchemy as sql
@@ -14,7 +14,6 @@ import pixeltable.exceptions as excs
14
14
  import pixeltable.type_system as ts
15
15
  from pixeltable.metadata import schema
16
16
  from .function import Function
17
- from .globals import get_caller_module_path
18
17
 
19
18
  _logger = logging.getLogger('pixeltable')
20
19
 
@@ -32,15 +31,15 @@ class FunctionRegistry:
32
31
  return cls._instance
33
32
 
34
33
  def __init__(self):
35
- self.stored_fns_by_id: Dict[UUID, Function] = {}
36
- self.module_fns: Dict[str, Function] = {} # fqn -> Function
34
+ self.stored_fns_by_id: dict[UUID, Function] = {}
35
+ self.module_fns: dict[str, Function] = {} # fqn -> Function
37
36
  self.type_methods: dict[ts.ColumnType.Type, dict[str, Function]] = {}
38
37
 
39
38
  def clear_cache(self) -> None:
40
39
  """
41
40
  Useful during testing
42
41
  """
43
- self.stored_fns_by_id: Dict[UUID, Function] = {}
42
+ self.stored_fns_by_id = {}
44
43
 
45
44
  # def register_std_modules(self) -> None:
46
45
  # """Register all submodules of pixeltable.functions"""
@@ -76,7 +75,7 @@ class FunctionRegistry:
76
75
  raise excs.Error(f'Duplicate method name for type {base_type}: {fn.name}')
77
76
  self.type_methods[base_type][fn.name] = fn
78
77
 
79
- def list_functions(self) -> List[Function]:
78
+ def list_functions(self) -> list[Function]:
80
79
  # retrieve Function.Metadata data for all existing stored functions from store directly
81
80
  # (self.stored_fns_by_id isn't guaranteed to contain all functions)
82
81
  # TODO: have the client do this, once the client takes over the Db functionality
@@ -85,7 +84,7 @@ class FunctionRegistry:
85
84
  # schema.Db.name, schema.Dir.path, sql_func.length(schema.Function.init_obj))\
86
85
  # .where(schema.Function.db_id == schema.Db.id)\
87
86
  # .where(schema.Function.dir_id == schema.Dir.id)
88
- # stored_fn_md: List[Function.Metadata] = []
87
+ # stored_fn_md: list[Function.Metadata] = []
89
88
  # with Env.get().engine.begin() as conn:
90
89
  # rows = conn.execute(stmt)
91
90
  # for name, md_dict, db_name, dir_path, init_obj_len in rows:
@@ -1,14 +1,15 @@
1
1
  from __future__ import annotations
2
+
2
3
  import inspect
3
- from typing import Dict, Optional, Any, Callable
4
+ from typing import Any, Callable, Optional
4
5
 
5
6
  import sqlalchemy as sql
6
7
 
7
- import pixeltable
8
- import pixeltable.exceptions as excs
9
- import pixeltable.type_system as ts
8
+ import pixeltable as pxt
9
+ from pixeltable import exprs
10
+
10
11
  from .function import Function
11
- from .signature import Signature, Parameter
12
+ from .signature import Signature
12
13
 
13
14
 
14
15
  class QueryTemplateFunction(Function):
@@ -16,24 +17,23 @@ class QueryTemplateFunction(Function):
16
17
 
17
18
  @classmethod
18
19
  def create(
19
- cls, template_callable: Callable, param_types: Optional[list[ts.ColumnType]], path: str, name: str
20
+ cls, template_callable: Callable, param_types: Optional[list[pxt.ColumnType]], path: str, name: str
20
21
  ) -> QueryTemplateFunction:
21
22
  # we need to construct a template df and a signature
22
23
  py_sig = inspect.signature(template_callable)
23
24
  py_params = list(py_sig.parameters.values())
24
25
  params = Signature.create_parameters(py_params=py_params, param_types=param_types)
25
26
  # invoke template_callable with parameter expressions to construct a DataFrame with parameters
26
- import pixeltable.exprs as exprs
27
27
  var_exprs = [exprs.Variable(param.name, param.col_type) for param in params]
28
28
  template_df = template_callable(*var_exprs)
29
29
  from pixeltable import DataFrame
30
30
  assert isinstance(template_df, DataFrame)
31
31
  # we take params and return json
32
- sig = Signature(return_type=ts.JsonType(), parameters=params)
32
+ sig = Signature(return_type=pxt.JsonType(), parameters=params)
33
33
  return QueryTemplateFunction(template_df, sig, path=path, name=name)
34
34
 
35
35
  def __init__(
36
- self, template_df: Optional['pixeltable.DataFrame'], sig: Optional[Signature], path: Optional[str] = None,
36
+ self, template_df: Optional['pxt.DataFrame'], sig: Optional[Signature], path: Optional[str] = None,
37
37
  name: Optional[str] = None,
38
38
  ):
39
39
  super().__init__(sig, self_path=path)
@@ -46,7 +46,6 @@ class QueryTemplateFunction(Function):
46
46
  self.conn: Optional[sql.engine.Connection] = None
47
47
 
48
48
  # convert defaults to Literals
49
- import pixeltable.exprs as exprs
50
49
  self.defaults: dict[str, exprs.Literal] = {} # key: param name, value: default value converted to a Literal
51
50
  param_types = self.template_df.parameters()
52
51
  for param in [p for p in self.signature.parameters.values() if p.has_default()]:
@@ -75,10 +74,10 @@ class QueryTemplateFunction(Function):
75
74
  def name(self) -> str:
76
75
  return self.self_name
77
76
 
78
- def _as_dict(self) -> Dict:
77
+ def _as_dict(self) -> dict:
79
78
  return {'name': self.name, 'signature': self.signature.as_dict(), 'df': self.template_df.as_dict()}
80
79
 
81
80
  @classmethod
82
- def _from_dict(cls, d: Dict) -> Function:
81
+ def _from_dict(cls, d: dict) -> Function:
83
82
  from pixeltable.dataframe import DataFrame
84
83
  return cls(DataFrame.from_dict(d['df']), Signature.from_dict(d['signature']), name=d['name'])
@@ -1,12 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
- import json
4
3
  import dataclasses
5
4
  import enum
6
5
  import inspect
6
+ import json
7
7
  import logging
8
8
  import typing
9
- from typing import Optional, Callable, Dict, List, Any, Union, Tuple
9
+ from typing import Any, Callable, Optional, Union
10
10
 
11
11
  import pixeltable.exceptions as excs
12
12
  import pixeltable.type_system as ts
@@ -18,7 +18,7 @@ _logger = logging.getLogger('pixeltable')
18
18
  class Parameter:
19
19
  name: str
20
20
  col_type: Optional[ts.ColumnType] # None for variable parameters
21
- kind: enum.Enum # inspect.Parameter.kind; inspect._ParameterKind is private
21
+ kind: inspect._ParameterKind
22
22
  # for some reason, this needs to precede is_batched in the dataclass definition,
23
23
  # otherwise Python complains that an argument with a default is followed by an argument without a default
24
24
  default: Any = inspect.Parameter.empty # default value for the parameter
@@ -82,7 +82,7 @@ class Signature:
82
82
  """
83
83
  SPECIAL_PARAM_NAMES = ['group_by', 'order_by']
84
84
 
85
- def __init__(self, return_type: ts.ColumnType, parameters: List[Parameter], is_batched: bool = False):
85
+ def __init__(self, return_type: ts.ColumnType, parameters: list[Parameter], is_batched: bool = False):
86
86
  assert isinstance(return_type, ts.ColumnType)
87
87
  self.return_type = return_type
88
88
  self.is_batched = is_batched
@@ -97,7 +97,7 @@ class Signature:
97
97
  assert isinstance(self.return_type, ts.ColumnType)
98
98
  return self.return_type
99
99
 
100
- def as_dict(self) -> Dict[str, Any]:
100
+ def as_dict(self) -> dict[str, Any]:
101
101
  result = {
102
102
  'return_type': self.get_return_type().as_dict(),
103
103
  'parameters': [p.as_dict() for p in self.parameters.values()],
@@ -106,11 +106,13 @@ class Signature:
106
106
  return result
107
107
 
108
108
  @classmethod
109
- def from_dict(cls, d: Dict[str, Any]) -> Signature:
109
+ def from_dict(cls, d: dict[str, Any]) -> Signature:
110
110
  parameters = [Parameter.from_dict(param_dict) for param_dict in d['parameters']]
111
111
  return cls(ts.ColumnType.from_dict(d['return_type']), parameters, d['is_batched'])
112
112
 
113
- def __eq__(self, other: Signature) -> bool:
113
+ def __eq__(self, other: object) -> bool:
114
+ if not isinstance(other, Signature):
115
+ return False
114
116
  if self.get_return_type() != other.get_return_type():
115
117
  return False
116
118
  if len(self.parameters) != len(other.parameters):
@@ -122,7 +124,7 @@ class Signature:
122
124
  return True
123
125
 
124
126
  def __str__(self) -> str:
125
- param_strs: List[str] = []
127
+ param_strs: list[str] = []
126
128
  for p in self.parameters.values():
127
129
  if p.kind == inspect.Parameter.VAR_POSITIONAL:
128
130
  param_strs.append(f'*{p.name}')
@@ -133,7 +135,7 @@ class Signature:
133
135
  return f'({", ".join(param_strs)}) -> {str(self.get_return_type())}'
134
136
 
135
137
  @classmethod
136
- def _infer_type(cls, annotation: Optional[type]) -> Tuple[Optional[ts.ColumnType], Optional[bool]]:
138
+ def _infer_type(cls, annotation: Optional[type]) -> tuple[Optional[ts.ColumnType], Optional[bool]]:
137
139
  """Returns: (column type, is_batched) or (None, ...) if the type cannot be inferred"""
138
140
  if annotation is None:
139
141
  return (None, None)
@@ -154,13 +156,13 @@ class Signature:
154
156
  @classmethod
155
157
  def create_parameters(
156
158
  cls, py_fn: Optional[Callable] = None, py_params: Optional[list[inspect.Parameter]] = None,
157
- param_types: Optional[List[ts.ColumnType]] = None
158
- ) -> List[Parameter]:
159
+ param_types: Optional[list[ts.ColumnType]] = None
160
+ ) -> list[Parameter]:
159
161
  assert (py_fn is None) != (py_params is None)
160
162
  if py_fn is not None:
161
163
  sig = inspect.signature(py_fn)
162
164
  py_params = list(sig.parameters.values())
163
- parameters: List[Parameter] = []
165
+ parameters: list[Parameter] = []
164
166
 
165
167
  for idx, param in enumerate(py_params):
166
168
  if param.name in cls.SPECIAL_PARAM_NAMES:
@@ -187,9 +189,9 @@ class Signature:
187
189
 
188
190
  @classmethod
189
191
  def create(
190
- cls, py_fn: Callable,
191
- param_types: Optional[List[ts.ColumnType]] = None,
192
- return_type: Optional[Union[ts.ColumnType, Callable]] = None
192
+ cls, py_fn: Callable,
193
+ param_types: Optional[list[ts.ColumnType]] = None,
194
+ return_type: Optional[ts.ColumnType] = None
193
195
  ) -> Signature:
194
196
  """Create a signature for the given Callable.
195
197
  Infer the parameter and return types, if none are specified.
pixeltable/func/udf.py CHANGED
@@ -38,10 +38,6 @@ def udf(*args, **kwargs):
38
38
  >>> @pxt.udf
39
39
  ... def my_function(x: int) -> int:
40
40
  ... return x + 1
41
-
42
- >>> @pxt.udf(param_types=[pxt.IntType()], return_type=pxt.IntType())
43
- ... def my_function(x):
44
- ... return x + 1
45
41
  """
46
42
  if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
47
43
 
@@ -1,7 +1,7 @@
1
1
  from pixeltable.utils.code import local_public_names
2
2
 
3
3
  from . import (anthropic, audio, fireworks, huggingface, image, json, mistralai, openai, string, timestamp, together,
4
- video, vision)
4
+ video, vision, whisper)
5
5
  from .globals import *
6
6
 
7
7
  __all__ = local_public_names(__name__, exclude=['globals']) + local_public_names(globals.__name__)
@@ -11,18 +11,16 @@ t.select(pxtf.audio.get_metadata()).collect()
11
11
  ```
12
12
  """
13
13
 
14
- import pixeltable.func as func
15
- import pixeltable.type_system as ts
14
+ import pixeltable as pxt
16
15
  from pixeltable.utils.code import local_public_names
17
16
 
18
17
 
19
- @func.udf(return_type=ts.JsonType(nullable=False), param_types=[ts.AudioType(nullable=False)], is_method=True)
20
- def get_metadata(audio: str) -> dict:
18
+ @pxt.udf(is_method=True)
19
+ def get_metadata(audio: pxt.Audio) -> dict:
21
20
  """
22
21
  Gets various metadata associated with an audio file and returns it as a dictionary.
23
22
  """
24
- import pixeltable.functions as pxtf
25
- return pxtf.video._get_metadata(audio)
23
+ return pxt.functions.video._get_metadata(audio)
26
24
 
27
25
 
28
26
  __all__ = local_public_names(__name__)