pixeltable 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

@@ -0,0 +1,111 @@
1
+ import pytest
2
+
3
+ import pixeltable as pxt
4
+ import pixeltable.exceptions as excs
5
+ from pixeltable.tests.utils import skip_test_if_not_installed, validate_update_status
6
+
7
+
8
+ class TestTogether:
9
+
10
+ def test_completions(self, test_client: pxt.Client) -> None:
11
+ skip_test_if_not_installed('together')
12
+ TestTogether.skip_test_if_no_together_client()
13
+ cl = test_client
14
+ t = cl.create_table('test_tbl', {'input': pxt.StringType()})
15
+ from pixeltable.functions.together import completions
16
+ t.add_column(output=completions(prompt=t.input, model='mistralai/Mixtral-8x7B-v0.1', stop=['\n']))
17
+ t.add_column(output_2=completions(
18
+ prompt=t.input,
19
+ model='mistralai/Mixtral-8x7B-v0.1',
20
+ max_tokens=300,
21
+ stop=['\n'],
22
+ temperature=0.7,
23
+ top_p=0.9,
24
+ top_k=40,
25
+ repetition_penalty=1.1,
26
+ logprobs=1,
27
+ echo=True,
28
+ n=3,
29
+ safety_model='Meta-Llama/Llama-Guard-7b'
30
+ ))
31
+ validate_update_status(t.insert(input='I am going to the '), 1)
32
+ result = t.collect()
33
+ assert len(result['output'][0]['choices'][0]['text']) > 0
34
+ assert len(result['output_2'][0]['choices'][0]['text']) > 0
35
+
36
+ def test_chat_completions(self, test_client: pxt.Client) -> None:
37
+ skip_test_if_not_installed('together')
38
+ TestTogether.skip_test_if_no_together_client()
39
+ cl = test_client
40
+ t = cl.create_table('test_tbl', {'input': pxt.StringType()})
41
+ messages = [{'role': 'user', 'content': t.input}]
42
+ from pixeltable.functions.together import chat_completions
43
+ t.add_column(output=chat_completions(messages=messages, model='mistralai/Mixtral-8x7B-v0.1', stop=['\n']))
44
+ t.add_column(output_2=chat_completions(
45
+ messages=messages,
46
+ model='mistralai/Mixtral-8x7B-Instruct-v0.1',
47
+ max_tokens=300,
48
+ stop=['\n'],
49
+ temperature=0.7,
50
+ top_p=0.9,
51
+ top_k=40,
52
+ repetition_penalty=1.1,
53
+ logprobs=1,
54
+ echo=True,
55
+ n=3,
56
+ safety_model='Meta-Llama/Llama-Guard-7b',
57
+ response_format={'type': 'json_object'}
58
+ ))
59
+ validate_update_status(t.insert(input='Give me a typical example of a JSON structure.'), 1)
60
+ result = t.collect()
61
+ assert len(result['output'][0]['choices'][0]['message']) > 0
62
+ assert len(result['output_2'][0]['choices'][0]['message']) > 0
63
+
64
+ def test_embeddings(self, test_client: pxt.Client) -> None:
65
+ skip_test_if_not_installed('together')
66
+ TestTogether.skip_test_if_no_together_client()
67
+ cl = test_client
68
+ t = cl.create_table('test_tbl', {'input': pxt.StringType()})
69
+ from pixeltable.functions.together import embeddings
70
+ t.add_column(embed=embeddings(input=t.input, model='togethercomputer/m2-bert-80M-8k-retrieval'))
71
+ validate_update_status(t.insert(input='Together AI provides a variety of embeddings models.'), 1)
72
+ assert len(t.collect()['embed'][0]) > 0
73
+
74
+ def test_image_generations(self, test_client: pxt.Client) -> None:
75
+ skip_test_if_not_installed('together')
76
+ TestTogether.skip_test_if_no_together_client()
77
+ cl = test_client
78
+ t = cl.create_table(
79
+ 'test_tbl',
80
+ {'input': pxt.StringType(), 'negative_prompt': pxt.StringType(nullable=True)}
81
+ )
82
+ from pixeltable.functions.together import image_generations
83
+ t.add_column(img=image_generations(t.input, model='runwayml/stable-diffusion-v1-5'))
84
+ t.add_column(img_2=image_generations(
85
+ t.input,
86
+ model='stabilityai/stable-diffusion-2-1',
87
+ steps=30,
88
+ seed=4178780,
89
+ height=768,
90
+ width=512,
91
+ negative_prompt=t.negative_prompt
92
+ ))
93
+ validate_update_status(t.insert([
94
+ {'input': 'A friendly dinosaur playing tennis in a cornfield'},
95
+ {'input': 'A friendly dinosaur playing tennis in a cornfield',
96
+ 'negative_prompt': 'tennis court'}
97
+ ]), 2)
98
+ assert t.collect()['img'][0].size == (512, 512)
99
+ assert t.collect()['img_2'][0].size == (512, 768)
100
+ assert t.collect()['img'][1].size == (512, 512)
101
+ assert t.collect()['img_2'][1].size == (512, 768)
102
+
103
+ # This ensures that the test will be skipped, rather than returning an error, when no API key is
104
+ # available (for example, when a PR runs in CI).
105
+ @staticmethod
106
+ def skip_test_if_no_together_client() -> None:
107
+ try:
108
+ import pixeltable.functions.together
109
+ _ = pixeltable.functions.together.together_client()
110
+ except excs.Error as exc:
111
+ pytest.skip(str(exc))
@@ -33,7 +33,7 @@ class TestDataFrame:
33
33
  assert res1 == res4
34
34
 
35
35
  _ = t.where(t.c2 < 10).select(t.c2, t.c2).show(0) # repeated name no error
36
-
36
+
37
37
  # duplicate select list
38
38
  with pytest.raises(excs.Error) as exc_info:
39
39
  _ = t.select(t.c1).select(t.c2).show(0)
@@ -220,7 +220,7 @@ class TestDataFrame:
220
220
  for tup in ds:
221
221
  for col in df.get_column_names():
222
222
  assert col in tup
223
-
223
+
224
224
  arrval = tup['c_array']
225
225
  assert isinstance(arrval, np.ndarray)
226
226
  col_type = type_dict['c_array']
@@ -304,7 +304,7 @@ class TestDataFrame:
304
304
  def restrict_json_for_default_collate(obj):
305
305
  keys = ['id', 'label', 'iscrowd', 'bounding_box']
306
306
  return {k: obj[k] for k in keys}
307
-
307
+
308
308
  t = all_datatypes_tbl
309
309
  df = t.select(
310
310
  t.row_id,
@@ -370,7 +370,7 @@ class TestDataFrame:
370
370
  # check result cached
371
371
  ds1 = t.to_pytorch_dataset(image_format='pt')
372
372
  ds1_mtimes = _get_mtimes(ds1.path)
373
-
373
+
374
374
  ds2 = t.to_pytorch_dataset(image_format='pt')
375
375
  ds2_mtimes = _get_mtimes(ds2.path)
376
376
  assert ds2.path == ds1.path, 'result should be cached'
@@ -8,6 +8,7 @@ import PIL
8
8
  import cv2
9
9
  import numpy as np
10
10
  import pandas as pd
11
+ import pathlib
11
12
  import pytest
12
13
 
13
14
  import pixeltable as pxt
@@ -17,7 +18,7 @@ from pixeltable import exceptions as excs
17
18
  from pixeltable.iterators import FrameIterator
18
19
  from pixeltable.tests.utils import \
19
20
  make_tbl, create_table_data, read_data_file, get_video_files, get_audio_files, get_image_files, get_documents, \
20
- assert_resultset_eq
21
+ assert_resultset_eq, assert_hf_dataset_equal, make_test_arrow_table
21
22
  from pixeltable.tests.utils import skip_test_if_not_installed
22
23
  from pixeltable.type_system import \
23
24
  StringType, IntType, FloatType, TimestampType, ImageType, VideoType, JsonType, BoolType, ArrayType, AudioType, \
@@ -25,7 +26,6 @@ from pixeltable.type_system import \
25
26
  from pixeltable.utils.filecache import FileCache
26
27
  from pixeltable.utils.media_store import MediaStore
27
28
 
28
-
29
29
  class TestTable:
30
30
  # exc for a % 10 == 0
31
31
  @pxt.udf(return_type=FloatType(), param_types=[IntType()])
@@ -116,6 +116,100 @@ class TestTable:
116
116
  tbl.revert()
117
117
  assert tbl.num_retained_versions == num_retained_versions
118
118
 
119
+ def test_import_parquet(self, test_client: pxt.Client, tmp_path: pathlib.Path) -> None:
120
+ skip_test_if_not_installed('pyarrow')
121
+ import pyarrow as pa
122
+ from pixeltable.utils.arrow import iter_tuples
123
+
124
+ parquet_dir = tmp_path / 'test_data'
125
+ parquet_dir.mkdir()
126
+ make_test_arrow_table(parquet_dir)
127
+
128
+ tab = test_client.import_parquet('test_parquet', parquet_path=str(parquet_dir))
129
+ assert 'test_parquet' in test_client.list_tables()
130
+ assert tab is not None
131
+ num_elts = tab.count()
132
+ arrow_tab: pa.Table = pa.parquet.read_table(str(parquet_dir))
133
+ assert num_elts == arrow_tab.num_rows
134
+ assert set(tab.column_names()) == set(arrow_tab.column_names)
135
+
136
+ result_set = tab.order_by(tab.c_id).collect()
137
+ column_types = tab.column_types()
138
+
139
+ for tup, arrow_tup in zip(result_set, iter_tuples(arrow_tab)):
140
+ assert tup['c_id'] == arrow_tup['c_id']
141
+ for col, val in tup.items():
142
+ if val is None:
143
+ assert arrow_tup[col] is None
144
+ continue
145
+
146
+ if column_types[col].is_array_type():
147
+ assert (val == arrow_tup[col]).all()
148
+ else:
149
+ assert val == arrow_tup[col]
150
+
151
+ def test_import_huggingface_dataset(self, test_client: pxt.Client, tmp_path: pathlib.Path) -> None:
152
+ skip_test_if_not_installed('datasets')
153
+ import datasets
154
+
155
+ test_cases = [
156
+ # { # includes a timestamp. 20MB for specific slice
157
+ # Disbled this test case because download is failing, and its not critical.
158
+ # 'dataset_name': 'c4',
159
+ # # see https://huggingface.co/datasets/allenai/c4/blob/main/realnewslike/c4-train.00000-of-00512.json.gz
160
+ # 'dataset': datasets.load_dataset(
161
+ # "allenai/c4",
162
+ # data_dir="realnewslike",
163
+ # data_files="c4-train.00000-of-00512.json.gz",
164
+ # split='train[:1000]',
165
+ # cache_dir=tmp_path
166
+ # ),
167
+ # },
168
+ { # includes an embedding (array type), common in a few RAG datasets.
169
+ 'dataset_name': 'cohere_wikipedia',
170
+ 'dataset': datasets.load_dataset("Cohere/wikipedia-2023-11-embed-multilingual-v3",
171
+ data_dir='cr').select_columns(['url', 'title', 'text', 'emb']),
172
+ # column with name `_id`` is not currently allowed by pixeltable rules,
173
+ # so filter out that column.
174
+ # cr subdir has a small number of rows, avoid running out of space in CI runner
175
+ # see https://huggingface.co/datasets/Cohere/wikipedia-2023-11-embed-multilingual-v3/tree/main/cr
176
+ 'schema_override': {'emb': ArrayType((1024,), dtype=FloatType(), nullable=False)}
177
+ },
178
+ # example of dataset dictionary with multiple splits
179
+ {
180
+ 'dataset_name': 'rotten_tomatoes',
181
+ 'dataset': datasets.load_dataset("rotten_tomatoes"),
182
+ },
183
+ ]
184
+
185
+ # test a column name for splits other than the default of 'split'
186
+ split_column_name = 'my_split_col'
187
+ for rec in test_cases:
188
+ dataset_name = rec['dataset_name']
189
+ hf_dataset = rec['dataset']
190
+
191
+ tab = test_client.import_huggingface_dataset(
192
+ dataset_name,
193
+ hf_dataset,
194
+ column_name_for_split=split_column_name,
195
+ schema_override=rec.get('schema_override', None),
196
+ )
197
+ if isinstance(hf_dataset, datasets.Dataset):
198
+ assert_hf_dataset_equal(hf_dataset, tab.df(), split_column_name)
199
+ elif isinstance(hf_dataset, datasets.DatasetDict):
200
+ assert tab.count() == sum(hf_dataset.num_rows.values())
201
+ assert split_column_name in tab.column_names()
202
+
203
+ for dataset_name in hf_dataset:
204
+ df = tab.where(tab.my_split_col == dataset_name)
205
+ assert_hf_dataset_equal(hf_dataset[dataset_name], df, split_column_name)
206
+ else:
207
+ assert False
208
+
209
+ with pytest.raises(excs.Error) as exc_info:
210
+ test_client.import_huggingface_dataset('test', {})
211
+ assert 'type(dataset)' in str(exc_info.value)
212
+
119
213
  def test_image_table(self, test_client: pxt.Client) -> None:
120
214
  n_sample_rows = 20
121
215
  cl = test_client
@@ -533,6 +627,15 @@ class TestTable:
533
627
  t.insert(c5=np.ndarray((3, 2)))
534
628
  assert 'expected ndarray((2, 3)' in str(exc_info.value)
535
629
 
630
+ def test_insert_string_with_null(self, test_client: pxt.Client) -> None:
631
+ cl = test_client
632
+ t = cl.create_table('test', {'c1': StringType()})
633
+
634
+ t.insert([{'c1': 'this is a python\x00string'}])
635
+ assert t.count() == 1
636
+ for tup in t.df().collect():
637
+ assert tup['c1'] == 'this is a python string'
638
+
536
639
  def test_query(self, test_client: pxt.Client) -> None:
537
640
  skip_test_if_not_installed('boto3')
538
641
  cl = test_client
pixeltable/tests/utils.py CHANGED
@@ -2,8 +2,9 @@ import datetime
2
2
  import glob
3
3
  import json
4
4
  import os
5
+ from collections import namedtuple
5
6
  from pathlib import Path
6
- from typing import Dict, Any, List, Optional
7
+ from typing import Any, Dict, List, Optional, Set
7
8
 
8
9
  import numpy as np
9
10
  import pandas as pd
@@ -12,12 +13,21 @@ import pytest
12
13
  import pixeltable as pxt
13
14
  import pixeltable.type_system as ts
14
15
  from pixeltable import catalog
16
+ from pixeltable.catalog.globals import UpdateStatus
15
17
  from pixeltable.dataframe import DataFrameResultSet
16
18
  from pixeltable.env import Env
17
- from pixeltable.type_system import \
18
- ColumnType, StringType, IntType, FloatType, ArrayType, BoolType, TimestampType, JsonType, ImageType, VideoType
19
-
20
-
19
+ from pixeltable.type_system import (
20
+ ArrayType,
21
+ BoolType,
22
+ ColumnType,
23
+ FloatType,
24
+ ImageType,
25
+ IntType,
26
+ JsonType,
27
+ StringType,
28
+ TimestampType,
29
+ VideoType,
30
+ )
21
31
 
22
32
 
23
33
  def make_default_type(t: ColumnType.Type) -> ColumnType:
@@ -266,6 +276,7 @@ def get_sentences(n: int = 100) -> List[str]:
266
276
  # this dataset contains \' around the questions
267
277
  return [q['question'].replace("'", '') for q in questions_list[:n]]
268
278
 
279
+
269
280
  def assert_resultset_eq(r1: DataFrameResultSet, r2: DataFrameResultSet) -> None:
270
281
  assert len(r1) == len(r2)
271
282
  assert len(r1.column_names()) == len(r2.column_names()) # we don't care about the actual column names
@@ -280,6 +291,118 @@ def assert_resultset_eq(r1: DataFrameResultSet, r2: DataFrameResultSet) -> None:
280
291
  else:
281
292
  assert s1.equals(s2)
282
293
 
294
+
283
295
  def skip_test_if_not_installed(package) -> None:
284
296
  if not Env.get().is_installed_package(package):
285
297
  pytest.skip(f'Package `{package}` is not installed.')
298
+
299
+
300
+ def validate_update_status(status: UpdateStatus, expected_rows: Optional[int] = None) -> None:
301
+ assert status.num_excs == 0
302
+ if expected_rows is not None:
303
+ assert status.num_rows == expected_rows
304
+
305
+
306
+ def make_test_arrow_table(output_path: Path) -> None:
307
+ import pyarrow as pa
308
+
309
+ value_dict = {
310
+ 'c_id': [1, 2, 3, 4, 5],
311
+ 'c_int64': [-10, -20, -30, -40, None],
312
+ 'c_int32': [-1, -2, -3, -4, None],
313
+ 'c_float32': [1.1, 2.2, 3.3, 4.4, None],
314
+ 'c_string': ['aaa', 'bbb', 'ccc', 'ddd', None],
315
+ 'c_boolean': [True, False, True, False, None],
316
+ 'c_timestamp': [
317
+ datetime.datetime(2012, 1, 1, 12, 0, 0, 25),
318
+ datetime.datetime(2012, 1, 2, 12, 0, 0, 25),
319
+ datetime.datetime(2012, 1, 3, 12, 0, 0, 25),
320
+ datetime.datetime(2012, 1, 4, 12, 0, 0, 25),
321
+ None,
322
+ ],
323
+ # The pyarrow fixed_shape_tensor type does not support NULLs (currently can write them but not read them)
324
+ # So, no nulls in this column
325
+ 'c_array_float32': [
326
+ [
327
+ 1.0,
328
+ 2.0,
329
+ ],
330
+ [
331
+ 10.0,
332
+ 20.0,
333
+ ],
334
+ [
335
+ 100.0,
336
+ 200.0,
337
+ ],
338
+ [
339
+ 1000.0,
340
+ 2000.0,
341
+ ],
342
+ [10000.0, 20000.0],
343
+ ],
344
+ }
345
+
346
+ arr_size = len(value_dict['c_array_float32'][0])
347
+ tensor_type = pa.fixed_shape_tensor(pa.float32(), (arr_size,))
348
+
349
+ schema = pa.schema(
350
+ [
351
+ ('c_id', pa.int32()),
352
+ ('c_int64', pa.int64()),
353
+ ('c_int32', pa.int32()),
354
+ ('c_float32', pa.float32()),
355
+ ('c_string', pa.string()),
356
+ ('c_boolean', pa.bool_()),
357
+ ('c_timestamp', pa.timestamp('us')),
358
+ ('c_array_float32', tensor_type),
359
+ ]
360
+ )
361
+
362
+ test_table = pa.Table.from_pydict(value_dict, schema=schema)
363
+ pa.parquet.write_table(test_table, str(output_path / 'test.parquet'))
364
+
365
+
366
+ def assert_hf_dataset_equal(hf_dataset: 'datasets.Dataset', df: pxt.DataFrame, split_column_name: str) -> None:
367
+ import datasets
368
+ assert df.count() == hf_dataset.num_rows
369
+ assert set(df.get_column_names()) == (set(hf_dataset.features.keys()) | {split_column_name})
370
+
371
+ # immutable so we can use it as in a set
372
+ DatasetTuple = namedtuple('DatasetTuple', ' '.join(hf_dataset.features.keys()))
373
+ acc_dataset: Set[DatasetTuple] = set()
374
+ for tup in hf_dataset:
375
+ immutable_tup = {}
376
+ for k in tup:
377
+ if isinstance(tup[k], list):
378
+ immutable_tup[k] = tuple(tup[k])
379
+ else:
380
+ immutable_tup[k] = tup[k]
381
+
382
+ acc_dataset.add(DatasetTuple(**immutable_tup))
383
+
384
+ for tup in df.collect():
385
+ assert tup[split_column_name] in hf_dataset.split._name
386
+
387
+ encoded_tup = {}
388
+ for column_name, value in tup.items():
389
+ if column_name == split_column_name:
390
+ continue
391
+ feature_type = hf_dataset.features[column_name]
392
+ if isinstance(feature_type, datasets.ClassLabel):
393
+ assert value in feature_type.names
394
+ # must use the index of the class label as the value to
395
+ # compare with dataset iteration output.
396
+ value = feature_type.encode_example(value)
397
+ elif isinstance(feature_type, datasets.Sequence):
398
+ assert feature_type.feature.dtype == 'float32', 'may need to add more types'
399
+ value = tuple([float(x) for x in value])
400
+
401
+ encoded_tup[column_name] = value
402
+
403
+ check_tup = DatasetTuple(**encoded_tup)
404
+ assert check_tup in acc_dataset
405
+
406
+
407
+ SAMPLE_IMAGE_URL = \
408
+ 'https://raw.githubusercontent.com/pixeltable/pixeltable/master/docs/source/data/images/000000000009.jpg'