pixeltable 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/catalog/column.py +26 -49
- pixeltable/catalog/insertable_table.py +7 -4
- pixeltable/catalog/table.py +163 -57
- pixeltable/catalog/table_version.py +416 -140
- pixeltable/catalog/table_version_path.py +2 -2
- pixeltable/client.py +72 -6
- pixeltable/dataframe.py +65 -21
- pixeltable/env.py +52 -53
- pixeltable/exec/cache_prefetch_node.py +1 -1
- pixeltable/exec/in_memory_data_node.py +11 -7
- pixeltable/exprs/comparison.py +3 -3
- pixeltable/exprs/data_row.py +5 -1
- pixeltable/exprs/literal.py +16 -4
- pixeltable/exprs/row_builder.py +8 -40
- pixeltable/ext/__init__.py +5 -0
- pixeltable/ext/functions/yolox.py +92 -0
- pixeltable/func/aggregate_function.py +15 -15
- pixeltable/func/expr_template_function.py +9 -1
- pixeltable/func/globals.py +24 -14
- pixeltable/func/signature.py +18 -12
- pixeltable/func/udf.py +7 -2
- pixeltable/functions/__init__.py +9 -9
- pixeltable/functions/eval.py +7 -8
- pixeltable/functions/fireworks.py +10 -37
- pixeltable/functions/huggingface.py +47 -19
- pixeltable/functions/openai.py +192 -24
- pixeltable/functions/together.py +104 -9
- pixeltable/functions/util.py +11 -0
- pixeltable/index/__init__.py +2 -0
- pixeltable/index/base.py +49 -0
- pixeltable/index/embedding_index.py +95 -0
- pixeltable/metadata/schema.py +45 -22
- pixeltable/plan.py +15 -34
- pixeltable/store.py +38 -41
- pixeltable/tests/conftest.py +8 -14
- pixeltable/tests/ext/test_yolox.py +21 -0
- pixeltable/tests/functions/test_fireworks.py +43 -0
- pixeltable/tests/functions/test_functions.py +60 -0
- pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +7 -143
- pixeltable/tests/functions/test_openai.py +162 -0
- pixeltable/tests/functions/test_together.py +112 -0
- pixeltable/tests/test_component_view.py +14 -5
- pixeltable/tests/test_dataframe.py +23 -22
- pixeltable/tests/test_exprs.py +99 -102
- pixeltable/tests/test_function.py +51 -43
- pixeltable/tests/test_index.py +138 -0
- pixeltable/tests/test_migration.py +2 -1
- pixeltable/tests/test_snapshot.py +24 -1
- pixeltable/tests/test_table.py +205 -26
- pixeltable/tests/test_types.py +30 -0
- pixeltable/tests/test_video.py +16 -16
- pixeltable/tests/test_view.py +5 -0
- pixeltable/tests/utils.py +171 -14
- pixeltable/tool/create_test_db_dump.py +16 -0
- pixeltable/type_system.py +77 -128
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/hf_datasets.py +157 -0
- pixeltable/utils/parquet.py +68 -27
- pixeltable/utils/pytorch.py +16 -97
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/METADATA +35 -28
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/RECORD +63 -50
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import pixeltable as pxt
|
|
4
|
+
import pixeltable.exceptions as excs
|
|
5
|
+
from pixeltable.tests.utils import SAMPLE_IMAGE_URL, skip_test_if_not_installed, validate_update_status
|
|
6
|
+
from pixeltable.type_system import StringType, ImageType
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.mark.remote_api
|
|
10
|
+
class TestOpenai:
|
|
11
|
+
|
|
12
|
+
def test_audio(self, test_client: pxt.Client) -> None:
|
|
13
|
+
skip_test_if_not_installed('openai')
|
|
14
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
15
|
+
cl = test_client
|
|
16
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
17
|
+
from pixeltable.functions.openai import speech, transcriptions, translations
|
|
18
|
+
t.add_column(speech=speech(t.input, model='tts-1', voice='onyx'))
|
|
19
|
+
t.add_column(speech_2=speech(t.input, model='tts-1', voice='onyx', response_format='flac', speed=1.05))
|
|
20
|
+
t.add_column(transcription=transcriptions(t.speech, model='whisper-1'))
|
|
21
|
+
t.add_column(transcription_2=transcriptions(
|
|
22
|
+
t.speech, model='whisper-1', language='en', prompt='Transcribe the contents of this recording.'
|
|
23
|
+
))
|
|
24
|
+
t.add_column(translation=translations(t.speech, model='whisper-1'))
|
|
25
|
+
t.add_column(translation_2=translations(
|
|
26
|
+
t.speech, model='whisper-1', prompt='Translate the recording from Spanish into English.', temperature=0.05
|
|
27
|
+
))
|
|
28
|
+
validate_update_status(t.insert([
|
|
29
|
+
{'input': 'I am a banana.'},
|
|
30
|
+
{'input': 'Es fácil traducir del español al inglés.'}
|
|
31
|
+
]), expected_rows=2)
|
|
32
|
+
# The audio generation -> transcription loop on these examples should be simple and clear enough
|
|
33
|
+
# that the unit test can reliably expect the output closely enough to pass these checks.
|
|
34
|
+
results = t.collect()
|
|
35
|
+
assert results[0]['transcription']['text'] in ['I am a banana.', "I'm a banana."]
|
|
36
|
+
assert results[0]['transcription_2']['text'] in ['I am a banana.', "I'm a banana."]
|
|
37
|
+
assert 'easy to translate' in results[1]['translation']['text']
|
|
38
|
+
assert 'easy to translate' in results[1]['translation_2']['text']
|
|
39
|
+
|
|
40
|
+
def test_chat_completions(self, test_client: pxt.Client) -> None:
|
|
41
|
+
skip_test_if_not_installed('openai')
|
|
42
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
43
|
+
cl = test_client
|
|
44
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
45
|
+
from pixeltable.functions.openai import chat_completions
|
|
46
|
+
msgs = [
|
|
47
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
48
|
+
{"role": "user", "content": t.input}
|
|
49
|
+
]
|
|
50
|
+
t.add_column(input_msgs=msgs)
|
|
51
|
+
t.add_column(chat_output=chat_completions(model='gpt-3.5-turbo', messages=t.input_msgs))
|
|
52
|
+
# with inlined messages
|
|
53
|
+
t.add_column(chat_output_2=chat_completions(model='gpt-3.5-turbo', messages=msgs))
|
|
54
|
+
# test a bunch of the parameters
|
|
55
|
+
t.add_column(chat_output_3=chat_completions(
|
|
56
|
+
model='gpt-3.5-turbo', messages=msgs, frequency_penalty=0.1, logprobs=True, top_logprobs=3,
|
|
57
|
+
max_tokens=500, n=3, presence_penalty=0.1, seed=4171780, stop=['\n'], temperature=0.7, top_p=0.8,
|
|
58
|
+
user='pixeltable'
|
|
59
|
+
))
|
|
60
|
+
# test with JSON output enforced
|
|
61
|
+
t.add_column(chat_output_4=chat_completions(
|
|
62
|
+
model='gpt-3.5-turbo', messages=msgs, response_format={'type': 'json_object'}
|
|
63
|
+
))
|
|
64
|
+
# TODO Also test the `tools` and `tool_choice` parameters.
|
|
65
|
+
validate_update_status(t.insert(input='Give me an example of a typical JSON structure.'), 1)
|
|
66
|
+
result = t.collect()
|
|
67
|
+
assert len(result['chat_output'][0]['choices'][0]['message']['content']) > 0
|
|
68
|
+
assert len(result['chat_output_2'][0]['choices'][0]['message']['content']) > 0
|
|
69
|
+
assert len(result['chat_output_3'][0]['choices'][0]['message']['content']) > 0
|
|
70
|
+
assert len(result['chat_output_4'][0]['choices'][0]['message']['content']) > 0
|
|
71
|
+
|
|
72
|
+
# When OpenAI gets a request with `response_format` equal to `json_object`, but the prompt does not
|
|
73
|
+
# contain the string "json", it refuses the request.
|
|
74
|
+
# TODO This should probably not be throwing an exception, but rather logging the error in
|
|
75
|
+
# `t.chat_output_4.errormsg` etc.
|
|
76
|
+
with pytest.raises(excs.ExprEvalError) as exc_info:
|
|
77
|
+
t.insert(input='Say something interesting.')
|
|
78
|
+
assert "\\'messages\\' must contain the word \\'json\\'" in str(exc_info.value)
|
|
79
|
+
|
|
80
|
+
def test_gpt_4_vision(self, test_client: pxt.Client) -> None:
|
|
81
|
+
skip_test_if_not_installed('openai')
|
|
82
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
83
|
+
cl = test_client
|
|
84
|
+
t = cl.create_table('test_tbl', {'prompt': StringType(), 'img': ImageType()})
|
|
85
|
+
from pixeltable.functions.openai import chat_completions, vision
|
|
86
|
+
from pixeltable.functions.string import str_format
|
|
87
|
+
t.add_column(response=vision(prompt="What's in this image?", image=t.img))
|
|
88
|
+
# Also get the response the low-level way, by calling chat_completions
|
|
89
|
+
msgs = [
|
|
90
|
+
{'role': 'user',
|
|
91
|
+
'content': [
|
|
92
|
+
{'type': 'text', 'text': t.prompt},
|
|
93
|
+
{'type': 'image_url', 'image_url': {
|
|
94
|
+
'url': str_format('data:image/png;base64,{0}', t.img.b64_encode())
|
|
95
|
+
}}
|
|
96
|
+
]}
|
|
97
|
+
]
|
|
98
|
+
t.add_column(response_2=chat_completions(model='gpt-4-vision-preview', messages=msgs, max_tokens=300).choices[0].message.content)
|
|
99
|
+
validate_update_status(t.insert(prompt="What's in this image?", img=SAMPLE_IMAGE_URL), 1)
|
|
100
|
+
result = t.collect()['response_2'][0]
|
|
101
|
+
assert len(result) > 0
|
|
102
|
+
|
|
103
|
+
def test_embeddings(self, test_client: pxt.Client) -> None:
|
|
104
|
+
skip_test_if_not_installed('openai')
|
|
105
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
106
|
+
cl = test_client
|
|
107
|
+
from pixeltable.functions.openai import embeddings
|
|
108
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
109
|
+
t.add_column(ada_embed=embeddings(model='text-embedding-ada-002', input=t.input))
|
|
110
|
+
t.add_column(text_3=embeddings(model='text-embedding-3-small', input=t.input, user='pixeltable'))
|
|
111
|
+
validate_update_status(t.insert(input='Say something interesting.'), 1)
|
|
112
|
+
_ = t.head()
|
|
113
|
+
|
|
114
|
+
def test_moderations(self, test_client: pxt.Client) -> None:
|
|
115
|
+
skip_test_if_not_installed('openai')
|
|
116
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
117
|
+
cl = test_client
|
|
118
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
119
|
+
from pixeltable.functions.openai import moderations
|
|
120
|
+
t.add_column(moderation=moderations(input=t.input))
|
|
121
|
+
t.add_column(moderation_2=moderations(input=t.input, model='text-moderation-stable'))
|
|
122
|
+
validate_update_status(t.insert(input='Say something interesting.'), 1)
|
|
123
|
+
_ = t.head()
|
|
124
|
+
|
|
125
|
+
def test_image_generations(self, test_client: pxt.Client) -> None:
|
|
126
|
+
skip_test_if_not_installed('openai')
|
|
127
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
128
|
+
cl = test_client
|
|
129
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
130
|
+
from pixeltable.functions.openai import image_generations
|
|
131
|
+
t.add_column(img=image_generations(t.input))
|
|
132
|
+
# Test dall-e-2 options
|
|
133
|
+
t.add_column(img_2=image_generations(
|
|
134
|
+
t.input, model='dall-e-2', size='512x512', user='pixeltable'
|
|
135
|
+
))
|
|
136
|
+
validate_update_status(t.insert(input='A friendly dinosaur playing tennis in a cornfield'), 1)
|
|
137
|
+
assert t.collect()['img'][0].size == (1024, 1024)
|
|
138
|
+
assert t.collect()['img_2'][0].size == (512, 512)
|
|
139
|
+
|
|
140
|
+
@pytest.mark.skip('Test is expensive and slow')
|
|
141
|
+
def test_image_generations_dall_e_3(self, test_client: pxt.Client) -> None:
|
|
142
|
+
skip_test_if_not_installed('openai')
|
|
143
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
144
|
+
cl = test_client
|
|
145
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
146
|
+
from pixeltable.functions.openai import image_generations
|
|
147
|
+
# Test dall-e-3 options
|
|
148
|
+
t.add_column(img_3=image_generations(
|
|
149
|
+
t.input, model='dall-e-3', quality='hd', size='1792x1024', style='natural', user='pixeltable'
|
|
150
|
+
))
|
|
151
|
+
validate_update_status(t.insert(input='A friendly dinosaur playing tennis in a cornfield'), 1)
|
|
152
|
+
assert t.collect()['img_3'][0].size == (1792, 1024)
|
|
153
|
+
|
|
154
|
+
# This ensures that the test will be skipped, rather than returning an error, when no API key is
|
|
155
|
+
# available (for example, when a PR runs in CI).
|
|
156
|
+
@staticmethod
|
|
157
|
+
def skip_test_if_no_openai_client() -> None:
|
|
158
|
+
try:
|
|
159
|
+
import pixeltable.functions.openai
|
|
160
|
+
_ = pixeltable.functions.openai.openai_client()
|
|
161
|
+
except excs.Error as exc:
|
|
162
|
+
pytest.skip(str(exc))
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import pixeltable as pxt
|
|
4
|
+
import pixeltable.exceptions as excs
|
|
5
|
+
from pixeltable.tests.utils import skip_test_if_not_installed, validate_update_status
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@pytest.mark.remote_api
|
|
9
|
+
class TestTogether:
|
|
10
|
+
|
|
11
|
+
def test_completions(self, test_client: pxt.Client) -> None:
|
|
12
|
+
skip_test_if_not_installed('together')
|
|
13
|
+
TestTogether.skip_test_if_no_together_client()
|
|
14
|
+
cl = test_client
|
|
15
|
+
t = cl.create_table('test_tbl', {'input': pxt.StringType()})
|
|
16
|
+
from pixeltable.functions.together import completions
|
|
17
|
+
t.add_column(output=completions(prompt=t.input, model='mistralai/Mixtral-8x7B-v0.1', stop=['\n']))
|
|
18
|
+
t.add_column(output_2=completions(
|
|
19
|
+
prompt=t.input,
|
|
20
|
+
model='mistralai/Mixtral-8x7B-v0.1',
|
|
21
|
+
max_tokens=300,
|
|
22
|
+
stop=['\n'],
|
|
23
|
+
temperature=0.7,
|
|
24
|
+
top_p=0.9,
|
|
25
|
+
top_k=40,
|
|
26
|
+
repetition_penalty=1.1,
|
|
27
|
+
logprobs=1,
|
|
28
|
+
echo=True,
|
|
29
|
+
n=3,
|
|
30
|
+
safety_model='Meta-Llama/Llama-Guard-7b'
|
|
31
|
+
))
|
|
32
|
+
validate_update_status(t.insert(input='I am going to the '), 1)
|
|
33
|
+
result = t.collect()
|
|
34
|
+
assert len(result['output'][0]['choices'][0]['text']) > 0
|
|
35
|
+
assert len(result['output_2'][0]['choices'][0]['text']) > 0
|
|
36
|
+
|
|
37
|
+
def test_chat_completions(self, test_client: pxt.Client) -> None:
|
|
38
|
+
skip_test_if_not_installed('together')
|
|
39
|
+
TestTogether.skip_test_if_no_together_client()
|
|
40
|
+
cl = test_client
|
|
41
|
+
t = cl.create_table('test_tbl', {'input': pxt.StringType()})
|
|
42
|
+
messages = [{'role': 'user', 'content': t.input}]
|
|
43
|
+
from pixeltable.functions.together import chat_completions
|
|
44
|
+
t.add_column(output=chat_completions(messages=messages, model='mistralai/Mixtral-8x7B-v0.1', stop=['\n']))
|
|
45
|
+
t.add_column(output_2=chat_completions(
|
|
46
|
+
messages=messages,
|
|
47
|
+
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
|
48
|
+
max_tokens=300,
|
|
49
|
+
stop=['\n'],
|
|
50
|
+
temperature=0.7,
|
|
51
|
+
top_p=0.9,
|
|
52
|
+
top_k=40,
|
|
53
|
+
repetition_penalty=1.1,
|
|
54
|
+
logprobs=1,
|
|
55
|
+
echo=True,
|
|
56
|
+
n=3,
|
|
57
|
+
safety_model='Meta-Llama/Llama-Guard-7b',
|
|
58
|
+
response_format={'type': 'json_object'}
|
|
59
|
+
))
|
|
60
|
+
validate_update_status(t.insert(input='Give me a typical example of a JSON structure.'), 1)
|
|
61
|
+
result = t.collect()
|
|
62
|
+
assert len(result['output'][0]['choices'][0]['message']) > 0
|
|
63
|
+
assert len(result['output_2'][0]['choices'][0]['message']) > 0
|
|
64
|
+
|
|
65
|
+
def test_embeddings(self, test_client: pxt.Client) -> None:
|
|
66
|
+
skip_test_if_not_installed('together')
|
|
67
|
+
TestTogether.skip_test_if_no_together_client()
|
|
68
|
+
cl = test_client
|
|
69
|
+
t = cl.create_table('test_tbl', {'input': pxt.StringType()})
|
|
70
|
+
from pixeltable.functions.together import embeddings
|
|
71
|
+
t.add_column(embed=embeddings(input=t.input, model='togethercomputer/m2-bert-80M-8k-retrieval'))
|
|
72
|
+
validate_update_status(t.insert(input='Together AI provides a variety of embeddings models.'), 1)
|
|
73
|
+
assert len(t.collect()['embed'][0]) > 0
|
|
74
|
+
|
|
75
|
+
def test_image_generations(self, test_client: pxt.Client) -> None:
|
|
76
|
+
skip_test_if_not_installed('together')
|
|
77
|
+
TestTogether.skip_test_if_no_together_client()
|
|
78
|
+
cl = test_client
|
|
79
|
+
t = cl.create_table(
|
|
80
|
+
'test_tbl',
|
|
81
|
+
{'input': pxt.StringType(), 'negative_prompt': pxt.StringType(nullable=True)}
|
|
82
|
+
)
|
|
83
|
+
from pixeltable.functions.together import image_generations
|
|
84
|
+
t.add_column(img=image_generations(t.input, model='runwayml/stable-diffusion-v1-5'))
|
|
85
|
+
t.add_column(img_2=image_generations(
|
|
86
|
+
t.input,
|
|
87
|
+
model='stabilityai/stable-diffusion-2-1',
|
|
88
|
+
steps=30,
|
|
89
|
+
seed=4178780,
|
|
90
|
+
height=768,
|
|
91
|
+
width=512,
|
|
92
|
+
negative_prompt=t.negative_prompt
|
|
93
|
+
))
|
|
94
|
+
validate_update_status(t.insert([
|
|
95
|
+
{'input': 'A friendly dinosaur playing tennis in a cornfield'},
|
|
96
|
+
{'input': 'A friendly dinosaur playing tennis in a cornfield',
|
|
97
|
+
'negative_prompt': 'tennis court'}
|
|
98
|
+
]), 2)
|
|
99
|
+
assert t.collect()['img'][0].size == (512, 512)
|
|
100
|
+
assert t.collect()['img_2'][0].size == (512, 768)
|
|
101
|
+
assert t.collect()['img'][1].size == (512, 512)
|
|
102
|
+
assert t.collect()['img_2'][1].size == (512, 768)
|
|
103
|
+
|
|
104
|
+
# This ensures that the test will be skipped, rather than returning an error, when no API key is
|
|
105
|
+
# available (for example, when a PR runs in CI).
|
|
106
|
+
@staticmethod
|
|
107
|
+
def skip_test_if_no_together_client() -> None:
|
|
108
|
+
try:
|
|
109
|
+
import pixeltable.functions.together
|
|
110
|
+
_ = pixeltable.functions.together.together_client()
|
|
111
|
+
except excs.Error as exc:
|
|
112
|
+
pytest.skip(str(exc))
|
|
@@ -9,7 +9,7 @@ import pixeltable as pxt
|
|
|
9
9
|
from pixeltable import exceptions as excs
|
|
10
10
|
from pixeltable.iterators import ComponentIterator
|
|
11
11
|
from pixeltable.iterators.video import FrameIterator
|
|
12
|
-
from pixeltable.tests.utils import assert_resultset_eq, get_test_video_files
|
|
12
|
+
from pixeltable.tests.utils import assert_resultset_eq, get_test_video_files, validate_update_status
|
|
13
13
|
from pixeltable.type_system import IntType, VideoType, JsonType
|
|
14
14
|
|
|
15
15
|
class ConstantImgIterator(ComponentIterator):
|
|
@@ -157,10 +157,19 @@ class TestComponentView:
|
|
|
157
157
|
assert status.num_excs == 0
|
|
158
158
|
import urllib
|
|
159
159
|
video_url = urllib.parse.urljoin('file:', urllib.request.pathname2url(video_filepaths[0]))
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
assert
|
|
160
|
+
validate_update_status(
|
|
161
|
+
view_t.update({'annotation': {'a': 1}}, where=view_t.video == video_url),
|
|
162
|
+
expected_rows=view_t.where(view_t.video == video_url).count())
|
|
163
|
+
assert view_t.where(view_t.annotation != None).count() == view_t.where(view_t.video == video_url).count()
|
|
164
|
+
|
|
165
|
+
# batch update with _rowid works
|
|
166
|
+
validate_update_status(
|
|
167
|
+
view_t.batch_update(
|
|
168
|
+
[{'annotation': {'a': 1}, '_rowid': (1, 0)}, {'annotation': {'a': 1}, '_rowid': (1, 1)}]),
|
|
169
|
+
expected_rows=2)
|
|
170
|
+
with pytest.raises(AssertionError):
|
|
171
|
+
# malformed _rowid
|
|
172
|
+
view_t.batch_update([{'annotation': {'a': 1}, '_rowid': (1,)}])
|
|
164
173
|
|
|
165
174
|
with pytest.raises(excs.Error) as excinfo:
|
|
166
175
|
_ = cl.create_view(
|
|
@@ -16,6 +16,22 @@ from pixeltable.tests.utils import get_video_files, get_audio_files, skip_test_i
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class TestDataFrame:
|
|
19
|
+
|
|
20
|
+
@pxt.udf(return_type=pxt.JsonType(nullable=False), param_types=[pxt.JsonType(nullable=False)])
|
|
21
|
+
def yolo_to_coco(detections):
|
|
22
|
+
bboxes, labels = detections['bboxes'], detections['labels']
|
|
23
|
+
num_annotations = len(detections['bboxes'])
|
|
24
|
+
assert num_annotations == len(detections['labels'])
|
|
25
|
+
result = []
|
|
26
|
+
for i in range(num_annotations):
|
|
27
|
+
bbox = bboxes[i]
|
|
28
|
+
ann = {
|
|
29
|
+
'bbox': [round(bbox[0]), round(bbox[1]), round(bbox[2] - bbox[0]), round(bbox[3] - bbox[1])],
|
|
30
|
+
'category': labels[i],
|
|
31
|
+
}
|
|
32
|
+
result.append(ann)
|
|
33
|
+
return result
|
|
34
|
+
|
|
19
35
|
def test_select_where(self, test_tbl: catalog.Table) -> None:
|
|
20
36
|
t = test_tbl
|
|
21
37
|
res1 = t[t.c1, t.c2, t.c3].show(0)
|
|
@@ -33,7 +49,7 @@ class TestDataFrame:
|
|
|
33
49
|
assert res1 == res4
|
|
34
50
|
|
|
35
51
|
_ = t.where(t.c2 < 10).select(t.c2, t.c2).show(0) # repeated name no error
|
|
36
|
-
|
|
52
|
+
|
|
37
53
|
# duplicate select list
|
|
38
54
|
with pytest.raises(excs.Error) as exc_info:
|
|
39
55
|
_ = t.select(t.c1).select(t.c2).show(0)
|
|
@@ -156,7 +172,7 @@ class TestDataFrame:
|
|
|
156
172
|
_ = df.__repr__()
|
|
157
173
|
_ = df._repr_html_()
|
|
158
174
|
|
|
159
|
-
def test_count(self, test_tbl: catalog.Table,
|
|
175
|
+
def test_count(self, test_tbl: catalog.Table, small_img_tbl) -> None:
|
|
160
176
|
skip_test_if_not_installed('nos')
|
|
161
177
|
t = test_tbl
|
|
162
178
|
cnt = t.count()
|
|
@@ -166,7 +182,7 @@ class TestDataFrame:
|
|
|
166
182
|
assert cnt == 10
|
|
167
183
|
|
|
168
184
|
# count() doesn't work with similarity search
|
|
169
|
-
t =
|
|
185
|
+
t = small_img_tbl
|
|
170
186
|
probe = t.select(t.img).show(1)
|
|
171
187
|
img = probe[0, 0]
|
|
172
188
|
with pytest.raises(excs.Error):
|
|
@@ -220,7 +236,7 @@ class TestDataFrame:
|
|
|
220
236
|
for tup in ds:
|
|
221
237
|
for col in df.get_column_names():
|
|
222
238
|
assert col in tup
|
|
223
|
-
|
|
239
|
+
|
|
224
240
|
arrval = tup['c_array']
|
|
225
241
|
assert isinstance(arrval, np.ndarray)
|
|
226
242
|
col_type = type_dict['c_array']
|
|
@@ -304,7 +320,7 @@ class TestDataFrame:
|
|
|
304
320
|
def restrict_json_for_default_collate(obj):
|
|
305
321
|
keys = ['id', 'label', 'iscrowd', 'bounding_box']
|
|
306
322
|
return {k: obj[k] for k in keys}
|
|
307
|
-
|
|
323
|
+
|
|
308
324
|
t = all_datatypes_tbl
|
|
309
325
|
df = t.select(
|
|
310
326
|
t.row_id,
|
|
@@ -370,7 +386,7 @@ class TestDataFrame:
|
|
|
370
386
|
# check result cached
|
|
371
387
|
ds1 = t.to_pytorch_dataset(image_format='pt')
|
|
372
388
|
ds1_mtimes = _get_mtimes(ds1.path)
|
|
373
|
-
|
|
389
|
+
|
|
374
390
|
ds2 = t.to_pytorch_dataset(image_format='pt')
|
|
375
391
|
ds2_mtimes = _get_mtimes(ds2.path)
|
|
376
392
|
assert ds2.path == ds1.path, 'result should be cached'
|
|
@@ -397,22 +413,7 @@ class TestDataFrame:
|
|
|
397
413
|
view_t.add_column(detections=yolox_medium(view_t.frame))
|
|
398
414
|
base_t.insert(video=get_video_files()[0])
|
|
399
415
|
|
|
400
|
-
|
|
401
|
-
def yolo_to_coco(detections):
|
|
402
|
-
bboxes, labels = detections['bboxes'], detections['labels']
|
|
403
|
-
num_annotations = len(detections['bboxes'])
|
|
404
|
-
assert num_annotations == len(detections['labels'])
|
|
405
|
-
result = []
|
|
406
|
-
for i in range(num_annotations):
|
|
407
|
-
bbox = bboxes[i]
|
|
408
|
-
ann = {
|
|
409
|
-
'bbox': [round(bbox[0]), round(bbox[1]), round(bbox[2] - bbox[0]), round(bbox[3] - bbox[1])],
|
|
410
|
-
'category': labels[i],
|
|
411
|
-
}
|
|
412
|
-
result.append(ann)
|
|
413
|
-
return result
|
|
414
|
-
|
|
415
|
-
query = view_t.select({'image': view_t.frame, 'annotations': yolo_to_coco(view_t.detections)})
|
|
416
|
+
query = view_t.select({'image': view_t.frame, 'annotations': self.yolo_to_coco(view_t.detections)})
|
|
416
417
|
path = query.to_coco_dataset()
|
|
417
418
|
# we get a valid COCO dataset
|
|
418
419
|
coco_ds = COCO(path)
|