pixeltable 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +21 -4
- pixeltable/catalog/__init__.py +13 -0
- pixeltable/catalog/catalog.py +159 -0
- pixeltable/catalog/column.py +200 -0
- pixeltable/catalog/dir.py +32 -0
- pixeltable/catalog/globals.py +33 -0
- pixeltable/catalog/insertable_table.py +191 -0
- pixeltable/catalog/named_function.py +36 -0
- pixeltable/catalog/path.py +58 -0
- pixeltable/catalog/path_dict.py +139 -0
- pixeltable/catalog/schema_object.py +39 -0
- pixeltable/catalog/table.py +581 -0
- pixeltable/catalog/table_version.py +749 -0
- pixeltable/catalog/table_version_path.py +133 -0
- pixeltable/catalog/view.py +203 -0
- pixeltable/client.py +520 -31
- pixeltable/dataframe.py +540 -349
- pixeltable/env.py +373 -48
- pixeltable/exceptions.py +12 -21
- pixeltable/exec/__init__.py +9 -0
- pixeltable/exec/aggregation_node.py +78 -0
- pixeltable/exec/cache_prefetch_node.py +113 -0
- pixeltable/exec/component_iteration_node.py +79 -0
- pixeltable/exec/data_row_batch.py +95 -0
- pixeltable/exec/exec_context.py +22 -0
- pixeltable/exec/exec_node.py +61 -0
- pixeltable/exec/expr_eval_node.py +217 -0
- pixeltable/exec/in_memory_data_node.py +69 -0
- pixeltable/exec/media_validation_node.py +43 -0
- pixeltable/exec/sql_scan_node.py +225 -0
- pixeltable/exprs/__init__.py +24 -0
- pixeltable/exprs/arithmetic_expr.py +102 -0
- pixeltable/exprs/array_slice.py +71 -0
- pixeltable/exprs/column_property_ref.py +77 -0
- pixeltable/exprs/column_ref.py +105 -0
- pixeltable/exprs/comparison.py +77 -0
- pixeltable/exprs/compound_predicate.py +98 -0
- pixeltable/exprs/data_row.py +187 -0
- pixeltable/exprs/expr.py +586 -0
- pixeltable/exprs/expr_set.py +39 -0
- pixeltable/exprs/function_call.py +380 -0
- pixeltable/exprs/globals.py +69 -0
- pixeltable/exprs/image_member_access.py +115 -0
- pixeltable/exprs/image_similarity_predicate.py +58 -0
- pixeltable/exprs/inline_array.py +107 -0
- pixeltable/exprs/inline_dict.py +101 -0
- pixeltable/exprs/is_null.py +38 -0
- pixeltable/exprs/json_mapper.py +121 -0
- pixeltable/exprs/json_path.py +159 -0
- pixeltable/exprs/literal.py +54 -0
- pixeltable/exprs/object_ref.py +41 -0
- pixeltable/exprs/predicate.py +44 -0
- pixeltable/exprs/row_builder.py +355 -0
- pixeltable/exprs/rowid_ref.py +94 -0
- pixeltable/exprs/type_cast.py +53 -0
- pixeltable/exprs/variable.py +45 -0
- pixeltable/func/__init__.py +9 -0
- pixeltable/func/aggregate_function.py +194 -0
- pixeltable/func/batched_function.py +53 -0
- pixeltable/func/callable_function.py +69 -0
- pixeltable/func/expr_template_function.py +82 -0
- pixeltable/func/function.py +110 -0
- pixeltable/func/function_registry.py +227 -0
- pixeltable/func/globals.py +36 -0
- pixeltable/func/nos_function.py +202 -0
- pixeltable/func/signature.py +166 -0
- pixeltable/func/udf.py +163 -0
- pixeltable/functions/__init__.py +52 -103
- pixeltable/functions/eval.py +216 -0
- pixeltable/functions/fireworks.py +61 -0
- pixeltable/functions/huggingface.py +120 -0
- pixeltable/functions/image.py +16 -0
- pixeltable/functions/openai.py +88 -0
- pixeltable/functions/pil/image.py +148 -7
- pixeltable/functions/string.py +13 -0
- pixeltable/functions/together.py +27 -0
- pixeltable/functions/util.py +41 -0
- pixeltable/functions/video.py +62 -0
- pixeltable/iterators/__init__.py +3 -0
- pixeltable/iterators/base.py +48 -0
- pixeltable/iterators/document.py +311 -0
- pixeltable/iterators/video.py +89 -0
- pixeltable/metadata/__init__.py +54 -0
- pixeltable/metadata/converters/convert_10.py +18 -0
- pixeltable/metadata/schema.py +211 -0
- pixeltable/plan.py +656 -0
- pixeltable/store.py +413 -182
- pixeltable/tests/conftest.py +143 -86
- pixeltable/tests/test_audio.py +65 -0
- pixeltable/tests/test_catalog.py +27 -0
- pixeltable/tests/test_client.py +14 -14
- pixeltable/tests/test_component_view.py +372 -0
- pixeltable/tests/test_dataframe.py +433 -0
- pixeltable/tests/test_dirs.py +78 -62
- pixeltable/tests/test_document.py +117 -0
- pixeltable/tests/test_exprs.py +591 -135
- pixeltable/tests/test_function.py +297 -67
- pixeltable/tests/test_functions.py +283 -1
- pixeltable/tests/test_migration.py +43 -0
- pixeltable/tests/test_nos.py +54 -0
- pixeltable/tests/test_snapshot.py +208 -0
- pixeltable/tests/test_table.py +1086 -258
- pixeltable/tests/test_transactional_directory.py +42 -0
- pixeltable/tests/test_types.py +5 -11
- pixeltable/tests/test_video.py +149 -34
- pixeltable/tests/test_view.py +530 -0
- pixeltable/tests/utils.py +186 -45
- pixeltable/tool/create_test_db_dump.py +149 -0
- pixeltable/type_system.py +490 -133
- pixeltable/utils/__init__.py +17 -46
- pixeltable/utils/clip.py +12 -15
- pixeltable/utils/coco.py +136 -0
- pixeltable/utils/documents.py +39 -0
- pixeltable/utils/filecache.py +195 -0
- pixeltable/utils/help.py +11 -0
- pixeltable/utils/media_store.py +76 -0
- pixeltable/utils/parquet.py +126 -0
- pixeltable/utils/pytorch.py +172 -0
- pixeltable/utils/s3.py +13 -0
- pixeltable/utils/sql.py +17 -0
- pixeltable/utils/transactional_directory.py +35 -0
- pixeltable-0.2.0.dist-info/LICENSE +18 -0
- pixeltable-0.2.0.dist-info/METADATA +117 -0
- pixeltable-0.2.0.dist-info/RECORD +125 -0
- {pixeltable-0.1.2.dist-info → pixeltable-0.2.0.dist-info}/WHEEL +1 -1
- pixeltable/catalog.py +0 -1421
- pixeltable/exprs.py +0 -1745
- pixeltable/function.py +0 -269
- pixeltable/functions/clip.py +0 -10
- pixeltable/functions/pil/__init__.py +0 -23
- pixeltable/functions/tf.py +0 -21
- pixeltable/index.py +0 -57
- pixeltable/tests/test_dict.py +0 -24
- pixeltable/tests/test_tf.py +0 -69
- pixeltable/tf.py +0 -33
- pixeltable/utils/tf.py +0 -33
- pixeltable/utils/video.py +0 -32
- pixeltable-0.1.2.dist-info/LICENSE +0 -201
- pixeltable-0.1.2.dist-info/METADATA +0 -89
- pixeltable-0.1.2.dist-info/RECORD +0 -37
pixeltable/tests/test_exprs.py
CHANGED
|
@@ -1,45 +1,50 @@
|
|
|
1
|
-
import
|
|
1
|
+
import json
|
|
2
|
+
import urllib.parse
|
|
3
|
+
from typing import List, Dict
|
|
4
|
+
|
|
2
5
|
import pytest
|
|
6
|
+
import sqlalchemy as sql
|
|
3
7
|
|
|
8
|
+
import pixeltable as pxt
|
|
9
|
+
import pixeltable.func as func
|
|
4
10
|
from pixeltable import catalog
|
|
5
|
-
from pixeltable
|
|
6
|
-
from pixeltable
|
|
7
|
-
from pixeltable.exprs import Expr,
|
|
11
|
+
from pixeltable import exceptions as excs
|
|
12
|
+
from pixeltable import exprs
|
|
13
|
+
from pixeltable.exprs import Expr, ColumnRef
|
|
8
14
|
from pixeltable.exprs import RELATIVE_PATH_ROOT as R
|
|
9
|
-
from pixeltable.functions import
|
|
15
|
+
from pixeltable.functions import cast, sum, count
|
|
10
16
|
from pixeltable.functions.pil.image import blend
|
|
11
|
-
from pixeltable.
|
|
12
|
-
from pixeltable import
|
|
13
|
-
from pixeltable.
|
|
17
|
+
from pixeltable.iterators import FrameIterator
|
|
18
|
+
from pixeltable.tests.utils import get_image_files, skip_test_if_not_installed
|
|
19
|
+
from pixeltable.type_system import StringType, BoolType, IntType, ArrayType, ColumnType, FloatType, \
|
|
20
|
+
VideoType
|
|
14
21
|
|
|
15
22
|
|
|
16
23
|
class TestExprs:
|
|
17
|
-
# This breaks with exception 'cannot pickle _thread._local obj'
|
|
18
|
-
# sum = Function(
|
|
19
|
-
# IntType(), [IntType()],
|
|
20
|
-
# init_fn=lambda: TestExprs.SumAggregator(), update_fn=SumAggregator.update, value_fn=SumAggregator.value)
|
|
21
|
-
|
|
22
24
|
def test_basic(self, test_tbl: catalog.Table) -> None:
|
|
23
25
|
t = test_tbl
|
|
24
|
-
assert
|
|
26
|
+
assert t['c1'].equals(t.c1)
|
|
27
|
+
assert t['c7']['*'].f5.equals(t.c7['*'].f5)
|
|
28
|
+
|
|
29
|
+
assert isinstance(t.c1 == None, Expr)
|
|
25
30
|
assert isinstance(t.c1 < 'a', Expr)
|
|
26
|
-
assert isinstance(t['c1'] <= 'a', Expr)
|
|
27
31
|
assert isinstance(t.c1 <= 'a', Expr)
|
|
28
|
-
assert isinstance(t['c1'] == 'a', Expr)
|
|
29
32
|
assert isinstance(t.c1 == 'a', Expr)
|
|
30
|
-
assert isinstance(t['c1'] != 'a', Expr)
|
|
31
33
|
assert isinstance(t.c1 != 'a', Expr)
|
|
32
|
-
assert isinstance(t['c1'] > 'a', Expr)
|
|
33
34
|
assert isinstance(t.c1 > 'a', Expr)
|
|
34
|
-
assert isinstance(t['c1'] >= 'a', Expr)
|
|
35
35
|
assert isinstance(t.c1 >= 'a', Expr)
|
|
36
36
|
assert isinstance((t.c1 == 'a') & (t.c2 < 5), Expr)
|
|
37
37
|
assert isinstance((t.c1 == 'a') | (t.c2 < 5), Expr)
|
|
38
38
|
assert isinstance(~(t.c1 == 'a'), Expr)
|
|
39
|
+
with pytest.raises(AttributeError) as excinfo:
|
|
40
|
+
_ = t.does_not_exist
|
|
41
|
+
assert 'unknown' in str(excinfo.value).lower()
|
|
39
42
|
|
|
40
43
|
def test_compound_predicates(self, test_tbl: catalog.Table) -> None:
|
|
41
44
|
t = test_tbl
|
|
42
45
|
# compound predicates that can be fully evaluated in SQL
|
|
46
|
+
_ = t.where((t.c1 == 'test string') & (t.c6.f1 > 50)).collect()
|
|
47
|
+
_ = t.where((t.c1 == 'test string') & (t.c2 > 50)).collect()
|
|
43
48
|
e = ((t.c1 == 'test string') & (t.c2 > 50)).sql_expr()
|
|
44
49
|
assert len(e.clauses) == 2
|
|
45
50
|
|
|
@@ -52,46 +57,192 @@ class TestExprs:
|
|
|
52
57
|
e = (~(t.c1 == 'test string')).sql_expr()
|
|
53
58
|
assert isinstance(e, sql.sql.expression.BinaryExpression)
|
|
54
59
|
|
|
60
|
+
with pytest.raises(TypeError) as exc_info:
|
|
61
|
+
_ = t.where((t.c1 == 'test string') or (t.c6.f1 > 50)).collect()
|
|
62
|
+
assert 'cannot be used in conjunction with python boolean operators' in str(exc_info.value).lower()
|
|
63
|
+
|
|
55
64
|
# compound predicates with Python functions
|
|
56
|
-
udf
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
sql_pred,
|
|
69
|
-
assert
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
assert
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
#
|
|
79
|
-
|
|
80
|
-
assert
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
65
|
+
@pxt.udf(return_type=BoolType(), param_types=[StringType()])
|
|
66
|
+
def udf(_: str) -> bool:
|
|
67
|
+
return True
|
|
68
|
+
@pxt.udf(return_type=BoolType(), param_types=[IntType()])
|
|
69
|
+
def udf2(_: int) -> bool:
|
|
70
|
+
return True
|
|
71
|
+
|
|
72
|
+
# TODO: find a way to test this
|
|
73
|
+
# # & can be split
|
|
74
|
+
# p = (t.c1 == 'test string') & udf(t.c1)
|
|
75
|
+
# assert p.sql_expr() is None
|
|
76
|
+
# sql_pred, other_pred = p.extract_sql_predicate()
|
|
77
|
+
# assert isinstance(sql_pred, sql.sql.expression.BinaryExpression)
|
|
78
|
+
# assert isinstance(other_pred, FunctionCall)
|
|
79
|
+
#
|
|
80
|
+
# p = (t.c1 == 'test string') & udf(t.c1) & (t.c2 > 50)
|
|
81
|
+
# assert p.sql_expr() is None
|
|
82
|
+
# sql_pred, other_pred = p.extract_sql_predicate()
|
|
83
|
+
# assert len(sql_pred.clauses) == 2
|
|
84
|
+
# assert isinstance(other_pred, FunctionCall)
|
|
85
|
+
#
|
|
86
|
+
# p = (t.c1 == 'test string') & udf(t.c1) & (t.c2 > 50) & udf2(t.c2)
|
|
87
|
+
# assert p.sql_expr() is None
|
|
88
|
+
# sql_pred, other_pred = p.extract_sql_predicate()
|
|
89
|
+
# assert len(sql_pred.clauses) == 2
|
|
90
|
+
# assert isinstance(other_pred, CompoundPredicate)
|
|
91
|
+
#
|
|
92
|
+
# # | cannot be split
|
|
93
|
+
# p = (t.c1 == 'test string') | udf(t.c1)
|
|
94
|
+
# assert p.sql_expr() is None
|
|
95
|
+
# sql_pred, other_pred = p.extract_sql_predicate()
|
|
96
|
+
# assert sql_pred is None
|
|
97
|
+
# assert isinstance(other_pred, CompoundPredicate)
|
|
98
|
+
|
|
99
|
+
def test_filters(self, test_tbl: catalog.Table) -> None:
|
|
86
100
|
t = test_tbl
|
|
87
101
|
_ = t[t.c1 == 'test string'].show()
|
|
88
102
|
print(_)
|
|
89
103
|
_ = t[t.c2 > 50].show()
|
|
90
104
|
print(_)
|
|
105
|
+
_ = t[t.c1n == None].show()
|
|
106
|
+
print(_)
|
|
107
|
+
_ = t[t.c1n != None].show(0)
|
|
108
|
+
print(_)
|
|
109
|
+
|
|
110
|
+
def test_exception_handling(self, test_tbl: catalog.Table) -> None:
|
|
111
|
+
t = test_tbl
|
|
112
|
+
|
|
113
|
+
# error in expr that's handled in SQL
|
|
114
|
+
with pytest.raises(excs.Error):
|
|
115
|
+
_ = t[(t.c2 + 1) / t.c2].show()
|
|
116
|
+
|
|
117
|
+
# error in expr that's handled in Python
|
|
118
|
+
with pytest.raises(excs.Error):
|
|
119
|
+
_ = t[(t.c6.f2 + 1) / (t.c2 - 10)].show()
|
|
120
|
+
|
|
121
|
+
# the same, but with an inline function
|
|
122
|
+
@pxt.udf(return_type=FloatType(), param_types=[IntType(), IntType()])
|
|
123
|
+
def f(a: int, b: int) -> float:
|
|
124
|
+
return a / b
|
|
125
|
+
with pytest.raises(excs.Error):
|
|
126
|
+
_ = t[f(t.c2 + 1, t.c2)].show()
|
|
127
|
+
|
|
128
|
+
# error in agg.init()
|
|
129
|
+
@pxt.uda(update_types=[IntType()], value_type=IntType(), name='agg')
|
|
130
|
+
class Aggregator(pxt.Aggregator):
|
|
131
|
+
def __init__(self):
|
|
132
|
+
self.sum = 1 / 0
|
|
133
|
+
def update(self, val):
|
|
134
|
+
pass
|
|
135
|
+
def value(self):
|
|
136
|
+
return 1
|
|
137
|
+
with pytest.raises(excs.Error):
|
|
138
|
+
_ = t[agg(t.c2)].show()
|
|
139
|
+
|
|
140
|
+
# error in agg.update()
|
|
141
|
+
@pxt.uda(update_types=[IntType()], value_type=IntType(), name='agg')
|
|
142
|
+
class Aggregator(pxt.Aggregator):
|
|
143
|
+
def __init__(self):
|
|
144
|
+
self.sum = 0
|
|
145
|
+
def update(self, val):
|
|
146
|
+
self.sum += 1 / val
|
|
147
|
+
def value(self):
|
|
148
|
+
return 1
|
|
149
|
+
with pytest.raises(excs.Error):
|
|
150
|
+
_ = t[agg(t.c2 - 10)].show()
|
|
151
|
+
|
|
152
|
+
# error in agg.value()
|
|
153
|
+
@pxt.uda(update_types=[IntType()], value_type=IntType(), name='agg')
|
|
154
|
+
class Aggregator(pxt.Aggregator):
|
|
155
|
+
def __init__(self):
|
|
156
|
+
self.sum = 0
|
|
157
|
+
def update(self, val):
|
|
158
|
+
self.sum += val
|
|
159
|
+
def value(self):
|
|
160
|
+
return 1 / self.sum
|
|
161
|
+
with pytest.raises(excs.Error):
|
|
162
|
+
_ = t[t.c2 <= 2][agg(t.c2 - 1)].show()
|
|
163
|
+
|
|
164
|
+
def test_props(self, test_tbl: catalog.Table, img_tbl: catalog.Table) -> None:
|
|
165
|
+
t = test_tbl
|
|
166
|
+
# errortype/-msg for computed column
|
|
167
|
+
res = t.select(error=t.c8.errortype).collect()
|
|
168
|
+
assert res.to_pandas()['error'].isna().all()
|
|
169
|
+
res = t.select(error=t.c8.errormsg).collect()
|
|
170
|
+
assert res.to_pandas()['error'].isna().all()
|
|
171
|
+
|
|
172
|
+
img_t = img_tbl
|
|
173
|
+
# fileurl
|
|
174
|
+
res = img_t.select(img_t.img.fileurl).show(0).to_pandas()
|
|
175
|
+
stored_urls = set(res.iloc[:, 0])
|
|
176
|
+
assert len(stored_urls) == len(res)
|
|
177
|
+
all_urls = set([urllib.parse.urljoin('file:', path) for path in get_image_files()])
|
|
178
|
+
assert stored_urls <= all_urls
|
|
179
|
+
|
|
180
|
+
# localpath
|
|
181
|
+
res = img_t.select(img_t.img.localpath).show(0).to_pandas()
|
|
182
|
+
stored_paths = set(res.iloc[:, 0])
|
|
183
|
+
assert len(stored_paths) == len(res)
|
|
184
|
+
all_paths = set(get_image_files())
|
|
185
|
+
assert stored_paths <= all_paths
|
|
186
|
+
|
|
187
|
+
# errortype/-msg for image column
|
|
188
|
+
res = img_t.select(error=img_t.img.errortype).collect().to_pandas()
|
|
189
|
+
assert res['error'].isna().all()
|
|
190
|
+
res = img_t.select(error=img_t.img.errormsg).collect().to_pandas()
|
|
191
|
+
assert res['error'].isna().all()
|
|
192
|
+
|
|
193
|
+
for c in [t.c1, t.c1n, t.c2, t.c3, t.c4, t.c5, t.c6, t.c7]:
|
|
194
|
+
# errortype/errormsg only applies to stored computed and media columns
|
|
195
|
+
with pytest.raises(excs.Error) as excinfo:
|
|
196
|
+
_ = t.select(c.errortype).show()
|
|
197
|
+
assert 'only valid for' in str(excinfo.value)
|
|
198
|
+
with pytest.raises(excs.Error) as excinfo:
|
|
199
|
+
_ = t.select(c.errormsg).show()
|
|
200
|
+
assert 'only valid for' in str(excinfo.value)
|
|
201
|
+
|
|
202
|
+
# fileurl/localpath only applies to media columns
|
|
203
|
+
with pytest.raises(excs.Error) as excinfo:
|
|
204
|
+
_ = t.select(t.c1.fileurl).show()
|
|
205
|
+
assert 'only valid for' in str(excinfo.value)
|
|
206
|
+
with pytest.raises(excs.Error) as excinfo:
|
|
207
|
+
_ = t.select(t.c1.localpath).show()
|
|
208
|
+
assert 'only valid for' in str(excinfo.value)
|
|
209
|
+
|
|
210
|
+
# fileurl/localpath doesn't apply to unstored computed img columns
|
|
211
|
+
img_t.add_column(c9=img_t.img.rotate(30))
|
|
212
|
+
with pytest.raises(excs.Error) as excinfo:
|
|
213
|
+
_ = img_t.select(img_t.c9.localpath).show()
|
|
214
|
+
assert 'computed unstored' in str(excinfo.value)
|
|
215
|
+
|
|
216
|
+
def test_null_args(self, test_client: pxt.Client) -> None:
|
|
217
|
+
# create table with two int columns
|
|
218
|
+
schema = {'c1': FloatType(nullable=True), 'c2': FloatType(nullable=True)}
|
|
219
|
+
t = test_client.create_table('test', schema)
|
|
220
|
+
|
|
221
|
+
# computed column that doesn't allow nulls
|
|
222
|
+
t.add_column(c3=lambda c1, c2: c1 + c2, type=FloatType(nullable=False))
|
|
223
|
+
# function that does allow nulls
|
|
224
|
+
@pxt.udf(return_type=FloatType(nullable=True),
|
|
225
|
+
param_types=[FloatType(nullable=False), FloatType(nullable=True)])
|
|
226
|
+
def f(a: int, b: int) -> int:
|
|
227
|
+
if b is None:
|
|
228
|
+
return a
|
|
229
|
+
return a + b
|
|
230
|
+
t.add_column(c4=f(t.c1, t.c2))
|
|
231
|
+
|
|
232
|
+
# data that tests all combinations of nulls
|
|
233
|
+
data = [{'c1': 1.0, 'c2': 1.0}, {'c1': 1.0, 'c2': None}, {'c1': None, 'c2': 1.0}, {'c1': None, 'c2': None}]
|
|
234
|
+
status = t.insert(data, fail_on_exception=False)
|
|
235
|
+
assert status.num_rows == len(data)
|
|
236
|
+
assert status.num_excs == len(data) - 1
|
|
237
|
+
result = t.select(t.c3, t.c4).collect()
|
|
238
|
+
assert result['c3'] == [2.0, None, None, None]
|
|
239
|
+
assert result['c4'] == [2.0, 1.0, None, None]
|
|
91
240
|
|
|
92
241
|
def test_arithmetic_exprs(self, test_tbl: catalog.Table) -> None:
|
|
93
242
|
t = test_tbl
|
|
94
243
|
|
|
244
|
+
_ = t[t.c2, t.c6.f3, t.c2 + t.c6.f3, (t.c2 + t.c6.f3) / (t.c6.f3 + 1)].show()
|
|
245
|
+
_ = t[t.c2 + t.c2].show()
|
|
95
246
|
for op1, op2 in [(t.c2, t.c2), (t.c3, t.c3)]:
|
|
96
247
|
_ = t[op1 + op2].show()
|
|
97
248
|
_ = t[op1 - op2].show()
|
|
@@ -103,13 +254,13 @@ class TestExprs:
|
|
|
103
254
|
(t.c1, t.c2), (t.c1, 1), (t.c2, t.c1), (t.c2, 'a'),
|
|
104
255
|
(t.c1, t.c3), (t.c1, 1.0), (t.c3, t.c1), (t.c3, 'a')
|
|
105
256
|
]:
|
|
106
|
-
with pytest.raises(
|
|
257
|
+
with pytest.raises(excs.Error):
|
|
107
258
|
_ = t[op1 + op2]
|
|
108
|
-
with pytest.raises(
|
|
259
|
+
with pytest.raises(excs.Error):
|
|
109
260
|
_ = t[op1 - op2]
|
|
110
|
-
with pytest.raises(
|
|
261
|
+
with pytest.raises(excs.Error):
|
|
111
262
|
_ = t[op1 * op2]
|
|
112
|
-
with pytest.raises(
|
|
263
|
+
with pytest.raises(excs.Error):
|
|
113
264
|
_ = t[op1 / op2]
|
|
114
265
|
|
|
115
266
|
# TODO: test division; requires predicate
|
|
@@ -117,16 +268,18 @@ class TestExprs:
|
|
|
117
268
|
_ = t[op1 + op2].show()
|
|
118
269
|
_ = t[op1 - op2].show()
|
|
119
270
|
_ = t[op1 * op2].show()
|
|
271
|
+
with pytest.raises(excs.Error):
|
|
272
|
+
_ = t[op1 / op2].show()
|
|
120
273
|
|
|
121
274
|
for op1, op2 in [
|
|
122
275
|
(t.c6.f1, t.c6.f2), (t.c6.f1, t.c6.f3), (t.c6.f1, 1), (t.c6.f1, 1.0),
|
|
123
276
|
(t.c6.f2, t.c6.f1), (t.c6.f3, t.c6.f1), (t.c6.f2, 'a'), (t.c6.f3, 'a'),
|
|
124
277
|
]:
|
|
125
|
-
with pytest.raises(
|
|
278
|
+
with pytest.raises(excs.Error):
|
|
126
279
|
_ = t[op1 + op2].show()
|
|
127
|
-
with pytest.raises(
|
|
280
|
+
with pytest.raises(excs.Error):
|
|
128
281
|
_ = t[op1 - op2].show()
|
|
129
|
-
with pytest.raises(
|
|
282
|
+
with pytest.raises(excs.Error):
|
|
130
283
|
_ = t[op1 * op2].show()
|
|
131
284
|
|
|
132
285
|
|
|
@@ -138,8 +291,8 @@ class TestExprs:
|
|
|
138
291
|
|
|
139
292
|
def test_inline_array(self, test_tbl: catalog.Table) -> None:
|
|
140
293
|
t = test_tbl
|
|
141
|
-
result = t[[
|
|
142
|
-
t = result.
|
|
294
|
+
result = t.select([[t.c2, 1], [t.c2, 2]]).show()
|
|
295
|
+
t = result.column_types()[0]
|
|
143
296
|
assert t.is_array_type()
|
|
144
297
|
assert isinstance(t, ArrayType)
|
|
145
298
|
assert t.shape == (2, 2)
|
|
@@ -167,41 +320,164 @@ class TestExprs:
|
|
|
167
320
|
_ = t[t.c6.f1]
|
|
168
321
|
_ = _.show()
|
|
169
322
|
print(_)
|
|
323
|
+
# predicate on dict field
|
|
324
|
+
_ = t[t.c6.f2 < 2].show()
|
|
170
325
|
#_ = t[t.c6.f2].show()
|
|
171
326
|
#_ = t[t.c6.f5].show()
|
|
172
327
|
_ = t[t.c6.f6.f8].show()
|
|
173
|
-
_ = t[cast(t.c6.f6.f8, ArrayType((4,),
|
|
328
|
+
_ = t[cast(t.c6.f6.f8, ArrayType((4,), FloatType()))].show()
|
|
174
329
|
|
|
175
330
|
# top-level is array
|
|
176
331
|
#_ = t[t.c7['*'].f1].show()
|
|
177
332
|
#_ = t[t.c7['*'].f2].show()
|
|
178
333
|
#_ = t[t.c7['*'].f5].show()
|
|
179
334
|
_ = t[t.c7['*'].f6.f8].show()
|
|
180
|
-
_ = t[
|
|
335
|
+
_ = t[t.c7[0].f6.f8].show()
|
|
336
|
+
_ = t[t.c7[:2].f6.f8].show()
|
|
337
|
+
_ = t[t.c7[::-1].f6.f8].show()
|
|
338
|
+
_ = t[cast(t.c7['*'].f6.f8, ArrayType((2, 4), FloatType()))].show()
|
|
181
339
|
print(_)
|
|
182
340
|
|
|
183
341
|
def test_arrays(self, test_tbl: catalog.Table) -> None:
|
|
184
342
|
t = test_tbl
|
|
185
|
-
t.add_column(
|
|
343
|
+
t.add_column(array_col=[[t.c2, 1], [1, t.c2]])
|
|
186
344
|
_ = t[t.array_col].show()
|
|
187
345
|
print(_)
|
|
188
346
|
_ = t[t.array_col[:, 0]].show()
|
|
189
347
|
print(_)
|
|
190
348
|
|
|
349
|
+
def test_astype(self, test_tbl: catalog.Table) -> None:
|
|
350
|
+
t = test_tbl
|
|
351
|
+
# Convert int to float
|
|
352
|
+
status = t.add_column(c2_as_float=t.c2.astype(FloatType()))
|
|
353
|
+
assert status.num_excs == 0
|
|
354
|
+
data = t.select(t.c2, t.c2_as_float).collect()
|
|
355
|
+
for row in data:
|
|
356
|
+
assert isinstance(row['c2'], int)
|
|
357
|
+
assert isinstance(row['c2_as_float'], float)
|
|
358
|
+
assert row['c2'] == row['c2_as_float']
|
|
359
|
+
# Compound expression
|
|
360
|
+
status = t.add_column(compound_as_float=(t.c2 + 1).astype(FloatType()))
|
|
361
|
+
assert status.num_excs == 0
|
|
362
|
+
data = t.select(t.c2, t.compound_as_float).collect()
|
|
363
|
+
for row in data:
|
|
364
|
+
assert isinstance(row['compound_as_float'], float)
|
|
365
|
+
assert row['c2'] + 1 == row['compound_as_float']
|
|
366
|
+
# Type conversion error
|
|
367
|
+
status = t.add_column(c2_as_string=t.c2.astype(StringType()))
|
|
368
|
+
assert status.num_excs == t.count()
|
|
369
|
+
|
|
370
|
+
def test_apply(self, test_tbl: catalog.Table) -> None:
|
|
371
|
+
|
|
372
|
+
t = test_tbl
|
|
373
|
+
|
|
374
|
+
# For each column c1, ..., c5, we create a new column ci_as_str that converts it to
|
|
375
|
+
# a string, then check that each row is correctly converted
|
|
376
|
+
# (For c1 this is the no-op string-to-string conversion)
|
|
377
|
+
for col_id in range(1, 6):
|
|
378
|
+
col_name = f'c{col_id}'
|
|
379
|
+
str_col_name = f'c{col_id}_str'
|
|
380
|
+
status = t.add_column(**{str_col_name: t[col_name].apply(str)})
|
|
381
|
+
assert status.num_excs == 0
|
|
382
|
+
data = t.select(t[col_name], t[str_col_name]).collect()
|
|
383
|
+
for row in data:
|
|
384
|
+
assert row[str_col_name] == str(row[col_name])
|
|
385
|
+
|
|
386
|
+
# Test a compound expression with apply
|
|
387
|
+
status = t.add_column(c2_plus_1_str=(t.c2 + 1).apply(str))
|
|
388
|
+
assert status.num_excs == 0
|
|
389
|
+
data = t.select(t.c2, t.c2_plus_1_str).collect()
|
|
390
|
+
for row in data:
|
|
391
|
+
assert row['c2_plus_1_str'] == str(row['c2'] + 1)
|
|
392
|
+
|
|
393
|
+
# For columns c6, c7, try using json.dumps and json.loads to emit and parse JSON <-> str
|
|
394
|
+
for col_id in range(6, 8):
|
|
395
|
+
col_name = f'c{col_id}'
|
|
396
|
+
str_col_name = f'c{col_id}_str'
|
|
397
|
+
back_to_json_col_name = f'c{col_id}_back_to_json'
|
|
398
|
+
status = t.add_column(**{str_col_name: t[col_name].apply(json.dumps)})
|
|
399
|
+
assert status.num_excs == 0
|
|
400
|
+
status = t.add_column(**{back_to_json_col_name: t[str_col_name].apply(json.loads)})
|
|
401
|
+
assert status.num_excs == 0
|
|
402
|
+
data = t.select(t[col_name], t[str_col_name], t[back_to_json_col_name]).collect()
|
|
403
|
+
for row in data:
|
|
404
|
+
assert row[str_col_name] == json.dumps(row[col_name])
|
|
405
|
+
assert row[back_to_json_col_name] == row[col_name]
|
|
406
|
+
|
|
407
|
+
def f1(x):
|
|
408
|
+
return str(x)
|
|
409
|
+
|
|
410
|
+
# Now test that a function without a return type throws an exception ...
|
|
411
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
412
|
+
t.c2.apply(f1)
|
|
413
|
+
assert 'Column type of `f1` cannot be inferred.' in str(exc_info.value)
|
|
414
|
+
|
|
415
|
+
# ... but works if the type is specified explicitly.
|
|
416
|
+
status = t.add_column(c2_str_f1=t.c2.apply(f1, col_type=StringType()))
|
|
417
|
+
assert status.num_excs == 0
|
|
418
|
+
|
|
419
|
+
# Test that the return type of a function can be successfully inferred.
|
|
420
|
+
def f2(x) -> str:
|
|
421
|
+
return str(x)
|
|
422
|
+
|
|
423
|
+
status = t.add_column(c2_str_f2=t.c2.apply(f2))
|
|
424
|
+
assert status.num_excs == 0
|
|
425
|
+
|
|
426
|
+
# Test various validation failures.
|
|
427
|
+
|
|
428
|
+
def f3(x, y) -> str:
|
|
429
|
+
return f'{x}{y}'
|
|
430
|
+
|
|
431
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
432
|
+
t.c2.apply(f3) # Too many required parameters
|
|
433
|
+
assert str(exc_info.value) == 'Function `f3` has multiple required parameters.'
|
|
434
|
+
|
|
435
|
+
def f4() -> str:
|
|
436
|
+
return "pixeltable"
|
|
437
|
+
|
|
438
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
439
|
+
t.c2.apply(f4) # No positional parameters
|
|
440
|
+
assert str(exc_info.value) == 'Function `f4` has no positional parameters.'
|
|
441
|
+
|
|
442
|
+
def f5(**kwargs) -> str:
|
|
443
|
+
return ""
|
|
444
|
+
|
|
445
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
446
|
+
t.c2.apply(f5) # No positional parameters
|
|
447
|
+
assert str(exc_info.value) == 'Function `f5` has no positional parameters.'
|
|
448
|
+
|
|
449
|
+
# Ensure these varargs signatures are acceptable
|
|
450
|
+
|
|
451
|
+
def f6(x, **kwargs) -> str:
|
|
452
|
+
return x
|
|
453
|
+
|
|
454
|
+
t.c2.apply(f6)
|
|
455
|
+
|
|
456
|
+
def f7(x, *args) -> str:
|
|
457
|
+
return x
|
|
458
|
+
|
|
459
|
+
t.c2.apply(f7)
|
|
460
|
+
|
|
461
|
+
def f8(*args) -> str:
|
|
462
|
+
return ''
|
|
463
|
+
|
|
464
|
+
t.c2.apply(f8)
|
|
465
|
+
|
|
191
466
|
def test_select_list(self, img_tbl) -> None:
|
|
192
467
|
t = img_tbl
|
|
193
468
|
result = t[t.img].show(n=100)
|
|
194
469
|
_ = result._repr_html_()
|
|
195
|
-
df = t[t.img, udf_call(lambda img: img.rotate(60), ImageType(), tbl=t)]
|
|
196
|
-
_ = df.show(n=100)._repr_html_()
|
|
197
470
|
df = t[[t.img, t.img.rotate(60)]]
|
|
198
471
|
_ = df.show(n=100)._repr_html_()
|
|
199
472
|
|
|
200
|
-
with pytest.raises(
|
|
473
|
+
with pytest.raises(excs.Error):
|
|
201
474
|
_ = t[t.img.rotate]
|
|
202
475
|
|
|
203
476
|
def test_img_members(self, img_tbl) -> None:
|
|
204
477
|
t = img_tbl
|
|
478
|
+
# make sure the limit is applied in Python, not in the SELECT
|
|
479
|
+
result = t[t.img.height > 200][t.img].show(n=3)
|
|
480
|
+
assert len(result) == 3
|
|
205
481
|
result = t[t.img.crop((10, 10, 60, 60))].show(n=100)
|
|
206
482
|
result = t[t.img.crop((10, 10, 60, 60)).resize((100, 100))].show(n=100)
|
|
207
483
|
result = t[t.img.crop((10, 10, 60, 60)).resize((100, 100)).convert('L')].show(n=100)
|
|
@@ -210,20 +486,21 @@ class TestExprs:
|
|
|
210
486
|
_ = result._repr_html_()
|
|
211
487
|
|
|
212
488
|
def test_img_functions(self, img_tbl) -> None:
|
|
489
|
+
skip_test_if_not_installed('nos')
|
|
213
490
|
t = img_tbl
|
|
491
|
+
from pixeltable.functions.pil.image import resize
|
|
492
|
+
result = t[t.img.resize((224, 224))].show(0)
|
|
493
|
+
result = t[resize(t.img, (224, 224))].show(0)
|
|
214
494
|
result = t[blend(t.img, t.img.rotate(90), 0.5)].show(100)
|
|
215
495
|
print(result)
|
|
216
|
-
|
|
496
|
+
from pixeltable.functions.nos.image_embedding import openai_clip
|
|
497
|
+
result = t[openai_clip(t.img.resize((224, 224)))].show(10)
|
|
217
498
|
print(result)
|
|
218
499
|
_ = result._repr_html_()
|
|
219
500
|
_ = t.img.entropy() > 1
|
|
220
|
-
_ = _.extract_sql_predicate()
|
|
221
501
|
_ = (t.img.entropy() > 1) & (t.split == 'train')
|
|
222
|
-
_ = _.extract_sql_predicate()
|
|
223
502
|
_ = (t.img.entropy() > 1) & (t.split == 'train') & (t.split == 'val')
|
|
224
|
-
_ = _.extract_sql_predicate()
|
|
225
503
|
_ = (t.split == 'train') & (t.img.entropy() > 1) & (t.split == 'val') & (t.img.entropy() < 0)
|
|
226
|
-
_ = _.extract_sql_predicate()
|
|
227
504
|
_ = t[(t.split == 'train') & (t.category == 'n03445777')][t.img].show()
|
|
228
505
|
print(_)
|
|
229
506
|
result = t[t.img.width > 1].show()
|
|
@@ -235,32 +512,33 @@ class TestExprs:
|
|
|
235
512
|
][t.img, t.split].show()
|
|
236
513
|
print(result)
|
|
237
514
|
|
|
238
|
-
def test_categoricals_map(self, img_tbl) -> None:
|
|
239
|
-
t = img_tbl
|
|
240
|
-
m = t[t.category].categorical_map()
|
|
241
|
-
_ = t[dict_map(t.category, m)].show()
|
|
242
|
-
print(_)
|
|
243
|
-
|
|
244
515
|
def test_similarity(self, indexed_img_tbl: catalog.Table) -> None:
|
|
516
|
+
skip_test_if_not_installed('nos')
|
|
245
517
|
t = indexed_img_tbl
|
|
246
518
|
_ = t.show(30)
|
|
247
|
-
probe = t
|
|
519
|
+
probe = t.select(t.img, t.category).show(1)
|
|
248
520
|
img = probe[0, 0]
|
|
249
|
-
result = t
|
|
521
|
+
result = t.where(t.img.nearest(img)).show(10)
|
|
250
522
|
assert len(result) == 10
|
|
251
523
|
# nearest() with one SQL predicate and one Python predicate
|
|
252
524
|
result = t[t.img.nearest(img) & (t.category == probe[0, 1]) & (t.img.width > 1)].show(10)
|
|
253
525
|
# TODO: figure out how to verify results
|
|
254
|
-
#assert len(result) == 3
|
|
255
526
|
|
|
256
|
-
|
|
527
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
528
|
+
_ = t[t.img.nearest(img)].order_by(t.category).show()
|
|
529
|
+
assert 'cannot be used in conjunction with' in str(exc_info.value)
|
|
530
|
+
|
|
531
|
+
result = t[t.img.nearest('musical instrument')].show(10)
|
|
257
532
|
assert len(result) == 10
|
|
258
533
|
# matches() with one SQL predicate and one Python predicate
|
|
259
534
|
french_horn_category = 'n03394916'
|
|
260
535
|
result = t[
|
|
261
|
-
t.img.
|
|
536
|
+
t.img.nearest('musical instrument') & (t.category == french_horn_category) & (t.img.width > 1)
|
|
262
537
|
].show(10)
|
|
263
|
-
|
|
538
|
+
|
|
539
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
540
|
+
_ = t[t.img.nearest(5)].show()
|
|
541
|
+
assert 'requires' in str(exc_info.value)
|
|
264
542
|
|
|
265
543
|
# TODO: this doesn't work when combined with test_similarity(), for some reason the data table for img_tbl
|
|
266
544
|
# doesn't get created; why?
|
|
@@ -269,48 +547,47 @@ class TestExprs:
|
|
|
269
547
|
probe = t[t.img].show(1)
|
|
270
548
|
img = probe[0, 0]
|
|
271
549
|
|
|
272
|
-
with pytest.raises(
|
|
550
|
+
with pytest.raises(excs.Error):
|
|
273
551
|
_ = t[t.img.nearest(img)].show(10)
|
|
274
|
-
with pytest.raises(
|
|
275
|
-
_ = t[t.img.
|
|
276
|
-
|
|
277
|
-
def
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
(t.c2 > 5) | (t.c1 == 'test'),
|
|
296
|
-
t.c7['*'].f5 >> [R[3], R[2], R[1], R[0]],
|
|
297
|
-
t.c8[0, 1:],
|
|
298
|
-
utils.sum_uda(t.c2).window(partition_by=t.c4, order_by=t.c3),
|
|
299
|
-
]
|
|
300
|
-
for e in test_exprs:
|
|
552
|
+
with pytest.raises(excs.Error):
|
|
553
|
+
_ = t[t.img.nearest('musical instrument')].show(10)
|
|
554
|
+
|
|
555
|
+
def test_ids(
|
|
556
|
+
self, test_tbl: catalog.Table, test_tbl_exprs: List[exprs.Expr],
|
|
557
|
+
img_tbl: catalog.Table, img_tbl_exprs: List[exprs.Expr]
|
|
558
|
+
) -> None:
|
|
559
|
+
d: Dict[int, exprs.Expr] = {}
|
|
560
|
+
for e in test_tbl_exprs:
|
|
561
|
+
assert e.id is not None
|
|
562
|
+
d[e.id] = e
|
|
563
|
+
for e in img_tbl_exprs:
|
|
564
|
+
assert e.id is not None
|
|
565
|
+
d[e.id] = e
|
|
566
|
+
assert len(d) == len(test_tbl_exprs) + len(img_tbl_exprs)
|
|
567
|
+
|
|
568
|
+
def test_serialization(
|
|
569
|
+
self, test_tbl_exprs: List[exprs.Expr], img_tbl_exprs: List[exprs.Expr]
|
|
570
|
+
) -> None:
|
|
571
|
+
"""Test as_dict()/from_dict() (via serialize()/deserialize()) for all exprs."""
|
|
572
|
+
for e in test_tbl_exprs:
|
|
301
573
|
e_serialized = e.serialize()
|
|
302
|
-
e_deserialized = Expr.deserialize(e_serialized
|
|
574
|
+
e_deserialized = Expr.deserialize(e_serialized)
|
|
303
575
|
assert e.equals(e_deserialized)
|
|
304
576
|
|
|
305
|
-
|
|
306
|
-
img_t.img.width,
|
|
307
|
-
img_t.img.rotate(90),
|
|
308
|
-
]
|
|
309
|
-
for e in img_test_exprs:
|
|
577
|
+
for e in img_tbl_exprs:
|
|
310
578
|
e_serialized = e.serialize()
|
|
311
|
-
e_deserialized = Expr.deserialize(e_serialized
|
|
579
|
+
e_deserialized = Expr.deserialize(e_serialized)
|
|
312
580
|
assert e.equals(e_deserialized)
|
|
313
581
|
|
|
582
|
+
def test_print(self, test_tbl_exprs: List[exprs.Expr], img_tbl_exprs: List[exprs.Expr]) -> None:
|
|
583
|
+
_ = func.FunctionRegistry.get().module_fns
|
|
584
|
+
for e in test_tbl_exprs:
|
|
585
|
+
_ = str(e)
|
|
586
|
+
print(_)
|
|
587
|
+
for e in img_tbl_exprs:
|
|
588
|
+
_ = str(e)
|
|
589
|
+
print(_)
|
|
590
|
+
|
|
314
591
|
def test_subexprs(self, img_tbl: catalog.Table) -> None:
|
|
315
592
|
t = img_tbl
|
|
316
593
|
e = t.img
|
|
@@ -318,31 +595,210 @@ class TestExprs:
|
|
|
318
595
|
assert len(subexprs) == 1
|
|
319
596
|
e = t.img.rotate(90).resize((224, 224))
|
|
320
597
|
subexprs = [s for s in e.subexprs()]
|
|
321
|
-
assert len(subexprs) ==
|
|
322
|
-
subexprs = [s for s in e.subexprs(
|
|
598
|
+
assert len(subexprs) == 4
|
|
599
|
+
subexprs = [s for s in e.subexprs(expr_class=ColumnRef)]
|
|
323
600
|
assert len(subexprs) == 1
|
|
324
601
|
assert t.img.equals(subexprs[0])
|
|
325
602
|
|
|
326
|
-
def test_window_fns(self,
|
|
327
|
-
|
|
603
|
+
def test_window_fns(self, test_client: pxt.Client, test_tbl: catalog.Table) -> None:
|
|
604
|
+
cl = test_client
|
|
328
605
|
t = test_tbl
|
|
329
|
-
_ = t
|
|
330
|
-
|
|
606
|
+
_ = t.select(sum(t.c2, group_by=t.c4, order_by=t.c3)).show(100)
|
|
607
|
+
|
|
608
|
+
# conflicting ordering requirements
|
|
609
|
+
with pytest.raises(excs.Error):
|
|
610
|
+
_ = t.select(sum(t.c2, group_by=t.c4, order_by=t.c3), sum(t.c2, group_by=t.c3, order_by=t.c4)).show(100)
|
|
611
|
+
with pytest.raises(excs.Error):
|
|
612
|
+
_ = t.select(sum(t.c2, group_by=t.c4, order_by=t.c3), sum(t.c2, group_by=t.c3, order_by=t.c4)).show(100)
|
|
613
|
+
|
|
331
614
|
# backfill works
|
|
332
|
-
t.add_column(
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
615
|
+
t.add_column(c9=sum(t.c2, group_by=t.c4, order_by=t.c3))
|
|
616
|
+
_ = t.c9.col.has_window_fn_call()
|
|
617
|
+
|
|
618
|
+
# ordering conflict between frame extraction and window fn
|
|
619
|
+
base_t = cl.create_table('videos', {'video': VideoType(), 'c2': IntType(nullable=False)})
|
|
620
|
+
args = {'video': base_t.video, 'fps': 0}
|
|
621
|
+
v = cl.create_view('frame_view', base_t, iterator_class=FrameIterator, iterator_args=args)
|
|
622
|
+
# compatible ordering
|
|
623
|
+
_ = v.select(v.frame, sum(v.frame_idx, group_by=base_t, order_by=v.pos)).show(100)
|
|
624
|
+
with pytest.raises(excs.Error):
|
|
625
|
+
# incompatible ordering
|
|
626
|
+
_ = v.select(v.frame, sum(v.c2, order_by=base_t, group_by=v.pos)).show(100)
|
|
627
|
+
|
|
628
|
+
schema = {
|
|
629
|
+
'c2': IntType(nullable=False),
|
|
630
|
+
'c3': FloatType(nullable=False),
|
|
631
|
+
'c4': BoolType(nullable=False),
|
|
632
|
+
}
|
|
633
|
+
new_t = cl.create_table('insert_test', schema=schema)
|
|
634
|
+
new_t.add_column(c2_sum=sum(new_t.c2, group_by=new_t.c4, order_by=new_t.c3))
|
|
635
|
+
rows = list(t.select(t.c2, t.c4, t.c3).collect())
|
|
636
|
+
new_t.insert(rows)
|
|
342
637
|
_ = new_t.show(0)
|
|
343
|
-
print(_)
|
|
344
638
|
|
|
345
639
|
def test_aggregates(self, test_tbl: catalog.Table) -> None:
|
|
346
640
|
t = test_tbl
|
|
347
|
-
_ = t[t.c2 % 2, sum(t.c2), count(t.c2), sum(t.c2) + count(t.c2), sum(t.c2) +
|
|
348
|
-
|
|
641
|
+
_ = t[t.c2 % 2, sum(t.c2), count(t.c2), sum(t.c2) + count(t.c2), sum(t.c2) + (t.c2 % 2)]\
|
|
642
|
+
.group_by(t.c2 % 2).show()
|
|
643
|
+
|
|
644
|
+
# check that aggregates don't show up in the wrong places
|
|
645
|
+
with pytest.raises(excs.Error):
|
|
646
|
+
# aggregate in where clause
|
|
647
|
+
_ = t[sum(t.c2) > 0][sum(t.c2)].group_by(t.c2 % 2).show()
|
|
648
|
+
with pytest.raises(excs.Error):
|
|
649
|
+
# aggregate in group_by clause
|
|
650
|
+
_ = t[sum(t.c2)].group_by(sum(t.c2)).show()
|
|
651
|
+
with pytest.raises(excs.Error):
|
|
652
|
+
# mixing aggregates and non-aggregates
|
|
653
|
+
_ = t[sum(t.c2) + t.c2].group_by(t.c2 % 2).show()
|
|
654
|
+
with pytest.raises(excs.Error):
|
|
655
|
+
# nested aggregates
|
|
656
|
+
_ = t[sum(count(t.c2))].group_by(t.c2 % 2).show()
|
|
657
|
+
|
|
658
|
+
def test_udas(self, test_tbl: catalog.Table) -> None:
|
|
659
|
+
t = test_tbl
|
|
660
|
+
|
|
661
|
+
@pxt.uda(
|
|
662
|
+
name='window_agg', init_types=[IntType()], update_types=[IntType()], value_type=IntType(),
|
|
663
|
+
allows_window=True, requires_order_by=False)
|
|
664
|
+
class WindowAgg:
|
|
665
|
+
def __init__(self, val: int = 0):
|
|
666
|
+
self.val = val
|
|
667
|
+
def update(self, ignore: int) -> None:
|
|
668
|
+
pass
|
|
669
|
+
def value(self) -> int:
|
|
670
|
+
return self.val
|
|
671
|
+
|
|
672
|
+
@pxt.uda(
|
|
673
|
+
name='ordered_agg', init_types=[IntType()], update_types=[IntType()], value_type=IntType(),
|
|
674
|
+
requires_order_by=True, allows_window=True)
|
|
675
|
+
class WindowAgg:
|
|
676
|
+
def __init__(self, val: int = 0):
|
|
677
|
+
self.val = val
|
|
678
|
+
def update(self, i: int) -> None:
|
|
679
|
+
pass
|
|
680
|
+
def value(self) -> int:
|
|
681
|
+
return self.val
|
|
682
|
+
|
|
683
|
+
@pxt.uda(
|
|
684
|
+
name='std_agg', init_types=[IntType()], update_types=[IntType()], value_type=IntType(),
|
|
685
|
+
requires_order_by=False, allows_window=False)
|
|
686
|
+
class StdAgg:
|
|
687
|
+
def __init__(self, val: int = 0):
|
|
688
|
+
self.val = val
|
|
689
|
+
def update(self, i: int) -> None:
|
|
690
|
+
pass
|
|
691
|
+
def value(self) -> int:
|
|
692
|
+
return self.val
|
|
693
|
+
|
|
694
|
+
# init arg is passed along
|
|
695
|
+
assert t.select(out=window_agg(t.c2, order_by=t.c2)).collect()[0]['out'] == 0
|
|
696
|
+
assert t.select(out=window_agg(t.c2, val=1, order_by=t.c2)).collect()[0]['out'] == 1
|
|
697
|
+
|
|
698
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
699
|
+
_ = t.select(window_agg(t.c2, val=t.c2, order_by=t.c2)).collect()
|
|
700
|
+
assert 'needs to be a constant' in str(exc_info.value)
|
|
701
|
+
|
|
702
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
703
|
+
# ordering expression not a pixeltable expr
|
|
704
|
+
_ = t.select(ordered_agg(1, t.c2)).collect()
|
|
705
|
+
assert 'but instead is a' in str(exc_info.value).lower()
|
|
706
|
+
|
|
707
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
708
|
+
# explicit order_by
|
|
709
|
+
_ = t.select(ordered_agg(t.c2, order_by=t.c2)).collect()
|
|
710
|
+
assert 'order_by invalid' in str(exc_info.value).lower()
|
|
711
|
+
|
|
712
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
713
|
+
# order_by for non-window function
|
|
714
|
+
_ = t.select(std_agg(t.c2, order_by=t.c2)).collect()
|
|
715
|
+
assert 'does not allow windows' in str(exc_info.value).lower()
|
|
716
|
+
|
|
717
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
718
|
+
# group_by for non-window function
|
|
719
|
+
_ = t.select(std_agg(t.c2, group_by=t.c4)).collect()
|
|
720
|
+
assert 'group_by invalid' in str(exc_info.value).lower()
|
|
721
|
+
|
|
722
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
723
|
+
# missing init type
|
|
724
|
+
@pxt.uda(update_types=[IntType()], value_type=IntType())
|
|
725
|
+
class WindowAgg:
|
|
726
|
+
def __init__(self, val: int = 0):
|
|
727
|
+
self.val = val
|
|
728
|
+
def update(self, ignore: int) -> None:
|
|
729
|
+
pass
|
|
730
|
+
def value(self) -> int:
|
|
731
|
+
return self.val
|
|
732
|
+
assert 'init_types must be a list of' in str(exc_info.value)
|
|
733
|
+
|
|
734
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
735
|
+
# missing update parameter
|
|
736
|
+
@pxt.uda(init_types=[IntType()], update_types=[], value_type=IntType())
|
|
737
|
+
class WindowAgg:
|
|
738
|
+
def __init__(self, val: int = 0):
|
|
739
|
+
self.val = val
|
|
740
|
+
def update(self) -> None:
|
|
741
|
+
pass
|
|
742
|
+
def value(self) -> int:
|
|
743
|
+
return self.val
|
|
744
|
+
assert 'must have at least one parameter' in str(exc_info.value)
|
|
745
|
+
|
|
746
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
747
|
+
# missing update type
|
|
748
|
+
@pxt.uda(init_types=[IntType()], update_types=[IntType()], value_type=IntType())
|
|
749
|
+
class WindowAgg:
|
|
750
|
+
def __init__(self, val: int = 0):
|
|
751
|
+
self.val = val
|
|
752
|
+
def update(self, i1: int, i2: int) -> None:
|
|
753
|
+
pass
|
|
754
|
+
def value(self) -> int:
|
|
755
|
+
return self.val
|
|
756
|
+
assert 'update_types must be a list of' in str(exc_info.value)
|
|
757
|
+
|
|
758
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
759
|
+
# duplicate parameter names
|
|
760
|
+
@pxt.uda(init_types=[IntType()], update_types=[IntType()], value_type=IntType())
|
|
761
|
+
class WindowAgg:
|
|
762
|
+
def __init__(self, val: int = 0):
|
|
763
|
+
self.val = val
|
|
764
|
+
def update(self, val: int) -> None:
|
|
765
|
+
pass
|
|
766
|
+
def value(self) -> int:
|
|
767
|
+
return self.val
|
|
768
|
+
assert 'cannot have parameters with the same name: val' in str(exc_info.value)
|
|
769
|
+
|
|
770
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
771
|
+
# invalid name
|
|
772
|
+
@pxt.uda(name='not an identifier', init_types=[IntType()], update_types=[IntType()], value_type=IntType())
|
|
773
|
+
class WindowAgg:
|
|
774
|
+
def __init__(self, val: int = 0):
|
|
775
|
+
self.val = val
|
|
776
|
+
def update(self, i1: int, i2: int) -> None:
|
|
777
|
+
pass
|
|
778
|
+
def value(self) -> int:
|
|
779
|
+
return self.val
|
|
780
|
+
assert 'invalid name' in str(exc_info.value).lower()
|
|
781
|
+
|
|
782
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
783
|
+
# reserved parameter name
|
|
784
|
+
@pxt.uda(init_types=[IntType()], update_types=[IntType()], value_type=IntType())
|
|
785
|
+
class WindowAgg:
|
|
786
|
+
def __init__(self, val: int = 0):
|
|
787
|
+
self.val = val
|
|
788
|
+
def update(self, order_by: int) -> None:
|
|
789
|
+
pass
|
|
790
|
+
def value(self) -> int:
|
|
791
|
+
return self.val
|
|
792
|
+
assert 'order_by is reserved' in str(exc_info.value).lower()
|
|
793
|
+
|
|
794
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
795
|
+
# reserved parameter name
|
|
796
|
+
@pxt.uda(init_types=[IntType()], update_types=[IntType()], value_type=IntType())
|
|
797
|
+
class WindowAgg:
|
|
798
|
+
def __init__(self, val: int = 0):
|
|
799
|
+
self.val = val
|
|
800
|
+
def update(self, group_by: int) -> None:
|
|
801
|
+
pass
|
|
802
|
+
def value(self) -> int:
|
|
803
|
+
return self.val
|
|
804
|
+
assert 'group_by is reserved' in str(exc_info.value).lower()
|