pixeltable 0.1.0__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (147) hide show
  1. pixeltable/__init__.py +34 -6
  2. pixeltable/catalog/__init__.py +13 -0
  3. pixeltable/catalog/catalog.py +159 -0
  4. pixeltable/catalog/column.py +200 -0
  5. pixeltable/catalog/dir.py +32 -0
  6. pixeltable/catalog/globals.py +33 -0
  7. pixeltable/catalog/insertable_table.py +191 -0
  8. pixeltable/catalog/named_function.py +36 -0
  9. pixeltable/catalog/path.py +58 -0
  10. pixeltable/catalog/path_dict.py +139 -0
  11. pixeltable/catalog/schema_object.py +39 -0
  12. pixeltable/catalog/table.py +581 -0
  13. pixeltable/catalog/table_version.py +749 -0
  14. pixeltable/catalog/table_version_path.py +133 -0
  15. pixeltable/catalog/view.py +203 -0
  16. pixeltable/client.py +590 -30
  17. pixeltable/dataframe.py +540 -349
  18. pixeltable/env.py +359 -45
  19. pixeltable/exceptions.py +12 -21
  20. pixeltable/exec/__init__.py +9 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +116 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +95 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +69 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +225 -0
  31. pixeltable/exprs/__init__.py +24 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +105 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +195 -0
  39. pixeltable/exprs/expr.py +586 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +380 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +115 -0
  44. pixeltable/exprs/image_similarity_predicate.py +58 -0
  45. pixeltable/exprs/inline_array.py +107 -0
  46. pixeltable/exprs/inline_dict.py +101 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +54 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +355 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/type_cast.py +53 -0
  56. pixeltable/exprs/variable.py +45 -0
  57. pixeltable/func/__init__.py +9 -0
  58. pixeltable/func/aggregate_function.py +194 -0
  59. pixeltable/func/batched_function.py +53 -0
  60. pixeltable/func/callable_function.py +69 -0
  61. pixeltable/func/expr_template_function.py +82 -0
  62. pixeltable/func/function.py +110 -0
  63. pixeltable/func/function_registry.py +227 -0
  64. pixeltable/func/globals.py +36 -0
  65. pixeltable/func/nos_function.py +202 -0
  66. pixeltable/func/signature.py +166 -0
  67. pixeltable/func/udf.py +163 -0
  68. pixeltable/functions/__init__.py +52 -103
  69. pixeltable/functions/eval.py +216 -0
  70. pixeltable/functions/fireworks.py +34 -0
  71. pixeltable/functions/huggingface.py +120 -0
  72. pixeltable/functions/image.py +16 -0
  73. pixeltable/functions/openai.py +256 -0
  74. pixeltable/functions/pil/image.py +148 -7
  75. pixeltable/functions/string.py +13 -0
  76. pixeltable/functions/together.py +122 -0
  77. pixeltable/functions/util.py +41 -0
  78. pixeltable/functions/video.py +62 -0
  79. pixeltable/iterators/__init__.py +3 -0
  80. pixeltable/iterators/base.py +48 -0
  81. pixeltable/iterators/document.py +311 -0
  82. pixeltable/iterators/video.py +89 -0
  83. pixeltable/metadata/__init__.py +54 -0
  84. pixeltable/metadata/converters/convert_10.py +18 -0
  85. pixeltable/metadata/schema.py +211 -0
  86. pixeltable/plan.py +656 -0
  87. pixeltable/store.py +418 -182
  88. pixeltable/tests/conftest.py +146 -88
  89. pixeltable/tests/functions/test_fireworks.py +42 -0
  90. pixeltable/tests/functions/test_functions.py +60 -0
  91. pixeltable/tests/functions/test_huggingface.py +158 -0
  92. pixeltable/tests/functions/test_openai.py +152 -0
  93. pixeltable/tests/functions/test_together.py +111 -0
  94. pixeltable/tests/test_audio.py +65 -0
  95. pixeltable/tests/test_catalog.py +27 -0
  96. pixeltable/tests/test_client.py +14 -14
  97. pixeltable/tests/test_component_view.py +370 -0
  98. pixeltable/tests/test_dataframe.py +439 -0
  99. pixeltable/tests/test_dirs.py +78 -62
  100. pixeltable/tests/test_document.py +120 -0
  101. pixeltable/tests/test_exprs.py +592 -135
  102. pixeltable/tests/test_function.py +297 -67
  103. pixeltable/tests/test_migration.py +43 -0
  104. pixeltable/tests/test_nos.py +54 -0
  105. pixeltable/tests/test_snapshot.py +208 -0
  106. pixeltable/tests/test_table.py +1195 -263
  107. pixeltable/tests/test_transactional_directory.py +42 -0
  108. pixeltable/tests/test_types.py +5 -11
  109. pixeltable/tests/test_video.py +151 -34
  110. pixeltable/tests/test_view.py +530 -0
  111. pixeltable/tests/utils.py +320 -45
  112. pixeltable/tool/create_test_db_dump.py +149 -0
  113. pixeltable/tool/create_test_video.py +81 -0
  114. pixeltable/type_system.py +445 -124
  115. pixeltable/utils/__init__.py +17 -46
  116. pixeltable/utils/arrow.py +98 -0
  117. pixeltable/utils/clip.py +12 -15
  118. pixeltable/utils/coco.py +136 -0
  119. pixeltable/utils/documents.py +39 -0
  120. pixeltable/utils/filecache.py +195 -0
  121. pixeltable/utils/help.py +11 -0
  122. pixeltable/utils/hf_datasets.py +157 -0
  123. pixeltable/utils/media_store.py +76 -0
  124. pixeltable/utils/parquet.py +167 -0
  125. pixeltable/utils/pytorch.py +91 -0
  126. pixeltable/utils/s3.py +13 -0
  127. pixeltable/utils/sql.py +17 -0
  128. pixeltable/utils/transactional_directory.py +35 -0
  129. pixeltable-0.2.4.dist-info/LICENSE +18 -0
  130. pixeltable-0.2.4.dist-info/METADATA +127 -0
  131. pixeltable-0.2.4.dist-info/RECORD +132 -0
  132. {pixeltable-0.1.0.dist-info → pixeltable-0.2.4.dist-info}/WHEEL +1 -1
  133. pixeltable/catalog.py +0 -1421
  134. pixeltable/exprs.py +0 -1745
  135. pixeltable/function.py +0 -269
  136. pixeltable/functions/clip.py +0 -10
  137. pixeltable/functions/pil/__init__.py +0 -23
  138. pixeltable/functions/tf.py +0 -21
  139. pixeltable/index.py +0 -57
  140. pixeltable/tests/test_dict.py +0 -24
  141. pixeltable/tests/test_functions.py +0 -11
  142. pixeltable/tests/test_tf.py +0 -69
  143. pixeltable/tf.py +0 -33
  144. pixeltable/utils/tf.py +0 -33
  145. pixeltable/utils/video.py +0 -32
  146. pixeltable-0.1.0.dist-info/METADATA +0 -34
  147. pixeltable-0.1.0.dist-info/RECORD +0 -36
@@ -0,0 +1,439 @@
1
+ import datetime
2
+ import pickle
3
+ from pathlib import Path
4
+ from typing import Any, Dict
5
+
6
+ import bs4
7
+ import numpy as np
8
+ import pytest
9
+ import requests
10
+
11
+ import pixeltable as pxt
12
+ from pixeltable import catalog
13
+ from pixeltable import exceptions as excs
14
+ from pixeltable.iterators import FrameIterator
15
+ from pixeltable.tests.utils import get_video_files, get_audio_files, skip_test_if_not_installed
16
+
17
+
18
+ class TestDataFrame:
19
+ def test_select_where(self, test_tbl: catalog.Table) -> None:
20
+ t = test_tbl
21
+ res1 = t[t.c1, t.c2, t.c3].show(0)
22
+ res2 = t.select(t.c1, t.c2, t.c3).show(0)
23
+ assert res1 == res2
24
+
25
+ res1 = t[t.c2 < 10][t.c1, t.c2, t.c3].show(0)
26
+ res2 = t.where(t.c2 < 10).select(t.c1, t.c2, t.c3).show(0)
27
+ assert res1 == res2
28
+
29
+ res3 = t.where(t.c2 < 10).select(c1=t.c1, c2=t.c2, c3=t.c3).show(0)
30
+ assert res1 == res3
31
+
32
+ res4 = t.where(t.c2 < 10).select(t.c1, c2=t.c2, c3=t.c3).show(0)
33
+ assert res1 == res4
34
+
35
+ _ = t.where(t.c2 < 10).select(t.c2, t.c2).show(0) # repeated name no error
36
+
37
+ # duplicate select list
38
+ with pytest.raises(excs.Error) as exc_info:
39
+ _ = t.select(t.c1).select(t.c2).show(0)
40
+ assert 'already specified' in str(exc_info.value)
41
+
42
+ # invalid expr in select list: Callable is not a valid literal
43
+ with pytest.raises(TypeError) as exc_info:
44
+ _ = t.select(datetime.datetime.now).show(0)
45
+ assert 'Not a valid literal' in str(exc_info.value)
46
+
47
+ # catch invalid name in select list from user input
48
+ # only check stuff that's not caught by python kwargs checker
49
+ with pytest.raises(excs.Error) as exc_info:
50
+ _ = t.select(t.c1, **{'c2-1': t.c2}).show(0)
51
+ assert 'Invalid name' in str(exc_info.value)
52
+
53
+ with pytest.raises(excs.Error) as exc_info:
54
+ _ = t.select(t.c1, **{'': t.c2}).show(0)
55
+ assert 'Invalid name' in str(exc_info.value)
56
+
57
+ with pytest.raises(excs.Error) as exc_info:
58
+ _ = t.select(t.c1, **{'foo.bar': t.c2}).show(0)
59
+ assert 'Invalid name' in str(exc_info.value)
60
+
61
+ with pytest.raises(excs.Error) as exc_info:
62
+ _ = t.select(t.c1, _c3=t.c2).show(0)
63
+ assert 'Invalid name' in str(exc_info.value)
64
+
65
+ # catch repeated name from user input
66
+ with pytest.raises(excs.Error) as exc_info:
67
+ _ = t.select(t.c2, c2=t.c1).show(0)
68
+ assert 'Repeated column name' in str(exc_info.value)
69
+
70
+ with pytest.raises(excs.Error) as exc_info:
71
+ _ = t.select(t.c2+1, col_0=t.c2).show(0)
72
+ assert 'Repeated column name' in str(exc_info.value)
73
+
74
+ def test_result_set_iterator(self, test_tbl: catalog.Table) -> None:
75
+ t = test_tbl
76
+ res = t.select(t.c1, t.c2, t.c3).collect()
77
+ pd_df = res.to_pandas()
78
+
79
+ def check_row(row: Dict[str, Any], idx: int) -> None:
80
+ assert len(row) == 3
81
+ assert 'c1' in row
82
+ assert row['c1'] == pd_df['c1'][idx]
83
+ assert 'c2' in row
84
+ assert row['c2'] == pd_df['c2'][idx]
85
+ assert 'c3' in row
86
+ assert row['c3'] == pd_df['c3'][idx]
87
+
88
+ # row iteration
89
+ for idx, row in enumerate(res):
90
+ check_row(row, idx)
91
+
92
+ # row access
93
+ row = res[0]
94
+ check_row(row, 0)
95
+
96
+ # column access
97
+ col_values = res['c2']
98
+ assert col_values == pd_df['c2'].values.tolist()
99
+
100
+ # cell access
101
+ assert res[0, 'c2'] == pd_df['c2'][0]
102
+ assert res[0, 'c2'] == res[0, 1]
103
+
104
+ with pytest.raises(excs.Error) as exc_info:
105
+ _ = res['does_not_exist']
106
+ assert 'Invalid column name' in str(exc_info.value)
107
+
108
+ with pytest.raises(excs.Error) as exc_info:
109
+ _ = res[0, 'does_not_exist']
110
+ assert 'Invalid column name' in str(exc_info.value)
111
+
112
+ with pytest.raises(excs.Error) as exc_info:
113
+ _ = res[0, 0, 0]
114
+ assert 'Bad index' in str(exc_info.value)
115
+
116
+ with pytest.raises(excs.Error) as exc_info:
117
+ _ = res['c2', 0]
118
+ assert 'Bad index' in str(exc_info.value)
119
+
120
+ def test_order_by(self, test_tbl: catalog.Table) -> None:
121
+ t = test_tbl
122
+ res = t.select(t.c4, t.c2).order_by(t.c4).order_by(t.c2, asc=False).show(0)
123
+
124
+ # invalid expr in order_by()
125
+ with pytest.raises(excs.Error) as exc_info:
126
+ _ = t.order_by(datetime.datetime.now()).show(0)
127
+ assert 'Invalid expression' in str(exc_info.value)
128
+
129
+ def test_head_tail(self, test_tbl: catalog.Table) -> None:
130
+ t = test_tbl
131
+ res = t.head(10).to_pandas()
132
+ assert np.all(res.c2 == list(range(10)))
133
+ # Where is applied
134
+ res = t.where(t.c2 > 9).head(10).to_pandas()
135
+ assert np.all(res.c2 == list(range(10, 20)))
136
+ # order_by() is an error
137
+ with pytest.raises(excs.Error) as exc_info:
138
+ _ = t.order_by(t.c2).head(10)
139
+ assert 'cannot be used with order_by' in str(exc_info.value)
140
+
141
+ res = t.tail().to_pandas()
142
+ assert np.all(res.c2 == list(range(90, 100)))
143
+ res = t.where(t.c2 < 90).tail().to_pandas()
144
+ assert np.all(res.c2 == list(range(80, 90)))
145
+ # order_by() is an error
146
+ with pytest.raises(excs.Error) as exc_info:
147
+ _ = t.order_by(t.c2).tail(10)
148
+ assert 'cannot be used with order_by' in str(exc_info.value)
149
+
150
+ def test_describe(self, test_tbl: catalog.Table) -> None:
151
+ t = test_tbl
152
+ df = t.select(t.c1).where(t.c2 < 10).limit(10)
153
+ df.describe()
154
+
155
+ # TODO: how to you check the output of these?
156
+ _ = df.__repr__()
157
+ _ = df._repr_html_()
158
+
159
+ def test_count(self, test_tbl: catalog.Table, indexed_img_tbl: catalog.Table) -> None:
160
+ skip_test_if_not_installed('nos')
161
+ t = test_tbl
162
+ cnt = t.count()
163
+ assert cnt == 100
164
+
165
+ cnt = t.where(t.c2 < 10).count()
166
+ assert cnt == 10
167
+
168
+ # count() doesn't work with similarity search
169
+ t = indexed_img_tbl
170
+ probe = t.select(t.img).show(1)
171
+ img = probe[0, 0]
172
+ with pytest.raises(excs.Error):
173
+ _ = t.where(t.img.nearest(img)).count()
174
+ with pytest.raises(excs.Error):
175
+ _ = t.where(t.img.nearest('car')).count()
176
+
177
+ # for now, count() doesn't work with non-SQL Where clauses
178
+ with pytest.raises(excs.Error):
179
+ _ = t.where(t.img.width > 100).count()
180
+
181
+ def test_select_literal(self, test_tbl: catalog.Table) -> None:
182
+ t = test_tbl
183
+ res = t.select(1.0).where(t.c2 < 10).collect()
184
+ assert res[res.column_names()[0]] == [1.0] * 10
185
+
186
+ # TODO This test doesn't work on Windows due to reliance on the structure of file URLs
187
+ @pytest.mark.skip('Test is not portable')
188
+ def test_html_media_url(self, test_client: pxt.Client) -> None:
189
+ tab = test_client.create_table('test_html_repr', {'video': pxt.VideoType(), 'audio': pxt.AudioType()})
190
+ status = tab.insert(video=get_video_files()[0], audio=get_audio_files()[0])
191
+ assert status.num_rows == 1
192
+ assert status.num_excs == 0
193
+
194
+ res = tab.select(tab.video, tab.audio).collect()
195
+ doc = bs4.BeautifulSoup(res._repr_html_(), features='html.parser')
196
+ video_tags = doc.find_all('video')
197
+ assert len(video_tags) == 1
198
+ audio_tags = doc.find_all('audio')
199
+ assert len(audio_tags) == 1
200
+
201
+ # get the source elements and test their src attributes
202
+ for tag in video_tags + audio_tags:
203
+ sources = tag.find_all('source')
204
+ assert len(sources) == 1
205
+ for src in sources:
206
+ response = requests.get(src['src'])
207
+ assert response.status_code == 200
208
+
209
+ def test_to_pytorch_dataset(self, all_datatypes_tbl: catalog.Table):
210
+ """ tests all types are handled correctly in this conversion
211
+ """
212
+ skip_test_if_not_installed('torch')
213
+ import torch
214
+
215
+ t = all_datatypes_tbl
216
+ df = t.where(t.row_id < 1)
217
+ assert df.count() > 0
218
+ ds = df.to_pytorch_dataset()
219
+ type_dict = dict(zip(df.get_column_names(),df.get_column_types()))
220
+ for tup in ds:
221
+ for col in df.get_column_names():
222
+ assert col in tup
223
+
224
+ arrval = tup['c_array']
225
+ assert isinstance(arrval, np.ndarray)
226
+ col_type = type_dict['c_array']
227
+ assert arrval.dtype == col_type.numpy_dtype()
228
+ assert arrval.shape == col_type.shape
229
+ assert arrval.dtype == np.float32
230
+ assert arrval.flags["WRITEABLE"], 'required by pytorch collate function'
231
+
232
+ assert isinstance(tup['c_bool'], bool)
233
+ assert isinstance(tup['c_int'], int)
234
+ assert isinstance(tup['c_float'], float)
235
+ assert isinstance(tup['c_timestamp'], float)
236
+ assert torch.is_tensor(tup['c_image'])
237
+ assert isinstance(tup['c_video'], str)
238
+ assert isinstance(tup['c_json'], dict)
239
+
240
+ def test_to_pytorch_image_format(self, all_datatypes_tbl: catalog.Table) -> None:
241
+ """ tests the image_format parameter is honored
242
+ """
243
+ skip_test_if_not_installed('torch')
244
+ import torch
245
+ import torchvision.transforms as T
246
+
247
+ W, H = 220, 224 # make different from each other
248
+ t = all_datatypes_tbl
249
+ df = t.select(
250
+ t.row_id,
251
+ t.c_image,
252
+ c_image_xformed=t.c_image.resize([W, H]).convert('RGB')
253
+ ).where(t.row_id < 1)
254
+
255
+ pandas_df = df.show().to_pandas()
256
+ im_plain = pandas_df['c_image'].values[0]
257
+ im_xformed = pandas_df['c_image_xformed'].values[0]
258
+ assert pandas_df.shape[0] == 1
259
+
260
+ ds = df.to_pytorch_dataset(image_format='np')
261
+ ds_ptformat = df.to_pytorch_dataset(image_format='pt')
262
+
263
+ elt_count = 0
264
+ for elt, elt_pt in zip(ds, ds_ptformat):
265
+ arr_plain = elt['c_image']
266
+ assert isinstance(arr_plain, np.ndarray)
267
+ assert arr_plain.flags["WRITEABLE"], 'required by pytorch collate function'
268
+
269
+ # NB: compare numpy array bc PIL.Image object itself is not using same file.
270
+ assert (arr_plain == np.array(im_plain)).all(), 'numpy image should be the same as the original'
271
+ arr_xformed = elt['c_image_xformed']
272
+ assert isinstance(arr_xformed, np.ndarray)
273
+ assert arr_xformed.flags["WRITEABLE"], 'required by pytorch collate function'
274
+
275
+ assert arr_xformed.shape == (H, W, 3)
276
+ assert arr_xformed.dtype == np.uint8
277
+ # same as above, compare numpy array bc PIL.Image object itself is not using same file.
278
+ assert (arr_xformed == np.array(im_xformed)).all(),\
279
+ 'numpy image array for xformed image should be the same as the original'
280
+
281
+ # now compare pytorch version
282
+ arr_pt = elt_pt['c_image']
283
+ assert torch.is_tensor(arr_pt)
284
+ arr_pt = elt_pt['c_image_xformed']
285
+ assert torch.is_tensor(arr_pt)
286
+ assert arr_pt.shape == (3, H, W)
287
+ assert arr_pt.dtype == torch.float32
288
+ assert (0.0 <= arr_pt).all()
289
+ assert (arr_pt <= 1.0).all()
290
+ assert torch.isclose(T.ToTensor()(arr_xformed), arr_pt).all(),\
291
+ 'pytorch image should be consistent with numpy image'
292
+ elt_count += 1
293
+ assert elt_count == 1
294
+
295
+ @pytest.mark.skip('Flaky test (fails intermittently)')
296
+ def test_to_pytorch_dataloader(self, all_datatypes_tbl: catalog.Table) -> None:
297
+ """ Tests the dataset works well with pytorch dataloader:
298
+ 1. compatibility with multiprocessing
299
+ 2. compatibility of all types with default collate_fn
300
+ """
301
+ skip_test_if_not_installed('torch')
302
+ import torch.utils.data
303
+ @pxt.udf(param_types=[pxt.JsonType()], return_type=pxt.JsonType())
304
+ def restrict_json_for_default_collate(obj):
305
+ keys = ['id', 'label', 'iscrowd', 'bounding_box']
306
+ return {k: obj[k] for k in keys}
307
+
308
+ t = all_datatypes_tbl
309
+ df = t.select(
310
+ t.row_id,
311
+ t.c_int,
312
+ t.c_float,
313
+ t.c_bool,
314
+ t.c_timestamp,
315
+ t.c_array,
316
+ t.c_video,
317
+ # default collate_fn doesnt support null values, nor lists of different lengths
318
+ # but does allow some dictionaries if they are uniform
319
+ c_json = restrict_json_for_default_collate(t.c_json.detections[0]),
320
+ # images must be uniform shape for pytorch collate_fn to not fail
321
+ c_image=t.c_image.resize([220, 224]).convert('RGB')
322
+ )
323
+ df_size = df.count()
324
+ ds = df.to_pytorch_dataset(image_format='pt')
325
+ # test serialization:
326
+ # - pickle.dumps() and pickle.loads() must work so that
327
+ # we can use num_workers > 0
328
+ x = pickle.dumps(ds)
329
+ _ = pickle.loads(x)
330
+
331
+ # test we get all rows
332
+ def check_recover_all_rows(ds, size : int, **kwargs):
333
+ dl = torch.utils.data.DataLoader(ds, **kwargs)
334
+ loaded_ids = set()
335
+ for batch in dl:
336
+ for row_id in batch['row_id']:
337
+ val = int(row_id) # np.int -> int or will fail set equality test below.
338
+ assert val not in loaded_ids, val
339
+ loaded_ids.add(val)
340
+
341
+ assert loaded_ids == set(range(size))
342
+
343
+ # check different number of workers
344
+ check_recover_all_rows(ds, size=df_size, batch_size=3, num_workers=0) # within this process
345
+ check_recover_all_rows(ds, size=df_size, batch_size=3, num_workers=2) # two separate processes
346
+
347
+ # check edge case where some workers get no rows
348
+ short_size = 1
349
+ df_short = df.where(t.row_id < short_size)
350
+ ds_short = df_short.to_pytorch_dataset(image_format='pt')
351
+ check_recover_all_rows(ds_short, size=short_size, batch_size=13, num_workers=short_size+1)
352
+
353
+ def test_pytorch_dataset_caching(self, all_datatypes_tbl: catalog.Table) -> None:
354
+ """ Tests that dataset caching works
355
+ 1. using the same dataset twice in a row uses the cache
356
+ 2. adding a row to the table invalidates the cached version
357
+ 3. changing the select list invalidates the cached version
358
+ """
359
+ skip_test_if_not_installed('torch')
360
+ t = all_datatypes_tbl
361
+
362
+ t.drop_column('c_video') # null value video column triggers internal assertions in DataRow
363
+ # see https://github.com/pixeltable/pixeltable/issues/38
364
+
365
+ t.drop_column('c_array') # no support yet for null array values in the pytorch dataset
366
+
367
+ def _get_mtimes(dir: Path):
368
+ return {p.name: p.stat().st_mtime for p in dir.iterdir()}
369
+
370
+ # check result cached
371
+ ds1 = t.to_pytorch_dataset(image_format='pt')
372
+ ds1_mtimes = _get_mtimes(ds1.path)
373
+
374
+ ds2 = t.to_pytorch_dataset(image_format='pt')
375
+ ds2_mtimes = _get_mtimes(ds2.path)
376
+ assert ds2.path == ds1.path, 'result should be cached'
377
+ assert ds2_mtimes == ds1_mtimes, 'no extra file system work should have occurred'
378
+
379
+ # check invalidation on insert
380
+ t_size = t.count()
381
+ t.insert(row_id=t_size)
382
+ ds3 = t.to_pytorch_dataset(image_format='pt')
383
+ assert ds3.path != ds1.path, 'different path should be used'
384
+
385
+ # check select list invalidation
386
+ ds4 = t.select(t.row_id).to_pytorch_dataset(image_format='pt')
387
+ assert ds4.path != ds3.path, 'different select list, hence different path should be used'
388
+
389
+ def test_to_coco(self, test_client: pxt.Client) -> None:
390
+ skip_test_if_not_installed('nos')
391
+ from pycocotools.coco import COCO
392
+ cl = test_client
393
+ base_t = cl.create_table('videos', {'video': pxt.VideoType()})
394
+ args = {'video': base_t.video, 'fps': 1}
395
+ view_t = cl.create_view('frames', base_t, iterator_class=FrameIterator, iterator_args=args)
396
+ from pixeltable.functions.nos.object_detection_2d import yolox_medium
397
+ view_t.add_column(detections=yolox_medium(view_t.frame))
398
+ base_t.insert(video=get_video_files()[0])
399
+
400
+ @pxt.udf(return_type=pxt.JsonType(nullable=False), param_types=[pxt.JsonType(nullable=False)])
401
+ def yolo_to_coco(detections):
402
+ bboxes, labels = detections['bboxes'], detections['labels']
403
+ num_annotations = len(detections['bboxes'])
404
+ assert num_annotations == len(detections['labels'])
405
+ result = []
406
+ for i in range(num_annotations):
407
+ bbox = bboxes[i]
408
+ ann = {
409
+ 'bbox': [round(bbox[0]), round(bbox[1]), round(bbox[2] - bbox[0]), round(bbox[3] - bbox[1])],
410
+ 'category': labels[i],
411
+ }
412
+ result.append(ann)
413
+ return result
414
+
415
+ query = view_t.select({'image': view_t.frame, 'annotations': yolo_to_coco(view_t.detections)})
416
+ path = query.to_coco_dataset()
417
+ # we get a valid COCO dataset
418
+ coco_ds = COCO(path)
419
+ assert len(coco_ds.imgs) == view_t.count()
420
+
421
+ # we call to_coco_dataset() again and get the cached dataset
422
+ new_path = query.to_coco_dataset()
423
+ assert path == new_path
424
+
425
+ # the cache is invalidated when we add more data
426
+ base_t.insert(video=get_video_files()[1])
427
+ new_path = query.to_coco_dataset()
428
+ assert path != new_path
429
+ coco_ds = COCO(new_path)
430
+ assert len(coco_ds.imgs) == view_t.count()
431
+
432
+ # incorrect select list
433
+ with pytest.raises(excs.Error) as exc_info:
434
+ _ = view_t.select({'image': view_t.frame, 'annotations': view_t.detections}).to_coco_dataset()
435
+ assert '"annotations" is not a list' in str(exc_info.value)
436
+
437
+ with pytest.raises(excs.Error) as exc_info:
438
+ _ = view_t.select(view_t.detections).to_coco_dataset()
439
+ assert 'missing key "image"' in str(exc_info.value).lower()
@@ -1,91 +1,107 @@
1
1
  import pytest
2
2
 
3
- import pixeltable as pt
4
- from pixeltable import exceptions as exc
3
+ import pixeltable as pxt
4
+ from pixeltable import exceptions as excs
5
5
  from pixeltable.tests.utils import make_tbl
6
- from pixeltable import catalog
7
6
 
8
7
 
9
8
  class TestDirs:
10
- def test_create(self, test_db: catalog.Db) -> None:
11
- db = test_db
9
+ def test_create(self, test_client: pxt.Client) -> None:
10
+ cl = test_client
12
11
  dirs = ['dir1', 'dir1.sub1', 'dir1.sub1.subsub1']
13
12
  for name in dirs:
14
- db.create_dir(name)
13
+ cl.create_dir(name)
15
14
 
16
- with pytest.raises(exc.BadFormatError):
17
- db.create_dir('1dir')
18
- with pytest.raises(exc.BadFormatError):
19
- db.create_dir('_dir1')
20
- with pytest.raises(exc.BadFormatError):
21
- db.create_dir('dir 1')
22
- with pytest.raises(exc.BadFormatError):
23
- db.create_dir('dir1..sub2')
24
- with pytest.raises(exc.BadFormatError):
25
- db.create_dir('dir1.sub2.')
26
- with pytest.raises(exc.BadFormatError):
27
- db.create_dir('dir1:sub2.')
15
+ # invalid names
16
+ with pytest.raises(excs.Error):
17
+ cl.create_dir('1dir')
18
+ with pytest.raises(excs.Error):
19
+ cl.create_dir('_dir1')
20
+ with pytest.raises(excs.Error):
21
+ cl.create_dir('dir 1')
22
+ with pytest.raises(excs.Error):
23
+ cl.create_dir('dir1..sub2')
24
+ with pytest.raises(excs.Error):
25
+ cl.create_dir('dir1.sub2.')
26
+ with pytest.raises(excs.Error):
27
+ cl.create_dir('dir1:sub2.')
28
28
 
29
29
  # existing dirs
30
- with pytest.raises(exc.DuplicateNameError):
31
- db.create_dir('dir1')
32
- with pytest.raises(exc.DuplicateNameError):
33
- db.create_dir('dir1.sub1')
34
- with pytest.raises(exc.DuplicateNameError):
35
- db.create_dir('dir1.sub1.subsub1')
30
+ with pytest.raises(excs.Error):
31
+ cl.create_dir('dir1')
32
+ cl.create_dir('dir1', ignore_errors=True)
33
+ with pytest.raises(excs.Error):
34
+ cl.create_dir('dir1.sub1')
35
+ with pytest.raises(excs.Error):
36
+ cl.create_dir('dir1.sub1.subsub1')
36
37
 
37
38
  # existing table
38
- make_tbl(db, 'dir1.t1')
39
- with pytest.raises(exc.DuplicateNameError):
40
- db.create_dir('dir1.t1')
39
+ make_tbl(cl, 'dir1.t1')
40
+ with pytest.raises(excs.Error):
41
+ cl.create_dir('dir1.t1')
41
42
 
42
- with pytest.raises(exc.UnknownEntityError):
43
- db.create_dir('dir2.sub2')
44
- make_tbl(db, 't2')
45
- with pytest.raises(exc.UnknownEntityError):
46
- db.create_dir('t2.sub2')
43
+ with pytest.raises(excs.Error):
44
+ cl.create_dir('dir2.sub2')
45
+ make_tbl(cl, 't2')
46
+ with pytest.raises(excs.Error):
47
+ cl.create_dir('t2.sub2')
47
48
 
48
49
  # new client: force loading from store
49
- cl2 = pt.Client()
50
- db = cl2.get_db('test')
50
+ cl2 = pxt.Client(reload=True)
51
51
 
52
- listing = db.list_dirs(recursive=True)
52
+ listing = cl2.list_dirs(recursive=True)
53
53
  assert listing == dirs
54
- listing = db.list_dirs(recursive=False)
54
+ listing = cl2.list_dirs(recursive=False)
55
55
  assert listing == ['dir1']
56
- listing = db.list_dirs('dir1', recursive=True)
56
+ listing = cl2.list_dirs('dir1', recursive=True)
57
57
  assert listing == ['dir1.sub1', 'dir1.sub1.subsub1']
58
- listing = db.list_dirs('dir1', recursive=False)
58
+ listing = cl2.list_dirs('dir1', recursive=False)
59
59
  assert listing == ['dir1.sub1']
60
- listing = db.list_dirs('dir1.sub1', recursive=True)
60
+ listing = cl2.list_dirs('dir1.sub1', recursive=True)
61
61
  assert listing == ['dir1.sub1.subsub1']
62
- listing = db.list_dirs('dir1.sub1', recursive=False)
62
+ listing = cl2.list_dirs('dir1.sub1', recursive=False)
63
63
  assert listing == ['dir1.sub1.subsub1']
64
64
 
65
- def test_rm(self, test_db: catalog.Db) -> None:
66
- db = test_db
65
+ def test_rm(self, test_client: pxt.Client) -> None:
66
+ cl = test_client
67
67
  dirs = ['dir1', 'dir1.sub1', 'dir1.sub1.subsub1']
68
68
  for name in dirs:
69
- db.create_dir(name)
70
- make_tbl(db, 't1')
71
- make_tbl(db, 'dir1.t1')
69
+ cl.create_dir(name)
70
+ make_tbl(cl, 't1')
71
+ make_tbl(cl, 'dir1.t1')
72
72
 
73
- with pytest.raises(exc.BadFormatError):
74
- db.rm_dir('1dir')
75
- with pytest.raises(exc.BadFormatError):
76
- db.rm_dir('dir1..sub1')
77
- with pytest.raises(exc.UnknownEntityError):
78
- db.rm_dir('dir2')
79
- with pytest.raises(exc.UnknownEntityError):
80
- db.rm_dir('t1')
73
+ # bad name
74
+ with pytest.raises(excs.Error):
75
+ cl.rm_dir('1dir')
76
+ # bad path
77
+ with pytest.raises(excs.Error):
78
+ cl.rm_dir('dir1..sub1')
79
+ # doesn't exist
80
+ with pytest.raises(excs.Error):
81
+ cl.rm_dir('dir2')
82
+ # not empty
83
+ with pytest.raises(excs.Error):
84
+ cl.rm_dir('dir1')
81
85
 
82
- with pytest.raises(exc.DirectoryNotEmptyError):
83
- db.rm_dir('dir1')
86
+ cl.rm_dir('dir1.sub1.subsub1')
87
+ assert cl.list_dirs('dir1.sub1') == []
84
88
 
85
- def test_rename_tbl(self, test_db: catalog.Db) -> None:
86
- db = test_db
87
- db.create_dir('dir1')
88
- make_tbl(db, 'dir1.t1')
89
- assert db.list_tables('dir1') == ['dir1.t1']
90
- db.rename_table('dir1.t1', 't2')
91
- assert db.list_tables('dir1') == ['dir1.t2']
89
+ # check after reloading
90
+ cl = pxt.Client(reload=True)
91
+ assert cl.list_dirs('dir1.sub1') == []
92
+
93
+ def test_move(self, test_client: pxt.Client) -> None:
94
+ cl = test_client
95
+ cl.create_dir('dir1')
96
+ cl.create_dir('dir1.sub1')
97
+ make_tbl(cl, 'dir1.sub1.t1')
98
+ assert cl.list_tables('dir1') == ['dir1.sub1.t1']
99
+ cl.move('dir1.sub1.t1', 'dir1.sub1.t2')
100
+ assert cl.list_tables('dir1') == ['dir1.sub1.t2']
101
+ cl.create_dir('dir2')
102
+ cl.move('dir1', 'dir2.dir1')
103
+ assert cl.list_tables('dir2') == ['dir2.dir1.sub1.t2']
104
+
105
+ # new client: force loading from store
106
+ cl2 = pxt.Client(reload=True)
107
+ assert cl2.list_tables('dir2') == ['dir2.dir1.sub1.t2']