pixeltable 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__init__.py +2 -2
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +1 -1
- pixeltable/catalog/column.py +41 -29
- pixeltable/catalog/globals.py +18 -0
- pixeltable/catalog/insertable_table.py +30 -10
- pixeltable/catalog/table.py +198 -86
- pixeltable/catalog/table_version.py +47 -53
- pixeltable/catalog/table_version_path.py +2 -2
- pixeltable/catalog/view.py +17 -18
- pixeltable/dataframe.py +27 -36
- pixeltable/env.py +7 -0
- pixeltable/exec/__init__.py +0 -1
- pixeltable/exec/aggregation_node.py +6 -3
- pixeltable/exec/cache_prefetch_node.py +189 -43
- pixeltable/exec/data_row_batch.py +5 -22
- pixeltable/exec/exec_context.py +2 -2
- pixeltable/exec/exec_node.py +3 -2
- pixeltable/exec/expr_eval_node.py +23 -16
- pixeltable/exec/in_memory_data_node.py +6 -3
- pixeltable/exec/sql_node.py +24 -25
- pixeltable/exprs/arithmetic_expr.py +12 -5
- pixeltable/exprs/array_slice.py +7 -7
- pixeltable/exprs/column_property_ref.py +37 -10
- pixeltable/exprs/column_ref.py +97 -14
- pixeltable/exprs/comparison.py +10 -5
- pixeltable/exprs/compound_predicate.py +8 -7
- pixeltable/exprs/data_row.py +27 -18
- pixeltable/exprs/expr.py +53 -52
- pixeltable/exprs/expr_set.py +5 -0
- pixeltable/exprs/function_call.py +32 -16
- pixeltable/exprs/globals.py +4 -1
- pixeltable/exprs/in_predicate.py +8 -7
- pixeltable/exprs/inline_expr.py +4 -4
- pixeltable/exprs/is_null.py +4 -4
- pixeltable/exprs/json_mapper.py +11 -12
- pixeltable/exprs/json_path.py +6 -11
- pixeltable/exprs/literal.py +5 -5
- pixeltable/exprs/method_ref.py +5 -4
- pixeltable/exprs/object_ref.py +2 -1
- pixeltable/exprs/row_builder.py +88 -36
- pixeltable/exprs/rowid_ref.py +12 -11
- pixeltable/exprs/similarity_expr.py +12 -7
- pixeltable/exprs/sql_element_cache.py +7 -5
- pixeltable/exprs/type_cast.py +8 -6
- pixeltable/exprs/variable.py +5 -4
- pixeltable/func/aggregate_function.py +9 -9
- pixeltable/func/expr_template_function.py +6 -5
- pixeltable/func/function.py +11 -10
- pixeltable/func/udf.py +6 -11
- pixeltable/functions/__init__.py +2 -2
- pixeltable/functions/globals.py +5 -7
- pixeltable/functions/huggingface.py +155 -45
- pixeltable/functions/llama_cpp.py +107 -0
- pixeltable/functions/mistralai.py +1 -1
- pixeltable/functions/ollama.py +147 -0
- pixeltable/functions/openai.py +1 -1
- pixeltable/functions/replicate.py +72 -0
- pixeltable/functions/string.py +9 -0
- pixeltable/functions/together.py +1 -1
- pixeltable/functions/util.py +5 -2
- pixeltable/globals.py +67 -26
- pixeltable/index/btree.py +16 -3
- pixeltable/index/embedding_index.py +4 -4
- pixeltable/io/__init__.py +1 -2
- pixeltable/io/fiftyone.py +178 -0
- pixeltable/io/globals.py +96 -2
- pixeltable/iterators/base.py +3 -2
- pixeltable/iterators/document.py +1 -1
- pixeltable/iterators/video.py +120 -63
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_21.py +34 -0
- pixeltable/metadata/converters/util.py +45 -4
- pixeltable/metadata/notes.py +1 -0
- pixeltable/metadata/schema.py +8 -0
- pixeltable/plan.py +17 -15
- pixeltable/py.typed +0 -0
- pixeltable/store.py +7 -2
- pixeltable/tool/create_test_db_dump.py +1 -1
- pixeltable/tool/create_test_video.py +1 -1
- pixeltable/tool/embed_udf.py +1 -1
- pixeltable/tool/mypy_plugin.py +28 -5
- pixeltable/type_system.py +100 -36
- pixeltable/utils/coco.py +5 -5
- pixeltable/utils/documents.py +15 -1
- pixeltable/utils/formatter.py +12 -13
- pixeltable/utils/s3.py +6 -3
- {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/METADATA +158 -49
- pixeltable-0.2.23.dist-info/RECORD +153 -0
- pixeltable/exec/media_validation_node.py +0 -43
- pixeltable-0.2.21.dist-info/RECORD +0 -148
- {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/entry_points.txt +0 -0
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
4
|
import inspect
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
6
6
|
|
|
7
7
|
import pixeltable.exceptions as excs
|
|
8
8
|
import pixeltable.type_system as ts
|
|
@@ -36,8 +36,8 @@ class AggregateFunction(Function):
|
|
|
36
36
|
RESERVED_PARAMS = {ORDER_BY_PARAM, GROUP_BY_PARAM}
|
|
37
37
|
|
|
38
38
|
def __init__(
|
|
39
|
-
self, aggregator_class:
|
|
40
|
-
init_types:
|
|
39
|
+
self, aggregator_class: type[Aggregator], self_path: str,
|
|
40
|
+
init_types: list[ts.ColumnType], update_types: list[ts.ColumnType], value_type: ts.ColumnType,
|
|
41
41
|
requires_order_by: bool, allows_std_agg: bool, allows_window: bool):
|
|
42
42
|
self.agg_cls = aggregator_class
|
|
43
43
|
self.requires_order_by = requires_order_by
|
|
@@ -86,7 +86,7 @@ class AggregateFunction(Function):
|
|
|
86
86
|
res += '\n\n' + inspect.getdoc(self.agg_cls.update)
|
|
87
87
|
return res
|
|
88
88
|
|
|
89
|
-
def __call__(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.
|
|
89
|
+
def __call__(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.FunctionCall':
|
|
90
90
|
from pixeltable import exprs
|
|
91
91
|
|
|
92
92
|
# perform semantic analysis of special parameters 'order_by' and 'group_by'
|
|
@@ -128,7 +128,7 @@ class AggregateFunction(Function):
|
|
|
128
128
|
order_by_clause=[order_by_clause] if order_by_clause is not None else [],
|
|
129
129
|
group_by_clause=[group_by_clause] if group_by_clause is not None else [])
|
|
130
130
|
|
|
131
|
-
def validate_call(self, bound_args:
|
|
131
|
+
def validate_call(self, bound_args: dict[str, Any]) -> None:
|
|
132
132
|
# check that init parameters are not Exprs
|
|
133
133
|
# TODO: do this in the planner (check that init parameters are either constants or only refer to grouping exprs)
|
|
134
134
|
import pixeltable.exprs as exprs
|
|
@@ -146,10 +146,10 @@ class AggregateFunction(Function):
|
|
|
146
146
|
def uda(
|
|
147
147
|
*,
|
|
148
148
|
value_type: ts.ColumnType,
|
|
149
|
-
update_types:
|
|
150
|
-
init_types: Optional[
|
|
149
|
+
update_types: list[ts.ColumnType],
|
|
150
|
+
init_types: Optional[list[ts.ColumnType]] = None,
|
|
151
151
|
requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
|
|
152
|
-
) -> Callable[[
|
|
152
|
+
) -> Callable[[type[Aggregator]], AggregateFunction]:
|
|
153
153
|
"""Decorator for user-defined aggregate functions.
|
|
154
154
|
|
|
155
155
|
The decorated class must inherit from Aggregator and implement the following methods:
|
|
@@ -171,7 +171,7 @@ def uda(
|
|
|
171
171
|
if init_types is None:
|
|
172
172
|
init_types = []
|
|
173
173
|
|
|
174
|
-
def decorator(cls:
|
|
174
|
+
def decorator(cls: type[Aggregator]) -> AggregateFunction:
|
|
175
175
|
# validate type parameters
|
|
176
176
|
num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
|
|
177
177
|
if num_init_params > 0:
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Any, Optional
|
|
3
3
|
|
|
4
4
|
import pixeltable
|
|
5
5
|
import pixeltable.exceptions as excs
|
|
6
|
+
|
|
6
7
|
from .function import Function
|
|
7
|
-
from .signature import Signature
|
|
8
|
+
from .signature import Signature
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
class ExprTemplateFunction(Function):
|
|
@@ -22,7 +23,7 @@ class ExprTemplateFunction(Function):
|
|
|
22
23
|
self.param_exprs_by_name = {p.name: p for p in self.param_exprs}
|
|
23
24
|
|
|
24
25
|
# verify default values
|
|
25
|
-
self.defaults:
|
|
26
|
+
self.defaults: dict[str, exprs.Literal] = {} # key: param name, value: default value converted to a Literal
|
|
26
27
|
for param in signature.parameters.values():
|
|
27
28
|
if param.default is inspect.Parameter.empty:
|
|
28
29
|
continue
|
|
@@ -77,7 +78,7 @@ class ExprTemplateFunction(Function):
|
|
|
77
78
|
def name(self) -> str:
|
|
78
79
|
return self.self_name
|
|
79
80
|
|
|
80
|
-
def _as_dict(self) ->
|
|
81
|
+
def _as_dict(self) -> dict:
|
|
81
82
|
if self.self_path is not None:
|
|
82
83
|
return super()._as_dict()
|
|
83
84
|
return {
|
|
@@ -87,7 +88,7 @@ class ExprTemplateFunction(Function):
|
|
|
87
88
|
}
|
|
88
89
|
|
|
89
90
|
@classmethod
|
|
90
|
-
def _from_dict(cls, d:
|
|
91
|
+
def _from_dict(cls, d: dict) -> Function:
|
|
91
92
|
if 'expr' not in d:
|
|
92
93
|
return super()._from_dict(d)
|
|
93
94
|
assert 'signature' in d and 'name' in d
|
pixeltable/func/function.py
CHANGED
|
@@ -3,12 +3,13 @@ from __future__ import annotations
|
|
|
3
3
|
import abc
|
|
4
4
|
import importlib
|
|
5
5
|
import inspect
|
|
6
|
-
from typing import Any, Callable,
|
|
6
|
+
from typing import Any, Callable, Optional
|
|
7
7
|
|
|
8
8
|
import sqlalchemy as sql
|
|
9
9
|
|
|
10
|
-
import pixeltable
|
|
10
|
+
import pixeltable as pxt
|
|
11
11
|
import pixeltable.type_system as ts
|
|
12
|
+
|
|
12
13
|
from .globals import resolve_symbol
|
|
13
14
|
from .signature import Signature
|
|
14
15
|
|
|
@@ -66,13 +67,13 @@ class Function(abc.ABC):
|
|
|
66
67
|
def help_str(self) -> str:
|
|
67
68
|
return self.display_name + str(self.signature)
|
|
68
69
|
|
|
69
|
-
def __call__(self, *args: Any, **kwargs: Any) -> '
|
|
70
|
+
def __call__(self, *args: Any, **kwargs: Any) -> 'pxt.exprs.FunctionCall':
|
|
70
71
|
from pixeltable import exprs
|
|
71
72
|
bound_args = self.signature.py_signature.bind(*args, **kwargs)
|
|
72
73
|
self.validate_call(bound_args.arguments)
|
|
73
74
|
return exprs.FunctionCall(self, bound_args.arguments)
|
|
74
75
|
|
|
75
|
-
def validate_call(self, bound_args:
|
|
76
|
+
def validate_call(self, bound_args: dict[str, Any]) -> None:
|
|
76
77
|
"""Override this to do custom validation of the arguments"""
|
|
77
78
|
pass
|
|
78
79
|
|
|
@@ -121,7 +122,7 @@ class Function(abc.ABC):
|
|
|
121
122
|
"""Print source code"""
|
|
122
123
|
print('source not available')
|
|
123
124
|
|
|
124
|
-
def as_dict(self) ->
|
|
125
|
+
def as_dict(self) -> dict:
|
|
125
126
|
"""
|
|
126
127
|
Return a serialized reference to the instance that can be passed to json.dumps() and converted back
|
|
127
128
|
to an instance with from_dict().
|
|
@@ -130,13 +131,13 @@ class Function(abc.ABC):
|
|
|
130
131
|
classpath = f'{self.__class__.__module__}.{self.__class__.__qualname__}'
|
|
131
132
|
return {'_classpath': classpath, **self._as_dict()}
|
|
132
133
|
|
|
133
|
-
def _as_dict(self) ->
|
|
134
|
+
def _as_dict(self) -> dict:
|
|
134
135
|
"""Default serialization: store the path to self (which includes the module path)"""
|
|
135
136
|
assert self.self_path is not None
|
|
136
137
|
return {'path': self.self_path}
|
|
137
138
|
|
|
138
139
|
@classmethod
|
|
139
|
-
def from_dict(cls, d:
|
|
140
|
+
def from_dict(cls, d: dict) -> Function:
|
|
140
141
|
"""
|
|
141
142
|
Turn dict that was produced by calling as_dict() into an instance of the correct Function subclass.
|
|
142
143
|
"""
|
|
@@ -147,14 +148,14 @@ class Function(abc.ABC):
|
|
|
147
148
|
return func_class._from_dict(d)
|
|
148
149
|
|
|
149
150
|
@classmethod
|
|
150
|
-
def _from_dict(cls, d:
|
|
151
|
+
def _from_dict(cls, d: dict) -> Function:
|
|
151
152
|
"""Default deserialization: load the symbol indicated by the stored symbol_path"""
|
|
152
153
|
assert 'path' in d and d['path'] is not None
|
|
153
154
|
instance = resolve_symbol(d['path'])
|
|
154
155
|
assert isinstance(instance, Function)
|
|
155
156
|
return instance
|
|
156
157
|
|
|
157
|
-
def to_store(self) ->
|
|
158
|
+
def to_store(self) -> tuple[dict, bytes]:
|
|
158
159
|
"""
|
|
159
160
|
Serialize the function to a format that can be stored in the Pixeltable store
|
|
160
161
|
Returns:
|
|
@@ -165,7 +166,7 @@ class Function(abc.ABC):
|
|
|
165
166
|
raise NotImplementedError()
|
|
166
167
|
|
|
167
168
|
@classmethod
|
|
168
|
-
def from_store(cls, name: Optional[str], md:
|
|
169
|
+
def from_store(cls, name: Optional[str], md: dict, binary_obj: bytes) -> Function:
|
|
169
170
|
"""
|
|
170
171
|
Create a Function instance from the serialized representation returned by to_store()
|
|
171
172
|
"""
|
pixeltable/func/udf.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Any, Callable, Optional, overload
|
|
4
4
|
|
|
5
5
|
import pixeltable.exceptions as excs
|
|
6
6
|
import pixeltable.type_system as ts
|
|
7
|
+
|
|
7
8
|
from .callable_function import CallableFunction
|
|
8
9
|
from .expr_template_function import ExprTemplateFunction
|
|
9
10
|
from .function import Function
|
|
@@ -21,8 +22,6 @@ def udf(decorated_fn: Callable) -> Function: ...
|
|
|
21
22
|
@overload
|
|
22
23
|
def udf(
|
|
23
24
|
*,
|
|
24
|
-
return_type: Optional[ts.ColumnType] = None,
|
|
25
|
-
param_types: Optional[List[ts.ColumnType]] = None,
|
|
26
25
|
batch_size: Optional[int] = None,
|
|
27
26
|
substitute_fn: Optional[Callable] = None,
|
|
28
27
|
is_method: bool = False,
|
|
@@ -49,8 +48,6 @@ def udf(*args, **kwargs):
|
|
|
49
48
|
|
|
50
49
|
# Decorator schema invoked with parentheses: @pxt.udf(**kwargs)
|
|
51
50
|
# Create a decorator for the specified schema.
|
|
52
|
-
return_type = kwargs.pop('return_type', None)
|
|
53
|
-
param_types = kwargs.pop('param_types', None)
|
|
54
51
|
batch_size = kwargs.pop('batch_size', None)
|
|
55
52
|
substitute_fn = kwargs.pop('substitute_fn', None)
|
|
56
53
|
is_method = kwargs.pop('is_method', None)
|
|
@@ -64,9 +61,7 @@ def udf(*args, **kwargs):
|
|
|
64
61
|
def decorator(decorated_fn: Callable):
|
|
65
62
|
return make_function(
|
|
66
63
|
decorated_fn,
|
|
67
|
-
|
|
68
|
-
param_types,
|
|
69
|
-
batch_size,
|
|
64
|
+
batch_size=batch_size,
|
|
70
65
|
substitute_fn=substitute_fn,
|
|
71
66
|
is_method=is_method,
|
|
72
67
|
is_property=is_property,
|
|
@@ -79,7 +74,7 @@ def udf(*args, **kwargs):
|
|
|
79
74
|
def make_function(
|
|
80
75
|
decorated_fn: Callable,
|
|
81
76
|
return_type: Optional[ts.ColumnType] = None,
|
|
82
|
-
param_types: Optional[
|
|
77
|
+
param_types: Optional[list[ts.ColumnType]] = None,
|
|
83
78
|
batch_size: Optional[int] = None,
|
|
84
79
|
substitute_fn: Optional[Callable] = None,
|
|
85
80
|
is_method: bool = False,
|
|
@@ -158,10 +153,10 @@ def make_function(
|
|
|
158
153
|
def expr_udf(py_fn: Callable) -> ExprTemplateFunction: ...
|
|
159
154
|
|
|
160
155
|
@overload
|
|
161
|
-
def expr_udf(*, param_types: Optional[
|
|
156
|
+
def expr_udf(*, param_types: Optional[list[ts.ColumnType]] = None) -> Callable[[Callable], ExprTemplateFunction]: ...
|
|
162
157
|
|
|
163
158
|
def expr_udf(*args: Any, **kwargs: Any) -> Any:
|
|
164
|
-
def make_expr_template(py_fn: Callable, param_types: Optional[
|
|
159
|
+
def make_expr_template(py_fn: Callable, param_types: Optional[list[ts.ColumnType]]) -> ExprTemplateFunction:
|
|
165
160
|
if py_fn.__module__ != '__main__' and py_fn.__name__.isidentifier():
|
|
166
161
|
# this is a named function in a module
|
|
167
162
|
function_path = f'{py_fn.__module__}.{py_fn.__qualname__}'
|
pixeltable/functions/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from pixeltable.utils.code import local_public_names
|
|
2
2
|
|
|
3
|
-
from . import (anthropic, audio, fireworks, huggingface, image, json,
|
|
4
|
-
video, vision, whisper)
|
|
3
|
+
from . import (anthropic, audio, fireworks, huggingface, image, json, llama_cpp, mistralai, ollama, openai, string,
|
|
4
|
+
timestamp, together, video, vision, whisper)
|
|
5
5
|
from .globals import *
|
|
6
6
|
|
|
7
7
|
__all__ = local_public_names(__name__, exclude=['globals']) + local_public_names(globals.__name__)
|
pixeltable/functions/globals.py
CHANGED
|
@@ -36,9 +36,7 @@ class sum(func.Aggregator):
|
|
|
36
36
|
return self.sum
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
|
|
40
|
-
# TODO: find a way to have this type-checked
|
|
41
|
-
@sum.to_sql # type: ignore
|
|
39
|
+
@sum.to_sql
|
|
42
40
|
def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
|
|
43
41
|
# This can produce a Decimal. We are deliberately avoiding an explicit cast to a Bigint here, because that can
|
|
44
42
|
# cause overflows in Postgres. We're instead doing the conversion to the target type in SqlNode.__iter__().
|
|
@@ -58,7 +56,7 @@ class count(func.Aggregator):
|
|
|
58
56
|
return self.count
|
|
59
57
|
|
|
60
58
|
|
|
61
|
-
@count.to_sql
|
|
59
|
+
@count.to_sql
|
|
62
60
|
def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
|
|
63
61
|
return sql.sql.func.count(val)
|
|
64
62
|
|
|
@@ -82,7 +80,7 @@ class min(func.Aggregator):
|
|
|
82
80
|
return self.val
|
|
83
81
|
|
|
84
82
|
|
|
85
|
-
@min.to_sql
|
|
83
|
+
@min.to_sql
|
|
86
84
|
def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
|
|
87
85
|
return sql.sql.func.min(val)
|
|
88
86
|
|
|
@@ -106,7 +104,7 @@ class max(func.Aggregator):
|
|
|
106
104
|
return self.val
|
|
107
105
|
|
|
108
106
|
|
|
109
|
-
@max.to_sql
|
|
107
|
+
@max.to_sql
|
|
110
108
|
def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
|
|
111
109
|
return sql.sql.func.max(val)
|
|
112
110
|
|
|
@@ -134,7 +132,7 @@ class mean(func.Aggregator):
|
|
|
134
132
|
return self.sum / self.count
|
|
135
133
|
|
|
136
134
|
|
|
137
|
-
@mean.to_sql
|
|
135
|
+
@mean.to_sql
|
|
138
136
|
def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
|
|
139
137
|
return sql.sql.func.avg(val)
|
|
140
138
|
|
|
@@ -7,21 +7,22 @@ first `pip install transformers` (or in some cases, `sentence-transformers`, as
|
|
|
7
7
|
UDFs).
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
from typing import
|
|
10
|
+
from typing import Any, Callable, Optional, TypeVar
|
|
11
11
|
|
|
12
12
|
import PIL.Image
|
|
13
13
|
|
|
14
14
|
import pixeltable as pxt
|
|
15
15
|
import pixeltable.env as env
|
|
16
|
+
import pixeltable.exceptions as excs
|
|
16
17
|
from pixeltable.func import Batch
|
|
17
|
-
from pixeltable.functions.util import
|
|
18
|
+
from pixeltable.functions.util import normalize_image_mode, resolve_torch_device
|
|
18
19
|
from pixeltable.utils.code import local_public_names
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
@pxt.udf(batch_size=32)
|
|
22
23
|
def sentence_transformer(
|
|
23
24
|
sentence: Batch[str], *, model_id: str, normalize_embeddings: bool = False
|
|
24
|
-
) -> Batch[pxt.Array[(None,),
|
|
25
|
+
) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
25
26
|
"""
|
|
26
27
|
Computes sentence embeddings. `model_id` should be a pretrained Sentence Transformers model, as described
|
|
27
28
|
in the [Sentence Transformers Pretrained Models](https://sbert.net/docs/sentence_transformer/pretrained_models.html)
|
|
@@ -29,7 +30,7 @@ def sentence_transformer(
|
|
|
29
30
|
|
|
30
31
|
__Requirements:__
|
|
31
32
|
|
|
32
|
-
- `pip install sentence-transformers`
|
|
33
|
+
- `pip install torch sentence-transformers`
|
|
33
34
|
|
|
34
35
|
Args:
|
|
35
36
|
sentence: The sentence to embed.
|
|
@@ -48,11 +49,15 @@ def sentence_transformer(
|
|
|
48
49
|
>>> tbl['result'] = sentence_transformer(tbl.sentence, model_id='all-mpnet-base-v2')
|
|
49
50
|
"""
|
|
50
51
|
env.Env.get().require_package('sentence_transformers')
|
|
52
|
+
device = resolve_torch_device('auto')
|
|
53
|
+
import torch
|
|
51
54
|
from sentence_transformers import SentenceTransformer # type: ignore
|
|
52
55
|
|
|
53
|
-
model
|
|
56
|
+
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
57
|
+
model = _lookup_model(model_id, SentenceTransformer, device=device, pass_device_to_create=True)
|
|
54
58
|
|
|
55
|
-
|
|
59
|
+
# specifying the device, uses it for computation
|
|
60
|
+
array = model.encode(sentence, device=device, normalize_embeddings=normalize_embeddings)
|
|
56
61
|
return [array[i] for i in range(array.shape[0])]
|
|
57
62
|
|
|
58
63
|
|
|
@@ -70,11 +75,15 @@ def _(model_id: str) -> pxt.ArrayType:
|
|
|
70
75
|
@pxt.udf
|
|
71
76
|
def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embeddings: bool = False) -> list:
|
|
72
77
|
env.Env.get().require_package('sentence_transformers')
|
|
78
|
+
device = resolve_torch_device('auto')
|
|
79
|
+
import torch
|
|
73
80
|
from sentence_transformers import SentenceTransformer
|
|
74
81
|
|
|
75
|
-
model
|
|
82
|
+
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
83
|
+
model = _lookup_model(model_id, SentenceTransformer, device=device, pass_device_to_create=True)
|
|
76
84
|
|
|
77
|
-
|
|
85
|
+
# specifying the device, uses it for computation
|
|
86
|
+
array = model.encode(sentences, device=device, normalize_embeddings=normalize_embeddings)
|
|
78
87
|
return [array[i].tolist() for i in range(array.shape[0])]
|
|
79
88
|
|
|
80
89
|
|
|
@@ -88,7 +97,7 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
|
|
|
88
97
|
|
|
89
98
|
__Requirements:__
|
|
90
99
|
|
|
91
|
-
- `pip install sentence-transformers`
|
|
100
|
+
- `pip install torch sentence-transformers`
|
|
92
101
|
|
|
93
102
|
Parameters:
|
|
94
103
|
sentences1: The first sentence to be paired.
|
|
@@ -107,9 +116,13 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
|
|
|
107
116
|
)
|
|
108
117
|
"""
|
|
109
118
|
env.Env.get().require_package('sentence_transformers')
|
|
119
|
+
device = resolve_torch_device('auto')
|
|
120
|
+
import torch
|
|
110
121
|
from sentence_transformers import CrossEncoder
|
|
111
122
|
|
|
112
|
-
model
|
|
123
|
+
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
124
|
+
# and uses the device for predict computation
|
|
125
|
+
model = _lookup_model(model_id, CrossEncoder, device=device, pass_device_to_create=True)
|
|
113
126
|
|
|
114
127
|
array = model.predict([[s1, s2] for s1, s2 in zip(sentences1, sentences2)], convert_to_numpy=True)
|
|
115
128
|
return array.tolist()
|
|
@@ -118,23 +131,27 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
|
|
|
118
131
|
@pxt.udf
|
|
119
132
|
def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> list:
|
|
120
133
|
env.Env.get().require_package('sentence_transformers')
|
|
134
|
+
device = resolve_torch_device('auto')
|
|
135
|
+
import torch
|
|
121
136
|
from sentence_transformers import CrossEncoder
|
|
122
137
|
|
|
123
|
-
model
|
|
138
|
+
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
139
|
+
# and uses the device for predict computation
|
|
140
|
+
model = _lookup_model(model_id, CrossEncoder, device=device, pass_device_to_create=True)
|
|
124
141
|
|
|
125
142
|
array = model.predict([[sentence1, s2] for s2 in sentences2], convert_to_numpy=True)
|
|
126
143
|
return array.tolist()
|
|
127
144
|
|
|
128
145
|
|
|
129
146
|
@pxt.udf(batch_size=32)
|
|
130
|
-
def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,),
|
|
147
|
+
def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
131
148
|
"""
|
|
132
149
|
Computes a CLIP embedding for the specified text. `model_id` should be a reference to a pretrained
|
|
133
150
|
[CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
|
|
134
151
|
|
|
135
152
|
__Requirements:__
|
|
136
153
|
|
|
137
|
-
- `pip install transformers`
|
|
154
|
+
- `pip install torch transformers`
|
|
138
155
|
|
|
139
156
|
Args:
|
|
140
157
|
text: The string to embed.
|
|
@@ -165,14 +182,14 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), fl
|
|
|
165
182
|
|
|
166
183
|
|
|
167
184
|
@pxt.udf(batch_size=32)
|
|
168
|
-
def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,),
|
|
185
|
+
def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
169
186
|
"""
|
|
170
187
|
Computes a CLIP embedding for the specified image. `model_id` should be a reference to a pretrained
|
|
171
188
|
[CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
|
|
172
189
|
|
|
173
190
|
__Requirements:__
|
|
174
191
|
|
|
175
|
-
- `pip install transformers`
|
|
192
|
+
- `pip install torch transformers`
|
|
176
193
|
|
|
177
194
|
Args:
|
|
178
195
|
image: The image to embed.
|
|
@@ -215,14 +232,20 @@ def _(model_id: str) -> pxt.ArrayType:
|
|
|
215
232
|
|
|
216
233
|
|
|
217
234
|
@pxt.udf(batch_size=4)
|
|
218
|
-
def detr_for_object_detection(
|
|
235
|
+
def detr_for_object_detection(
|
|
236
|
+
image: Batch[PIL.Image.Image],
|
|
237
|
+
*,
|
|
238
|
+
model_id: str,
|
|
239
|
+
threshold: float = 0.5,
|
|
240
|
+
revision: str = 'no_timm',
|
|
241
|
+
) -> Batch[dict]:
|
|
219
242
|
"""
|
|
220
243
|
Computes DETR object detections for the specified image. `model_id` should be a reference to a pretrained
|
|
221
244
|
[DETR Model](https://huggingface.co/docs/transformers/model_doc/detr).
|
|
222
245
|
|
|
223
246
|
__Requirements:__
|
|
224
247
|
|
|
225
|
-
- `pip install transformers`
|
|
248
|
+
- `pip install torch transformers`
|
|
226
249
|
|
|
227
250
|
Args:
|
|
228
251
|
image: The image to embed.
|
|
@@ -254,12 +277,12 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
|
|
|
254
277
|
env.Env.get().require_package('transformers')
|
|
255
278
|
device = resolve_torch_device('auto')
|
|
256
279
|
import torch
|
|
257
|
-
from transformers import
|
|
280
|
+
from transformers import DetrForObjectDetection, DetrImageProcessor
|
|
258
281
|
|
|
259
282
|
model = _lookup_model(
|
|
260
|
-
model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision=
|
|
283
|
+
model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision=revision), device=device
|
|
261
284
|
)
|
|
262
|
-
processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision=
|
|
285
|
+
processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision=revision))
|
|
263
286
|
normalized_images = [normalize_image_mode(img) for img in image]
|
|
264
287
|
|
|
265
288
|
with torch.no_grad():
|
|
@@ -286,7 +309,7 @@ def vit_for_image_classification(
|
|
|
286
309
|
*,
|
|
287
310
|
model_id: str,
|
|
288
311
|
top_k: int = 5
|
|
289
|
-
) -> Batch[
|
|
312
|
+
) -> Batch[dict[str, Any]]:
|
|
290
313
|
"""
|
|
291
314
|
Computes image classifications for the specified image using a Vision Transformer (ViT) model.
|
|
292
315
|
`model_id` should be a reference to a pretrained [ViT Model](https://huggingface.co/docs/transformers/en/model_doc/vit).
|
|
@@ -299,7 +322,7 @@ def vit_for_image_classification(
|
|
|
299
322
|
|
|
300
323
|
__Requirements:__
|
|
301
324
|
|
|
302
|
-
- `pip install transformers`
|
|
325
|
+
- `pip install torch transformers`
|
|
303
326
|
|
|
304
327
|
Args:
|
|
305
328
|
image: The image to classify.
|
|
@@ -307,30 +330,30 @@ def vit_for_image_classification(
|
|
|
307
330
|
top_k: The number of classes to return.
|
|
308
331
|
|
|
309
332
|
Returns:
|
|
310
|
-
A
|
|
311
|
-
in the following format:
|
|
333
|
+
A dictionary containing the output of the image classification model, in the following format:
|
|
312
334
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
335
|
+
```python
|
|
336
|
+
{
|
|
337
|
+
'scores': [0.325, 0.198, 0.105], # list of probabilities of the top-k most likely classes
|
|
338
|
+
'labels': [340, 353, 386], # list of class IDs for the top-k most likely classes
|
|
339
|
+
'label_text': ['zebra', 'gazelle', 'African elephant, Loxodonta africana'],
|
|
340
|
+
# corresponding text names of the top-k most likely classes
|
|
341
|
+
```
|
|
320
342
|
|
|
321
343
|
Examples:
|
|
322
344
|
Add a computed column that applies the model `google/vit-base-patch16-224` to an existing
|
|
323
|
-
Pixeltable column `image` of the table `tbl
|
|
345
|
+
Pixeltable column `image` of the table `tbl`, returning the 10 most likely classes for each image:
|
|
324
346
|
|
|
325
347
|
>>> tbl['image_class'] = vit_for_image_classification(
|
|
326
348
|
... tbl.image,
|
|
327
|
-
... model_id='google/vit-base-patch16-224'
|
|
349
|
+
... model_id='google/vit-base-patch16-224',
|
|
350
|
+
... top_k=10
|
|
328
351
|
... )
|
|
329
352
|
"""
|
|
330
353
|
env.Env.get().require_package('transformers')
|
|
331
354
|
device = resolve_torch_device('auto')
|
|
332
355
|
import torch
|
|
333
|
-
from transformers import
|
|
356
|
+
from transformers import ViTForImageClassification, ViTImageProcessor
|
|
334
357
|
|
|
335
358
|
model: ViTForImageClassification = _lookup_model(model_id, ViTForImageClassification.from_pretrained, device=device)
|
|
336
359
|
processor = _lookup_processor(model_id, ViTImageProcessor.from_pretrained)
|
|
@@ -344,19 +367,98 @@ def vit_for_image_classification(
|
|
|
344
367
|
probs = torch.softmax(logits, dim=-1)
|
|
345
368
|
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
|
|
346
369
|
|
|
370
|
+
# There is no official post_process method for ViT models; for consistency, we structure the output
|
|
371
|
+
# the same way as the output of the DETR model given by `post_process_object_detection`.
|
|
347
372
|
return [
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
}
|
|
354
|
-
for k in range(top_k_probs.shape[1])
|
|
355
|
-
]
|
|
373
|
+
{
|
|
374
|
+
'scores': [top_k_probs[n, k].item() for k in range(top_k_probs.shape[1])],
|
|
375
|
+
'labels': [top_k_indices[n, k].item() for k in range(top_k_probs.shape[1])],
|
|
376
|
+
'label_text': [model.config.id2label[top_k_indices[n, k].item()] for k in range(top_k_probs.shape[1])],
|
|
377
|
+
}
|
|
356
378
|
for n in range(top_k_probs.shape[0])
|
|
357
379
|
]
|
|
358
380
|
|
|
359
381
|
|
|
382
|
+
@pxt.udf
|
|
383
|
+
def speech2text_for_conditional_generation(
|
|
384
|
+
audio: pxt.Audio,
|
|
385
|
+
*,
|
|
386
|
+
model_id: str,
|
|
387
|
+
language: Optional[str] = None,
|
|
388
|
+
) -> str:
|
|
389
|
+
"""
|
|
390
|
+
Transcribes or translates speech to text using a Speech2Text model. `model_id` should be a reference to a
|
|
391
|
+
pretrained [Speech2Text](https://huggingface.co/docs/transformers/en/model_doc/speech_to_text) model.
|
|
392
|
+
|
|
393
|
+
__Requirements:__
|
|
394
|
+
|
|
395
|
+
- `pip install torch torchaudio sentencepiece transformers`
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
audio: The audio clip to transcribe or translate.
|
|
399
|
+
model_id: The pretrained model to use for the transcription or translation.
|
|
400
|
+
language: If using a multilingual translation model, the language code to translate to. If not provided,
|
|
401
|
+
the model's default language will be used. If the model is not translation model, is not a
|
|
402
|
+
multilingual model, or does not support the specified language, an error will be raised.
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
The transcribed or translated text.
|
|
406
|
+
|
|
407
|
+
Examples:
|
|
408
|
+
Add a computed column that applies the model `facebook/s2t-small-librispeech-asr` to an existing
|
|
409
|
+
Pixeltable column `audio` of the table `tbl`:
|
|
410
|
+
|
|
411
|
+
>>> tbl['transcription'] = speech2text_for_conditional_generation(
|
|
412
|
+
... tbl.audio,
|
|
413
|
+
... model_id='facebook/s2t-small-librispeech-asr'
|
|
414
|
+
... )
|
|
415
|
+
|
|
416
|
+
Add a computed column that applies the model `facebook/s2t-medium-mustc-multilingual-st` to an existing
|
|
417
|
+
Pixeltable column `audio` of the table `tbl`, translating the audio to French:
|
|
418
|
+
|
|
419
|
+
>>> tbl['translation'] = speech2text_for_conditional_generation(
|
|
420
|
+
... tbl.audio,
|
|
421
|
+
... model_id='facebook/s2t-medium-mustc-multilingual-st',
|
|
422
|
+
... language='fr'
|
|
423
|
+
... )
|
|
424
|
+
"""
|
|
425
|
+
env.Env.get().require_package('transformers')
|
|
426
|
+
env.Env.get().require_package('torchaudio')
|
|
427
|
+
env.Env.get().require_package('sentencepiece')
|
|
428
|
+
device = resolve_torch_device('auto', allow_mps=False) # Doesn't seem to work on 'mps'; use 'cpu' instead
|
|
429
|
+
import librosa
|
|
430
|
+
import torch
|
|
431
|
+
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
|
432
|
+
|
|
433
|
+
# facebook/s2t-small-librispeech-asr
|
|
434
|
+
# facebook/s2t-small-mustc-en-fr-st
|
|
435
|
+
model = _lookup_model(model_id, Speech2TextForConditionalGeneration.from_pretrained, device=device)
|
|
436
|
+
processor = _lookup_processor(model_id, Speech2TextProcessor.from_pretrained)
|
|
437
|
+
assert isinstance(processor, Speech2TextProcessor)
|
|
438
|
+
|
|
439
|
+
if language is not None and language not in processor.tokenizer.lang_code_to_id:
|
|
440
|
+
raise excs.Error(
|
|
441
|
+
f"Language code '{language}' is not supported by the model '{model_id}'. "
|
|
442
|
+
f"Supported languages are: {list(processor.tokenizer.lang_code_to_id.keys())}")
|
|
443
|
+
|
|
444
|
+
forced_bos_token_id: Optional[int] = None if language is None else processor.tokenizer.lang_code_to_id[language]
|
|
445
|
+
|
|
446
|
+
# Get the model's sampling rate. Default to 16 kHz (the standard) if not in config
|
|
447
|
+
model_sampling_rate = getattr(model.config, 'sampling_rate', 16_000)
|
|
448
|
+
waveform, sampling_rate = librosa.load(audio, sr=model_sampling_rate, mono=True)
|
|
449
|
+
|
|
450
|
+
with torch.no_grad():
|
|
451
|
+
inputs = processor(
|
|
452
|
+
waveform,
|
|
453
|
+
sampling_rate=sampling_rate,
|
|
454
|
+
return_tensors='pt'
|
|
455
|
+
)
|
|
456
|
+
generated_ids = model.generate(**inputs.to(device), forced_bos_token_id=forced_bos_token_id).to('cpu')
|
|
457
|
+
|
|
458
|
+
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
459
|
+
return transcription
|
|
460
|
+
|
|
461
|
+
|
|
360
462
|
@pxt.udf
|
|
361
463
|
def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str, Any]:
|
|
362
464
|
"""
|
|
@@ -386,14 +488,22 @@ def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str,
|
|
|
386
488
|
T = TypeVar('T')
|
|
387
489
|
|
|
388
490
|
|
|
389
|
-
def _lookup_model(
|
|
491
|
+
def _lookup_model(
|
|
492
|
+
model_id: str,
|
|
493
|
+
create: Callable[..., T],
|
|
494
|
+
device: Optional[str] = None,
|
|
495
|
+
pass_device_to_create: bool = False
|
|
496
|
+
) -> T:
|
|
390
497
|
from torch import nn
|
|
391
498
|
|
|
392
499
|
key = (model_id, create, device) # For safety, include the `create` callable in the cache key
|
|
393
500
|
if key not in _model_cache:
|
|
394
|
-
|
|
501
|
+
if pass_device_to_create:
|
|
502
|
+
model = create(model_id, device=device)
|
|
503
|
+
else:
|
|
504
|
+
model = create(model_id)
|
|
395
505
|
if isinstance(model, nn.Module):
|
|
396
|
-
if device is not None:
|
|
506
|
+
if not pass_device_to_create and device is not None:
|
|
397
507
|
model.to(device)
|
|
398
508
|
model.eval()
|
|
399
509
|
_model_cache[key] = model
|