pixeltable 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/catalog/column.py +25 -48
- pixeltable/catalog/insertable_table.py +7 -4
- pixeltable/catalog/table.py +163 -57
- pixeltable/catalog/table_version.py +416 -140
- pixeltable/catalog/table_version_path.py +2 -2
- pixeltable/client.py +0 -4
- pixeltable/dataframe.py +65 -21
- pixeltable/env.py +16 -1
- pixeltable/exec/cache_prefetch_node.py +1 -1
- pixeltable/exec/in_memory_data_node.py +11 -7
- pixeltable/exprs/comparison.py +3 -3
- pixeltable/exprs/data_row.py +5 -1
- pixeltable/exprs/literal.py +16 -4
- pixeltable/exprs/row_builder.py +8 -40
- pixeltable/ext/__init__.py +5 -0
- pixeltable/ext/functions/yolox.py +92 -0
- pixeltable/func/aggregate_function.py +15 -15
- pixeltable/func/expr_template_function.py +9 -1
- pixeltable/func/globals.py +24 -14
- pixeltable/func/signature.py +18 -12
- pixeltable/func/udf.py +7 -2
- pixeltable/functions/__init__.py +8 -8
- pixeltable/functions/eval.py +7 -8
- pixeltable/functions/huggingface.py +47 -19
- pixeltable/functions/openai.py +2 -2
- pixeltable/functions/util.py +11 -0
- pixeltable/index/__init__.py +2 -0
- pixeltable/index/base.py +49 -0
- pixeltable/index/embedding_index.py +95 -0
- pixeltable/metadata/schema.py +45 -22
- pixeltable/plan.py +15 -34
- pixeltable/store.py +38 -41
- pixeltable/tests/conftest.py +5 -11
- pixeltable/tests/ext/test_yolox.py +21 -0
- pixeltable/tests/functions/test_fireworks.py +1 -0
- pixeltable/tests/functions/test_huggingface.py +2 -2
- pixeltable/tests/functions/test_openai.py +15 -5
- pixeltable/tests/functions/test_together.py +1 -0
- pixeltable/tests/test_component_view.py +14 -5
- pixeltable/tests/test_dataframe.py +19 -18
- pixeltable/tests/test_exprs.py +99 -102
- pixeltable/tests/test_function.py +51 -43
- pixeltable/tests/test_index.py +138 -0
- pixeltable/tests/test_migration.py +2 -1
- pixeltable/tests/test_snapshot.py +24 -1
- pixeltable/tests/test_table.py +101 -25
- pixeltable/tests/test_types.py +30 -0
- pixeltable/tests/test_video.py +16 -16
- pixeltable/tests/test_view.py +5 -0
- pixeltable/tests/utils.py +43 -9
- pixeltable/tool/create_test_db_dump.py +16 -0
- pixeltable/type_system.py +37 -45
- {pixeltable-0.2.4.dist-info → pixeltable-0.2.5.dist-info}/METADATA +5 -4
- {pixeltable-0.2.4.dist-info → pixeltable-0.2.5.dist-info}/RECORD +56 -49
- {pixeltable-0.2.4.dist-info → pixeltable-0.2.5.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.4.dist-info → pixeltable-0.2.5.dist-info}/WHEEL +0 -0
pixeltable/func/globals.py
CHANGED
|
@@ -1,29 +1,39 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
from types import ModuleType
|
|
3
1
|
import importlib
|
|
4
2
|
import inspect
|
|
3
|
+
from types import ModuleType
|
|
4
|
+
from typing import Optional
|
|
5
5
|
|
|
6
|
+
import pixeltable.exceptions as excs
|
|
6
7
|
|
|
7
|
-
|
|
8
|
+
|
|
9
|
+
def resolve_symbol(symbol_path: str) -> Optional[object]:
|
|
8
10
|
path_elems = symbol_path.split('.')
|
|
9
11
|
module: Optional[ModuleType] = None
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
# try to import the submodule directly
|
|
13
|
-
submodule_path = '.'.join(path_elems[0:3])
|
|
12
|
+
i = len(path_elems) - 1
|
|
13
|
+
while i > 0 and module is None:
|
|
14
14
|
try:
|
|
15
|
-
module = importlib.import_module(
|
|
16
|
-
path_elems = path_elems[3:]
|
|
15
|
+
module = importlib.import_module('.'.join(path_elems[:i]))
|
|
17
16
|
except ModuleNotFoundError:
|
|
18
|
-
|
|
19
|
-
if
|
|
20
|
-
|
|
21
|
-
path_elems = path_elems[1:]
|
|
17
|
+
i -= 1
|
|
18
|
+
if i == 0:
|
|
19
|
+
return None # Not resolvable
|
|
22
20
|
obj = module
|
|
23
|
-
for el in path_elems:
|
|
21
|
+
for el in path_elems[i:]:
|
|
24
22
|
obj = getattr(obj, el)
|
|
25
23
|
return obj
|
|
26
24
|
|
|
25
|
+
|
|
26
|
+
def validate_symbol_path(fn_path: str) -> None:
|
|
27
|
+
path_elems = fn_path.split('.')
|
|
28
|
+
fn_name = path_elems[-1]
|
|
29
|
+
if any(el == '<locals>' for el in path_elems):
|
|
30
|
+
raise excs.Error(
|
|
31
|
+
f'{fn_name}(): nested functions are not supported. Move the function to the module level or into a class.')
|
|
32
|
+
if any(not el.isidentifier() for el in path_elems):
|
|
33
|
+
raise excs.Error(
|
|
34
|
+
f'{fn_name}(): cannot resolve symbol path {fn_path}. Move the function to the module level or into a class.')
|
|
35
|
+
|
|
36
|
+
|
|
27
37
|
def get_caller_module_path() -> str:
|
|
28
38
|
"""Return the module path of our caller's caller"""
|
|
29
39
|
stack = inspect.stack()
|
pixeltable/func/signature.py
CHANGED
|
@@ -114,20 +114,12 @@ class Signature:
|
|
|
114
114
|
return (col_type, is_batched)
|
|
115
115
|
|
|
116
116
|
@classmethod
|
|
117
|
-
def
|
|
118
|
-
cls, c: Callable,
|
|
119
|
-
param_types: Optional[List[ts.ColumnType]] = None,
|
|
120
|
-
return_type: Optional[Union[ts.ColumnType, Callable]] = None
|
|
121
|
-
) -> Signature:
|
|
122
|
-
"""Create a signature for the given Callable.
|
|
123
|
-
Infer the parameter and return types, if none are specified.
|
|
124
|
-
Raises an exception if the types cannot be inferred.
|
|
125
|
-
"""
|
|
117
|
+
def create_parameters(
|
|
118
|
+
cls, c: Callable, param_types: Optional[List[ts.ColumnType]] = None) -> List[Parameter]:
|
|
126
119
|
sig = inspect.signature(c)
|
|
127
120
|
py_parameters = list(sig.parameters.values())
|
|
128
|
-
|
|
129
|
-
# check non-var parameters for name collisions and default value compatibility
|
|
130
121
|
parameters: List[Parameter] = []
|
|
122
|
+
|
|
131
123
|
for idx, param in enumerate(py_parameters):
|
|
132
124
|
if param.name in cls.SPECIAL_PARAM_NAMES:
|
|
133
125
|
raise excs.Error(f"'{param.name}' is a reserved parameter name")
|
|
@@ -135,6 +127,7 @@ class Signature:
|
|
|
135
127
|
parameters.append(Parameter(param.name, None, param.kind, False))
|
|
136
128
|
continue
|
|
137
129
|
|
|
130
|
+
# check non-var parameters for name collisions and default value compatibility
|
|
138
131
|
if param_types is not None:
|
|
139
132
|
if idx >= len(param_types):
|
|
140
133
|
raise excs.Error(f'Missing type for parameter {param.name}')
|
|
@@ -155,7 +148,20 @@ class Signature:
|
|
|
155
148
|
|
|
156
149
|
parameters.append(Parameter(param.name, param_type, param.kind, is_batched))
|
|
157
150
|
|
|
158
|
-
|
|
151
|
+
return parameters
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def create(
|
|
155
|
+
cls, c: Callable,
|
|
156
|
+
param_types: Optional[List[ts.ColumnType]] = None,
|
|
157
|
+
return_type: Optional[Union[ts.ColumnType, Callable]] = None
|
|
158
|
+
) -> Signature:
|
|
159
|
+
"""Create a signature for the given Callable.
|
|
160
|
+
Infer the parameter and return types, if none are specified.
|
|
161
|
+
Raises an exception if the types cannot be inferred.
|
|
162
|
+
"""
|
|
163
|
+
parameters = cls.create_parameters(c, param_types)
|
|
164
|
+
sig = inspect.signature(c)
|
|
159
165
|
if return_type is None:
|
|
160
166
|
return_type, return_is_batched = cls._infer_type(sig.return_annotation)
|
|
161
167
|
if return_type is None:
|
pixeltable/func/udf.py
CHANGED
|
@@ -11,6 +11,7 @@ from .callable_function import CallableFunction
|
|
|
11
11
|
from .expr_template_function import ExprTemplateFunction
|
|
12
12
|
from .function import Function
|
|
13
13
|
from .function_registry import FunctionRegistry
|
|
14
|
+
from .globals import validate_symbol_path
|
|
14
15
|
from .signature import Signature
|
|
15
16
|
|
|
16
17
|
|
|
@@ -124,6 +125,8 @@ def make_function(
|
|
|
124
125
|
|
|
125
126
|
# If this function is part of a module, register it
|
|
126
127
|
if function_path is not None:
|
|
128
|
+
# do the validation at the very end, so it's easier to write tests for other failure scenarios
|
|
129
|
+
validate_symbol_path(function_path)
|
|
127
130
|
FunctionRegistry.get().register_function(function_path, result)
|
|
128
131
|
|
|
129
132
|
return result
|
|
@@ -142,17 +145,19 @@ def expr_udf(*args: Any, **kwargs: Any) -> Any:
|
|
|
142
145
|
else:
|
|
143
146
|
function_path = None
|
|
144
147
|
|
|
145
|
-
sig = Signature.create(py_fn, param_types=param_types, return_type=None)
|
|
146
148
|
# TODO: verify that the inferred return type matches that of the template
|
|
147
149
|
# TODO: verify that the signature doesn't contain batched parameters
|
|
148
150
|
|
|
149
151
|
# construct Parameters from the function signature
|
|
152
|
+
params = Signature.create_parameters(py_fn, param_types=param_types)
|
|
150
153
|
import pixeltable.exprs as exprs
|
|
151
|
-
var_exprs = [exprs.Variable(param.name, param.col_type) for param in
|
|
154
|
+
var_exprs = [exprs.Variable(param.name, param.col_type) for param in params]
|
|
152
155
|
# call the function with the parameter expressions to construct an Expr with parameters
|
|
153
156
|
template = py_fn(*var_exprs)
|
|
154
157
|
assert isinstance(template, exprs.Expr)
|
|
155
158
|
py_sig = inspect.signature(py_fn)
|
|
159
|
+
if function_path is not None:
|
|
160
|
+
validate_symbol_path(function_path)
|
|
156
161
|
return ExprTemplateFunction(template, py_signature=py_sig, self_path=function_path, name=py_fn.__name__)
|
|
157
162
|
|
|
158
163
|
if len(args) == 1:
|
pixeltable/functions/__init__.py
CHANGED
|
@@ -23,8 +23,8 @@ def cast(expr: exprs.Expr, target_type: ColumnType) -> exprs.Expr:
|
|
|
23
23
|
return expr
|
|
24
24
|
|
|
25
25
|
@func.uda(
|
|
26
|
-
update_types=[IntType()], value_type=IntType(),
|
|
27
|
-
class
|
|
26
|
+
update_types=[IntType()], value_type=IntType(), allows_window=True, requires_order_by=False)
|
|
27
|
+
class sum(func.Aggregator):
|
|
28
28
|
def __init__(self):
|
|
29
29
|
self.sum: Union[int, float] = 0
|
|
30
30
|
def update(self, val: Union[int, float]) -> None:
|
|
@@ -35,8 +35,8 @@ class SumAggregator(func.Aggregator):
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
@func.uda(
|
|
38
|
-
update_types=[IntType()], value_type=IntType(),
|
|
39
|
-
class
|
|
38
|
+
update_types=[IntType()], value_type=IntType(), allows_window = True, requires_order_by = False)
|
|
39
|
+
class count(func.Aggregator):
|
|
40
40
|
def __init__(self):
|
|
41
41
|
self.count = 0
|
|
42
42
|
def update(self, val: int) -> None:
|
|
@@ -47,8 +47,8 @@ class CountAggregator(func.Aggregator):
|
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
@func.uda(
|
|
50
|
-
update_types=[IntType()], value_type=FloatType(),
|
|
51
|
-
class
|
|
50
|
+
update_types=[IntType()], value_type=FloatType(), allows_window=False, requires_order_by=False)
|
|
51
|
+
class mean(func.Aggregator):
|
|
52
52
|
def __init__(self):
|
|
53
53
|
self.sum = 0
|
|
54
54
|
self.count = 0
|
|
@@ -63,9 +63,9 @@ class MeanAggregator(func.Aggregator):
|
|
|
63
63
|
|
|
64
64
|
|
|
65
65
|
@func.uda(
|
|
66
|
-
init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(),
|
|
66
|
+
init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(),
|
|
67
67
|
requires_order_by=True, allows_window=False)
|
|
68
|
-
class
|
|
68
|
+
class make_video(func.Aggregator):
|
|
69
69
|
def __init__(self, fps: int = 25):
|
|
70
70
|
"""follows https://pyav.org/docs/develop/cookbook/numpy.html#generating-video"""
|
|
71
71
|
self.container: Optional[av.container.OutputContainer] = None
|
pixeltable/functions/eval.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
1
|
from typing import List, Tuple, Dict
|
|
3
2
|
from collections import defaultdict
|
|
4
3
|
import sys
|
|
@@ -157,16 +156,16 @@ def calculate_image_tpfp(
|
|
|
157
156
|
ts.JsonType(nullable=False)
|
|
158
157
|
])
|
|
159
158
|
def eval_detections(
|
|
160
|
-
pred_bboxes: List[List[int]],
|
|
161
|
-
gt_bboxes: List[List[int]],
|
|
159
|
+
pred_bboxes: List[List[int]], pred_labels: List[int], pred_scores: List[float],
|
|
160
|
+
gt_bboxes: List[List[int]], gt_labels: List[int]
|
|
162
161
|
) -> Dict:
|
|
163
|
-
class_idxs = list(set(
|
|
162
|
+
class_idxs = list(set(pred_labels + gt_labels))
|
|
164
163
|
result: List[Dict] = []
|
|
165
164
|
pred_bboxes_arr = np.asarray(pred_bboxes)
|
|
166
|
-
pred_classes_arr = np.asarray(
|
|
165
|
+
pred_classes_arr = np.asarray(pred_labels)
|
|
167
166
|
pred_scores_arr = np.asarray(pred_scores)
|
|
168
167
|
gt_bboxes_arr = np.asarray(gt_bboxes)
|
|
169
|
-
gt_classes_arr = np.asarray(
|
|
168
|
+
gt_classes_arr = np.asarray(gt_labels)
|
|
170
169
|
for class_idx in class_idxs:
|
|
171
170
|
pred_filter = pred_classes_arr == class_idx
|
|
172
171
|
gt_filter = gt_classes_arr == class_idx
|
|
@@ -181,8 +180,8 @@ def eval_detections(
|
|
|
181
180
|
return result
|
|
182
181
|
|
|
183
182
|
@func.uda(
|
|
184
|
-
update_types=[ts.JsonType()], value_type=ts.JsonType(),
|
|
185
|
-
class
|
|
183
|
+
update_types=[ts.JsonType()], value_type=ts.JsonType(), allows_std_agg=True, allows_window=False)
|
|
184
|
+
class mean_ap(func.Aggregator):
|
|
186
185
|
def __init__(self):
|
|
187
186
|
self.class_tpfp: Dict[int, List[Dict]] = defaultdict(list)
|
|
188
187
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Callable, TypeVar, Optional
|
|
2
2
|
|
|
3
3
|
import PIL.Image
|
|
4
4
|
import numpy as np
|
|
@@ -7,10 +7,13 @@ import pixeltable as pxt
|
|
|
7
7
|
import pixeltable.env as env
|
|
8
8
|
import pixeltable.type_system as ts
|
|
9
9
|
from pixeltable.func import Batch
|
|
10
|
+
from pixeltable.functions.util import resolve_torch_device
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
@pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
|
|
13
|
-
def sentence_transformer(
|
|
14
|
+
def sentence_transformer(
|
|
15
|
+
sentences: Batch[str], *, model_id: str, normalize_embeddings: bool = False
|
|
16
|
+
) -> Batch[np.ndarray]:
|
|
14
17
|
env.Env.get().require_package('sentence_transformers')
|
|
15
18
|
from sentence_transformers import SentenceTransformer
|
|
16
19
|
|
|
@@ -53,44 +56,60 @@ def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> li
|
|
|
53
56
|
return array.tolist()
|
|
54
57
|
|
|
55
58
|
|
|
56
|
-
@pxt.udf(batch_size=32, return_type=ts.ArrayType((
|
|
59
|
+
@pxt.udf(batch_size=32, return_type=ts.ArrayType((512,), dtype=ts.FloatType(), nullable=False))
|
|
57
60
|
def clip_text(text: Batch[str], *, model_id: str) -> Batch[np.ndarray]:
|
|
58
61
|
env.Env.get().require_package('transformers')
|
|
62
|
+
device = resolve_torch_device('auto')
|
|
63
|
+
import torch
|
|
59
64
|
from transformers import CLIPModel, CLIPProcessor
|
|
60
65
|
|
|
61
|
-
model = _lookup_model(model_id, CLIPModel.from_pretrained)
|
|
66
|
+
model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
|
|
67
|
+
assert model.config.projection_dim == 512
|
|
62
68
|
processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
|
|
63
69
|
|
|
64
|
-
|
|
65
|
-
|
|
70
|
+
with torch.no_grad():
|
|
71
|
+
inputs = processor(text=text, return_tensors='pt', padding=True, truncation=True)
|
|
72
|
+
embeddings = model.get_text_features(**inputs.to(device)).detach().to('cpu').numpy()
|
|
73
|
+
|
|
66
74
|
return [embeddings[i] for i in range(embeddings.shape[0])]
|
|
67
75
|
|
|
68
76
|
|
|
69
|
-
@pxt.udf(batch_size=32, return_type=ts.ArrayType((
|
|
77
|
+
@pxt.udf(batch_size=32, return_type=ts.ArrayType((512,), dtype=ts.FloatType(), nullable=False))
|
|
70
78
|
def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndarray]:
|
|
71
79
|
env.Env.get().require_package('transformers')
|
|
80
|
+
device = resolve_torch_device('auto')
|
|
81
|
+
import torch
|
|
72
82
|
from transformers import CLIPModel, CLIPProcessor
|
|
73
83
|
|
|
74
|
-
model = _lookup_model(model_id, CLIPModel.from_pretrained)
|
|
84
|
+
model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
|
|
85
|
+
assert model.config.projection_dim == 512
|
|
75
86
|
processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
|
|
76
87
|
|
|
77
|
-
|
|
78
|
-
|
|
88
|
+
with torch.no_grad():
|
|
89
|
+
inputs = processor(images=image, return_tensors='pt', padding=True)
|
|
90
|
+
embeddings = model.get_image_features(**inputs.to(device)).detach().to('cpu').numpy()
|
|
91
|
+
|
|
79
92
|
return [embeddings[i] for i in range(embeddings.shape[0])]
|
|
80
93
|
|
|
81
94
|
|
|
82
|
-
@pxt.udf(batch_size=
|
|
95
|
+
@pxt.udf(batch_size=4)
|
|
83
96
|
def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
|
|
84
97
|
env.Env.get().require_package('transformers')
|
|
98
|
+
device = resolve_torch_device('auto')
|
|
99
|
+
import torch
|
|
85
100
|
from transformers import DetrImageProcessor, DetrForObjectDetection
|
|
86
101
|
|
|
87
|
-
model = _lookup_model(
|
|
102
|
+
model = _lookup_model(
|
|
103
|
+
model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision='no_timm'), device=device)
|
|
88
104
|
processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision='no_timm'))
|
|
89
105
|
|
|
90
|
-
|
|
91
|
-
|
|
106
|
+
with torch.no_grad():
|
|
107
|
+
inputs = processor(images=image, return_tensors='pt')
|
|
108
|
+
outputs = model(**inputs.to(device))
|
|
109
|
+
results = processor.post_process_object_detection(
|
|
110
|
+
outputs, threshold=threshold, target_sizes=[(img.height, img.width) for img in image]
|
|
111
|
+
)
|
|
92
112
|
|
|
93
|
-
results = processor.post_process_object_detection(outputs, threshold=threshold)
|
|
94
113
|
return [
|
|
95
114
|
{
|
|
96
115
|
'scores': [score.item() for score in result['scores']],
|
|
@@ -102,14 +121,23 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
|
|
|
102
121
|
]
|
|
103
122
|
|
|
104
123
|
|
|
105
|
-
|
|
106
|
-
|
|
124
|
+
T = TypeVar('T')
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _lookup_model(model_id: str, create: Callable[[str], T], device: Optional[str] = None) -> T:
|
|
128
|
+
from torch import nn
|
|
129
|
+
key = (model_id, create, device) # For safety, include the `create` callable in the cache key
|
|
107
130
|
if key not in _model_cache:
|
|
108
|
-
|
|
131
|
+
model = create(model_id)
|
|
132
|
+
if device is not None:
|
|
133
|
+
model.to(device)
|
|
134
|
+
if isinstance(model, nn.Module):
|
|
135
|
+
model.eval()
|
|
136
|
+
_model_cache[key] = model
|
|
109
137
|
return _model_cache[key]
|
|
110
138
|
|
|
111
139
|
|
|
112
|
-
def _lookup_processor(model_id: str, create: Callable) ->
|
|
140
|
+
def _lookup_processor(model_id: str, create: Callable[[str], T]) -> T:
|
|
113
141
|
key = (model_id, create) # For safety, include the `create` callable in the cache key
|
|
114
142
|
if key not in _processor_cache:
|
|
115
143
|
_processor_cache[key] = create(model_id)
|
pixeltable/functions/openai.py
CHANGED
|
@@ -26,8 +26,8 @@ def openai_client() -> openai.OpenAI:
|
|
|
26
26
|
def _retry(fn: Callable) -> Callable:
|
|
27
27
|
return tenacity.retry(
|
|
28
28
|
retry=tenacity.retry_if_exception_type(openai.RateLimitError),
|
|
29
|
-
wait=tenacity.wait_random_exponential(
|
|
30
|
-
stop=tenacity.stop_after_attempt(
|
|
29
|
+
wait=tenacity.wait_random_exponential(multiplier=3, max=180),
|
|
30
|
+
stop=tenacity.stop_after_attempt(20)
|
|
31
31
|
)(fn)
|
|
32
32
|
|
|
33
33
|
|
pixeltable/functions/util.py
CHANGED
|
@@ -39,3 +39,14 @@ def create_nos_modules() -> List[types.ModuleType]:
|
|
|
39
39
|
setattr(sub_module, model_id, pt_func)
|
|
40
40
|
|
|
41
41
|
return new_modules
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def resolve_torch_device(device: str) -> str:
|
|
45
|
+
import torch
|
|
46
|
+
if device == 'auto':
|
|
47
|
+
if torch.cuda.is_available():
|
|
48
|
+
return 'cuda'
|
|
49
|
+
if torch.backends.mps.is_available():
|
|
50
|
+
return 'mps'
|
|
51
|
+
return 'cpu'
|
|
52
|
+
return device
|
pixeltable/index/base.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import sqlalchemy as sql
|
|
7
|
+
|
|
8
|
+
import pixeltable.catalog as catalog
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class IndexBase(abc.ABC):
|
|
12
|
+
"""
|
|
13
|
+
Internal interface used by the catalog and runtime system to interact with indices:
|
|
14
|
+
- types and expressions needed to create and populate the index value column
|
|
15
|
+
- creating/dropping the index
|
|
16
|
+
- TODO: translating queries into sqlalchemy predicates
|
|
17
|
+
"""
|
|
18
|
+
@abc.abstractmethod
|
|
19
|
+
def __init__(self, c: catalog.Column, **kwargs: Any):
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
@abc.abstractmethod
|
|
23
|
+
def index_value_expr(self) -> 'pixeltable.exprs.Expr':
|
|
24
|
+
"""Return expression that computes the value that goes into the index"""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
@abc.abstractmethod
|
|
28
|
+
def index_sa_type(self) -> sql.sqltypes.TypeEngine:
|
|
29
|
+
"""Return the sqlalchemy type of the index value column"""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
def create_index(self, index_name: str, index_value_col: catalog.Column, conn: sql.engine.Connection) -> None:
|
|
34
|
+
"""Create the index on the index value column"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def display_name(cls) -> str:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
def as_dict(self) -> dict:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
@abc.abstractmethod
|
|
48
|
+
def from_dict(cls, c: catalog.Column, d: dict) -> IndexBase:
|
|
49
|
+
pass
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import pgvector.sqlalchemy
|
|
6
|
+
import sqlalchemy as sql
|
|
7
|
+
|
|
8
|
+
import pixeltable.catalog as catalog
|
|
9
|
+
import pixeltable.exceptions as excs
|
|
10
|
+
import pixeltable.func as func
|
|
11
|
+
import pixeltable.type_system as ts
|
|
12
|
+
from .base import IndexBase
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EmbeddingIndex(IndexBase):
|
|
16
|
+
"""
|
|
17
|
+
Internal interface used by the catalog and runtime system to interact with (embedding) indices:
|
|
18
|
+
- types and expressions needed to create and populate the index value column
|
|
19
|
+
- creating/dropping the index
|
|
20
|
+
- translating 'matches' queries into sqlalchemy predicates
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self, c: catalog.Column, text_embed: Optional[func.Function] = None,
|
|
25
|
+
img_embed: Optional[func.Function] = None):
|
|
26
|
+
if not c.col_type.is_string_type() and not c.col_type.is_image_type():
|
|
27
|
+
raise excs.Error(f'Embedding index requires string or image column')
|
|
28
|
+
if c.col_type.is_string_type() and text_embed is None:
|
|
29
|
+
raise excs.Error(f'Text embedding function is required for column {c.name} (parameter `txt_embed`)')
|
|
30
|
+
if c.col_type.is_image_type() and img_embed is None:
|
|
31
|
+
raise excs.Error(f'Image embedding function is required for column {c.name} (parameter `img_embed`)')
|
|
32
|
+
if text_embed is not None:
|
|
33
|
+
# verify signature
|
|
34
|
+
self._validate_embedding_fn(text_embed, 'txt_embed', ts.ColumnType.Type.STRING)
|
|
35
|
+
if img_embed is not None:
|
|
36
|
+
# verify signature
|
|
37
|
+
self._validate_embedding_fn(img_embed, 'img_embed', ts.ColumnType.Type.IMAGE)
|
|
38
|
+
|
|
39
|
+
from pixeltable.exprs import ColumnRef
|
|
40
|
+
self.value_expr = text_embed(ColumnRef(c)) if c.col_type.is_string_type() else img_embed(ColumnRef(c))
|
|
41
|
+
assert self.value_expr.col_type.is_array_type()
|
|
42
|
+
self.txt_embed = text_embed
|
|
43
|
+
self.img_embed = img_embed
|
|
44
|
+
vector_size = self.value_expr.col_type.shape[0]
|
|
45
|
+
assert vector_size is not None
|
|
46
|
+
self.index_col_type = pgvector.sqlalchemy.Vector(vector_size)
|
|
47
|
+
|
|
48
|
+
def index_value_expr(self) -> 'pixeltable.exprs.Expr':
|
|
49
|
+
"""Return expression that computes the value that goes into the index"""
|
|
50
|
+
return self.value_expr
|
|
51
|
+
|
|
52
|
+
def index_sa_type(self) -> sql.sqltypes.TypeEngine:
|
|
53
|
+
"""Return the sqlalchemy type of the index value column"""
|
|
54
|
+
return self.index_col_type
|
|
55
|
+
|
|
56
|
+
def create_index(self, index_name: str, index_value_col: catalog.Column, conn: sql.engine.Connection) -> None:
|
|
57
|
+
"""Create the index on the index value column"""
|
|
58
|
+
idx = sql.Index(
|
|
59
|
+
index_name, index_value_col.sa_col,
|
|
60
|
+
postgresql_using='hnsw',
|
|
61
|
+
postgresql_with={'m': 16, 'ef_construction': 64},
|
|
62
|
+
postgresql_ops={index_value_col.sa_col.name: 'vector_cosine_ops'}
|
|
63
|
+
)
|
|
64
|
+
idx.create(bind=conn)
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def display_name(cls) -> str:
|
|
68
|
+
return 'embedding'
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def _validate_embedding_fn(cls, embed_fn: func.Function, name: str, expected_type: ts.ColumnType.Type) -> None:
|
|
72
|
+
"""Validate the signature"""
|
|
73
|
+
assert isinstance(embed_fn, func.Function)
|
|
74
|
+
sig = embed_fn.signature
|
|
75
|
+
if not sig.return_type.is_array_type():
|
|
76
|
+
raise excs.Error(f'{name} must return an array, but returns {sig.return_type}')
|
|
77
|
+
else:
|
|
78
|
+
shape = sig.return_type.shape
|
|
79
|
+
if len(shape) != 1 or shape[0] == None:
|
|
80
|
+
raise excs.Error(f'{name} must return a 1D array of a specific length, but returns {sig.return_type}')
|
|
81
|
+
if len(sig.parameters) != 1 or sig.parameters_by_pos[0].col_type.type_enum != expected_type:
|
|
82
|
+
raise excs.Error(
|
|
83
|
+
f'{name} must take a single {expected_type.name.lower()} parameter, but has signature {sig}')
|
|
84
|
+
|
|
85
|
+
def as_dict(self) -> dict:
|
|
86
|
+
return {
|
|
87
|
+
'txt_embed': None if self.txt_embed is None else self.txt_embed.as_dict(),
|
|
88
|
+
'img_embed': None if self.img_embed is None else self.img_embed.as_dict()
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def from_dict(cls, c: catalog.Column, d: dict) -> EmbeddingIndex:
|
|
93
|
+
txt_embed = func.Function.from_dict(d['txt_embed']) if d['txt_embed'] is not None else None
|
|
94
|
+
img_embed = func.Function.from_dict(d['img_embed']) if d['img_embed'] is not None else None
|
|
95
|
+
return cls(c, text_embed=txt_embed, img_embed=img_embed)
|
pixeltable/metadata/schema.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Optional, List,
|
|
1
|
+
from typing import Optional, List, get_type_hints, Type, Any, TypeVar, Tuple, Union
|
|
2
2
|
import platform
|
|
3
3
|
import uuid
|
|
4
4
|
import dataclasses
|
|
@@ -71,16 +71,43 @@ class Dir(Base):
|
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
@dataclasses.dataclass
|
|
74
|
-
class
|
|
74
|
+
class ColumnMd:
|
|
75
75
|
"""
|
|
76
|
-
Records
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
76
|
+
Records the non-versioned metadata of a column.
|
|
77
|
+
- immutable attributes: type, primary key, etc.
|
|
78
|
+
- when a column was added/dropped, which is needed to GC unreachable storage columns
|
|
79
|
+
(a column that was added after table snapshot n and dropped before table snapshot n+1 can be removed
|
|
80
|
+
from the stored table).
|
|
80
81
|
"""
|
|
81
|
-
|
|
82
|
+
id: int
|
|
82
83
|
schema_version_add: int
|
|
83
84
|
schema_version_drop: Optional[int]
|
|
85
|
+
col_type: dict
|
|
86
|
+
|
|
87
|
+
# if True, is part of the primary key
|
|
88
|
+
is_pk: bool
|
|
89
|
+
|
|
90
|
+
# if set, this is a computed column
|
|
91
|
+
value_expr: Optional[dict]
|
|
92
|
+
|
|
93
|
+
# if True, the column is present in the stored table
|
|
94
|
+
stored: Optional[bool]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclasses.dataclass
|
|
98
|
+
class IndexMd:
|
|
99
|
+
"""
|
|
100
|
+
Metadata needed to instantiate an EmbeddingIndex
|
|
101
|
+
"""
|
|
102
|
+
id: int
|
|
103
|
+
name: str
|
|
104
|
+
indexed_col_id: int # column being indexed
|
|
105
|
+
index_val_col_id: int # column holding the values to be indexed
|
|
106
|
+
index_val_undo_col_id: int # column holding index values for deleted rows
|
|
107
|
+
schema_version_add: int
|
|
108
|
+
schema_version_drop: Optional[int]
|
|
109
|
+
class_fqn: str
|
|
110
|
+
init_args: dict[str, Any]
|
|
84
111
|
|
|
85
112
|
|
|
86
113
|
@dataclasses.dataclass
|
|
@@ -91,13 +118,13 @@ class ViewMd:
|
|
|
91
118
|
base_versions: List[Tuple[str, Optional[int]]]
|
|
92
119
|
|
|
93
120
|
# filter predicate applied to the base table; view-only
|
|
94
|
-
predicate: Optional[
|
|
121
|
+
predicate: Optional[dict[str, Any]]
|
|
95
122
|
|
|
96
123
|
# ComponentIterator subclass; only for component views
|
|
97
124
|
iterator_class_fqn: Optional[str]
|
|
98
125
|
|
|
99
126
|
# args to pass to the iterator class constructor; only for component views
|
|
100
|
-
iterator_args: Optional[
|
|
127
|
+
iterator_args: Optional[dict[str, Any]]
|
|
101
128
|
|
|
102
129
|
|
|
103
130
|
@dataclasses.dataclass
|
|
@@ -109,15 +136,15 @@ class TableMd:
|
|
|
109
136
|
# each version has a corresponding schema version (current_version >= current_schema_version)
|
|
110
137
|
current_schema_version: int
|
|
111
138
|
|
|
112
|
-
# used to assign Column.id
|
|
113
|
-
|
|
139
|
+
next_col_id: int # used to assign Column.id
|
|
140
|
+
next_idx_id: int # used to assign IndexMd.id
|
|
114
141
|
|
|
115
142
|
# - used to assign the rowid column in the storage table
|
|
116
143
|
# - every row is assigned a unique and immutable rowid on insertion
|
|
117
144
|
next_row_id: int
|
|
118
145
|
|
|
119
|
-
|
|
120
|
-
|
|
146
|
+
column_md: dict[int, ColumnMd] # col_id -> ColumnMd
|
|
147
|
+
index_md: dict[int, IndexMd] # index_id -> IndexMd
|
|
121
148
|
view_md: Optional[ViewMd]
|
|
122
149
|
|
|
123
150
|
|
|
@@ -155,24 +182,20 @@ class TableVersion(Base):
|
|
|
155
182
|
@dataclasses.dataclass
|
|
156
183
|
class SchemaColumn:
|
|
157
184
|
"""
|
|
158
|
-
Records the
|
|
159
|
-
Contains the full set of columns for each new schema version: one record per (column x schema version).
|
|
185
|
+
Records the versioned metadata of a column.
|
|
160
186
|
"""
|
|
161
187
|
pos: int
|
|
162
188
|
name: str
|
|
163
|
-
col_type: dict
|
|
164
|
-
is_pk: bool
|
|
165
|
-
value_expr: Optional[dict]
|
|
166
|
-
stored: Optional[bool]
|
|
167
|
-
# if True, creates vector index for this column
|
|
168
|
-
is_indexed: bool
|
|
169
189
|
|
|
170
190
|
|
|
171
191
|
@dataclasses.dataclass
|
|
172
192
|
class TableSchemaVersionMd:
|
|
193
|
+
"""
|
|
194
|
+
Records all versioned table metadata.
|
|
195
|
+
"""
|
|
173
196
|
schema_version: int
|
|
174
197
|
preceding_schema_version: Optional[int]
|
|
175
|
-
columns:
|
|
198
|
+
columns: dict[int, SchemaColumn] # col_id -> SchemaColumn
|
|
176
199
|
num_retained_versions: int
|
|
177
200
|
comment: str
|
|
178
201
|
|