pixeltable 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (139) hide show
  1. pixeltable/__init__.py +34 -6
  2. pixeltable/catalog/__init__.py +13 -0
  3. pixeltable/catalog/catalog.py +159 -0
  4. pixeltable/catalog/column.py +200 -0
  5. pixeltable/catalog/dir.py +32 -0
  6. pixeltable/catalog/globals.py +33 -0
  7. pixeltable/catalog/insertable_table.py +191 -0
  8. pixeltable/catalog/named_function.py +36 -0
  9. pixeltable/catalog/path.py +58 -0
  10. pixeltable/catalog/path_dict.py +139 -0
  11. pixeltable/catalog/schema_object.py +39 -0
  12. pixeltable/catalog/table.py +581 -0
  13. pixeltable/catalog/table_version.py +749 -0
  14. pixeltable/catalog/table_version_path.py +133 -0
  15. pixeltable/catalog/view.py +203 -0
  16. pixeltable/client.py +520 -30
  17. pixeltable/dataframe.py +540 -349
  18. pixeltable/env.py +373 -45
  19. pixeltable/exceptions.py +12 -21
  20. pixeltable/exec/__init__.py +9 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +113 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +95 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +69 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +225 -0
  31. pixeltable/exprs/__init__.py +24 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +105 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +187 -0
  39. pixeltable/exprs/expr.py +586 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +380 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +115 -0
  44. pixeltable/exprs/image_similarity_predicate.py +58 -0
  45. pixeltable/exprs/inline_array.py +107 -0
  46. pixeltable/exprs/inline_dict.py +101 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +54 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +355 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/type_cast.py +53 -0
  56. pixeltable/exprs/variable.py +45 -0
  57. pixeltable/func/__init__.py +9 -0
  58. pixeltable/func/aggregate_function.py +194 -0
  59. pixeltable/func/batched_function.py +53 -0
  60. pixeltable/func/callable_function.py +69 -0
  61. pixeltable/func/expr_template_function.py +82 -0
  62. pixeltable/func/function.py +110 -0
  63. pixeltable/func/function_registry.py +227 -0
  64. pixeltable/func/globals.py +36 -0
  65. pixeltable/func/nos_function.py +202 -0
  66. pixeltable/func/signature.py +166 -0
  67. pixeltable/func/udf.py +163 -0
  68. pixeltable/functions/__init__.py +52 -103
  69. pixeltable/functions/eval.py +216 -0
  70. pixeltable/functions/fireworks.py +61 -0
  71. pixeltable/functions/huggingface.py +120 -0
  72. pixeltable/functions/image.py +16 -0
  73. pixeltable/functions/openai.py +88 -0
  74. pixeltable/functions/pil/image.py +148 -7
  75. pixeltable/functions/string.py +13 -0
  76. pixeltable/functions/together.py +27 -0
  77. pixeltable/functions/util.py +41 -0
  78. pixeltable/functions/video.py +62 -0
  79. pixeltable/iterators/__init__.py +3 -0
  80. pixeltable/iterators/base.py +48 -0
  81. pixeltable/iterators/document.py +311 -0
  82. pixeltable/iterators/video.py +89 -0
  83. pixeltable/metadata/__init__.py +54 -0
  84. pixeltable/metadata/converters/convert_10.py +18 -0
  85. pixeltable/metadata/schema.py +211 -0
  86. pixeltable/plan.py +656 -0
  87. pixeltable/store.py +413 -182
  88. pixeltable/tests/conftest.py +143 -87
  89. pixeltable/tests/test_audio.py +65 -0
  90. pixeltable/tests/test_catalog.py +27 -0
  91. pixeltable/tests/test_client.py +14 -14
  92. pixeltable/tests/test_component_view.py +372 -0
  93. pixeltable/tests/test_dataframe.py +433 -0
  94. pixeltable/tests/test_dirs.py +78 -62
  95. pixeltable/tests/test_document.py +117 -0
  96. pixeltable/tests/test_exprs.py +591 -135
  97. pixeltable/tests/test_function.py +297 -67
  98. pixeltable/tests/test_functions.py +283 -1
  99. pixeltable/tests/test_migration.py +43 -0
  100. pixeltable/tests/test_nos.py +54 -0
  101. pixeltable/tests/test_snapshot.py +208 -0
  102. pixeltable/tests/test_table.py +1085 -262
  103. pixeltable/tests/test_transactional_directory.py +42 -0
  104. pixeltable/tests/test_types.py +5 -11
  105. pixeltable/tests/test_video.py +149 -34
  106. pixeltable/tests/test_view.py +530 -0
  107. pixeltable/tests/utils.py +186 -45
  108. pixeltable/tool/create_test_db_dump.py +149 -0
  109. pixeltable/type_system.py +490 -126
  110. pixeltable/utils/__init__.py +17 -46
  111. pixeltable/utils/clip.py +12 -15
  112. pixeltable/utils/coco.py +136 -0
  113. pixeltable/utils/documents.py +39 -0
  114. pixeltable/utils/filecache.py +195 -0
  115. pixeltable/utils/help.py +11 -0
  116. pixeltable/utils/media_store.py +76 -0
  117. pixeltable/utils/parquet.py +126 -0
  118. pixeltable/utils/pytorch.py +172 -0
  119. pixeltable/utils/s3.py +13 -0
  120. pixeltable/utils/sql.py +17 -0
  121. pixeltable/utils/transactional_directory.py +35 -0
  122. pixeltable-0.2.0.dist-info/LICENSE +18 -0
  123. pixeltable-0.2.0.dist-info/METADATA +117 -0
  124. pixeltable-0.2.0.dist-info/RECORD +125 -0
  125. {pixeltable-0.1.1.dist-info → pixeltable-0.2.0.dist-info}/WHEEL +1 -1
  126. pixeltable/catalog.py +0 -1421
  127. pixeltable/exprs.py +0 -1745
  128. pixeltable/function.py +0 -269
  129. pixeltable/functions/clip.py +0 -10
  130. pixeltable/functions/pil/__init__.py +0 -23
  131. pixeltable/functions/tf.py +0 -21
  132. pixeltable/index.py +0 -57
  133. pixeltable/tests/test_dict.py +0 -24
  134. pixeltable/tests/test_tf.py +0 -69
  135. pixeltable/tf.py +0 -33
  136. pixeltable/utils/tf.py +0 -33
  137. pixeltable/utils/video.py +0 -32
  138. pixeltable-0.1.1.dist-info/METADATA +0 -31
  139. pixeltable-0.1.1.dist-info/RECORD +0 -36
@@ -1,11 +1,293 @@
1
- import sqlalchemy as sql
1
+ from typing import Dict, Any
2
+
2
3
  import pytest
3
4
 
5
+ import pixeltable as pxt
4
6
  from pixeltable import catalog
7
+ from pixeltable.env import Env
5
8
  from pixeltable.functions.pil.image import blend
9
+ from pixeltable.iterators import FrameIterator
10
+ from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed, get_sentences, get_image_files
11
+ from pixeltable.type_system import VideoType, StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
6
12
 
7
13
 
8
14
  class TestFunctions:
9
15
  def test_pil(self, img_tbl: catalog.Table) -> None:
10
16
  t = img_tbl
11
17
  _ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
18
+
19
+ def test_eval_detections(self, test_client: pxt.Client) -> None:
20
+ skip_test_if_not_installed('nos')
21
+ cl = test_client
22
+ video_t = cl.create_table('video_tbl', {'video': VideoType()})
23
+ # create frame view
24
+ args = {'video': video_t.video, 'fps': 1}
25
+ v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
26
+
27
+ files = get_video_files()
28
+ video_t.insert(video=files[-1])
29
+ v.add_column(frame_s=v.frame.resize([640, 480]))
30
+ from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
31
+ v.add_column(detections_a=yolox_nano(v.frame_s))
32
+ v.add_column(detections_b=yolox_small(v.frame_s))
33
+ v.add_column(gt=yolox_large(v.frame_s))
34
+ from pixeltable.functions.eval import eval_detections, mean_ap
35
+ res = v.select(
36
+ eval_detections(
37
+ v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
38
+ )).show()
39
+ v.add_column(
40
+ eval_a=eval_detections(
41
+ v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
42
+ v.add_column(
43
+ eval_b=eval_detections(
44
+ v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
45
+ ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
46
+ ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
47
+ common_classes = set(ap_a.keys()) & set(ap_b.keys())
48
+
49
+ ## TODO: following assertion is failing on CI,
50
+ # It is not necessarily a bug, as assert codition is not expected to be always true
51
+ # for k in common_classes:
52
+ # assert ap_a[k] <= ap_b[k]
53
+
54
+ def test_str(self, test_client: pxt.Client) -> None:
55
+ cl = test_client
56
+ t = cl.create_table('test_tbl', {'input': StringType()})
57
+ from pixeltable.functions.string import str_format
58
+ t.add_column(s1=str_format('ABC {0}', t.input))
59
+ t.add_column(s2=str_format('DEF {this}', this=t.input))
60
+ t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
61
+ status = t.insert(input='MNO')
62
+ assert status.num_rows == 1
63
+ assert status.num_excs == 0
64
+ row = t.head()[0]
65
+ assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
66
+
67
+ def test_openai(self, test_client: pxt.Client) -> None:
68
+ skip_test_if_not_installed('openai')
69
+ TestFunctions.skip_test_if_no_openai_client()
70
+ if Env.get().openai_client is None:
71
+ pytest.skip(f'OpenAI client does not exist (missing API key?).')
72
+ cl = test_client
73
+ t = cl.create_table('test_tbl', {'input': StringType()})
74
+ from pixeltable.functions.openai import chat_completions, embeddings, moderations
75
+ msgs = [
76
+ {"role": "system", "content": "You are a helpful assistant."},
77
+ {"role": "user", "content": t.input}
78
+ ]
79
+ t.add_column(input_msgs=msgs)
80
+ t.add_column(chat_output=chat_completions(model='gpt-3.5-turbo', messages=t.input_msgs))
81
+ # with inlined messages
82
+ t.add_column(chat_output2=chat_completions(model='gpt-3.5-turbo', messages=msgs))
83
+ t.add_column(ada_embed=embeddings(model='text-embedding-ada-002', input=t.input))
84
+ t.add_column(text_3=embeddings(model='text-embedding-3-small', input=t.input))
85
+ t.add_column(moderation=moderations(input=t.input))
86
+ t.insert(input='I find you really annoying')
87
+ _ = t.head()
88
+
89
+ def test_gpt_4_vision(self, test_client: pxt.Client) -> None:
90
+ skip_test_if_not_installed('openai')
91
+ TestFunctions.skip_test_if_no_openai_client()
92
+ cl = test_client
93
+ t = cl.create_table('test_tbl', {'prompt': StringType(), 'img': ImageType()})
94
+ from pixeltable.functions.openai import chat_completions
95
+ from pixeltable.functions.string import str_format
96
+ msgs = [
97
+ {'role': 'user',
98
+ 'content': [
99
+ {'type': 'text', 'text': t.prompt},
100
+ {'type': 'image_url', 'image_url': {
101
+ 'url': str_format('data:image/png;base64,{0}', t.img.b64_encode())
102
+ }}
103
+ ]}
104
+ ]
105
+ t.add_column(response=chat_completions(model='gpt-4-vision-preview', messages=msgs, max_tokens=300))
106
+ t.add_column(response_content=t.response.choices[0].message.content)
107
+ t.insert(prompt="What's in this image?", img=_sample_image_url)
108
+ result = t.collect()['response_content'][0]
109
+ assert len(result) > 0
110
+
111
+ @staticmethod
112
+ def skip_test_if_no_openai_client() -> None:
113
+ if Env.get().openai_client is None:
114
+ pytest.skip(f'OpenAI client does not exist (missing API key?)')
115
+
116
+ def test_together(self, test_client: pxt.Client) -> None:
117
+ skip_test_if_not_installed('together')
118
+ if not Env.get().has_together_client:
119
+ pytest.skip(f'Together client does not exist (missing API key?)')
120
+ cl = test_client
121
+ t = cl.create_table('test_tbl', {'input': StringType()})
122
+ from pixeltable.functions.together import completions
123
+ t.add_column(output=completions(prompt=t.input, model='mistralai/Mixtral-8x7B-v0.1', stop=['\n']))
124
+ t.add_column(output_text=t.output.output.choices[0].text)
125
+ t.insert(input='I am going to the ')
126
+ result = t.select(t.output_text).collect()['output_text'][0]
127
+ assert len(result) > 0
128
+
129
+ def test_fireworks(self, test_client: pxt.Client) -> None:
130
+ skip_test_if_not_installed('fireworks')
131
+ try:
132
+ from pixeltable.functions.fireworks import initialize
133
+ initialize()
134
+ except:
135
+ pytest.skip(f'Fireworks client does not exist (missing API key?)')
136
+ cl = test_client
137
+ t = cl.create_table('test_tbl', {'input': StringType()})
138
+ from pixeltable.functions.fireworks import chat_completions
139
+ t['output'] = chat_completions(prompt=t.input, model='accounts/fireworks/models/llama-v2-7b-chat', max_tokens=256).choices[0].text
140
+ t.insert(input='I am going to the ')
141
+ result = t.select(t.output).collect()['output'][0]
142
+ assert len(result) > 0
143
+
144
+ def test_hf_function(self, test_client: pxt.Client) -> None:
145
+ skip_test_if_not_installed('sentence_transformers')
146
+ cl = test_client
147
+ t = cl.create_table('test_tbl', {'input': StringType(), 'bool_col': BoolType()})
148
+ from pixeltable.functions.huggingface import sentence_transformer
149
+ model_id = 'intfloat/e5-large-v2'
150
+ t.add_column(e5=sentence_transformer(t.input, model_id=model_id))
151
+ sents = get_sentences()
152
+ status = t.insert({'input': s, 'bool_col': True} for s in sents)
153
+ assert status.num_rows == len(sents)
154
+ assert status.num_excs == 0
155
+
156
+ # verify handling of constant params
157
+ with pytest.raises(ValueError) as exc_info:
158
+ t.add_column(e5_2=sentence_transformer(t.input, model_id=t.input))
159
+ assert ': parameter model_id must be a constant value' in str(exc_info.value)
160
+ with pytest.raises(ValueError) as exc_info:
161
+ t.add_column(e5_2=sentence_transformer(t.input, model_id=model_id, normalize_embeddings=t.bool_col))
162
+ assert ': parameter normalize_embeddings must be a constant value' in str(exc_info.value)
163
+
164
+ # make sure this doesn't cause an exception
165
+ # TODO: is there some way to capture the output?
166
+ t.describe()
167
+
168
+ def test_sentence_transformer(self, test_client: pxt.Client) -> None:
169
+ skip_test_if_not_installed('sentence_transformers')
170
+ cl = test_client
171
+ t = cl.create_table('test_tbl', {'input': StringType(), 'input_list': JsonType()})
172
+ sents = get_sentences(10)
173
+ status = t.insert({'input': s, 'input_list': sents} for s in sents)
174
+ assert status.num_rows == len(sents)
175
+ assert status.num_excs == 0
176
+
177
+ # run multiple models one at a time in order to exercise batching
178
+ from pixeltable.functions.huggingface import sentence_transformer, sentence_transformer_list
179
+ model_ids = ['sentence-transformers/all-mpnet-base-v2', 'BAAI/bge-reranker-base']
180
+ num_dims = [768, 768]
181
+ for idx, model_id in enumerate(model_ids):
182
+ col_name = f'embed{idx}'
183
+ t[col_name] = sentence_transformer(t.input, model_id=model_id, normalize_embeddings=True)
184
+ assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
185
+ list_col_name = f'embed_list{idx}'
186
+ t[list_col_name] = sentence_transformer_list(t.input_list, model_id=model_id, normalize_embeddings=True)
187
+ assert t.column_types()[list_col_name] == JsonType()
188
+
189
+ def verify_row(row: Dict[str, Any]) -> None:
190
+ for idx, (_, d) in enumerate(zip(model_ids, num_dims)):
191
+ assert row[f'embed{idx}'].shape == (d,)
192
+ assert len(row[f'embed_list{idx}']) == len(sents)
193
+ assert all(len(v) == d for v in row[f'embed_list{idx}'])
194
+
195
+ verify_row(t.tail(1)[0])
196
+
197
+ # execution still works after reload
198
+ cl = pxt.Client(reload=True)
199
+ t = cl.get_table('test_tbl')
200
+ status = t.insert({'input': s, 'input_list': sents} for s in sents)
201
+ assert status.num_rows == len(sents)
202
+ assert status.num_excs == 0
203
+ verify_row(t.tail(1)[0])
204
+
205
+ def test_cross_encoder(self, test_client: pxt.Client) -> None:
206
+ skip_test_if_not_installed('sentence_transformers')
207
+ cl = test_client
208
+ t = cl.create_table('test_tbl', {'input': StringType(), 'input_list': JsonType()})
209
+ sents = get_sentences(10)
210
+ status = t.insert({'input': s, 'input_list': sents} for s in sents)
211
+ assert status.num_rows == len(sents)
212
+ assert status.num_excs == 0
213
+
214
+ # run multiple models one at a time in order to exercise batching
215
+ from pixeltable.functions.huggingface import cross_encoder, cross_encoder_list
216
+ model_ids = ['cross-encoder/ms-marco-MiniLM-L-6-v2', 'cross-encoder/ms-marco-TinyBERT-L-2-v2']
217
+ for idx, model_id in enumerate(model_ids):
218
+ col_name = f'embed{idx}'
219
+ t[col_name] = cross_encoder(t.input, t.input, model_id=model_id)
220
+ assert t.column_types()[col_name] == FloatType()
221
+ list_col_name = f'embed_list{idx}'
222
+ t[list_col_name] = cross_encoder_list(t.input, t.input_list, model_id=model_id)
223
+ assert t.column_types()[list_col_name] == JsonType()
224
+
225
+ def verify_row(row: Dict[str, Any]) -> None:
226
+ for i in range(len(model_ids)):
227
+ assert len(row[f'embed_list{idx}']) == len(sents)
228
+ assert all(isinstance(v, float) for v in row[f'embed_list{idx}'])
229
+
230
+ verify_row(t.tail(1)[0])
231
+
232
+ # execution still works after reload
233
+ cl = pxt.Client(reload=True)
234
+ t = cl.get_table('test_tbl')
235
+ status = t.insert({'input': s, 'input_list': sents} for s in sents)
236
+ assert status.num_rows == len(sents)
237
+ assert status.num_excs == 0
238
+ verify_row(t.tail(1)[0])
239
+
240
+ def test_clip(self, test_client: pxt.Client) -> None:
241
+ skip_test_if_not_installed('transformers')
242
+ cl = test_client
243
+ t = cl.create_table('test_tbl', {'text': StringType(), 'img': ImageType()})
244
+ num_rows = 10
245
+ sents = get_sentences(num_rows)
246
+ imgs = get_image_files()[:num_rows]
247
+ status = t.insert({'text': text, 'img': img} for text, img in zip(sents, imgs))
248
+ assert status.num_rows == len(sents)
249
+ assert status.num_excs == 0
250
+
251
+ # run multiple models one at a time in order to exercise batching
252
+ from pixeltable.functions.huggingface import clip_text, clip_image
253
+ model_ids = ['openai/clip-vit-base-patch32', 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K']
254
+ for idx, model_id in enumerate(model_ids):
255
+ col_name = f'embed_text{idx}'
256
+ t[col_name] = clip_text(t.text, model_id=model_id)
257
+ assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
258
+ col_name = f'embed_img{idx}'
259
+ t[col_name] = clip_image(t.img, model_id=model_id)
260
+ assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
261
+
262
+ def verify_row(row: Dict[str, Any]) -> None:
263
+ for idx, _ in enumerate(model_ids):
264
+ assert row[f'embed_text{idx}'].shape == (512,)
265
+ assert row[f'embed_img{idx}'].shape == (512,)
266
+
267
+ verify_row(t.tail(1)[0])
268
+
269
+ # execution still works after reload
270
+ cl = pxt.Client(reload=True)
271
+ t = cl.get_table('test_tbl')
272
+ status = t.insert({'text': text, 'img': img} for text, img in zip(sents, imgs))
273
+ assert status.num_rows == len(sents)
274
+ assert status.num_excs == 0
275
+ verify_row(t.tail(1)[0])
276
+
277
+ def test_detr_for_object_detection(self, test_client: pxt.Client) -> None:
278
+ skip_test_if_not_installed('transformers')
279
+ cl = test_client
280
+ t = cl.create_table('test_tbl', {'img': ImageType()})
281
+ from pixeltable.functions.huggingface import detr_for_object_detection
282
+ t['detect'] = detr_for_object_detection(t.img, model_id='facebook/detr-resnet-50', threshold=0.8)
283
+ status = t.insert(img=_sample_image_url)
284
+ assert status.num_rows == 1
285
+ assert status.num_excs == 0
286
+ result = t.select(t.detect).collect()[0]['detect']
287
+ assert 'orange' in result['label_text']
288
+ assert 'bowl' in result['label_text']
289
+ assert 'broccoli' in result['label_text']
290
+
291
+
292
+ _sample_image_url = \
293
+ 'https://raw.githubusercontent.com/pixeltable/pixeltable/master/docs/source/data/images/000000000009.jpg'
@@ -0,0 +1,43 @@
1
+ import glob
2
+ import logging
3
+ import os
4
+ import subprocess
5
+
6
+ import pgserver
7
+ import pytest
8
+
9
+ import pixeltable as pxt
10
+ from pixeltable.env import Env
11
+ from pixeltable.tests.conftest import clean_db
12
+
13
+ _logger = logging.getLogger('pixeltable')
14
+
15
+
16
+ class TestMigration:
17
+
18
+ @pytest.mark.skip(reason='Suspended')
19
+ def test_db_migration(self, init_env) -> None:
20
+ env = Env.get()
21
+ pg_package_dir = os.path.dirname(pgserver.__file__)
22
+ pg_restore_binary = f'{pg_package_dir}/pginstall/bin/pg_restore'
23
+ _logger.info(f'Using pg_restore binary at: {pg_restore_binary}')
24
+ dump_files = glob.glob('pixeltable/tests/data/dbdumps/*.dump.gz')
25
+ dump_files.sort()
26
+ for dump_file in dump_files:
27
+ _logger.info(f'Testing migration from DB dump {dump_file}.')
28
+ _logger.info(f'DB URL: {env.db_url}')
29
+ clean_db(restore_tables=False)
30
+ with open(dump_file, 'rb') as dump:
31
+ gunzip_process = subprocess.Popen(
32
+ ["gunzip", "-c"],
33
+ stdin=dump,
34
+ stdout=subprocess.PIPE
35
+ )
36
+ subprocess.run(
37
+ [pg_restore_binary, '-d', env.db_url, '-U', 'postgres'],
38
+ stdin=gunzip_process.stdout,
39
+ check=True
40
+ )
41
+ # TODO(aaron-siegel) This will test that the migration succeeds without raising any exceptions.
42
+ # We should also add some assertions to sanity-check the outcome.
43
+ _ = pxt.Client()
@@ -0,0 +1,54 @@
1
+ import pytest
2
+
3
+ import pixeltable as pxt
4
+ from pixeltable.iterators import FrameIterator
5
+ from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed
6
+ from pixeltable.type_system import ImageType, VideoType
7
+
8
+
9
+ class TestNOS:
10
+ def test_basic(self, test_client: pxt.Client) -> None:
11
+ skip_test_if_not_installed('nos')
12
+ cl = test_client
13
+ video_t = cl.create_table('video_tbl', {'video': VideoType()})
14
+ # create frame view
15
+ args = {'video': video_t.video, 'fps': 1}
16
+ v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
17
+ v.add_column(transform1=v.frame.rotate(30), stored=False)
18
+ from pixeltable.functions.nos.object_detection_2d import \
19
+ torchvision_fasterrcnn_mobilenet_v3_large_320_fpn as fasterrcnn
20
+ v.add_column(detections=fasterrcnn(v.transform1))
21
+ from pixeltable.functions.nos.image_embedding import openai_clip
22
+ v.add_column(embed=openai_clip(v.transform1.resize([224, 224])))
23
+ # add a stored column that isn't referenced in nos calls
24
+ v.add_column(transform2=v.frame.rotate(60), stored=True)
25
+
26
+ status = video_t.insert(video=get_video_files()[0])
27
+ pass
28
+
29
+ def test_exceptions(self, test_client: pxt.Client) -> None:
30
+ skip_test_if_not_installed('nos')
31
+ cl = test_client
32
+ video_t = cl.create_table('video_tbl', {'video': VideoType()})
33
+ # create frame view
34
+ args = {'video': video_t.video, 'fps': 1}
35
+ v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
36
+ video_t.insert(video=get_video_files()[0])
37
+
38
+ v.add_column(frame_s=v.frame.resize([640, 480]))
39
+ # 'rotated' has exceptions
40
+ v.add_column(rotated=lambda frame_s, frame_idx: frame_s.rotate(int(360 / frame_idx)), type=ImageType())
41
+ from pixeltable.functions.nos.object_detection_2d import yolox_medium
42
+ v.add_column(detections=yolox_medium(v.rotated), stored=True)
43
+ assert v.where(v.detections.errortype != None).count() == 1
44
+
45
+ @pytest.mark.skip(reason='too slow')
46
+ def test_sd(self, test_client: pxt.Client) -> None:
47
+ skip_test_if_not_installed('nos')
48
+ """Test model that mixes batched with scalar parameters"""
49
+ t = test_client.create_table('sd_test', {'prompt': pxt.StringType()})
50
+ t.insert(prompt='cat on a sofa')
51
+ from pixeltable.functions.nos.image_generation import stabilityai_stable_diffusion_2 as sd2
52
+ t.add_column(img=sd2(t.prompt, 1, 512, 512), stored=True)
53
+ img = t[t.img].show(1)[0, 0]
54
+ assert img.size == (512, 512)
@@ -0,0 +1,208 @@
1
+ from typing import Any, Dict
2
+
3
+ import numpy as np
4
+ import pytest
5
+
6
+ import pixeltable as pxt
7
+ import pixeltable.exceptions as excs
8
+ from pixeltable.tests.utils import create_test_tbl, assert_resultset_eq
9
+ from pixeltable.type_system import IntType
10
+
11
+
12
+ class TestSnapshot:
13
+ def run_basic_test(
14
+ self, cl: pxt.Client, tbl: pxt.Table, snap: pxt.Table, extra_items: Dict[str, Any], filter: Any,
15
+ reload_md: bool
16
+ ) -> None:
17
+ tbl_path, snap_path = cl.get_path(tbl), cl.get_path(snap)
18
+ # run the initial query against the base table here, before reloading, otherwise the filter breaks
19
+ tbl_select_list = [tbl[col_name] for col_name in tbl.column_names()]
20
+ tbl_select_list.extend([value_expr for _, value_expr in extra_items.items()])
21
+ orig_resultset = tbl.select(*tbl_select_list).where(filter).order_by(tbl.c2).collect()
22
+
23
+ if reload_md:
24
+ # reload md
25
+ cl = pxt.Client(reload=True)
26
+ tbl = cl.get_table(tbl_path)
27
+ snap = cl.get_table(snap_path)
28
+
29
+ # view select list: base cols followed by view cols
30
+ snap_select_list = [snap[col_name] for col_name in snap.column_names()[len(extra_items):]]
31
+ snap_select_list.extend([snap[col_name] for col_name in extra_items.keys()])
32
+ snap_query = snap.select(*snap_select_list).order_by(snap.c2)
33
+ r1 = list(orig_resultset)
34
+ r2 = list(snap_query.collect())
35
+ assert_resultset_eq(snap_query.collect(), orig_resultset)
36
+
37
+ # adding data to a base table doesn't change the snapshot
38
+ rows = list(tbl.select(tbl.c1, tbl.c1n, tbl.c2, tbl.c3, tbl.c4, tbl.c5, tbl.c6, tbl.c7).collect())
39
+ status = tbl.insert(rows)
40
+ assert status.num_rows == len(rows)
41
+ assert_resultset_eq(snap_query.collect(), orig_resultset)
42
+
43
+ # update() doesn't affect the view
44
+ status = tbl.update({'c3': tbl.c3 + 1.0})
45
+ assert status.num_rows == tbl.count()
46
+ assert_resultset_eq(snap_query.collect(), orig_resultset)
47
+
48
+ # delete() doesn't affect the view
49
+ num_tbl_rows = tbl.count()
50
+ status = tbl.delete()
51
+ assert status.num_rows == num_tbl_rows
52
+ assert_resultset_eq(snap_query.collect(), orig_resultset)
53
+
54
+ tbl.revert() # undo delete()
55
+ tbl.revert() # undo update()
56
+ tbl.revert() # undo insert()
57
+ # can't revert a version referenced by a snapshot
58
+ with pytest.raises(excs.Error) as excinfo:
59
+ tbl.revert()
60
+ assert 'version is needed' in str(excinfo.value)
61
+
62
+ # can't drop a table with snapshots
63
+ with pytest.raises(excs.Error) as excinfo:
64
+ cl.drop_table(tbl_path)
65
+ assert snap_path in str(excinfo.value)
66
+
67
+ cl.drop_table(snap_path)
68
+ cl.drop_table(tbl_path)
69
+
70
+ def test_basic(self, test_client: pxt.Client) -> None:
71
+ cl = test_client
72
+ cl.create_dir('main')
73
+ cl.create_dir('snap')
74
+ tbl_path = 'main.tbl1'
75
+ snap_path = 'snap.snap1'
76
+
77
+ for reload_md in [False, True]:
78
+ for has_filter in [False, True]:
79
+ for has_cols in [False, True]:
80
+ cl = pxt.Client(reload=True)
81
+ tbl = create_test_tbl(name=tbl_path, client=cl)
82
+ schema = {
83
+ 'v1': tbl.c3 * 2.0,
84
+ # include a lambda to make sure that is handled correctly
85
+ 'v2': {'value': lambda c3: c3 * 2.0, 'type': pxt.FloatType()}
86
+ } if has_cols else {}
87
+ extra_items = {'v1': tbl.c3 * 2.0, 'v2': tbl.c3 * 2.0} if has_cols else {}
88
+ filter = tbl.c2 < 10 if has_filter else None
89
+ snap = cl.create_view(snap_path, tbl, schema=schema, filter=filter, is_snapshot=True)
90
+ self.run_basic_test(cl, tbl, snap, extra_items=extra_items, filter=filter, reload_md=reload_md)
91
+
92
+ def test_views_of_snapshots(self, test_client: pxt.Client) -> None:
93
+ cl = test_client
94
+ t = cl.create_table('tbl', {'a': IntType()})
95
+ rows = [{'a': 1}, {'a': 2}, {'a': 3}]
96
+ status = t.insert(rows)
97
+ assert status.num_rows == len(rows)
98
+ assert status.num_excs == 0
99
+ s1 = cl.create_view('s1', t, is_snapshot=True)
100
+ v1 = cl.create_view('v1', s1, is_snapshot=False)
101
+ s2 = cl.create_view('s2', v1, is_snapshot=True)
102
+ v2 = cl.create_view('v2', s2, is_snapshot=False)
103
+
104
+ def verify(s1: pxt.Table, s2: pxt.Table, v1: pxt.Table, v2: pxt.Table) -> None:
105
+ assert s1.count() == len(rows)
106
+ assert v1.count() == len(rows)
107
+ assert s2.count() == len(rows)
108
+ assert v2.count() == len(rows)
109
+
110
+ verify(s1, s2, v1, v2)
111
+
112
+ status = t.insert(rows)
113
+ assert status.num_rows == len(rows)
114
+ assert status.num_excs == 0
115
+ verify(s1, s2, v1, v2)
116
+
117
+ cl = pxt.Client(reload=True)
118
+ s1 = cl.get_table('s1')
119
+ s2 = cl.get_table('s2')
120
+ v1 = cl.get_table('v1')
121
+ v2 = cl.get_table('v2')
122
+ verify(s1, s2, v1, v2)
123
+
124
+ def test_snapshot_of_view_chain(self, test_client: pxt.Client) -> None:
125
+ cl = test_client
126
+ t = cl.create_table('tbl', {'a': IntType()})
127
+ rows = [{'a': 1}, {'a': 2}, {'a': 3}]
128
+ status = t.insert(rows)
129
+ assert status.num_rows == len(rows)
130
+ assert status.num_excs == 0
131
+ v1 = cl.create_view('v1', t, is_snapshot=False)
132
+ v2 = cl.create_view('v2', v1, is_snapshot=False)
133
+ s = cl.create_view('s', v2, is_snapshot=True)
134
+
135
+ def verify(v1: pxt.Table, v2: pxt.Table, s: pxt.Table) -> None:
136
+ assert v1.count() == t.count()
137
+ assert v2.count() == t.count()
138
+ assert s.count() == len(rows)
139
+
140
+ verify(v1, v2, s)
141
+
142
+ status = t.insert(rows)
143
+ assert status.num_rows == len(rows) * 3 # we also updated 2 views
144
+ assert status.num_excs == 0
145
+ verify(v1, v2, s)
146
+
147
+ cl = pxt.Client(reload=True)
148
+ v1 = cl.get_table('v1')
149
+ v2 = cl.get_table('v2')
150
+ s = cl.get_table('s')
151
+ verify(v1, v2, s)
152
+
153
+ def test_multiple_snapshot_paths(self, test_client: pxt.Client) -> None:
154
+ cl = test_client
155
+ t = create_test_tbl(cl)
156
+ c4 = t.select(t.c4).order_by(t.c2).collect().to_pandas()['c4']
157
+ orig_c3 = t.select(t.c3).collect().to_pandas()['c3']
158
+ v = cl.create_view('v', base=t, schema={'v1': t.c3 + 1})
159
+ s1 = cl.create_view('s1', v, is_snapshot=True)
160
+ t.drop_column('c4')
161
+ # s2 references the same view version as s1, but a different version of t (due to a schema change)
162
+ s2 = cl.create_view('s2', v, is_snapshot=True)
163
+ t.update({'c6': {'a': 17}})
164
+ # s3 references the same view version as s2, but a different version of t (due to a data change)
165
+ s3 = cl.create_view('s3', v, is_snapshot=True)
166
+ t.update({'c3': t.c3 + 1})
167
+ # s4 references different versions of t and v
168
+ s4 = cl.create_view('s4', v, is_snapshot=True)
169
+
170
+ def validate(t: pxt.Table, v: pxt.Table, s1: pxt.Table, s2: pxt.Table, s3: pxt.Table, s4: pxt.Table) -> None:
171
+ # c4 is only visible in s1
172
+ assert np.all(s1.select(s1.c4).collect().to_pandas()['c4'] == c4)
173
+ with pytest.raises(AttributeError):
174
+ _ = t.select(t.c4).collect()
175
+ with pytest.raises(AttributeError):
176
+ _ = v.select(v.c4).collect()
177
+ with pytest.raises(AttributeError):
178
+ _ = s2.select(s2.c4).collect()
179
+ with pytest.raises(AttributeError):
180
+ _ = s3.select(s3.c4).collect()
181
+ with pytest.raises(AttributeError):
182
+ _ = s4.select(s4.c4).collect()
183
+
184
+ # c3
185
+ assert np.all(t.select(t.c3).order_by(t.c2).collect().to_pandas()['c3'] == orig_c3 + 1)
186
+ assert np.all(s1.select(s1.c3).order_by(s1.c2).collect().to_pandas()['c3'] == orig_c3)
187
+ assert np.all(s2.select(s2.c3).order_by(s2.c2).collect().to_pandas()['c3'] == orig_c3)
188
+ assert np.all(s3.select(s3.c3).order_by(s3.c2).collect().to_pandas()['c3'] == orig_c3)
189
+ assert np.all(s4.select(s4.c3).order_by(s4.c2).collect().to_pandas()['c3'] == orig_c3 + 1)
190
+
191
+ # v1
192
+ assert np.all(
193
+ v.select(v.v1).order_by(v.c2).collect().to_pandas()['v1'] == \
194
+ t.select(t.c3).order_by(t.c2).collect().to_pandas()['c3'] + 1)
195
+ assert np.all(s1.select(s1.v1).order_by(s1.c2).collect().to_pandas()['v1'] == orig_c3 + 1)
196
+ assert np.all(s2.select(s2.v1).order_by(s2.c2).collect().to_pandas()['v1'] == orig_c3 + 1)
197
+ assert np.all(s3.select(s3.v1).order_by(s3.c2).collect().to_pandas()['v1'] == orig_c3 + 1)
198
+ assert np.all(
199
+ s4.select(s4.v1).order_by(s4.c2).collect().to_pandas()['v1'] == \
200
+ t.select(t.c3).order_by(t.c2).collect().to_pandas()['c3'] + 1)
201
+
202
+ validate(t, v, s1, s2, s3, s4)
203
+
204
+ # make sure it works after metadata reload
205
+ cl = pxt.Client(reload=True)
206
+ t, v = cl.get_table('test_tbl'), cl.get_table('v')
207
+ s1, s2, s3, s4 = cl.get_table('s1'), cl.get_table('s2'), cl.get_table('s3'), cl.get_table('s4')
208
+ validate(t, v, s1, s2, s3, s4)