pixeltable 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/catalog/column.py +26 -49
- pixeltable/catalog/insertable_table.py +7 -4
- pixeltable/catalog/table.py +163 -57
- pixeltable/catalog/table_version.py +416 -140
- pixeltable/catalog/table_version_path.py +2 -2
- pixeltable/client.py +72 -6
- pixeltable/dataframe.py +65 -21
- pixeltable/env.py +52 -53
- pixeltable/exec/cache_prefetch_node.py +1 -1
- pixeltable/exec/in_memory_data_node.py +11 -7
- pixeltable/exprs/comparison.py +3 -3
- pixeltable/exprs/data_row.py +5 -1
- pixeltable/exprs/literal.py +16 -4
- pixeltable/exprs/row_builder.py +8 -40
- pixeltable/ext/__init__.py +5 -0
- pixeltable/ext/functions/yolox.py +92 -0
- pixeltable/func/aggregate_function.py +15 -15
- pixeltable/func/expr_template_function.py +9 -1
- pixeltable/func/globals.py +24 -14
- pixeltable/func/signature.py +18 -12
- pixeltable/func/udf.py +7 -2
- pixeltable/functions/__init__.py +9 -9
- pixeltable/functions/eval.py +7 -8
- pixeltable/functions/fireworks.py +10 -37
- pixeltable/functions/huggingface.py +47 -19
- pixeltable/functions/openai.py +192 -24
- pixeltable/functions/together.py +104 -9
- pixeltable/functions/util.py +11 -0
- pixeltable/index/__init__.py +2 -0
- pixeltable/index/base.py +49 -0
- pixeltable/index/embedding_index.py +95 -0
- pixeltable/metadata/schema.py +45 -22
- pixeltable/plan.py +15 -34
- pixeltable/store.py +38 -41
- pixeltable/tests/conftest.py +8 -14
- pixeltable/tests/ext/test_yolox.py +21 -0
- pixeltable/tests/functions/test_fireworks.py +43 -0
- pixeltable/tests/functions/test_functions.py +60 -0
- pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +7 -143
- pixeltable/tests/functions/test_openai.py +162 -0
- pixeltable/tests/functions/test_together.py +112 -0
- pixeltable/tests/test_component_view.py +14 -5
- pixeltable/tests/test_dataframe.py +23 -22
- pixeltable/tests/test_exprs.py +99 -102
- pixeltable/tests/test_function.py +51 -43
- pixeltable/tests/test_index.py +138 -0
- pixeltable/tests/test_migration.py +2 -1
- pixeltable/tests/test_snapshot.py +24 -1
- pixeltable/tests/test_table.py +205 -26
- pixeltable/tests/test_types.py +30 -0
- pixeltable/tests/test_video.py +16 -16
- pixeltable/tests/test_view.py +5 -0
- pixeltable/tests/utils.py +171 -14
- pixeltable/tool/create_test_db_dump.py +16 -0
- pixeltable/type_system.py +77 -128
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/hf_datasets.py +157 -0
- pixeltable/utils/parquet.py +68 -27
- pixeltable/utils/pytorch.py +16 -97
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/METADATA +35 -28
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/RECORD +63 -50
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/WHEEL +0 -0
pixeltable/functions/openai.py
CHANGED
|
@@ -1,9 +1,14 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import io
|
|
3
|
-
|
|
3
|
+
import pathlib
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import Optional, TypeVar, Union, Callable
|
|
4
6
|
|
|
5
7
|
import PIL.Image
|
|
6
8
|
import numpy as np
|
|
9
|
+
import openai
|
|
10
|
+
import tenacity
|
|
11
|
+
from openai._types import NOT_GIVEN, NotGiven
|
|
7
12
|
|
|
8
13
|
import pixeltable as pxt
|
|
9
14
|
import pixeltable.type_system as ts
|
|
@@ -11,43 +16,148 @@ from pixeltable import env
|
|
|
11
16
|
from pixeltable.func import Batch
|
|
12
17
|
|
|
13
18
|
|
|
19
|
+
def openai_client() -> openai.OpenAI:
|
|
20
|
+
return env.Env.get().get_client('openai', lambda api_key: openai.OpenAI(api_key=api_key))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Exponential backoff decorator using tenacity.
|
|
24
|
+
# TODO(aaron-siegel): Right now this hardwires random exponential backoff with defaults suggested
|
|
25
|
+
# by OpenAI. Should we investigate making this more customizable in the future?
|
|
26
|
+
def _retry(fn: Callable) -> Callable:
|
|
27
|
+
return tenacity.retry(
|
|
28
|
+
retry=tenacity.retry_if_exception_type(openai.RateLimitError),
|
|
29
|
+
wait=tenacity.wait_random_exponential(multiplier=3, max=180),
|
|
30
|
+
stop=tenacity.stop_after_attempt(20)
|
|
31
|
+
)(fn)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
#####################################
|
|
35
|
+
# Audio Endpoints
|
|
36
|
+
|
|
37
|
+
@pxt.udf(return_type=ts.AudioType())
|
|
38
|
+
@_retry
|
|
39
|
+
def speech(
|
|
40
|
+
input: str,
|
|
41
|
+
*,
|
|
42
|
+
model: str,
|
|
43
|
+
voice: str,
|
|
44
|
+
response_format: Optional[str] = None,
|
|
45
|
+
speed: Optional[float] = None
|
|
46
|
+
) -> str:
|
|
47
|
+
content = openai_client().audio.speech.create(
|
|
48
|
+
input=input,
|
|
49
|
+
model=model,
|
|
50
|
+
voice=voice,
|
|
51
|
+
response_format=_opt(response_format),
|
|
52
|
+
speed=_opt(speed)
|
|
53
|
+
)
|
|
54
|
+
ext = response_format or 'mp3'
|
|
55
|
+
output_filename = str(env.Env.get().tmp_dir / f"{uuid.uuid4()}.{ext}")
|
|
56
|
+
content.stream_to_file(output_filename, chunk_size=1 << 20)
|
|
57
|
+
return output_filename
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pxt.udf(
|
|
61
|
+
param_types=[ts.AudioType(), ts.StringType(), ts.StringType(nullable=True),
|
|
62
|
+
ts.StringType(nullable=True), ts.FloatType(nullable=True)]
|
|
63
|
+
)
|
|
64
|
+
@_retry
|
|
65
|
+
def transcriptions(
|
|
66
|
+
audio: str,
|
|
67
|
+
*,
|
|
68
|
+
model: str,
|
|
69
|
+
language: Optional[str] = None,
|
|
70
|
+
prompt: Optional[str] = None,
|
|
71
|
+
temperature: Optional[float] = None
|
|
72
|
+
) -> dict:
|
|
73
|
+
file = pathlib.Path(audio)
|
|
74
|
+
transcription = openai_client().audio.transcriptions.create(
|
|
75
|
+
file=file,
|
|
76
|
+
model=model,
|
|
77
|
+
language=_opt(language),
|
|
78
|
+
prompt=_opt(prompt),
|
|
79
|
+
temperature=_opt(temperature)
|
|
80
|
+
)
|
|
81
|
+
return transcription.dict()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@pxt.udf(
|
|
85
|
+
param_types=[ts.AudioType(), ts.StringType(), ts.StringType(nullable=True), ts.FloatType(nullable=True)]
|
|
86
|
+
)
|
|
87
|
+
@_retry
|
|
88
|
+
def translations(
|
|
89
|
+
audio: str,
|
|
90
|
+
*,
|
|
91
|
+
model: str,
|
|
92
|
+
prompt: Optional[str] = None,
|
|
93
|
+
temperature: Optional[float] = None
|
|
94
|
+
) -> dict:
|
|
95
|
+
file = pathlib.Path(audio)
|
|
96
|
+
translation = openai_client().audio.translations.create(
|
|
97
|
+
file=file,
|
|
98
|
+
model=model,
|
|
99
|
+
prompt=_opt(prompt),
|
|
100
|
+
temperature=_opt(temperature)
|
|
101
|
+
)
|
|
102
|
+
return translation.dict()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
#####################################
|
|
106
|
+
# Chat Endpoints
|
|
107
|
+
|
|
14
108
|
@pxt.udf
|
|
109
|
+
@_retry
|
|
15
110
|
def chat_completions(
|
|
16
111
|
messages: list,
|
|
112
|
+
*,
|
|
17
113
|
model: str,
|
|
18
114
|
frequency_penalty: Optional[float] = None,
|
|
19
|
-
logit_bias: Optional[dict] = None,
|
|
115
|
+
logit_bias: Optional[dict[str, int]] = None,
|
|
116
|
+
logprobs: Optional[bool] = None,
|
|
117
|
+
top_logprobs: Optional[int] = None,
|
|
20
118
|
max_tokens: Optional[int] = None,
|
|
21
119
|
n: Optional[int] = None,
|
|
22
120
|
presence_penalty: Optional[float] = None,
|
|
23
121
|
response_format: Optional[dict] = None,
|
|
24
122
|
seed: Optional[int] = None,
|
|
123
|
+
stop: Optional[list[str]] = None,
|
|
124
|
+
temperature: Optional[float] = None,
|
|
25
125
|
top_p: Optional[float] = None,
|
|
26
|
-
|
|
126
|
+
tools: Optional[list[dict]] = None,
|
|
127
|
+
tool_choice: Optional[dict] = None,
|
|
128
|
+
user: Optional[str] = None
|
|
27
129
|
) -> dict:
|
|
28
|
-
|
|
29
|
-
result = env.Env.get().openai_client.chat.completions.create(
|
|
130
|
+
result = openai_client().chat.completions.create(
|
|
30
131
|
messages=messages,
|
|
31
132
|
model=model,
|
|
32
|
-
frequency_penalty=frequency_penalty
|
|
33
|
-
logit_bias=logit_bias
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
133
|
+
frequency_penalty=_opt(frequency_penalty),
|
|
134
|
+
logit_bias=_opt(logit_bias),
|
|
135
|
+
logprobs=_opt(logprobs),
|
|
136
|
+
top_logprobs=_opt(top_logprobs),
|
|
137
|
+
max_tokens=_opt(max_tokens),
|
|
138
|
+
n=_opt(n),
|
|
139
|
+
presence_penalty=_opt(presence_penalty),
|
|
140
|
+
response_format=_opt(response_format),
|
|
141
|
+
seed=_opt(seed),
|
|
142
|
+
stop=_opt(stop),
|
|
143
|
+
temperature=_opt(temperature),
|
|
144
|
+
top_p=_opt(top_p),
|
|
145
|
+
tools=_opt(tools),
|
|
146
|
+
tool_choice=_opt(tool_choice),
|
|
147
|
+
user=_opt(user)
|
|
41
148
|
)
|
|
42
149
|
return result.dict()
|
|
43
150
|
|
|
44
151
|
|
|
45
152
|
@pxt.udf
|
|
153
|
+
@_retry
|
|
46
154
|
def vision(
|
|
47
155
|
prompt: str,
|
|
48
156
|
image: PIL.Image.Image,
|
|
157
|
+
*,
|
|
49
158
|
model: str = 'gpt-4-vision-preview'
|
|
50
159
|
) -> str:
|
|
160
|
+
# TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
|
|
51
161
|
bytes_arr = io.BytesIO()
|
|
52
162
|
image.save(bytes_arr, format='png')
|
|
53
163
|
b64_bytes = base64.b64encode(bytes_arr.getvalue())
|
|
@@ -61,28 +171,86 @@ def vision(
|
|
|
61
171
|
}}
|
|
62
172
|
]}
|
|
63
173
|
]
|
|
64
|
-
result =
|
|
174
|
+
result = openai_client().chat.completions.create(
|
|
65
175
|
messages=messages,
|
|
66
176
|
model=model
|
|
67
177
|
)
|
|
68
178
|
return result.choices[0].message.content
|
|
69
179
|
|
|
70
180
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
result = env.Env().get().openai_client.moderations.create(input=input, model=model)
|
|
74
|
-
return result.dict()
|
|
75
|
-
|
|
181
|
+
#####################################
|
|
182
|
+
# Embeddings Endpoints
|
|
76
183
|
|
|
77
184
|
@pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
|
|
78
|
-
|
|
79
|
-
|
|
185
|
+
@_retry
|
|
186
|
+
def embeddings(
|
|
187
|
+
input: Batch[str],
|
|
188
|
+
*,
|
|
189
|
+
model: str,
|
|
190
|
+
user: Optional[str] = None
|
|
191
|
+
) -> Batch[np.ndarray]:
|
|
192
|
+
result = openai_client().embeddings.create(
|
|
80
193
|
input=input,
|
|
81
194
|
model=model,
|
|
195
|
+
user=_opt(user),
|
|
82
196
|
encoding_format='float'
|
|
83
197
|
)
|
|
84
|
-
|
|
198
|
+
return [
|
|
85
199
|
np.array(data.embedding, dtype=np.float64)
|
|
86
200
|
for data in result.data
|
|
87
201
|
]
|
|
88
|
-
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
#####################################
|
|
205
|
+
# Images Endpoints
|
|
206
|
+
|
|
207
|
+
@pxt.udf
|
|
208
|
+
@_retry
|
|
209
|
+
def image_generations(
|
|
210
|
+
prompt: str,
|
|
211
|
+
*,
|
|
212
|
+
model: Optional[str] = None,
|
|
213
|
+
quality: Optional[str] = None,
|
|
214
|
+
size: Optional[str] = None,
|
|
215
|
+
style: Optional[str] = None,
|
|
216
|
+
user: Optional[str] = None
|
|
217
|
+
) -> PIL.Image.Image:
|
|
218
|
+
# TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
|
|
219
|
+
result = openai_client().images.generate(
|
|
220
|
+
prompt=prompt,
|
|
221
|
+
model=_opt(model),
|
|
222
|
+
quality=_opt(quality),
|
|
223
|
+
size=_opt(size),
|
|
224
|
+
style=_opt(style),
|
|
225
|
+
user=_opt(user),
|
|
226
|
+
response_format="b64_json"
|
|
227
|
+
)
|
|
228
|
+
b64_str = result.data[0].b64_json
|
|
229
|
+
b64_bytes = base64.b64decode(b64_str)
|
|
230
|
+
img = PIL.Image.open(io.BytesIO(b64_bytes))
|
|
231
|
+
img.load()
|
|
232
|
+
return img
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
#####################################
|
|
236
|
+
# Moderations Endpoints
|
|
237
|
+
|
|
238
|
+
@pxt.udf
|
|
239
|
+
@_retry
|
|
240
|
+
def moderations(
|
|
241
|
+
input: str,
|
|
242
|
+
*,
|
|
243
|
+
model: Optional[str] = None
|
|
244
|
+
) -> dict:
|
|
245
|
+
result = openai_client().moderations.create(
|
|
246
|
+
input=input,
|
|
247
|
+
model=_opt(model)
|
|
248
|
+
)
|
|
249
|
+
return result.dict()
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
_T = TypeVar('_T')
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _opt(arg: _T) -> Union[_T, NotGiven]:
|
|
256
|
+
return arg if arg is not None else NOT_GIVEN
|
pixeltable/functions/together.py
CHANGED
|
@@ -1,27 +1,122 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import io
|
|
1
3
|
from typing import Optional
|
|
2
4
|
|
|
5
|
+
import PIL.Image
|
|
6
|
+
import numpy as np
|
|
7
|
+
import together
|
|
8
|
+
|
|
3
9
|
import pixeltable as pxt
|
|
10
|
+
from pixeltable import env
|
|
11
|
+
from pixeltable.func import Batch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def together_client() -> together.Together:
|
|
15
|
+
return env.Env.get().get_client('together', lambda api_key: together.Together(api_key=api_key))
|
|
4
16
|
|
|
5
17
|
|
|
6
18
|
@pxt.udf
|
|
7
19
|
def completions(
|
|
8
20
|
prompt: str,
|
|
21
|
+
*,
|
|
9
22
|
model: str,
|
|
10
23
|
max_tokens: Optional[int] = None,
|
|
11
|
-
repetition_penalty: Optional[float] = None,
|
|
12
24
|
stop: Optional[list] = None,
|
|
13
|
-
|
|
25
|
+
temperature: Optional[float] = None,
|
|
14
26
|
top_p: Optional[float] = None,
|
|
15
|
-
|
|
27
|
+
top_k: Optional[int] = None,
|
|
28
|
+
repetition_penalty: Optional[float] = None,
|
|
29
|
+
logprobs: Optional[int] = None,
|
|
30
|
+
echo: Optional[bool] = None,
|
|
31
|
+
n: Optional[int] = None,
|
|
32
|
+
safety_model: Optional[str] = None
|
|
16
33
|
) -> dict:
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
model,
|
|
34
|
+
return together_client().completions.create(
|
|
35
|
+
prompt=prompt,
|
|
36
|
+
model=model,
|
|
21
37
|
max_tokens=max_tokens,
|
|
22
|
-
repetition_penalty=repetition_penalty,
|
|
23
38
|
stop=stop,
|
|
39
|
+
temperature=temperature,
|
|
40
|
+
top_p=top_p,
|
|
24
41
|
top_k=top_k,
|
|
42
|
+
repetition_penalty=repetition_penalty,
|
|
43
|
+
logprobs=logprobs,
|
|
44
|
+
echo=echo,
|
|
45
|
+
n=n,
|
|
46
|
+
safety_model=safety_model
|
|
47
|
+
).dict()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pxt.udf
|
|
51
|
+
def chat_completions(
|
|
52
|
+
messages: list[dict[str, str]],
|
|
53
|
+
*,
|
|
54
|
+
model: str,
|
|
55
|
+
max_tokens: Optional[int] = None,
|
|
56
|
+
stop: Optional[list[str]] = None,
|
|
57
|
+
temperature: Optional[float] = None,
|
|
58
|
+
top_p: Optional[float] = None,
|
|
59
|
+
top_k: Optional[int] = None,
|
|
60
|
+
repetition_penalty: Optional[float] = None,
|
|
61
|
+
logprobs: Optional[int] = None,
|
|
62
|
+
echo: Optional[bool] = None,
|
|
63
|
+
n: Optional[int] = None,
|
|
64
|
+
safety_model: Optional[str] = None,
|
|
65
|
+
response_format: Optional[dict] = None,
|
|
66
|
+
tools: Optional[dict] = None,
|
|
67
|
+
tool_choice: Optional[dict] = None
|
|
68
|
+
) -> dict:
|
|
69
|
+
return together_client().chat.completions.create(
|
|
70
|
+
messages=messages,
|
|
71
|
+
model=model,
|
|
72
|
+
max_tokens=max_tokens,
|
|
73
|
+
stop=stop,
|
|
74
|
+
temperature=temperature,
|
|
25
75
|
top_p=top_p,
|
|
26
|
-
|
|
76
|
+
top_k=top_k,
|
|
77
|
+
repetition_penalty=repetition_penalty,
|
|
78
|
+
logprobs=logprobs,
|
|
79
|
+
echo=echo,
|
|
80
|
+
n=n,
|
|
81
|
+
safety_model=safety_model,
|
|
82
|
+
response_format=response_format,
|
|
83
|
+
tools=tools,
|
|
84
|
+
tool_choice=tool_choice
|
|
85
|
+
).dict()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@pxt.udf(batch_size=32, return_type=pxt.ArrayType((None,), dtype=pxt.FloatType()))
|
|
89
|
+
def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
|
|
90
|
+
result = together_client().embeddings.create(input=input, model=model)
|
|
91
|
+
return [
|
|
92
|
+
np.array(data.embedding, dtype=np.float64)
|
|
93
|
+
for data in result.data
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pxt.udf
|
|
98
|
+
def image_generations(
|
|
99
|
+
prompt: str,
|
|
100
|
+
*,
|
|
101
|
+
model: str,
|
|
102
|
+
steps: Optional[int] = None,
|
|
103
|
+
seed: Optional[int] = None,
|
|
104
|
+
height: Optional[int] = None,
|
|
105
|
+
width: Optional[int] = None,
|
|
106
|
+
negative_prompt: Optional[str] = None,
|
|
107
|
+
) -> PIL.Image.Image:
|
|
108
|
+
# TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
|
|
109
|
+
result = together_client().images.generate(
|
|
110
|
+
prompt=prompt,
|
|
111
|
+
model=model,
|
|
112
|
+
steps=steps,
|
|
113
|
+
seed=seed,
|
|
114
|
+
height=height,
|
|
115
|
+
width=width,
|
|
116
|
+
negative_prompt=negative_prompt
|
|
27
117
|
)
|
|
118
|
+
b64_str = result.data[0].b64_json
|
|
119
|
+
b64_bytes = base64.b64decode(b64_str)
|
|
120
|
+
img = PIL.Image.open(io.BytesIO(b64_bytes))
|
|
121
|
+
img.load()
|
|
122
|
+
return img
|
pixeltable/functions/util.py
CHANGED
|
@@ -39,3 +39,14 @@ def create_nos_modules() -> List[types.ModuleType]:
|
|
|
39
39
|
setattr(sub_module, model_id, pt_func)
|
|
40
40
|
|
|
41
41
|
return new_modules
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def resolve_torch_device(device: str) -> str:
|
|
45
|
+
import torch
|
|
46
|
+
if device == 'auto':
|
|
47
|
+
if torch.cuda.is_available():
|
|
48
|
+
return 'cuda'
|
|
49
|
+
if torch.backends.mps.is_available():
|
|
50
|
+
return 'mps'
|
|
51
|
+
return 'cpu'
|
|
52
|
+
return device
|
pixeltable/index/base.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import sqlalchemy as sql
|
|
7
|
+
|
|
8
|
+
import pixeltable.catalog as catalog
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class IndexBase(abc.ABC):
|
|
12
|
+
"""
|
|
13
|
+
Internal interface used by the catalog and runtime system to interact with indices:
|
|
14
|
+
- types and expressions needed to create and populate the index value column
|
|
15
|
+
- creating/dropping the index
|
|
16
|
+
- TODO: translating queries into sqlalchemy predicates
|
|
17
|
+
"""
|
|
18
|
+
@abc.abstractmethod
|
|
19
|
+
def __init__(self, c: catalog.Column, **kwargs: Any):
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
@abc.abstractmethod
|
|
23
|
+
def index_value_expr(self) -> 'pixeltable.exprs.Expr':
|
|
24
|
+
"""Return expression that computes the value that goes into the index"""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
@abc.abstractmethod
|
|
28
|
+
def index_sa_type(self) -> sql.sqltypes.TypeEngine:
|
|
29
|
+
"""Return the sqlalchemy type of the index value column"""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
def create_index(self, index_name: str, index_value_col: catalog.Column, conn: sql.engine.Connection) -> None:
|
|
34
|
+
"""Create the index on the index value column"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def display_name(cls) -> str:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
def as_dict(self) -> dict:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
@abc.abstractmethod
|
|
48
|
+
def from_dict(cls, c: catalog.Column, d: dict) -> IndexBase:
|
|
49
|
+
pass
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import pgvector.sqlalchemy
|
|
6
|
+
import sqlalchemy as sql
|
|
7
|
+
|
|
8
|
+
import pixeltable.catalog as catalog
|
|
9
|
+
import pixeltable.exceptions as excs
|
|
10
|
+
import pixeltable.func as func
|
|
11
|
+
import pixeltable.type_system as ts
|
|
12
|
+
from .base import IndexBase
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EmbeddingIndex(IndexBase):
|
|
16
|
+
"""
|
|
17
|
+
Internal interface used by the catalog and runtime system to interact with (embedding) indices:
|
|
18
|
+
- types and expressions needed to create and populate the index value column
|
|
19
|
+
- creating/dropping the index
|
|
20
|
+
- translating 'matches' queries into sqlalchemy predicates
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self, c: catalog.Column, text_embed: Optional[func.Function] = None,
|
|
25
|
+
img_embed: Optional[func.Function] = None):
|
|
26
|
+
if not c.col_type.is_string_type() and not c.col_type.is_image_type():
|
|
27
|
+
raise excs.Error(f'Embedding index requires string or image column')
|
|
28
|
+
if c.col_type.is_string_type() and text_embed is None:
|
|
29
|
+
raise excs.Error(f'Text embedding function is required for column {c.name} (parameter `txt_embed`)')
|
|
30
|
+
if c.col_type.is_image_type() and img_embed is None:
|
|
31
|
+
raise excs.Error(f'Image embedding function is required for column {c.name} (parameter `img_embed`)')
|
|
32
|
+
if text_embed is not None:
|
|
33
|
+
# verify signature
|
|
34
|
+
self._validate_embedding_fn(text_embed, 'txt_embed', ts.ColumnType.Type.STRING)
|
|
35
|
+
if img_embed is not None:
|
|
36
|
+
# verify signature
|
|
37
|
+
self._validate_embedding_fn(img_embed, 'img_embed', ts.ColumnType.Type.IMAGE)
|
|
38
|
+
|
|
39
|
+
from pixeltable.exprs import ColumnRef
|
|
40
|
+
self.value_expr = text_embed(ColumnRef(c)) if c.col_type.is_string_type() else img_embed(ColumnRef(c))
|
|
41
|
+
assert self.value_expr.col_type.is_array_type()
|
|
42
|
+
self.txt_embed = text_embed
|
|
43
|
+
self.img_embed = img_embed
|
|
44
|
+
vector_size = self.value_expr.col_type.shape[0]
|
|
45
|
+
assert vector_size is not None
|
|
46
|
+
self.index_col_type = pgvector.sqlalchemy.Vector(vector_size)
|
|
47
|
+
|
|
48
|
+
def index_value_expr(self) -> 'pixeltable.exprs.Expr':
|
|
49
|
+
"""Return expression that computes the value that goes into the index"""
|
|
50
|
+
return self.value_expr
|
|
51
|
+
|
|
52
|
+
def index_sa_type(self) -> sql.sqltypes.TypeEngine:
|
|
53
|
+
"""Return the sqlalchemy type of the index value column"""
|
|
54
|
+
return self.index_col_type
|
|
55
|
+
|
|
56
|
+
def create_index(self, index_name: str, index_value_col: catalog.Column, conn: sql.engine.Connection) -> None:
|
|
57
|
+
"""Create the index on the index value column"""
|
|
58
|
+
idx = sql.Index(
|
|
59
|
+
index_name, index_value_col.sa_col,
|
|
60
|
+
postgresql_using='hnsw',
|
|
61
|
+
postgresql_with={'m': 16, 'ef_construction': 64},
|
|
62
|
+
postgresql_ops={index_value_col.sa_col.name: 'vector_cosine_ops'}
|
|
63
|
+
)
|
|
64
|
+
idx.create(bind=conn)
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def display_name(cls) -> str:
|
|
68
|
+
return 'embedding'
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def _validate_embedding_fn(cls, embed_fn: func.Function, name: str, expected_type: ts.ColumnType.Type) -> None:
|
|
72
|
+
"""Validate the signature"""
|
|
73
|
+
assert isinstance(embed_fn, func.Function)
|
|
74
|
+
sig = embed_fn.signature
|
|
75
|
+
if not sig.return_type.is_array_type():
|
|
76
|
+
raise excs.Error(f'{name} must return an array, but returns {sig.return_type}')
|
|
77
|
+
else:
|
|
78
|
+
shape = sig.return_type.shape
|
|
79
|
+
if len(shape) != 1 or shape[0] == None:
|
|
80
|
+
raise excs.Error(f'{name} must return a 1D array of a specific length, but returns {sig.return_type}')
|
|
81
|
+
if len(sig.parameters) != 1 or sig.parameters_by_pos[0].col_type.type_enum != expected_type:
|
|
82
|
+
raise excs.Error(
|
|
83
|
+
f'{name} must take a single {expected_type.name.lower()} parameter, but has signature {sig}')
|
|
84
|
+
|
|
85
|
+
def as_dict(self) -> dict:
|
|
86
|
+
return {
|
|
87
|
+
'txt_embed': None if self.txt_embed is None else self.txt_embed.as_dict(),
|
|
88
|
+
'img_embed': None if self.img_embed is None else self.img_embed.as_dict()
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def from_dict(cls, c: catalog.Column, d: dict) -> EmbeddingIndex:
|
|
93
|
+
txt_embed = func.Function.from_dict(d['txt_embed']) if d['txt_embed'] is not None else None
|
|
94
|
+
img_embed = func.Function.from_dict(d['img_embed']) if d['img_embed'] is not None else None
|
|
95
|
+
return cls(c, text_embed=txt_embed, img_embed=img_embed)
|
pixeltable/metadata/schema.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Optional, List,
|
|
1
|
+
from typing import Optional, List, get_type_hints, Type, Any, TypeVar, Tuple, Union
|
|
2
2
|
import platform
|
|
3
3
|
import uuid
|
|
4
4
|
import dataclasses
|
|
@@ -71,16 +71,43 @@ class Dir(Base):
|
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
@dataclasses.dataclass
|
|
74
|
-
class
|
|
74
|
+
class ColumnMd:
|
|
75
75
|
"""
|
|
76
|
-
Records
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
76
|
+
Records the non-versioned metadata of a column.
|
|
77
|
+
- immutable attributes: type, primary key, etc.
|
|
78
|
+
- when a column was added/dropped, which is needed to GC unreachable storage columns
|
|
79
|
+
(a column that was added after table snapshot n and dropped before table snapshot n+1 can be removed
|
|
80
|
+
from the stored table).
|
|
80
81
|
"""
|
|
81
|
-
|
|
82
|
+
id: int
|
|
82
83
|
schema_version_add: int
|
|
83
84
|
schema_version_drop: Optional[int]
|
|
85
|
+
col_type: dict
|
|
86
|
+
|
|
87
|
+
# if True, is part of the primary key
|
|
88
|
+
is_pk: bool
|
|
89
|
+
|
|
90
|
+
# if set, this is a computed column
|
|
91
|
+
value_expr: Optional[dict]
|
|
92
|
+
|
|
93
|
+
# if True, the column is present in the stored table
|
|
94
|
+
stored: Optional[bool]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclasses.dataclass
|
|
98
|
+
class IndexMd:
|
|
99
|
+
"""
|
|
100
|
+
Metadata needed to instantiate an EmbeddingIndex
|
|
101
|
+
"""
|
|
102
|
+
id: int
|
|
103
|
+
name: str
|
|
104
|
+
indexed_col_id: int # column being indexed
|
|
105
|
+
index_val_col_id: int # column holding the values to be indexed
|
|
106
|
+
index_val_undo_col_id: int # column holding index values for deleted rows
|
|
107
|
+
schema_version_add: int
|
|
108
|
+
schema_version_drop: Optional[int]
|
|
109
|
+
class_fqn: str
|
|
110
|
+
init_args: dict[str, Any]
|
|
84
111
|
|
|
85
112
|
|
|
86
113
|
@dataclasses.dataclass
|
|
@@ -91,13 +118,13 @@ class ViewMd:
|
|
|
91
118
|
base_versions: List[Tuple[str, Optional[int]]]
|
|
92
119
|
|
|
93
120
|
# filter predicate applied to the base table; view-only
|
|
94
|
-
predicate: Optional[
|
|
121
|
+
predicate: Optional[dict[str, Any]]
|
|
95
122
|
|
|
96
123
|
# ComponentIterator subclass; only for component views
|
|
97
124
|
iterator_class_fqn: Optional[str]
|
|
98
125
|
|
|
99
126
|
# args to pass to the iterator class constructor; only for component views
|
|
100
|
-
iterator_args: Optional[
|
|
127
|
+
iterator_args: Optional[dict[str, Any]]
|
|
101
128
|
|
|
102
129
|
|
|
103
130
|
@dataclasses.dataclass
|
|
@@ -109,15 +136,15 @@ class TableMd:
|
|
|
109
136
|
# each version has a corresponding schema version (current_version >= current_schema_version)
|
|
110
137
|
current_schema_version: int
|
|
111
138
|
|
|
112
|
-
# used to assign Column.id
|
|
113
|
-
|
|
139
|
+
next_col_id: int # used to assign Column.id
|
|
140
|
+
next_idx_id: int # used to assign IndexMd.id
|
|
114
141
|
|
|
115
142
|
# - used to assign the rowid column in the storage table
|
|
116
143
|
# - every row is assigned a unique and immutable rowid on insertion
|
|
117
144
|
next_row_id: int
|
|
118
145
|
|
|
119
|
-
|
|
120
|
-
|
|
146
|
+
column_md: dict[int, ColumnMd] # col_id -> ColumnMd
|
|
147
|
+
index_md: dict[int, IndexMd] # index_id -> IndexMd
|
|
121
148
|
view_md: Optional[ViewMd]
|
|
122
149
|
|
|
123
150
|
|
|
@@ -155,24 +182,20 @@ class TableVersion(Base):
|
|
|
155
182
|
@dataclasses.dataclass
|
|
156
183
|
class SchemaColumn:
|
|
157
184
|
"""
|
|
158
|
-
Records the
|
|
159
|
-
Contains the full set of columns for each new schema version: one record per (column x schema version).
|
|
185
|
+
Records the versioned metadata of a column.
|
|
160
186
|
"""
|
|
161
187
|
pos: int
|
|
162
188
|
name: str
|
|
163
|
-
col_type: dict
|
|
164
|
-
is_pk: bool
|
|
165
|
-
value_expr: Optional[dict]
|
|
166
|
-
stored: Optional[bool]
|
|
167
|
-
# if True, creates vector index for this column
|
|
168
|
-
is_indexed: bool
|
|
169
189
|
|
|
170
190
|
|
|
171
191
|
@dataclasses.dataclass
|
|
172
192
|
class TableSchemaVersionMd:
|
|
193
|
+
"""
|
|
194
|
+
Records all versioned table metadata.
|
|
195
|
+
"""
|
|
173
196
|
schema_version: int
|
|
174
197
|
preceding_schema_version: Optional[int]
|
|
175
|
-
columns:
|
|
198
|
+
columns: dict[int, SchemaColumn] # col_id -> SchemaColumn
|
|
176
199
|
num_retained_versions: int
|
|
177
200
|
comment: str
|
|
178
201
|
|