pixeltable 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__init__.py +2 -2
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/column.py +8 -22
- pixeltable/catalog/insertable_table.py +26 -8
- pixeltable/catalog/table.py +179 -83
- pixeltable/catalog/table_version.py +13 -39
- pixeltable/catalog/table_version_path.py +2 -2
- pixeltable/catalog/view.py +2 -2
- pixeltable/dataframe.py +20 -28
- pixeltable/env.py +2 -0
- pixeltable/exec/cache_prefetch_node.py +189 -43
- pixeltable/exec/data_row_batch.py +3 -3
- pixeltable/exec/exec_context.py +2 -2
- pixeltable/exec/exec_node.py +2 -2
- pixeltable/exec/expr_eval_node.py +8 -8
- pixeltable/exprs/arithmetic_expr.py +9 -4
- pixeltable/exprs/column_ref.py +4 -0
- pixeltable/exprs/comparison.py +5 -0
- pixeltable/exprs/json_path.py +1 -1
- pixeltable/func/aggregate_function.py +8 -8
- pixeltable/func/expr_template_function.py +6 -5
- pixeltable/func/udf.py +6 -11
- pixeltable/functions/huggingface.py +136 -25
- pixeltable/functions/llama_cpp.py +3 -2
- pixeltable/functions/mistralai.py +1 -1
- pixeltable/functions/openai.py +1 -1
- pixeltable/functions/together.py +1 -1
- pixeltable/functions/util.py +5 -2
- pixeltable/globals.py +55 -6
- pixeltable/plan.py +1 -1
- pixeltable/tool/create_test_db_dump.py +1 -1
- pixeltable/type_system.py +83 -35
- pixeltable/utils/coco.py +5 -5
- pixeltable/utils/formatter.py +3 -3
- pixeltable/utils/s3.py +6 -3
- {pixeltable-0.2.22.dist-info → pixeltable-0.2.23.dist-info}/METADATA +119 -46
- {pixeltable-0.2.22.dist-info → pixeltable-0.2.23.dist-info}/RECORD +40 -40
- {pixeltable-0.2.22.dist-info → pixeltable-0.2.23.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.22.dist-info → pixeltable-0.2.23.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.22.dist-info → pixeltable-0.2.23.dist-info}/entry_points.txt +0 -0
|
@@ -3,7 +3,7 @@ import sys
|
|
|
3
3
|
import time
|
|
4
4
|
import warnings
|
|
5
5
|
from dataclasses import dataclass
|
|
6
|
-
from typing import Iterable,
|
|
6
|
+
from typing import Iterable, Optional
|
|
7
7
|
|
|
8
8
|
from tqdm import TqdmWarning, tqdm
|
|
9
9
|
|
|
@@ -22,10 +22,10 @@ class ExprEvalNode(ExecNode):
|
|
|
22
22
|
@dataclass
|
|
23
23
|
class Cohort:
|
|
24
24
|
"""List of exprs that form an evaluation context and contain calls to at most one external function"""
|
|
25
|
-
exprs_:
|
|
25
|
+
exprs_: list[exprs.Expr]
|
|
26
26
|
batched_fn: Optional[CallableFunction]
|
|
27
|
-
segment_ctxs:
|
|
28
|
-
target_slot_idxs:
|
|
27
|
+
segment_ctxs: list['exprs.RowBuilder.EvalCtx']
|
|
28
|
+
target_slot_idxs: list[int]
|
|
29
29
|
batch_size: int = 8
|
|
30
30
|
|
|
31
31
|
def __init__(
|
|
@@ -38,7 +38,7 @@ class ExprEvalNode(ExecNode):
|
|
|
38
38
|
# we're only materializing exprs that are not already in the input
|
|
39
39
|
self.target_exprs = [e for e in output_exprs if e.slot_idx not in input_slot_idxs]
|
|
40
40
|
self.pbar: Optional[tqdm] = None
|
|
41
|
-
self.cohorts:
|
|
41
|
+
self.cohorts: list[ExprEvalNode.Cohort] = []
|
|
42
42
|
self._create_cohorts()
|
|
43
43
|
|
|
44
44
|
def __next__(self) -> DataRowBatch:
|
|
@@ -83,7 +83,7 @@ class ExprEvalNode(ExecNode):
|
|
|
83
83
|
all_exprs = self.row_builder.get_dependencies(self.target_exprs)
|
|
84
84
|
# break up all_exprs into cohorts such that each cohort contains calls to at most one external function;
|
|
85
85
|
# seed the cohorts with only the ext fn calls
|
|
86
|
-
cohorts:
|
|
86
|
+
cohorts: list[list[exprs.Expr]] = []
|
|
87
87
|
current_batched_fn: Optional[CallableFunction] = None
|
|
88
88
|
for e in all_exprs:
|
|
89
89
|
if not self._is_batched_fn_call(e):
|
|
@@ -100,7 +100,7 @@ class ExprEvalNode(ExecNode):
|
|
|
100
100
|
# cohorts are evaluated in order, so we can exclude the target slots from preceding cohorts and input slots
|
|
101
101
|
exclude = set(e.slot_idx for e in self.input_exprs)
|
|
102
102
|
all_target_slot_idxs = set(e.slot_idx for e in self.target_exprs)
|
|
103
|
-
target_slot_idxs:
|
|
103
|
+
target_slot_idxs: list[list[int]] = [] # the ones materialized by each cohort
|
|
104
104
|
for i in range(len(cohorts)):
|
|
105
105
|
cohorts[i] = self.row_builder.get_dependencies(
|
|
106
106
|
cohorts[i], exclude=[self.row_builder.unique_exprs[slot_idx] for slot_idx in exclude])
|
|
@@ -171,7 +171,7 @@ class ExprEvalNode(ExecNode):
|
|
|
171
171
|
arg_batches: list[list[exprs.Expr]] = [[] for _ in range(len(fn_call.args))]
|
|
172
172
|
kwarg_batches: dict[str, list[exprs.Expr]] = {k: [] for k in fn_call.kwargs.keys()}
|
|
173
173
|
|
|
174
|
-
valid_batch_idxs:
|
|
174
|
+
valid_batch_idxs: list[int] = [] # rows with exceptions are not valid
|
|
175
175
|
for row_idx in range(batch_start_idx, batch_start_idx + num_batch_rows):
|
|
176
176
|
row = rows[row_idx]
|
|
177
177
|
if row.has_exc(fn_call.slot_idx):
|
|
@@ -69,11 +69,15 @@ class ArithmeticExpr(Expr):
|
|
|
69
69
|
return left * right
|
|
70
70
|
if self.operator == ArithmeticOperator.DIV:
|
|
71
71
|
assert self.col_type.is_float_type()
|
|
72
|
+
# Avoid DivisionByZero: if right is 0, make this a NULL
|
|
73
|
+
# TODO: Should we cast the NULLs to NaNs when they are retrieved back into Python?
|
|
74
|
+
nullif = sql.sql.func.nullif(right, 0)
|
|
72
75
|
# We have to cast to a `float`, or else we'll get a `Decimal`
|
|
73
|
-
return sql.sql.expression.cast(left /
|
|
76
|
+
return sql.sql.expression.cast(left / nullif, sql.Float)
|
|
74
77
|
if self.operator == ArithmeticOperator.MOD:
|
|
75
78
|
if self.col_type.is_int_type():
|
|
76
|
-
|
|
79
|
+
nullif = sql.sql.func.nullif(right, 0)
|
|
80
|
+
return left % nullif
|
|
77
81
|
if self.col_type.is_float_type():
|
|
78
82
|
# Postgres does not support modulus for floats
|
|
79
83
|
return None
|
|
@@ -83,10 +87,11 @@ class ArithmeticExpr(Expr):
|
|
|
83
87
|
# We need the behavior to be consistent, so that expressions will evaluate the same way
|
|
84
88
|
# whether or not their operands can be translated to SQL. These SQL clauses should
|
|
85
89
|
# mimic the behavior of Python's // operator.
|
|
90
|
+
nullif = sql.sql.func.nullif(right, 0)
|
|
86
91
|
if self.col_type.is_int_type():
|
|
87
|
-
return sql.sql.expression.cast(sql.func.floor(left /
|
|
92
|
+
return sql.sql.expression.cast(sql.func.floor(left / nullif), sql.Integer)
|
|
88
93
|
if self.col_type.is_float_type():
|
|
89
|
-
return sql.sql.expression.cast(sql.func.floor(left /
|
|
94
|
+
return sql.sql.expression.cast(sql.func.floor(left / nullif), sql.Float)
|
|
90
95
|
assert False
|
|
91
96
|
|
|
92
97
|
def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
|
pixeltable/exprs/column_ref.py
CHANGED
|
@@ -135,6 +135,10 @@ class ColumnRef(Expr):
|
|
|
135
135
|
def __repr__(self) -> str:
|
|
136
136
|
return f'ColumnRef({self.col!r})'
|
|
137
137
|
|
|
138
|
+
def _repr_html_(self) -> str:
|
|
139
|
+
tbl = catalog.Catalog.get().tbls[self.col.tbl.id]
|
|
140
|
+
return tbl._description_html(cols=[self.col])._repr_html_() # type: ignore[attr-defined]
|
|
141
|
+
|
|
138
142
|
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
139
143
|
return None if self.perform_validation else self.col.sa_col
|
|
140
144
|
|
pixeltable/exprs/comparison.py
CHANGED
|
@@ -67,6 +67,11 @@ class Comparison(Expr):
|
|
|
67
67
|
return self.components[1]
|
|
68
68
|
|
|
69
69
|
def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
70
|
+
if str(self._op1.col_type.to_sa_type()) != str(self._op2.col_type.to_sa_type()):
|
|
71
|
+
# Comparing columns of different SQL types (e.g., string vs. json); this can only be done in Python
|
|
72
|
+
# TODO(aaron-siegel): We may be able to handle some cases in SQL by casting one side to the other's type
|
|
73
|
+
return None
|
|
74
|
+
|
|
70
75
|
left = sql_elements.get(self._op1)
|
|
71
76
|
if self.is_search_arg_comparison:
|
|
72
77
|
# reference the index value column if there is an index and this is not a snapshot
|
pixeltable/exprs/json_path.py
CHANGED
|
@@ -32,7 +32,7 @@ class JsonPath(Expr):
|
|
|
32
32
|
"""
|
|
33
33
|
if path_elements is None:
|
|
34
34
|
path_elements = []
|
|
35
|
-
super().__init__(ts.JsonType())
|
|
35
|
+
super().__init__(ts.JsonType(nullable=True)) # JsonPath expressions are always nullable
|
|
36
36
|
if anchor is not None:
|
|
37
37
|
self.components = [anchor]
|
|
38
38
|
self.path_elements: list[Union[str, int, slice]] = path_elements
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
4
|
import inspect
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
6
6
|
|
|
7
7
|
import pixeltable.exceptions as excs
|
|
8
8
|
import pixeltable.type_system as ts
|
|
@@ -36,8 +36,8 @@ class AggregateFunction(Function):
|
|
|
36
36
|
RESERVED_PARAMS = {ORDER_BY_PARAM, GROUP_BY_PARAM}
|
|
37
37
|
|
|
38
38
|
def __init__(
|
|
39
|
-
self, aggregator_class:
|
|
40
|
-
init_types:
|
|
39
|
+
self, aggregator_class: type[Aggregator], self_path: str,
|
|
40
|
+
init_types: list[ts.ColumnType], update_types: list[ts.ColumnType], value_type: ts.ColumnType,
|
|
41
41
|
requires_order_by: bool, allows_std_agg: bool, allows_window: bool):
|
|
42
42
|
self.agg_cls = aggregator_class
|
|
43
43
|
self.requires_order_by = requires_order_by
|
|
@@ -128,7 +128,7 @@ class AggregateFunction(Function):
|
|
|
128
128
|
order_by_clause=[order_by_clause] if order_by_clause is not None else [],
|
|
129
129
|
group_by_clause=[group_by_clause] if group_by_clause is not None else [])
|
|
130
130
|
|
|
131
|
-
def validate_call(self, bound_args:
|
|
131
|
+
def validate_call(self, bound_args: dict[str, Any]) -> None:
|
|
132
132
|
# check that init parameters are not Exprs
|
|
133
133
|
# TODO: do this in the planner (check that init parameters are either constants or only refer to grouping exprs)
|
|
134
134
|
import pixeltable.exprs as exprs
|
|
@@ -146,10 +146,10 @@ class AggregateFunction(Function):
|
|
|
146
146
|
def uda(
|
|
147
147
|
*,
|
|
148
148
|
value_type: ts.ColumnType,
|
|
149
|
-
update_types:
|
|
150
|
-
init_types: Optional[
|
|
149
|
+
update_types: list[ts.ColumnType],
|
|
150
|
+
init_types: Optional[list[ts.ColumnType]] = None,
|
|
151
151
|
requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
|
|
152
|
-
) -> Callable[[
|
|
152
|
+
) -> Callable[[type[Aggregator]], AggregateFunction]:
|
|
153
153
|
"""Decorator for user-defined aggregate functions.
|
|
154
154
|
|
|
155
155
|
The decorated class must inherit from Aggregator and implement the following methods:
|
|
@@ -171,7 +171,7 @@ def uda(
|
|
|
171
171
|
if init_types is None:
|
|
172
172
|
init_types = []
|
|
173
173
|
|
|
174
|
-
def decorator(cls:
|
|
174
|
+
def decorator(cls: type[Aggregator]) -> AggregateFunction:
|
|
175
175
|
# validate type parameters
|
|
176
176
|
num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
|
|
177
177
|
if num_init_params > 0:
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Any, Optional
|
|
3
3
|
|
|
4
4
|
import pixeltable
|
|
5
5
|
import pixeltable.exceptions as excs
|
|
6
|
+
|
|
6
7
|
from .function import Function
|
|
7
|
-
from .signature import Signature
|
|
8
|
+
from .signature import Signature
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
class ExprTemplateFunction(Function):
|
|
@@ -22,7 +23,7 @@ class ExprTemplateFunction(Function):
|
|
|
22
23
|
self.param_exprs_by_name = {p.name: p for p in self.param_exprs}
|
|
23
24
|
|
|
24
25
|
# verify default values
|
|
25
|
-
self.defaults:
|
|
26
|
+
self.defaults: dict[str, exprs.Literal] = {} # key: param name, value: default value converted to a Literal
|
|
26
27
|
for param in signature.parameters.values():
|
|
27
28
|
if param.default is inspect.Parameter.empty:
|
|
28
29
|
continue
|
|
@@ -77,7 +78,7 @@ class ExprTemplateFunction(Function):
|
|
|
77
78
|
def name(self) -> str:
|
|
78
79
|
return self.self_name
|
|
79
80
|
|
|
80
|
-
def _as_dict(self) ->
|
|
81
|
+
def _as_dict(self) -> dict:
|
|
81
82
|
if self.self_path is not None:
|
|
82
83
|
return super()._as_dict()
|
|
83
84
|
return {
|
|
@@ -87,7 +88,7 @@ class ExprTemplateFunction(Function):
|
|
|
87
88
|
}
|
|
88
89
|
|
|
89
90
|
@classmethod
|
|
90
|
-
def _from_dict(cls, d:
|
|
91
|
+
def _from_dict(cls, d: dict) -> Function:
|
|
91
92
|
if 'expr' not in d:
|
|
92
93
|
return super()._from_dict(d)
|
|
93
94
|
assert 'signature' in d and 'name' in d
|
pixeltable/func/udf.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Any, Callable, Optional, overload
|
|
4
4
|
|
|
5
5
|
import pixeltable.exceptions as excs
|
|
6
6
|
import pixeltable.type_system as ts
|
|
7
|
+
|
|
7
8
|
from .callable_function import CallableFunction
|
|
8
9
|
from .expr_template_function import ExprTemplateFunction
|
|
9
10
|
from .function import Function
|
|
@@ -21,8 +22,6 @@ def udf(decorated_fn: Callable) -> Function: ...
|
|
|
21
22
|
@overload
|
|
22
23
|
def udf(
|
|
23
24
|
*,
|
|
24
|
-
return_type: Optional[ts.ColumnType] = None,
|
|
25
|
-
param_types: Optional[List[ts.ColumnType]] = None,
|
|
26
25
|
batch_size: Optional[int] = None,
|
|
27
26
|
substitute_fn: Optional[Callable] = None,
|
|
28
27
|
is_method: bool = False,
|
|
@@ -49,8 +48,6 @@ def udf(*args, **kwargs):
|
|
|
49
48
|
|
|
50
49
|
# Decorator schema invoked with parentheses: @pxt.udf(**kwargs)
|
|
51
50
|
# Create a decorator for the specified schema.
|
|
52
|
-
return_type = kwargs.pop('return_type', None)
|
|
53
|
-
param_types = kwargs.pop('param_types', None)
|
|
54
51
|
batch_size = kwargs.pop('batch_size', None)
|
|
55
52
|
substitute_fn = kwargs.pop('substitute_fn', None)
|
|
56
53
|
is_method = kwargs.pop('is_method', None)
|
|
@@ -64,9 +61,7 @@ def udf(*args, **kwargs):
|
|
|
64
61
|
def decorator(decorated_fn: Callable):
|
|
65
62
|
return make_function(
|
|
66
63
|
decorated_fn,
|
|
67
|
-
|
|
68
|
-
param_types,
|
|
69
|
-
batch_size,
|
|
64
|
+
batch_size=batch_size,
|
|
70
65
|
substitute_fn=substitute_fn,
|
|
71
66
|
is_method=is_method,
|
|
72
67
|
is_property=is_property,
|
|
@@ -79,7 +74,7 @@ def udf(*args, **kwargs):
|
|
|
79
74
|
def make_function(
|
|
80
75
|
decorated_fn: Callable,
|
|
81
76
|
return_type: Optional[ts.ColumnType] = None,
|
|
82
|
-
param_types: Optional[
|
|
77
|
+
param_types: Optional[list[ts.ColumnType]] = None,
|
|
83
78
|
batch_size: Optional[int] = None,
|
|
84
79
|
substitute_fn: Optional[Callable] = None,
|
|
85
80
|
is_method: bool = False,
|
|
@@ -158,10 +153,10 @@ def make_function(
|
|
|
158
153
|
def expr_udf(py_fn: Callable) -> ExprTemplateFunction: ...
|
|
159
154
|
|
|
160
155
|
@overload
|
|
161
|
-
def expr_udf(*, param_types: Optional[
|
|
156
|
+
def expr_udf(*, param_types: Optional[list[ts.ColumnType]] = None) -> Callable[[Callable], ExprTemplateFunction]: ...
|
|
162
157
|
|
|
163
158
|
def expr_udf(*args: Any, **kwargs: Any) -> Any:
|
|
164
|
-
def make_expr_template(py_fn: Callable, param_types: Optional[
|
|
159
|
+
def make_expr_template(py_fn: Callable, param_types: Optional[list[ts.ColumnType]]) -> ExprTemplateFunction:
|
|
165
160
|
if py_fn.__module__ != '__main__' and py_fn.__name__.isidentifier():
|
|
166
161
|
# this is a named function in a module
|
|
167
162
|
function_path = f'{py_fn.__module__}.{py_fn.__qualname__}'
|
|
@@ -7,21 +7,22 @@ first `pip install transformers` (or in some cases, `sentence-transformers`, as
|
|
|
7
7
|
UDFs).
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
from typing import
|
|
10
|
+
from typing import Any, Callable, Optional, TypeVar
|
|
11
11
|
|
|
12
12
|
import PIL.Image
|
|
13
13
|
|
|
14
14
|
import pixeltable as pxt
|
|
15
15
|
import pixeltable.env as env
|
|
16
|
+
import pixeltable.exceptions as excs
|
|
16
17
|
from pixeltable.func import Batch
|
|
17
|
-
from pixeltable.functions.util import
|
|
18
|
+
from pixeltable.functions.util import normalize_image_mode, resolve_torch_device
|
|
18
19
|
from pixeltable.utils.code import local_public_names
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
@pxt.udf(batch_size=32)
|
|
22
23
|
def sentence_transformer(
|
|
23
24
|
sentence: Batch[str], *, model_id: str, normalize_embeddings: bool = False
|
|
24
|
-
) -> Batch[pxt.Array[(None,),
|
|
25
|
+
) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
25
26
|
"""
|
|
26
27
|
Computes sentence embeddings. `model_id` should be a pretrained Sentence Transformers model, as described
|
|
27
28
|
in the [Sentence Transformers Pretrained Models](https://sbert.net/docs/sentence_transformer/pretrained_models.html)
|
|
@@ -29,7 +30,7 @@ def sentence_transformer(
|
|
|
29
30
|
|
|
30
31
|
__Requirements:__
|
|
31
32
|
|
|
32
|
-
- `pip install sentence-transformers`
|
|
33
|
+
- `pip install torch sentence-transformers`
|
|
33
34
|
|
|
34
35
|
Args:
|
|
35
36
|
sentence: The sentence to embed.
|
|
@@ -48,11 +49,15 @@ def sentence_transformer(
|
|
|
48
49
|
>>> tbl['result'] = sentence_transformer(tbl.sentence, model_id='all-mpnet-base-v2')
|
|
49
50
|
"""
|
|
50
51
|
env.Env.get().require_package('sentence_transformers')
|
|
52
|
+
device = resolve_torch_device('auto')
|
|
53
|
+
import torch
|
|
51
54
|
from sentence_transformers import SentenceTransformer # type: ignore
|
|
52
55
|
|
|
53
|
-
model
|
|
56
|
+
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
57
|
+
model = _lookup_model(model_id, SentenceTransformer, device=device, pass_device_to_create=True)
|
|
54
58
|
|
|
55
|
-
|
|
59
|
+
# specifying the device, uses it for computation
|
|
60
|
+
array = model.encode(sentence, device=device, normalize_embeddings=normalize_embeddings)
|
|
56
61
|
return [array[i] for i in range(array.shape[0])]
|
|
57
62
|
|
|
58
63
|
|
|
@@ -70,11 +75,15 @@ def _(model_id: str) -> pxt.ArrayType:
|
|
|
70
75
|
@pxt.udf
|
|
71
76
|
def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embeddings: bool = False) -> list:
|
|
72
77
|
env.Env.get().require_package('sentence_transformers')
|
|
78
|
+
device = resolve_torch_device('auto')
|
|
79
|
+
import torch
|
|
73
80
|
from sentence_transformers import SentenceTransformer
|
|
74
81
|
|
|
75
|
-
model
|
|
82
|
+
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
83
|
+
model = _lookup_model(model_id, SentenceTransformer, device=device, pass_device_to_create=True)
|
|
76
84
|
|
|
77
|
-
|
|
85
|
+
# specifying the device, uses it for computation
|
|
86
|
+
array = model.encode(sentences, device=device, normalize_embeddings=normalize_embeddings)
|
|
78
87
|
return [array[i].tolist() for i in range(array.shape[0])]
|
|
79
88
|
|
|
80
89
|
|
|
@@ -88,7 +97,7 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
|
|
|
88
97
|
|
|
89
98
|
__Requirements:__
|
|
90
99
|
|
|
91
|
-
- `pip install sentence-transformers`
|
|
100
|
+
- `pip install torch sentence-transformers`
|
|
92
101
|
|
|
93
102
|
Parameters:
|
|
94
103
|
sentences1: The first sentence to be paired.
|
|
@@ -107,9 +116,13 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
|
|
|
107
116
|
)
|
|
108
117
|
"""
|
|
109
118
|
env.Env.get().require_package('sentence_transformers')
|
|
119
|
+
device = resolve_torch_device('auto')
|
|
120
|
+
import torch
|
|
110
121
|
from sentence_transformers import CrossEncoder
|
|
111
122
|
|
|
112
|
-
model
|
|
123
|
+
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
124
|
+
# and uses the device for predict computation
|
|
125
|
+
model = _lookup_model(model_id, CrossEncoder, device=device, pass_device_to_create=True)
|
|
113
126
|
|
|
114
127
|
array = model.predict([[s1, s2] for s1, s2 in zip(sentences1, sentences2)], convert_to_numpy=True)
|
|
115
128
|
return array.tolist()
|
|
@@ -118,23 +131,27 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
|
|
|
118
131
|
@pxt.udf
|
|
119
132
|
def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> list:
|
|
120
133
|
env.Env.get().require_package('sentence_transformers')
|
|
134
|
+
device = resolve_torch_device('auto')
|
|
135
|
+
import torch
|
|
121
136
|
from sentence_transformers import CrossEncoder
|
|
122
137
|
|
|
123
|
-
model
|
|
138
|
+
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
139
|
+
# and uses the device for predict computation
|
|
140
|
+
model = _lookup_model(model_id, CrossEncoder, device=device, pass_device_to_create=True)
|
|
124
141
|
|
|
125
142
|
array = model.predict([[sentence1, s2] for s2 in sentences2], convert_to_numpy=True)
|
|
126
143
|
return array.tolist()
|
|
127
144
|
|
|
128
145
|
|
|
129
146
|
@pxt.udf(batch_size=32)
|
|
130
|
-
def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,),
|
|
147
|
+
def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
131
148
|
"""
|
|
132
149
|
Computes a CLIP embedding for the specified text. `model_id` should be a reference to a pretrained
|
|
133
150
|
[CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
|
|
134
151
|
|
|
135
152
|
__Requirements:__
|
|
136
153
|
|
|
137
|
-
- `pip install transformers`
|
|
154
|
+
- `pip install torch transformers`
|
|
138
155
|
|
|
139
156
|
Args:
|
|
140
157
|
text: The string to embed.
|
|
@@ -165,14 +182,14 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), fl
|
|
|
165
182
|
|
|
166
183
|
|
|
167
184
|
@pxt.udf(batch_size=32)
|
|
168
|
-
def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,),
|
|
185
|
+
def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
169
186
|
"""
|
|
170
187
|
Computes a CLIP embedding for the specified image. `model_id` should be a reference to a pretrained
|
|
171
188
|
[CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
|
|
172
189
|
|
|
173
190
|
__Requirements:__
|
|
174
191
|
|
|
175
|
-
- `pip install transformers`
|
|
192
|
+
- `pip install torch transformers`
|
|
176
193
|
|
|
177
194
|
Args:
|
|
178
195
|
image: The image to embed.
|
|
@@ -215,14 +232,20 @@ def _(model_id: str) -> pxt.ArrayType:
|
|
|
215
232
|
|
|
216
233
|
|
|
217
234
|
@pxt.udf(batch_size=4)
|
|
218
|
-
def detr_for_object_detection(
|
|
235
|
+
def detr_for_object_detection(
|
|
236
|
+
image: Batch[PIL.Image.Image],
|
|
237
|
+
*,
|
|
238
|
+
model_id: str,
|
|
239
|
+
threshold: float = 0.5,
|
|
240
|
+
revision: str = 'no_timm',
|
|
241
|
+
) -> Batch[dict]:
|
|
219
242
|
"""
|
|
220
243
|
Computes DETR object detections for the specified image. `model_id` should be a reference to a pretrained
|
|
221
244
|
[DETR Model](https://huggingface.co/docs/transformers/model_doc/detr).
|
|
222
245
|
|
|
223
246
|
__Requirements:__
|
|
224
247
|
|
|
225
|
-
- `pip install transformers`
|
|
248
|
+
- `pip install torch transformers`
|
|
226
249
|
|
|
227
250
|
Args:
|
|
228
251
|
image: The image to embed.
|
|
@@ -254,12 +277,12 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
|
|
|
254
277
|
env.Env.get().require_package('transformers')
|
|
255
278
|
device = resolve_torch_device('auto')
|
|
256
279
|
import torch
|
|
257
|
-
from transformers import
|
|
280
|
+
from transformers import DetrForObjectDetection, DetrImageProcessor
|
|
258
281
|
|
|
259
282
|
model = _lookup_model(
|
|
260
|
-
model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision=
|
|
283
|
+
model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision=revision), device=device
|
|
261
284
|
)
|
|
262
|
-
processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision=
|
|
285
|
+
processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision=revision))
|
|
263
286
|
normalized_images = [normalize_image_mode(img) for img in image]
|
|
264
287
|
|
|
265
288
|
with torch.no_grad():
|
|
@@ -299,7 +322,7 @@ def vit_for_image_classification(
|
|
|
299
322
|
|
|
300
323
|
__Requirements:__
|
|
301
324
|
|
|
302
|
-
- `pip install transformers`
|
|
325
|
+
- `pip install torch transformers`
|
|
303
326
|
|
|
304
327
|
Args:
|
|
305
328
|
image: The image to classify.
|
|
@@ -330,7 +353,7 @@ def vit_for_image_classification(
|
|
|
330
353
|
env.Env.get().require_package('transformers')
|
|
331
354
|
device = resolve_torch_device('auto')
|
|
332
355
|
import torch
|
|
333
|
-
from transformers import
|
|
356
|
+
from transformers import ViTForImageClassification, ViTImageProcessor
|
|
334
357
|
|
|
335
358
|
model: ViTForImageClassification = _lookup_model(model_id, ViTForImageClassification.from_pretrained, device=device)
|
|
336
359
|
processor = _lookup_processor(model_id, ViTImageProcessor.from_pretrained)
|
|
@@ -356,6 +379,86 @@ def vit_for_image_classification(
|
|
|
356
379
|
]
|
|
357
380
|
|
|
358
381
|
|
|
382
|
+
@pxt.udf
|
|
383
|
+
def speech2text_for_conditional_generation(
|
|
384
|
+
audio: pxt.Audio,
|
|
385
|
+
*,
|
|
386
|
+
model_id: str,
|
|
387
|
+
language: Optional[str] = None,
|
|
388
|
+
) -> str:
|
|
389
|
+
"""
|
|
390
|
+
Transcribes or translates speech to text using a Speech2Text model. `model_id` should be a reference to a
|
|
391
|
+
pretrained [Speech2Text](https://huggingface.co/docs/transformers/en/model_doc/speech_to_text) model.
|
|
392
|
+
|
|
393
|
+
__Requirements:__
|
|
394
|
+
|
|
395
|
+
- `pip install torch torchaudio sentencepiece transformers`
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
audio: The audio clip to transcribe or translate.
|
|
399
|
+
model_id: The pretrained model to use for the transcription or translation.
|
|
400
|
+
language: If using a multilingual translation model, the language code to translate to. If not provided,
|
|
401
|
+
the model's default language will be used. If the model is not translation model, is not a
|
|
402
|
+
multilingual model, or does not support the specified language, an error will be raised.
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
The transcribed or translated text.
|
|
406
|
+
|
|
407
|
+
Examples:
|
|
408
|
+
Add a computed column that applies the model `facebook/s2t-small-librispeech-asr` to an existing
|
|
409
|
+
Pixeltable column `audio` of the table `tbl`:
|
|
410
|
+
|
|
411
|
+
>>> tbl['transcription'] = speech2text_for_conditional_generation(
|
|
412
|
+
... tbl.audio,
|
|
413
|
+
... model_id='facebook/s2t-small-librispeech-asr'
|
|
414
|
+
... )
|
|
415
|
+
|
|
416
|
+
Add a computed column that applies the model `facebook/s2t-medium-mustc-multilingual-st` to an existing
|
|
417
|
+
Pixeltable column `audio` of the table `tbl`, translating the audio to French:
|
|
418
|
+
|
|
419
|
+
>>> tbl['translation'] = speech2text_for_conditional_generation(
|
|
420
|
+
... tbl.audio,
|
|
421
|
+
... model_id='facebook/s2t-medium-mustc-multilingual-st',
|
|
422
|
+
... language='fr'
|
|
423
|
+
... )
|
|
424
|
+
"""
|
|
425
|
+
env.Env.get().require_package('transformers')
|
|
426
|
+
env.Env.get().require_package('torchaudio')
|
|
427
|
+
env.Env.get().require_package('sentencepiece')
|
|
428
|
+
device = resolve_torch_device('auto', allow_mps=False) # Doesn't seem to work on 'mps'; use 'cpu' instead
|
|
429
|
+
import librosa
|
|
430
|
+
import torch
|
|
431
|
+
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
|
432
|
+
|
|
433
|
+
# facebook/s2t-small-librispeech-asr
|
|
434
|
+
# facebook/s2t-small-mustc-en-fr-st
|
|
435
|
+
model = _lookup_model(model_id, Speech2TextForConditionalGeneration.from_pretrained, device=device)
|
|
436
|
+
processor = _lookup_processor(model_id, Speech2TextProcessor.from_pretrained)
|
|
437
|
+
assert isinstance(processor, Speech2TextProcessor)
|
|
438
|
+
|
|
439
|
+
if language is not None and language not in processor.tokenizer.lang_code_to_id:
|
|
440
|
+
raise excs.Error(
|
|
441
|
+
f"Language code '{language}' is not supported by the model '{model_id}'. "
|
|
442
|
+
f"Supported languages are: {list(processor.tokenizer.lang_code_to_id.keys())}")
|
|
443
|
+
|
|
444
|
+
forced_bos_token_id: Optional[int] = None if language is None else processor.tokenizer.lang_code_to_id[language]
|
|
445
|
+
|
|
446
|
+
# Get the model's sampling rate. Default to 16 kHz (the standard) if not in config
|
|
447
|
+
model_sampling_rate = getattr(model.config, 'sampling_rate', 16_000)
|
|
448
|
+
waveform, sampling_rate = librosa.load(audio, sr=model_sampling_rate, mono=True)
|
|
449
|
+
|
|
450
|
+
with torch.no_grad():
|
|
451
|
+
inputs = processor(
|
|
452
|
+
waveform,
|
|
453
|
+
sampling_rate=sampling_rate,
|
|
454
|
+
return_tensors='pt'
|
|
455
|
+
)
|
|
456
|
+
generated_ids = model.generate(**inputs.to(device), forced_bos_token_id=forced_bos_token_id).to('cpu')
|
|
457
|
+
|
|
458
|
+
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
459
|
+
return transcription
|
|
460
|
+
|
|
461
|
+
|
|
359
462
|
@pxt.udf
|
|
360
463
|
def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str, Any]:
|
|
361
464
|
"""
|
|
@@ -385,14 +488,22 @@ def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str,
|
|
|
385
488
|
T = TypeVar('T')
|
|
386
489
|
|
|
387
490
|
|
|
388
|
-
def _lookup_model(
|
|
491
|
+
def _lookup_model(
|
|
492
|
+
model_id: str,
|
|
493
|
+
create: Callable[..., T],
|
|
494
|
+
device: Optional[str] = None,
|
|
495
|
+
pass_device_to_create: bool = False
|
|
496
|
+
) -> T:
|
|
389
497
|
from torch import nn
|
|
390
498
|
|
|
391
499
|
key = (model_id, create, device) # For safety, include the `create` callable in the cache key
|
|
392
500
|
if key not in _model_cache:
|
|
393
|
-
|
|
501
|
+
if pass_device_to_create:
|
|
502
|
+
model = create(model_id, device=device)
|
|
503
|
+
else:
|
|
504
|
+
model = create(model_id)
|
|
394
505
|
if isinstance(model, nn.Module):
|
|
395
|
-
if device is not None:
|
|
506
|
+
if not pass_device_to_create and device is not None:
|
|
396
507
|
model.to(device)
|
|
397
508
|
model.eval()
|
|
398
509
|
_model_cache[key] = model
|
|
@@ -76,7 +76,7 @@ def _lookup_local_model(model_path: str, n_gpu_layers: int) -> 'llama_cpp.Llama'
|
|
|
76
76
|
|
|
77
77
|
key = (model_path, None, n_gpu_layers)
|
|
78
78
|
if key not in _model_cache:
|
|
79
|
-
llm = llama_cpp.Llama(model_path, n_gpu_layers=n_gpu_layers)
|
|
79
|
+
llm = llama_cpp.Llama(model_path, n_gpu_layers=n_gpu_layers, verbose=False)
|
|
80
80
|
_model_cache[key] = llm
|
|
81
81
|
return _model_cache[key]
|
|
82
82
|
|
|
@@ -89,7 +89,8 @@ def _lookup_pretrained_model(repo_id: str, filename: Optional[str], n_gpu_layers
|
|
|
89
89
|
llm = llama_cpp.Llama.from_pretrained(
|
|
90
90
|
repo_id=repo_id,
|
|
91
91
|
filename=filename,
|
|
92
|
-
n_gpu_layers=n_gpu_layers
|
|
92
|
+
n_gpu_layers=n_gpu_layers,
|
|
93
|
+
verbose=False,
|
|
93
94
|
)
|
|
94
95
|
_model_cache[key] = llm
|
|
95
96
|
return _model_cache[key]
|
|
@@ -141,7 +141,7 @@ _embedding_dimensions_cache: dict[str, int] = {
|
|
|
141
141
|
|
|
142
142
|
|
|
143
143
|
@pxt.udf(batch_size=16)
|
|
144
|
-
def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,),
|
|
144
|
+
def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
145
145
|
"""
|
|
146
146
|
Embeddings API.
|
|
147
147
|
|
pixeltable/functions/openai.py
CHANGED
|
@@ -304,7 +304,7 @@ _embedding_dimensions_cache: dict[str, int] = {
|
|
|
304
304
|
@pxt.udf(batch_size=32)
|
|
305
305
|
def embeddings(
|
|
306
306
|
input: Batch[str], *, model: str, dimensions: Optional[int] = None, user: Optional[str] = None
|
|
307
|
-
) -> Batch[pxt.Array[(None,),
|
|
307
|
+
) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
308
308
|
"""
|
|
309
309
|
Creates an embedding vector representing the input text.
|
|
310
310
|
|
pixeltable/functions/together.py
CHANGED
|
@@ -186,7 +186,7 @@ _embedding_dimensions_cache = {
|
|
|
186
186
|
|
|
187
187
|
|
|
188
188
|
@pxt.udf(batch_size=32)
|
|
189
|
-
def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,),
|
|
189
|
+
def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
190
190
|
"""
|
|
191
191
|
Query an embedding model for a given string of text.
|
|
192
192
|
|
pixeltable/functions/util.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
import PIL.Image
|
|
2
2
|
|
|
3
|
+
from pixeltable.env import Env
|
|
3
4
|
|
|
4
|
-
|
|
5
|
+
|
|
6
|
+
def resolve_torch_device(device: str, allow_mps: bool = True) -> str:
|
|
7
|
+
Env.get().require_package('torch')
|
|
5
8
|
import torch
|
|
6
9
|
|
|
7
10
|
if device == 'auto':
|
|
8
11
|
if torch.cuda.is_available():
|
|
9
12
|
return 'cuda'
|
|
10
|
-
if torch.backends.mps.is_available():
|
|
13
|
+
if allow_mps and torch.backends.mps.is_available():
|
|
11
14
|
return 'mps'
|
|
12
15
|
return 'cpu'
|
|
13
16
|
return device
|