pixeltable 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/catalog/column.py +1 -1
- pixeltable/client.py +72 -2
- pixeltable/env.py +36 -52
- pixeltable/functions/__init__.py +1 -1
- pixeltable/functions/fireworks.py +10 -37
- pixeltable/functions/openai.py +192 -24
- pixeltable/functions/together.py +104 -9
- pixeltable/tests/conftest.py +4 -4
- pixeltable/tests/functions/test_fireworks.py +42 -0
- pixeltable/tests/functions/test_functions.py +60 -0
- pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +5 -141
- pixeltable/tests/functions/test_openai.py +152 -0
- pixeltable/tests/functions/test_together.py +111 -0
- pixeltable/tests/test_dataframe.py +4 -4
- pixeltable/tests/test_table.py +105 -2
- pixeltable/tests/utils.py +128 -5
- pixeltable/type_system.py +41 -84
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/hf_datasets.py +157 -0
- pixeltable/utils/parquet.py +68 -27
- pixeltable/utils/pytorch.py +16 -97
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/METADATA +33 -27
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/RECORD +25 -19
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/WHEEL +0 -0
pixeltable/functions/together.py
CHANGED
|
@@ -1,27 +1,122 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import io
|
|
1
3
|
from typing import Optional
|
|
2
4
|
|
|
5
|
+
import PIL.Image
|
|
6
|
+
import numpy as np
|
|
7
|
+
import together
|
|
8
|
+
|
|
3
9
|
import pixeltable as pxt
|
|
10
|
+
from pixeltable import env
|
|
11
|
+
from pixeltable.func import Batch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def together_client() -> together.Together:
|
|
15
|
+
return env.Env.get().get_client('together', lambda api_key: together.Together(api_key=api_key))
|
|
4
16
|
|
|
5
17
|
|
|
6
18
|
@pxt.udf
|
|
7
19
|
def completions(
|
|
8
20
|
prompt: str,
|
|
21
|
+
*,
|
|
9
22
|
model: str,
|
|
10
23
|
max_tokens: Optional[int] = None,
|
|
11
|
-
repetition_penalty: Optional[float] = None,
|
|
12
24
|
stop: Optional[list] = None,
|
|
13
|
-
|
|
25
|
+
temperature: Optional[float] = None,
|
|
14
26
|
top_p: Optional[float] = None,
|
|
15
|
-
|
|
27
|
+
top_k: Optional[int] = None,
|
|
28
|
+
repetition_penalty: Optional[float] = None,
|
|
29
|
+
logprobs: Optional[int] = None,
|
|
30
|
+
echo: Optional[bool] = None,
|
|
31
|
+
n: Optional[int] = None,
|
|
32
|
+
safety_model: Optional[str] = None
|
|
16
33
|
) -> dict:
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
model,
|
|
34
|
+
return together_client().completions.create(
|
|
35
|
+
prompt=prompt,
|
|
36
|
+
model=model,
|
|
21
37
|
max_tokens=max_tokens,
|
|
22
|
-
repetition_penalty=repetition_penalty,
|
|
23
38
|
stop=stop,
|
|
39
|
+
temperature=temperature,
|
|
40
|
+
top_p=top_p,
|
|
24
41
|
top_k=top_k,
|
|
42
|
+
repetition_penalty=repetition_penalty,
|
|
43
|
+
logprobs=logprobs,
|
|
44
|
+
echo=echo,
|
|
45
|
+
n=n,
|
|
46
|
+
safety_model=safety_model
|
|
47
|
+
).dict()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pxt.udf
|
|
51
|
+
def chat_completions(
|
|
52
|
+
messages: list[dict[str, str]],
|
|
53
|
+
*,
|
|
54
|
+
model: str,
|
|
55
|
+
max_tokens: Optional[int] = None,
|
|
56
|
+
stop: Optional[list[str]] = None,
|
|
57
|
+
temperature: Optional[float] = None,
|
|
58
|
+
top_p: Optional[float] = None,
|
|
59
|
+
top_k: Optional[int] = None,
|
|
60
|
+
repetition_penalty: Optional[float] = None,
|
|
61
|
+
logprobs: Optional[int] = None,
|
|
62
|
+
echo: Optional[bool] = None,
|
|
63
|
+
n: Optional[int] = None,
|
|
64
|
+
safety_model: Optional[str] = None,
|
|
65
|
+
response_format: Optional[dict] = None,
|
|
66
|
+
tools: Optional[dict] = None,
|
|
67
|
+
tool_choice: Optional[dict] = None
|
|
68
|
+
) -> dict:
|
|
69
|
+
return together_client().chat.completions.create(
|
|
70
|
+
messages=messages,
|
|
71
|
+
model=model,
|
|
72
|
+
max_tokens=max_tokens,
|
|
73
|
+
stop=stop,
|
|
74
|
+
temperature=temperature,
|
|
25
75
|
top_p=top_p,
|
|
26
|
-
|
|
76
|
+
top_k=top_k,
|
|
77
|
+
repetition_penalty=repetition_penalty,
|
|
78
|
+
logprobs=logprobs,
|
|
79
|
+
echo=echo,
|
|
80
|
+
n=n,
|
|
81
|
+
safety_model=safety_model,
|
|
82
|
+
response_format=response_format,
|
|
83
|
+
tools=tools,
|
|
84
|
+
tool_choice=tool_choice
|
|
85
|
+
).dict()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@pxt.udf(batch_size=32, return_type=pxt.ArrayType((None,), dtype=pxt.FloatType()))
|
|
89
|
+
def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
|
|
90
|
+
result = together_client().embeddings.create(input=input, model=model)
|
|
91
|
+
return [
|
|
92
|
+
np.array(data.embedding, dtype=np.float64)
|
|
93
|
+
for data in result.data
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pxt.udf
|
|
98
|
+
def image_generations(
|
|
99
|
+
prompt: str,
|
|
100
|
+
*,
|
|
101
|
+
model: str,
|
|
102
|
+
steps: Optional[int] = None,
|
|
103
|
+
seed: Optional[int] = None,
|
|
104
|
+
height: Optional[int] = None,
|
|
105
|
+
width: Optional[int] = None,
|
|
106
|
+
negative_prompt: Optional[str] = None,
|
|
107
|
+
) -> PIL.Image.Image:
|
|
108
|
+
# TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
|
|
109
|
+
result = together_client().images.generate(
|
|
110
|
+
prompt=prompt,
|
|
111
|
+
model=model,
|
|
112
|
+
steps=steps,
|
|
113
|
+
seed=seed,
|
|
114
|
+
height=height,
|
|
115
|
+
width=width,
|
|
116
|
+
negative_prompt=negative_prompt
|
|
27
117
|
)
|
|
118
|
+
b64_str = result.data[0].b64_json
|
|
119
|
+
b64_bytes = base64.b64decode(b64_str)
|
|
120
|
+
img = PIL.Image.open(io.BytesIO(b64_bytes))
|
|
121
|
+
img.load()
|
|
122
|
+
return img
|
pixeltable/tests/conftest.py
CHANGED
|
@@ -80,8 +80,8 @@ def test_tbl(test_client: pxt.Client) -> catalog.Table:
|
|
|
80
80
|
return create_test_tbl(test_client)
|
|
81
81
|
|
|
82
82
|
# @pytest.fixture(scope='function')
|
|
83
|
-
# def test_stored_fn(test_client:
|
|
84
|
-
# @
|
|
83
|
+
# def test_stored_fn(test_client: pxt.Client) -> pxt.Function:
|
|
84
|
+
# @pxt.udf(return_type=pxt.IntType(), param_types=[pxt.IntType()])
|
|
85
85
|
# def test_fn(x):
|
|
86
86
|
# return x + 1
|
|
87
87
|
# test_client.create_function('test_fn', test_fn)
|
|
@@ -89,7 +89,7 @@ def test_tbl(test_client: pxt.Client) -> catalog.Table:
|
|
|
89
89
|
|
|
90
90
|
@pytest.fixture(scope='function')
|
|
91
91
|
def test_tbl_exprs(test_tbl: catalog.Table) -> List[exprs.Expr]:
|
|
92
|
-
#def test_tbl_exprs(test_tbl: catalog.Table, test_stored_fn:
|
|
92
|
+
#def test_tbl_exprs(test_tbl: catalog.Table, test_stored_fn: pxt.Function) -> List[exprs.Expr]:
|
|
93
93
|
|
|
94
94
|
t = test_tbl
|
|
95
95
|
return [
|
|
@@ -156,7 +156,7 @@ def img_tbl_exprs(img_tbl: catalog.Table) -> List[exprs.Expr]:
|
|
|
156
156
|
# TODO: why does this not work with a session scope? (some user tables don't get created with create_all())
|
|
157
157
|
#@pytest.fixture(scope='session')
|
|
158
158
|
#def indexed_img_tbl(init_env: None) -> catalog.Table:
|
|
159
|
-
# cl =
|
|
159
|
+
# cl = pxt.Client()
|
|
160
160
|
# db = cl.create_db('test_indexed')
|
|
161
161
|
@pytest.fixture(scope='function')
|
|
162
162
|
def indexed_img_tbl(test_client: pxt.Client) -> catalog.Table:
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import pixeltable as pxt
|
|
4
|
+
import pixeltable.exceptions as excs
|
|
5
|
+
from pixeltable.tests.utils import skip_test_if_not_installed, validate_update_status
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestFireworks:
|
|
9
|
+
|
|
10
|
+
def test_fireworks(self, test_client: pxt.Client) -> None:
|
|
11
|
+
skip_test_if_not_installed('fireworks')
|
|
12
|
+
TestFireworks.skip_test_if_no_fireworks_client()
|
|
13
|
+
cl = test_client
|
|
14
|
+
t = cl.create_table('test_tbl', {'input': pxt.StringType()})
|
|
15
|
+
from pixeltable.functions.fireworks import chat_completions
|
|
16
|
+
messages = [{'role': 'user', 'content': t.input}]
|
|
17
|
+
t['output'] = chat_completions(
|
|
18
|
+
messages=messages,
|
|
19
|
+
model='accounts/fireworks/models/llama-v2-7b-chat'
|
|
20
|
+
)
|
|
21
|
+
t['output_2'] = chat_completions(
|
|
22
|
+
messages=messages,
|
|
23
|
+
model='accounts/fireworks/models/llama-v2-7b-chat',
|
|
24
|
+
max_tokens=300,
|
|
25
|
+
top_k=40,
|
|
26
|
+
top_p=0.9,
|
|
27
|
+
temperature=0.7
|
|
28
|
+
)
|
|
29
|
+
validate_update_status(t.insert(input="How's everything going today?"), 1)
|
|
30
|
+
results = t.collect()
|
|
31
|
+
assert len(results['output'][0]['choices'][0]['message']['content']) > 0
|
|
32
|
+
assert len(results['output_2'][0]['choices'][0]['message']['content']) > 0
|
|
33
|
+
|
|
34
|
+
# This ensures that the test will be skipped, rather than returning an error, when no API key is
|
|
35
|
+
# available (for example, when a PR runs in CI).
|
|
36
|
+
@staticmethod
|
|
37
|
+
def skip_test_if_no_fireworks_client() -> None:
|
|
38
|
+
try:
|
|
39
|
+
import pixeltable.functions.fireworks
|
|
40
|
+
_ = pixeltable.functions.fireworks.fireworks_client()
|
|
41
|
+
except excs.Error as exc:
|
|
42
|
+
pytest.skip(str(exc))
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import pixeltable as pxt
|
|
2
|
+
from pixeltable import catalog
|
|
3
|
+
from pixeltable.functions.pil.image import blend
|
|
4
|
+
from pixeltable.iterators import FrameIterator
|
|
5
|
+
from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed
|
|
6
|
+
from pixeltable.type_system import VideoType, StringType
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestFunctions:
|
|
10
|
+
def test_pil(self, img_tbl: catalog.Table) -> None:
|
|
11
|
+
t = img_tbl
|
|
12
|
+
_ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
|
|
13
|
+
|
|
14
|
+
def test_eval_detections(self, test_client: pxt.Client) -> None:
|
|
15
|
+
skip_test_if_not_installed('nos')
|
|
16
|
+
cl = test_client
|
|
17
|
+
video_t = cl.create_table('video_tbl', {'video': VideoType()})
|
|
18
|
+
# create frame view
|
|
19
|
+
args = {'video': video_t.video, 'fps': 1}
|
|
20
|
+
v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
|
|
21
|
+
|
|
22
|
+
files = get_video_files()
|
|
23
|
+
video_t.insert(video=files[-1])
|
|
24
|
+
v.add_column(frame_s=v.frame.resize([640, 480]))
|
|
25
|
+
from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
|
|
26
|
+
v.add_column(detections_a=yolox_nano(v.frame_s))
|
|
27
|
+
v.add_column(detections_b=yolox_small(v.frame_s))
|
|
28
|
+
v.add_column(gt=yolox_large(v.frame_s))
|
|
29
|
+
from pixeltable.functions.eval import eval_detections, mean_ap
|
|
30
|
+
res = v.select(
|
|
31
|
+
eval_detections(
|
|
32
|
+
v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
|
|
33
|
+
)).show()
|
|
34
|
+
v.add_column(
|
|
35
|
+
eval_a=eval_detections(
|
|
36
|
+
v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
|
|
37
|
+
v.add_column(
|
|
38
|
+
eval_b=eval_detections(
|
|
39
|
+
v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
|
|
40
|
+
ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
|
|
41
|
+
ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
|
|
42
|
+
common_classes = set(ap_a.keys()) & set(ap_b.keys())
|
|
43
|
+
|
|
44
|
+
## TODO: following assertion is failing on CI,
|
|
45
|
+
# It is not necessarily a bug, as assert codition is not expected to be always true
|
|
46
|
+
# for k in common_classes:
|
|
47
|
+
# assert ap_a[k] <= ap_b[k]
|
|
48
|
+
|
|
49
|
+
def test_str(self, test_client: pxt.Client) -> None:
|
|
50
|
+
cl = test_client
|
|
51
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
52
|
+
from pixeltable.functions.string import str_format
|
|
53
|
+
t.add_column(s1=str_format('ABC {0}', t.input))
|
|
54
|
+
t.add_column(s2=str_format('DEF {this}', this=t.input))
|
|
55
|
+
t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
|
|
56
|
+
status = t.insert(input='MNO')
|
|
57
|
+
assert status.num_rows == 1
|
|
58
|
+
assert status.num_excs == 0
|
|
59
|
+
row = t.head()[0]
|
|
60
|
+
assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
|
|
@@ -3,144 +3,12 @@ from typing import Dict, Any
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
5
|
import pixeltable as pxt
|
|
6
|
-
from pixeltable import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
from pixeltable.functions.pil.image import blend
|
|
10
|
-
from pixeltable.iterators import FrameIterator
|
|
11
|
-
from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed, get_sentences, get_image_files
|
|
12
|
-
from pixeltable.type_system import VideoType, StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
|
|
6
|
+
from pixeltable.tests.utils import skip_test_if_not_installed, get_sentences, get_image_files, \
|
|
7
|
+
SAMPLE_IMAGE_URL
|
|
8
|
+
from pixeltable.type_system import StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
|
|
13
9
|
|
|
14
10
|
|
|
15
|
-
class
|
|
16
|
-
def test_pil(self, img_tbl: catalog.Table) -> None:
|
|
17
|
-
t = img_tbl
|
|
18
|
-
_ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
|
|
19
|
-
|
|
20
|
-
def test_eval_detections(self, test_client: pxt.Client) -> None:
|
|
21
|
-
skip_test_if_not_installed('nos')
|
|
22
|
-
cl = test_client
|
|
23
|
-
video_t = cl.create_table('video_tbl', {'video': VideoType()})
|
|
24
|
-
# create frame view
|
|
25
|
-
args = {'video': video_t.video, 'fps': 1}
|
|
26
|
-
v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
|
|
27
|
-
|
|
28
|
-
files = get_video_files()
|
|
29
|
-
video_t.insert(video=files[-1])
|
|
30
|
-
v.add_column(frame_s=v.frame.resize([640, 480]))
|
|
31
|
-
from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
|
|
32
|
-
v.add_column(detections_a=yolox_nano(v.frame_s))
|
|
33
|
-
v.add_column(detections_b=yolox_small(v.frame_s))
|
|
34
|
-
v.add_column(gt=yolox_large(v.frame_s))
|
|
35
|
-
from pixeltable.functions.eval import eval_detections, mean_ap
|
|
36
|
-
res = v.select(
|
|
37
|
-
eval_detections(
|
|
38
|
-
v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
|
|
39
|
-
)).show()
|
|
40
|
-
v.add_column(
|
|
41
|
-
eval_a=eval_detections(
|
|
42
|
-
v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
|
|
43
|
-
v.add_column(
|
|
44
|
-
eval_b=eval_detections(
|
|
45
|
-
v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
|
|
46
|
-
ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
|
|
47
|
-
ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
|
|
48
|
-
common_classes = set(ap_a.keys()) & set(ap_b.keys())
|
|
49
|
-
|
|
50
|
-
## TODO: following assertion is failing on CI,
|
|
51
|
-
# It is not necessarily a bug, as assert codition is not expected to be always true
|
|
52
|
-
# for k in common_classes:
|
|
53
|
-
# assert ap_a[k] <= ap_b[k]
|
|
54
|
-
|
|
55
|
-
def test_str(self, test_client: pxt.Client) -> None:
|
|
56
|
-
cl = test_client
|
|
57
|
-
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
58
|
-
from pixeltable.functions.string import str_format
|
|
59
|
-
t.add_column(s1=str_format('ABC {0}', t.input))
|
|
60
|
-
t.add_column(s2=str_format('DEF {this}', this=t.input))
|
|
61
|
-
t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
|
|
62
|
-
status = t.insert(input='MNO')
|
|
63
|
-
assert status.num_rows == 1
|
|
64
|
-
assert status.num_excs == 0
|
|
65
|
-
row = t.head()[0]
|
|
66
|
-
assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
|
|
67
|
-
|
|
68
|
-
def test_openai(self, test_client: pxt.Client) -> None:
|
|
69
|
-
skip_test_if_not_installed('openai')
|
|
70
|
-
TestFunctions.skip_test_if_no_openai_client()
|
|
71
|
-
cl = test_client
|
|
72
|
-
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
73
|
-
from pixeltable.functions.openai import chat_completions, embeddings, moderations
|
|
74
|
-
msgs = [
|
|
75
|
-
{"role": "system", "content": "You are a helpful assistant."},
|
|
76
|
-
{"role": "user", "content": t.input}
|
|
77
|
-
]
|
|
78
|
-
t.add_column(input_msgs=msgs)
|
|
79
|
-
t.add_column(chat_output=chat_completions(model='gpt-3.5-turbo', messages=t.input_msgs))
|
|
80
|
-
# with inlined messages
|
|
81
|
-
t.add_column(chat_output2=chat_completions(model='gpt-3.5-turbo', messages=msgs))
|
|
82
|
-
t.add_column(ada_embed=embeddings(model='text-embedding-ada-002', input=t.input))
|
|
83
|
-
t.add_column(text_3=embeddings(model='text-embedding-3-small', input=t.input))
|
|
84
|
-
t.add_column(moderation=moderations(input=t.input))
|
|
85
|
-
t.insert(input='I find you really annoying')
|
|
86
|
-
_ = t.head()
|
|
87
|
-
|
|
88
|
-
def test_gpt_4_vision(self, test_client: pxt.Client) -> None:
|
|
89
|
-
skip_test_if_not_installed('openai')
|
|
90
|
-
TestFunctions.skip_test_if_no_openai_client()
|
|
91
|
-
cl = test_client
|
|
92
|
-
t = cl.create_table('test_tbl', {'prompt': StringType(), 'img': ImageType()})
|
|
93
|
-
from pixeltable.functions.openai import chat_completions
|
|
94
|
-
from pixeltable.functions.string import str_format
|
|
95
|
-
msgs = [
|
|
96
|
-
{'role': 'user',
|
|
97
|
-
'content': [
|
|
98
|
-
{'type': 'text', 'text': t.prompt},
|
|
99
|
-
{'type': 'image_url', 'image_url': {
|
|
100
|
-
'url': str_format('data:image/png;base64,{0}', t.img.b64_encode())
|
|
101
|
-
}}
|
|
102
|
-
]}
|
|
103
|
-
]
|
|
104
|
-
t.add_column(response=chat_completions(model='gpt-4-vision-preview', messages=msgs, max_tokens=300))
|
|
105
|
-
t.add_column(response_content=t.response.choices[0].message.content)
|
|
106
|
-
t.insert(prompt="What's in this image?", img=_sample_image_url)
|
|
107
|
-
result = t.collect()['response_content'][0]
|
|
108
|
-
assert len(result) > 0
|
|
109
|
-
|
|
110
|
-
@staticmethod
|
|
111
|
-
def skip_test_if_no_openai_client() -> None:
|
|
112
|
-
try:
|
|
113
|
-
_ = Env.get().openai_client
|
|
114
|
-
except excs.Error as exc:
|
|
115
|
-
pytest.skip(str(exc))
|
|
116
|
-
|
|
117
|
-
def test_together(self, test_client: pxt.Client) -> None:
|
|
118
|
-
skip_test_if_not_installed('together')
|
|
119
|
-
if not Env.get().has_together_client:
|
|
120
|
-
pytest.skip(f'Together client does not exist (missing API key?)')
|
|
121
|
-
cl = test_client
|
|
122
|
-
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
123
|
-
from pixeltable.functions.together import completions
|
|
124
|
-
t.add_column(output=completions(prompt=t.input, model='mistralai/Mixtral-8x7B-v0.1', stop=['\n']))
|
|
125
|
-
t.add_column(output_text=t.output.output.choices[0].text)
|
|
126
|
-
t.insert(input='I am going to the ')
|
|
127
|
-
result = t.select(t.output_text).collect()['output_text'][0]
|
|
128
|
-
assert len(result) > 0
|
|
129
|
-
|
|
130
|
-
def test_fireworks(self, test_client: pxt.Client) -> None:
|
|
131
|
-
skip_test_if_not_installed('fireworks')
|
|
132
|
-
try:
|
|
133
|
-
from pixeltable.functions.fireworks import initialize
|
|
134
|
-
initialize()
|
|
135
|
-
except:
|
|
136
|
-
pytest.skip(f'Fireworks client does not exist (missing API key?)')
|
|
137
|
-
cl = test_client
|
|
138
|
-
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
139
|
-
from pixeltable.functions.fireworks import chat_completions
|
|
140
|
-
t['output'] = chat_completions(prompt=t.input, model='accounts/fireworks/models/llama-v2-7b-chat', max_tokens=256).choices[0].text
|
|
141
|
-
t.insert(input='I am going to the ')
|
|
142
|
-
result = t.select(t.output).collect()['output'][0]
|
|
143
|
-
assert len(result) > 0
|
|
11
|
+
class TestHuggingface:
|
|
144
12
|
|
|
145
13
|
def test_hf_function(self, test_client: pxt.Client) -> None:
|
|
146
14
|
skip_test_if_not_installed('sentence_transformers')
|
|
@@ -281,14 +149,10 @@ class TestFunctions:
|
|
|
281
149
|
t = cl.create_table('test_tbl', {'img': ImageType()})
|
|
282
150
|
from pixeltable.functions.huggingface import detr_for_object_detection
|
|
283
151
|
t['detect'] = detr_for_object_detection(t.img, model_id='facebook/detr-resnet-50', threshold=0.8)
|
|
284
|
-
status = t.insert(img=
|
|
152
|
+
status = t.insert(img=SAMPLE_IMAGE_URL)
|
|
285
153
|
assert status.num_rows == 1
|
|
286
154
|
assert status.num_excs == 0
|
|
287
155
|
result = t.select(t.detect).collect()[0]['detect']
|
|
288
156
|
assert 'orange' in result['label_text']
|
|
289
157
|
assert 'bowl' in result['label_text']
|
|
290
158
|
assert 'broccoli' in result['label_text']
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
_sample_image_url = \
|
|
294
|
-
'https://raw.githubusercontent.com/pixeltable/pixeltable/master/docs/source/data/images/000000000009.jpg'
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import pixeltable as pxt
|
|
4
|
+
import pixeltable.exceptions as excs
|
|
5
|
+
from pixeltable.tests.utils import SAMPLE_IMAGE_URL, skip_test_if_not_installed, validate_update_status
|
|
6
|
+
from pixeltable.type_system import StringType, ImageType
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestOpenai:
|
|
10
|
+
|
|
11
|
+
def test_audio(self, test_client: pxt.Client) -> None:
|
|
12
|
+
skip_test_if_not_installed('openai')
|
|
13
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
14
|
+
cl = test_client
|
|
15
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
16
|
+
from pixeltable.functions.openai import speech, transcriptions, translations
|
|
17
|
+
t.add_column(speech=speech(t.input, model='tts-1', voice='onyx'))
|
|
18
|
+
t.add_column(speech_2=speech(t.input, model='tts-1', voice='onyx', response_format='flac', speed=1.05))
|
|
19
|
+
t.add_column(transcription=transcriptions(t.speech, model='whisper-1'))
|
|
20
|
+
t.add_column(transcription_2=transcriptions(
|
|
21
|
+
t.speech, model='whisper-1', language='en', prompt='Transcribe the contents of this recording.'
|
|
22
|
+
))
|
|
23
|
+
t.add_column(translation=translations(t.speech, model='whisper-1'))
|
|
24
|
+
t.add_column(translation_2=translations(
|
|
25
|
+
t.speech, model='whisper-1', prompt='Translate the recording from Spanish into English.', temperature=0.7
|
|
26
|
+
))
|
|
27
|
+
validate_update_status(t.insert([
|
|
28
|
+
{'input': 'I am a banana.'},
|
|
29
|
+
{'input': 'Es fácil traducir del español al inglés.'}
|
|
30
|
+
]), expected_rows=2)
|
|
31
|
+
# The audio generation -> transcription loop on these examples should be simple and clear enough
|
|
32
|
+
# that the unit test can reliably expect the output closely enough to pass these checks.
|
|
33
|
+
results = t.collect()
|
|
34
|
+
assert results[0]['transcription']['text'] in ['I am a banana.', "I'm a banana."]
|
|
35
|
+
assert results[0]['transcription_2']['text'] in ['I am a banana.', "I'm a banana."]
|
|
36
|
+
assert 'easy to translate from Spanish' in results[1]['translation']['text']
|
|
37
|
+
assert 'easy to translate from Spanish' in results[1]['translation_2']['text']
|
|
38
|
+
|
|
39
|
+
def test_chat_completions(self, test_client: pxt.Client) -> None:
|
|
40
|
+
skip_test_if_not_installed('openai')
|
|
41
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
42
|
+
cl = test_client
|
|
43
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
44
|
+
from pixeltable.functions.openai import chat_completions
|
|
45
|
+
msgs = [
|
|
46
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
47
|
+
{"role": "user", "content": t.input}
|
|
48
|
+
]
|
|
49
|
+
t.add_column(input_msgs=msgs)
|
|
50
|
+
t.add_column(chat_output=chat_completions(model='gpt-3.5-turbo', messages=t.input_msgs))
|
|
51
|
+
# with inlined messages
|
|
52
|
+
t.add_column(chat_output_2=chat_completions(model='gpt-3.5-turbo', messages=msgs))
|
|
53
|
+
# test a bunch of the parameters
|
|
54
|
+
t.add_column(chat_output_3=chat_completions(
|
|
55
|
+
model='gpt-3.5-turbo', messages=msgs, frequency_penalty=0.1, logprobs=True, top_logprobs=3,
|
|
56
|
+
max_tokens=500, n=3, presence_penalty=0.1, seed=4171780, stop=['\n'], temperature=0.7, top_p=0.8,
|
|
57
|
+
user='pixeltable'
|
|
58
|
+
))
|
|
59
|
+
# test with JSON output enforced
|
|
60
|
+
t.add_column(chat_output_4=chat_completions(
|
|
61
|
+
model='gpt-3.5-turbo', messages=msgs, response_format={'type': 'json_object'}
|
|
62
|
+
))
|
|
63
|
+
# TODO Also test the `tools` and `tool_choice` parameters.
|
|
64
|
+
validate_update_status(t.insert(input='Give me an example of a typical JSON structure.'), 1)
|
|
65
|
+
result = t.collect()
|
|
66
|
+
assert len(result['chat_output'][0]['choices'][0]['message']['content']) > 0
|
|
67
|
+
assert len(result['chat_output_2'][0]['choices'][0]['message']['content']) > 0
|
|
68
|
+
assert len(result['chat_output_3'][0]['choices'][0]['message']['content']) > 0
|
|
69
|
+
assert len(result['chat_output_4'][0]['choices'][0]['message']['content']) > 0
|
|
70
|
+
|
|
71
|
+
# When OpenAI gets a request with `response_format` equal to `json_object`, but the prompt does not
|
|
72
|
+
# contain the string "json", it refuses the request.
|
|
73
|
+
# TODO This should probably not be throwing an exception, but rather logging the error in
|
|
74
|
+
# `t.chat_output_4.errormsg` etc.
|
|
75
|
+
with pytest.raises(excs.ExprEvalError) as exc_info:
|
|
76
|
+
t.insert(input='Say something interesting.')
|
|
77
|
+
assert "\\'messages\\' must contain the word \\'json\\'" in str(exc_info.value)
|
|
78
|
+
|
|
79
|
+
def test_gpt_4_vision(self, test_client: pxt.Client) -> None:
|
|
80
|
+
skip_test_if_not_installed('openai')
|
|
81
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
82
|
+
cl = test_client
|
|
83
|
+
t = cl.create_table('test_tbl', {'prompt': StringType(), 'img': ImageType()})
|
|
84
|
+
from pixeltable.functions.openai import chat_completions, vision
|
|
85
|
+
from pixeltable.functions.string import str_format
|
|
86
|
+
t.add_column(response=vision(prompt="What's in this image?", image=t.img))
|
|
87
|
+
# Also get the response the low-level way, by calling chat_completions
|
|
88
|
+
msgs = [
|
|
89
|
+
{'role': 'user',
|
|
90
|
+
'content': [
|
|
91
|
+
{'type': 'text', 'text': t.prompt},
|
|
92
|
+
{'type': 'image_url', 'image_url': {
|
|
93
|
+
'url': str_format('data:image/png;base64,{0}', t.img.b64_encode())
|
|
94
|
+
}}
|
|
95
|
+
]}
|
|
96
|
+
]
|
|
97
|
+
t.add_column(response_2=chat_completions(model='gpt-4-vision-preview', messages=msgs, max_tokens=300).choices[0].message.content)
|
|
98
|
+
validate_update_status(t.insert(prompt="What's in this image?", img=SAMPLE_IMAGE_URL), 1)
|
|
99
|
+
result = t.collect()['response_2'][0]
|
|
100
|
+
assert len(result) > 0
|
|
101
|
+
|
|
102
|
+
def test_embeddings(self, test_client: pxt.Client) -> None:
|
|
103
|
+
skip_test_if_not_installed('openai')
|
|
104
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
105
|
+
cl = test_client
|
|
106
|
+
from pixeltable.functions.openai import embeddings
|
|
107
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
108
|
+
t.add_column(ada_embed=embeddings(model='text-embedding-ada-002', input=t.input))
|
|
109
|
+
t.add_column(text_3=embeddings(model='text-embedding-3-small', input=t.input, user='pixeltable'))
|
|
110
|
+
validate_update_status(t.insert(input='Say something interesting.'), 1)
|
|
111
|
+
_ = t.head()
|
|
112
|
+
|
|
113
|
+
def test_moderations(self, test_client: pxt.Client) -> None:
|
|
114
|
+
skip_test_if_not_installed('openai')
|
|
115
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
116
|
+
cl = test_client
|
|
117
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
118
|
+
from pixeltable.functions.openai import moderations
|
|
119
|
+
t.add_column(moderation=moderations(input=t.input))
|
|
120
|
+
t.add_column(moderation_2=moderations(input=t.input, model='text-moderation-stable'))
|
|
121
|
+
validate_update_status(t.insert(input='Say something interesting.'), 1)
|
|
122
|
+
_ = t.head()
|
|
123
|
+
|
|
124
|
+
def test_image_generations(self, test_client: pxt.Client) -> None:
|
|
125
|
+
skip_test_if_not_installed('openai')
|
|
126
|
+
TestOpenai.skip_test_if_no_openai_client()
|
|
127
|
+
cl = test_client
|
|
128
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
129
|
+
from pixeltable.functions.openai import image_generations
|
|
130
|
+
t.add_column(img=image_generations(t.input))
|
|
131
|
+
# Test dall-e-2 options
|
|
132
|
+
t.add_column(img_2=image_generations(
|
|
133
|
+
t.input, model='dall-e-2', size='512x512', user='pixeltable'
|
|
134
|
+
))
|
|
135
|
+
# Test dall-e-3 options
|
|
136
|
+
t.add_column(img_3=image_generations(
|
|
137
|
+
t.input, model='dall-e-3', quality='hd', size='1792x1024', style='natural', user='pixeltable'
|
|
138
|
+
))
|
|
139
|
+
validate_update_status(t.insert(input='A friendly dinosaur playing tennis in a cornfield'), 1)
|
|
140
|
+
assert t.collect()['img'][0].size == (1024, 1024)
|
|
141
|
+
assert t.collect()['img_2'][0].size == (512, 512)
|
|
142
|
+
assert t.collect()['img_3'][0].size == (1792, 1024)
|
|
143
|
+
|
|
144
|
+
# This ensures that the test will be skipped, rather than returning an error, when no API key is
|
|
145
|
+
# available (for example, when a PR runs in CI).
|
|
146
|
+
@staticmethod
|
|
147
|
+
def skip_test_if_no_openai_client() -> None:
|
|
148
|
+
try:
|
|
149
|
+
import pixeltable.functions.openai
|
|
150
|
+
_ = pixeltable.functions.openai.openai_client()
|
|
151
|
+
except excs.Error as exc:
|
|
152
|
+
pytest.skip(str(exc))
|