pixeltable 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/insertable_table.py +2 -2
- pixeltable/catalog/table.py +5 -5
- pixeltable/catalog/table_version.py +12 -14
- pixeltable/catalog/view.py +2 -2
- pixeltable/dataframe.py +7 -6
- pixeltable/exec/expr_eval_node.py +8 -1
- pixeltable/exec/sql_scan_node.py +1 -1
- pixeltable/exprs/__init__.py +0 -1
- pixeltable/exprs/comparison.py +5 -5
- pixeltable/exprs/compound_predicate.py +12 -12
- pixeltable/exprs/expr.py +32 -0
- pixeltable/exprs/in_predicate.py +3 -3
- pixeltable/exprs/is_null.py +5 -5
- pixeltable/func/aggregate_function.py +10 -4
- pixeltable/func/callable_function.py +4 -0
- pixeltable/func/function_registry.py +2 -0
- pixeltable/functions/globals.py +36 -1
- pixeltable/functions/huggingface.py +62 -4
- pixeltable/functions/image.py +17 -0
- pixeltable/functions/string.py +622 -7
- pixeltable/functions/video.py +26 -8
- pixeltable/globals.py +3 -3
- pixeltable/io/globals.py +53 -4
- pixeltable/io/label_studio.py +42 -2
- pixeltable/io/pandas.py +18 -7
- pixeltable/plan.py +6 -6
- pixeltable/tool/create_test_db_dump.py +1 -1
- pixeltable/tool/doc_plugins/griffe.py +77 -0
- pixeltable/tool/doc_plugins/mkdocstrings.py +6 -0
- pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +135 -0
- pixeltable/utils/s3.py +1 -1
- pixeltable-0.2.13.dist-info/METADATA +206 -0
- {pixeltable-0.2.12.dist-info → pixeltable-0.2.13.dist-info}/RECORD +37 -34
- pixeltable-0.2.13.dist-info/entry_points.txt +3 -0
- pixeltable/exprs/predicate.py +0 -44
- pixeltable-0.2.12.dist-info/METADATA +0 -137
- {pixeltable-0.2.12.dist-info → pixeltable-0.2.13.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.12.dist-info → pixeltable-0.2.13.dist-info}/WHEEL +0 -0
pixeltable/__version__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
|
1
1
|
# These version placeholders will be replaced during build.
|
|
2
|
-
__version__ = "0.2.
|
|
3
|
-
__version_tuple__ = (0, 2,
|
|
2
|
+
__version__ = "0.2.13"
|
|
3
|
+
__version_tuple__ = (0, 2, 13)
|
|
@@ -129,11 +129,11 @@ class InsertableTable(Table):
|
|
|
129
129
|
msg = str(e)
|
|
130
130
|
raise excs.Error(f'Error in column {col.name}: {msg[0].lower() + msg[1:]}\nRow: {row}')
|
|
131
131
|
|
|
132
|
-
def delete(self, where: Optional['pixeltable.exprs.
|
|
132
|
+
def delete(self, where: Optional['pixeltable.exprs.Expr'] = None) -> UpdateStatus:
|
|
133
133
|
"""Delete rows in this table.
|
|
134
134
|
|
|
135
135
|
Args:
|
|
136
|
-
where: a
|
|
136
|
+
where: a predicate to filter rows to delete.
|
|
137
137
|
|
|
138
138
|
Examples:
|
|
139
139
|
Delete all rows in a table:
|
pixeltable/catalog/table.py
CHANGED
|
@@ -113,7 +113,7 @@ class Table(SchemaObject):
|
|
|
113
113
|
from pixeltable.dataframe import DataFrame
|
|
114
114
|
return DataFrame(self._tbl_version_path).select(*items, **named_items)
|
|
115
115
|
|
|
116
|
-
def where(self, pred: 'exprs.
|
|
116
|
+
def where(self, pred: 'exprs.Expr') -> 'pixeltable.dataframe.DataFrame':
|
|
117
117
|
"""Return a DataFrame for this table.
|
|
118
118
|
"""
|
|
119
119
|
# local import: avoid circular imports
|
|
@@ -716,13 +716,13 @@ class Table(SchemaObject):
|
|
|
716
716
|
raise NotImplementedError
|
|
717
717
|
|
|
718
718
|
def update(
|
|
719
|
-
self, value_spec: dict[str, Any], where: Optional['pixeltable.exprs.
|
|
719
|
+
self, value_spec: dict[str, Any], where: Optional['pixeltable.exprs.Expr'] = None, cascade: bool = True
|
|
720
720
|
) -> UpdateStatus:
|
|
721
721
|
"""Update rows in this table.
|
|
722
722
|
|
|
723
723
|
Args:
|
|
724
724
|
value_spec: a dictionary mapping column names to literal values or Pixeltable expressions.
|
|
725
|
-
where: a
|
|
725
|
+
where: a predicate to filter rows to update.
|
|
726
726
|
cascade: if True, also update all computed columns that transitively depend on the updated columns.
|
|
727
727
|
|
|
728
728
|
Examples:
|
|
@@ -786,11 +786,11 @@ class Table(SchemaObject):
|
|
|
786
786
|
row_updates.append(col_vals)
|
|
787
787
|
return self._tbl_version.batch_update(row_updates, rowids, cascade)
|
|
788
788
|
|
|
789
|
-
def delete(self, where: Optional['pixeltable.exprs.
|
|
789
|
+
def delete(self, where: Optional['pixeltable.exprs.Expr'] = None) -> UpdateStatus:
|
|
790
790
|
"""Delete rows in this table.
|
|
791
791
|
|
|
792
792
|
Args:
|
|
793
|
-
where: a
|
|
793
|
+
where: a predicate to filter rows to delete.
|
|
794
794
|
|
|
795
795
|
Examples:
|
|
796
796
|
Delete all rows in a table:
|
|
@@ -678,12 +678,12 @@ class TableVersion:
|
|
|
678
678
|
return result
|
|
679
679
|
|
|
680
680
|
def update(
|
|
681
|
-
self, value_spec: dict[str, Any], where: Optional['exprs.
|
|
681
|
+
self, value_spec: dict[str, Any], where: Optional['exprs.Expr'] = None, cascade: bool = True
|
|
682
682
|
) -> UpdateStatus:
|
|
683
683
|
"""Update rows in this TableVersionPath.
|
|
684
684
|
Args:
|
|
685
685
|
value_spec: a list of (column, value) pairs specifying the columns to update and their new values.
|
|
686
|
-
where: a
|
|
686
|
+
where: a predicate to filter rows to update.
|
|
687
687
|
cascade: if True, also update all computed columns that transitively depend on the updated columns,
|
|
688
688
|
including within views.
|
|
689
689
|
"""
|
|
@@ -694,8 +694,8 @@ class TableVersion:
|
|
|
694
694
|
|
|
695
695
|
update_spec = self._validate_update_spec(value_spec, allow_pk=False, allow_exprs=True)
|
|
696
696
|
if where is not None:
|
|
697
|
-
if not isinstance(where, exprs.
|
|
698
|
-
raise excs.Error(f"'where' argument must be a
|
|
697
|
+
if not isinstance(where, exprs.Expr):
|
|
698
|
+
raise excs.Error(f"'where' argument must be a predicate, got {type(where)}")
|
|
699
699
|
analysis_info = Planner.analyze(self.path, where)
|
|
700
700
|
# for now we require that the updated rows can be identified via SQL, rather than via a Python filter
|
|
701
701
|
if analysis_info.filter is not None:
|
|
@@ -757,7 +757,7 @@ class TableVersion:
|
|
|
757
757
|
|
|
758
758
|
def _update(
|
|
759
759
|
self, conn: sql.engine.Connection, update_targets: dict[Column, 'pixeltable.exprs.Expr'],
|
|
760
|
-
where_clause: Optional['pixeltable.exprs.
|
|
760
|
+
where_clause: Optional['pixeltable.exprs.Expr'] = None, cascade: bool = True,
|
|
761
761
|
show_progress: bool = True
|
|
762
762
|
) -> UpdateStatus:
|
|
763
763
|
from pixeltable.plan import Planner
|
|
@@ -789,8 +789,6 @@ class TableVersion:
|
|
|
789
789
|
raise excs.Error(f'Column {col_name} is computed and cannot be updated')
|
|
790
790
|
if col.is_pk and not allow_pk:
|
|
791
791
|
raise excs.Error(f'Column {col_name} is a primary key column and cannot be updated')
|
|
792
|
-
if col.col_type.is_media_type():
|
|
793
|
-
raise excs.Error(f'Column {col_name} has type image/video/audio/document and cannot be updated')
|
|
794
792
|
|
|
795
793
|
# make sure that the value is compatible with the column type
|
|
796
794
|
try:
|
|
@@ -848,17 +846,17 @@ class TableVersion:
|
|
|
848
846
|
result.cols_with_excs = list(dict.fromkeys(result.cols_with_excs).keys()) # remove duplicates
|
|
849
847
|
return result
|
|
850
848
|
|
|
851
|
-
def delete(self, where: Optional['exprs.
|
|
849
|
+
def delete(self, where: Optional['exprs.Expr'] = None) -> UpdateStatus:
|
|
852
850
|
"""Delete rows in this table.
|
|
853
851
|
Args:
|
|
854
|
-
where: a
|
|
852
|
+
where: a predicate to filter rows to delete.
|
|
855
853
|
"""
|
|
856
854
|
assert self.is_insertable()
|
|
857
|
-
from pixeltable.exprs import
|
|
855
|
+
from pixeltable.exprs import Expr
|
|
858
856
|
from pixeltable.plan import Planner
|
|
859
857
|
if where is not None:
|
|
860
|
-
if not isinstance(where,
|
|
861
|
-
raise excs.Error(f"'where' argument must be a
|
|
858
|
+
if not isinstance(where, Expr):
|
|
859
|
+
raise excs.Error(f"'where' argument must be a predicate, got {type(where)}")
|
|
862
860
|
analysis_info = Planner.analyze(self.path, where)
|
|
863
861
|
# for now we require that the updated rows can be identified via SQL, rather than via a Python filter
|
|
864
862
|
if analysis_info.filter is not None:
|
|
@@ -872,11 +870,11 @@ class TableVersion:
|
|
|
872
870
|
return status
|
|
873
871
|
|
|
874
872
|
def propagate_delete(
|
|
875
|
-
self, where: Optional['exprs.
|
|
873
|
+
self, where: Optional['exprs.Expr'], base_versions: List[Optional[int]],
|
|
876
874
|
conn: sql.engine.Connection, timestamp: float) -> int:
|
|
877
875
|
"""Delete rows in this table and propagate to views.
|
|
878
876
|
Args:
|
|
879
|
-
where: a
|
|
877
|
+
where: a predicate to filter rows to delete.
|
|
880
878
|
Returns:
|
|
881
879
|
number of deleted rows
|
|
882
880
|
"""
|
pixeltable/catalog/view.py
CHANGED
|
@@ -51,7 +51,7 @@ class View(Table):
|
|
|
51
51
|
@classmethod
|
|
52
52
|
def create(
|
|
53
53
|
cls, dir_id: UUID, name: str, base: TableVersionPath, schema: Dict[str, Any],
|
|
54
|
-
predicate: 'pxt.exprs.
|
|
54
|
+
predicate: 'pxt.exprs.Expr', is_snapshot: bool, num_retained_versions: int, comment: str,
|
|
55
55
|
iterator_cls: Optional[Type[ComponentIterator]], iterator_args: Optional[Dict]
|
|
56
56
|
) -> View:
|
|
57
57
|
columns = cls._create_columns(schema)
|
|
@@ -213,5 +213,5 @@ class View(Table):
|
|
|
213
213
|
) -> UpdateStatus:
|
|
214
214
|
raise excs.Error(f'{self.display_name()} {self._name!r}: cannot insert into view')
|
|
215
215
|
|
|
216
|
-
def delete(self, where: Optional['pixeltable.exprs.
|
|
216
|
+
def delete(self, where: Optional['pixeltable.exprs.Expr'] = None) -> UpdateStatus:
|
|
217
217
|
raise excs.Error(f'{self.display_name()} {self._name!r}: cannot delete from view')
|
pixeltable/dataframe.py
CHANGED
|
@@ -153,7 +153,7 @@ class DataFrame:
|
|
|
153
153
|
self,
|
|
154
154
|
tbl: catalog.TableVersionPath,
|
|
155
155
|
select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]] = None,
|
|
156
|
-
where_clause: Optional[exprs.
|
|
156
|
+
where_clause: Optional[exprs.Expr] = None,
|
|
157
157
|
group_by_clause: Optional[List[exprs.Expr]] = None,
|
|
158
158
|
grouping_tbl: Optional[catalog.TableVersion] = None,
|
|
159
159
|
order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, # List[(expr, asc)]
|
|
@@ -530,7 +530,11 @@ class DataFrame:
|
|
|
530
530
|
limit=self.limit_val,
|
|
531
531
|
)
|
|
532
532
|
|
|
533
|
-
def where(self, pred: exprs.
|
|
533
|
+
def where(self, pred: exprs.Expr) -> DataFrame:
|
|
534
|
+
if not isinstance(pred, exprs.Expr):
|
|
535
|
+
raise excs.Error(f'Where() requires a Pixeltable expression, but instead got {type(pred)}')
|
|
536
|
+
if not pred.col_type.is_bool_type():
|
|
537
|
+
raise excs.Error(f'Where(): expression needs to return bool, but instead returns {pred.col_type}')
|
|
534
538
|
return DataFrame(
|
|
535
539
|
self.tbl,
|
|
536
540
|
select_list=self.select_list,
|
|
@@ -628,12 +632,9 @@ class DataFrame:
|
|
|
628
632
|
def __getitem__(self, index: object) -> DataFrame:
|
|
629
633
|
"""
|
|
630
634
|
Allowed:
|
|
631
|
-
- [<Predicate>]: filter operation
|
|
632
635
|
- [List[Expr]]/[Tuple[Expr]]: setting the select list
|
|
633
636
|
- [Expr]: setting a single-col select list
|
|
634
637
|
"""
|
|
635
|
-
if isinstance(index, exprs.Predicate):
|
|
636
|
-
return self.where(index)
|
|
637
638
|
if isinstance(index, tuple):
|
|
638
639
|
index = list(index)
|
|
639
640
|
if isinstance(index, exprs.Expr):
|
|
@@ -668,7 +669,7 @@ class DataFrame:
|
|
|
668
669
|
tbl = catalog.TableVersionPath.from_dict(d['tbl'])
|
|
669
670
|
select_list = [(exprs.Expr.from_dict(e), name) for e, name in d['select_list']] \
|
|
670
671
|
if d['select_list'] is not None else None
|
|
671
|
-
where_clause = exprs.
|
|
672
|
+
where_clause = exprs.Expr.from_dict(d['where_clause']) \
|
|
672
673
|
if d['where_clause'] is not None else None
|
|
673
674
|
group_by_clause = [exprs.Expr.from_dict(e) for e in d['group_by_clause']] \
|
|
674
675
|
if d['group_by_clause'] is not None else None
|
|
@@ -50,7 +50,14 @@ class ExprEvalNode(ExecNode):
|
|
|
50
50
|
|
|
51
51
|
def _open(self) -> None:
|
|
52
52
|
warnings.simplefilter("ignore", category=TqdmWarning)
|
|
53
|
-
|
|
53
|
+
# This is a temporary hack. When B-tree indices on string columns were implemented (via computed columns
|
|
54
|
+
# that invoke the `BtreeIndex.str_filter` udf), it resulted in frivolous progress bars appearing on every
|
|
55
|
+
# insertion. This special-cases the `str_filter` call to suppress the corresponding progress bar.
|
|
56
|
+
# TODO(aaron-siegel) Remove this hack once we clean up progress bars more generally.
|
|
57
|
+
is_str_filter_node = all(
|
|
58
|
+
isinstance(expr, exprs.FunctionCall) and expr.fn.name == 'str_filter' for expr in self.output_exprs
|
|
59
|
+
)
|
|
60
|
+
if self.ctx.show_pbar and not is_str_filter_node:
|
|
54
61
|
self.pbar = tqdm(
|
|
55
62
|
total=len(self.target_exprs) * self.ctx.num_rows,
|
|
56
63
|
desc='Computing cells',
|
pixeltable/exec/sql_scan_node.py
CHANGED
|
@@ -19,7 +19,7 @@ class SqlScanNode(ExecNode):
|
|
|
19
19
|
def __init__(
|
|
20
20
|
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
21
21
|
select_list: Iterable[exprs.Expr],
|
|
22
|
-
where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.
|
|
22
|
+
where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Expr] = None,
|
|
23
23
|
order_by_items: Optional[List[Tuple[exprs.Expr, bool]]] = None,
|
|
24
24
|
limit: int = 0, set_pk: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
|
|
25
25
|
):
|
pixeltable/exprs/__init__.py
CHANGED
|
@@ -17,7 +17,6 @@ from .json_mapper import JsonMapper
|
|
|
17
17
|
from .json_path import RELATIVE_PATH_ROOT, JsonPath
|
|
18
18
|
from .literal import Literal
|
|
19
19
|
from .object_ref import ObjectRef
|
|
20
|
-
from .predicate import Predicate
|
|
21
20
|
from .row_builder import RowBuilder, ColumnSlotIdx, ExecProfile
|
|
22
21
|
from .rowid_ref import RowidRef
|
|
23
22
|
from .similarity_expr import SimilarityExpr
|
pixeltable/exprs/comparison.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Optional, List, Any, Dict
|
|
3
|
+
from typing import Optional, List, Any, Dict
|
|
4
4
|
|
|
5
5
|
import sqlalchemy as sql
|
|
6
6
|
|
|
@@ -9,15 +9,15 @@ from .data_row import DataRow
|
|
|
9
9
|
from .expr import Expr
|
|
10
10
|
from .globals import ComparisonOperator
|
|
11
11
|
from .literal import Literal
|
|
12
|
-
from .predicate import Predicate
|
|
13
12
|
from .row_builder import RowBuilder
|
|
14
13
|
import pixeltable.exceptions as excs
|
|
15
14
|
import pixeltable.index as index
|
|
15
|
+
import pixeltable.type_system as ts
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class Comparison(
|
|
18
|
+
class Comparison(Expr):
|
|
19
19
|
def __init__(self, operator: ComparisonOperator, op1: Expr, op2: Expr):
|
|
20
|
-
super().__init__()
|
|
20
|
+
super().__init__(ts.BoolType())
|
|
21
21
|
self.operator = operator
|
|
22
22
|
|
|
23
23
|
# if this is a comparison of a column to a literal (ie, could be used as a search argument in an index lookup),
|
|
@@ -50,7 +50,7 @@ class Comparison(Predicate):
|
|
|
50
50
|
def _equals(self, other: Comparison) -> bool:
|
|
51
51
|
return self.operator == other.operator
|
|
52
52
|
|
|
53
|
-
def _id_attrs(self) ->
|
|
53
|
+
def _id_attrs(self) -> list[tuple[str, Any]]:
|
|
54
54
|
return super()._id_attrs() + [('operator', self.operator.value)]
|
|
55
55
|
|
|
56
56
|
@property
|
|
@@ -1,20 +1,20 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import operator
|
|
4
|
+
from typing import Optional, List, Any, Dict, Callable
|
|
4
5
|
|
|
5
6
|
import sqlalchemy as sql
|
|
6
7
|
|
|
8
|
+
from .data_row import DataRow
|
|
7
9
|
from .expr import Expr
|
|
8
10
|
from .globals import LogicalOperator
|
|
9
|
-
from .predicate import Predicate
|
|
10
|
-
from .data_row import DataRow
|
|
11
11
|
from .row_builder import RowBuilder
|
|
12
|
-
import pixeltable.
|
|
12
|
+
import pixeltable.type_system as ts
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class CompoundPredicate(
|
|
16
|
-
def __init__(self, operator: LogicalOperator, operands: List[
|
|
17
|
-
super().__init__()
|
|
15
|
+
class CompoundPredicate(Expr):
|
|
16
|
+
def __init__(self, operator: LogicalOperator, operands: List[Expr]):
|
|
17
|
+
super().__init__(ts.BoolType())
|
|
18
18
|
self.operator = operator
|
|
19
19
|
# operands are stored in self.components
|
|
20
20
|
if self.operator == LogicalOperator.NOT:
|
|
@@ -22,7 +22,7 @@ class CompoundPredicate(Predicate):
|
|
|
22
22
|
self.components = operands
|
|
23
23
|
else:
|
|
24
24
|
assert len(operands) > 1
|
|
25
|
-
self.operands: List[
|
|
25
|
+
self.operands: List[Expr] = []
|
|
26
26
|
for operand in operands:
|
|
27
27
|
self._merge_operand(operand)
|
|
28
28
|
|
|
@@ -34,14 +34,14 @@ class CompoundPredicate(Predicate):
|
|
|
34
34
|
return f' {self.operator} '.join([f'({e})' for e in self.components])
|
|
35
35
|
|
|
36
36
|
@classmethod
|
|
37
|
-
def make_conjunction(cls, operands: List[
|
|
37
|
+
def make_conjunction(cls, operands: List[Expr]) -> Optional[Expr]:
|
|
38
38
|
if len(operands) == 0:
|
|
39
39
|
return None
|
|
40
40
|
if len(operands) == 1:
|
|
41
41
|
return operands[0]
|
|
42
42
|
return CompoundPredicate(LogicalOperator.AND, operands)
|
|
43
43
|
|
|
44
|
-
def _merge_operand(self, operand:
|
|
44
|
+
def _merge_operand(self, operand: Expr) -> None:
|
|
45
45
|
"""
|
|
46
46
|
Merge this operand, if possible, otherwise simply record it.
|
|
47
47
|
"""
|
|
@@ -55,11 +55,11 @@ class CompoundPredicate(Predicate):
|
|
|
55
55
|
def _equals(self, other: CompoundPredicate) -> bool:
|
|
56
56
|
return self.operator == other.operator
|
|
57
57
|
|
|
58
|
-
def _id_attrs(self) ->
|
|
58
|
+
def _id_attrs(self) -> list[tuple[str, Any]]:
|
|
59
59
|
return super()._id_attrs() + [('operator', self.operator.value)]
|
|
60
60
|
|
|
61
61
|
def split_conjuncts(
|
|
62
|
-
self, condition: Callable[[
|
|
62
|
+
self, condition: Callable[[Expr], bool]) -> tuple[list[Expr], Optional[Expr]]:
|
|
63
63
|
if self.operator == LogicalOperator.OR or self.operator == LogicalOperator.NOT:
|
|
64
64
|
return super().split_conjuncts(condition)
|
|
65
65
|
matches = [op for op in self.components if condition(op)]
|
pixeltable/exprs/expr.py
CHANGED
|
@@ -518,6 +518,38 @@ class Expr(abc.ABC):
|
|
|
518
518
|
return ArithmeticExpr(op, self, Literal(other)) # type: ignore[arg-type]
|
|
519
519
|
raise TypeError(f'Other must be Expr or literal: {type(other)}')
|
|
520
520
|
|
|
521
|
+
def __and__(self, other: object) -> Expr:
|
|
522
|
+
if not isinstance(other, Expr):
|
|
523
|
+
raise TypeError(f'Other needs to be an expression: {type(other)}')
|
|
524
|
+
if not other.col_type.is_bool_type():
|
|
525
|
+
raise TypeError(f'Other needs to be an expression that returns a boolean: {other.col_type}')
|
|
526
|
+
from .compound_predicate import CompoundPredicate
|
|
527
|
+
return CompoundPredicate(LogicalOperator.AND, [self, other])
|
|
528
|
+
|
|
529
|
+
def __or__(self, other: object) -> Expr:
|
|
530
|
+
if not isinstance(other, Expr):
|
|
531
|
+
raise TypeError(f'Other needs to be an expression: {type(other)}')
|
|
532
|
+
if not other.col_type.is_bool_type():
|
|
533
|
+
raise TypeError(f'Other needs to be an expression that returns a boolean: {other.col_type}')
|
|
534
|
+
from .compound_predicate import CompoundPredicate
|
|
535
|
+
return CompoundPredicate(LogicalOperator.OR, [self, other])
|
|
536
|
+
|
|
537
|
+
def __invert__(self) -> Expr:
|
|
538
|
+
from .compound_predicate import CompoundPredicate
|
|
539
|
+
return CompoundPredicate(LogicalOperator.NOT, [self])
|
|
540
|
+
|
|
541
|
+
def split_conjuncts(
|
|
542
|
+
self, condition: Callable[[Expr], bool]) -> tuple[list[Expr], Optional[Expr]]:
|
|
543
|
+
"""
|
|
544
|
+
Returns clauses of a conjunction that meet condition in the first element.
|
|
545
|
+
The second element contains remaining clauses, rolled into a conjunction.
|
|
546
|
+
"""
|
|
547
|
+
assert self.col_type.is_bool_type() # only valid for predicates
|
|
548
|
+
if condition(self):
|
|
549
|
+
return [self], None
|
|
550
|
+
else:
|
|
551
|
+
return [], self
|
|
552
|
+
|
|
521
553
|
def _make_applicator_function(self, fn: Callable, col_type: Optional[ts.ColumnType]) -> 'pixeltable.func.Function':
|
|
522
554
|
"""
|
|
523
555
|
Creates a unary pixeltable `Function` that encapsulates a python `Callable`. The result type of
|
pixeltable/exprs/in_predicate.py
CHANGED
|
@@ -5,20 +5,20 @@ from typing import Optional, List, Any, Dict, Tuple, Iterable
|
|
|
5
5
|
import sqlalchemy as sql
|
|
6
6
|
|
|
7
7
|
import pixeltable.exceptions as excs
|
|
8
|
+
import pixeltable.type_system as ts
|
|
8
9
|
from .data_row import DataRow
|
|
9
10
|
from .expr import Expr
|
|
10
|
-
from .predicate import Predicate
|
|
11
11
|
from .row_builder import RowBuilder
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
class InPredicate(
|
|
14
|
+
class InPredicate(Expr):
|
|
15
15
|
"""Predicate corresponding to the SQL IN operator."""
|
|
16
16
|
|
|
17
17
|
def __init__(self, lhs: Expr, value_set_literal: Optional[Iterable] = None, value_set_expr: Optional[Expr] = None):
|
|
18
18
|
assert (value_set_literal is None) != (value_set_expr is None)
|
|
19
19
|
if not lhs.col_type.is_scalar_type():
|
|
20
20
|
raise excs.Error(f'isin(): only supported for scalar types, not {lhs.col_type}')
|
|
21
|
-
super().__init__()
|
|
21
|
+
super().__init__(ts.BoolType())
|
|
22
22
|
|
|
23
23
|
self.value_list: Optional[list] = None # only contains values of the correct type
|
|
24
24
|
if value_set_expr is not None:
|
pixeltable/exprs/is_null.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
from typing import Optional, List, Dict
|
|
3
4
|
|
|
4
5
|
import sqlalchemy as sql
|
|
5
6
|
|
|
6
|
-
|
|
7
|
-
from .expr import Expr
|
|
7
|
+
import pixeltable.type_system as ts
|
|
8
8
|
from .data_row import DataRow
|
|
9
|
+
from .expr import Expr
|
|
9
10
|
from .row_builder import RowBuilder
|
|
10
|
-
import pixeltable.catalog as catalog
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
class IsNull(
|
|
13
|
+
class IsNull(Expr):
|
|
14
14
|
def __init__(self, e: Expr):
|
|
15
|
-
super().__init__()
|
|
15
|
+
super().__init__(ts.BoolType())
|
|
16
16
|
self.components = [e]
|
|
17
17
|
self.id = self._create_id()
|
|
18
18
|
|
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
-
import importlib
|
|
5
4
|
import inspect
|
|
6
|
-
from typing import
|
|
7
|
-
import itertools
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type
|
|
8
6
|
|
|
9
7
|
import pixeltable.exceptions as excs
|
|
10
8
|
import pixeltable.type_system as ts
|
|
9
|
+
|
|
11
10
|
from .function import Function
|
|
12
|
-
from .signature import Signature, Parameter
|
|
13
11
|
from .globals import validate_symbol_path
|
|
12
|
+
from .signature import Parameter, Signature
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import pixeltable
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
class Aggregator(abc.ABC):
|
|
@@ -40,6 +42,7 @@ class AggregateFunction(Function):
|
|
|
40
42
|
self.requires_order_by = requires_order_by
|
|
41
43
|
self.allows_std_agg = allows_std_agg
|
|
42
44
|
self.allows_window = allows_window
|
|
45
|
+
self.__doc__ = aggregator_class.__doc__
|
|
43
46
|
|
|
44
47
|
# our signature is the signature of 'update', but without self,
|
|
45
48
|
# plus the parameters of 'init' as keyword-only parameters
|
|
@@ -135,6 +138,9 @@ class AggregateFunction(Function):
|
|
|
135
138
|
f'expression'
|
|
136
139
|
)
|
|
137
140
|
|
|
141
|
+
def __repr__(self) -> str:
|
|
142
|
+
return f'<Pixeltable Aggregator {self.name}>'
|
|
143
|
+
|
|
138
144
|
|
|
139
145
|
def uda(
|
|
140
146
|
*,
|
|
@@ -25,6 +25,7 @@ class CallableFunction(Function):
|
|
|
25
25
|
self.py_fn = py_fn
|
|
26
26
|
self.self_name = self_name
|
|
27
27
|
self.batch_size = batch_size
|
|
28
|
+
self.__doc__ = py_fn.__doc__
|
|
28
29
|
super().__init__(signature, self_path=self_path)
|
|
29
30
|
|
|
30
31
|
@property
|
|
@@ -113,3 +114,6 @@ class CallableFunction(Function):
|
|
|
113
114
|
f'{self.display_name}(): '
|
|
114
115
|
f'parameter {param.name} must be a constant value, not a Pixeltable expression'
|
|
115
116
|
)
|
|
117
|
+
|
|
118
|
+
def __repr__(self) -> str:
|
|
119
|
+
return f'<Pixeltable UDF {self.name}>'
|
|
@@ -66,6 +66,8 @@ class FunctionRegistry:
|
|
|
66
66
|
# self.module_fns[fn_path] = obj
|
|
67
67
|
|
|
68
68
|
def register_function(self, fqn: str, fn: Function) -> None:
|
|
69
|
+
if fqn in self.module_fns:
|
|
70
|
+
raise excs.Error(f'A UDF with that name already exists: {fqn}')
|
|
69
71
|
self.module_fns[fqn] = fn
|
|
70
72
|
|
|
71
73
|
def list_functions(self) -> List[Function]:
|
pixeltable/functions/globals.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Union
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
import pixeltable.func as func
|
|
4
4
|
import pixeltable.type_system as ts
|
|
@@ -14,6 +14,7 @@ def cast(expr: exprs.Expr, target_type: ts.ColumnType) -> exprs.Expr:
|
|
|
14
14
|
|
|
15
15
|
@func.uda(update_types=[ts.IntType()], value_type=ts.IntType(), allows_window=True, requires_order_by=False)
|
|
16
16
|
class sum(func.Aggregator):
|
|
17
|
+
"""Sums the selected integers or floats."""
|
|
17
18
|
def __init__(self):
|
|
18
19
|
self.sum: Union[int, float] = 0
|
|
19
20
|
|
|
@@ -38,6 +39,40 @@ class count(func.Aggregator):
|
|
|
38
39
|
return self.count
|
|
39
40
|
|
|
40
41
|
|
|
42
|
+
@func.uda(update_types=[ts.FloatType()], value_type=ts.FloatType(nullable=True), allows_window=True, requires_order_by=False)
|
|
43
|
+
class max(func.Aggregator):
|
|
44
|
+
def __init__(self):
|
|
45
|
+
self.val = None
|
|
46
|
+
|
|
47
|
+
def update(self, val: Optional[float]) -> None:
|
|
48
|
+
if val is not None:
|
|
49
|
+
if self.val is None:
|
|
50
|
+
self.val = val
|
|
51
|
+
else:
|
|
52
|
+
import builtins
|
|
53
|
+
self.val = builtins.max(self.val, val)
|
|
54
|
+
|
|
55
|
+
def value(self) -> Optional[float]:
|
|
56
|
+
return self.val
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@func.uda(update_types=[ts.FloatType()], value_type=ts.FloatType(nullable=True), allows_window=True, requires_order_by=False)
|
|
60
|
+
class min(func.Aggregator):
|
|
61
|
+
def __init__(self):
|
|
62
|
+
self.val = None
|
|
63
|
+
|
|
64
|
+
def update(self, val: Optional[float]) -> None:
|
|
65
|
+
if val is not None:
|
|
66
|
+
if self.val is None:
|
|
67
|
+
self.val = val
|
|
68
|
+
else:
|
|
69
|
+
import builtins
|
|
70
|
+
self.val = builtins.min(self.val, val)
|
|
71
|
+
|
|
72
|
+
def value(self) -> Optional[float]:
|
|
73
|
+
return self.val
|
|
74
|
+
|
|
75
|
+
|
|
41
76
|
@func.uda(update_types=[ts.IntType()], value_type=ts.FloatType(), allows_window=False, requires_order_by=False)
|
|
42
77
|
class mean(func.Aggregator):
|
|
43
78
|
def __init__(self):
|
|
@@ -1,3 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
|
|
3
|
+
that wrap various models from the Hugging Face `transformers` package.
|
|
4
|
+
|
|
5
|
+
These UDFs will cause Pixeltable to invoke the relevant models locally. In order to use them, you must
|
|
6
|
+
first `pip install transformers` (or in some cases, `sentence-transformers`, as noted in the specific
|
|
7
|
+
UDFs).
|
|
8
|
+
"""
|
|
9
|
+
|
|
1
10
|
from typing import Callable, TypeVar, Optional, Any
|
|
2
11
|
|
|
3
12
|
import PIL.Image
|
|
@@ -13,15 +22,39 @@ from pixeltable.utils.code import local_public_names
|
|
|
13
22
|
|
|
14
23
|
@pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
|
|
15
24
|
def sentence_transformer(
|
|
16
|
-
|
|
25
|
+
sentence: Batch[str], *, model_id: str, normalize_embeddings: bool = False
|
|
17
26
|
) -> Batch[np.ndarray]:
|
|
18
|
-
"""
|
|
27
|
+
"""
|
|
28
|
+
Runs the specified pretrained sentence-transformers model. `model_id` should be a pretrained model, as described
|
|
29
|
+
in the [Sentence Transformers Pretrained Models](https://sbert.net/docs/sentence_transformer/pretrained_models.html)
|
|
30
|
+
documentation.
|
|
31
|
+
|
|
32
|
+
__Requirements:__
|
|
33
|
+
|
|
34
|
+
- `pip install sentence-transformers`
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
sentence: The sentence to embed.
|
|
38
|
+
model_id: The pretrained model to use for the encoding.
|
|
39
|
+
normalize_embeddings: If `True`, normalizes embeddings to length 1; see the
|
|
40
|
+
[Sentence Transformers API Docs](https://sbert.net/docs/package_reference/sentence_transformer/SentenceTransformer.html)
|
|
41
|
+
for more details
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
An array containing the output of the embedding model.
|
|
45
|
+
|
|
46
|
+
Examples:
|
|
47
|
+
Add a computed column that applies the model `all-mpnet-base-2` to an existing Pixeltable column `tbl.sentence`
|
|
48
|
+
of the table `tbl`:
|
|
49
|
+
|
|
50
|
+
>>> tbl['result'] = sentence_transformer(tbl.sentence, model_id='all-mpnet-base-v2')
|
|
51
|
+
"""
|
|
19
52
|
env.Env.get().require_package('sentence_transformers')
|
|
20
53
|
from sentence_transformers import SentenceTransformer
|
|
21
54
|
|
|
22
55
|
model = _lookup_model(model_id, SentenceTransformer)
|
|
23
56
|
|
|
24
|
-
array = model.encode(
|
|
57
|
+
array = model.encode(sentence, normalize_embeddings=normalize_embeddings)
|
|
25
58
|
return [array[i] for i in range(array.shape[0])]
|
|
26
59
|
|
|
27
60
|
|
|
@@ -49,7 +82,32 @@ def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embed
|
|
|
49
82
|
|
|
50
83
|
@pxt.udf(batch_size=32)
|
|
51
84
|
def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: str) -> Batch[float]:
|
|
52
|
-
"""
|
|
85
|
+
"""
|
|
86
|
+
Runs the specified cross-encoder model to compute similarity scores for pairs of sentences.
|
|
87
|
+
`model_id` should be a pretrained model, as described in the
|
|
88
|
+
[Cross-Encoder Pretrained Models](https://www.sbert.net/docs/cross_encoder/pretrained_models.html)
|
|
89
|
+
documentation.
|
|
90
|
+
|
|
91
|
+
__Requirements:__
|
|
92
|
+
|
|
93
|
+
- `pip install sentence-transformers`
|
|
94
|
+
|
|
95
|
+
Parameters:
|
|
96
|
+
sentences1: The first sentence to be paired.
|
|
97
|
+
sentences2: The second sentence to be paired.
|
|
98
|
+
model_id: The identifier of the cross-encoder model to use.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
The similarity score between the inputs.
|
|
102
|
+
|
|
103
|
+
Examples:
|
|
104
|
+
Add a computed column that applies the model `ms-marco-MiniLM-L-4-v2` to the sentences in
|
|
105
|
+
columns `tbl.sentence1` and `tbl.sentence2`:
|
|
106
|
+
|
|
107
|
+
>>> tbl['result'] = sentence_transformer(
|
|
108
|
+
tbl.sentence1, tbl.sentence2, model_id='ms-marco-MiniLM-L-4-v2'
|
|
109
|
+
)
|
|
110
|
+
"""
|
|
53
111
|
env.Env.get().require_package('sentence_transformers')
|
|
54
112
|
from sentence_transformers import CrossEncoder
|
|
55
113
|
|