pixeltable 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/catalog/column.py +26 -49
- pixeltable/catalog/insertable_table.py +7 -4
- pixeltable/catalog/table.py +163 -57
- pixeltable/catalog/table_version.py +416 -140
- pixeltable/catalog/table_version_path.py +2 -2
- pixeltable/client.py +72 -6
- pixeltable/dataframe.py +65 -21
- pixeltable/env.py +52 -53
- pixeltable/exec/cache_prefetch_node.py +1 -1
- pixeltable/exec/in_memory_data_node.py +11 -7
- pixeltable/exprs/comparison.py +3 -3
- pixeltable/exprs/data_row.py +5 -1
- pixeltable/exprs/literal.py +16 -4
- pixeltable/exprs/row_builder.py +8 -40
- pixeltable/ext/__init__.py +5 -0
- pixeltable/ext/functions/yolox.py +92 -0
- pixeltable/func/aggregate_function.py +15 -15
- pixeltable/func/expr_template_function.py +9 -1
- pixeltable/func/globals.py +24 -14
- pixeltable/func/signature.py +18 -12
- pixeltable/func/udf.py +7 -2
- pixeltable/functions/__init__.py +9 -9
- pixeltable/functions/eval.py +7 -8
- pixeltable/functions/fireworks.py +10 -37
- pixeltable/functions/huggingface.py +47 -19
- pixeltable/functions/openai.py +192 -24
- pixeltable/functions/together.py +104 -9
- pixeltable/functions/util.py +11 -0
- pixeltable/index/__init__.py +2 -0
- pixeltable/index/base.py +49 -0
- pixeltable/index/embedding_index.py +95 -0
- pixeltable/metadata/schema.py +45 -22
- pixeltable/plan.py +15 -34
- pixeltable/store.py +38 -41
- pixeltable/tests/conftest.py +8 -14
- pixeltable/tests/ext/test_yolox.py +21 -0
- pixeltable/tests/functions/test_fireworks.py +43 -0
- pixeltable/tests/functions/test_functions.py +60 -0
- pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +7 -143
- pixeltable/tests/functions/test_openai.py +162 -0
- pixeltable/tests/functions/test_together.py +112 -0
- pixeltable/tests/test_component_view.py +14 -5
- pixeltable/tests/test_dataframe.py +23 -22
- pixeltable/tests/test_exprs.py +99 -102
- pixeltable/tests/test_function.py +51 -43
- pixeltable/tests/test_index.py +138 -0
- pixeltable/tests/test_migration.py +2 -1
- pixeltable/tests/test_snapshot.py +24 -1
- pixeltable/tests/test_table.py +205 -26
- pixeltable/tests/test_types.py +30 -0
- pixeltable/tests/test_video.py +16 -16
- pixeltable/tests/test_view.py +5 -0
- pixeltable/tests/utils.py +171 -14
- pixeltable/tool/create_test_db_dump.py +16 -0
- pixeltable/type_system.py +77 -128
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/hf_datasets.py +157 -0
- pixeltable/utils/parquet.py +68 -27
- pixeltable/utils/pytorch.py +16 -97
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/METADATA +35 -28
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/RECORD +63 -50
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Iterable, Iterator
|
|
4
|
+
from urllib.request import urlretrieve
|
|
5
|
+
|
|
6
|
+
import PIL.Image
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from yolox.data import ValTransform
|
|
10
|
+
from yolox.exp import get_exp, Exp
|
|
11
|
+
from yolox.models import YOLOX
|
|
12
|
+
from yolox.utils import postprocess
|
|
13
|
+
|
|
14
|
+
import pixeltable as pxt
|
|
15
|
+
from pixeltable import env
|
|
16
|
+
from pixeltable.func import Batch
|
|
17
|
+
from pixeltable.functions.util import resolve_torch_device
|
|
18
|
+
|
|
19
|
+
_logger = logging.getLogger('pixeltable')
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pxt.udf(batch_size=4)
|
|
23
|
+
def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
|
|
24
|
+
"""
|
|
25
|
+
Runs the specified YOLOX object detection model on an image.
|
|
26
|
+
|
|
27
|
+
YOLOX support is part of the `pixeltable.ext` package: long-term support is not guaranteed, and it is not
|
|
28
|
+
intended for use in production applications.
|
|
29
|
+
|
|
30
|
+
Parameters:
|
|
31
|
+
- `model_id` - one of: `yolox_nano, `yolox_tiny`, `yolox_s`, `yolox_m`, `yolox_l`, `yolox_x`
|
|
32
|
+
- `threshold` - the threshold for object detection
|
|
33
|
+
"""
|
|
34
|
+
model, exp = _lookup_model(model_id, 'cpu')
|
|
35
|
+
image_tensors = list(_images_to_tensors(images, exp))
|
|
36
|
+
batch_tensor = torch.stack(image_tensors)
|
|
37
|
+
|
|
38
|
+
with torch.no_grad():
|
|
39
|
+
output_tensor = model(batch_tensor)
|
|
40
|
+
|
|
41
|
+
outputs = postprocess(
|
|
42
|
+
output_tensor, 80, threshold, exp.nmsthre, class_agnostic=False
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
results: list[dict] = []
|
|
46
|
+
for image in images:
|
|
47
|
+
ratio = min(exp.test_size[0] / image.height, exp.test_size[1] / image.width)
|
|
48
|
+
if outputs[0] is None:
|
|
49
|
+
results.append({'bboxes': [], 'scores': [], 'labels': []})
|
|
50
|
+
else:
|
|
51
|
+
results.append({
|
|
52
|
+
'bboxes': [(output[:4] / ratio).tolist() for output in outputs[0]],
|
|
53
|
+
'scores': [output[4].item() * output[5].item() for output in outputs[0]],
|
|
54
|
+
'labels': [int(output[6]) for output in outputs[0]]
|
|
55
|
+
})
|
|
56
|
+
return results
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _images_to_tensors(images: Iterable[PIL.Image.Image], exp: Exp) -> Iterator[torch.Tensor]:
|
|
60
|
+
for image in images:
|
|
61
|
+
image_transform, _ = _val_transform(np.array(image), None, exp.test_size)
|
|
62
|
+
yield torch.from_numpy(image_transform)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _lookup_model(model_id: str, device: str) -> (YOLOX, Exp):
|
|
66
|
+
key = (model_id, device)
|
|
67
|
+
if key in _model_cache:
|
|
68
|
+
return _model_cache[key]
|
|
69
|
+
|
|
70
|
+
weights_url = f'https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/{model_id}.pth'
|
|
71
|
+
weights_file = Path(f'{env.Env.get().tmp_dir}/{model_id}.pth')
|
|
72
|
+
if not weights_file.exists():
|
|
73
|
+
_logger.info(f'Downloading weights for YOLOX model {model_id}: from {weights_url} -> {weights_file}')
|
|
74
|
+
urlretrieve(weights_url, weights_file)
|
|
75
|
+
|
|
76
|
+
exp = get_exp(exp_name=model_id)
|
|
77
|
+
model = exp.get_model().to(device)
|
|
78
|
+
|
|
79
|
+
model.eval()
|
|
80
|
+
model.head.training = False
|
|
81
|
+
model.training = False
|
|
82
|
+
|
|
83
|
+
# Load in the weights from training
|
|
84
|
+
weights = torch.load(weights_file, map_location=torch.device(device))
|
|
85
|
+
model.load_state_dict(weights['model'])
|
|
86
|
+
|
|
87
|
+
_model_cache[key] = (model, exp)
|
|
88
|
+
return model, exp
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
_model_cache = {}
|
|
92
|
+
_val_transform = ValTransform(legacy=False)
|
|
@@ -3,13 +3,14 @@ from __future__ import annotations
|
|
|
3
3
|
import abc
|
|
4
4
|
import importlib
|
|
5
5
|
import inspect
|
|
6
|
-
from typing import Optional, Any, Type, List, Dict
|
|
6
|
+
from typing import Optional, Any, Type, List, Dict, Callable
|
|
7
7
|
import itertools
|
|
8
8
|
|
|
9
9
|
import pixeltable.exceptions as excs
|
|
10
10
|
import pixeltable.type_system as ts
|
|
11
11
|
from .function import Function
|
|
12
12
|
from .signature import Signature, Parameter
|
|
13
|
+
from .globals import validate_symbol_path
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class Aggregator(abc.ABC):
|
|
@@ -136,8 +137,7 @@ def uda(
|
|
|
136
137
|
update_types: List[ts.ColumnType],
|
|
137
138
|
init_types: Optional[List[ts.ColumnType]] = None,
|
|
138
139
|
requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
|
|
139
|
-
|
|
140
|
-
) -> Type[Aggregator]:
|
|
140
|
+
) -> Callable:
|
|
141
141
|
"""Decorator for user-defined aggregate functions.
|
|
142
142
|
|
|
143
143
|
The decorated class must inherit from Aggregator and implement the following methods:
|
|
@@ -155,14 +155,11 @@ def uda(
|
|
|
155
155
|
- requires_order_by: if True, the first parameter to the function is the order-by expression
|
|
156
156
|
- allows_std_agg: if True, the function can be used as a standard aggregate function w/o a window
|
|
157
157
|
- allows_window: if True, the function can be used with a window
|
|
158
|
-
- name: name of the AggregateFunction instance; if None, the class name is used
|
|
159
158
|
"""
|
|
160
|
-
if name is not None and not name.isidentifier():
|
|
161
|
-
raise excs.Error(f'Invalid name: {name}')
|
|
162
159
|
if init_types is None:
|
|
163
160
|
init_types = []
|
|
164
161
|
|
|
165
|
-
def decorator(cls: Type[Aggregator]) -> Type[
|
|
162
|
+
def decorator(cls: Type[Aggregator]) -> Type[Function]:
|
|
166
163
|
# validate type parameters
|
|
167
164
|
num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
|
|
168
165
|
if num_init_params > 0:
|
|
@@ -178,17 +175,20 @@ def uda(
|
|
|
178
175
|
assert value_type is not None
|
|
179
176
|
|
|
180
177
|
# the AggregateFunction instance resides in the same module as cls
|
|
181
|
-
|
|
182
|
-
nonlocal name
|
|
183
|
-
name = name or cls.__name__
|
|
184
|
-
|
|
178
|
+
class_path = f'{cls.__module__}.{cls.__qualname__}'
|
|
179
|
+
# nonlocal name
|
|
180
|
+
# name = name or cls.__name__
|
|
181
|
+
# instance_path_elements = class_path.split('.')[:-1] + [name]
|
|
182
|
+
# instance_path = '.'.join(instance_path_elements)
|
|
185
183
|
|
|
186
184
|
# create the corresponding AggregateFunction instance
|
|
187
185
|
instance = AggregateFunction(
|
|
188
|
-
cls,
|
|
189
|
-
|
|
190
|
-
|
|
186
|
+
cls, class_path, init_types, update_types, value_type, requires_order_by, allows_std_agg, allows_window)
|
|
187
|
+
# do the path validation at the very end, in order to be able to write tests for the other failure cases
|
|
188
|
+
validate_symbol_path(class_path)
|
|
189
|
+
#module = importlib.import_module(cls.__module__)
|
|
190
|
+
#setattr(module, name, instance)
|
|
191
191
|
|
|
192
|
-
return
|
|
192
|
+
return instance
|
|
193
193
|
|
|
194
194
|
return decorator
|
|
@@ -50,9 +50,17 @@ class ExprTemplateFunction(Function):
|
|
|
50
50
|
bound_args.update(
|
|
51
51
|
{param_name: default for param_name, default in self.defaults.items() if param_name not in bound_args})
|
|
52
52
|
result = self.expr.copy()
|
|
53
|
+
import pixeltable.exprs as exprs
|
|
53
54
|
for param_name, arg in bound_args.items():
|
|
54
55
|
param_expr = self.param_exprs_by_name[param_name]
|
|
55
|
-
|
|
56
|
+
if not isinstance(arg, exprs.Expr):
|
|
57
|
+
# TODO: use the available param_expr.col_type
|
|
58
|
+
arg_expr = exprs.Expr.from_object(arg)
|
|
59
|
+
if arg_expr is None:
|
|
60
|
+
raise excs.Error(f'{self.self_name}(): cannot convert argument {arg} to a Pixeltable expression')
|
|
61
|
+
else:
|
|
62
|
+
arg_expr = arg
|
|
63
|
+
result = result.substitute(param_expr, arg_expr)
|
|
56
64
|
import pixeltable.exprs as exprs
|
|
57
65
|
assert not result.contains(exprs.Variable)
|
|
58
66
|
return result
|
pixeltable/func/globals.py
CHANGED
|
@@ -1,29 +1,39 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
from types import ModuleType
|
|
3
1
|
import importlib
|
|
4
2
|
import inspect
|
|
3
|
+
from types import ModuleType
|
|
4
|
+
from typing import Optional
|
|
5
5
|
|
|
6
|
+
import pixeltable.exceptions as excs
|
|
6
7
|
|
|
7
|
-
|
|
8
|
+
|
|
9
|
+
def resolve_symbol(symbol_path: str) -> Optional[object]:
|
|
8
10
|
path_elems = symbol_path.split('.')
|
|
9
11
|
module: Optional[ModuleType] = None
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
# try to import the submodule directly
|
|
13
|
-
submodule_path = '.'.join(path_elems[0:3])
|
|
12
|
+
i = len(path_elems) - 1
|
|
13
|
+
while i > 0 and module is None:
|
|
14
14
|
try:
|
|
15
|
-
module = importlib.import_module(
|
|
16
|
-
path_elems = path_elems[3:]
|
|
15
|
+
module = importlib.import_module('.'.join(path_elems[:i]))
|
|
17
16
|
except ModuleNotFoundError:
|
|
18
|
-
|
|
19
|
-
if
|
|
20
|
-
|
|
21
|
-
path_elems = path_elems[1:]
|
|
17
|
+
i -= 1
|
|
18
|
+
if i == 0:
|
|
19
|
+
return None # Not resolvable
|
|
22
20
|
obj = module
|
|
23
|
-
for el in path_elems:
|
|
21
|
+
for el in path_elems[i:]:
|
|
24
22
|
obj = getattr(obj, el)
|
|
25
23
|
return obj
|
|
26
24
|
|
|
25
|
+
|
|
26
|
+
def validate_symbol_path(fn_path: str) -> None:
|
|
27
|
+
path_elems = fn_path.split('.')
|
|
28
|
+
fn_name = path_elems[-1]
|
|
29
|
+
if any(el == '<locals>' for el in path_elems):
|
|
30
|
+
raise excs.Error(
|
|
31
|
+
f'{fn_name}(): nested functions are not supported. Move the function to the module level or into a class.')
|
|
32
|
+
if any(not el.isidentifier() for el in path_elems):
|
|
33
|
+
raise excs.Error(
|
|
34
|
+
f'{fn_name}(): cannot resolve symbol path {fn_path}. Move the function to the module level or into a class.')
|
|
35
|
+
|
|
36
|
+
|
|
27
37
|
def get_caller_module_path() -> str:
|
|
28
38
|
"""Return the module path of our caller's caller"""
|
|
29
39
|
stack = inspect.stack()
|
pixeltable/func/signature.py
CHANGED
|
@@ -114,20 +114,12 @@ class Signature:
|
|
|
114
114
|
return (col_type, is_batched)
|
|
115
115
|
|
|
116
116
|
@classmethod
|
|
117
|
-
def
|
|
118
|
-
cls, c: Callable,
|
|
119
|
-
param_types: Optional[List[ts.ColumnType]] = None,
|
|
120
|
-
return_type: Optional[Union[ts.ColumnType, Callable]] = None
|
|
121
|
-
) -> Signature:
|
|
122
|
-
"""Create a signature for the given Callable.
|
|
123
|
-
Infer the parameter and return types, if none are specified.
|
|
124
|
-
Raises an exception if the types cannot be inferred.
|
|
125
|
-
"""
|
|
117
|
+
def create_parameters(
|
|
118
|
+
cls, c: Callable, param_types: Optional[List[ts.ColumnType]] = None) -> List[Parameter]:
|
|
126
119
|
sig = inspect.signature(c)
|
|
127
120
|
py_parameters = list(sig.parameters.values())
|
|
128
|
-
|
|
129
|
-
# check non-var parameters for name collisions and default value compatibility
|
|
130
121
|
parameters: List[Parameter] = []
|
|
122
|
+
|
|
131
123
|
for idx, param in enumerate(py_parameters):
|
|
132
124
|
if param.name in cls.SPECIAL_PARAM_NAMES:
|
|
133
125
|
raise excs.Error(f"'{param.name}' is a reserved parameter name")
|
|
@@ -135,6 +127,7 @@ class Signature:
|
|
|
135
127
|
parameters.append(Parameter(param.name, None, param.kind, False))
|
|
136
128
|
continue
|
|
137
129
|
|
|
130
|
+
# check non-var parameters for name collisions and default value compatibility
|
|
138
131
|
if param_types is not None:
|
|
139
132
|
if idx >= len(param_types):
|
|
140
133
|
raise excs.Error(f'Missing type for parameter {param.name}')
|
|
@@ -155,7 +148,20 @@ class Signature:
|
|
|
155
148
|
|
|
156
149
|
parameters.append(Parameter(param.name, param_type, param.kind, is_batched))
|
|
157
150
|
|
|
158
|
-
|
|
151
|
+
return parameters
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def create(
|
|
155
|
+
cls, c: Callable,
|
|
156
|
+
param_types: Optional[List[ts.ColumnType]] = None,
|
|
157
|
+
return_type: Optional[Union[ts.ColumnType, Callable]] = None
|
|
158
|
+
) -> Signature:
|
|
159
|
+
"""Create a signature for the given Callable.
|
|
160
|
+
Infer the parameter and return types, if none are specified.
|
|
161
|
+
Raises an exception if the types cannot be inferred.
|
|
162
|
+
"""
|
|
163
|
+
parameters = cls.create_parameters(c, param_types)
|
|
164
|
+
sig = inspect.signature(c)
|
|
159
165
|
if return_type is None:
|
|
160
166
|
return_type, return_is_batched = cls._infer_type(sig.return_annotation)
|
|
161
167
|
if return_type is None:
|
pixeltable/func/udf.py
CHANGED
|
@@ -11,6 +11,7 @@ from .callable_function import CallableFunction
|
|
|
11
11
|
from .expr_template_function import ExprTemplateFunction
|
|
12
12
|
from .function import Function
|
|
13
13
|
from .function_registry import FunctionRegistry
|
|
14
|
+
from .globals import validate_symbol_path
|
|
14
15
|
from .signature import Signature
|
|
15
16
|
|
|
16
17
|
|
|
@@ -124,6 +125,8 @@ def make_function(
|
|
|
124
125
|
|
|
125
126
|
# If this function is part of a module, register it
|
|
126
127
|
if function_path is not None:
|
|
128
|
+
# do the validation at the very end, so it's easier to write tests for other failure scenarios
|
|
129
|
+
validate_symbol_path(function_path)
|
|
127
130
|
FunctionRegistry.get().register_function(function_path, result)
|
|
128
131
|
|
|
129
132
|
return result
|
|
@@ -142,17 +145,19 @@ def expr_udf(*args: Any, **kwargs: Any) -> Any:
|
|
|
142
145
|
else:
|
|
143
146
|
function_path = None
|
|
144
147
|
|
|
145
|
-
sig = Signature.create(py_fn, param_types=param_types, return_type=None)
|
|
146
148
|
# TODO: verify that the inferred return type matches that of the template
|
|
147
149
|
# TODO: verify that the signature doesn't contain batched parameters
|
|
148
150
|
|
|
149
151
|
# construct Parameters from the function signature
|
|
152
|
+
params = Signature.create_parameters(py_fn, param_types=param_types)
|
|
150
153
|
import pixeltable.exprs as exprs
|
|
151
|
-
var_exprs = [exprs.Variable(param.name, param.col_type) for param in
|
|
154
|
+
var_exprs = [exprs.Variable(param.name, param.col_type) for param in params]
|
|
152
155
|
# call the function with the parameter expressions to construct an Expr with parameters
|
|
153
156
|
template = py_fn(*var_exprs)
|
|
154
157
|
assert isinstance(template, exprs.Expr)
|
|
155
158
|
py_sig = inspect.signature(py_fn)
|
|
159
|
+
if function_path is not None:
|
|
160
|
+
validate_symbol_path(function_path)
|
|
156
161
|
return ExprTemplateFunction(template, py_signature=py_sig, self_path=function_path, name=py_fn.__name__)
|
|
157
162
|
|
|
158
163
|
if len(args) == 1:
|
pixeltable/functions/__init__.py
CHANGED
|
@@ -15,7 +15,7 @@ import pixeltable.functions.pil.image
|
|
|
15
15
|
from pixeltable import exprs
|
|
16
16
|
from pixeltable.type_system import IntType, ColumnType, FloatType, ImageType, VideoType
|
|
17
17
|
# automatically import all submodules so that the udfs get registered
|
|
18
|
-
from . import image, string, video,
|
|
18
|
+
from . import image, string, video, huggingface
|
|
19
19
|
|
|
20
20
|
# TODO: remove and replace calls with astype()
|
|
21
21
|
def cast(expr: exprs.Expr, target_type: ColumnType) -> exprs.Expr:
|
|
@@ -23,8 +23,8 @@ def cast(expr: exprs.Expr, target_type: ColumnType) -> exprs.Expr:
|
|
|
23
23
|
return expr
|
|
24
24
|
|
|
25
25
|
@func.uda(
|
|
26
|
-
update_types=[IntType()], value_type=IntType(),
|
|
27
|
-
class
|
|
26
|
+
update_types=[IntType()], value_type=IntType(), allows_window=True, requires_order_by=False)
|
|
27
|
+
class sum(func.Aggregator):
|
|
28
28
|
def __init__(self):
|
|
29
29
|
self.sum: Union[int, float] = 0
|
|
30
30
|
def update(self, val: Union[int, float]) -> None:
|
|
@@ -35,8 +35,8 @@ class SumAggregator(func.Aggregator):
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
@func.uda(
|
|
38
|
-
update_types=[IntType()], value_type=IntType(),
|
|
39
|
-
class
|
|
38
|
+
update_types=[IntType()], value_type=IntType(), allows_window = True, requires_order_by = False)
|
|
39
|
+
class count(func.Aggregator):
|
|
40
40
|
def __init__(self):
|
|
41
41
|
self.count = 0
|
|
42
42
|
def update(self, val: int) -> None:
|
|
@@ -47,8 +47,8 @@ class CountAggregator(func.Aggregator):
|
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
@func.uda(
|
|
50
|
-
update_types=[IntType()], value_type=FloatType(),
|
|
51
|
-
class
|
|
50
|
+
update_types=[IntType()], value_type=FloatType(), allows_window=False, requires_order_by=False)
|
|
51
|
+
class mean(func.Aggregator):
|
|
52
52
|
def __init__(self):
|
|
53
53
|
self.sum = 0
|
|
54
54
|
self.count = 0
|
|
@@ -63,9 +63,9 @@ class MeanAggregator(func.Aggregator):
|
|
|
63
63
|
|
|
64
64
|
|
|
65
65
|
@func.uda(
|
|
66
|
-
init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(),
|
|
66
|
+
init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(),
|
|
67
67
|
requires_order_by=True, allows_window=False)
|
|
68
|
-
class
|
|
68
|
+
class make_video(func.Aggregator):
|
|
69
69
|
def __init__(self, fps: int = 25):
|
|
70
70
|
"""follows https://pyav.org/docs/develop/cookbook/numpy.html#generating-video"""
|
|
71
71
|
self.container: Optional[av.container.OutputContainer] = None
|
pixeltable/functions/eval.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
1
|
from typing import List, Tuple, Dict
|
|
3
2
|
from collections import defaultdict
|
|
4
3
|
import sys
|
|
@@ -157,16 +156,16 @@ def calculate_image_tpfp(
|
|
|
157
156
|
ts.JsonType(nullable=False)
|
|
158
157
|
])
|
|
159
158
|
def eval_detections(
|
|
160
|
-
pred_bboxes: List[List[int]],
|
|
161
|
-
gt_bboxes: List[List[int]],
|
|
159
|
+
pred_bboxes: List[List[int]], pred_labels: List[int], pred_scores: List[float],
|
|
160
|
+
gt_bboxes: List[List[int]], gt_labels: List[int]
|
|
162
161
|
) -> Dict:
|
|
163
|
-
class_idxs = list(set(
|
|
162
|
+
class_idxs = list(set(pred_labels + gt_labels))
|
|
164
163
|
result: List[Dict] = []
|
|
165
164
|
pred_bboxes_arr = np.asarray(pred_bboxes)
|
|
166
|
-
pred_classes_arr = np.asarray(
|
|
165
|
+
pred_classes_arr = np.asarray(pred_labels)
|
|
167
166
|
pred_scores_arr = np.asarray(pred_scores)
|
|
168
167
|
gt_bboxes_arr = np.asarray(gt_bboxes)
|
|
169
|
-
gt_classes_arr = np.asarray(
|
|
168
|
+
gt_classes_arr = np.asarray(gt_labels)
|
|
170
169
|
for class_idx in class_idxs:
|
|
171
170
|
pred_filter = pred_classes_arr == class_idx
|
|
172
171
|
gt_filter = gt_classes_arr == class_idx
|
|
@@ -181,8 +180,8 @@ def eval_detections(
|
|
|
181
180
|
return result
|
|
182
181
|
|
|
183
182
|
@func.uda(
|
|
184
|
-
update_types=[ts.JsonType()], value_type=ts.JsonType(),
|
|
185
|
-
class
|
|
183
|
+
update_types=[ts.JsonType()], value_type=ts.JsonType(), allows_std_agg=True, allows_window=False)
|
|
184
|
+
class mean_ap(func.Aggregator):
|
|
186
185
|
def __init__(self):
|
|
187
186
|
self.class_tpfp: Dict[int, List[Dict]] = defaultdict(list)
|
|
188
187
|
|
|
@@ -1,61 +1,34 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import os
|
|
3
1
|
from typing import Optional
|
|
4
2
|
|
|
3
|
+
import fireworks.client
|
|
4
|
+
|
|
5
5
|
import pixeltable as pxt
|
|
6
|
-
import pixeltable.exceptions as excs
|
|
7
6
|
from pixeltable import env
|
|
8
7
|
|
|
9
8
|
|
|
9
|
+
def fireworks_client() -> fireworks.client.Fireworks:
|
|
10
|
+
return env.Env.get().get_client('fireworks', lambda api_key: fireworks.client.Fireworks(api_key=api_key))
|
|
11
|
+
|
|
12
|
+
|
|
10
13
|
@pxt.udf
|
|
11
14
|
def chat_completions(
|
|
12
|
-
|
|
13
|
-
model: str,
|
|
15
|
+
messages: list[dict[str, str]],
|
|
14
16
|
*,
|
|
17
|
+
model: str,
|
|
15
18
|
max_tokens: Optional[int] = None,
|
|
16
|
-
repetition_penalty: Optional[float] = None,
|
|
17
19
|
top_k: Optional[int] = None,
|
|
18
20
|
top_p: Optional[float] = None,
|
|
19
21
|
temperature: Optional[float] = None
|
|
20
22
|
) -> dict:
|
|
21
|
-
initialize()
|
|
22
23
|
kwargs = {
|
|
23
24
|
'max_tokens': max_tokens,
|
|
24
|
-
'repetition_penalty': repetition_penalty,
|
|
25
25
|
'top_k': top_k,
|
|
26
26
|
'top_p': top_p,
|
|
27
27
|
'temperature': temperature
|
|
28
28
|
}
|
|
29
29
|
kwargs_not_none = dict(filter(lambda x: x[1] is not None, kwargs.items()))
|
|
30
|
-
|
|
31
|
-
return fireworks.client.Completion.create(
|
|
30
|
+
return fireworks_client().chat.completions.create(
|
|
32
31
|
model=model,
|
|
33
|
-
|
|
32
|
+
messages=messages,
|
|
34
33
|
**kwargs_not_none
|
|
35
34
|
).dict()
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def initialize():
|
|
39
|
-
global _is_fireworks_initialized
|
|
40
|
-
if _is_fireworks_initialized:
|
|
41
|
-
return
|
|
42
|
-
|
|
43
|
-
_logger.info('Initializing Fireworks client.')
|
|
44
|
-
|
|
45
|
-
config = pxt.env.Env.get().config
|
|
46
|
-
|
|
47
|
-
if 'fireworks' in config and 'api_key' in config['fireworks']:
|
|
48
|
-
api_key = config['fireworks']['api_key']
|
|
49
|
-
else:
|
|
50
|
-
api_key = os.environ.get('FIREWORKS_API_KEY')
|
|
51
|
-
if api_key is None or api_key == '':
|
|
52
|
-
raise excs.Error('Fireworks client not initialized (no API key configured).')
|
|
53
|
-
|
|
54
|
-
import fireworks.client
|
|
55
|
-
|
|
56
|
-
fireworks.client.api_key = api_key
|
|
57
|
-
_is_fireworks_initialized = True
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
_logger = logging.getLogger('pixeltable')
|
|
61
|
-
_is_fireworks_initialized = False
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Callable, TypeVar, Optional
|
|
2
2
|
|
|
3
3
|
import PIL.Image
|
|
4
4
|
import numpy as np
|
|
@@ -7,10 +7,13 @@ import pixeltable as pxt
|
|
|
7
7
|
import pixeltable.env as env
|
|
8
8
|
import pixeltable.type_system as ts
|
|
9
9
|
from pixeltable.func import Batch
|
|
10
|
+
from pixeltable.functions.util import resolve_torch_device
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
@pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
|
|
13
|
-
def sentence_transformer(
|
|
14
|
+
def sentence_transformer(
|
|
15
|
+
sentences: Batch[str], *, model_id: str, normalize_embeddings: bool = False
|
|
16
|
+
) -> Batch[np.ndarray]:
|
|
14
17
|
env.Env.get().require_package('sentence_transformers')
|
|
15
18
|
from sentence_transformers import SentenceTransformer
|
|
16
19
|
|
|
@@ -53,44 +56,60 @@ def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> li
|
|
|
53
56
|
return array.tolist()
|
|
54
57
|
|
|
55
58
|
|
|
56
|
-
@pxt.udf(batch_size=32, return_type=ts.ArrayType((
|
|
59
|
+
@pxt.udf(batch_size=32, return_type=ts.ArrayType((512,), dtype=ts.FloatType(), nullable=False))
|
|
57
60
|
def clip_text(text: Batch[str], *, model_id: str) -> Batch[np.ndarray]:
|
|
58
61
|
env.Env.get().require_package('transformers')
|
|
62
|
+
device = resolve_torch_device('auto')
|
|
63
|
+
import torch
|
|
59
64
|
from transformers import CLIPModel, CLIPProcessor
|
|
60
65
|
|
|
61
|
-
model = _lookup_model(model_id, CLIPModel.from_pretrained)
|
|
66
|
+
model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
|
|
67
|
+
assert model.config.projection_dim == 512
|
|
62
68
|
processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
|
|
63
69
|
|
|
64
|
-
|
|
65
|
-
|
|
70
|
+
with torch.no_grad():
|
|
71
|
+
inputs = processor(text=text, return_tensors='pt', padding=True, truncation=True)
|
|
72
|
+
embeddings = model.get_text_features(**inputs.to(device)).detach().to('cpu').numpy()
|
|
73
|
+
|
|
66
74
|
return [embeddings[i] for i in range(embeddings.shape[0])]
|
|
67
75
|
|
|
68
76
|
|
|
69
|
-
@pxt.udf(batch_size=32, return_type=ts.ArrayType((
|
|
77
|
+
@pxt.udf(batch_size=32, return_type=ts.ArrayType((512,), dtype=ts.FloatType(), nullable=False))
|
|
70
78
|
def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndarray]:
|
|
71
79
|
env.Env.get().require_package('transformers')
|
|
80
|
+
device = resolve_torch_device('auto')
|
|
81
|
+
import torch
|
|
72
82
|
from transformers import CLIPModel, CLIPProcessor
|
|
73
83
|
|
|
74
|
-
model = _lookup_model(model_id, CLIPModel.from_pretrained)
|
|
84
|
+
model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
|
|
85
|
+
assert model.config.projection_dim == 512
|
|
75
86
|
processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
|
|
76
87
|
|
|
77
|
-
|
|
78
|
-
|
|
88
|
+
with torch.no_grad():
|
|
89
|
+
inputs = processor(images=image, return_tensors='pt', padding=True)
|
|
90
|
+
embeddings = model.get_image_features(**inputs.to(device)).detach().to('cpu').numpy()
|
|
91
|
+
|
|
79
92
|
return [embeddings[i] for i in range(embeddings.shape[0])]
|
|
80
93
|
|
|
81
94
|
|
|
82
|
-
@pxt.udf(batch_size=
|
|
95
|
+
@pxt.udf(batch_size=4)
|
|
83
96
|
def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
|
|
84
97
|
env.Env.get().require_package('transformers')
|
|
98
|
+
device = resolve_torch_device('auto')
|
|
99
|
+
import torch
|
|
85
100
|
from transformers import DetrImageProcessor, DetrForObjectDetection
|
|
86
101
|
|
|
87
|
-
model = _lookup_model(
|
|
102
|
+
model = _lookup_model(
|
|
103
|
+
model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision='no_timm'), device=device)
|
|
88
104
|
processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision='no_timm'))
|
|
89
105
|
|
|
90
|
-
|
|
91
|
-
|
|
106
|
+
with torch.no_grad():
|
|
107
|
+
inputs = processor(images=image, return_tensors='pt')
|
|
108
|
+
outputs = model(**inputs.to(device))
|
|
109
|
+
results = processor.post_process_object_detection(
|
|
110
|
+
outputs, threshold=threshold, target_sizes=[(img.height, img.width) for img in image]
|
|
111
|
+
)
|
|
92
112
|
|
|
93
|
-
results = processor.post_process_object_detection(outputs, threshold=threshold)
|
|
94
113
|
return [
|
|
95
114
|
{
|
|
96
115
|
'scores': [score.item() for score in result['scores']],
|
|
@@ -102,14 +121,23 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
|
|
|
102
121
|
]
|
|
103
122
|
|
|
104
123
|
|
|
105
|
-
|
|
106
|
-
|
|
124
|
+
T = TypeVar('T')
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _lookup_model(model_id: str, create: Callable[[str], T], device: Optional[str] = None) -> T:
|
|
128
|
+
from torch import nn
|
|
129
|
+
key = (model_id, create, device) # For safety, include the `create` callable in the cache key
|
|
107
130
|
if key not in _model_cache:
|
|
108
|
-
|
|
131
|
+
model = create(model_id)
|
|
132
|
+
if device is not None:
|
|
133
|
+
model.to(device)
|
|
134
|
+
if isinstance(model, nn.Module):
|
|
135
|
+
model.eval()
|
|
136
|
+
_model_cache[key] = model
|
|
109
137
|
return _model_cache[key]
|
|
110
138
|
|
|
111
139
|
|
|
112
|
-
def _lookup_processor(model_id: str, create: Callable) ->
|
|
140
|
+
def _lookup_processor(model_id: str, create: Callable[[str], T]) -> T:
|
|
113
141
|
key = (model_id, create) # For safety, include the `create` callable in the cache key
|
|
114
142
|
if key not in _processor_cache:
|
|
115
143
|
_processor_cache[key] = create(model_id)
|