pixeltable 0.1.0__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +34 -6
- pixeltable/catalog/__init__.py +13 -0
- pixeltable/catalog/catalog.py +159 -0
- pixeltable/catalog/column.py +200 -0
- pixeltable/catalog/dir.py +32 -0
- pixeltable/catalog/globals.py +33 -0
- pixeltable/catalog/insertable_table.py +191 -0
- pixeltable/catalog/named_function.py +36 -0
- pixeltable/catalog/path.py +58 -0
- pixeltable/catalog/path_dict.py +139 -0
- pixeltable/catalog/schema_object.py +39 -0
- pixeltable/catalog/table.py +581 -0
- pixeltable/catalog/table_version.py +749 -0
- pixeltable/catalog/table_version_path.py +133 -0
- pixeltable/catalog/view.py +203 -0
- pixeltable/client.py +590 -30
- pixeltable/dataframe.py +540 -349
- pixeltable/env.py +359 -45
- pixeltable/exceptions.py +12 -21
- pixeltable/exec/__init__.py +9 -0
- pixeltable/exec/aggregation_node.py +78 -0
- pixeltable/exec/cache_prefetch_node.py +116 -0
- pixeltable/exec/component_iteration_node.py +79 -0
- pixeltable/exec/data_row_batch.py +95 -0
- pixeltable/exec/exec_context.py +22 -0
- pixeltable/exec/exec_node.py +61 -0
- pixeltable/exec/expr_eval_node.py +217 -0
- pixeltable/exec/in_memory_data_node.py +69 -0
- pixeltable/exec/media_validation_node.py +43 -0
- pixeltable/exec/sql_scan_node.py +225 -0
- pixeltable/exprs/__init__.py +24 -0
- pixeltable/exprs/arithmetic_expr.py +102 -0
- pixeltable/exprs/array_slice.py +71 -0
- pixeltable/exprs/column_property_ref.py +77 -0
- pixeltable/exprs/column_ref.py +105 -0
- pixeltable/exprs/comparison.py +77 -0
- pixeltable/exprs/compound_predicate.py +98 -0
- pixeltable/exprs/data_row.py +195 -0
- pixeltable/exprs/expr.py +586 -0
- pixeltable/exprs/expr_set.py +39 -0
- pixeltable/exprs/function_call.py +380 -0
- pixeltable/exprs/globals.py +69 -0
- pixeltable/exprs/image_member_access.py +115 -0
- pixeltable/exprs/image_similarity_predicate.py +58 -0
- pixeltable/exprs/inline_array.py +107 -0
- pixeltable/exprs/inline_dict.py +101 -0
- pixeltable/exprs/is_null.py +38 -0
- pixeltable/exprs/json_mapper.py +121 -0
- pixeltable/exprs/json_path.py +159 -0
- pixeltable/exprs/literal.py +54 -0
- pixeltable/exprs/object_ref.py +41 -0
- pixeltable/exprs/predicate.py +44 -0
- pixeltable/exprs/row_builder.py +355 -0
- pixeltable/exprs/rowid_ref.py +94 -0
- pixeltable/exprs/type_cast.py +53 -0
- pixeltable/exprs/variable.py +45 -0
- pixeltable/func/__init__.py +9 -0
- pixeltable/func/aggregate_function.py +194 -0
- pixeltable/func/batched_function.py +53 -0
- pixeltable/func/callable_function.py +69 -0
- pixeltable/func/expr_template_function.py +82 -0
- pixeltable/func/function.py +110 -0
- pixeltable/func/function_registry.py +227 -0
- pixeltable/func/globals.py +36 -0
- pixeltable/func/nos_function.py +202 -0
- pixeltable/func/signature.py +166 -0
- pixeltable/func/udf.py +163 -0
- pixeltable/functions/__init__.py +52 -103
- pixeltable/functions/eval.py +216 -0
- pixeltable/functions/fireworks.py +34 -0
- pixeltable/functions/huggingface.py +120 -0
- pixeltable/functions/image.py +16 -0
- pixeltable/functions/openai.py +256 -0
- pixeltable/functions/pil/image.py +148 -7
- pixeltable/functions/string.py +13 -0
- pixeltable/functions/together.py +122 -0
- pixeltable/functions/util.py +41 -0
- pixeltable/functions/video.py +62 -0
- pixeltable/iterators/__init__.py +3 -0
- pixeltable/iterators/base.py +48 -0
- pixeltable/iterators/document.py +311 -0
- pixeltable/iterators/video.py +89 -0
- pixeltable/metadata/__init__.py +54 -0
- pixeltable/metadata/converters/convert_10.py +18 -0
- pixeltable/metadata/schema.py +211 -0
- pixeltable/plan.py +656 -0
- pixeltable/store.py +418 -182
- pixeltable/tests/conftest.py +146 -88
- pixeltable/tests/functions/test_fireworks.py +42 -0
- pixeltable/tests/functions/test_functions.py +60 -0
- pixeltable/tests/functions/test_huggingface.py +158 -0
- pixeltable/tests/functions/test_openai.py +152 -0
- pixeltable/tests/functions/test_together.py +111 -0
- pixeltable/tests/test_audio.py +65 -0
- pixeltable/tests/test_catalog.py +27 -0
- pixeltable/tests/test_client.py +14 -14
- pixeltable/tests/test_component_view.py +370 -0
- pixeltable/tests/test_dataframe.py +439 -0
- pixeltable/tests/test_dirs.py +78 -62
- pixeltable/tests/test_document.py +120 -0
- pixeltable/tests/test_exprs.py +592 -135
- pixeltable/tests/test_function.py +297 -67
- pixeltable/tests/test_migration.py +43 -0
- pixeltable/tests/test_nos.py +54 -0
- pixeltable/tests/test_snapshot.py +208 -0
- pixeltable/tests/test_table.py +1195 -263
- pixeltable/tests/test_transactional_directory.py +42 -0
- pixeltable/tests/test_types.py +5 -11
- pixeltable/tests/test_video.py +151 -34
- pixeltable/tests/test_view.py +530 -0
- pixeltable/tests/utils.py +320 -45
- pixeltable/tool/create_test_db_dump.py +149 -0
- pixeltable/tool/create_test_video.py +81 -0
- pixeltable/type_system.py +445 -124
- pixeltable/utils/__init__.py +17 -46
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/clip.py +12 -15
- pixeltable/utils/coco.py +136 -0
- pixeltable/utils/documents.py +39 -0
- pixeltable/utils/filecache.py +195 -0
- pixeltable/utils/help.py +11 -0
- pixeltable/utils/hf_datasets.py +157 -0
- pixeltable/utils/media_store.py +76 -0
- pixeltable/utils/parquet.py +167 -0
- pixeltable/utils/pytorch.py +91 -0
- pixeltable/utils/s3.py +13 -0
- pixeltable/utils/sql.py +17 -0
- pixeltable/utils/transactional_directory.py +35 -0
- pixeltable-0.2.4.dist-info/LICENSE +18 -0
- pixeltable-0.2.4.dist-info/METADATA +127 -0
- pixeltable-0.2.4.dist-info/RECORD +132 -0
- {pixeltable-0.1.0.dist-info → pixeltable-0.2.4.dist-info}/WHEEL +1 -1
- pixeltable/catalog.py +0 -1421
- pixeltable/exprs.py +0 -1745
- pixeltable/function.py +0 -269
- pixeltable/functions/clip.py +0 -10
- pixeltable/functions/pil/__init__.py +0 -23
- pixeltable/functions/tf.py +0 -21
- pixeltable/index.py +0 -57
- pixeltable/tests/test_dict.py +0 -24
- pixeltable/tests/test_functions.py +0 -11
- pixeltable/tests/test_tf.py +0 -69
- pixeltable/tf.py +0 -33
- pixeltable/utils/tf.py +0 -33
- pixeltable/utils/video.py +0 -32
- pixeltable-0.1.0.dist-info/METADATA +0 -34
- pixeltable-0.1.0.dist-info/RECORD +0 -36
pixeltable/tests/test_exprs.py
CHANGED
|
@@ -1,45 +1,51 @@
|
|
|
1
|
-
import
|
|
1
|
+
import json
|
|
2
|
+
import urllib.parse
|
|
3
|
+
import urllib.request
|
|
4
|
+
from typing import List, Dict
|
|
5
|
+
|
|
2
6
|
import pytest
|
|
7
|
+
import sqlalchemy as sql
|
|
3
8
|
|
|
9
|
+
import pixeltable as pxt
|
|
10
|
+
import pixeltable.func as func
|
|
4
11
|
from pixeltable import catalog
|
|
5
|
-
from pixeltable
|
|
6
|
-
from pixeltable
|
|
7
|
-
from pixeltable.exprs import Expr,
|
|
12
|
+
from pixeltable import exceptions as excs
|
|
13
|
+
from pixeltable import exprs
|
|
14
|
+
from pixeltable.exprs import Expr, ColumnRef
|
|
8
15
|
from pixeltable.exprs import RELATIVE_PATH_ROOT as R
|
|
9
|
-
from pixeltable.functions import
|
|
16
|
+
from pixeltable.functions import cast, sum, count
|
|
10
17
|
from pixeltable.functions.pil.image import blend
|
|
11
|
-
from pixeltable.
|
|
12
|
-
from pixeltable import
|
|
13
|
-
from pixeltable.
|
|
18
|
+
from pixeltable.iterators import FrameIterator
|
|
19
|
+
from pixeltable.tests.utils import get_image_files, skip_test_if_not_installed
|
|
20
|
+
from pixeltable.type_system import StringType, BoolType, IntType, ArrayType, ColumnType, FloatType, \
|
|
21
|
+
VideoType
|
|
14
22
|
|
|
15
23
|
|
|
16
24
|
class TestExprs:
|
|
17
|
-
# This breaks with exception 'cannot pickle _thread._local obj'
|
|
18
|
-
# sum = Function(
|
|
19
|
-
# IntType(), [IntType()],
|
|
20
|
-
# init_fn=lambda: TestExprs.SumAggregator(), update_fn=SumAggregator.update, value_fn=SumAggregator.value)
|
|
21
|
-
|
|
22
25
|
def test_basic(self, test_tbl: catalog.Table) -> None:
|
|
23
26
|
t = test_tbl
|
|
24
|
-
assert
|
|
27
|
+
assert t['c1'].equals(t.c1)
|
|
28
|
+
assert t['c7']['*'].f5.equals(t.c7['*'].f5)
|
|
29
|
+
|
|
30
|
+
assert isinstance(t.c1 == None, Expr)
|
|
25
31
|
assert isinstance(t.c1 < 'a', Expr)
|
|
26
|
-
assert isinstance(t['c1'] <= 'a', Expr)
|
|
27
32
|
assert isinstance(t.c1 <= 'a', Expr)
|
|
28
|
-
assert isinstance(t['c1'] == 'a', Expr)
|
|
29
33
|
assert isinstance(t.c1 == 'a', Expr)
|
|
30
|
-
assert isinstance(t['c1'] != 'a', Expr)
|
|
31
34
|
assert isinstance(t.c1 != 'a', Expr)
|
|
32
|
-
assert isinstance(t['c1'] > 'a', Expr)
|
|
33
35
|
assert isinstance(t.c1 > 'a', Expr)
|
|
34
|
-
assert isinstance(t['c1'] >= 'a', Expr)
|
|
35
36
|
assert isinstance(t.c1 >= 'a', Expr)
|
|
36
37
|
assert isinstance((t.c1 == 'a') & (t.c2 < 5), Expr)
|
|
37
38
|
assert isinstance((t.c1 == 'a') | (t.c2 < 5), Expr)
|
|
38
39
|
assert isinstance(~(t.c1 == 'a'), Expr)
|
|
40
|
+
with pytest.raises(AttributeError) as excinfo:
|
|
41
|
+
_ = t.does_not_exist
|
|
42
|
+
assert 'unknown' in str(excinfo.value).lower()
|
|
39
43
|
|
|
40
44
|
def test_compound_predicates(self, test_tbl: catalog.Table) -> None:
|
|
41
45
|
t = test_tbl
|
|
42
46
|
# compound predicates that can be fully evaluated in SQL
|
|
47
|
+
_ = t.where((t.c1 == 'test string') & (t.c6.f1 > 50)).collect()
|
|
48
|
+
_ = t.where((t.c1 == 'test string') & (t.c2 > 50)).collect()
|
|
43
49
|
e = ((t.c1 == 'test string') & (t.c2 > 50)).sql_expr()
|
|
44
50
|
assert len(e.clauses) == 2
|
|
45
51
|
|
|
@@ -52,46 +58,192 @@ class TestExprs:
|
|
|
52
58
|
e = (~(t.c1 == 'test string')).sql_expr()
|
|
53
59
|
assert isinstance(e, sql.sql.expression.BinaryExpression)
|
|
54
60
|
|
|
61
|
+
with pytest.raises(TypeError) as exc_info:
|
|
62
|
+
_ = t.where((t.c1 == 'test string') or (t.c6.f1 > 50)).collect()
|
|
63
|
+
assert 'cannot be used in conjunction with python boolean operators' in str(exc_info.value).lower()
|
|
64
|
+
|
|
55
65
|
# compound predicates with Python functions
|
|
56
|
-
udf
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
sql_pred,
|
|
69
|
-
assert
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
assert
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
#
|
|
79
|
-
|
|
80
|
-
assert
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
66
|
+
@pxt.udf(return_type=BoolType(), param_types=[StringType()])
|
|
67
|
+
def udf(_: str) -> bool:
|
|
68
|
+
return True
|
|
69
|
+
@pxt.udf(return_type=BoolType(), param_types=[IntType()])
|
|
70
|
+
def udf2(_: int) -> bool:
|
|
71
|
+
return True
|
|
72
|
+
|
|
73
|
+
# TODO: find a way to test this
|
|
74
|
+
# # & can be split
|
|
75
|
+
# p = (t.c1 == 'test string') & udf(t.c1)
|
|
76
|
+
# assert p.sql_expr() is None
|
|
77
|
+
# sql_pred, other_pred = p.extract_sql_predicate()
|
|
78
|
+
# assert isinstance(sql_pred, sql.sql.expression.BinaryExpression)
|
|
79
|
+
# assert isinstance(other_pred, FunctionCall)
|
|
80
|
+
#
|
|
81
|
+
# p = (t.c1 == 'test string') & udf(t.c1) & (t.c2 > 50)
|
|
82
|
+
# assert p.sql_expr() is None
|
|
83
|
+
# sql_pred, other_pred = p.extract_sql_predicate()
|
|
84
|
+
# assert len(sql_pred.clauses) == 2
|
|
85
|
+
# assert isinstance(other_pred, FunctionCall)
|
|
86
|
+
#
|
|
87
|
+
# p = (t.c1 == 'test string') & udf(t.c1) & (t.c2 > 50) & udf2(t.c2)
|
|
88
|
+
# assert p.sql_expr() is None
|
|
89
|
+
# sql_pred, other_pred = p.extract_sql_predicate()
|
|
90
|
+
# assert len(sql_pred.clauses) == 2
|
|
91
|
+
# assert isinstance(other_pred, CompoundPredicate)
|
|
92
|
+
#
|
|
93
|
+
# # | cannot be split
|
|
94
|
+
# p = (t.c1 == 'test string') | udf(t.c1)
|
|
95
|
+
# assert p.sql_expr() is None
|
|
96
|
+
# sql_pred, other_pred = p.extract_sql_predicate()
|
|
97
|
+
# assert sql_pred is None
|
|
98
|
+
# assert isinstance(other_pred, CompoundPredicate)
|
|
99
|
+
|
|
100
|
+
def test_filters(self, test_tbl: catalog.Table) -> None:
|
|
86
101
|
t = test_tbl
|
|
87
102
|
_ = t[t.c1 == 'test string'].show()
|
|
88
103
|
print(_)
|
|
89
104
|
_ = t[t.c2 > 50].show()
|
|
90
105
|
print(_)
|
|
106
|
+
_ = t[t.c1n == None].show()
|
|
107
|
+
print(_)
|
|
108
|
+
_ = t[t.c1n != None].show(0)
|
|
109
|
+
print(_)
|
|
110
|
+
|
|
111
|
+
def test_exception_handling(self, test_tbl: catalog.Table) -> None:
|
|
112
|
+
t = test_tbl
|
|
113
|
+
|
|
114
|
+
# error in expr that's handled in SQL
|
|
115
|
+
with pytest.raises(excs.Error):
|
|
116
|
+
_ = t[(t.c2 + 1) / t.c2].show()
|
|
117
|
+
|
|
118
|
+
# error in expr that's handled in Python
|
|
119
|
+
with pytest.raises(excs.Error):
|
|
120
|
+
_ = t[(t.c6.f2 + 1) / (t.c2 - 10)].show()
|
|
121
|
+
|
|
122
|
+
# the same, but with an inline function
|
|
123
|
+
@pxt.udf(return_type=FloatType(), param_types=[IntType(), IntType()])
|
|
124
|
+
def f(a: int, b: int) -> float:
|
|
125
|
+
return a / b
|
|
126
|
+
with pytest.raises(excs.Error):
|
|
127
|
+
_ = t[f(t.c2 + 1, t.c2)].show()
|
|
128
|
+
|
|
129
|
+
# error in agg.init()
|
|
130
|
+
@pxt.uda(update_types=[IntType()], value_type=IntType(), name='agg')
|
|
131
|
+
class Aggregator(pxt.Aggregator):
|
|
132
|
+
def __init__(self):
|
|
133
|
+
self.sum = 1 / 0
|
|
134
|
+
def update(self, val):
|
|
135
|
+
pass
|
|
136
|
+
def value(self):
|
|
137
|
+
return 1
|
|
138
|
+
with pytest.raises(excs.Error):
|
|
139
|
+
_ = t[agg(t.c2)].show()
|
|
140
|
+
|
|
141
|
+
# error in agg.update()
|
|
142
|
+
@pxt.uda(update_types=[IntType()], value_type=IntType(), name='agg')
|
|
143
|
+
class Aggregator(pxt.Aggregator):
|
|
144
|
+
def __init__(self):
|
|
145
|
+
self.sum = 0
|
|
146
|
+
def update(self, val):
|
|
147
|
+
self.sum += 1 / val
|
|
148
|
+
def value(self):
|
|
149
|
+
return 1
|
|
150
|
+
with pytest.raises(excs.Error):
|
|
151
|
+
_ = t[agg(t.c2 - 10)].show()
|
|
152
|
+
|
|
153
|
+
# error in agg.value()
|
|
154
|
+
@pxt.uda(update_types=[IntType()], value_type=IntType(), name='agg')
|
|
155
|
+
class Aggregator(pxt.Aggregator):
|
|
156
|
+
def __init__(self):
|
|
157
|
+
self.sum = 0
|
|
158
|
+
def update(self, val):
|
|
159
|
+
self.sum += val
|
|
160
|
+
def value(self):
|
|
161
|
+
return 1 / self.sum
|
|
162
|
+
with pytest.raises(excs.Error):
|
|
163
|
+
_ = t[t.c2 <= 2][agg(t.c2 - 1)].show()
|
|
164
|
+
|
|
165
|
+
def test_props(self, test_tbl: catalog.Table, img_tbl: catalog.Table) -> None:
|
|
166
|
+
t = test_tbl
|
|
167
|
+
# errortype/-msg for computed column
|
|
168
|
+
res = t.select(error=t.c8.errortype).collect()
|
|
169
|
+
assert res.to_pandas()['error'].isna().all()
|
|
170
|
+
res = t.select(error=t.c8.errormsg).collect()
|
|
171
|
+
assert res.to_pandas()['error'].isna().all()
|
|
172
|
+
|
|
173
|
+
img_t = img_tbl
|
|
174
|
+
# fileurl
|
|
175
|
+
res = img_t.select(img_t.img.fileurl).show(0).to_pandas()
|
|
176
|
+
stored_urls = set(res.iloc[:, 0])
|
|
177
|
+
assert len(stored_urls) == len(res)
|
|
178
|
+
all_urls = set(urllib.parse.urljoin('file:', urllib.request.pathname2url(path)) for path in get_image_files())
|
|
179
|
+
assert stored_urls <= all_urls
|
|
180
|
+
|
|
181
|
+
# localpath
|
|
182
|
+
res = img_t.select(img_t.img.localpath).show(0).to_pandas()
|
|
183
|
+
stored_paths = set(res.iloc[:, 0])
|
|
184
|
+
assert len(stored_paths) == len(res)
|
|
185
|
+
all_paths = set(get_image_files())
|
|
186
|
+
assert stored_paths <= all_paths
|
|
187
|
+
|
|
188
|
+
# errortype/-msg for image column
|
|
189
|
+
res = img_t.select(error=img_t.img.errortype).collect().to_pandas()
|
|
190
|
+
assert res['error'].isna().all()
|
|
191
|
+
res = img_t.select(error=img_t.img.errormsg).collect().to_pandas()
|
|
192
|
+
assert res['error'].isna().all()
|
|
193
|
+
|
|
194
|
+
for c in [t.c1, t.c1n, t.c2, t.c3, t.c4, t.c5, t.c6, t.c7]:
|
|
195
|
+
# errortype/errormsg only applies to stored computed and media columns
|
|
196
|
+
with pytest.raises(excs.Error) as excinfo:
|
|
197
|
+
_ = t.select(c.errortype).show()
|
|
198
|
+
assert 'only valid for' in str(excinfo.value)
|
|
199
|
+
with pytest.raises(excs.Error) as excinfo:
|
|
200
|
+
_ = t.select(c.errormsg).show()
|
|
201
|
+
assert 'only valid for' in str(excinfo.value)
|
|
202
|
+
|
|
203
|
+
# fileurl/localpath only applies to media columns
|
|
204
|
+
with pytest.raises(excs.Error) as excinfo:
|
|
205
|
+
_ = t.select(t.c1.fileurl).show()
|
|
206
|
+
assert 'only valid for' in str(excinfo.value)
|
|
207
|
+
with pytest.raises(excs.Error) as excinfo:
|
|
208
|
+
_ = t.select(t.c1.localpath).show()
|
|
209
|
+
assert 'only valid for' in str(excinfo.value)
|
|
210
|
+
|
|
211
|
+
# fileurl/localpath doesn't apply to unstored computed img columns
|
|
212
|
+
img_t.add_column(c9=img_t.img.rotate(30))
|
|
213
|
+
with pytest.raises(excs.Error) as excinfo:
|
|
214
|
+
_ = img_t.select(img_t.c9.localpath).show()
|
|
215
|
+
assert 'computed unstored' in str(excinfo.value)
|
|
216
|
+
|
|
217
|
+
def test_null_args(self, test_client: pxt.Client) -> None:
|
|
218
|
+
# create table with two int columns
|
|
219
|
+
schema = {'c1': FloatType(nullable=True), 'c2': FloatType(nullable=True)}
|
|
220
|
+
t = test_client.create_table('test', schema)
|
|
221
|
+
|
|
222
|
+
# computed column that doesn't allow nulls
|
|
223
|
+
t.add_column(c3=lambda c1, c2: c1 + c2, type=FloatType(nullable=False))
|
|
224
|
+
# function that does allow nulls
|
|
225
|
+
@pxt.udf(return_type=FloatType(nullable=True),
|
|
226
|
+
param_types=[FloatType(nullable=False), FloatType(nullable=True)])
|
|
227
|
+
def f(a: int, b: int) -> int:
|
|
228
|
+
if b is None:
|
|
229
|
+
return a
|
|
230
|
+
return a + b
|
|
231
|
+
t.add_column(c4=f(t.c1, t.c2))
|
|
232
|
+
|
|
233
|
+
# data that tests all combinations of nulls
|
|
234
|
+
data = [{'c1': 1.0, 'c2': 1.0}, {'c1': 1.0, 'c2': None}, {'c1': None, 'c2': 1.0}, {'c1': None, 'c2': None}]
|
|
235
|
+
status = t.insert(data, fail_on_exception=False)
|
|
236
|
+
assert status.num_rows == len(data)
|
|
237
|
+
assert status.num_excs == len(data) - 1
|
|
238
|
+
result = t.select(t.c3, t.c4).collect()
|
|
239
|
+
assert result['c3'] == [2.0, None, None, None]
|
|
240
|
+
assert result['c4'] == [2.0, 1.0, None, None]
|
|
91
241
|
|
|
92
242
|
def test_arithmetic_exprs(self, test_tbl: catalog.Table) -> None:
|
|
93
243
|
t = test_tbl
|
|
94
244
|
|
|
245
|
+
_ = t[t.c2, t.c6.f3, t.c2 + t.c6.f3, (t.c2 + t.c6.f3) / (t.c6.f3 + 1)].show()
|
|
246
|
+
_ = t[t.c2 + t.c2].show()
|
|
95
247
|
for op1, op2 in [(t.c2, t.c2), (t.c3, t.c3)]:
|
|
96
248
|
_ = t[op1 + op2].show()
|
|
97
249
|
_ = t[op1 - op2].show()
|
|
@@ -103,13 +255,13 @@ class TestExprs:
|
|
|
103
255
|
(t.c1, t.c2), (t.c1, 1), (t.c2, t.c1), (t.c2, 'a'),
|
|
104
256
|
(t.c1, t.c3), (t.c1, 1.0), (t.c3, t.c1), (t.c3, 'a')
|
|
105
257
|
]:
|
|
106
|
-
with pytest.raises(
|
|
258
|
+
with pytest.raises(excs.Error):
|
|
107
259
|
_ = t[op1 + op2]
|
|
108
|
-
with pytest.raises(
|
|
260
|
+
with pytest.raises(excs.Error):
|
|
109
261
|
_ = t[op1 - op2]
|
|
110
|
-
with pytest.raises(
|
|
262
|
+
with pytest.raises(excs.Error):
|
|
111
263
|
_ = t[op1 * op2]
|
|
112
|
-
with pytest.raises(
|
|
264
|
+
with pytest.raises(excs.Error):
|
|
113
265
|
_ = t[op1 / op2]
|
|
114
266
|
|
|
115
267
|
# TODO: test division; requires predicate
|
|
@@ -117,16 +269,18 @@ class TestExprs:
|
|
|
117
269
|
_ = t[op1 + op2].show()
|
|
118
270
|
_ = t[op1 - op2].show()
|
|
119
271
|
_ = t[op1 * op2].show()
|
|
272
|
+
with pytest.raises(excs.Error):
|
|
273
|
+
_ = t[op1 / op2].show()
|
|
120
274
|
|
|
121
275
|
for op1, op2 in [
|
|
122
276
|
(t.c6.f1, t.c6.f2), (t.c6.f1, t.c6.f3), (t.c6.f1, 1), (t.c6.f1, 1.0),
|
|
123
277
|
(t.c6.f2, t.c6.f1), (t.c6.f3, t.c6.f1), (t.c6.f2, 'a'), (t.c6.f3, 'a'),
|
|
124
278
|
]:
|
|
125
|
-
with pytest.raises(
|
|
279
|
+
with pytest.raises(excs.Error):
|
|
126
280
|
_ = t[op1 + op2].show()
|
|
127
|
-
with pytest.raises(
|
|
281
|
+
with pytest.raises(excs.Error):
|
|
128
282
|
_ = t[op1 - op2].show()
|
|
129
|
-
with pytest.raises(
|
|
283
|
+
with pytest.raises(excs.Error):
|
|
130
284
|
_ = t[op1 * op2].show()
|
|
131
285
|
|
|
132
286
|
|
|
@@ -138,8 +292,8 @@ class TestExprs:
|
|
|
138
292
|
|
|
139
293
|
def test_inline_array(self, test_tbl: catalog.Table) -> None:
|
|
140
294
|
t = test_tbl
|
|
141
|
-
result = t[[
|
|
142
|
-
t = result.
|
|
295
|
+
result = t.select([[t.c2, 1], [t.c2, 2]]).show()
|
|
296
|
+
t = result.column_types()[0]
|
|
143
297
|
assert t.is_array_type()
|
|
144
298
|
assert isinstance(t, ArrayType)
|
|
145
299
|
assert t.shape == (2, 2)
|
|
@@ -167,41 +321,164 @@ class TestExprs:
|
|
|
167
321
|
_ = t[t.c6.f1]
|
|
168
322
|
_ = _.show()
|
|
169
323
|
print(_)
|
|
324
|
+
# predicate on dict field
|
|
325
|
+
_ = t[t.c6.f2 < 2].show()
|
|
170
326
|
#_ = t[t.c6.f2].show()
|
|
171
327
|
#_ = t[t.c6.f5].show()
|
|
172
328
|
_ = t[t.c6.f6.f8].show()
|
|
173
|
-
_ = t[cast(t.c6.f6.f8, ArrayType((4,),
|
|
329
|
+
_ = t[cast(t.c6.f6.f8, ArrayType((4,), FloatType()))].show()
|
|
174
330
|
|
|
175
331
|
# top-level is array
|
|
176
332
|
#_ = t[t.c7['*'].f1].show()
|
|
177
333
|
#_ = t[t.c7['*'].f2].show()
|
|
178
334
|
#_ = t[t.c7['*'].f5].show()
|
|
179
335
|
_ = t[t.c7['*'].f6.f8].show()
|
|
180
|
-
_ = t[
|
|
336
|
+
_ = t[t.c7[0].f6.f8].show()
|
|
337
|
+
_ = t[t.c7[:2].f6.f8].show()
|
|
338
|
+
_ = t[t.c7[::-1].f6.f8].show()
|
|
339
|
+
_ = t[cast(t.c7['*'].f6.f8, ArrayType((2, 4), FloatType()))].show()
|
|
181
340
|
print(_)
|
|
182
341
|
|
|
183
342
|
def test_arrays(self, test_tbl: catalog.Table) -> None:
|
|
184
343
|
t = test_tbl
|
|
185
|
-
t.add_column(
|
|
344
|
+
t.add_column(array_col=[[t.c2, 1], [1, t.c2]])
|
|
186
345
|
_ = t[t.array_col].show()
|
|
187
346
|
print(_)
|
|
188
347
|
_ = t[t.array_col[:, 0]].show()
|
|
189
348
|
print(_)
|
|
190
349
|
|
|
350
|
+
def test_astype(self, test_tbl: catalog.Table) -> None:
|
|
351
|
+
t = test_tbl
|
|
352
|
+
# Convert int to float
|
|
353
|
+
status = t.add_column(c2_as_float=t.c2.astype(FloatType()))
|
|
354
|
+
assert status.num_excs == 0
|
|
355
|
+
data = t.select(t.c2, t.c2_as_float).collect()
|
|
356
|
+
for row in data:
|
|
357
|
+
assert isinstance(row['c2'], int)
|
|
358
|
+
assert isinstance(row['c2_as_float'], float)
|
|
359
|
+
assert row['c2'] == row['c2_as_float']
|
|
360
|
+
# Compound expression
|
|
361
|
+
status = t.add_column(compound_as_float=(t.c2 + 1).astype(FloatType()))
|
|
362
|
+
assert status.num_excs == 0
|
|
363
|
+
data = t.select(t.c2, t.compound_as_float).collect()
|
|
364
|
+
for row in data:
|
|
365
|
+
assert isinstance(row['compound_as_float'], float)
|
|
366
|
+
assert row['c2'] + 1 == row['compound_as_float']
|
|
367
|
+
# Type conversion error
|
|
368
|
+
status = t.add_column(c2_as_string=t.c2.astype(StringType()))
|
|
369
|
+
assert status.num_excs == t.count()
|
|
370
|
+
|
|
371
|
+
def test_apply(self, test_tbl: catalog.Table) -> None:
|
|
372
|
+
|
|
373
|
+
t = test_tbl
|
|
374
|
+
|
|
375
|
+
# For each column c1, ..., c5, we create a new column ci_as_str that converts it to
|
|
376
|
+
# a string, then check that each row is correctly converted
|
|
377
|
+
# (For c1 this is the no-op string-to-string conversion)
|
|
378
|
+
for col_id in range(1, 6):
|
|
379
|
+
col_name = f'c{col_id}'
|
|
380
|
+
str_col_name = f'c{col_id}_str'
|
|
381
|
+
status = t.add_column(**{str_col_name: t[col_name].apply(str)})
|
|
382
|
+
assert status.num_excs == 0
|
|
383
|
+
data = t.select(t[col_name], t[str_col_name]).collect()
|
|
384
|
+
for row in data:
|
|
385
|
+
assert row[str_col_name] == str(row[col_name])
|
|
386
|
+
|
|
387
|
+
# Test a compound expression with apply
|
|
388
|
+
status = t.add_column(c2_plus_1_str=(t.c2 + 1).apply(str))
|
|
389
|
+
assert status.num_excs == 0
|
|
390
|
+
data = t.select(t.c2, t.c2_plus_1_str).collect()
|
|
391
|
+
for row in data:
|
|
392
|
+
assert row['c2_plus_1_str'] == str(row['c2'] + 1)
|
|
393
|
+
|
|
394
|
+
# For columns c6, c7, try using json.dumps and json.loads to emit and parse JSON <-> str
|
|
395
|
+
for col_id in range(6, 8):
|
|
396
|
+
col_name = f'c{col_id}'
|
|
397
|
+
str_col_name = f'c{col_id}_str'
|
|
398
|
+
back_to_json_col_name = f'c{col_id}_back_to_json'
|
|
399
|
+
status = t.add_column(**{str_col_name: t[col_name].apply(json.dumps)})
|
|
400
|
+
assert status.num_excs == 0
|
|
401
|
+
status = t.add_column(**{back_to_json_col_name: t[str_col_name].apply(json.loads)})
|
|
402
|
+
assert status.num_excs == 0
|
|
403
|
+
data = t.select(t[col_name], t[str_col_name], t[back_to_json_col_name]).collect()
|
|
404
|
+
for row in data:
|
|
405
|
+
assert row[str_col_name] == json.dumps(row[col_name])
|
|
406
|
+
assert row[back_to_json_col_name] == row[col_name]
|
|
407
|
+
|
|
408
|
+
def f1(x):
|
|
409
|
+
return str(x)
|
|
410
|
+
|
|
411
|
+
# Now test that a function without a return type throws an exception ...
|
|
412
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
413
|
+
t.c2.apply(f1)
|
|
414
|
+
assert 'Column type of `f1` cannot be inferred.' in str(exc_info.value)
|
|
415
|
+
|
|
416
|
+
# ... but works if the type is specified explicitly.
|
|
417
|
+
status = t.add_column(c2_str_f1=t.c2.apply(f1, col_type=StringType()))
|
|
418
|
+
assert status.num_excs == 0
|
|
419
|
+
|
|
420
|
+
# Test that the return type of a function can be successfully inferred.
|
|
421
|
+
def f2(x) -> str:
|
|
422
|
+
return str(x)
|
|
423
|
+
|
|
424
|
+
status = t.add_column(c2_str_f2=t.c2.apply(f2))
|
|
425
|
+
assert status.num_excs == 0
|
|
426
|
+
|
|
427
|
+
# Test various validation failures.
|
|
428
|
+
|
|
429
|
+
def f3(x, y) -> str:
|
|
430
|
+
return f'{x}{y}'
|
|
431
|
+
|
|
432
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
433
|
+
t.c2.apply(f3) # Too many required parameters
|
|
434
|
+
assert str(exc_info.value) == 'Function `f3` has multiple required parameters.'
|
|
435
|
+
|
|
436
|
+
def f4() -> str:
|
|
437
|
+
return "pixeltable"
|
|
438
|
+
|
|
439
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
440
|
+
t.c2.apply(f4) # No positional parameters
|
|
441
|
+
assert str(exc_info.value) == 'Function `f4` has no positional parameters.'
|
|
442
|
+
|
|
443
|
+
def f5(**kwargs) -> str:
|
|
444
|
+
return ""
|
|
445
|
+
|
|
446
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
447
|
+
t.c2.apply(f5) # No positional parameters
|
|
448
|
+
assert str(exc_info.value) == 'Function `f5` has no positional parameters.'
|
|
449
|
+
|
|
450
|
+
# Ensure these varargs signatures are acceptable
|
|
451
|
+
|
|
452
|
+
def f6(x, **kwargs) -> str:
|
|
453
|
+
return x
|
|
454
|
+
|
|
455
|
+
t.c2.apply(f6)
|
|
456
|
+
|
|
457
|
+
def f7(x, *args) -> str:
|
|
458
|
+
return x
|
|
459
|
+
|
|
460
|
+
t.c2.apply(f7)
|
|
461
|
+
|
|
462
|
+
def f8(*args) -> str:
|
|
463
|
+
return ''
|
|
464
|
+
|
|
465
|
+
t.c2.apply(f8)
|
|
466
|
+
|
|
191
467
|
def test_select_list(self, img_tbl) -> None:
|
|
192
468
|
t = img_tbl
|
|
193
469
|
result = t[t.img].show(n=100)
|
|
194
470
|
_ = result._repr_html_()
|
|
195
|
-
df = t[t.img, udf_call(lambda img: img.rotate(60), ImageType(), tbl=t)]
|
|
196
|
-
_ = df.show(n=100)._repr_html_()
|
|
197
471
|
df = t[[t.img, t.img.rotate(60)]]
|
|
198
472
|
_ = df.show(n=100)._repr_html_()
|
|
199
473
|
|
|
200
|
-
with pytest.raises(
|
|
474
|
+
with pytest.raises(excs.Error):
|
|
201
475
|
_ = t[t.img.rotate]
|
|
202
476
|
|
|
203
477
|
def test_img_members(self, img_tbl) -> None:
|
|
204
478
|
t = img_tbl
|
|
479
|
+
# make sure the limit is applied in Python, not in the SELECT
|
|
480
|
+
result = t[t.img.height > 200][t.img].show(n=3)
|
|
481
|
+
assert len(result) == 3
|
|
205
482
|
result = t[t.img.crop((10, 10, 60, 60))].show(n=100)
|
|
206
483
|
result = t[t.img.crop((10, 10, 60, 60)).resize((100, 100))].show(n=100)
|
|
207
484
|
result = t[t.img.crop((10, 10, 60, 60)).resize((100, 100)).convert('L')].show(n=100)
|
|
@@ -210,20 +487,21 @@ class TestExprs:
|
|
|
210
487
|
_ = result._repr_html_()
|
|
211
488
|
|
|
212
489
|
def test_img_functions(self, img_tbl) -> None:
|
|
490
|
+
skip_test_if_not_installed('nos')
|
|
213
491
|
t = img_tbl
|
|
492
|
+
from pixeltable.functions.pil.image import resize
|
|
493
|
+
result = t[t.img.resize((224, 224))].show(0)
|
|
494
|
+
result = t[resize(t.img, (224, 224))].show(0)
|
|
214
495
|
result = t[blend(t.img, t.img.rotate(90), 0.5)].show(100)
|
|
215
496
|
print(result)
|
|
216
|
-
|
|
497
|
+
from pixeltable.functions.nos.image_embedding import openai_clip
|
|
498
|
+
result = t[openai_clip(t.img.resize((224, 224)))].show(10)
|
|
217
499
|
print(result)
|
|
218
500
|
_ = result._repr_html_()
|
|
219
501
|
_ = t.img.entropy() > 1
|
|
220
|
-
_ = _.extract_sql_predicate()
|
|
221
502
|
_ = (t.img.entropy() > 1) & (t.split == 'train')
|
|
222
|
-
_ = _.extract_sql_predicate()
|
|
223
503
|
_ = (t.img.entropy() > 1) & (t.split == 'train') & (t.split == 'val')
|
|
224
|
-
_ = _.extract_sql_predicate()
|
|
225
504
|
_ = (t.split == 'train') & (t.img.entropy() > 1) & (t.split == 'val') & (t.img.entropy() < 0)
|
|
226
|
-
_ = _.extract_sql_predicate()
|
|
227
505
|
_ = t[(t.split == 'train') & (t.category == 'n03445777')][t.img].show()
|
|
228
506
|
print(_)
|
|
229
507
|
result = t[t.img.width > 1].show()
|
|
@@ -235,32 +513,33 @@ class TestExprs:
|
|
|
235
513
|
][t.img, t.split].show()
|
|
236
514
|
print(result)
|
|
237
515
|
|
|
238
|
-
def test_categoricals_map(self, img_tbl) -> None:
|
|
239
|
-
t = img_tbl
|
|
240
|
-
m = t[t.category].categorical_map()
|
|
241
|
-
_ = t[dict_map(t.category, m)].show()
|
|
242
|
-
print(_)
|
|
243
|
-
|
|
244
516
|
def test_similarity(self, indexed_img_tbl: catalog.Table) -> None:
|
|
517
|
+
skip_test_if_not_installed('nos')
|
|
245
518
|
t = indexed_img_tbl
|
|
246
519
|
_ = t.show(30)
|
|
247
|
-
probe = t
|
|
520
|
+
probe = t.select(t.img, t.category).show(1)
|
|
248
521
|
img = probe[0, 0]
|
|
249
|
-
result = t
|
|
522
|
+
result = t.where(t.img.nearest(img)).show(10)
|
|
250
523
|
assert len(result) == 10
|
|
251
524
|
# nearest() with one SQL predicate and one Python predicate
|
|
252
525
|
result = t[t.img.nearest(img) & (t.category == probe[0, 1]) & (t.img.width > 1)].show(10)
|
|
253
526
|
# TODO: figure out how to verify results
|
|
254
|
-
#assert len(result) == 3
|
|
255
527
|
|
|
256
|
-
|
|
528
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
529
|
+
_ = t[t.img.nearest(img)].order_by(t.category).show()
|
|
530
|
+
assert 'cannot be used in conjunction with' in str(exc_info.value)
|
|
531
|
+
|
|
532
|
+
result = t[t.img.nearest('musical instrument')].show(10)
|
|
257
533
|
assert len(result) == 10
|
|
258
534
|
# matches() with one SQL predicate and one Python predicate
|
|
259
535
|
french_horn_category = 'n03394916'
|
|
260
536
|
result = t[
|
|
261
|
-
t.img.
|
|
537
|
+
t.img.nearest('musical instrument') & (t.category == french_horn_category) & (t.img.width > 1)
|
|
262
538
|
].show(10)
|
|
263
|
-
|
|
539
|
+
|
|
540
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
541
|
+
_ = t[t.img.nearest(5)].show()
|
|
542
|
+
assert 'requires' in str(exc_info.value)
|
|
264
543
|
|
|
265
544
|
# TODO: this doesn't work when combined with test_similarity(), for some reason the data table for img_tbl
|
|
266
545
|
# doesn't get created; why?
|
|
@@ -269,48 +548,47 @@ class TestExprs:
|
|
|
269
548
|
probe = t[t.img].show(1)
|
|
270
549
|
img = probe[0, 0]
|
|
271
550
|
|
|
272
|
-
with pytest.raises(
|
|
551
|
+
with pytest.raises(excs.Error):
|
|
273
552
|
_ = t[t.img.nearest(img)].show(10)
|
|
274
|
-
with pytest.raises(
|
|
275
|
-
_ = t[t.img.
|
|
276
|
-
|
|
277
|
-
def
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
(t.c2 > 5) | (t.c1 == 'test'),
|
|
296
|
-
t.c7['*'].f5 >> [R[3], R[2], R[1], R[0]],
|
|
297
|
-
t.c8[0, 1:],
|
|
298
|
-
utils.sum_uda(t.c2).window(partition_by=t.c4, order_by=t.c3),
|
|
299
|
-
]
|
|
300
|
-
for e in test_exprs:
|
|
553
|
+
with pytest.raises(excs.Error):
|
|
554
|
+
_ = t[t.img.nearest('musical instrument')].show(10)
|
|
555
|
+
|
|
556
|
+
def test_ids(
|
|
557
|
+
self, test_tbl: catalog.Table, test_tbl_exprs: List[exprs.Expr],
|
|
558
|
+
img_tbl: catalog.Table, img_tbl_exprs: List[exprs.Expr]
|
|
559
|
+
) -> None:
|
|
560
|
+
d: Dict[int, exprs.Expr] = {}
|
|
561
|
+
for e in test_tbl_exprs:
|
|
562
|
+
assert e.id is not None
|
|
563
|
+
d[e.id] = e
|
|
564
|
+
for e in img_tbl_exprs:
|
|
565
|
+
assert e.id is not None
|
|
566
|
+
d[e.id] = e
|
|
567
|
+
assert len(d) == len(test_tbl_exprs) + len(img_tbl_exprs)
|
|
568
|
+
|
|
569
|
+
def test_serialization(
|
|
570
|
+
self, test_tbl_exprs: List[exprs.Expr], img_tbl_exprs: List[exprs.Expr]
|
|
571
|
+
) -> None:
|
|
572
|
+
"""Test as_dict()/from_dict() (via serialize()/deserialize()) for all exprs."""
|
|
573
|
+
for e in test_tbl_exprs:
|
|
301
574
|
e_serialized = e.serialize()
|
|
302
|
-
e_deserialized = Expr.deserialize(e_serialized
|
|
575
|
+
e_deserialized = Expr.deserialize(e_serialized)
|
|
303
576
|
assert e.equals(e_deserialized)
|
|
304
577
|
|
|
305
|
-
|
|
306
|
-
img_t.img.width,
|
|
307
|
-
img_t.img.rotate(90),
|
|
308
|
-
]
|
|
309
|
-
for e in img_test_exprs:
|
|
578
|
+
for e in img_tbl_exprs:
|
|
310
579
|
e_serialized = e.serialize()
|
|
311
|
-
e_deserialized = Expr.deserialize(e_serialized
|
|
580
|
+
e_deserialized = Expr.deserialize(e_serialized)
|
|
312
581
|
assert e.equals(e_deserialized)
|
|
313
582
|
|
|
583
|
+
def test_print(self, test_tbl_exprs: List[exprs.Expr], img_tbl_exprs: List[exprs.Expr]) -> None:
|
|
584
|
+
_ = func.FunctionRegistry.get().module_fns
|
|
585
|
+
for e in test_tbl_exprs:
|
|
586
|
+
_ = str(e)
|
|
587
|
+
print(_)
|
|
588
|
+
for e in img_tbl_exprs:
|
|
589
|
+
_ = str(e)
|
|
590
|
+
print(_)
|
|
591
|
+
|
|
314
592
|
def test_subexprs(self, img_tbl: catalog.Table) -> None:
|
|
315
593
|
t = img_tbl
|
|
316
594
|
e = t.img
|
|
@@ -318,31 +596,210 @@ class TestExprs:
|
|
|
318
596
|
assert len(subexprs) == 1
|
|
319
597
|
e = t.img.rotate(90).resize((224, 224))
|
|
320
598
|
subexprs = [s for s in e.subexprs()]
|
|
321
|
-
assert len(subexprs) ==
|
|
322
|
-
subexprs = [s for s in e.subexprs(
|
|
599
|
+
assert len(subexprs) == 4
|
|
600
|
+
subexprs = [s for s in e.subexprs(expr_class=ColumnRef)]
|
|
323
601
|
assert len(subexprs) == 1
|
|
324
602
|
assert t.img.equals(subexprs[0])
|
|
325
603
|
|
|
326
|
-
def test_window_fns(self,
|
|
327
|
-
|
|
604
|
+
def test_window_fns(self, test_client: pxt.Client, test_tbl: catalog.Table) -> None:
|
|
605
|
+
cl = test_client
|
|
328
606
|
t = test_tbl
|
|
329
|
-
_ = t
|
|
330
|
-
|
|
607
|
+
_ = t.select(sum(t.c2, group_by=t.c4, order_by=t.c3)).show(100)
|
|
608
|
+
|
|
609
|
+
# conflicting ordering requirements
|
|
610
|
+
with pytest.raises(excs.Error):
|
|
611
|
+
_ = t.select(sum(t.c2, group_by=t.c4, order_by=t.c3), sum(t.c2, group_by=t.c3, order_by=t.c4)).show(100)
|
|
612
|
+
with pytest.raises(excs.Error):
|
|
613
|
+
_ = t.select(sum(t.c2, group_by=t.c4, order_by=t.c3), sum(t.c2, group_by=t.c3, order_by=t.c4)).show(100)
|
|
614
|
+
|
|
331
615
|
# backfill works
|
|
332
|
-
t.add_column(
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
616
|
+
t.add_column(c9=sum(t.c2, group_by=t.c4, order_by=t.c3))
|
|
617
|
+
_ = t.c9.col.has_window_fn_call()
|
|
618
|
+
|
|
619
|
+
# ordering conflict between frame extraction and window fn
|
|
620
|
+
base_t = cl.create_table('videos', {'video': VideoType(), 'c2': IntType(nullable=False)})
|
|
621
|
+
args = {'video': base_t.video, 'fps': 0}
|
|
622
|
+
v = cl.create_view('frame_view', base_t, iterator_class=FrameIterator, iterator_args=args)
|
|
623
|
+
# compatible ordering
|
|
624
|
+
_ = v.select(v.frame, sum(v.frame_idx, group_by=base_t, order_by=v.pos)).show(100)
|
|
625
|
+
with pytest.raises(excs.Error):
|
|
626
|
+
# incompatible ordering
|
|
627
|
+
_ = v.select(v.frame, sum(v.c2, order_by=base_t, group_by=v.pos)).show(100)
|
|
628
|
+
|
|
629
|
+
schema = {
|
|
630
|
+
'c2': IntType(nullable=False),
|
|
631
|
+
'c3': FloatType(nullable=False),
|
|
632
|
+
'c4': BoolType(nullable=False),
|
|
633
|
+
}
|
|
634
|
+
new_t = cl.create_table('insert_test', schema=schema)
|
|
635
|
+
new_t.add_column(c2_sum=sum(new_t.c2, group_by=new_t.c4, order_by=new_t.c3))
|
|
636
|
+
rows = list(t.select(t.c2, t.c4, t.c3).collect())
|
|
637
|
+
new_t.insert(rows)
|
|
342
638
|
_ = new_t.show(0)
|
|
343
|
-
print(_)
|
|
344
639
|
|
|
345
640
|
def test_aggregates(self, test_tbl: catalog.Table) -> None:
|
|
346
641
|
t = test_tbl
|
|
347
|
-
_ = t[t.c2 % 2, sum(t.c2), count(t.c2), sum(t.c2) + count(t.c2), sum(t.c2) +
|
|
348
|
-
|
|
642
|
+
_ = t[t.c2 % 2, sum(t.c2), count(t.c2), sum(t.c2) + count(t.c2), sum(t.c2) + (t.c2 % 2)]\
|
|
643
|
+
.group_by(t.c2 % 2).show()
|
|
644
|
+
|
|
645
|
+
# check that aggregates don't show up in the wrong places
|
|
646
|
+
with pytest.raises(excs.Error):
|
|
647
|
+
# aggregate in where clause
|
|
648
|
+
_ = t[sum(t.c2) > 0][sum(t.c2)].group_by(t.c2 % 2).show()
|
|
649
|
+
with pytest.raises(excs.Error):
|
|
650
|
+
# aggregate in group_by clause
|
|
651
|
+
_ = t[sum(t.c2)].group_by(sum(t.c2)).show()
|
|
652
|
+
with pytest.raises(excs.Error):
|
|
653
|
+
# mixing aggregates and non-aggregates
|
|
654
|
+
_ = t[sum(t.c2) + t.c2].group_by(t.c2 % 2).show()
|
|
655
|
+
with pytest.raises(excs.Error):
|
|
656
|
+
# nested aggregates
|
|
657
|
+
_ = t[sum(count(t.c2))].group_by(t.c2 % 2).show()
|
|
658
|
+
|
|
659
|
+
def test_udas(self, test_tbl: catalog.Table) -> None:
|
|
660
|
+
t = test_tbl
|
|
661
|
+
|
|
662
|
+
@pxt.uda(
|
|
663
|
+
name='window_agg', init_types=[IntType()], update_types=[IntType()], value_type=IntType(),
|
|
664
|
+
allows_window=True, requires_order_by=False)
|
|
665
|
+
class WindowAgg:
|
|
666
|
+
def __init__(self, val: int = 0):
|
|
667
|
+
self.val = val
|
|
668
|
+
def update(self, ignore: int) -> None:
|
|
669
|
+
pass
|
|
670
|
+
def value(self) -> int:
|
|
671
|
+
return self.val
|
|
672
|
+
|
|
673
|
+
@pxt.uda(
|
|
674
|
+
name='ordered_agg', init_types=[IntType()], update_types=[IntType()], value_type=IntType(),
|
|
675
|
+
requires_order_by=True, allows_window=True)
|
|
676
|
+
class WindowAgg:
|
|
677
|
+
def __init__(self, val: int = 0):
|
|
678
|
+
self.val = val
|
|
679
|
+
def update(self, i: int) -> None:
|
|
680
|
+
pass
|
|
681
|
+
def value(self) -> int:
|
|
682
|
+
return self.val
|
|
683
|
+
|
|
684
|
+
@pxt.uda(
|
|
685
|
+
name='std_agg', init_types=[IntType()], update_types=[IntType()], value_type=IntType(),
|
|
686
|
+
requires_order_by=False, allows_window=False)
|
|
687
|
+
class StdAgg:
|
|
688
|
+
def __init__(self, val: int = 0):
|
|
689
|
+
self.val = val
|
|
690
|
+
def update(self, i: int) -> None:
|
|
691
|
+
pass
|
|
692
|
+
def value(self) -> int:
|
|
693
|
+
return self.val
|
|
694
|
+
|
|
695
|
+
# init arg is passed along
|
|
696
|
+
assert t.select(out=window_agg(t.c2, order_by=t.c2)).collect()[0]['out'] == 0
|
|
697
|
+
assert t.select(out=window_agg(t.c2, val=1, order_by=t.c2)).collect()[0]['out'] == 1
|
|
698
|
+
|
|
699
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
700
|
+
_ = t.select(window_agg(t.c2, val=t.c2, order_by=t.c2)).collect()
|
|
701
|
+
assert 'needs to be a constant' in str(exc_info.value)
|
|
702
|
+
|
|
703
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
704
|
+
# ordering expression not a pixeltable expr
|
|
705
|
+
_ = t.select(ordered_agg(1, t.c2)).collect()
|
|
706
|
+
assert 'but instead is a' in str(exc_info.value).lower()
|
|
707
|
+
|
|
708
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
709
|
+
# explicit order_by
|
|
710
|
+
_ = t.select(ordered_agg(t.c2, order_by=t.c2)).collect()
|
|
711
|
+
assert 'order_by invalid' in str(exc_info.value).lower()
|
|
712
|
+
|
|
713
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
714
|
+
# order_by for non-window function
|
|
715
|
+
_ = t.select(std_agg(t.c2, order_by=t.c2)).collect()
|
|
716
|
+
assert 'does not allow windows' in str(exc_info.value).lower()
|
|
717
|
+
|
|
718
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
719
|
+
# group_by for non-window function
|
|
720
|
+
_ = t.select(std_agg(t.c2, group_by=t.c4)).collect()
|
|
721
|
+
assert 'group_by invalid' in str(exc_info.value).lower()
|
|
722
|
+
|
|
723
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
724
|
+
# missing init type
|
|
725
|
+
@pxt.uda(update_types=[IntType()], value_type=IntType())
|
|
726
|
+
class WindowAgg:
|
|
727
|
+
def __init__(self, val: int = 0):
|
|
728
|
+
self.val = val
|
|
729
|
+
def update(self, ignore: int) -> None:
|
|
730
|
+
pass
|
|
731
|
+
def value(self) -> int:
|
|
732
|
+
return self.val
|
|
733
|
+
assert 'init_types must be a list of' in str(exc_info.value)
|
|
734
|
+
|
|
735
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
736
|
+
# missing update parameter
|
|
737
|
+
@pxt.uda(init_types=[IntType()], update_types=[], value_type=IntType())
|
|
738
|
+
class WindowAgg:
|
|
739
|
+
def __init__(self, val: int = 0):
|
|
740
|
+
self.val = val
|
|
741
|
+
def update(self) -> None:
|
|
742
|
+
pass
|
|
743
|
+
def value(self) -> int:
|
|
744
|
+
return self.val
|
|
745
|
+
assert 'must have at least one parameter' in str(exc_info.value)
|
|
746
|
+
|
|
747
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
748
|
+
# missing update type
|
|
749
|
+
@pxt.uda(init_types=[IntType()], update_types=[IntType()], value_type=IntType())
|
|
750
|
+
class WindowAgg:
|
|
751
|
+
def __init__(self, val: int = 0):
|
|
752
|
+
self.val = val
|
|
753
|
+
def update(self, i1: int, i2: int) -> None:
|
|
754
|
+
pass
|
|
755
|
+
def value(self) -> int:
|
|
756
|
+
return self.val
|
|
757
|
+
assert 'update_types must be a list of' in str(exc_info.value)
|
|
758
|
+
|
|
759
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
760
|
+
# duplicate parameter names
|
|
761
|
+
@pxt.uda(init_types=[IntType()], update_types=[IntType()], value_type=IntType())
|
|
762
|
+
class WindowAgg:
|
|
763
|
+
def __init__(self, val: int = 0):
|
|
764
|
+
self.val = val
|
|
765
|
+
def update(self, val: int) -> None:
|
|
766
|
+
pass
|
|
767
|
+
def value(self) -> int:
|
|
768
|
+
return self.val
|
|
769
|
+
assert 'cannot have parameters with the same name: val' in str(exc_info.value)
|
|
770
|
+
|
|
771
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
772
|
+
# invalid name
|
|
773
|
+
@pxt.uda(name='not an identifier', init_types=[IntType()], update_types=[IntType()], value_type=IntType())
|
|
774
|
+
class WindowAgg:
|
|
775
|
+
def __init__(self, val: int = 0):
|
|
776
|
+
self.val = val
|
|
777
|
+
def update(self, i1: int, i2: int) -> None:
|
|
778
|
+
pass
|
|
779
|
+
def value(self) -> int:
|
|
780
|
+
return self.val
|
|
781
|
+
assert 'invalid name' in str(exc_info.value).lower()
|
|
782
|
+
|
|
783
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
784
|
+
# reserved parameter name
|
|
785
|
+
@pxt.uda(init_types=[IntType()], update_types=[IntType()], value_type=IntType())
|
|
786
|
+
class WindowAgg:
|
|
787
|
+
def __init__(self, val: int = 0):
|
|
788
|
+
self.val = val
|
|
789
|
+
def update(self, order_by: int) -> None:
|
|
790
|
+
pass
|
|
791
|
+
def value(self) -> int:
|
|
792
|
+
return self.val
|
|
793
|
+
assert 'order_by is reserved' in str(exc_info.value).lower()
|
|
794
|
+
|
|
795
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
796
|
+
# reserved parameter name
|
|
797
|
+
@pxt.uda(init_types=[IntType()], update_types=[IntType()], value_type=IntType())
|
|
798
|
+
class WindowAgg:
|
|
799
|
+
def __init__(self, val: int = 0):
|
|
800
|
+
self.val = val
|
|
801
|
+
def update(self, group_by: int) -> None:
|
|
802
|
+
pass
|
|
803
|
+
def value(self) -> int:
|
|
804
|
+
return self.val
|
|
805
|
+
assert 'group_by is reserved' in str(exc_info.value).lower()
|