pixeltable 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/catalog/column.py +1 -1
- pixeltable/client.py +72 -2
- pixeltable/env.py +36 -52
- pixeltable/functions/__init__.py +1 -1
- pixeltable/functions/fireworks.py +10 -37
- pixeltable/functions/openai.py +192 -24
- pixeltable/functions/together.py +104 -9
- pixeltable/tests/conftest.py +4 -4
- pixeltable/tests/functions/test_fireworks.py +42 -0
- pixeltable/tests/functions/test_functions.py +60 -0
- pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +5 -141
- pixeltable/tests/functions/test_openai.py +152 -0
- pixeltable/tests/functions/test_together.py +111 -0
- pixeltable/tests/test_dataframe.py +4 -4
- pixeltable/tests/test_table.py +105 -2
- pixeltable/tests/utils.py +128 -5
- pixeltable/type_system.py +41 -84
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/hf_datasets.py +157 -0
- pixeltable/utils/parquet.py +68 -27
- pixeltable/utils/pytorch.py +16 -97
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/METADATA +33 -27
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/RECORD +25 -19
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import pixeltable as pxt
|
|
4
|
+
import pixeltable.exceptions as excs
|
|
5
|
+
from pixeltable.tests.utils import skip_test_if_not_installed, validate_update_status
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestTogether:
|
|
9
|
+
|
|
10
|
+
def test_completions(self, test_client: pxt.Client) -> None:
|
|
11
|
+
skip_test_if_not_installed('together')
|
|
12
|
+
TestTogether.skip_test_if_no_together_client()
|
|
13
|
+
cl = test_client
|
|
14
|
+
t = cl.create_table('test_tbl', {'input': pxt.StringType()})
|
|
15
|
+
from pixeltable.functions.together import completions
|
|
16
|
+
t.add_column(output=completions(prompt=t.input, model='mistralai/Mixtral-8x7B-v0.1', stop=['\n']))
|
|
17
|
+
t.add_column(output_2=completions(
|
|
18
|
+
prompt=t.input,
|
|
19
|
+
model='mistralai/Mixtral-8x7B-v0.1',
|
|
20
|
+
max_tokens=300,
|
|
21
|
+
stop=['\n'],
|
|
22
|
+
temperature=0.7,
|
|
23
|
+
top_p=0.9,
|
|
24
|
+
top_k=40,
|
|
25
|
+
repetition_penalty=1.1,
|
|
26
|
+
logprobs=1,
|
|
27
|
+
echo=True,
|
|
28
|
+
n=3,
|
|
29
|
+
safety_model='Meta-Llama/Llama-Guard-7b'
|
|
30
|
+
))
|
|
31
|
+
validate_update_status(t.insert(input='I am going to the '), 1)
|
|
32
|
+
result = t.collect()
|
|
33
|
+
assert len(result['output'][0]['choices'][0]['text']) > 0
|
|
34
|
+
assert len(result['output_2'][0]['choices'][0]['text']) > 0
|
|
35
|
+
|
|
36
|
+
def test_chat_completions(self, test_client: pxt.Client) -> None:
|
|
37
|
+
skip_test_if_not_installed('together')
|
|
38
|
+
TestTogether.skip_test_if_no_together_client()
|
|
39
|
+
cl = test_client
|
|
40
|
+
t = cl.create_table('test_tbl', {'input': pxt.StringType()})
|
|
41
|
+
messages = [{'role': 'user', 'content': t.input}]
|
|
42
|
+
from pixeltable.functions.together import chat_completions
|
|
43
|
+
t.add_column(output=chat_completions(messages=messages, model='mistralai/Mixtral-8x7B-v0.1', stop=['\n']))
|
|
44
|
+
t.add_column(output_2=chat_completions(
|
|
45
|
+
messages=messages,
|
|
46
|
+
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
|
47
|
+
max_tokens=300,
|
|
48
|
+
stop=['\n'],
|
|
49
|
+
temperature=0.7,
|
|
50
|
+
top_p=0.9,
|
|
51
|
+
top_k=40,
|
|
52
|
+
repetition_penalty=1.1,
|
|
53
|
+
logprobs=1,
|
|
54
|
+
echo=True,
|
|
55
|
+
n=3,
|
|
56
|
+
safety_model='Meta-Llama/Llama-Guard-7b',
|
|
57
|
+
response_format={'type': 'json_object'}
|
|
58
|
+
))
|
|
59
|
+
validate_update_status(t.insert(input='Give me a typical example of a JSON structure.'), 1)
|
|
60
|
+
result = t.collect()
|
|
61
|
+
assert len(result['output'][0]['choices'][0]['message']) > 0
|
|
62
|
+
assert len(result['output_2'][0]['choices'][0]['message']) > 0
|
|
63
|
+
|
|
64
|
+
def test_embeddings(self, test_client: pxt.Client) -> None:
|
|
65
|
+
skip_test_if_not_installed('together')
|
|
66
|
+
TestTogether.skip_test_if_no_together_client()
|
|
67
|
+
cl = test_client
|
|
68
|
+
t = cl.create_table('test_tbl', {'input': pxt.StringType()})
|
|
69
|
+
from pixeltable.functions.together import embeddings
|
|
70
|
+
t.add_column(embed=embeddings(input=t.input, model='togethercomputer/m2-bert-80M-8k-retrieval'))
|
|
71
|
+
validate_update_status(t.insert(input='Together AI provides a variety of embeddings models.'), 1)
|
|
72
|
+
assert len(t.collect()['embed'][0]) > 0
|
|
73
|
+
|
|
74
|
+
def test_image_generations(self, test_client: pxt.Client) -> None:
|
|
75
|
+
skip_test_if_not_installed('together')
|
|
76
|
+
TestTogether.skip_test_if_no_together_client()
|
|
77
|
+
cl = test_client
|
|
78
|
+
t = cl.create_table(
|
|
79
|
+
'test_tbl',
|
|
80
|
+
{'input': pxt.StringType(), 'negative_prompt': pxt.StringType(nullable=True)}
|
|
81
|
+
)
|
|
82
|
+
from pixeltable.functions.together import image_generations
|
|
83
|
+
t.add_column(img=image_generations(t.input, model='runwayml/stable-diffusion-v1-5'))
|
|
84
|
+
t.add_column(img_2=image_generations(
|
|
85
|
+
t.input,
|
|
86
|
+
model='stabilityai/stable-diffusion-2-1',
|
|
87
|
+
steps=30,
|
|
88
|
+
seed=4178780,
|
|
89
|
+
height=768,
|
|
90
|
+
width=512,
|
|
91
|
+
negative_prompt=t.negative_prompt
|
|
92
|
+
))
|
|
93
|
+
validate_update_status(t.insert([
|
|
94
|
+
{'input': 'A friendly dinosaur playing tennis in a cornfield'},
|
|
95
|
+
{'input': 'A friendly dinosaur playing tennis in a cornfield',
|
|
96
|
+
'negative_prompt': 'tennis court'}
|
|
97
|
+
]), 2)
|
|
98
|
+
assert t.collect()['img'][0].size == (512, 512)
|
|
99
|
+
assert t.collect()['img_2'][0].size == (512, 768)
|
|
100
|
+
assert t.collect()['img'][1].size == (512, 512)
|
|
101
|
+
assert t.collect()['img_2'][1].size == (512, 768)
|
|
102
|
+
|
|
103
|
+
# This ensures that the test will be skipped, rather than returning an error, when no API key is
|
|
104
|
+
# available (for example, when a PR runs in CI).
|
|
105
|
+
@staticmethod
|
|
106
|
+
def skip_test_if_no_together_client() -> None:
|
|
107
|
+
try:
|
|
108
|
+
import pixeltable.functions.together
|
|
109
|
+
_ = pixeltable.functions.together.together_client()
|
|
110
|
+
except excs.Error as exc:
|
|
111
|
+
pytest.skip(str(exc))
|
|
@@ -33,7 +33,7 @@ class TestDataFrame:
|
|
|
33
33
|
assert res1 == res4
|
|
34
34
|
|
|
35
35
|
_ = t.where(t.c2 < 10).select(t.c2, t.c2).show(0) # repeated name no error
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
# duplicate select list
|
|
38
38
|
with pytest.raises(excs.Error) as exc_info:
|
|
39
39
|
_ = t.select(t.c1).select(t.c2).show(0)
|
|
@@ -220,7 +220,7 @@ class TestDataFrame:
|
|
|
220
220
|
for tup in ds:
|
|
221
221
|
for col in df.get_column_names():
|
|
222
222
|
assert col in tup
|
|
223
|
-
|
|
223
|
+
|
|
224
224
|
arrval = tup['c_array']
|
|
225
225
|
assert isinstance(arrval, np.ndarray)
|
|
226
226
|
col_type = type_dict['c_array']
|
|
@@ -304,7 +304,7 @@ class TestDataFrame:
|
|
|
304
304
|
def restrict_json_for_default_collate(obj):
|
|
305
305
|
keys = ['id', 'label', 'iscrowd', 'bounding_box']
|
|
306
306
|
return {k: obj[k] for k in keys}
|
|
307
|
-
|
|
307
|
+
|
|
308
308
|
t = all_datatypes_tbl
|
|
309
309
|
df = t.select(
|
|
310
310
|
t.row_id,
|
|
@@ -370,7 +370,7 @@ class TestDataFrame:
|
|
|
370
370
|
# check result cached
|
|
371
371
|
ds1 = t.to_pytorch_dataset(image_format='pt')
|
|
372
372
|
ds1_mtimes = _get_mtimes(ds1.path)
|
|
373
|
-
|
|
373
|
+
|
|
374
374
|
ds2 = t.to_pytorch_dataset(image_format='pt')
|
|
375
375
|
ds2_mtimes = _get_mtimes(ds2.path)
|
|
376
376
|
assert ds2.path == ds1.path, 'result should be cached'
|
pixeltable/tests/test_table.py
CHANGED
|
@@ -8,6 +8,7 @@ import PIL
|
|
|
8
8
|
import cv2
|
|
9
9
|
import numpy as np
|
|
10
10
|
import pandas as pd
|
|
11
|
+
import pathlib
|
|
11
12
|
import pytest
|
|
12
13
|
|
|
13
14
|
import pixeltable as pxt
|
|
@@ -17,7 +18,7 @@ from pixeltable import exceptions as excs
|
|
|
17
18
|
from pixeltable.iterators import FrameIterator
|
|
18
19
|
from pixeltable.tests.utils import \
|
|
19
20
|
make_tbl, create_table_data, read_data_file, get_video_files, get_audio_files, get_image_files, get_documents, \
|
|
20
|
-
assert_resultset_eq
|
|
21
|
+
assert_resultset_eq, assert_hf_dataset_equal, make_test_arrow_table
|
|
21
22
|
from pixeltable.tests.utils import skip_test_if_not_installed
|
|
22
23
|
from pixeltable.type_system import \
|
|
23
24
|
StringType, IntType, FloatType, TimestampType, ImageType, VideoType, JsonType, BoolType, ArrayType, AudioType, \
|
|
@@ -25,7 +26,6 @@ from pixeltable.type_system import \
|
|
|
25
26
|
from pixeltable.utils.filecache import FileCache
|
|
26
27
|
from pixeltable.utils.media_store import MediaStore
|
|
27
28
|
|
|
28
|
-
|
|
29
29
|
class TestTable:
|
|
30
30
|
# exc for a % 10 == 0
|
|
31
31
|
@pxt.udf(return_type=FloatType(), param_types=[IntType()])
|
|
@@ -116,6 +116,100 @@ class TestTable:
|
|
|
116
116
|
tbl.revert()
|
|
117
117
|
assert tbl.num_retained_versions == num_retained_versions
|
|
118
118
|
|
|
119
|
+
def test_import_parquet(self, test_client: pxt.Client, tmp_path: pathlib.Path) -> None:
|
|
120
|
+
skip_test_if_not_installed('pyarrow')
|
|
121
|
+
import pyarrow as pa
|
|
122
|
+
from pixeltable.utils.arrow import iter_tuples
|
|
123
|
+
|
|
124
|
+
parquet_dir = tmp_path / 'test_data'
|
|
125
|
+
parquet_dir.mkdir()
|
|
126
|
+
make_test_arrow_table(parquet_dir)
|
|
127
|
+
|
|
128
|
+
tab = test_client.import_parquet('test_parquet', parquet_path=str(parquet_dir))
|
|
129
|
+
assert 'test_parquet' in test_client.list_tables()
|
|
130
|
+
assert tab is not None
|
|
131
|
+
num_elts = tab.count()
|
|
132
|
+
arrow_tab: pa.Table = pa.parquet.read_table(str(parquet_dir))
|
|
133
|
+
assert num_elts == arrow_tab.num_rows
|
|
134
|
+
assert set(tab.column_names()) == set(arrow_tab.column_names)
|
|
135
|
+
|
|
136
|
+
result_set = tab.order_by(tab.c_id).collect()
|
|
137
|
+
column_types = tab.column_types()
|
|
138
|
+
|
|
139
|
+
for tup, arrow_tup in zip(result_set, iter_tuples(arrow_tab)):
|
|
140
|
+
assert tup['c_id'] == arrow_tup['c_id']
|
|
141
|
+
for col, val in tup.items():
|
|
142
|
+
if val is None:
|
|
143
|
+
assert arrow_tup[col] is None
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
if column_types[col].is_array_type():
|
|
147
|
+
assert (val == arrow_tup[col]).all()
|
|
148
|
+
else:
|
|
149
|
+
assert val == arrow_tup[col]
|
|
150
|
+
|
|
151
|
+
def test_import_huggingface_dataset(self, test_client: pxt.Client, tmp_path: pathlib.Path) -> None:
|
|
152
|
+
skip_test_if_not_installed('datasets')
|
|
153
|
+
import datasets
|
|
154
|
+
|
|
155
|
+
test_cases = [
|
|
156
|
+
# { # includes a timestamp. 20MB for specific slice
|
|
157
|
+
# Disbled this test case because download is failing, and its not critical.
|
|
158
|
+
# 'dataset_name': 'c4',
|
|
159
|
+
# # see https://huggingface.co/datasets/allenai/c4/blob/main/realnewslike/c4-train.00000-of-00512.json.gz
|
|
160
|
+
# 'dataset': datasets.load_dataset(
|
|
161
|
+
# "allenai/c4",
|
|
162
|
+
# data_dir="realnewslike",
|
|
163
|
+
# data_files="c4-train.00000-of-00512.json.gz",
|
|
164
|
+
# split='train[:1000]',
|
|
165
|
+
# cache_dir=tmp_path
|
|
166
|
+
# ),
|
|
167
|
+
# },
|
|
168
|
+
{ # includes an embedding (array type), common in a few RAG datasets.
|
|
169
|
+
'dataset_name': 'cohere_wikipedia',
|
|
170
|
+
'dataset': datasets.load_dataset("Cohere/wikipedia-2023-11-embed-multilingual-v3",
|
|
171
|
+
data_dir='cr').select_columns(['url', 'title', 'text', 'emb']),
|
|
172
|
+
# column with name `_id`` is not currently allowed by pixeltable rules,
|
|
173
|
+
# so filter out that column.
|
|
174
|
+
# cr subdir has a small number of rows, avoid running out of space in CI runner
|
|
175
|
+
# see https://huggingface.co/datasets/Cohere/wikipedia-2023-11-embed-multilingual-v3/tree/main/cr
|
|
176
|
+
'schema_override': {'emb': ArrayType((1024,), dtype=FloatType(), nullable=False)}
|
|
177
|
+
},
|
|
178
|
+
# example of dataset dictionary with multiple splits
|
|
179
|
+
{
|
|
180
|
+
'dataset_name': 'rotten_tomatoes',
|
|
181
|
+
'dataset': datasets.load_dataset("rotten_tomatoes"),
|
|
182
|
+
},
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
# test a column name for splits other than the default of 'split'
|
|
186
|
+
split_column_name = 'my_split_col'
|
|
187
|
+
for rec in test_cases:
|
|
188
|
+
dataset_name = rec['dataset_name']
|
|
189
|
+
hf_dataset = rec['dataset']
|
|
190
|
+
|
|
191
|
+
tab = test_client.import_huggingface_dataset(
|
|
192
|
+
dataset_name,
|
|
193
|
+
hf_dataset,
|
|
194
|
+
column_name_for_split=split_column_name,
|
|
195
|
+
schema_override=rec.get('schema_override', None),
|
|
196
|
+
)
|
|
197
|
+
if isinstance(hf_dataset, datasets.Dataset):
|
|
198
|
+
assert_hf_dataset_equal(hf_dataset, tab.df(), split_column_name)
|
|
199
|
+
elif isinstance(hf_dataset, datasets.DatasetDict):
|
|
200
|
+
assert tab.count() == sum(hf_dataset.num_rows.values())
|
|
201
|
+
assert split_column_name in tab.column_names()
|
|
202
|
+
|
|
203
|
+
for dataset_name in hf_dataset:
|
|
204
|
+
df = tab.where(tab.my_split_col == dataset_name)
|
|
205
|
+
assert_hf_dataset_equal(hf_dataset[dataset_name], df, split_column_name)
|
|
206
|
+
else:
|
|
207
|
+
assert False
|
|
208
|
+
|
|
209
|
+
with pytest.raises(excs.Error) as exc_info:
|
|
210
|
+
test_client.import_huggingface_dataset('test', {})
|
|
211
|
+
assert 'type(dataset)' in str(exc_info.value)
|
|
212
|
+
|
|
119
213
|
def test_image_table(self, test_client: pxt.Client) -> None:
|
|
120
214
|
n_sample_rows = 20
|
|
121
215
|
cl = test_client
|
|
@@ -533,6 +627,15 @@ class TestTable:
|
|
|
533
627
|
t.insert(c5=np.ndarray((3, 2)))
|
|
534
628
|
assert 'expected ndarray((2, 3)' in str(exc_info.value)
|
|
535
629
|
|
|
630
|
+
def test_insert_string_with_null(self, test_client: pxt.Client) -> None:
|
|
631
|
+
cl = test_client
|
|
632
|
+
t = cl.create_table('test', {'c1': StringType()})
|
|
633
|
+
|
|
634
|
+
t.insert([{'c1': 'this is a python\x00string'}])
|
|
635
|
+
assert t.count() == 1
|
|
636
|
+
for tup in t.df().collect():
|
|
637
|
+
assert tup['c1'] == 'this is a python string'
|
|
638
|
+
|
|
536
639
|
def test_query(self, test_client: pxt.Client) -> None:
|
|
537
640
|
skip_test_if_not_installed('boto3')
|
|
538
641
|
cl = test_client
|
pixeltable/tests/utils.py
CHANGED
|
@@ -2,8 +2,9 @@ import datetime
|
|
|
2
2
|
import glob
|
|
3
3
|
import json
|
|
4
4
|
import os
|
|
5
|
+
from collections import namedtuple
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
7
|
+
from typing import Any, Dict, List, Optional, Set
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import pandas as pd
|
|
@@ -12,12 +13,21 @@ import pytest
|
|
|
12
13
|
import pixeltable as pxt
|
|
13
14
|
import pixeltable.type_system as ts
|
|
14
15
|
from pixeltable import catalog
|
|
16
|
+
from pixeltable.catalog.globals import UpdateStatus
|
|
15
17
|
from pixeltable.dataframe import DataFrameResultSet
|
|
16
18
|
from pixeltable.env import Env
|
|
17
|
-
from pixeltable.type_system import
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
19
|
+
from pixeltable.type_system import (
|
|
20
|
+
ArrayType,
|
|
21
|
+
BoolType,
|
|
22
|
+
ColumnType,
|
|
23
|
+
FloatType,
|
|
24
|
+
ImageType,
|
|
25
|
+
IntType,
|
|
26
|
+
JsonType,
|
|
27
|
+
StringType,
|
|
28
|
+
TimestampType,
|
|
29
|
+
VideoType,
|
|
30
|
+
)
|
|
21
31
|
|
|
22
32
|
|
|
23
33
|
def make_default_type(t: ColumnType.Type) -> ColumnType:
|
|
@@ -266,6 +276,7 @@ def get_sentences(n: int = 100) -> List[str]:
|
|
|
266
276
|
# this dataset contains \' around the questions
|
|
267
277
|
return [q['question'].replace("'", '') for q in questions_list[:n]]
|
|
268
278
|
|
|
279
|
+
|
|
269
280
|
def assert_resultset_eq(r1: DataFrameResultSet, r2: DataFrameResultSet) -> None:
|
|
270
281
|
assert len(r1) == len(r2)
|
|
271
282
|
assert len(r1.column_names()) == len(r2.column_names()) # we don't care about the actual column names
|
|
@@ -280,6 +291,118 @@ def assert_resultset_eq(r1: DataFrameResultSet, r2: DataFrameResultSet) -> None:
|
|
|
280
291
|
else:
|
|
281
292
|
assert s1.equals(s2)
|
|
282
293
|
|
|
294
|
+
|
|
283
295
|
def skip_test_if_not_installed(package) -> None:
|
|
284
296
|
if not Env.get().is_installed_package(package):
|
|
285
297
|
pytest.skip(f'Package `{package}` is not installed.')
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def validate_update_status(status: UpdateStatus, expected_rows: Optional[int] = None) -> None:
|
|
301
|
+
assert status.num_excs == 0
|
|
302
|
+
if expected_rows is not None:
|
|
303
|
+
assert status.num_rows == expected_rows
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def make_test_arrow_table(output_path: Path) -> None:
|
|
307
|
+
import pyarrow as pa
|
|
308
|
+
|
|
309
|
+
value_dict = {
|
|
310
|
+
'c_id': [1, 2, 3, 4, 5],
|
|
311
|
+
'c_int64': [-10, -20, -30, -40, None],
|
|
312
|
+
'c_int32': [-1, -2, -3, -4, None],
|
|
313
|
+
'c_float32': [1.1, 2.2, 3.3, 4.4, None],
|
|
314
|
+
'c_string': ['aaa', 'bbb', 'ccc', 'ddd', None],
|
|
315
|
+
'c_boolean': [True, False, True, False, None],
|
|
316
|
+
'c_timestamp': [
|
|
317
|
+
datetime.datetime(2012, 1, 1, 12, 0, 0, 25),
|
|
318
|
+
datetime.datetime(2012, 1, 2, 12, 0, 0, 25),
|
|
319
|
+
datetime.datetime(2012, 1, 3, 12, 0, 0, 25),
|
|
320
|
+
datetime.datetime(2012, 1, 4, 12, 0, 0, 25),
|
|
321
|
+
None,
|
|
322
|
+
],
|
|
323
|
+
# The pyarrow fixed_shape_tensor type does not support NULLs (currently can write them but not read them)
|
|
324
|
+
# So, no nulls in this column
|
|
325
|
+
'c_array_float32': [
|
|
326
|
+
[
|
|
327
|
+
1.0,
|
|
328
|
+
2.0,
|
|
329
|
+
],
|
|
330
|
+
[
|
|
331
|
+
10.0,
|
|
332
|
+
20.0,
|
|
333
|
+
],
|
|
334
|
+
[
|
|
335
|
+
100.0,
|
|
336
|
+
200.0,
|
|
337
|
+
],
|
|
338
|
+
[
|
|
339
|
+
1000.0,
|
|
340
|
+
2000.0,
|
|
341
|
+
],
|
|
342
|
+
[10000.0, 20000.0],
|
|
343
|
+
],
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
arr_size = len(value_dict['c_array_float32'][0])
|
|
347
|
+
tensor_type = pa.fixed_shape_tensor(pa.float32(), (arr_size,))
|
|
348
|
+
|
|
349
|
+
schema = pa.schema(
|
|
350
|
+
[
|
|
351
|
+
('c_id', pa.int32()),
|
|
352
|
+
('c_int64', pa.int64()),
|
|
353
|
+
('c_int32', pa.int32()),
|
|
354
|
+
('c_float32', pa.float32()),
|
|
355
|
+
('c_string', pa.string()),
|
|
356
|
+
('c_boolean', pa.bool_()),
|
|
357
|
+
('c_timestamp', pa.timestamp('us')),
|
|
358
|
+
('c_array_float32', tensor_type),
|
|
359
|
+
]
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
test_table = pa.Table.from_pydict(value_dict, schema=schema)
|
|
363
|
+
pa.parquet.write_table(test_table, str(output_path / 'test.parquet'))
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def assert_hf_dataset_equal(hf_dataset: 'datasets.Dataset', df: pxt.DataFrame, split_column_name: str) -> None:
|
|
367
|
+
import datasets
|
|
368
|
+
assert df.count() == hf_dataset.num_rows
|
|
369
|
+
assert set(df.get_column_names()) == (set(hf_dataset.features.keys()) | {split_column_name})
|
|
370
|
+
|
|
371
|
+
# immutable so we can use it as in a set
|
|
372
|
+
DatasetTuple = namedtuple('DatasetTuple', ' '.join(hf_dataset.features.keys()))
|
|
373
|
+
acc_dataset: Set[DatasetTuple] = set()
|
|
374
|
+
for tup in hf_dataset:
|
|
375
|
+
immutable_tup = {}
|
|
376
|
+
for k in tup:
|
|
377
|
+
if isinstance(tup[k], list):
|
|
378
|
+
immutable_tup[k] = tuple(tup[k])
|
|
379
|
+
else:
|
|
380
|
+
immutable_tup[k] = tup[k]
|
|
381
|
+
|
|
382
|
+
acc_dataset.add(DatasetTuple(**immutable_tup))
|
|
383
|
+
|
|
384
|
+
for tup in df.collect():
|
|
385
|
+
assert tup[split_column_name] in hf_dataset.split._name
|
|
386
|
+
|
|
387
|
+
encoded_tup = {}
|
|
388
|
+
for column_name, value in tup.items():
|
|
389
|
+
if column_name == split_column_name:
|
|
390
|
+
continue
|
|
391
|
+
feature_type = hf_dataset.features[column_name]
|
|
392
|
+
if isinstance(feature_type, datasets.ClassLabel):
|
|
393
|
+
assert value in feature_type.names
|
|
394
|
+
# must use the index of the class label as the value to
|
|
395
|
+
# compare with dataset iteration output.
|
|
396
|
+
value = feature_type.encode_example(value)
|
|
397
|
+
elif isinstance(feature_type, datasets.Sequence):
|
|
398
|
+
assert feature_type.feature.dtype == 'float32', 'may need to add more types'
|
|
399
|
+
value = tuple([float(x) for x in value])
|
|
400
|
+
|
|
401
|
+
encoded_tup[column_name] = value
|
|
402
|
+
|
|
403
|
+
check_tup = DatasetTuple(**encoded_tup)
|
|
404
|
+
assert check_tup in acc_dataset
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
SAMPLE_IMAGE_URL = \
|
|
408
|
+
'https://raw.githubusercontent.com/pixeltable/pixeltable/master/docs/source/data/images/000000000009.jpg'
|