pixeltable 0.2.18__py3-none-any.whl → 0.2.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +1 -1
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/insertable_table.py +9 -7
- pixeltable/catalog/table.py +18 -5
- pixeltable/catalog/table_version.py +1 -1
- pixeltable/catalog/view.py +1 -1
- pixeltable/dataframe.py +1 -1
- pixeltable/env.py +140 -40
- pixeltable/exceptions.py +12 -5
- pixeltable/exec/component_iteration_node.py +63 -42
- pixeltable/exprs/__init__.py +1 -2
- pixeltable/exprs/expr.py +5 -6
- pixeltable/exprs/function_call.py +8 -10
- pixeltable/exprs/inline_expr.py +200 -0
- pixeltable/exprs/json_path.py +3 -6
- pixeltable/ext/functions/whisperx.py +2 -0
- pixeltable/ext/functions/yolox.py +5 -3
- pixeltable/functions/huggingface.py +89 -12
- pixeltable/functions/image.py +3 -3
- pixeltable/functions/together.py +37 -16
- pixeltable/functions/vision.py +43 -21
- pixeltable/functions/whisper.py +3 -0
- pixeltable/globals.py +7 -1
- pixeltable/io/globals.py +1 -1
- pixeltable/io/hf_datasets.py +3 -3
- pixeltable/iterators/document.py +1 -1
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_18.py +1 -1
- pixeltable/metadata/converters/convert_20.py +56 -0
- pixeltable/metadata/converters/util.py +29 -4
- pixeltable/metadata/notes.py +1 -0
- pixeltable/tool/create_test_db_dump.py +15 -4
- pixeltable/type_system.py +3 -1
- pixeltable/utils/filecache.py +126 -79
- pixeltable-0.2.20.dist-info/LICENSE +201 -0
- {pixeltable-0.2.18.dist-info → pixeltable-0.2.20.dist-info}/METADATA +16 -6
- {pixeltable-0.2.18.dist-info → pixeltable-0.2.20.dist-info}/RECORD +39 -39
- pixeltable/exprs/inline_array.py +0 -117
- pixeltable/exprs/inline_dict.py +0 -104
- pixeltable-0.2.18.dist-info/LICENSE +0 -18
- {pixeltable-0.2.18.dist-info → pixeltable-0.2.20.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.18.dist-info → pixeltable-0.2.20.dist-info}/entry_points.txt +0 -0
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import json
|
|
5
5
|
import sys
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import Any, Optional
|
|
7
7
|
|
|
8
8
|
import sqlalchemy as sql
|
|
9
9
|
|
|
@@ -11,10 +11,10 @@ import pixeltable.catalog as catalog
|
|
|
11
11
|
import pixeltable.exceptions as excs
|
|
12
12
|
import pixeltable.func as func
|
|
13
13
|
import pixeltable.type_system as ts
|
|
14
|
+
|
|
14
15
|
from .data_row import DataRow
|
|
15
16
|
from .expr import Expr
|
|
16
|
-
from .
|
|
17
|
-
from .inline_dict import InlineDict
|
|
17
|
+
from .inline_expr import InlineDict, InlineList
|
|
18
18
|
from .row_builder import RowBuilder
|
|
19
19
|
from .rowid_ref import RowidRef
|
|
20
20
|
from .sql_element_cache import SqlElementCache
|
|
@@ -53,7 +53,7 @@ class FunctionCall(Expr):
|
|
|
53
53
|
super().__init__(fn.call_return_type(bound_args))
|
|
54
54
|
self.fn = fn
|
|
55
55
|
self.is_method_call = is_method_call
|
|
56
|
-
self.normalize_args(signature, bound_args)
|
|
56
|
+
self.normalize_args(fn.name, signature, bound_args)
|
|
57
57
|
|
|
58
58
|
self.agg_init_args = {}
|
|
59
59
|
if self.is_agg_fn_call:
|
|
@@ -143,7 +143,7 @@ class FunctionCall(Expr):
|
|
|
143
143
|
return super().default_column_name()
|
|
144
144
|
|
|
145
145
|
@classmethod
|
|
146
|
-
def normalize_args(cls, signature: func.Signature, bound_args: dict[str, Any]) -> None:
|
|
146
|
+
def normalize_args(cls, fn_name: str, signature: func.Signature, bound_args: dict[str, Any]) -> None:
|
|
147
147
|
"""Converts all args to Exprs and checks that they are compatible with signature.
|
|
148
148
|
|
|
149
149
|
Updates bound_args in place, where necessary.
|
|
@@ -163,9 +163,7 @@ class FunctionCall(Expr):
|
|
|
163
163
|
|
|
164
164
|
if isinstance(arg, list) or isinstance(arg, tuple):
|
|
165
165
|
try:
|
|
166
|
-
|
|
167
|
-
is_json = is_var_param or (param.col_type is not None and param.col_type.is_json_type())
|
|
168
|
-
arg = InlineArray(arg, force_json=is_json)
|
|
166
|
+
arg = InlineList(arg)
|
|
169
167
|
bound_args[param_name] = arg
|
|
170
168
|
continue
|
|
171
169
|
except excs.Error:
|
|
@@ -177,7 +175,7 @@ class FunctionCall(Expr):
|
|
|
177
175
|
try:
|
|
178
176
|
_ = json.dumps(arg)
|
|
179
177
|
except TypeError:
|
|
180
|
-
raise excs.Error(f'Argument for parameter {param_name!r} is not json-serializable: {arg}')
|
|
178
|
+
raise excs.Error(f'Argument for parameter {param_name!r} is not json-serializable: {arg} (of type {type(arg)})')
|
|
181
179
|
if arg is not None:
|
|
182
180
|
try:
|
|
183
181
|
param_type = param.col_type
|
|
@@ -215,7 +213,7 @@ class FunctionCall(Expr):
|
|
|
215
213
|
or (arg.col_type.is_json_type() and param.col_type.is_scalar_type())
|
|
216
214
|
):
|
|
217
215
|
raise excs.Error(
|
|
218
|
-
f'Parameter {param_name}: argument type {arg.col_type} does not match parameter type '
|
|
216
|
+
f'Parameter {param_name} (in function {fn_name}): argument type {arg.col_type} does not match parameter type '
|
|
219
217
|
f'{param.col_type}')
|
|
220
218
|
|
|
221
219
|
def _equals(self, other: FunctionCall) -> bool:
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from typing import Any, Iterable, Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import sqlalchemy as sql
|
|
8
|
+
|
|
9
|
+
import pixeltable.exceptions as excs
|
|
10
|
+
import pixeltable.type_system as ts
|
|
11
|
+
|
|
12
|
+
from .data_row import DataRow
|
|
13
|
+
from .expr import Expr
|
|
14
|
+
from .literal import Literal
|
|
15
|
+
from .row_builder import RowBuilder
|
|
16
|
+
from .sql_element_cache import SqlElementCache
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class InlineArray(Expr):
|
|
20
|
+
"""
|
|
21
|
+
Array 'literal' which can use Exprs as values.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, elements: Iterable):
|
|
25
|
+
exprs = []
|
|
26
|
+
for el in elements:
|
|
27
|
+
if isinstance(el, Expr):
|
|
28
|
+
exprs.append(el)
|
|
29
|
+
elif isinstance(el, list) or isinstance(el, tuple):
|
|
30
|
+
exprs.append(InlineArray(el))
|
|
31
|
+
else:
|
|
32
|
+
exprs.append(Literal(el))
|
|
33
|
+
|
|
34
|
+
inferred_element_type: Optional[ts.ColumnType] = ts.InvalidType()
|
|
35
|
+
for i, expr in enumerate(exprs):
|
|
36
|
+
supertype = inferred_element_type.supertype(expr.col_type)
|
|
37
|
+
if supertype is None:
|
|
38
|
+
raise excs.Error(
|
|
39
|
+
f'Could not infer element type of array: element of type `{expr.col_type}` at index {i} '
|
|
40
|
+
f'is not compatible with type `{inferred_element_type}` of preceding elements'
|
|
41
|
+
)
|
|
42
|
+
inferred_element_type = supertype
|
|
43
|
+
|
|
44
|
+
if inferred_element_type.is_scalar_type():
|
|
45
|
+
col_type = ts.ArrayType((len(exprs),), inferred_element_type)
|
|
46
|
+
elif inferred_element_type.is_array_type():
|
|
47
|
+
assert isinstance(inferred_element_type, ts.ArrayType)
|
|
48
|
+
col_type = ts.ArrayType(
|
|
49
|
+
(len(exprs), *inferred_element_type.shape),
|
|
50
|
+
ts.ColumnType.make_type(inferred_element_type.dtype)
|
|
51
|
+
)
|
|
52
|
+
else:
|
|
53
|
+
raise excs.Error(f'Element type is not a valid dtype for an array: {inferred_element_type}')
|
|
54
|
+
|
|
55
|
+
super().__init__(col_type)
|
|
56
|
+
self.components.extend(exprs)
|
|
57
|
+
self.id = self._create_id()
|
|
58
|
+
|
|
59
|
+
def __str__(self) -> str:
|
|
60
|
+
elem_strs = [str(expr) for expr in self.components]
|
|
61
|
+
return f'[{", ".join(elem_strs)}]'
|
|
62
|
+
|
|
63
|
+
def _equals(self, _: InlineArray) -> bool:
|
|
64
|
+
return True # Always true if components match
|
|
65
|
+
|
|
66
|
+
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
|
|
70
|
+
data_row[self.slot_idx] = np.array([data_row[el.slot_idx] for el in self.components])
|
|
71
|
+
|
|
72
|
+
def _as_dict(self) -> dict:
|
|
73
|
+
return super()._as_dict()
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def _from_dict(cls, _: dict, components: list[Expr]) -> Expr:
|
|
77
|
+
try:
|
|
78
|
+
return cls(components)
|
|
79
|
+
except excs.Error:
|
|
80
|
+
# For legacy compatibility reasons, we need to try constructing as an `InlineList`.
|
|
81
|
+
# This is because in schema versions <= 19, `InlineArray` was serialized incorrectly, and
|
|
82
|
+
# there is no way to determine the correct expression type until the subexpressions are
|
|
83
|
+
# loaded and their types are known.
|
|
84
|
+
return InlineList(components)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class InlineList(Expr):
|
|
88
|
+
"""
|
|
89
|
+
List 'literal' which can use Exprs as values.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, elements: Iterable):
|
|
93
|
+
exprs = []
|
|
94
|
+
for el in elements:
|
|
95
|
+
if isinstance(el, Expr):
|
|
96
|
+
exprs.append(el)
|
|
97
|
+
elif isinstance(el, list) or isinstance(el, tuple):
|
|
98
|
+
exprs.append(InlineList(el))
|
|
99
|
+
elif isinstance(el, dict):
|
|
100
|
+
exprs.append(InlineDict(el))
|
|
101
|
+
else:
|
|
102
|
+
exprs.append(Literal(el))
|
|
103
|
+
|
|
104
|
+
super().__init__(ts.JsonType())
|
|
105
|
+
self.components.extend(exprs)
|
|
106
|
+
self.id = self._create_id()
|
|
107
|
+
|
|
108
|
+
def __str__(self) -> str:
|
|
109
|
+
elem_strs = [str(expr) for expr in self.components]
|
|
110
|
+
return f'[{", ".join(elem_strs)}]'
|
|
111
|
+
|
|
112
|
+
def _equals(self, _: InlineList) -> bool:
|
|
113
|
+
return True # Always true if components match
|
|
114
|
+
|
|
115
|
+
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
def eval(self, data_row: DataRow, _: RowBuilder) -> None:
|
|
119
|
+
data_row[self.slot_idx] = [data_row[el.slot_idx] for el in self.components]
|
|
120
|
+
|
|
121
|
+
def _as_dict(self) -> dict:
|
|
122
|
+
return super()._as_dict()
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
125
|
+
def _from_dict(cls, _: dict, components: list[Expr]) -> Expr:
|
|
126
|
+
return cls(components)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class InlineDict(Expr):
|
|
130
|
+
"""
|
|
131
|
+
Dictionary 'literal' which can use Exprs as values.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
keys: list[str]
|
|
135
|
+
|
|
136
|
+
def __init__(self, d: dict[str, Any]):
|
|
137
|
+
self.keys = []
|
|
138
|
+
exprs: list[Expr] = []
|
|
139
|
+
for key, val in d.items():
|
|
140
|
+
if not isinstance(key, str):
|
|
141
|
+
raise excs.Error(f'Dictionary requires string keys; {key} has type {type(key)}')
|
|
142
|
+
self.keys.append(key)
|
|
143
|
+
if isinstance(val, Expr):
|
|
144
|
+
exprs.append(val)
|
|
145
|
+
elif isinstance(val, dict):
|
|
146
|
+
exprs.append(InlineDict(val))
|
|
147
|
+
elif isinstance(val, list) or isinstance(val, tuple):
|
|
148
|
+
exprs.append(InlineList(val))
|
|
149
|
+
else:
|
|
150
|
+
exprs.append(Literal(val))
|
|
151
|
+
|
|
152
|
+
super().__init__(ts.JsonType())
|
|
153
|
+
self.components.extend(exprs)
|
|
154
|
+
self.id = self._create_id()
|
|
155
|
+
|
|
156
|
+
def __str__(self) -> str:
|
|
157
|
+
item_strs = list(f"'{key}': {str(expr)}" for key, expr in zip(self.keys, self.components))
|
|
158
|
+
return '{' + ', '.join(item_strs) + '}'
|
|
159
|
+
|
|
160
|
+
def _equals(self, other: InlineDict) -> bool:
|
|
161
|
+
# The dict values are just the components, which have already been checked
|
|
162
|
+
return self.keys == other.keys
|
|
163
|
+
|
|
164
|
+
def _id_attrs(self) -> list[tuple[str, Any]]:
|
|
165
|
+
return super()._id_attrs() + [('keys', self.keys)]
|
|
166
|
+
|
|
167
|
+
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
168
|
+
return None
|
|
169
|
+
|
|
170
|
+
def eval(self, data_row: DataRow, _: RowBuilder) -> None:
|
|
171
|
+
assert len(self.keys) == len(self.components)
|
|
172
|
+
data_row[self.slot_idx] = {
|
|
173
|
+
key: data_row[expr.slot_idx]
|
|
174
|
+
for key, expr in zip(self.keys, self.components)
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
def to_kwargs(self) -> dict[str, Any]:
|
|
178
|
+
"""Deconstructs this expression into a dictionary by recursively unwrapping all Literals,
|
|
179
|
+
InlineDicts, and InlineLists."""
|
|
180
|
+
return InlineDict._to_kwarg_element(self)
|
|
181
|
+
|
|
182
|
+
@classmethod
|
|
183
|
+
def _to_kwarg_element(cls, expr: Expr) -> Any:
|
|
184
|
+
if isinstance(expr, Literal):
|
|
185
|
+
return expr.val
|
|
186
|
+
if isinstance(expr, InlineDict):
|
|
187
|
+
return {key: cls._to_kwarg_element(val) for key, val in zip(expr.keys, expr.components)}
|
|
188
|
+
if isinstance(expr, InlineList):
|
|
189
|
+
return [cls._to_kwarg_element(el) for el in expr.components]
|
|
190
|
+
return expr
|
|
191
|
+
|
|
192
|
+
def _as_dict(self) -> dict[str, Any]:
|
|
193
|
+
return {'keys': self.keys, **super()._as_dict()}
|
|
194
|
+
|
|
195
|
+
@classmethod
|
|
196
|
+
def _from_dict(cls, d: dict, components: list[Expr]) -> Expr:
|
|
197
|
+
assert 'keys' in d
|
|
198
|
+
assert len(d['keys']) == len(components)
|
|
199
|
+
arg = dict(zip(d['keys'], components))
|
|
200
|
+
return InlineDict(arg)
|
pixeltable/exprs/json_path.py
CHANGED
|
@@ -105,12 +105,9 @@ class JsonPath(Expr):
|
|
|
105
105
|
return JsonPath(self._anchor, self.path_elements + [name])
|
|
106
106
|
|
|
107
107
|
def __getitem__(self, index: object) -> 'JsonPath':
|
|
108
|
-
if isinstance(index, str):
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
elif not isinstance(index, (int, slice)):
|
|
112
|
-
raise excs.Error(f'Invalid json list index: {index}')
|
|
113
|
-
return JsonPath(self._anchor, self.path_elements + [index])
|
|
108
|
+
if isinstance(index, (int, slice, str)):
|
|
109
|
+
return JsonPath(self._anchor, self.path_elements + [index])
|
|
110
|
+
raise excs.Error(f'Invalid json list index: {index}')
|
|
114
111
|
|
|
115
112
|
def __rshift__(self, other: object) -> 'JsonMapper':
|
|
116
113
|
rhs_expr = Expr.from_object(other)
|
|
@@ -19,6 +19,8 @@ def transcribe(
|
|
|
19
19
|
equivalent to the WhisperX `transcribe` function, as described in the
|
|
20
20
|
[WhisperX library documentation](https://github.com/m-bain/whisperX).
|
|
21
21
|
|
|
22
|
+
WhisperX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
|
|
23
|
+
|
|
22
24
|
__Requirements:__
|
|
23
25
|
|
|
24
26
|
- `pip install whisperx`
|
|
@@ -26,8 +26,7 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
|
|
|
26
26
|
Computes YOLOX object detections for the specified image. `model_id` should reference one of the models
|
|
27
27
|
defined in the [YOLOX documentation](https://github.com/Megvii-BaseDetection/YOLOX).
|
|
28
28
|
|
|
29
|
-
YOLOX
|
|
30
|
-
intended for use in production applications.
|
|
29
|
+
YOLOX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
|
|
31
30
|
|
|
32
31
|
__Requirements__:
|
|
33
32
|
|
|
@@ -79,6 +78,8 @@ def yolo_to_coco(detections: dict) -> list:
|
|
|
79
78
|
"""
|
|
80
79
|
Converts the output of a YOLOX object detection model to COCO format.
|
|
81
80
|
|
|
81
|
+
YOLOX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
|
|
82
|
+
|
|
82
83
|
Args:
|
|
83
84
|
detections: The output of a YOLOX object detection model, as returned by `yolox`.
|
|
84
85
|
|
|
@@ -89,7 +90,8 @@ def yolo_to_coco(detections: dict) -> list:
|
|
|
89
90
|
Add a computed column that converts the output `tbl.detections` to COCO format, where `tbl.image`
|
|
90
91
|
is the image for which detections were computed:
|
|
91
92
|
|
|
92
|
-
>>> tbl['
|
|
93
|
+
>>> tbl['detections'] = yolox(tbl.image, model_id='yolox_m', threshold=0.8)
|
|
94
|
+
... tbl['detections_coco'] = yolo_to_coco(tbl.detections)
|
|
93
95
|
"""
|
|
94
96
|
bboxes, labels = detections['bboxes'], detections['labels']
|
|
95
97
|
num_annotations = len(detections['bboxes'])
|
|
@@ -185,7 +185,7 @@ def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndar
|
|
|
185
185
|
|
|
186
186
|
Examples:
|
|
187
187
|
Add a computed column that applies the model `openai/clip-vit-base-patch32` to an existing
|
|
188
|
-
Pixeltable column `
|
|
188
|
+
Pixeltable column `image` of the table `tbl`:
|
|
189
189
|
|
|
190
190
|
>>> tbl['result'] = clip_image(tbl.image, model_id='openai/clip-vit-base-patch32')
|
|
191
191
|
"""
|
|
@@ -228,24 +228,24 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
|
|
|
228
228
|
|
|
229
229
|
Args:
|
|
230
230
|
image: The image to embed.
|
|
231
|
-
model_id: The pretrained model to use for
|
|
231
|
+
model_id: The pretrained model to use for object detection.
|
|
232
232
|
|
|
233
233
|
Returns:
|
|
234
234
|
A dictionary containing the output of the object detection model, in the following format:
|
|
235
235
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
236
|
+
```python
|
|
237
|
+
{
|
|
238
|
+
'scores': [0.99, 0.999], # list of confidence scores for each detected object
|
|
239
|
+
'labels': [25, 25], # list of COCO class labels for each detected object
|
|
240
|
+
'label_text': ['giraffe', 'giraffe'], # corresponding text names of class labels
|
|
241
|
+
'boxes': [[51.942, 356.174, 181.481, 413.975], [383.225, 58.66, 605.64, 361.346]]
|
|
242
|
+
# list of bounding boxes for each detected object, as [x1, y1, x2, y2]
|
|
243
|
+
}
|
|
244
|
+
```
|
|
245
245
|
|
|
246
246
|
Examples:
|
|
247
247
|
Add a computed column that applies the model `facebook/detr-resnet-50` to an existing
|
|
248
|
-
Pixeltable column `
|
|
248
|
+
Pixeltable column `image` of the table `tbl`:
|
|
249
249
|
|
|
250
250
|
>>> tbl['detections'] = detr_for_object_detection(
|
|
251
251
|
... tbl.image,
|
|
@@ -282,6 +282,83 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
|
|
|
282
282
|
]
|
|
283
283
|
|
|
284
284
|
|
|
285
|
+
@pxt.udf(batch_size=4)
|
|
286
|
+
def vit_for_image_classification(
|
|
287
|
+
image: Batch[PIL.Image.Image],
|
|
288
|
+
*,
|
|
289
|
+
model_id: str,
|
|
290
|
+
top_k: int = 5
|
|
291
|
+
) -> Batch[list[dict[str, Any]]]:
|
|
292
|
+
"""
|
|
293
|
+
Computes image classifications for the specified image using a Vision Transformer (ViT) model.
|
|
294
|
+
`model_id` should be a reference to a pretrained [ViT Model](https://huggingface.co/docs/transformers/en/model_doc/vit).
|
|
295
|
+
|
|
296
|
+
__Note:__ Be sure the model is a ViT model that is trained for image classification (that is, a model designed for
|
|
297
|
+
use with the
|
|
298
|
+
[ViTForImageClassification](https://huggingface.co/docs/transformers/en/model_doc/vit#transformers.ViTForImageClassification)
|
|
299
|
+
class), such as `google/vit-base-patch16-224`. General feature-extraction models such as
|
|
300
|
+
`google/vit-base-patch16-224-in21k` will not produce the desired results.
|
|
301
|
+
|
|
302
|
+
__Requirements:__
|
|
303
|
+
|
|
304
|
+
- `pip install transformers`
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
image: The image to classify.
|
|
308
|
+
model_id: The pretrained model to use for the classification.
|
|
309
|
+
top_k: The number of classes to return.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
A list of the `top_k` highest-scoring classes for each image. Each element in the list is a dictionary
|
|
313
|
+
in the following format:
|
|
314
|
+
|
|
315
|
+
```python
|
|
316
|
+
{
|
|
317
|
+
'p': 0.230, # class probability
|
|
318
|
+
'class': 935, # class ID
|
|
319
|
+
'label': 'mashed potato', # class label
|
|
320
|
+
}
|
|
321
|
+
```
|
|
322
|
+
|
|
323
|
+
Examples:
|
|
324
|
+
Add a computed column that applies the model `google/vit-base-patch16-224` to an existing
|
|
325
|
+
Pixeltable column `image` of the table `tbl`:
|
|
326
|
+
|
|
327
|
+
>>> tbl['image_class'] = vit_for_image_classification(
|
|
328
|
+
... tbl.image,
|
|
329
|
+
... model_id='google/vit-base-patch16-224'
|
|
330
|
+
... )
|
|
331
|
+
"""
|
|
332
|
+
env.Env.get().require_package('transformers')
|
|
333
|
+
device = resolve_torch_device('auto')
|
|
334
|
+
import torch
|
|
335
|
+
from transformers import ViTImageProcessor, ViTForImageClassification
|
|
336
|
+
|
|
337
|
+
model: ViTForImageClassification = _lookup_model(model_id, ViTForImageClassification.from_pretrained, device=device)
|
|
338
|
+
processor = _lookup_processor(model_id, ViTImageProcessor.from_pretrained)
|
|
339
|
+
normalized_images = [normalize_image_mode(img) for img in image]
|
|
340
|
+
|
|
341
|
+
with torch.no_grad():
|
|
342
|
+
inputs = processor(images=normalized_images, return_tensors='pt')
|
|
343
|
+
outputs = model(**inputs.to(device))
|
|
344
|
+
logits = outputs.logits
|
|
345
|
+
|
|
346
|
+
probs = torch.softmax(logits, dim=-1)
|
|
347
|
+
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
|
|
348
|
+
|
|
349
|
+
return [
|
|
350
|
+
[
|
|
351
|
+
{
|
|
352
|
+
'p': top_k_probs[n, k].item(),
|
|
353
|
+
'class': top_k_indices[n, k].item(),
|
|
354
|
+
'label': model.config.id2label[top_k_indices[n, k].item()],
|
|
355
|
+
}
|
|
356
|
+
for k in range(top_k_probs.shape[1])
|
|
357
|
+
]
|
|
358
|
+
for n in range(top_k_probs.shape[0])
|
|
359
|
+
]
|
|
360
|
+
|
|
361
|
+
|
|
285
362
|
@pxt.udf
|
|
286
363
|
def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str, Any]:
|
|
287
364
|
"""
|
pixeltable/functions/image.py
CHANGED
|
@@ -92,7 +92,7 @@ def _(self: Expr, mode: str) -> ts.ColumnType:
|
|
|
92
92
|
|
|
93
93
|
|
|
94
94
|
# Image.crop()
|
|
95
|
-
@func.udf(substitute_fn=PIL.Image.Image.crop,
|
|
95
|
+
@func.udf(substitute_fn=PIL.Image.Image.crop, is_method=True)
|
|
96
96
|
def crop(self: PIL.Image.Image, box: tuple[int, int, int, int]) -> PIL.Image.Image:
|
|
97
97
|
"""
|
|
98
98
|
Return a rectangular region from the image. The box is a 4-tuple defining the left, upper, right, and lower pixel
|
|
@@ -151,7 +151,7 @@ def _(self: Expr) -> ts.ColumnType:
|
|
|
151
151
|
|
|
152
152
|
|
|
153
153
|
# Image.resize()
|
|
154
|
-
@func.udf(
|
|
154
|
+
@func.udf(is_method=True)
|
|
155
155
|
def resize(self: PIL.Image.Image, size: tuple[int, int]) -> PIL.Image.Image:
|
|
156
156
|
"""
|
|
157
157
|
Return a resized copy of the image. The size parameter is a tuple containing the width and height of the new image.
|
|
@@ -366,7 +366,7 @@ def quantize(
|
|
|
366
366
|
|
|
367
367
|
|
|
368
368
|
@func.udf(substitute_fn=PIL.Image.Image.reduce, is_method=True)
|
|
369
|
-
def reduce(self: PIL.Image.Image, factor: int, box: Optional[tuple[int]] = None) -> PIL.Image.Image:
|
|
369
|
+
def reduce(self: PIL.Image.Image, factor: int, box: Optional[tuple[int, int, int, int]] = None) -> PIL.Image.Image:
|
|
370
370
|
"""
|
|
371
371
|
Reduce the image by the given factor.
|
|
372
372
|
|
pixeltable/functions/together.py
CHANGED
|
@@ -7,12 +7,15 @@ the [Working with Together AI](https://pixeltable.readme.io/docs/together-ai) tu
|
|
|
7
7
|
|
|
8
8
|
import base64
|
|
9
9
|
import io
|
|
10
|
-
from typing import TYPE_CHECKING, Optional
|
|
10
|
+
from typing import TYPE_CHECKING, Callable, Optional, TypeVar
|
|
11
11
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
import PIL.Image
|
|
14
|
+
import requests
|
|
15
|
+
import tenacity
|
|
14
16
|
|
|
15
17
|
import pixeltable as pxt
|
|
18
|
+
import pixeltable.exceptions as excs
|
|
16
19
|
from pixeltable import env
|
|
17
20
|
from pixeltable.func import Batch
|
|
18
21
|
from pixeltable.utils.code import local_public_names
|
|
@@ -24,7 +27,6 @@ if TYPE_CHECKING:
|
|
|
24
27
|
@env.register_client('together')
|
|
25
28
|
def _(api_key: str) -> 'together.Together':
|
|
26
29
|
import together
|
|
27
|
-
|
|
28
30
|
return together.Together(api_key=api_key)
|
|
29
31
|
|
|
30
32
|
|
|
@@ -32,6 +34,18 @@ def _together_client() -> 'together.Together':
|
|
|
32
34
|
return env.Env.get().get_client('together')
|
|
33
35
|
|
|
34
36
|
|
|
37
|
+
T = TypeVar('T')
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _retry(fn: Callable[..., T]) -> Callable[..., T]:
|
|
41
|
+
import together
|
|
42
|
+
return tenacity.retry(
|
|
43
|
+
retry=tenacity.retry_if_exception_type(together.error.RateLimitError),
|
|
44
|
+
wait=tenacity.wait_random_exponential(multiplier=1, max=60),
|
|
45
|
+
stop=tenacity.stop_after_attempt(20),
|
|
46
|
+
)(fn)
|
|
47
|
+
|
|
48
|
+
|
|
35
49
|
@pxt.udf
|
|
36
50
|
def completions(
|
|
37
51
|
prompt: str,
|
|
@@ -74,8 +88,7 @@ def completions(
|
|
|
74
88
|
>>> tbl['response'] = completions(tbl.prompt, model='mistralai/Mixtral-8x7B-v0.1')
|
|
75
89
|
"""
|
|
76
90
|
return (
|
|
77
|
-
_together_client()
|
|
78
|
-
.completions.create(
|
|
91
|
+
_retry(_together_client().completions.create)(
|
|
79
92
|
prompt=prompt,
|
|
80
93
|
model=model,
|
|
81
94
|
max_tokens=max_tokens,
|
|
@@ -139,8 +152,7 @@ def chat_completions(
|
|
|
139
152
|
... tbl['response'] = chat_completions(messages, model='mistralai/Mixtral-8x7B-v0.1')
|
|
140
153
|
"""
|
|
141
154
|
return (
|
|
142
|
-
_together_client()
|
|
143
|
-
.chat.completions.create(
|
|
155
|
+
_retry(_together_client().chat.completions.create)(
|
|
144
156
|
messages=messages,
|
|
145
157
|
model=model,
|
|
146
158
|
max_tokens=max_tokens,
|
|
@@ -198,7 +210,7 @@ def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
|
|
|
198
210
|
|
|
199
211
|
>>> tbl['response'] = embeddings(tbl.text, model='togethercomputer/m2-bert-80M-8k-retrieval')
|
|
200
212
|
"""
|
|
201
|
-
result = _together_client().embeddings.create(input=input, model=model)
|
|
213
|
+
result = _retry(_together_client().embeddings.create)(input=input, model=model)
|
|
202
214
|
return [np.array(data.embedding, dtype=np.float64) for data in result.data]
|
|
203
215
|
|
|
204
216
|
|
|
@@ -242,20 +254,29 @@ def image_generations(
|
|
|
242
254
|
The generated image.
|
|
243
255
|
|
|
244
256
|
Examples:
|
|
245
|
-
Add a computed column that applies the model `
|
|
257
|
+
Add a computed column that applies the model `stabilityai/stable-diffusion-xl-base-1.0`
|
|
246
258
|
to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
|
|
247
259
|
|
|
248
|
-
>>> tbl['response'] = image_generations(tbl.prompt, model='
|
|
260
|
+
>>> tbl['response'] = image_generations(tbl.prompt, model='stabilityai/stable-diffusion-xl-base-1.0')
|
|
249
261
|
"""
|
|
250
|
-
|
|
251
|
-
result = _together_client().images.generate(
|
|
262
|
+
result = _retry(_together_client().images.generate)(
|
|
252
263
|
prompt=prompt, model=model, steps=steps, seed=seed, height=height, width=width, negative_prompt=negative_prompt
|
|
253
264
|
)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
265
|
+
if result.data[0].b64_json is not None:
|
|
266
|
+
b64_bytes = base64.b64decode(result.data[0].b64_json)
|
|
267
|
+
img = PIL.Image.open(io.BytesIO(b64_bytes))
|
|
268
|
+
img.load()
|
|
269
|
+
return img
|
|
270
|
+
if result.data[0].url is not None:
|
|
271
|
+
try:
|
|
272
|
+
resp = requests.get(result.data[0].url)
|
|
273
|
+
with io.BytesIO(resp.content) as fp:
|
|
274
|
+
image = PIL.Image.open(fp)
|
|
275
|
+
image.load()
|
|
276
|
+
return image
|
|
277
|
+
except Exception as exc:
|
|
278
|
+
raise excs.Error('Failed to download generated image from together.ai.') from exc
|
|
279
|
+
raise excs.Error('Response does not contain a generated image.')
|
|
259
280
|
|
|
260
281
|
|
|
261
282
|
__all__ = local_public_names(__name__)
|