pixeltable 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (99) hide show
  1. pixeltable/__init__.py +18 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +31 -50
  4. pixeltable/catalog/insertable_table.py +7 -6
  5. pixeltable/catalog/table.py +171 -57
  6. pixeltable/catalog/table_version.py +417 -140
  7. pixeltable/catalog/table_version_path.py +2 -2
  8. pixeltable/dataframe.py +239 -121
  9. pixeltable/env.py +82 -16
  10. pixeltable/exec/__init__.py +2 -1
  11. pixeltable/exec/cache_prefetch_node.py +1 -1
  12. pixeltable/exec/data_row_batch.py +6 -7
  13. pixeltable/exec/expr_eval_node.py +28 -28
  14. pixeltable/exec/in_memory_data_node.py +11 -7
  15. pixeltable/exec/sql_scan_node.py +7 -6
  16. pixeltable/exprs/__init__.py +4 -3
  17. pixeltable/exprs/column_ref.py +9 -0
  18. pixeltable/exprs/comparison.py +3 -3
  19. pixeltable/exprs/data_row.py +5 -1
  20. pixeltable/exprs/expr.py +15 -7
  21. pixeltable/exprs/function_call.py +17 -15
  22. pixeltable/exprs/image_member_access.py +9 -28
  23. pixeltable/exprs/in_predicate.py +96 -0
  24. pixeltable/exprs/inline_array.py +13 -11
  25. pixeltable/exprs/inline_dict.py +15 -13
  26. pixeltable/exprs/literal.py +16 -4
  27. pixeltable/exprs/row_builder.py +15 -41
  28. pixeltable/exprs/similarity_expr.py +65 -0
  29. pixeltable/ext/__init__.py +5 -0
  30. pixeltable/ext/functions/yolox.py +92 -0
  31. pixeltable/func/__init__.py +0 -2
  32. pixeltable/func/aggregate_function.py +18 -15
  33. pixeltable/func/callable_function.py +57 -13
  34. pixeltable/func/expr_template_function.py +20 -3
  35. pixeltable/func/function.py +35 -4
  36. pixeltable/func/globals.py +24 -14
  37. pixeltable/func/signature.py +23 -27
  38. pixeltable/func/udf.py +13 -12
  39. pixeltable/functions/__init__.py +8 -8
  40. pixeltable/functions/eval.py +7 -8
  41. pixeltable/functions/huggingface.py +64 -17
  42. pixeltable/functions/openai.py +36 -3
  43. pixeltable/functions/pil/image.py +61 -64
  44. pixeltable/functions/together.py +21 -0
  45. pixeltable/functions/util.py +11 -0
  46. pixeltable/globals.py +425 -0
  47. pixeltable/index/__init__.py +2 -0
  48. pixeltable/index/base.py +51 -0
  49. pixeltable/index/embedding_index.py +168 -0
  50. pixeltable/io/__init__.py +3 -0
  51. pixeltable/{utils → io}/hf_datasets.py +48 -17
  52. pixeltable/io/pandas.py +148 -0
  53. pixeltable/{utils → io}/parquet.py +58 -33
  54. pixeltable/iterators/__init__.py +1 -1
  55. pixeltable/iterators/base.py +4 -0
  56. pixeltable/iterators/document.py +218 -97
  57. pixeltable/iterators/video.py +8 -9
  58. pixeltable/metadata/__init__.py +7 -3
  59. pixeltable/metadata/converters/convert_12.py +3 -0
  60. pixeltable/metadata/converters/convert_13.py +41 -0
  61. pixeltable/metadata/schema.py +45 -22
  62. pixeltable/plan.py +15 -51
  63. pixeltable/store.py +38 -41
  64. pixeltable/tool/create_test_db_dump.py +39 -4
  65. pixeltable/type_system.py +47 -96
  66. pixeltable/utils/documents.py +42 -12
  67. pixeltable/utils/http_server.py +70 -0
  68. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/METADATA +14 -10
  69. pixeltable-0.2.6.dist-info/RECORD +119 -0
  70. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/WHEEL +1 -1
  71. pixeltable/client.py +0 -604
  72. pixeltable/exprs/image_similarity_predicate.py +0 -58
  73. pixeltable/func/batched_function.py +0 -53
  74. pixeltable/tests/conftest.py +0 -177
  75. pixeltable/tests/functions/test_fireworks.py +0 -42
  76. pixeltable/tests/functions/test_functions.py +0 -60
  77. pixeltable/tests/functions/test_huggingface.py +0 -158
  78. pixeltable/tests/functions/test_openai.py +0 -152
  79. pixeltable/tests/functions/test_together.py +0 -111
  80. pixeltable/tests/test_audio.py +0 -65
  81. pixeltable/tests/test_catalog.py +0 -27
  82. pixeltable/tests/test_client.py +0 -21
  83. pixeltable/tests/test_component_view.py +0 -370
  84. pixeltable/tests/test_dataframe.py +0 -439
  85. pixeltable/tests/test_dirs.py +0 -107
  86. pixeltable/tests/test_document.py +0 -120
  87. pixeltable/tests/test_exprs.py +0 -805
  88. pixeltable/tests/test_function.py +0 -324
  89. pixeltable/tests/test_migration.py +0 -43
  90. pixeltable/tests/test_nos.py +0 -54
  91. pixeltable/tests/test_snapshot.py +0 -208
  92. pixeltable/tests/test_table.py +0 -1267
  93. pixeltable/tests/test_transactional_directory.py +0 -42
  94. pixeltable/tests/test_types.py +0 -22
  95. pixeltable/tests/test_video.py +0 -159
  96. pixeltable/tests/test_view.py +0 -530
  97. pixeltable/tests/utils.py +0 -408
  98. pixeltable-0.2.4.dist-info/RECORD +0 -132
  99. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/LICENSE +0 -0
@@ -1,439 +0,0 @@
1
- import datetime
2
- import pickle
3
- from pathlib import Path
4
- from typing import Any, Dict
5
-
6
- import bs4
7
- import numpy as np
8
- import pytest
9
- import requests
10
-
11
- import pixeltable as pxt
12
- from pixeltable import catalog
13
- from pixeltable import exceptions as excs
14
- from pixeltable.iterators import FrameIterator
15
- from pixeltable.tests.utils import get_video_files, get_audio_files, skip_test_if_not_installed
16
-
17
-
18
- class TestDataFrame:
19
- def test_select_where(self, test_tbl: catalog.Table) -> None:
20
- t = test_tbl
21
- res1 = t[t.c1, t.c2, t.c3].show(0)
22
- res2 = t.select(t.c1, t.c2, t.c3).show(0)
23
- assert res1 == res2
24
-
25
- res1 = t[t.c2 < 10][t.c1, t.c2, t.c3].show(0)
26
- res2 = t.where(t.c2 < 10).select(t.c1, t.c2, t.c3).show(0)
27
- assert res1 == res2
28
-
29
- res3 = t.where(t.c2 < 10).select(c1=t.c1, c2=t.c2, c3=t.c3).show(0)
30
- assert res1 == res3
31
-
32
- res4 = t.where(t.c2 < 10).select(t.c1, c2=t.c2, c3=t.c3).show(0)
33
- assert res1 == res4
34
-
35
- _ = t.where(t.c2 < 10).select(t.c2, t.c2).show(0) # repeated name no error
36
-
37
- # duplicate select list
38
- with pytest.raises(excs.Error) as exc_info:
39
- _ = t.select(t.c1).select(t.c2).show(0)
40
- assert 'already specified' in str(exc_info.value)
41
-
42
- # invalid expr in select list: Callable is not a valid literal
43
- with pytest.raises(TypeError) as exc_info:
44
- _ = t.select(datetime.datetime.now).show(0)
45
- assert 'Not a valid literal' in str(exc_info.value)
46
-
47
- # catch invalid name in select list from user input
48
- # only check stuff that's not caught by python kwargs checker
49
- with pytest.raises(excs.Error) as exc_info:
50
- _ = t.select(t.c1, **{'c2-1': t.c2}).show(0)
51
- assert 'Invalid name' in str(exc_info.value)
52
-
53
- with pytest.raises(excs.Error) as exc_info:
54
- _ = t.select(t.c1, **{'': t.c2}).show(0)
55
- assert 'Invalid name' in str(exc_info.value)
56
-
57
- with pytest.raises(excs.Error) as exc_info:
58
- _ = t.select(t.c1, **{'foo.bar': t.c2}).show(0)
59
- assert 'Invalid name' in str(exc_info.value)
60
-
61
- with pytest.raises(excs.Error) as exc_info:
62
- _ = t.select(t.c1, _c3=t.c2).show(0)
63
- assert 'Invalid name' in str(exc_info.value)
64
-
65
- # catch repeated name from user input
66
- with pytest.raises(excs.Error) as exc_info:
67
- _ = t.select(t.c2, c2=t.c1).show(0)
68
- assert 'Repeated column name' in str(exc_info.value)
69
-
70
- with pytest.raises(excs.Error) as exc_info:
71
- _ = t.select(t.c2+1, col_0=t.c2).show(0)
72
- assert 'Repeated column name' in str(exc_info.value)
73
-
74
- def test_result_set_iterator(self, test_tbl: catalog.Table) -> None:
75
- t = test_tbl
76
- res = t.select(t.c1, t.c2, t.c3).collect()
77
- pd_df = res.to_pandas()
78
-
79
- def check_row(row: Dict[str, Any], idx: int) -> None:
80
- assert len(row) == 3
81
- assert 'c1' in row
82
- assert row['c1'] == pd_df['c1'][idx]
83
- assert 'c2' in row
84
- assert row['c2'] == pd_df['c2'][idx]
85
- assert 'c3' in row
86
- assert row['c3'] == pd_df['c3'][idx]
87
-
88
- # row iteration
89
- for idx, row in enumerate(res):
90
- check_row(row, idx)
91
-
92
- # row access
93
- row = res[0]
94
- check_row(row, 0)
95
-
96
- # column access
97
- col_values = res['c2']
98
- assert col_values == pd_df['c2'].values.tolist()
99
-
100
- # cell access
101
- assert res[0, 'c2'] == pd_df['c2'][0]
102
- assert res[0, 'c2'] == res[0, 1]
103
-
104
- with pytest.raises(excs.Error) as exc_info:
105
- _ = res['does_not_exist']
106
- assert 'Invalid column name' in str(exc_info.value)
107
-
108
- with pytest.raises(excs.Error) as exc_info:
109
- _ = res[0, 'does_not_exist']
110
- assert 'Invalid column name' in str(exc_info.value)
111
-
112
- with pytest.raises(excs.Error) as exc_info:
113
- _ = res[0, 0, 0]
114
- assert 'Bad index' in str(exc_info.value)
115
-
116
- with pytest.raises(excs.Error) as exc_info:
117
- _ = res['c2', 0]
118
- assert 'Bad index' in str(exc_info.value)
119
-
120
- def test_order_by(self, test_tbl: catalog.Table) -> None:
121
- t = test_tbl
122
- res = t.select(t.c4, t.c2).order_by(t.c4).order_by(t.c2, asc=False).show(0)
123
-
124
- # invalid expr in order_by()
125
- with pytest.raises(excs.Error) as exc_info:
126
- _ = t.order_by(datetime.datetime.now()).show(0)
127
- assert 'Invalid expression' in str(exc_info.value)
128
-
129
- def test_head_tail(self, test_tbl: catalog.Table) -> None:
130
- t = test_tbl
131
- res = t.head(10).to_pandas()
132
- assert np.all(res.c2 == list(range(10)))
133
- # Where is applied
134
- res = t.where(t.c2 > 9).head(10).to_pandas()
135
- assert np.all(res.c2 == list(range(10, 20)))
136
- # order_by() is an error
137
- with pytest.raises(excs.Error) as exc_info:
138
- _ = t.order_by(t.c2).head(10)
139
- assert 'cannot be used with order_by' in str(exc_info.value)
140
-
141
- res = t.tail().to_pandas()
142
- assert np.all(res.c2 == list(range(90, 100)))
143
- res = t.where(t.c2 < 90).tail().to_pandas()
144
- assert np.all(res.c2 == list(range(80, 90)))
145
- # order_by() is an error
146
- with pytest.raises(excs.Error) as exc_info:
147
- _ = t.order_by(t.c2).tail(10)
148
- assert 'cannot be used with order_by' in str(exc_info.value)
149
-
150
- def test_describe(self, test_tbl: catalog.Table) -> None:
151
- t = test_tbl
152
- df = t.select(t.c1).where(t.c2 < 10).limit(10)
153
- df.describe()
154
-
155
- # TODO: how to you check the output of these?
156
- _ = df.__repr__()
157
- _ = df._repr_html_()
158
-
159
- def test_count(self, test_tbl: catalog.Table, indexed_img_tbl: catalog.Table) -> None:
160
- skip_test_if_not_installed('nos')
161
- t = test_tbl
162
- cnt = t.count()
163
- assert cnt == 100
164
-
165
- cnt = t.where(t.c2 < 10).count()
166
- assert cnt == 10
167
-
168
- # count() doesn't work with similarity search
169
- t = indexed_img_tbl
170
- probe = t.select(t.img).show(1)
171
- img = probe[0, 0]
172
- with pytest.raises(excs.Error):
173
- _ = t.where(t.img.nearest(img)).count()
174
- with pytest.raises(excs.Error):
175
- _ = t.where(t.img.nearest('car')).count()
176
-
177
- # for now, count() doesn't work with non-SQL Where clauses
178
- with pytest.raises(excs.Error):
179
- _ = t.where(t.img.width > 100).count()
180
-
181
- def test_select_literal(self, test_tbl: catalog.Table) -> None:
182
- t = test_tbl
183
- res = t.select(1.0).where(t.c2 < 10).collect()
184
- assert res[res.column_names()[0]] == [1.0] * 10
185
-
186
- # TODO This test doesn't work on Windows due to reliance on the structure of file URLs
187
- @pytest.mark.skip('Test is not portable')
188
- def test_html_media_url(self, test_client: pxt.Client) -> None:
189
- tab = test_client.create_table('test_html_repr', {'video': pxt.VideoType(), 'audio': pxt.AudioType()})
190
- status = tab.insert(video=get_video_files()[0], audio=get_audio_files()[0])
191
- assert status.num_rows == 1
192
- assert status.num_excs == 0
193
-
194
- res = tab.select(tab.video, tab.audio).collect()
195
- doc = bs4.BeautifulSoup(res._repr_html_(), features='html.parser')
196
- video_tags = doc.find_all('video')
197
- assert len(video_tags) == 1
198
- audio_tags = doc.find_all('audio')
199
- assert len(audio_tags) == 1
200
-
201
- # get the source elements and test their src attributes
202
- for tag in video_tags + audio_tags:
203
- sources = tag.find_all('source')
204
- assert len(sources) == 1
205
- for src in sources:
206
- response = requests.get(src['src'])
207
- assert response.status_code == 200
208
-
209
- def test_to_pytorch_dataset(self, all_datatypes_tbl: catalog.Table):
210
- """ tests all types are handled correctly in this conversion
211
- """
212
- skip_test_if_not_installed('torch')
213
- import torch
214
-
215
- t = all_datatypes_tbl
216
- df = t.where(t.row_id < 1)
217
- assert df.count() > 0
218
- ds = df.to_pytorch_dataset()
219
- type_dict = dict(zip(df.get_column_names(),df.get_column_types()))
220
- for tup in ds:
221
- for col in df.get_column_names():
222
- assert col in tup
223
-
224
- arrval = tup['c_array']
225
- assert isinstance(arrval, np.ndarray)
226
- col_type = type_dict['c_array']
227
- assert arrval.dtype == col_type.numpy_dtype()
228
- assert arrval.shape == col_type.shape
229
- assert arrval.dtype == np.float32
230
- assert arrval.flags["WRITEABLE"], 'required by pytorch collate function'
231
-
232
- assert isinstance(tup['c_bool'], bool)
233
- assert isinstance(tup['c_int'], int)
234
- assert isinstance(tup['c_float'], float)
235
- assert isinstance(tup['c_timestamp'], float)
236
- assert torch.is_tensor(tup['c_image'])
237
- assert isinstance(tup['c_video'], str)
238
- assert isinstance(tup['c_json'], dict)
239
-
240
- def test_to_pytorch_image_format(self, all_datatypes_tbl: catalog.Table) -> None:
241
- """ tests the image_format parameter is honored
242
- """
243
- skip_test_if_not_installed('torch')
244
- import torch
245
- import torchvision.transforms as T
246
-
247
- W, H = 220, 224 # make different from each other
248
- t = all_datatypes_tbl
249
- df = t.select(
250
- t.row_id,
251
- t.c_image,
252
- c_image_xformed=t.c_image.resize([W, H]).convert('RGB')
253
- ).where(t.row_id < 1)
254
-
255
- pandas_df = df.show().to_pandas()
256
- im_plain = pandas_df['c_image'].values[0]
257
- im_xformed = pandas_df['c_image_xformed'].values[0]
258
- assert pandas_df.shape[0] == 1
259
-
260
- ds = df.to_pytorch_dataset(image_format='np')
261
- ds_ptformat = df.to_pytorch_dataset(image_format='pt')
262
-
263
- elt_count = 0
264
- for elt, elt_pt in zip(ds, ds_ptformat):
265
- arr_plain = elt['c_image']
266
- assert isinstance(arr_plain, np.ndarray)
267
- assert arr_plain.flags["WRITEABLE"], 'required by pytorch collate function'
268
-
269
- # NB: compare numpy array bc PIL.Image object itself is not using same file.
270
- assert (arr_plain == np.array(im_plain)).all(), 'numpy image should be the same as the original'
271
- arr_xformed = elt['c_image_xformed']
272
- assert isinstance(arr_xformed, np.ndarray)
273
- assert arr_xformed.flags["WRITEABLE"], 'required by pytorch collate function'
274
-
275
- assert arr_xformed.shape == (H, W, 3)
276
- assert arr_xformed.dtype == np.uint8
277
- # same as above, compare numpy array bc PIL.Image object itself is not using same file.
278
- assert (arr_xformed == np.array(im_xformed)).all(),\
279
- 'numpy image array for xformed image should be the same as the original'
280
-
281
- # now compare pytorch version
282
- arr_pt = elt_pt['c_image']
283
- assert torch.is_tensor(arr_pt)
284
- arr_pt = elt_pt['c_image_xformed']
285
- assert torch.is_tensor(arr_pt)
286
- assert arr_pt.shape == (3, H, W)
287
- assert arr_pt.dtype == torch.float32
288
- assert (0.0 <= arr_pt).all()
289
- assert (arr_pt <= 1.0).all()
290
- assert torch.isclose(T.ToTensor()(arr_xformed), arr_pt).all(),\
291
- 'pytorch image should be consistent with numpy image'
292
- elt_count += 1
293
- assert elt_count == 1
294
-
295
- @pytest.mark.skip('Flaky test (fails intermittently)')
296
- def test_to_pytorch_dataloader(self, all_datatypes_tbl: catalog.Table) -> None:
297
- """ Tests the dataset works well with pytorch dataloader:
298
- 1. compatibility with multiprocessing
299
- 2. compatibility of all types with default collate_fn
300
- """
301
- skip_test_if_not_installed('torch')
302
- import torch.utils.data
303
- @pxt.udf(param_types=[pxt.JsonType()], return_type=pxt.JsonType())
304
- def restrict_json_for_default_collate(obj):
305
- keys = ['id', 'label', 'iscrowd', 'bounding_box']
306
- return {k: obj[k] for k in keys}
307
-
308
- t = all_datatypes_tbl
309
- df = t.select(
310
- t.row_id,
311
- t.c_int,
312
- t.c_float,
313
- t.c_bool,
314
- t.c_timestamp,
315
- t.c_array,
316
- t.c_video,
317
- # default collate_fn doesnt support null values, nor lists of different lengths
318
- # but does allow some dictionaries if they are uniform
319
- c_json = restrict_json_for_default_collate(t.c_json.detections[0]),
320
- # images must be uniform shape for pytorch collate_fn to not fail
321
- c_image=t.c_image.resize([220, 224]).convert('RGB')
322
- )
323
- df_size = df.count()
324
- ds = df.to_pytorch_dataset(image_format='pt')
325
- # test serialization:
326
- # - pickle.dumps() and pickle.loads() must work so that
327
- # we can use num_workers > 0
328
- x = pickle.dumps(ds)
329
- _ = pickle.loads(x)
330
-
331
- # test we get all rows
332
- def check_recover_all_rows(ds, size : int, **kwargs):
333
- dl = torch.utils.data.DataLoader(ds, **kwargs)
334
- loaded_ids = set()
335
- for batch in dl:
336
- for row_id in batch['row_id']:
337
- val = int(row_id) # np.int -> int or will fail set equality test below.
338
- assert val not in loaded_ids, val
339
- loaded_ids.add(val)
340
-
341
- assert loaded_ids == set(range(size))
342
-
343
- # check different number of workers
344
- check_recover_all_rows(ds, size=df_size, batch_size=3, num_workers=0) # within this process
345
- check_recover_all_rows(ds, size=df_size, batch_size=3, num_workers=2) # two separate processes
346
-
347
- # check edge case where some workers get no rows
348
- short_size = 1
349
- df_short = df.where(t.row_id < short_size)
350
- ds_short = df_short.to_pytorch_dataset(image_format='pt')
351
- check_recover_all_rows(ds_short, size=short_size, batch_size=13, num_workers=short_size+1)
352
-
353
- def test_pytorch_dataset_caching(self, all_datatypes_tbl: catalog.Table) -> None:
354
- """ Tests that dataset caching works
355
- 1. using the same dataset twice in a row uses the cache
356
- 2. adding a row to the table invalidates the cached version
357
- 3. changing the select list invalidates the cached version
358
- """
359
- skip_test_if_not_installed('torch')
360
- t = all_datatypes_tbl
361
-
362
- t.drop_column('c_video') # null value video column triggers internal assertions in DataRow
363
- # see https://github.com/pixeltable/pixeltable/issues/38
364
-
365
- t.drop_column('c_array') # no support yet for null array values in the pytorch dataset
366
-
367
- def _get_mtimes(dir: Path):
368
- return {p.name: p.stat().st_mtime for p in dir.iterdir()}
369
-
370
- # check result cached
371
- ds1 = t.to_pytorch_dataset(image_format='pt')
372
- ds1_mtimes = _get_mtimes(ds1.path)
373
-
374
- ds2 = t.to_pytorch_dataset(image_format='pt')
375
- ds2_mtimes = _get_mtimes(ds2.path)
376
- assert ds2.path == ds1.path, 'result should be cached'
377
- assert ds2_mtimes == ds1_mtimes, 'no extra file system work should have occurred'
378
-
379
- # check invalidation on insert
380
- t_size = t.count()
381
- t.insert(row_id=t_size)
382
- ds3 = t.to_pytorch_dataset(image_format='pt')
383
- assert ds3.path != ds1.path, 'different path should be used'
384
-
385
- # check select list invalidation
386
- ds4 = t.select(t.row_id).to_pytorch_dataset(image_format='pt')
387
- assert ds4.path != ds3.path, 'different select list, hence different path should be used'
388
-
389
- def test_to_coco(self, test_client: pxt.Client) -> None:
390
- skip_test_if_not_installed('nos')
391
- from pycocotools.coco import COCO
392
- cl = test_client
393
- base_t = cl.create_table('videos', {'video': pxt.VideoType()})
394
- args = {'video': base_t.video, 'fps': 1}
395
- view_t = cl.create_view('frames', base_t, iterator_class=FrameIterator, iterator_args=args)
396
- from pixeltable.functions.nos.object_detection_2d import yolox_medium
397
- view_t.add_column(detections=yolox_medium(view_t.frame))
398
- base_t.insert(video=get_video_files()[0])
399
-
400
- @pxt.udf(return_type=pxt.JsonType(nullable=False), param_types=[pxt.JsonType(nullable=False)])
401
- def yolo_to_coco(detections):
402
- bboxes, labels = detections['bboxes'], detections['labels']
403
- num_annotations = len(detections['bboxes'])
404
- assert num_annotations == len(detections['labels'])
405
- result = []
406
- for i in range(num_annotations):
407
- bbox = bboxes[i]
408
- ann = {
409
- 'bbox': [round(bbox[0]), round(bbox[1]), round(bbox[2] - bbox[0]), round(bbox[3] - bbox[1])],
410
- 'category': labels[i],
411
- }
412
- result.append(ann)
413
- return result
414
-
415
- query = view_t.select({'image': view_t.frame, 'annotations': yolo_to_coco(view_t.detections)})
416
- path = query.to_coco_dataset()
417
- # we get a valid COCO dataset
418
- coco_ds = COCO(path)
419
- assert len(coco_ds.imgs) == view_t.count()
420
-
421
- # we call to_coco_dataset() again and get the cached dataset
422
- new_path = query.to_coco_dataset()
423
- assert path == new_path
424
-
425
- # the cache is invalidated when we add more data
426
- base_t.insert(video=get_video_files()[1])
427
- new_path = query.to_coco_dataset()
428
- assert path != new_path
429
- coco_ds = COCO(new_path)
430
- assert len(coco_ds.imgs) == view_t.count()
431
-
432
- # incorrect select list
433
- with pytest.raises(excs.Error) as exc_info:
434
- _ = view_t.select({'image': view_t.frame, 'annotations': view_t.detections}).to_coco_dataset()
435
- assert '"annotations" is not a list' in str(exc_info.value)
436
-
437
- with pytest.raises(excs.Error) as exc_info:
438
- _ = view_t.select(view_t.detections).to_coco_dataset()
439
- assert 'missing key "image"' in str(exc_info.value).lower()
@@ -1,107 +0,0 @@
1
- import pytest
2
-
3
- import pixeltable as pxt
4
- from pixeltable import exceptions as excs
5
- from pixeltable.tests.utils import make_tbl
6
-
7
-
8
- class TestDirs:
9
- def test_create(self, test_client: pxt.Client) -> None:
10
- cl = test_client
11
- dirs = ['dir1', 'dir1.sub1', 'dir1.sub1.subsub1']
12
- for name in dirs:
13
- cl.create_dir(name)
14
-
15
- # invalid names
16
- with pytest.raises(excs.Error):
17
- cl.create_dir('1dir')
18
- with pytest.raises(excs.Error):
19
- cl.create_dir('_dir1')
20
- with pytest.raises(excs.Error):
21
- cl.create_dir('dir 1')
22
- with pytest.raises(excs.Error):
23
- cl.create_dir('dir1..sub2')
24
- with pytest.raises(excs.Error):
25
- cl.create_dir('dir1.sub2.')
26
- with pytest.raises(excs.Error):
27
- cl.create_dir('dir1:sub2.')
28
-
29
- # existing dirs
30
- with pytest.raises(excs.Error):
31
- cl.create_dir('dir1')
32
- cl.create_dir('dir1', ignore_errors=True)
33
- with pytest.raises(excs.Error):
34
- cl.create_dir('dir1.sub1')
35
- with pytest.raises(excs.Error):
36
- cl.create_dir('dir1.sub1.subsub1')
37
-
38
- # existing table
39
- make_tbl(cl, 'dir1.t1')
40
- with pytest.raises(excs.Error):
41
- cl.create_dir('dir1.t1')
42
-
43
- with pytest.raises(excs.Error):
44
- cl.create_dir('dir2.sub2')
45
- make_tbl(cl, 't2')
46
- with pytest.raises(excs.Error):
47
- cl.create_dir('t2.sub2')
48
-
49
- # new client: force loading from store
50
- cl2 = pxt.Client(reload=True)
51
-
52
- listing = cl2.list_dirs(recursive=True)
53
- assert listing == dirs
54
- listing = cl2.list_dirs(recursive=False)
55
- assert listing == ['dir1']
56
- listing = cl2.list_dirs('dir1', recursive=True)
57
- assert listing == ['dir1.sub1', 'dir1.sub1.subsub1']
58
- listing = cl2.list_dirs('dir1', recursive=False)
59
- assert listing == ['dir1.sub1']
60
- listing = cl2.list_dirs('dir1.sub1', recursive=True)
61
- assert listing == ['dir1.sub1.subsub1']
62
- listing = cl2.list_dirs('dir1.sub1', recursive=False)
63
- assert listing == ['dir1.sub1.subsub1']
64
-
65
- def test_rm(self, test_client: pxt.Client) -> None:
66
- cl = test_client
67
- dirs = ['dir1', 'dir1.sub1', 'dir1.sub1.subsub1']
68
- for name in dirs:
69
- cl.create_dir(name)
70
- make_tbl(cl, 't1')
71
- make_tbl(cl, 'dir1.t1')
72
-
73
- # bad name
74
- with pytest.raises(excs.Error):
75
- cl.rm_dir('1dir')
76
- # bad path
77
- with pytest.raises(excs.Error):
78
- cl.rm_dir('dir1..sub1')
79
- # doesn't exist
80
- with pytest.raises(excs.Error):
81
- cl.rm_dir('dir2')
82
- # not empty
83
- with pytest.raises(excs.Error):
84
- cl.rm_dir('dir1')
85
-
86
- cl.rm_dir('dir1.sub1.subsub1')
87
- assert cl.list_dirs('dir1.sub1') == []
88
-
89
- # check after reloading
90
- cl = pxt.Client(reload=True)
91
- assert cl.list_dirs('dir1.sub1') == []
92
-
93
- def test_move(self, test_client: pxt.Client) -> None:
94
- cl = test_client
95
- cl.create_dir('dir1')
96
- cl.create_dir('dir1.sub1')
97
- make_tbl(cl, 'dir1.sub1.t1')
98
- assert cl.list_tables('dir1') == ['dir1.sub1.t1']
99
- cl.move('dir1.sub1.t1', 'dir1.sub1.t2')
100
- assert cl.list_tables('dir1') == ['dir1.sub1.t2']
101
- cl.create_dir('dir2')
102
- cl.move('dir1', 'dir2.dir1')
103
- assert cl.list_tables('dir2') == ['dir2.dir1.sub1.t2']
104
-
105
- # new client: force loading from store
106
- cl2 = pxt.Client(reload=True)
107
- assert cl2.list_tables('dir2') == ['dir2.dir1.sub1.t2']
@@ -1,120 +0,0 @@
1
- import itertools
2
- import json
3
- import re
4
- from typing import Optional, Set, List
5
-
6
- import pytest
7
-
8
- import pixeltable as pxt
9
- from pixeltable.iterators.document import DocumentSplitter
10
- from pixeltable.tests.utils import get_documents, get_video_files, get_audio_files, get_image_files
11
- from pixeltable.tests.utils import skip_test_if_not_installed
12
- from pixeltable.type_system import DocumentType
13
-
14
-
15
- class TestDocument:
16
- def valid_doc_paths(self) -> List[str]:
17
- return get_documents()
18
-
19
- def invalid_doc_paths(self) -> List[str]:
20
- return [get_video_files()[0], get_audio_files()[0], get_image_files()[0]]
21
-
22
- def test_insert(self, test_client: pxt.Client) -> None:
23
- file_paths = self.valid_doc_paths()
24
- cl = test_client
25
- doc_t = cl.create_table('docs', {'doc': DocumentType()})
26
- status = doc_t.insert({'doc': p} for p in file_paths)
27
- assert status.num_rows == len(file_paths)
28
- assert status.num_excs == 0
29
- stored_paths = doc_t.select(output=doc_t.doc.localpath).collect()['output']
30
- assert set(stored_paths) == set(file_paths)
31
-
32
- file_paths = self.invalid_doc_paths()
33
- status = doc_t.insert(({'doc': p} for p in file_paths), fail_on_exception=False)
34
- assert status.num_rows == len(file_paths)
35
- assert status.num_excs == len(file_paths)
36
-
37
- def test_doc_splitter(self, test_client: pxt.Client) -> None:
38
- skip_test_if_not_installed('tiktoken')
39
- file_paths = self.valid_doc_paths()
40
- cl = test_client
41
- doc_t = cl.create_table('docs', {'doc': DocumentType()})
42
- status = doc_t.insert({'doc': p} for p in file_paths)
43
- assert status.num_excs == 0
44
-
45
- def normalize(s: str) -> str:
46
- # remove whitespace
47
- res = re.sub(r'\s+', '', s)
48
- # remove non-ascii
49
- res = res.encode('ascii', 'ignore').decode()
50
- return res
51
-
52
- # run all combinations of (heading, paragraph, sentence) x (token_limit, char_limit, None)
53
- # and make sure they extract the same text in aggregate
54
- all_text_reference: Optional[str] = None # all text as a single string; normalized
55
- headings_reference: Set[str] = {} # headings metadata as a json-serialized string
56
- for sep1 in ['heading', 'paragraph', 'sentence']:
57
- for sep2 in [None, 'token_limit', 'char_limit']:
58
- chunk_limits = [10, 20, 100] if sep2 is not None else [None]
59
- for limit in chunk_limits:
60
- iterator_args = {
61
- 'document': doc_t.doc,
62
- 'separators': sep1 + (',' + sep2 if sep2 is not None else ''),
63
- 'metadata': 'title,headings,sourceline'
64
- }
65
- if sep2 is not None:
66
- iterator_args['limit'] = limit
67
- iterator_args['overlap'] = 0
68
- chunks_t = cl.create_view(
69
- f'chunks', doc_t, iterator_class=DocumentSplitter, iterator_args=iterator_args)
70
- res = list(chunks_t.order_by(chunks_t.doc, chunks_t.pos).collect())
71
-
72
- if all_text_reference is None:
73
- all_text_reference = normalize(''.join([r['text'] for r in res]))
74
- headings_reference = set(json.dumps(r['headings']) for r in res)
75
- else:
76
- all_text = normalize(''.join([r['text'] for r in res]))
77
- headings = set(json.dumps(r['headings']) for r in res)
78
-
79
- # for debugging
80
- first_diff_index = next(
81
- (i for i, (c1, c2) in enumerate(zip(all_text, all_text_reference)) if c1 != c2),
82
- len(all_text) if len(all_text) != len(all_text_reference) else None)
83
- if first_diff_index is not None:
84
- a = all_text[max(0, first_diff_index - 10):first_diff_index + 10]
85
- b = all_text_reference[max(0, first_diff_index - 10):first_diff_index + 10]
86
-
87
- assert all_text == all_text_reference, f'{sep1}, {sep2}, {limit}'
88
- assert headings == headings_reference, f'{sep1}, {sep2}, {limit}'
89
- # TODO: verify chunk limit
90
- cl.drop_table('chunks')
91
-
92
- def test_doc_splitter_headings(self, test_client: pxt.Client) -> None:
93
- skip_test_if_not_installed('spacy')
94
- file_paths = self.valid_doc_paths()
95
- cl = test_client
96
- doc_t = cl.create_table('docs', {'doc': DocumentType()})
97
- status = doc_t.insert({'doc': p} for p in file_paths)
98
- assert status.num_excs == 0
99
-
100
- # verify that only the requested metadata is present in the view
101
- md_elements = ['title', 'headings', 'sourceline']
102
- md_tuples = list(itertools.chain.from_iterable(itertools.combinations(md_elements, i) for i in range(len(md_elements) + 1)))
103
- _ = [','.join(t) for t in md_tuples]
104
- for md_str in [','.join(t) for t in md_tuples]:
105
- iterator_args = {
106
- 'document': doc_t.doc,
107
- 'separators': 'sentence',
108
- 'metadata': md_str
109
- }
110
- chunks_t = cl.create_view(
111
- f'chunks', doc_t, iterator_class=DocumentSplitter, iterator_args=iterator_args)
112
- res = chunks_t.order_by(chunks_t.doc, chunks_t.pos).collect()
113
- requested_md_elements = set(md_str.split(','))
114
- for md_element in md_elements:
115
- if md_element in requested_md_elements:
116
- _ = res[md_element]
117
- else:
118
- with pytest.raises(pxt.Error):
119
- _ = res[md_element]
120
- cl.drop_table('chunks')