pixeltable 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +21 -4
- pixeltable/catalog/__init__.py +13 -0
- pixeltable/catalog/catalog.py +159 -0
- pixeltable/catalog/column.py +200 -0
- pixeltable/catalog/dir.py +32 -0
- pixeltable/catalog/globals.py +33 -0
- pixeltable/catalog/insertable_table.py +191 -0
- pixeltable/catalog/named_function.py +36 -0
- pixeltable/catalog/path.py +58 -0
- pixeltable/catalog/path_dict.py +139 -0
- pixeltable/catalog/schema_object.py +39 -0
- pixeltable/catalog/table.py +581 -0
- pixeltable/catalog/table_version.py +749 -0
- pixeltable/catalog/table_version_path.py +133 -0
- pixeltable/catalog/view.py +203 -0
- pixeltable/client.py +520 -31
- pixeltable/dataframe.py +540 -349
- pixeltable/env.py +373 -48
- pixeltable/exceptions.py +12 -21
- pixeltable/exec/__init__.py +9 -0
- pixeltable/exec/aggregation_node.py +78 -0
- pixeltable/exec/cache_prefetch_node.py +113 -0
- pixeltable/exec/component_iteration_node.py +79 -0
- pixeltable/exec/data_row_batch.py +95 -0
- pixeltable/exec/exec_context.py +22 -0
- pixeltable/exec/exec_node.py +61 -0
- pixeltable/exec/expr_eval_node.py +217 -0
- pixeltable/exec/in_memory_data_node.py +69 -0
- pixeltable/exec/media_validation_node.py +43 -0
- pixeltable/exec/sql_scan_node.py +225 -0
- pixeltable/exprs/__init__.py +24 -0
- pixeltable/exprs/arithmetic_expr.py +102 -0
- pixeltable/exprs/array_slice.py +71 -0
- pixeltable/exprs/column_property_ref.py +77 -0
- pixeltable/exprs/column_ref.py +105 -0
- pixeltable/exprs/comparison.py +77 -0
- pixeltable/exprs/compound_predicate.py +98 -0
- pixeltable/exprs/data_row.py +187 -0
- pixeltable/exprs/expr.py +586 -0
- pixeltable/exprs/expr_set.py +39 -0
- pixeltable/exprs/function_call.py +380 -0
- pixeltable/exprs/globals.py +69 -0
- pixeltable/exprs/image_member_access.py +115 -0
- pixeltable/exprs/image_similarity_predicate.py +58 -0
- pixeltable/exprs/inline_array.py +107 -0
- pixeltable/exprs/inline_dict.py +101 -0
- pixeltable/exprs/is_null.py +38 -0
- pixeltable/exprs/json_mapper.py +121 -0
- pixeltable/exprs/json_path.py +159 -0
- pixeltable/exprs/literal.py +54 -0
- pixeltable/exprs/object_ref.py +41 -0
- pixeltable/exprs/predicate.py +44 -0
- pixeltable/exprs/row_builder.py +355 -0
- pixeltable/exprs/rowid_ref.py +94 -0
- pixeltable/exprs/type_cast.py +53 -0
- pixeltable/exprs/variable.py +45 -0
- pixeltable/func/__init__.py +9 -0
- pixeltable/func/aggregate_function.py +194 -0
- pixeltable/func/batched_function.py +53 -0
- pixeltable/func/callable_function.py +69 -0
- pixeltable/func/expr_template_function.py +82 -0
- pixeltable/func/function.py +110 -0
- pixeltable/func/function_registry.py +227 -0
- pixeltable/func/globals.py +36 -0
- pixeltable/func/nos_function.py +202 -0
- pixeltable/func/signature.py +166 -0
- pixeltable/func/udf.py +163 -0
- pixeltable/functions/__init__.py +52 -103
- pixeltable/functions/eval.py +216 -0
- pixeltable/functions/fireworks.py +61 -0
- pixeltable/functions/huggingface.py +120 -0
- pixeltable/functions/image.py +16 -0
- pixeltable/functions/openai.py +88 -0
- pixeltable/functions/pil/image.py +148 -7
- pixeltable/functions/string.py +13 -0
- pixeltable/functions/together.py +27 -0
- pixeltable/functions/util.py +41 -0
- pixeltable/functions/video.py +62 -0
- pixeltable/iterators/__init__.py +3 -0
- pixeltable/iterators/base.py +48 -0
- pixeltable/iterators/document.py +311 -0
- pixeltable/iterators/video.py +89 -0
- pixeltable/metadata/__init__.py +54 -0
- pixeltable/metadata/converters/convert_10.py +18 -0
- pixeltable/metadata/schema.py +211 -0
- pixeltable/plan.py +656 -0
- pixeltable/store.py +413 -182
- pixeltable/tests/conftest.py +143 -86
- pixeltable/tests/test_audio.py +65 -0
- pixeltable/tests/test_catalog.py +27 -0
- pixeltable/tests/test_client.py +14 -14
- pixeltable/tests/test_component_view.py +372 -0
- pixeltable/tests/test_dataframe.py +433 -0
- pixeltable/tests/test_dirs.py +78 -62
- pixeltable/tests/test_document.py +117 -0
- pixeltable/tests/test_exprs.py +591 -135
- pixeltable/tests/test_function.py +297 -67
- pixeltable/tests/test_functions.py +283 -1
- pixeltable/tests/test_migration.py +43 -0
- pixeltable/tests/test_nos.py +54 -0
- pixeltable/tests/test_snapshot.py +208 -0
- pixeltable/tests/test_table.py +1086 -258
- pixeltable/tests/test_transactional_directory.py +42 -0
- pixeltable/tests/test_types.py +5 -11
- pixeltable/tests/test_video.py +149 -34
- pixeltable/tests/test_view.py +530 -0
- pixeltable/tests/utils.py +186 -45
- pixeltable/tool/create_test_db_dump.py +149 -0
- pixeltable/type_system.py +490 -133
- pixeltable/utils/__init__.py +17 -46
- pixeltable/utils/clip.py +12 -15
- pixeltable/utils/coco.py +136 -0
- pixeltable/utils/documents.py +39 -0
- pixeltable/utils/filecache.py +195 -0
- pixeltable/utils/help.py +11 -0
- pixeltable/utils/media_store.py +76 -0
- pixeltable/utils/parquet.py +126 -0
- pixeltable/utils/pytorch.py +172 -0
- pixeltable/utils/s3.py +13 -0
- pixeltable/utils/sql.py +17 -0
- pixeltable/utils/transactional_directory.py +35 -0
- pixeltable-0.2.1.dist-info/LICENSE +18 -0
- pixeltable-0.2.1.dist-info/METADATA +119 -0
- pixeltable-0.2.1.dist-info/RECORD +125 -0
- {pixeltable-0.1.2.dist-info → pixeltable-0.2.1.dist-info}/WHEEL +1 -1
- pixeltable/catalog.py +0 -1421
- pixeltable/exprs.py +0 -1745
- pixeltable/function.py +0 -269
- pixeltable/functions/clip.py +0 -10
- pixeltable/functions/pil/__init__.py +0 -23
- pixeltable/functions/tf.py +0 -21
- pixeltable/index.py +0 -57
- pixeltable/tests/test_dict.py +0 -24
- pixeltable/tests/test_tf.py +0 -69
- pixeltable/tf.py +0 -33
- pixeltable/utils/tf.py +0 -33
- pixeltable/utils/video.py +0 -32
- pixeltable-0.1.2.dist-info/LICENSE +0 -201
- pixeltable-0.1.2.dist-info/METADATA +0 -89
- pixeltable-0.1.2.dist-info/RECORD +0 -37
|
@@ -1,11 +1,293 @@
|
|
|
1
|
-
import
|
|
1
|
+
from typing import Dict, Any
|
|
2
|
+
|
|
2
3
|
import pytest
|
|
3
4
|
|
|
5
|
+
import pixeltable as pxt
|
|
4
6
|
from pixeltable import catalog
|
|
7
|
+
from pixeltable.env import Env
|
|
5
8
|
from pixeltable.functions.pil.image import blend
|
|
9
|
+
from pixeltable.iterators import FrameIterator
|
|
10
|
+
from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed, get_sentences, get_image_files
|
|
11
|
+
from pixeltable.type_system import VideoType, StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
|
|
6
12
|
|
|
7
13
|
|
|
8
14
|
class TestFunctions:
|
|
9
15
|
def test_pil(self, img_tbl: catalog.Table) -> None:
|
|
10
16
|
t = img_tbl
|
|
11
17
|
_ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
|
|
18
|
+
|
|
19
|
+
def test_eval_detections(self, test_client: pxt.Client) -> None:
|
|
20
|
+
skip_test_if_not_installed('nos')
|
|
21
|
+
cl = test_client
|
|
22
|
+
video_t = cl.create_table('video_tbl', {'video': VideoType()})
|
|
23
|
+
# create frame view
|
|
24
|
+
args = {'video': video_t.video, 'fps': 1}
|
|
25
|
+
v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
|
|
26
|
+
|
|
27
|
+
files = get_video_files()
|
|
28
|
+
video_t.insert(video=files[-1])
|
|
29
|
+
v.add_column(frame_s=v.frame.resize([640, 480]))
|
|
30
|
+
from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
|
|
31
|
+
v.add_column(detections_a=yolox_nano(v.frame_s))
|
|
32
|
+
v.add_column(detections_b=yolox_small(v.frame_s))
|
|
33
|
+
v.add_column(gt=yolox_large(v.frame_s))
|
|
34
|
+
from pixeltable.functions.eval import eval_detections, mean_ap
|
|
35
|
+
res = v.select(
|
|
36
|
+
eval_detections(
|
|
37
|
+
v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
|
|
38
|
+
)).show()
|
|
39
|
+
v.add_column(
|
|
40
|
+
eval_a=eval_detections(
|
|
41
|
+
v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
|
|
42
|
+
v.add_column(
|
|
43
|
+
eval_b=eval_detections(
|
|
44
|
+
v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
|
|
45
|
+
ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
|
|
46
|
+
ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
|
|
47
|
+
common_classes = set(ap_a.keys()) & set(ap_b.keys())
|
|
48
|
+
|
|
49
|
+
## TODO: following assertion is failing on CI,
|
|
50
|
+
# It is not necessarily a bug, as assert codition is not expected to be always true
|
|
51
|
+
# for k in common_classes:
|
|
52
|
+
# assert ap_a[k] <= ap_b[k]
|
|
53
|
+
|
|
54
|
+
def test_str(self, test_client: pxt.Client) -> None:
|
|
55
|
+
cl = test_client
|
|
56
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
57
|
+
from pixeltable.functions.string import str_format
|
|
58
|
+
t.add_column(s1=str_format('ABC {0}', t.input))
|
|
59
|
+
t.add_column(s2=str_format('DEF {this}', this=t.input))
|
|
60
|
+
t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
|
|
61
|
+
status = t.insert(input='MNO')
|
|
62
|
+
assert status.num_rows == 1
|
|
63
|
+
assert status.num_excs == 0
|
|
64
|
+
row = t.head()[0]
|
|
65
|
+
assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
|
|
66
|
+
|
|
67
|
+
def test_openai(self, test_client: pxt.Client) -> None:
|
|
68
|
+
skip_test_if_not_installed('openai')
|
|
69
|
+
TestFunctions.skip_test_if_no_openai_client()
|
|
70
|
+
if Env.get().openai_client is None:
|
|
71
|
+
pytest.skip(f'OpenAI client does not exist (missing API key?).')
|
|
72
|
+
cl = test_client
|
|
73
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
74
|
+
from pixeltable.functions.openai import chat_completions, embeddings, moderations
|
|
75
|
+
msgs = [
|
|
76
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
77
|
+
{"role": "user", "content": t.input}
|
|
78
|
+
]
|
|
79
|
+
t.add_column(input_msgs=msgs)
|
|
80
|
+
t.add_column(chat_output=chat_completions(model='gpt-3.5-turbo', messages=t.input_msgs))
|
|
81
|
+
# with inlined messages
|
|
82
|
+
t.add_column(chat_output2=chat_completions(model='gpt-3.5-turbo', messages=msgs))
|
|
83
|
+
t.add_column(ada_embed=embeddings(model='text-embedding-ada-002', input=t.input))
|
|
84
|
+
t.add_column(text_3=embeddings(model='text-embedding-3-small', input=t.input))
|
|
85
|
+
t.add_column(moderation=moderations(input=t.input))
|
|
86
|
+
t.insert(input='I find you really annoying')
|
|
87
|
+
_ = t.head()
|
|
88
|
+
|
|
89
|
+
def test_gpt_4_vision(self, test_client: pxt.Client) -> None:
|
|
90
|
+
skip_test_if_not_installed('openai')
|
|
91
|
+
TestFunctions.skip_test_if_no_openai_client()
|
|
92
|
+
cl = test_client
|
|
93
|
+
t = cl.create_table('test_tbl', {'prompt': StringType(), 'img': ImageType()})
|
|
94
|
+
from pixeltable.functions.openai import chat_completions
|
|
95
|
+
from pixeltable.functions.string import str_format
|
|
96
|
+
msgs = [
|
|
97
|
+
{'role': 'user',
|
|
98
|
+
'content': [
|
|
99
|
+
{'type': 'text', 'text': t.prompt},
|
|
100
|
+
{'type': 'image_url', 'image_url': {
|
|
101
|
+
'url': str_format('data:image/png;base64,{0}', t.img.b64_encode())
|
|
102
|
+
}}
|
|
103
|
+
]}
|
|
104
|
+
]
|
|
105
|
+
t.add_column(response=chat_completions(model='gpt-4-vision-preview', messages=msgs, max_tokens=300))
|
|
106
|
+
t.add_column(response_content=t.response.choices[0].message.content)
|
|
107
|
+
t.insert(prompt="What's in this image?", img=_sample_image_url)
|
|
108
|
+
result = t.collect()['response_content'][0]
|
|
109
|
+
assert len(result) > 0
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def skip_test_if_no_openai_client() -> None:
|
|
113
|
+
if Env.get().openai_client is None:
|
|
114
|
+
pytest.skip(f'OpenAI client does not exist (missing API key?)')
|
|
115
|
+
|
|
116
|
+
def test_together(self, test_client: pxt.Client) -> None:
|
|
117
|
+
skip_test_if_not_installed('together')
|
|
118
|
+
if not Env.get().has_together_client:
|
|
119
|
+
pytest.skip(f'Together client does not exist (missing API key?)')
|
|
120
|
+
cl = test_client
|
|
121
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
122
|
+
from pixeltable.functions.together import completions
|
|
123
|
+
t.add_column(output=completions(prompt=t.input, model='mistralai/Mixtral-8x7B-v0.1', stop=['\n']))
|
|
124
|
+
t.add_column(output_text=t.output.output.choices[0].text)
|
|
125
|
+
t.insert(input='I am going to the ')
|
|
126
|
+
result = t.select(t.output_text).collect()['output_text'][0]
|
|
127
|
+
assert len(result) > 0
|
|
128
|
+
|
|
129
|
+
def test_fireworks(self, test_client: pxt.Client) -> None:
|
|
130
|
+
skip_test_if_not_installed('fireworks')
|
|
131
|
+
try:
|
|
132
|
+
from pixeltable.functions.fireworks import initialize
|
|
133
|
+
initialize()
|
|
134
|
+
except:
|
|
135
|
+
pytest.skip(f'Fireworks client does not exist (missing API key?)')
|
|
136
|
+
cl = test_client
|
|
137
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
138
|
+
from pixeltable.functions.fireworks import chat_completions
|
|
139
|
+
t['output'] = chat_completions(prompt=t.input, model='accounts/fireworks/models/llama-v2-7b-chat', max_tokens=256).choices[0].text
|
|
140
|
+
t.insert(input='I am going to the ')
|
|
141
|
+
result = t.select(t.output).collect()['output'][0]
|
|
142
|
+
assert len(result) > 0
|
|
143
|
+
|
|
144
|
+
def test_hf_function(self, test_client: pxt.Client) -> None:
|
|
145
|
+
skip_test_if_not_installed('sentence_transformers')
|
|
146
|
+
cl = test_client
|
|
147
|
+
t = cl.create_table('test_tbl', {'input': StringType(), 'bool_col': BoolType()})
|
|
148
|
+
from pixeltable.functions.huggingface import sentence_transformer
|
|
149
|
+
model_id = 'intfloat/e5-large-v2'
|
|
150
|
+
t.add_column(e5=sentence_transformer(t.input, model_id=model_id))
|
|
151
|
+
sents = get_sentences()
|
|
152
|
+
status = t.insert({'input': s, 'bool_col': True} for s in sents)
|
|
153
|
+
assert status.num_rows == len(sents)
|
|
154
|
+
assert status.num_excs == 0
|
|
155
|
+
|
|
156
|
+
# verify handling of constant params
|
|
157
|
+
with pytest.raises(ValueError) as exc_info:
|
|
158
|
+
t.add_column(e5_2=sentence_transformer(t.input, model_id=t.input))
|
|
159
|
+
assert ': parameter model_id must be a constant value' in str(exc_info.value)
|
|
160
|
+
with pytest.raises(ValueError) as exc_info:
|
|
161
|
+
t.add_column(e5_2=sentence_transformer(t.input, model_id=model_id, normalize_embeddings=t.bool_col))
|
|
162
|
+
assert ': parameter normalize_embeddings must be a constant value' in str(exc_info.value)
|
|
163
|
+
|
|
164
|
+
# make sure this doesn't cause an exception
|
|
165
|
+
# TODO: is there some way to capture the output?
|
|
166
|
+
t.describe()
|
|
167
|
+
|
|
168
|
+
def test_sentence_transformer(self, test_client: pxt.Client) -> None:
|
|
169
|
+
skip_test_if_not_installed('sentence_transformers')
|
|
170
|
+
cl = test_client
|
|
171
|
+
t = cl.create_table('test_tbl', {'input': StringType(), 'input_list': JsonType()})
|
|
172
|
+
sents = get_sentences(10)
|
|
173
|
+
status = t.insert({'input': s, 'input_list': sents} for s in sents)
|
|
174
|
+
assert status.num_rows == len(sents)
|
|
175
|
+
assert status.num_excs == 0
|
|
176
|
+
|
|
177
|
+
# run multiple models one at a time in order to exercise batching
|
|
178
|
+
from pixeltable.functions.huggingface import sentence_transformer, sentence_transformer_list
|
|
179
|
+
model_ids = ['sentence-transformers/all-mpnet-base-v2', 'BAAI/bge-reranker-base']
|
|
180
|
+
num_dims = [768, 768]
|
|
181
|
+
for idx, model_id in enumerate(model_ids):
|
|
182
|
+
col_name = f'embed{idx}'
|
|
183
|
+
t[col_name] = sentence_transformer(t.input, model_id=model_id, normalize_embeddings=True)
|
|
184
|
+
assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
|
|
185
|
+
list_col_name = f'embed_list{idx}'
|
|
186
|
+
t[list_col_name] = sentence_transformer_list(t.input_list, model_id=model_id, normalize_embeddings=True)
|
|
187
|
+
assert t.column_types()[list_col_name] == JsonType()
|
|
188
|
+
|
|
189
|
+
def verify_row(row: Dict[str, Any]) -> None:
|
|
190
|
+
for idx, (_, d) in enumerate(zip(model_ids, num_dims)):
|
|
191
|
+
assert row[f'embed{idx}'].shape == (d,)
|
|
192
|
+
assert len(row[f'embed_list{idx}']) == len(sents)
|
|
193
|
+
assert all(len(v) == d for v in row[f'embed_list{idx}'])
|
|
194
|
+
|
|
195
|
+
verify_row(t.tail(1)[0])
|
|
196
|
+
|
|
197
|
+
# execution still works after reload
|
|
198
|
+
cl = pxt.Client(reload=True)
|
|
199
|
+
t = cl.get_table('test_tbl')
|
|
200
|
+
status = t.insert({'input': s, 'input_list': sents} for s in sents)
|
|
201
|
+
assert status.num_rows == len(sents)
|
|
202
|
+
assert status.num_excs == 0
|
|
203
|
+
verify_row(t.tail(1)[0])
|
|
204
|
+
|
|
205
|
+
def test_cross_encoder(self, test_client: pxt.Client) -> None:
|
|
206
|
+
skip_test_if_not_installed('sentence_transformers')
|
|
207
|
+
cl = test_client
|
|
208
|
+
t = cl.create_table('test_tbl', {'input': StringType(), 'input_list': JsonType()})
|
|
209
|
+
sents = get_sentences(10)
|
|
210
|
+
status = t.insert({'input': s, 'input_list': sents} for s in sents)
|
|
211
|
+
assert status.num_rows == len(sents)
|
|
212
|
+
assert status.num_excs == 0
|
|
213
|
+
|
|
214
|
+
# run multiple models one at a time in order to exercise batching
|
|
215
|
+
from pixeltable.functions.huggingface import cross_encoder, cross_encoder_list
|
|
216
|
+
model_ids = ['cross-encoder/ms-marco-MiniLM-L-6-v2', 'cross-encoder/ms-marco-TinyBERT-L-2-v2']
|
|
217
|
+
for idx, model_id in enumerate(model_ids):
|
|
218
|
+
col_name = f'embed{idx}'
|
|
219
|
+
t[col_name] = cross_encoder(t.input, t.input, model_id=model_id)
|
|
220
|
+
assert t.column_types()[col_name] == FloatType()
|
|
221
|
+
list_col_name = f'embed_list{idx}'
|
|
222
|
+
t[list_col_name] = cross_encoder_list(t.input, t.input_list, model_id=model_id)
|
|
223
|
+
assert t.column_types()[list_col_name] == JsonType()
|
|
224
|
+
|
|
225
|
+
def verify_row(row: Dict[str, Any]) -> None:
|
|
226
|
+
for i in range(len(model_ids)):
|
|
227
|
+
assert len(row[f'embed_list{idx}']) == len(sents)
|
|
228
|
+
assert all(isinstance(v, float) for v in row[f'embed_list{idx}'])
|
|
229
|
+
|
|
230
|
+
verify_row(t.tail(1)[0])
|
|
231
|
+
|
|
232
|
+
# execution still works after reload
|
|
233
|
+
cl = pxt.Client(reload=True)
|
|
234
|
+
t = cl.get_table('test_tbl')
|
|
235
|
+
status = t.insert({'input': s, 'input_list': sents} for s in sents)
|
|
236
|
+
assert status.num_rows == len(sents)
|
|
237
|
+
assert status.num_excs == 0
|
|
238
|
+
verify_row(t.tail(1)[0])
|
|
239
|
+
|
|
240
|
+
def test_clip(self, test_client: pxt.Client) -> None:
|
|
241
|
+
skip_test_if_not_installed('transformers')
|
|
242
|
+
cl = test_client
|
|
243
|
+
t = cl.create_table('test_tbl', {'text': StringType(), 'img': ImageType()})
|
|
244
|
+
num_rows = 10
|
|
245
|
+
sents = get_sentences(num_rows)
|
|
246
|
+
imgs = get_image_files()[:num_rows]
|
|
247
|
+
status = t.insert({'text': text, 'img': img} for text, img in zip(sents, imgs))
|
|
248
|
+
assert status.num_rows == len(sents)
|
|
249
|
+
assert status.num_excs == 0
|
|
250
|
+
|
|
251
|
+
# run multiple models one at a time in order to exercise batching
|
|
252
|
+
from pixeltable.functions.huggingface import clip_text, clip_image
|
|
253
|
+
model_ids = ['openai/clip-vit-base-patch32', 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K']
|
|
254
|
+
for idx, model_id in enumerate(model_ids):
|
|
255
|
+
col_name = f'embed_text{idx}'
|
|
256
|
+
t[col_name] = clip_text(t.text, model_id=model_id)
|
|
257
|
+
assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
|
|
258
|
+
col_name = f'embed_img{idx}'
|
|
259
|
+
t[col_name] = clip_image(t.img, model_id=model_id)
|
|
260
|
+
assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
|
|
261
|
+
|
|
262
|
+
def verify_row(row: Dict[str, Any]) -> None:
|
|
263
|
+
for idx, _ in enumerate(model_ids):
|
|
264
|
+
assert row[f'embed_text{idx}'].shape == (512,)
|
|
265
|
+
assert row[f'embed_img{idx}'].shape == (512,)
|
|
266
|
+
|
|
267
|
+
verify_row(t.tail(1)[0])
|
|
268
|
+
|
|
269
|
+
# execution still works after reload
|
|
270
|
+
cl = pxt.Client(reload=True)
|
|
271
|
+
t = cl.get_table('test_tbl')
|
|
272
|
+
status = t.insert({'text': text, 'img': img} for text, img in zip(sents, imgs))
|
|
273
|
+
assert status.num_rows == len(sents)
|
|
274
|
+
assert status.num_excs == 0
|
|
275
|
+
verify_row(t.tail(1)[0])
|
|
276
|
+
|
|
277
|
+
def test_detr_for_object_detection(self, test_client: pxt.Client) -> None:
|
|
278
|
+
skip_test_if_not_installed('transformers')
|
|
279
|
+
cl = test_client
|
|
280
|
+
t = cl.create_table('test_tbl', {'img': ImageType()})
|
|
281
|
+
from pixeltable.functions.huggingface import detr_for_object_detection
|
|
282
|
+
t['detect'] = detr_for_object_detection(t.img, model_id='facebook/detr-resnet-50', threshold=0.8)
|
|
283
|
+
status = t.insert(img=_sample_image_url)
|
|
284
|
+
assert status.num_rows == 1
|
|
285
|
+
assert status.num_excs == 0
|
|
286
|
+
result = t.select(t.detect).collect()[0]['detect']
|
|
287
|
+
assert 'orange' in result['label_text']
|
|
288
|
+
assert 'bowl' in result['label_text']
|
|
289
|
+
assert 'broccoli' in result['label_text']
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
_sample_image_url = \
|
|
293
|
+
'https://raw.githubusercontent.com/pixeltable/pixeltable/master/docs/source/data/images/000000000009.jpg'
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import subprocess
|
|
5
|
+
|
|
6
|
+
import pgserver
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
import pixeltable as pxt
|
|
10
|
+
from pixeltable.env import Env
|
|
11
|
+
from pixeltable.tests.conftest import clean_db
|
|
12
|
+
|
|
13
|
+
_logger = logging.getLogger('pixeltable')
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TestMigration:
|
|
17
|
+
|
|
18
|
+
@pytest.mark.skip(reason='Suspended')
|
|
19
|
+
def test_db_migration(self, init_env) -> None:
|
|
20
|
+
env = Env.get()
|
|
21
|
+
pg_package_dir = os.path.dirname(pgserver.__file__)
|
|
22
|
+
pg_restore_binary = f'{pg_package_dir}/pginstall/bin/pg_restore'
|
|
23
|
+
_logger.info(f'Using pg_restore binary at: {pg_restore_binary}')
|
|
24
|
+
dump_files = glob.glob('pixeltable/tests/data/dbdumps/*.dump.gz')
|
|
25
|
+
dump_files.sort()
|
|
26
|
+
for dump_file in dump_files:
|
|
27
|
+
_logger.info(f'Testing migration from DB dump {dump_file}.')
|
|
28
|
+
_logger.info(f'DB URL: {env.db_url}')
|
|
29
|
+
clean_db(restore_tables=False)
|
|
30
|
+
with open(dump_file, 'rb') as dump:
|
|
31
|
+
gunzip_process = subprocess.Popen(
|
|
32
|
+
["gunzip", "-c"],
|
|
33
|
+
stdin=dump,
|
|
34
|
+
stdout=subprocess.PIPE
|
|
35
|
+
)
|
|
36
|
+
subprocess.run(
|
|
37
|
+
[pg_restore_binary, '-d', env.db_url, '-U', 'postgres'],
|
|
38
|
+
stdin=gunzip_process.stdout,
|
|
39
|
+
check=True
|
|
40
|
+
)
|
|
41
|
+
# TODO(aaron-siegel) This will test that the migration succeeds without raising any exceptions.
|
|
42
|
+
# We should also add some assertions to sanity-check the outcome.
|
|
43
|
+
_ = pxt.Client()
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import pixeltable as pxt
|
|
4
|
+
from pixeltable.iterators import FrameIterator
|
|
5
|
+
from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed
|
|
6
|
+
from pixeltable.type_system import ImageType, VideoType
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestNOS:
|
|
10
|
+
def test_basic(self, test_client: pxt.Client) -> None:
|
|
11
|
+
skip_test_if_not_installed('nos')
|
|
12
|
+
cl = test_client
|
|
13
|
+
video_t = cl.create_table('video_tbl', {'video': VideoType()})
|
|
14
|
+
# create frame view
|
|
15
|
+
args = {'video': video_t.video, 'fps': 1}
|
|
16
|
+
v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
|
|
17
|
+
v.add_column(transform1=v.frame.rotate(30), stored=False)
|
|
18
|
+
from pixeltable.functions.nos.object_detection_2d import \
|
|
19
|
+
torchvision_fasterrcnn_mobilenet_v3_large_320_fpn as fasterrcnn
|
|
20
|
+
v.add_column(detections=fasterrcnn(v.transform1))
|
|
21
|
+
from pixeltable.functions.nos.image_embedding import openai_clip
|
|
22
|
+
v.add_column(embed=openai_clip(v.transform1.resize([224, 224])))
|
|
23
|
+
# add a stored column that isn't referenced in nos calls
|
|
24
|
+
v.add_column(transform2=v.frame.rotate(60), stored=True)
|
|
25
|
+
|
|
26
|
+
status = video_t.insert(video=get_video_files()[0])
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
def test_exceptions(self, test_client: pxt.Client) -> None:
|
|
30
|
+
skip_test_if_not_installed('nos')
|
|
31
|
+
cl = test_client
|
|
32
|
+
video_t = cl.create_table('video_tbl', {'video': VideoType()})
|
|
33
|
+
# create frame view
|
|
34
|
+
args = {'video': video_t.video, 'fps': 1}
|
|
35
|
+
v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
|
|
36
|
+
video_t.insert(video=get_video_files()[0])
|
|
37
|
+
|
|
38
|
+
v.add_column(frame_s=v.frame.resize([640, 480]))
|
|
39
|
+
# 'rotated' has exceptions
|
|
40
|
+
v.add_column(rotated=lambda frame_s, frame_idx: frame_s.rotate(int(360 / frame_idx)), type=ImageType())
|
|
41
|
+
from pixeltable.functions.nos.object_detection_2d import yolox_medium
|
|
42
|
+
v.add_column(detections=yolox_medium(v.rotated), stored=True)
|
|
43
|
+
assert v.where(v.detections.errortype != None).count() == 1
|
|
44
|
+
|
|
45
|
+
@pytest.mark.skip(reason='too slow')
|
|
46
|
+
def test_sd(self, test_client: pxt.Client) -> None:
|
|
47
|
+
skip_test_if_not_installed('nos')
|
|
48
|
+
"""Test model that mixes batched with scalar parameters"""
|
|
49
|
+
t = test_client.create_table('sd_test', {'prompt': pxt.StringType()})
|
|
50
|
+
t.insert(prompt='cat on a sofa')
|
|
51
|
+
from pixeltable.functions.nos.image_generation import stabilityai_stable_diffusion_2 as sd2
|
|
52
|
+
t.add_column(img=sd2(t.prompt, 1, 512, 512), stored=True)
|
|
53
|
+
img = t[t.img].show(1)[0, 0]
|
|
54
|
+
assert img.size == (512, 512)
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
import pixeltable as pxt
|
|
7
|
+
import pixeltable.exceptions as excs
|
|
8
|
+
from pixeltable.tests.utils import create_test_tbl, assert_resultset_eq
|
|
9
|
+
from pixeltable.type_system import IntType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestSnapshot:
|
|
13
|
+
def run_basic_test(
|
|
14
|
+
self, cl: pxt.Client, tbl: pxt.Table, snap: pxt.Table, extra_items: Dict[str, Any], filter: Any,
|
|
15
|
+
reload_md: bool
|
|
16
|
+
) -> None:
|
|
17
|
+
tbl_path, snap_path = cl.get_path(tbl), cl.get_path(snap)
|
|
18
|
+
# run the initial query against the base table here, before reloading, otherwise the filter breaks
|
|
19
|
+
tbl_select_list = [tbl[col_name] for col_name in tbl.column_names()]
|
|
20
|
+
tbl_select_list.extend([value_expr for _, value_expr in extra_items.items()])
|
|
21
|
+
orig_resultset = tbl.select(*tbl_select_list).where(filter).order_by(tbl.c2).collect()
|
|
22
|
+
|
|
23
|
+
if reload_md:
|
|
24
|
+
# reload md
|
|
25
|
+
cl = pxt.Client(reload=True)
|
|
26
|
+
tbl = cl.get_table(tbl_path)
|
|
27
|
+
snap = cl.get_table(snap_path)
|
|
28
|
+
|
|
29
|
+
# view select list: base cols followed by view cols
|
|
30
|
+
snap_select_list = [snap[col_name] for col_name in snap.column_names()[len(extra_items):]]
|
|
31
|
+
snap_select_list.extend([snap[col_name] for col_name in extra_items.keys()])
|
|
32
|
+
snap_query = snap.select(*snap_select_list).order_by(snap.c2)
|
|
33
|
+
r1 = list(orig_resultset)
|
|
34
|
+
r2 = list(snap_query.collect())
|
|
35
|
+
assert_resultset_eq(snap_query.collect(), orig_resultset)
|
|
36
|
+
|
|
37
|
+
# adding data to a base table doesn't change the snapshot
|
|
38
|
+
rows = list(tbl.select(tbl.c1, tbl.c1n, tbl.c2, tbl.c3, tbl.c4, tbl.c5, tbl.c6, tbl.c7).collect())
|
|
39
|
+
status = tbl.insert(rows)
|
|
40
|
+
assert status.num_rows == len(rows)
|
|
41
|
+
assert_resultset_eq(snap_query.collect(), orig_resultset)
|
|
42
|
+
|
|
43
|
+
# update() doesn't affect the view
|
|
44
|
+
status = tbl.update({'c3': tbl.c3 + 1.0})
|
|
45
|
+
assert status.num_rows == tbl.count()
|
|
46
|
+
assert_resultset_eq(snap_query.collect(), orig_resultset)
|
|
47
|
+
|
|
48
|
+
# delete() doesn't affect the view
|
|
49
|
+
num_tbl_rows = tbl.count()
|
|
50
|
+
status = tbl.delete()
|
|
51
|
+
assert status.num_rows == num_tbl_rows
|
|
52
|
+
assert_resultset_eq(snap_query.collect(), orig_resultset)
|
|
53
|
+
|
|
54
|
+
tbl.revert() # undo delete()
|
|
55
|
+
tbl.revert() # undo update()
|
|
56
|
+
tbl.revert() # undo insert()
|
|
57
|
+
# can't revert a version referenced by a snapshot
|
|
58
|
+
with pytest.raises(excs.Error) as excinfo:
|
|
59
|
+
tbl.revert()
|
|
60
|
+
assert 'version is needed' in str(excinfo.value)
|
|
61
|
+
|
|
62
|
+
# can't drop a table with snapshots
|
|
63
|
+
with pytest.raises(excs.Error) as excinfo:
|
|
64
|
+
cl.drop_table(tbl_path)
|
|
65
|
+
assert snap_path in str(excinfo.value)
|
|
66
|
+
|
|
67
|
+
cl.drop_table(snap_path)
|
|
68
|
+
cl.drop_table(tbl_path)
|
|
69
|
+
|
|
70
|
+
def test_basic(self, test_client: pxt.Client) -> None:
|
|
71
|
+
cl = test_client
|
|
72
|
+
cl.create_dir('main')
|
|
73
|
+
cl.create_dir('snap')
|
|
74
|
+
tbl_path = 'main.tbl1'
|
|
75
|
+
snap_path = 'snap.snap1'
|
|
76
|
+
|
|
77
|
+
for reload_md in [False, True]:
|
|
78
|
+
for has_filter in [False, True]:
|
|
79
|
+
for has_cols in [False, True]:
|
|
80
|
+
cl = pxt.Client(reload=True)
|
|
81
|
+
tbl = create_test_tbl(name=tbl_path, client=cl)
|
|
82
|
+
schema = {
|
|
83
|
+
'v1': tbl.c3 * 2.0,
|
|
84
|
+
# include a lambda to make sure that is handled correctly
|
|
85
|
+
'v2': {'value': lambda c3: c3 * 2.0, 'type': pxt.FloatType()}
|
|
86
|
+
} if has_cols else {}
|
|
87
|
+
extra_items = {'v1': tbl.c3 * 2.0, 'v2': tbl.c3 * 2.0} if has_cols else {}
|
|
88
|
+
filter = tbl.c2 < 10 if has_filter else None
|
|
89
|
+
snap = cl.create_view(snap_path, tbl, schema=schema, filter=filter, is_snapshot=True)
|
|
90
|
+
self.run_basic_test(cl, tbl, snap, extra_items=extra_items, filter=filter, reload_md=reload_md)
|
|
91
|
+
|
|
92
|
+
def test_views_of_snapshots(self, test_client: pxt.Client) -> None:
|
|
93
|
+
cl = test_client
|
|
94
|
+
t = cl.create_table('tbl', {'a': IntType()})
|
|
95
|
+
rows = [{'a': 1}, {'a': 2}, {'a': 3}]
|
|
96
|
+
status = t.insert(rows)
|
|
97
|
+
assert status.num_rows == len(rows)
|
|
98
|
+
assert status.num_excs == 0
|
|
99
|
+
s1 = cl.create_view('s1', t, is_snapshot=True)
|
|
100
|
+
v1 = cl.create_view('v1', s1, is_snapshot=False)
|
|
101
|
+
s2 = cl.create_view('s2', v1, is_snapshot=True)
|
|
102
|
+
v2 = cl.create_view('v2', s2, is_snapshot=False)
|
|
103
|
+
|
|
104
|
+
def verify(s1: pxt.Table, s2: pxt.Table, v1: pxt.Table, v2: pxt.Table) -> None:
|
|
105
|
+
assert s1.count() == len(rows)
|
|
106
|
+
assert v1.count() == len(rows)
|
|
107
|
+
assert s2.count() == len(rows)
|
|
108
|
+
assert v2.count() == len(rows)
|
|
109
|
+
|
|
110
|
+
verify(s1, s2, v1, v2)
|
|
111
|
+
|
|
112
|
+
status = t.insert(rows)
|
|
113
|
+
assert status.num_rows == len(rows)
|
|
114
|
+
assert status.num_excs == 0
|
|
115
|
+
verify(s1, s2, v1, v2)
|
|
116
|
+
|
|
117
|
+
cl = pxt.Client(reload=True)
|
|
118
|
+
s1 = cl.get_table('s1')
|
|
119
|
+
s2 = cl.get_table('s2')
|
|
120
|
+
v1 = cl.get_table('v1')
|
|
121
|
+
v2 = cl.get_table('v2')
|
|
122
|
+
verify(s1, s2, v1, v2)
|
|
123
|
+
|
|
124
|
+
def test_snapshot_of_view_chain(self, test_client: pxt.Client) -> None:
|
|
125
|
+
cl = test_client
|
|
126
|
+
t = cl.create_table('tbl', {'a': IntType()})
|
|
127
|
+
rows = [{'a': 1}, {'a': 2}, {'a': 3}]
|
|
128
|
+
status = t.insert(rows)
|
|
129
|
+
assert status.num_rows == len(rows)
|
|
130
|
+
assert status.num_excs == 0
|
|
131
|
+
v1 = cl.create_view('v1', t, is_snapshot=False)
|
|
132
|
+
v2 = cl.create_view('v2', v1, is_snapshot=False)
|
|
133
|
+
s = cl.create_view('s', v2, is_snapshot=True)
|
|
134
|
+
|
|
135
|
+
def verify(v1: pxt.Table, v2: pxt.Table, s: pxt.Table) -> None:
|
|
136
|
+
assert v1.count() == t.count()
|
|
137
|
+
assert v2.count() == t.count()
|
|
138
|
+
assert s.count() == len(rows)
|
|
139
|
+
|
|
140
|
+
verify(v1, v2, s)
|
|
141
|
+
|
|
142
|
+
status = t.insert(rows)
|
|
143
|
+
assert status.num_rows == len(rows) * 3 # we also updated 2 views
|
|
144
|
+
assert status.num_excs == 0
|
|
145
|
+
verify(v1, v2, s)
|
|
146
|
+
|
|
147
|
+
cl = pxt.Client(reload=True)
|
|
148
|
+
v1 = cl.get_table('v1')
|
|
149
|
+
v2 = cl.get_table('v2')
|
|
150
|
+
s = cl.get_table('s')
|
|
151
|
+
verify(v1, v2, s)
|
|
152
|
+
|
|
153
|
+
def test_multiple_snapshot_paths(self, test_client: pxt.Client) -> None:
|
|
154
|
+
cl = test_client
|
|
155
|
+
t = create_test_tbl(cl)
|
|
156
|
+
c4 = t.select(t.c4).order_by(t.c2).collect().to_pandas()['c4']
|
|
157
|
+
orig_c3 = t.select(t.c3).collect().to_pandas()['c3']
|
|
158
|
+
v = cl.create_view('v', base=t, schema={'v1': t.c3 + 1})
|
|
159
|
+
s1 = cl.create_view('s1', v, is_snapshot=True)
|
|
160
|
+
t.drop_column('c4')
|
|
161
|
+
# s2 references the same view version as s1, but a different version of t (due to a schema change)
|
|
162
|
+
s2 = cl.create_view('s2', v, is_snapshot=True)
|
|
163
|
+
t.update({'c6': {'a': 17}})
|
|
164
|
+
# s3 references the same view version as s2, but a different version of t (due to a data change)
|
|
165
|
+
s3 = cl.create_view('s3', v, is_snapshot=True)
|
|
166
|
+
t.update({'c3': t.c3 + 1})
|
|
167
|
+
# s4 references different versions of t and v
|
|
168
|
+
s4 = cl.create_view('s4', v, is_snapshot=True)
|
|
169
|
+
|
|
170
|
+
def validate(t: pxt.Table, v: pxt.Table, s1: pxt.Table, s2: pxt.Table, s3: pxt.Table, s4: pxt.Table) -> None:
|
|
171
|
+
# c4 is only visible in s1
|
|
172
|
+
assert np.all(s1.select(s1.c4).collect().to_pandas()['c4'] == c4)
|
|
173
|
+
with pytest.raises(AttributeError):
|
|
174
|
+
_ = t.select(t.c4).collect()
|
|
175
|
+
with pytest.raises(AttributeError):
|
|
176
|
+
_ = v.select(v.c4).collect()
|
|
177
|
+
with pytest.raises(AttributeError):
|
|
178
|
+
_ = s2.select(s2.c4).collect()
|
|
179
|
+
with pytest.raises(AttributeError):
|
|
180
|
+
_ = s3.select(s3.c4).collect()
|
|
181
|
+
with pytest.raises(AttributeError):
|
|
182
|
+
_ = s4.select(s4.c4).collect()
|
|
183
|
+
|
|
184
|
+
# c3
|
|
185
|
+
assert np.all(t.select(t.c3).order_by(t.c2).collect().to_pandas()['c3'] == orig_c3 + 1)
|
|
186
|
+
assert np.all(s1.select(s1.c3).order_by(s1.c2).collect().to_pandas()['c3'] == orig_c3)
|
|
187
|
+
assert np.all(s2.select(s2.c3).order_by(s2.c2).collect().to_pandas()['c3'] == orig_c3)
|
|
188
|
+
assert np.all(s3.select(s3.c3).order_by(s3.c2).collect().to_pandas()['c3'] == orig_c3)
|
|
189
|
+
assert np.all(s4.select(s4.c3).order_by(s4.c2).collect().to_pandas()['c3'] == orig_c3 + 1)
|
|
190
|
+
|
|
191
|
+
# v1
|
|
192
|
+
assert np.all(
|
|
193
|
+
v.select(v.v1).order_by(v.c2).collect().to_pandas()['v1'] == \
|
|
194
|
+
t.select(t.c3).order_by(t.c2).collect().to_pandas()['c3'] + 1)
|
|
195
|
+
assert np.all(s1.select(s1.v1).order_by(s1.c2).collect().to_pandas()['v1'] == orig_c3 + 1)
|
|
196
|
+
assert np.all(s2.select(s2.v1).order_by(s2.c2).collect().to_pandas()['v1'] == orig_c3 + 1)
|
|
197
|
+
assert np.all(s3.select(s3.v1).order_by(s3.c2).collect().to_pandas()['v1'] == orig_c3 + 1)
|
|
198
|
+
assert np.all(
|
|
199
|
+
s4.select(s4.v1).order_by(s4.c2).collect().to_pandas()['v1'] == \
|
|
200
|
+
t.select(t.c3).order_by(t.c2).collect().to_pandas()['c3'] + 1)
|
|
201
|
+
|
|
202
|
+
validate(t, v, s1, s2, s3, s4)
|
|
203
|
+
|
|
204
|
+
# make sure it works after metadata reload
|
|
205
|
+
cl = pxt.Client(reload=True)
|
|
206
|
+
t, v = cl.get_table('test_tbl'), cl.get_table('v')
|
|
207
|
+
s1, s2, s3, s4 = cl.get_table('s1'), cl.get_table('s2'), cl.get_table('s3'), cl.get_table('s4')
|
|
208
|
+
validate(t, v, s1, s2, s3, s4)
|