pixeltable 0.2.21__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +1 -1
- pixeltable/catalog/column.py +37 -11
- pixeltable/catalog/globals.py +18 -0
- pixeltable/catalog/insertable_table.py +6 -4
- pixeltable/catalog/table.py +19 -3
- pixeltable/catalog/table_version.py +34 -14
- pixeltable/catalog/view.py +16 -17
- pixeltable/dataframe.py +7 -8
- pixeltable/env.py +5 -0
- pixeltable/exec/__init__.py +0 -1
- pixeltable/exec/aggregation_node.py +6 -3
- pixeltable/exec/cache_prefetch_node.py +1 -1
- pixeltable/exec/data_row_batch.py +2 -19
- pixeltable/exec/exec_node.py +2 -1
- pixeltable/exec/expr_eval_node.py +17 -10
- pixeltable/exec/in_memory_data_node.py +6 -3
- pixeltable/exec/sql_node.py +24 -25
- pixeltable/exprs/arithmetic_expr.py +3 -1
- pixeltable/exprs/array_slice.py +7 -7
- pixeltable/exprs/column_property_ref.py +37 -10
- pixeltable/exprs/column_ref.py +93 -14
- pixeltable/exprs/comparison.py +5 -5
- pixeltable/exprs/compound_predicate.py +8 -7
- pixeltable/exprs/data_row.py +27 -18
- pixeltable/exprs/expr.py +53 -52
- pixeltable/exprs/expr_set.py +5 -0
- pixeltable/exprs/function_call.py +32 -16
- pixeltable/exprs/globals.py +4 -1
- pixeltable/exprs/in_predicate.py +8 -7
- pixeltable/exprs/inline_expr.py +4 -4
- pixeltable/exprs/is_null.py +4 -4
- pixeltable/exprs/json_mapper.py +11 -12
- pixeltable/exprs/json_path.py +5 -10
- pixeltable/exprs/literal.py +5 -5
- pixeltable/exprs/method_ref.py +5 -4
- pixeltable/exprs/object_ref.py +2 -1
- pixeltable/exprs/row_builder.py +88 -36
- pixeltable/exprs/rowid_ref.py +12 -11
- pixeltable/exprs/similarity_expr.py +12 -7
- pixeltable/exprs/sql_element_cache.py +7 -5
- pixeltable/exprs/type_cast.py +8 -6
- pixeltable/exprs/variable.py +5 -4
- pixeltable/func/aggregate_function.py +1 -1
- pixeltable/func/function.py +11 -10
- pixeltable/functions/__init__.py +2 -2
- pixeltable/functions/globals.py +5 -7
- pixeltable/functions/huggingface.py +19 -20
- pixeltable/functions/llama_cpp.py +106 -0
- pixeltable/functions/ollama.py +147 -0
- pixeltable/functions/replicate.py +72 -0
- pixeltable/functions/string.py +9 -0
- pixeltable/globals.py +12 -20
- pixeltable/index/btree.py +16 -3
- pixeltable/index/embedding_index.py +4 -4
- pixeltable/io/__init__.py +1 -2
- pixeltable/io/fiftyone.py +178 -0
- pixeltable/io/globals.py +96 -2
- pixeltable/iterators/base.py +3 -2
- pixeltable/iterators/document.py +1 -1
- pixeltable/iterators/video.py +120 -63
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_21.py +34 -0
- pixeltable/metadata/converters/util.py +45 -4
- pixeltable/metadata/notes.py +1 -0
- pixeltable/metadata/schema.py +8 -0
- pixeltable/plan.py +16 -14
- pixeltable/py.typed +0 -0
- pixeltable/store.py +7 -2
- pixeltable/tool/create_test_video.py +1 -1
- pixeltable/tool/embed_udf.py +1 -1
- pixeltable/tool/mypy_plugin.py +28 -5
- pixeltable/type_system.py +17 -1
- pixeltable/utils/documents.py +15 -1
- pixeltable/utils/formatter.py +9 -10
- {pixeltable-0.2.21.dist-info → pixeltable-0.2.22.dist-info}/METADATA +46 -10
- pixeltable-0.2.22.dist-info/RECORD +153 -0
- pixeltable/exec/media_validation_node.py +0 -43
- pixeltable-0.2.21.dist-info/RECORD +0 -148
- {pixeltable-0.2.21.dist-info → pixeltable-0.2.22.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.21.dist-info → pixeltable-0.2.22.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.21.dist-info → pixeltable-0.2.22.dist-info}/entry_points.txt +0 -0
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
from typing import
|
|
2
|
-
from .sql_element_cache import SqlElementCache
|
|
1
|
+
from typing import Any, Optional
|
|
3
2
|
|
|
4
3
|
import sqlalchemy as sql
|
|
5
|
-
import PIL.Image
|
|
6
4
|
|
|
7
5
|
import pixeltable.exceptions as excs
|
|
8
6
|
import pixeltable.type_system as ts
|
|
7
|
+
|
|
9
8
|
from .column_ref import ColumnRef
|
|
10
9
|
from .data_row import DataRow
|
|
11
10
|
from .expr import Expr
|
|
12
11
|
from .literal import Literal
|
|
13
12
|
from .row_builder import RowBuilder
|
|
13
|
+
from .sql_element_cache import SqlElementCache
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class SimilarityExpr(Expr):
|
|
@@ -27,7 +27,7 @@ class SimilarityExpr(Expr):
|
|
|
27
27
|
|
|
28
28
|
# determine index to use
|
|
29
29
|
idx_info = col_ref.col.get_idx_info()
|
|
30
|
-
|
|
30
|
+
from pixeltable import index
|
|
31
31
|
embedding_idx_info = {
|
|
32
32
|
info.name: info for info in idx_info.values() if isinstance(info.idx, index.EmbeddingIndex)
|
|
33
33
|
}
|
|
@@ -44,6 +44,7 @@ class SimilarityExpr(Expr):
|
|
|
44
44
|
else:
|
|
45
45
|
self.idx_info = next(iter(embedding_idx_info.values()))
|
|
46
46
|
idx = self.idx_info.idx
|
|
47
|
+
assert isinstance(idx, index.EmbeddingIndex)
|
|
47
48
|
|
|
48
49
|
if item_expr.col_type.is_string_type() and idx.string_embed is None:
|
|
49
50
|
raise excs.Error(
|
|
@@ -57,16 +58,20 @@ class SimilarityExpr(Expr):
|
|
|
57
58
|
def __str__(self) -> str:
|
|
58
59
|
return f'{self.components[0]}.similarity({self.components[1]})'
|
|
59
60
|
|
|
60
|
-
def sql_expr(self, _: SqlElementCache) -> Optional[sql.
|
|
61
|
+
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
61
62
|
if not isinstance(self.components[1], Literal):
|
|
62
63
|
raise excs.Error(f'similarity(): requires a string or a PIL.Image.Image object, not an expression')
|
|
63
64
|
item = self.components[1].val
|
|
65
|
+
from pixeltable import index
|
|
66
|
+
assert isinstance(self.idx_info.idx, index.EmbeddingIndex)
|
|
64
67
|
return self.idx_info.idx.similarity_clause(self.idx_info.val_col, item)
|
|
65
68
|
|
|
66
|
-
def as_order_by_clause(self, is_asc: bool) -> Optional[sql.
|
|
69
|
+
def as_order_by_clause(self, is_asc: bool) -> Optional[sql.ColumnElement]:
|
|
67
70
|
if not isinstance(self.components[1], Literal):
|
|
68
71
|
raise excs.Error(f'similarity(): requires a string or a PIL.Image.Image object, not an expression')
|
|
69
72
|
item = self.components[1].val
|
|
73
|
+
from pixeltable import index
|
|
74
|
+
assert isinstance(self.idx_info.idx, index.EmbeddingIndex)
|
|
70
75
|
return self.idx_info.idx.order_by_clause(self.idx_info.val_col, item, is_asc)
|
|
71
76
|
|
|
72
77
|
def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
|
|
@@ -74,7 +79,7 @@ class SimilarityExpr(Expr):
|
|
|
74
79
|
assert False
|
|
75
80
|
|
|
76
81
|
@classmethod
|
|
77
|
-
def _from_dict(cls, d: dict, components:
|
|
82
|
+
def _from_dict(cls, d: dict, components: list[Expr]) -> 'SimilarityExpr':
|
|
78
83
|
assert len(components) == 2
|
|
79
84
|
assert isinstance(components[0], ColumnRef)
|
|
80
85
|
return cls(components[0], components[1])
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Iterable, Union, Optional
|
|
1
|
+
from typing import Iterable, Union, Optional, cast
|
|
2
2
|
|
|
3
3
|
import sqlalchemy as sql
|
|
4
4
|
|
|
@@ -27,8 +27,10 @@ class SqlElementCache:
|
|
|
27
27
|
self.cache[e.id] = el
|
|
28
28
|
return el
|
|
29
29
|
|
|
30
|
-
def contains(self,
|
|
31
|
-
"""Returns True if
|
|
32
|
-
|
|
33
|
-
|
|
30
|
+
def contains(self, item: Expr) -> bool:
|
|
31
|
+
"""Returns True if the cache contains a (non-None) value for the given Expr."""
|
|
32
|
+
return self.get(item) is not None
|
|
33
|
+
|
|
34
|
+
def contains_all(self, items: Iterable[Expr]) -> bool:
|
|
35
|
+
"""Returns True if the cache contains a (non-None) value for every item in the collection of Exprs."""
|
|
34
36
|
return all(self.get(e) is not None for e in items)
|
pixeltable/exprs/type_cast.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Any, Optional
|
|
2
2
|
|
|
3
3
|
import sqlalchemy as sql
|
|
4
4
|
|
|
5
5
|
import pixeltable.type_system as ts
|
|
6
|
+
|
|
6
7
|
from .expr import DataRow, Expr
|
|
7
8
|
from .row_builder import RowBuilder
|
|
8
9
|
from .sql_element_cache import SqlElementCache
|
|
@@ -15,7 +16,7 @@ class TypeCast(Expr):
|
|
|
15
16
|
"""
|
|
16
17
|
def __init__(self, underlying: Expr, new_type: ts.ColumnType):
|
|
17
18
|
super().__init__(new_type)
|
|
18
|
-
self.components:
|
|
19
|
+
self.components: list[Expr] = [underlying]
|
|
19
20
|
self.id: Optional[int] = self._create_id()
|
|
20
21
|
|
|
21
22
|
@property
|
|
@@ -26,10 +27,10 @@ class TypeCast(Expr):
|
|
|
26
27
|
# `TypeCast` has no properties beyond those captured by `Expr`.
|
|
27
28
|
return True
|
|
28
29
|
|
|
29
|
-
def _id_attrs(self) ->
|
|
30
|
+
def _id_attrs(self) -> list[tuple[str, Any]]:
|
|
30
31
|
return super()._id_attrs() + [('new_type', self.col_type)]
|
|
31
32
|
|
|
32
|
-
def sql_expr(self, _: SqlElementCache) -> Optional[sql.
|
|
33
|
+
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
33
34
|
"""
|
|
34
35
|
sql_expr() is unimplemented for now, in order to sidestep potentially thorny
|
|
35
36
|
questions about consistency of doing type conversions in both Python and Postgres.
|
|
@@ -40,11 +41,12 @@ class TypeCast(Expr):
|
|
|
40
41
|
original_val = data_row[self._underlying.slot_idx]
|
|
41
42
|
data_row[self.slot_idx] = self.col_type.create_literal(original_val)
|
|
42
43
|
|
|
43
|
-
|
|
44
|
+
|
|
45
|
+
def _as_dict(self) -> dict:
|
|
44
46
|
return {'new_type': self.col_type.as_dict(), **super()._as_dict()}
|
|
45
47
|
|
|
46
48
|
@classmethod
|
|
47
|
-
def _from_dict(cls, d:
|
|
49
|
+
def _from_dict(cls, d: dict, components: list[Expr]) -> 'TypeCast':
|
|
48
50
|
assert 'new_type' in d
|
|
49
51
|
assert len(components) == 1
|
|
50
52
|
return cls(components[0], ts.ColumnType.from_dict(d['new_type']))
|
pixeltable/exprs/variable.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Any, NoReturn
|
|
4
4
|
|
|
5
5
|
import pixeltable.type_system as ts
|
|
6
|
+
|
|
6
7
|
from .data_row import DataRow
|
|
7
8
|
from .expr import Expr
|
|
8
9
|
from .row_builder import RowBuilder
|
|
@@ -20,7 +21,7 @@ class Variable(Expr):
|
|
|
20
21
|
self.name = name
|
|
21
22
|
self.id = self._create_id()
|
|
22
23
|
|
|
23
|
-
def _id_attrs(self) ->
|
|
24
|
+
def _id_attrs(self) -> list[tuple[str, Any]]:
|
|
24
25
|
return super()._id_attrs() + [('name', self.name)]
|
|
25
26
|
|
|
26
27
|
def default_column_name(self) -> NoReturn:
|
|
@@ -38,9 +39,9 @@ class Variable(Expr):
|
|
|
38
39
|
def eval(self, data_row: DataRow, row_builder: RowBuilder) -> NoReturn:
|
|
39
40
|
raise NotImplementedError()
|
|
40
41
|
|
|
41
|
-
def _as_dict(self) ->
|
|
42
|
+
def _as_dict(self) -> dict:
|
|
42
43
|
return {'name': self.name, 'type': self.col_type.as_dict(), **super()._as_dict()}
|
|
43
44
|
|
|
44
45
|
@classmethod
|
|
45
|
-
def _from_dict(cls, d:
|
|
46
|
+
def _from_dict(cls, d: dict, _: list[Expr]) -> Variable:
|
|
46
47
|
return cls(d['name'], ts.ColumnType.from_dict(d['type']))
|
|
@@ -86,7 +86,7 @@ class AggregateFunction(Function):
|
|
|
86
86
|
res += '\n\n' + inspect.getdoc(self.agg_cls.update)
|
|
87
87
|
return res
|
|
88
88
|
|
|
89
|
-
def __call__(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.
|
|
89
|
+
def __call__(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.FunctionCall':
|
|
90
90
|
from pixeltable import exprs
|
|
91
91
|
|
|
92
92
|
# perform semantic analysis of special parameters 'order_by' and 'group_by'
|
pixeltable/func/function.py
CHANGED
|
@@ -3,12 +3,13 @@ from __future__ import annotations
|
|
|
3
3
|
import abc
|
|
4
4
|
import importlib
|
|
5
5
|
import inspect
|
|
6
|
-
from typing import Any, Callable,
|
|
6
|
+
from typing import Any, Callable, Optional
|
|
7
7
|
|
|
8
8
|
import sqlalchemy as sql
|
|
9
9
|
|
|
10
|
-
import pixeltable
|
|
10
|
+
import pixeltable as pxt
|
|
11
11
|
import pixeltable.type_system as ts
|
|
12
|
+
|
|
12
13
|
from .globals import resolve_symbol
|
|
13
14
|
from .signature import Signature
|
|
14
15
|
|
|
@@ -66,13 +67,13 @@ class Function(abc.ABC):
|
|
|
66
67
|
def help_str(self) -> str:
|
|
67
68
|
return self.display_name + str(self.signature)
|
|
68
69
|
|
|
69
|
-
def __call__(self, *args: Any, **kwargs: Any) -> '
|
|
70
|
+
def __call__(self, *args: Any, **kwargs: Any) -> 'pxt.exprs.FunctionCall':
|
|
70
71
|
from pixeltable import exprs
|
|
71
72
|
bound_args = self.signature.py_signature.bind(*args, **kwargs)
|
|
72
73
|
self.validate_call(bound_args.arguments)
|
|
73
74
|
return exprs.FunctionCall(self, bound_args.arguments)
|
|
74
75
|
|
|
75
|
-
def validate_call(self, bound_args:
|
|
76
|
+
def validate_call(self, bound_args: dict[str, Any]) -> None:
|
|
76
77
|
"""Override this to do custom validation of the arguments"""
|
|
77
78
|
pass
|
|
78
79
|
|
|
@@ -121,7 +122,7 @@ class Function(abc.ABC):
|
|
|
121
122
|
"""Print source code"""
|
|
122
123
|
print('source not available')
|
|
123
124
|
|
|
124
|
-
def as_dict(self) ->
|
|
125
|
+
def as_dict(self) -> dict:
|
|
125
126
|
"""
|
|
126
127
|
Return a serialized reference to the instance that can be passed to json.dumps() and converted back
|
|
127
128
|
to an instance with from_dict().
|
|
@@ -130,13 +131,13 @@ class Function(abc.ABC):
|
|
|
130
131
|
classpath = f'{self.__class__.__module__}.{self.__class__.__qualname__}'
|
|
131
132
|
return {'_classpath': classpath, **self._as_dict()}
|
|
132
133
|
|
|
133
|
-
def _as_dict(self) ->
|
|
134
|
+
def _as_dict(self) -> dict:
|
|
134
135
|
"""Default serialization: store the path to self (which includes the module path)"""
|
|
135
136
|
assert self.self_path is not None
|
|
136
137
|
return {'path': self.self_path}
|
|
137
138
|
|
|
138
139
|
@classmethod
|
|
139
|
-
def from_dict(cls, d:
|
|
140
|
+
def from_dict(cls, d: dict) -> Function:
|
|
140
141
|
"""
|
|
141
142
|
Turn dict that was produced by calling as_dict() into an instance of the correct Function subclass.
|
|
142
143
|
"""
|
|
@@ -147,14 +148,14 @@ class Function(abc.ABC):
|
|
|
147
148
|
return func_class._from_dict(d)
|
|
148
149
|
|
|
149
150
|
@classmethod
|
|
150
|
-
def _from_dict(cls, d:
|
|
151
|
+
def _from_dict(cls, d: dict) -> Function:
|
|
151
152
|
"""Default deserialization: load the symbol indicated by the stored symbol_path"""
|
|
152
153
|
assert 'path' in d and d['path'] is not None
|
|
153
154
|
instance = resolve_symbol(d['path'])
|
|
154
155
|
assert isinstance(instance, Function)
|
|
155
156
|
return instance
|
|
156
157
|
|
|
157
|
-
def to_store(self) ->
|
|
158
|
+
def to_store(self) -> tuple[dict, bytes]:
|
|
158
159
|
"""
|
|
159
160
|
Serialize the function to a format that can be stored in the Pixeltable store
|
|
160
161
|
Returns:
|
|
@@ -165,7 +166,7 @@ class Function(abc.ABC):
|
|
|
165
166
|
raise NotImplementedError()
|
|
166
167
|
|
|
167
168
|
@classmethod
|
|
168
|
-
def from_store(cls, name: Optional[str], md:
|
|
169
|
+
def from_store(cls, name: Optional[str], md: dict, binary_obj: bytes) -> Function:
|
|
169
170
|
"""
|
|
170
171
|
Create a Function instance from the serialized representation returned by to_store()
|
|
171
172
|
"""
|
pixeltable/functions/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from pixeltable.utils.code import local_public_names
|
|
2
2
|
|
|
3
|
-
from . import (anthropic, audio, fireworks, huggingface, image, json,
|
|
4
|
-
video, vision, whisper)
|
|
3
|
+
from . import (anthropic, audio, fireworks, huggingface, image, json, llama_cpp, mistralai, ollama, openai, string,
|
|
4
|
+
timestamp, together, video, vision, whisper)
|
|
5
5
|
from .globals import *
|
|
6
6
|
|
|
7
7
|
__all__ = local_public_names(__name__, exclude=['globals']) + local_public_names(globals.__name__)
|
pixeltable/functions/globals.py
CHANGED
|
@@ -36,9 +36,7 @@ class sum(func.Aggregator):
|
|
|
36
36
|
return self.sum
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
|
|
40
|
-
# TODO: find a way to have this type-checked
|
|
41
|
-
@sum.to_sql # type: ignore
|
|
39
|
+
@sum.to_sql
|
|
42
40
|
def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
|
|
43
41
|
# This can produce a Decimal. We are deliberately avoiding an explicit cast to a Bigint here, because that can
|
|
44
42
|
# cause overflows in Postgres. We're instead doing the conversion to the target type in SqlNode.__iter__().
|
|
@@ -58,7 +56,7 @@ class count(func.Aggregator):
|
|
|
58
56
|
return self.count
|
|
59
57
|
|
|
60
58
|
|
|
61
|
-
@count.to_sql
|
|
59
|
+
@count.to_sql
|
|
62
60
|
def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
|
|
63
61
|
return sql.sql.func.count(val)
|
|
64
62
|
|
|
@@ -82,7 +80,7 @@ class min(func.Aggregator):
|
|
|
82
80
|
return self.val
|
|
83
81
|
|
|
84
82
|
|
|
85
|
-
@min.to_sql
|
|
83
|
+
@min.to_sql
|
|
86
84
|
def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
|
|
87
85
|
return sql.sql.func.min(val)
|
|
88
86
|
|
|
@@ -106,7 +104,7 @@ class max(func.Aggregator):
|
|
|
106
104
|
return self.val
|
|
107
105
|
|
|
108
106
|
|
|
109
|
-
@max.to_sql
|
|
107
|
+
@max.to_sql
|
|
110
108
|
def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
|
|
111
109
|
return sql.sql.func.max(val)
|
|
112
110
|
|
|
@@ -134,7 +132,7 @@ class mean(func.Aggregator):
|
|
|
134
132
|
return self.sum / self.count
|
|
135
133
|
|
|
136
134
|
|
|
137
|
-
@mean.to_sql
|
|
135
|
+
@mean.to_sql
|
|
138
136
|
def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
|
|
139
137
|
return sql.sql.func.avg(val)
|
|
140
138
|
|
|
@@ -286,7 +286,7 @@ def vit_for_image_classification(
|
|
|
286
286
|
*,
|
|
287
287
|
model_id: str,
|
|
288
288
|
top_k: int = 5
|
|
289
|
-
) -> Batch[
|
|
289
|
+
) -> Batch[dict[str, Any]]:
|
|
290
290
|
"""
|
|
291
291
|
Computes image classifications for the specified image using a Vision Transformer (ViT) model.
|
|
292
292
|
`model_id` should be a reference to a pretrained [ViT Model](https://huggingface.co/docs/transformers/en/model_doc/vit).
|
|
@@ -307,24 +307,24 @@ def vit_for_image_classification(
|
|
|
307
307
|
top_k: The number of classes to return.
|
|
308
308
|
|
|
309
309
|
Returns:
|
|
310
|
-
A
|
|
311
|
-
in the following format:
|
|
310
|
+
A dictionary containing the output of the image classification model, in the following format:
|
|
312
311
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
312
|
+
```python
|
|
313
|
+
{
|
|
314
|
+
'scores': [0.325, 0.198, 0.105], # list of probabilities of the top-k most likely classes
|
|
315
|
+
'labels': [340, 353, 386], # list of class IDs for the top-k most likely classes
|
|
316
|
+
'label_text': ['zebra', 'gazelle', 'African elephant, Loxodonta africana'],
|
|
317
|
+
# corresponding text names of the top-k most likely classes
|
|
318
|
+
```
|
|
320
319
|
|
|
321
320
|
Examples:
|
|
322
321
|
Add a computed column that applies the model `google/vit-base-patch16-224` to an existing
|
|
323
|
-
Pixeltable column `image` of the table `tbl
|
|
322
|
+
Pixeltable column `image` of the table `tbl`, returning the 10 most likely classes for each image:
|
|
324
323
|
|
|
325
324
|
>>> tbl['image_class'] = vit_for_image_classification(
|
|
326
325
|
... tbl.image,
|
|
327
|
-
... model_id='google/vit-base-patch16-224'
|
|
326
|
+
... model_id='google/vit-base-patch16-224',
|
|
327
|
+
... top_k=10
|
|
328
328
|
... )
|
|
329
329
|
"""
|
|
330
330
|
env.Env.get().require_package('transformers')
|
|
@@ -344,15 +344,14 @@ def vit_for_image_classification(
|
|
|
344
344
|
probs = torch.softmax(logits, dim=-1)
|
|
345
345
|
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
|
|
346
346
|
|
|
347
|
+
# There is no official post_process method for ViT models; for consistency, we structure the output
|
|
348
|
+
# the same way as the output of the DETR model given by `post_process_object_detection`.
|
|
347
349
|
return [
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
}
|
|
354
|
-
for k in range(top_k_probs.shape[1])
|
|
355
|
-
]
|
|
350
|
+
{
|
|
351
|
+
'scores': [top_k_probs[n, k].item() for k in range(top_k_probs.shape[1])],
|
|
352
|
+
'labels': [top_k_indices[n, k].item() for k in range(top_k_probs.shape[1])],
|
|
353
|
+
'label_text': [model.config.id2label[top_k_indices[n, k].item()] for k in range(top_k_probs.shape[1])],
|
|
354
|
+
}
|
|
356
355
|
for n in range(top_k_probs.shape[0])
|
|
357
356
|
]
|
|
358
357
|
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
3
|
+
|
|
4
|
+
import pixeltable as pxt
|
|
5
|
+
import pixeltable.exceptions as excs
|
|
6
|
+
from pixeltable.env import Env
|
|
7
|
+
from pixeltable.utils.code import local_public_names
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import llama_cpp
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pxt.udf
|
|
14
|
+
def create_chat_completion(
|
|
15
|
+
messages: list[dict],
|
|
16
|
+
*,
|
|
17
|
+
model_path: Optional[str] = None,
|
|
18
|
+
repo_id: Optional[str] = None,
|
|
19
|
+
repo_filename: Optional[str] = None,
|
|
20
|
+
args: Optional[dict[str, Any]] = None,
|
|
21
|
+
) -> dict:
|
|
22
|
+
"""
|
|
23
|
+
Generate a chat completion from a list of messages.
|
|
24
|
+
|
|
25
|
+
The model can be specified either as a local path, or as a repo_id and repo_filename that reference a pretrained
|
|
26
|
+
model on the Hugging Face model hub. Exactly one of `model_path` or `repo_id` must be provided; if `model_path`
|
|
27
|
+
is provided, then an optional `repo_filename` can also be specified.
|
|
28
|
+
|
|
29
|
+
For additional details, see the
|
|
30
|
+
[llama_cpp create_chat_completions documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
messages: A list of messages to generate a response for.
|
|
34
|
+
model_path: Path to the model (if using a local model).
|
|
35
|
+
repo_id: The Hugging Face model repo id (if using a pretrained model).
|
|
36
|
+
repo_filename: A filename or glob pattern to match the model file in the repo (optional, if using a
|
|
37
|
+
pretrained model).
|
|
38
|
+
args: Additional arguments to pass to the `create_chat_completions` call, such as `max_tokens`, `temperature`,
|
|
39
|
+
`top_p`, and `top_k`. For details, see the
|
|
40
|
+
[llama_cpp create_chat_completions documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
|
|
41
|
+
"""
|
|
42
|
+
Env.get().require_package('llama_cpp', min_version=[0, 3, 1])
|
|
43
|
+
|
|
44
|
+
if args is None:
|
|
45
|
+
args = {}
|
|
46
|
+
|
|
47
|
+
if (model_path is None) == (repo_id is None):
|
|
48
|
+
raise excs.Error('Exactly one of `model_path` or `repo_id` must be provided.')
|
|
49
|
+
if (repo_id is None) and (repo_filename is not None):
|
|
50
|
+
raise excs.Error('`repo_filename` can only be provided along with `repo_id`.')
|
|
51
|
+
|
|
52
|
+
n_gpu_layers = -1 if _is_gpu_available() else 0 # 0 = CPU only, -1 = offload all layers to GPU
|
|
53
|
+
|
|
54
|
+
if model_path is not None:
|
|
55
|
+
llm = _lookup_local_model(model_path, n_gpu_layers)
|
|
56
|
+
else:
|
|
57
|
+
Env.get().require_package('huggingface_hub')
|
|
58
|
+
llm = _lookup_pretrained_model(repo_id, repo_filename, n_gpu_layers)
|
|
59
|
+
return llm.create_chat_completion(messages, **args) # type: ignore
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _is_gpu_available() -> bool:
|
|
63
|
+
import llama_cpp
|
|
64
|
+
|
|
65
|
+
global _IS_GPU_AVAILABLE
|
|
66
|
+
if _IS_GPU_AVAILABLE is None:
|
|
67
|
+
llama_cpp_path = Path(llama_cpp.__file__).parent
|
|
68
|
+
lib = llama_cpp.llama_cpp.load_shared_library('llama', llama_cpp_path / 'lib')
|
|
69
|
+
_IS_GPU_AVAILABLE = bool(lib.llama_supports_gpu_offload())
|
|
70
|
+
|
|
71
|
+
return _IS_GPU_AVAILABLE
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _lookup_local_model(model_path: str, n_gpu_layers: int) -> 'llama_cpp.Llama':
|
|
75
|
+
import llama_cpp
|
|
76
|
+
|
|
77
|
+
key = (model_path, None, n_gpu_layers)
|
|
78
|
+
if key not in _model_cache:
|
|
79
|
+
llm = llama_cpp.Llama(model_path, n_gpu_layers=n_gpu_layers)
|
|
80
|
+
_model_cache[key] = llm
|
|
81
|
+
return _model_cache[key]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _lookup_pretrained_model(repo_id: str, filename: Optional[str], n_gpu_layers: int) -> 'llama_cpp.Llama':
|
|
85
|
+
import llama_cpp
|
|
86
|
+
|
|
87
|
+
key = (repo_id, filename, n_gpu_layers)
|
|
88
|
+
if key not in _model_cache:
|
|
89
|
+
llm = llama_cpp.Llama.from_pretrained(
|
|
90
|
+
repo_id=repo_id,
|
|
91
|
+
filename=filename,
|
|
92
|
+
n_gpu_layers=n_gpu_layers
|
|
93
|
+
)
|
|
94
|
+
_model_cache[key] = llm
|
|
95
|
+
return _model_cache[key]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
_model_cache: dict[tuple[str, str, int], Any] = {}
|
|
99
|
+
_IS_GPU_AVAILABLE: Optional[bool] = None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
__all__ = local_public_names(__name__)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def __dir__():
|
|
106
|
+
return __all__
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Optional
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
import pixeltable as pxt
|
|
6
|
+
from pixeltable import env
|
|
7
|
+
from pixeltable.func import Batch
|
|
8
|
+
from pixeltable.utils.code import local_public_names
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import ollama
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@env.register_client('ollama')
|
|
15
|
+
def _(host: str) -> 'ollama.Client':
|
|
16
|
+
import ollama
|
|
17
|
+
return ollama.Client(host=host)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _ollama_client() -> Optional['ollama.Client']:
|
|
21
|
+
try:
|
|
22
|
+
return env.Env.get().get_client('ollama')
|
|
23
|
+
except Exception:
|
|
24
|
+
return None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pxt.udf
|
|
28
|
+
def generate(
|
|
29
|
+
prompt: str,
|
|
30
|
+
*,
|
|
31
|
+
model: str,
|
|
32
|
+
suffix: str = '',
|
|
33
|
+
system: str = '',
|
|
34
|
+
template: str = '',
|
|
35
|
+
context: Optional[list[int]] = None,
|
|
36
|
+
raw: bool = False,
|
|
37
|
+
format: str = '',
|
|
38
|
+
options: Optional[dict] = None,
|
|
39
|
+
) -> dict:
|
|
40
|
+
"""
|
|
41
|
+
Generate a response for a given prompt with a provided model.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
prompt: The prompt to generate a response for.
|
|
45
|
+
model: The model name.
|
|
46
|
+
suffix: The text after the model response.
|
|
47
|
+
format: The format of the response; must be one of `'json'` or `''` (the empty string).
|
|
48
|
+
system: System message.
|
|
49
|
+
template: Prompt template to use.
|
|
50
|
+
context: The context parameter returned from a previous call to `generate()`.
|
|
51
|
+
raw: If `True`, no formatting will be applied to the prompt.
|
|
52
|
+
options: Additional options to pass to the `chat` call, such as `max_tokens`, `temperature`, `top_p`, and `top_k`.
|
|
53
|
+
For details, see the
|
|
54
|
+
[Valid Parameters and Values](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
|
|
55
|
+
section of the Ollama documentation.
|
|
56
|
+
"""
|
|
57
|
+
env.Env.get().require_package('ollama')
|
|
58
|
+
import ollama
|
|
59
|
+
|
|
60
|
+
client = _ollama_client() or ollama
|
|
61
|
+
return client.generate(
|
|
62
|
+
model=model,
|
|
63
|
+
prompt=prompt,
|
|
64
|
+
suffix=suffix,
|
|
65
|
+
system=system,
|
|
66
|
+
template=template,
|
|
67
|
+
context=context,
|
|
68
|
+
raw=raw,
|
|
69
|
+
format=format,
|
|
70
|
+
options=options,
|
|
71
|
+
) # type: ignore[call-overload]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@pxt.udf
|
|
75
|
+
def chat(
|
|
76
|
+
messages: list[dict],
|
|
77
|
+
*,
|
|
78
|
+
model: str,
|
|
79
|
+
tools: Optional[list[dict]] = None,
|
|
80
|
+
format: str = '',
|
|
81
|
+
options: Optional[dict] = None,
|
|
82
|
+
) -> dict:
|
|
83
|
+
"""
|
|
84
|
+
Generate the next message in a chat with a provided model.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
messages: The messages of the chat.
|
|
88
|
+
model: The model name.
|
|
89
|
+
tools: Tools for the model to use.
|
|
90
|
+
format: The format of the response; must be one of `'json'` or `''` (the empty string).
|
|
91
|
+
options: Additional options to pass to the `chat` call, such as `max_tokens`, `temperature`, `top_p`, and `top_k`.
|
|
92
|
+
For details, see the
|
|
93
|
+
[Valid Parameters and Values](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
|
|
94
|
+
section of the Ollama documentation.
|
|
95
|
+
"""
|
|
96
|
+
env.Env.get().require_package('ollama')
|
|
97
|
+
import ollama
|
|
98
|
+
|
|
99
|
+
client = _ollama_client() or ollama
|
|
100
|
+
return client.chat(
|
|
101
|
+
model=model,
|
|
102
|
+
messages=messages,
|
|
103
|
+
tools=tools,
|
|
104
|
+
format=format,
|
|
105
|
+
options=options,
|
|
106
|
+
) # type: ignore[call-overload]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pxt.udf(batch_size=16)
|
|
110
|
+
def embed(
|
|
111
|
+
input: Batch[str],
|
|
112
|
+
*,
|
|
113
|
+
model: str,
|
|
114
|
+
truncate: bool = True,
|
|
115
|
+
options: Optional[dict] = None,
|
|
116
|
+
) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
117
|
+
"""
|
|
118
|
+
Generate embeddings from a model.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
input: The input text to generate embeddings for.
|
|
122
|
+
model: The model name.
|
|
123
|
+
truncate: Truncates the end of each input to fit within context length.
|
|
124
|
+
Returns error if false and context length is exceeded.
|
|
125
|
+
options: Additional options to pass to the `embed` call.
|
|
126
|
+
For details, see the
|
|
127
|
+
[Valid Parameters and Values](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
|
|
128
|
+
section of the Ollama documentation.
|
|
129
|
+
"""
|
|
130
|
+
env.Env.get().require_package('ollama')
|
|
131
|
+
import ollama
|
|
132
|
+
|
|
133
|
+
client = _ollama_client() or ollama
|
|
134
|
+
results = client.embed(
|
|
135
|
+
model=model,
|
|
136
|
+
input=input,
|
|
137
|
+
truncate=truncate,
|
|
138
|
+
options=options, # type: ignore[arg-type]
|
|
139
|
+
)
|
|
140
|
+
return [np.array(data, dtype=np.float64) for data in results['embeddings']]
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
__all__ = local_public_names(__name__)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def __dir__():
|
|
147
|
+
return __all__
|