pixeltable 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (39) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/insertable_table.py +2 -2
  3. pixeltable/catalog/table.py +5 -5
  4. pixeltable/catalog/table_version.py +12 -14
  5. pixeltable/catalog/view.py +2 -2
  6. pixeltable/dataframe.py +7 -6
  7. pixeltable/exec/expr_eval_node.py +8 -1
  8. pixeltable/exec/sql_scan_node.py +1 -1
  9. pixeltable/exprs/__init__.py +0 -1
  10. pixeltable/exprs/comparison.py +5 -5
  11. pixeltable/exprs/compound_predicate.py +12 -12
  12. pixeltable/exprs/expr.py +32 -0
  13. pixeltable/exprs/in_predicate.py +3 -3
  14. pixeltable/exprs/is_null.py +5 -5
  15. pixeltable/func/aggregate_function.py +10 -4
  16. pixeltable/func/callable_function.py +4 -0
  17. pixeltable/func/function_registry.py +2 -0
  18. pixeltable/functions/globals.py +36 -1
  19. pixeltable/functions/huggingface.py +62 -4
  20. pixeltable/functions/image.py +17 -0
  21. pixeltable/functions/string.py +622 -7
  22. pixeltable/functions/video.py +26 -8
  23. pixeltable/globals.py +3 -3
  24. pixeltable/io/globals.py +53 -4
  25. pixeltable/io/label_studio.py +42 -2
  26. pixeltable/io/pandas.py +18 -7
  27. pixeltable/plan.py +6 -6
  28. pixeltable/tool/create_test_db_dump.py +1 -1
  29. pixeltable/tool/doc_plugins/griffe.py +77 -0
  30. pixeltable/tool/doc_plugins/mkdocstrings.py +6 -0
  31. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +135 -0
  32. pixeltable/utils/s3.py +1 -1
  33. pixeltable-0.2.13.dist-info/METADATA +206 -0
  34. {pixeltable-0.2.12.dist-info → pixeltable-0.2.13.dist-info}/RECORD +37 -34
  35. pixeltable-0.2.13.dist-info/entry_points.txt +3 -0
  36. pixeltable/exprs/predicate.py +0 -44
  37. pixeltable-0.2.12.dist-info/METADATA +0 -137
  38. {pixeltable-0.2.12.dist-info → pixeltable-0.2.13.dist-info}/LICENSE +0 -0
  39. {pixeltable-0.2.12.dist-info → pixeltable-0.2.13.dist-info}/WHEEL +0 -0
pixeltable/__version__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  # These version placeholders will be replaced during build.
2
- __version__ = "0.2.12"
3
- __version_tuple__ = (0, 2, 12)
2
+ __version__ = "0.2.13"
3
+ __version_tuple__ = (0, 2, 13)
@@ -129,11 +129,11 @@ class InsertableTable(Table):
129
129
  msg = str(e)
130
130
  raise excs.Error(f'Error in column {col.name}: {msg[0].lower() + msg[1:]}\nRow: {row}')
131
131
 
132
- def delete(self, where: Optional['pixeltable.exprs.Predicate'] = None) -> UpdateStatus:
132
+ def delete(self, where: Optional['pixeltable.exprs.Expr'] = None) -> UpdateStatus:
133
133
  """Delete rows in this table.
134
134
 
135
135
  Args:
136
- where: a Predicate to filter rows to delete.
136
+ where: a predicate to filter rows to delete.
137
137
 
138
138
  Examples:
139
139
  Delete all rows in a table:
@@ -113,7 +113,7 @@ class Table(SchemaObject):
113
113
  from pixeltable.dataframe import DataFrame
114
114
  return DataFrame(self._tbl_version_path).select(*items, **named_items)
115
115
 
116
- def where(self, pred: 'exprs.Predicate') -> 'pixeltable.dataframe.DataFrame':
116
+ def where(self, pred: 'exprs.Expr') -> 'pixeltable.dataframe.DataFrame':
117
117
  """Return a DataFrame for this table.
118
118
  """
119
119
  # local import: avoid circular imports
@@ -716,13 +716,13 @@ class Table(SchemaObject):
716
716
  raise NotImplementedError
717
717
 
718
718
  def update(
719
- self, value_spec: dict[str, Any], where: Optional['pixeltable.exprs.Predicate'] = None, cascade: bool = True
719
+ self, value_spec: dict[str, Any], where: Optional['pixeltable.exprs.Expr'] = None, cascade: bool = True
720
720
  ) -> UpdateStatus:
721
721
  """Update rows in this table.
722
722
 
723
723
  Args:
724
724
  value_spec: a dictionary mapping column names to literal values or Pixeltable expressions.
725
- where: a Predicate to filter rows to update.
725
+ where: a predicate to filter rows to update.
726
726
  cascade: if True, also update all computed columns that transitively depend on the updated columns.
727
727
 
728
728
  Examples:
@@ -786,11 +786,11 @@ class Table(SchemaObject):
786
786
  row_updates.append(col_vals)
787
787
  return self._tbl_version.batch_update(row_updates, rowids, cascade)
788
788
 
789
- def delete(self, where: Optional['pixeltable.exprs.Predicate'] = None) -> UpdateStatus:
789
+ def delete(self, where: Optional['pixeltable.exprs.Expr'] = None) -> UpdateStatus:
790
790
  """Delete rows in this table.
791
791
 
792
792
  Args:
793
- where: a Predicate to filter rows to delete.
793
+ where: a predicate to filter rows to delete.
794
794
 
795
795
  Examples:
796
796
  Delete all rows in a table:
@@ -678,12 +678,12 @@ class TableVersion:
678
678
  return result
679
679
 
680
680
  def update(
681
- self, value_spec: dict[str, Any], where: Optional['exprs.Predicate'] = None, cascade: bool = True
681
+ self, value_spec: dict[str, Any], where: Optional['exprs.Expr'] = None, cascade: bool = True
682
682
  ) -> UpdateStatus:
683
683
  """Update rows in this TableVersionPath.
684
684
  Args:
685
685
  value_spec: a list of (column, value) pairs specifying the columns to update and their new values.
686
- where: a Predicate to filter rows to update.
686
+ where: a predicate to filter rows to update.
687
687
  cascade: if True, also update all computed columns that transitively depend on the updated columns,
688
688
  including within views.
689
689
  """
@@ -694,8 +694,8 @@ class TableVersion:
694
694
 
695
695
  update_spec = self._validate_update_spec(value_spec, allow_pk=False, allow_exprs=True)
696
696
  if where is not None:
697
- if not isinstance(where, exprs.Predicate):
698
- raise excs.Error(f"'where' argument must be a Predicate, got {type(where)}")
697
+ if not isinstance(where, exprs.Expr):
698
+ raise excs.Error(f"'where' argument must be a predicate, got {type(where)}")
699
699
  analysis_info = Planner.analyze(self.path, where)
700
700
  # for now we require that the updated rows can be identified via SQL, rather than via a Python filter
701
701
  if analysis_info.filter is not None:
@@ -757,7 +757,7 @@ class TableVersion:
757
757
 
758
758
  def _update(
759
759
  self, conn: sql.engine.Connection, update_targets: dict[Column, 'pixeltable.exprs.Expr'],
760
- where_clause: Optional['pixeltable.exprs.Predicate'] = None, cascade: bool = True,
760
+ where_clause: Optional['pixeltable.exprs.Expr'] = None, cascade: bool = True,
761
761
  show_progress: bool = True
762
762
  ) -> UpdateStatus:
763
763
  from pixeltable.plan import Planner
@@ -789,8 +789,6 @@ class TableVersion:
789
789
  raise excs.Error(f'Column {col_name} is computed and cannot be updated')
790
790
  if col.is_pk and not allow_pk:
791
791
  raise excs.Error(f'Column {col_name} is a primary key column and cannot be updated')
792
- if col.col_type.is_media_type():
793
- raise excs.Error(f'Column {col_name} has type image/video/audio/document and cannot be updated')
794
792
 
795
793
  # make sure that the value is compatible with the column type
796
794
  try:
@@ -848,17 +846,17 @@ class TableVersion:
848
846
  result.cols_with_excs = list(dict.fromkeys(result.cols_with_excs).keys()) # remove duplicates
849
847
  return result
850
848
 
851
- def delete(self, where: Optional['exprs.Predicate'] = None) -> UpdateStatus:
849
+ def delete(self, where: Optional['exprs.Expr'] = None) -> UpdateStatus:
852
850
  """Delete rows in this table.
853
851
  Args:
854
- where: a Predicate to filter rows to delete.
852
+ where: a predicate to filter rows to delete.
855
853
  """
856
854
  assert self.is_insertable()
857
- from pixeltable.exprs import Predicate
855
+ from pixeltable.exprs import Expr
858
856
  from pixeltable.plan import Planner
859
857
  if where is not None:
860
- if not isinstance(where, Predicate):
861
- raise excs.Error(f"'where' argument must be a Predicate, got {type(where)}")
858
+ if not isinstance(where, Expr):
859
+ raise excs.Error(f"'where' argument must be a predicate, got {type(where)}")
862
860
  analysis_info = Planner.analyze(self.path, where)
863
861
  # for now we require that the updated rows can be identified via SQL, rather than via a Python filter
864
862
  if analysis_info.filter is not None:
@@ -872,11 +870,11 @@ class TableVersion:
872
870
  return status
873
871
 
874
872
  def propagate_delete(
875
- self, where: Optional['exprs.Predicate'], base_versions: List[Optional[int]],
873
+ self, where: Optional['exprs.Expr'], base_versions: List[Optional[int]],
876
874
  conn: sql.engine.Connection, timestamp: float) -> int:
877
875
  """Delete rows in this table and propagate to views.
878
876
  Args:
879
- where: a Predicate to filter rows to delete.
877
+ where: a predicate to filter rows to delete.
880
878
  Returns:
881
879
  number of deleted rows
882
880
  """
@@ -51,7 +51,7 @@ class View(Table):
51
51
  @classmethod
52
52
  def create(
53
53
  cls, dir_id: UUID, name: str, base: TableVersionPath, schema: Dict[str, Any],
54
- predicate: 'pxt.exprs.Predicate', is_snapshot: bool, num_retained_versions: int, comment: str,
54
+ predicate: 'pxt.exprs.Expr', is_snapshot: bool, num_retained_versions: int, comment: str,
55
55
  iterator_cls: Optional[Type[ComponentIterator]], iterator_args: Optional[Dict]
56
56
  ) -> View:
57
57
  columns = cls._create_columns(schema)
@@ -213,5 +213,5 @@ class View(Table):
213
213
  ) -> UpdateStatus:
214
214
  raise excs.Error(f'{self.display_name()} {self._name!r}: cannot insert into view')
215
215
 
216
- def delete(self, where: Optional['pixeltable.exprs.Predicate'] = None) -> UpdateStatus:
216
+ def delete(self, where: Optional['pixeltable.exprs.Expr'] = None) -> UpdateStatus:
217
217
  raise excs.Error(f'{self.display_name()} {self._name!r}: cannot delete from view')
pixeltable/dataframe.py CHANGED
@@ -153,7 +153,7 @@ class DataFrame:
153
153
  self,
154
154
  tbl: catalog.TableVersionPath,
155
155
  select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]] = None,
156
- where_clause: Optional[exprs.Predicate] = None,
156
+ where_clause: Optional[exprs.Expr] = None,
157
157
  group_by_clause: Optional[List[exprs.Expr]] = None,
158
158
  grouping_tbl: Optional[catalog.TableVersion] = None,
159
159
  order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, # List[(expr, asc)]
@@ -530,7 +530,11 @@ class DataFrame:
530
530
  limit=self.limit_val,
531
531
  )
532
532
 
533
- def where(self, pred: exprs.Predicate) -> DataFrame:
533
+ def where(self, pred: exprs.Expr) -> DataFrame:
534
+ if not isinstance(pred, exprs.Expr):
535
+ raise excs.Error(f'Where() requires a Pixeltable expression, but instead got {type(pred)}')
536
+ if not pred.col_type.is_bool_type():
537
+ raise excs.Error(f'Where(): expression needs to return bool, but instead returns {pred.col_type}')
534
538
  return DataFrame(
535
539
  self.tbl,
536
540
  select_list=self.select_list,
@@ -628,12 +632,9 @@ class DataFrame:
628
632
  def __getitem__(self, index: object) -> DataFrame:
629
633
  """
630
634
  Allowed:
631
- - [<Predicate>]: filter operation
632
635
  - [List[Expr]]/[Tuple[Expr]]: setting the select list
633
636
  - [Expr]: setting a single-col select list
634
637
  """
635
- if isinstance(index, exprs.Predicate):
636
- return self.where(index)
637
638
  if isinstance(index, tuple):
638
639
  index = list(index)
639
640
  if isinstance(index, exprs.Expr):
@@ -668,7 +669,7 @@ class DataFrame:
668
669
  tbl = catalog.TableVersionPath.from_dict(d['tbl'])
669
670
  select_list = [(exprs.Expr.from_dict(e), name) for e, name in d['select_list']] \
670
671
  if d['select_list'] is not None else None
671
- where_clause = exprs.Predicate.from_dict(d['where_clause']) \
672
+ where_clause = exprs.Expr.from_dict(d['where_clause']) \
672
673
  if d['where_clause'] is not None else None
673
674
  group_by_clause = [exprs.Expr.from_dict(e) for e in d['group_by_clause']] \
674
675
  if d['group_by_clause'] is not None else None
@@ -50,7 +50,14 @@ class ExprEvalNode(ExecNode):
50
50
 
51
51
  def _open(self) -> None:
52
52
  warnings.simplefilter("ignore", category=TqdmWarning)
53
- if self.ctx.show_pbar:
53
+ # This is a temporary hack. When B-tree indices on string columns were implemented (via computed columns
54
+ # that invoke the `BtreeIndex.str_filter` udf), it resulted in frivolous progress bars appearing on every
55
+ # insertion. This special-cases the `str_filter` call to suppress the corresponding progress bar.
56
+ # TODO(aaron-siegel) Remove this hack once we clean up progress bars more generally.
57
+ is_str_filter_node = all(
58
+ isinstance(expr, exprs.FunctionCall) and expr.fn.name == 'str_filter' for expr in self.output_exprs
59
+ )
60
+ if self.ctx.show_pbar and not is_str_filter_node:
54
61
  self.pbar = tqdm(
55
62
  total=len(self.target_exprs) * self.ctx.num_rows,
56
63
  desc='Computing cells',
@@ -19,7 +19,7 @@ class SqlScanNode(ExecNode):
19
19
  def __init__(
20
20
  self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
21
21
  select_list: Iterable[exprs.Expr],
22
- where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Predicate] = None,
22
+ where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Expr] = None,
23
23
  order_by_items: Optional[List[Tuple[exprs.Expr, bool]]] = None,
24
24
  limit: int = 0, set_pk: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
25
25
  ):
@@ -17,7 +17,6 @@ from .json_mapper import JsonMapper
17
17
  from .json_path import RELATIVE_PATH_ROOT, JsonPath
18
18
  from .literal import Literal
19
19
  from .object_ref import ObjectRef
20
- from .predicate import Predicate
21
20
  from .row_builder import RowBuilder, ColumnSlotIdx, ExecProfile
22
21
  from .rowid_ref import RowidRef
23
22
  from .similarity_expr import SimilarityExpr
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Optional, List, Any, Dict, Tuple
3
+ from typing import Optional, List, Any, Dict
4
4
 
5
5
  import sqlalchemy as sql
6
6
 
@@ -9,15 +9,15 @@ from .data_row import DataRow
9
9
  from .expr import Expr
10
10
  from .globals import ComparisonOperator
11
11
  from .literal import Literal
12
- from .predicate import Predicate
13
12
  from .row_builder import RowBuilder
14
13
  import pixeltable.exceptions as excs
15
14
  import pixeltable.index as index
15
+ import pixeltable.type_system as ts
16
16
 
17
17
 
18
- class Comparison(Predicate):
18
+ class Comparison(Expr):
19
19
  def __init__(self, operator: ComparisonOperator, op1: Expr, op2: Expr):
20
- super().__init__()
20
+ super().__init__(ts.BoolType())
21
21
  self.operator = operator
22
22
 
23
23
  # if this is a comparison of a column to a literal (ie, could be used as a search argument in an index lookup),
@@ -50,7 +50,7 @@ class Comparison(Predicate):
50
50
  def _equals(self, other: Comparison) -> bool:
51
51
  return self.operator == other.operator
52
52
 
53
- def _id_attrs(self) -> List[Tuple[str, Any]]:
53
+ def _id_attrs(self) -> list[tuple[str, Any]]:
54
54
  return super()._id_attrs() + [('operator', self.operator.value)]
55
55
 
56
56
  @property
@@ -1,20 +1,20 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, List, Any, Dict, Tuple, Callable
2
+
3
3
  import operator
4
+ from typing import Optional, List, Any, Dict, Callable
4
5
 
5
6
  import sqlalchemy as sql
6
7
 
8
+ from .data_row import DataRow
7
9
  from .expr import Expr
8
10
  from .globals import LogicalOperator
9
- from .predicate import Predicate
10
- from .data_row import DataRow
11
11
  from .row_builder import RowBuilder
12
- import pixeltable.catalog as catalog
12
+ import pixeltable.type_system as ts
13
13
 
14
14
 
15
- class CompoundPredicate(Predicate):
16
- def __init__(self, operator: LogicalOperator, operands: List[Predicate]):
17
- super().__init__()
15
+ class CompoundPredicate(Expr):
16
+ def __init__(self, operator: LogicalOperator, operands: List[Expr]):
17
+ super().__init__(ts.BoolType())
18
18
  self.operator = operator
19
19
  # operands are stored in self.components
20
20
  if self.operator == LogicalOperator.NOT:
@@ -22,7 +22,7 @@ class CompoundPredicate(Predicate):
22
22
  self.components = operands
23
23
  else:
24
24
  assert len(operands) > 1
25
- self.operands: List[Predicate] = []
25
+ self.operands: List[Expr] = []
26
26
  for operand in operands:
27
27
  self._merge_operand(operand)
28
28
 
@@ -34,14 +34,14 @@ class CompoundPredicate(Predicate):
34
34
  return f' {self.operator} '.join([f'({e})' for e in self.components])
35
35
 
36
36
  @classmethod
37
- def make_conjunction(cls, operands: List[Predicate]) -> Optional[Predicate]:
37
+ def make_conjunction(cls, operands: List[Expr]) -> Optional[Expr]:
38
38
  if len(operands) == 0:
39
39
  return None
40
40
  if len(operands) == 1:
41
41
  return operands[0]
42
42
  return CompoundPredicate(LogicalOperator.AND, operands)
43
43
 
44
- def _merge_operand(self, operand: Predicate) -> None:
44
+ def _merge_operand(self, operand: Expr) -> None:
45
45
  """
46
46
  Merge this operand, if possible, otherwise simply record it.
47
47
  """
@@ -55,11 +55,11 @@ class CompoundPredicate(Predicate):
55
55
  def _equals(self, other: CompoundPredicate) -> bool:
56
56
  return self.operator == other.operator
57
57
 
58
- def _id_attrs(self) -> List[Tuple[str, Any]]:
58
+ def _id_attrs(self) -> list[tuple[str, Any]]:
59
59
  return super()._id_attrs() + [('operator', self.operator.value)]
60
60
 
61
61
  def split_conjuncts(
62
- self, condition: Callable[[Predicate], bool]) -> Tuple[List[Predicate], Optional[Predicate]]:
62
+ self, condition: Callable[[Expr], bool]) -> tuple[list[Expr], Optional[Expr]]:
63
63
  if self.operator == LogicalOperator.OR or self.operator == LogicalOperator.NOT:
64
64
  return super().split_conjuncts(condition)
65
65
  matches = [op for op in self.components if condition(op)]
pixeltable/exprs/expr.py CHANGED
@@ -518,6 +518,38 @@ class Expr(abc.ABC):
518
518
  return ArithmeticExpr(op, self, Literal(other)) # type: ignore[arg-type]
519
519
  raise TypeError(f'Other must be Expr or literal: {type(other)}')
520
520
 
521
+ def __and__(self, other: object) -> Expr:
522
+ if not isinstance(other, Expr):
523
+ raise TypeError(f'Other needs to be an expression: {type(other)}')
524
+ if not other.col_type.is_bool_type():
525
+ raise TypeError(f'Other needs to be an expression that returns a boolean: {other.col_type}')
526
+ from .compound_predicate import CompoundPredicate
527
+ return CompoundPredicate(LogicalOperator.AND, [self, other])
528
+
529
+ def __or__(self, other: object) -> Expr:
530
+ if not isinstance(other, Expr):
531
+ raise TypeError(f'Other needs to be an expression: {type(other)}')
532
+ if not other.col_type.is_bool_type():
533
+ raise TypeError(f'Other needs to be an expression that returns a boolean: {other.col_type}')
534
+ from .compound_predicate import CompoundPredicate
535
+ return CompoundPredicate(LogicalOperator.OR, [self, other])
536
+
537
+ def __invert__(self) -> Expr:
538
+ from .compound_predicate import CompoundPredicate
539
+ return CompoundPredicate(LogicalOperator.NOT, [self])
540
+
541
+ def split_conjuncts(
542
+ self, condition: Callable[[Expr], bool]) -> tuple[list[Expr], Optional[Expr]]:
543
+ """
544
+ Returns clauses of a conjunction that meet condition in the first element.
545
+ The second element contains remaining clauses, rolled into a conjunction.
546
+ """
547
+ assert self.col_type.is_bool_type() # only valid for predicates
548
+ if condition(self):
549
+ return [self], None
550
+ else:
551
+ return [], self
552
+
521
553
  def _make_applicator_function(self, fn: Callable, col_type: Optional[ts.ColumnType]) -> 'pixeltable.func.Function':
522
554
  """
523
555
  Creates a unary pixeltable `Function` that encapsulates a python `Callable`. The result type of
@@ -5,20 +5,20 @@ from typing import Optional, List, Any, Dict, Tuple, Iterable
5
5
  import sqlalchemy as sql
6
6
 
7
7
  import pixeltable.exceptions as excs
8
+ import pixeltable.type_system as ts
8
9
  from .data_row import DataRow
9
10
  from .expr import Expr
10
- from .predicate import Predicate
11
11
  from .row_builder import RowBuilder
12
12
 
13
13
 
14
- class InPredicate(Predicate):
14
+ class InPredicate(Expr):
15
15
  """Predicate corresponding to the SQL IN operator."""
16
16
 
17
17
  def __init__(self, lhs: Expr, value_set_literal: Optional[Iterable] = None, value_set_expr: Optional[Expr] = None):
18
18
  assert (value_set_literal is None) != (value_set_expr is None)
19
19
  if not lhs.col_type.is_scalar_type():
20
20
  raise excs.Error(f'isin(): only supported for scalar types, not {lhs.col_type}')
21
- super().__init__()
21
+ super().__init__(ts.BoolType())
22
22
 
23
23
  self.value_list: Optional[list] = None # only contains values of the correct type
24
24
  if value_set_expr is not None:
@@ -1,18 +1,18 @@
1
1
  from __future__ import annotations
2
+
2
3
  from typing import Optional, List, Dict
3
4
 
4
5
  import sqlalchemy as sql
5
6
 
6
- from .predicate import Predicate
7
- from .expr import Expr
7
+ import pixeltable.type_system as ts
8
8
  from .data_row import DataRow
9
+ from .expr import Expr
9
10
  from .row_builder import RowBuilder
10
- import pixeltable.catalog as catalog
11
11
 
12
12
 
13
- class IsNull(Predicate):
13
+ class IsNull(Expr):
14
14
  def __init__(self, e: Expr):
15
- super().__init__()
15
+ super().__init__(ts.BoolType())
16
16
  self.components = [e]
17
17
  self.id = self._create_id()
18
18
 
@@ -1,16 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
- import importlib
5
4
  import inspect
6
- from typing import Optional, Any, Type, List, Dict, Callable
7
- import itertools
5
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type
8
6
 
9
7
  import pixeltable.exceptions as excs
10
8
  import pixeltable.type_system as ts
9
+
11
10
  from .function import Function
12
- from .signature import Signature, Parameter
13
11
  from .globals import validate_symbol_path
12
+ from .signature import Parameter, Signature
13
+
14
+ if TYPE_CHECKING:
15
+ import pixeltable
14
16
 
15
17
 
16
18
  class Aggregator(abc.ABC):
@@ -40,6 +42,7 @@ class AggregateFunction(Function):
40
42
  self.requires_order_by = requires_order_by
41
43
  self.allows_std_agg = allows_std_agg
42
44
  self.allows_window = allows_window
45
+ self.__doc__ = aggregator_class.__doc__
43
46
 
44
47
  # our signature is the signature of 'update', but without self,
45
48
  # plus the parameters of 'init' as keyword-only parameters
@@ -135,6 +138,9 @@ class AggregateFunction(Function):
135
138
  f'expression'
136
139
  )
137
140
 
141
+ def __repr__(self) -> str:
142
+ return f'<Pixeltable Aggregator {self.name}>'
143
+
138
144
 
139
145
  def uda(
140
146
  *,
@@ -25,6 +25,7 @@ class CallableFunction(Function):
25
25
  self.py_fn = py_fn
26
26
  self.self_name = self_name
27
27
  self.batch_size = batch_size
28
+ self.__doc__ = py_fn.__doc__
28
29
  super().__init__(signature, self_path=self_path)
29
30
 
30
31
  @property
@@ -113,3 +114,6 @@ class CallableFunction(Function):
113
114
  f'{self.display_name}(): '
114
115
  f'parameter {param.name} must be a constant value, not a Pixeltable expression'
115
116
  )
117
+
118
+ def __repr__(self) -> str:
119
+ return f'<Pixeltable UDF {self.name}>'
@@ -66,6 +66,8 @@ class FunctionRegistry:
66
66
  # self.module_fns[fn_path] = obj
67
67
 
68
68
  def register_function(self, fqn: str, fn: Function) -> None:
69
+ if fqn in self.module_fns:
70
+ raise excs.Error(f'A UDF with that name already exists: {fqn}')
69
71
  self.module_fns[fqn] = fn
70
72
 
71
73
  def list_functions(self) -> List[Function]:
@@ -1,4 +1,4 @@
1
- from typing import Union
1
+ from typing import Optional, Union
2
2
 
3
3
  import pixeltable.func as func
4
4
  import pixeltable.type_system as ts
@@ -14,6 +14,7 @@ def cast(expr: exprs.Expr, target_type: ts.ColumnType) -> exprs.Expr:
14
14
 
15
15
  @func.uda(update_types=[ts.IntType()], value_type=ts.IntType(), allows_window=True, requires_order_by=False)
16
16
  class sum(func.Aggregator):
17
+ """Sums the selected integers or floats."""
17
18
  def __init__(self):
18
19
  self.sum: Union[int, float] = 0
19
20
 
@@ -38,6 +39,40 @@ class count(func.Aggregator):
38
39
  return self.count
39
40
 
40
41
 
42
+ @func.uda(update_types=[ts.FloatType()], value_type=ts.FloatType(nullable=True), allows_window=True, requires_order_by=False)
43
+ class max(func.Aggregator):
44
+ def __init__(self):
45
+ self.val = None
46
+
47
+ def update(self, val: Optional[float]) -> None:
48
+ if val is not None:
49
+ if self.val is None:
50
+ self.val = val
51
+ else:
52
+ import builtins
53
+ self.val = builtins.max(self.val, val)
54
+
55
+ def value(self) -> Optional[float]:
56
+ return self.val
57
+
58
+
59
+ @func.uda(update_types=[ts.FloatType()], value_type=ts.FloatType(nullable=True), allows_window=True, requires_order_by=False)
60
+ class min(func.Aggregator):
61
+ def __init__(self):
62
+ self.val = None
63
+
64
+ def update(self, val: Optional[float]) -> None:
65
+ if val is not None:
66
+ if self.val is None:
67
+ self.val = val
68
+ else:
69
+ import builtins
70
+ self.val = builtins.min(self.val, val)
71
+
72
+ def value(self) -> Optional[float]:
73
+ return self.val
74
+
75
+
41
76
  @func.uda(update_types=[ts.IntType()], value_type=ts.FloatType(), allows_window=False, requires_order_by=False)
42
77
  class mean(func.Aggregator):
43
78
  def __init__(self):
@@ -1,3 +1,12 @@
1
+ """
2
+ Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
3
+ that wrap various models from the Hugging Face `transformers` package.
4
+
5
+ These UDFs will cause Pixeltable to invoke the relevant models locally. In order to use them, you must
6
+ first `pip install transformers` (or in some cases, `sentence-transformers`, as noted in the specific
7
+ UDFs).
8
+ """
9
+
1
10
  from typing import Callable, TypeVar, Optional, Any
2
11
 
3
12
  import PIL.Image
@@ -13,15 +22,39 @@ from pixeltable.utils.code import local_public_names
13
22
 
14
23
  @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
15
24
  def sentence_transformer(
16
- sentences: Batch[str], *, model_id: str, normalize_embeddings: bool = False
25
+ sentence: Batch[str], *, model_id: str, normalize_embeddings: bool = False
17
26
  ) -> Batch[np.ndarray]:
18
- """Runs the specified sentence transformer model."""
27
+ """
28
+ Runs the specified pretrained sentence-transformers model. `model_id` should be a pretrained model, as described
29
+ in the [Sentence Transformers Pretrained Models](https://sbert.net/docs/sentence_transformer/pretrained_models.html)
30
+ documentation.
31
+
32
+ __Requirements:__
33
+
34
+ - `pip install sentence-transformers`
35
+
36
+ Args:
37
+ sentence: The sentence to embed.
38
+ model_id: The pretrained model to use for the encoding.
39
+ normalize_embeddings: If `True`, normalizes embeddings to length 1; see the
40
+ [Sentence Transformers API Docs](https://sbert.net/docs/package_reference/sentence_transformer/SentenceTransformer.html)
41
+ for more details
42
+
43
+ Returns:
44
+ An array containing the output of the embedding model.
45
+
46
+ Examples:
47
+ Add a computed column that applies the model `all-mpnet-base-2` to an existing Pixeltable column `tbl.sentence`
48
+ of the table `tbl`:
49
+
50
+ >>> tbl['result'] = sentence_transformer(tbl.sentence, model_id='all-mpnet-base-v2')
51
+ """
19
52
  env.Env.get().require_package('sentence_transformers')
20
53
  from sentence_transformers import SentenceTransformer
21
54
 
22
55
  model = _lookup_model(model_id, SentenceTransformer)
23
56
 
24
- array = model.encode(sentences, normalize_embeddings=normalize_embeddings)
57
+ array = model.encode(sentence, normalize_embeddings=normalize_embeddings)
25
58
  return [array[i] for i in range(array.shape[0])]
26
59
 
27
60
 
@@ -49,7 +82,32 @@ def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embed
49
82
 
50
83
  @pxt.udf(batch_size=32)
51
84
  def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: str) -> Batch[float]:
52
- """Runs the specified cross-encoder model."""
85
+ """
86
+ Runs the specified cross-encoder model to compute similarity scores for pairs of sentences.
87
+ `model_id` should be a pretrained model, as described in the
88
+ [Cross-Encoder Pretrained Models](https://www.sbert.net/docs/cross_encoder/pretrained_models.html)
89
+ documentation.
90
+
91
+ __Requirements:__
92
+
93
+ - `pip install sentence-transformers`
94
+
95
+ Parameters:
96
+ sentences1: The first sentence to be paired.
97
+ sentences2: The second sentence to be paired.
98
+ model_id: The identifier of the cross-encoder model to use.
99
+
100
+ Returns:
101
+ The similarity score between the inputs.
102
+
103
+ Examples:
104
+ Add a computed column that applies the model `ms-marco-MiniLM-L-4-v2` to the sentences in
105
+ columns `tbl.sentence1` and `tbl.sentence2`:
106
+
107
+ >>> tbl['result'] = sentence_transformer(
108
+ tbl.sentence1, tbl.sentence2, model_id='ms-marco-MiniLM-L-4-v2'
109
+ )
110
+ """
53
111
  env.Env.get().require_package('sentence_transformers')
54
112
  from sentence_transformers import CrossEncoder
55
113