pixeltable 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

@@ -1,27 +1,122 @@
1
+ import base64
2
+ import io
1
3
  from typing import Optional
2
4
 
5
+ import PIL.Image
6
+ import numpy as np
7
+ import together
8
+
3
9
  import pixeltable as pxt
10
+ from pixeltable import env
11
+ from pixeltable.func import Batch
12
+
13
+
14
+ def together_client() -> together.Together:
15
+ return env.Env.get().get_client('together', lambda api_key: together.Together(api_key=api_key))
4
16
 
5
17
 
6
18
  @pxt.udf
7
19
  def completions(
8
20
  prompt: str,
21
+ *,
9
22
  model: str,
10
23
  max_tokens: Optional[int] = None,
11
- repetition_penalty: Optional[float] = None,
12
24
  stop: Optional[list] = None,
13
- top_k: Optional[int] = None,
25
+ temperature: Optional[float] = None,
14
26
  top_p: Optional[float] = None,
15
- temperature: Optional[float] = None
27
+ top_k: Optional[int] = None,
28
+ repetition_penalty: Optional[float] = None,
29
+ logprobs: Optional[int] = None,
30
+ echo: Optional[bool] = None,
31
+ n: Optional[int] = None,
32
+ safety_model: Optional[str] = None
16
33
  ) -> dict:
17
- import together
18
- return together.Complete.create(
19
- prompt,
20
- model,
34
+ return together_client().completions.create(
35
+ prompt=prompt,
36
+ model=model,
21
37
  max_tokens=max_tokens,
22
- repetition_penalty=repetition_penalty,
23
38
  stop=stop,
39
+ temperature=temperature,
40
+ top_p=top_p,
24
41
  top_k=top_k,
42
+ repetition_penalty=repetition_penalty,
43
+ logprobs=logprobs,
44
+ echo=echo,
45
+ n=n,
46
+ safety_model=safety_model
47
+ ).dict()
48
+
49
+
50
+ @pxt.udf
51
+ def chat_completions(
52
+ messages: list[dict[str, str]],
53
+ *,
54
+ model: str,
55
+ max_tokens: Optional[int] = None,
56
+ stop: Optional[list[str]] = None,
57
+ temperature: Optional[float] = None,
58
+ top_p: Optional[float] = None,
59
+ top_k: Optional[int] = None,
60
+ repetition_penalty: Optional[float] = None,
61
+ logprobs: Optional[int] = None,
62
+ echo: Optional[bool] = None,
63
+ n: Optional[int] = None,
64
+ safety_model: Optional[str] = None,
65
+ response_format: Optional[dict] = None,
66
+ tools: Optional[dict] = None,
67
+ tool_choice: Optional[dict] = None
68
+ ) -> dict:
69
+ return together_client().chat.completions.create(
70
+ messages=messages,
71
+ model=model,
72
+ max_tokens=max_tokens,
73
+ stop=stop,
74
+ temperature=temperature,
25
75
  top_p=top_p,
26
- temperature=temperature
76
+ top_k=top_k,
77
+ repetition_penalty=repetition_penalty,
78
+ logprobs=logprobs,
79
+ echo=echo,
80
+ n=n,
81
+ safety_model=safety_model,
82
+ response_format=response_format,
83
+ tools=tools,
84
+ tool_choice=tool_choice
85
+ ).dict()
86
+
87
+
88
+ @pxt.udf(batch_size=32, return_type=pxt.ArrayType((None,), dtype=pxt.FloatType()))
89
+ def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
90
+ result = together_client().embeddings.create(input=input, model=model)
91
+ return [
92
+ np.array(data.embedding, dtype=np.float64)
93
+ for data in result.data
94
+ ]
95
+
96
+
97
+ @pxt.udf
98
+ def image_generations(
99
+ prompt: str,
100
+ *,
101
+ model: str,
102
+ steps: Optional[int] = None,
103
+ seed: Optional[int] = None,
104
+ height: Optional[int] = None,
105
+ width: Optional[int] = None,
106
+ negative_prompt: Optional[str] = None,
107
+ ) -> PIL.Image.Image:
108
+ # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
109
+ result = together_client().images.generate(
110
+ prompt=prompt,
111
+ model=model,
112
+ steps=steps,
113
+ seed=seed,
114
+ height=height,
115
+ width=width,
116
+ negative_prompt=negative_prompt
27
117
  )
118
+ b64_str = result.data[0].b64_json
119
+ b64_bytes = base64.b64decode(b64_str)
120
+ img = PIL.Image.open(io.BytesIO(b64_bytes))
121
+ img.load()
122
+ return img
@@ -80,8 +80,8 @@ def test_tbl(test_client: pxt.Client) -> catalog.Table:
80
80
  return create_test_tbl(test_client)
81
81
 
82
82
  # @pytest.fixture(scope='function')
83
- # def test_stored_fn(test_client: pt.Client) -> pt.Function:
84
- # @pt.udf(return_type=pt.IntType(), param_types=[pt.IntType()])
83
+ # def test_stored_fn(test_client: pxt.Client) -> pxt.Function:
84
+ # @pxt.udf(return_type=pxt.IntType(), param_types=[pxt.IntType()])
85
85
  # def test_fn(x):
86
86
  # return x + 1
87
87
  # test_client.create_function('test_fn', test_fn)
@@ -89,7 +89,7 @@ def test_tbl(test_client: pxt.Client) -> catalog.Table:
89
89
 
90
90
  @pytest.fixture(scope='function')
91
91
  def test_tbl_exprs(test_tbl: catalog.Table) -> List[exprs.Expr]:
92
- #def test_tbl_exprs(test_tbl: catalog.Table, test_stored_fn: pt.Function) -> List[exprs.Expr]:
92
+ #def test_tbl_exprs(test_tbl: catalog.Table, test_stored_fn: pxt.Function) -> List[exprs.Expr]:
93
93
 
94
94
  t = test_tbl
95
95
  return [
@@ -156,7 +156,7 @@ def img_tbl_exprs(img_tbl: catalog.Table) -> List[exprs.Expr]:
156
156
  # TODO: why does this not work with a session scope? (some user tables don't get created with create_all())
157
157
  #@pytest.fixture(scope='session')
158
158
  #def indexed_img_tbl(init_env: None) -> catalog.Table:
159
- # cl = pt.Client()
159
+ # cl = pxt.Client()
160
160
  # db = cl.create_db('test_indexed')
161
161
  @pytest.fixture(scope='function')
162
162
  def indexed_img_tbl(test_client: pxt.Client) -> catalog.Table:
@@ -0,0 +1,42 @@
1
+ import pytest
2
+
3
+ import pixeltable as pxt
4
+ import pixeltable.exceptions as excs
5
+ from pixeltable.tests.utils import skip_test_if_not_installed, validate_update_status
6
+
7
+
8
+ class TestFireworks:
9
+
10
+ def test_fireworks(self, test_client: pxt.Client) -> None:
11
+ skip_test_if_not_installed('fireworks')
12
+ TestFireworks.skip_test_if_no_fireworks_client()
13
+ cl = test_client
14
+ t = cl.create_table('test_tbl', {'input': pxt.StringType()})
15
+ from pixeltable.functions.fireworks import chat_completions
16
+ messages = [{'role': 'user', 'content': t.input}]
17
+ t['output'] = chat_completions(
18
+ messages=messages,
19
+ model='accounts/fireworks/models/llama-v2-7b-chat'
20
+ )
21
+ t['output_2'] = chat_completions(
22
+ messages=messages,
23
+ model='accounts/fireworks/models/llama-v2-7b-chat',
24
+ max_tokens=300,
25
+ top_k=40,
26
+ top_p=0.9,
27
+ temperature=0.7
28
+ )
29
+ validate_update_status(t.insert(input="How's everything going today?"), 1)
30
+ results = t.collect()
31
+ assert len(results['output'][0]['choices'][0]['message']['content']) > 0
32
+ assert len(results['output_2'][0]['choices'][0]['message']['content']) > 0
33
+
34
+ # This ensures that the test will be skipped, rather than returning an error, when no API key is
35
+ # available (for example, when a PR runs in CI).
36
+ @staticmethod
37
+ def skip_test_if_no_fireworks_client() -> None:
38
+ try:
39
+ import pixeltable.functions.fireworks
40
+ _ = pixeltable.functions.fireworks.fireworks_client()
41
+ except excs.Error as exc:
42
+ pytest.skip(str(exc))
@@ -0,0 +1,60 @@
1
+ import pixeltable as pxt
2
+ from pixeltable import catalog
3
+ from pixeltable.functions.pil.image import blend
4
+ from pixeltable.iterators import FrameIterator
5
+ from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed
6
+ from pixeltable.type_system import VideoType, StringType
7
+
8
+
9
+ class TestFunctions:
10
+ def test_pil(self, img_tbl: catalog.Table) -> None:
11
+ t = img_tbl
12
+ _ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
13
+
14
+ def test_eval_detections(self, test_client: pxt.Client) -> None:
15
+ skip_test_if_not_installed('nos')
16
+ cl = test_client
17
+ video_t = cl.create_table('video_tbl', {'video': VideoType()})
18
+ # create frame view
19
+ args = {'video': video_t.video, 'fps': 1}
20
+ v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
21
+
22
+ files = get_video_files()
23
+ video_t.insert(video=files[-1])
24
+ v.add_column(frame_s=v.frame.resize([640, 480]))
25
+ from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
26
+ v.add_column(detections_a=yolox_nano(v.frame_s))
27
+ v.add_column(detections_b=yolox_small(v.frame_s))
28
+ v.add_column(gt=yolox_large(v.frame_s))
29
+ from pixeltable.functions.eval import eval_detections, mean_ap
30
+ res = v.select(
31
+ eval_detections(
32
+ v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
33
+ )).show()
34
+ v.add_column(
35
+ eval_a=eval_detections(
36
+ v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
37
+ v.add_column(
38
+ eval_b=eval_detections(
39
+ v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
40
+ ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
41
+ ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
42
+ common_classes = set(ap_a.keys()) & set(ap_b.keys())
43
+
44
+ ## TODO: following assertion is failing on CI,
45
+ # It is not necessarily a bug, as assert codition is not expected to be always true
46
+ # for k in common_classes:
47
+ # assert ap_a[k] <= ap_b[k]
48
+
49
+ def test_str(self, test_client: pxt.Client) -> None:
50
+ cl = test_client
51
+ t = cl.create_table('test_tbl', {'input': StringType()})
52
+ from pixeltable.functions.string import str_format
53
+ t.add_column(s1=str_format('ABC {0}', t.input))
54
+ t.add_column(s2=str_format('DEF {this}', this=t.input))
55
+ t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
56
+ status = t.insert(input='MNO')
57
+ assert status.num_rows == 1
58
+ assert status.num_excs == 0
59
+ row = t.head()[0]
60
+ assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
@@ -3,144 +3,12 @@ from typing import Dict, Any
3
3
  import pytest
4
4
 
5
5
  import pixeltable as pxt
6
- from pixeltable import catalog
7
- from pixeltable.env import Env
8
- import pixeltable.exceptions as excs
9
- from pixeltable.functions.pil.image import blend
10
- from pixeltable.iterators import FrameIterator
11
- from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed, get_sentences, get_image_files
12
- from pixeltable.type_system import VideoType, StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
6
+ from pixeltable.tests.utils import skip_test_if_not_installed, get_sentences, get_image_files, \
7
+ SAMPLE_IMAGE_URL
8
+ from pixeltable.type_system import StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
13
9
 
14
10
 
15
- class TestFunctions:
16
- def test_pil(self, img_tbl: catalog.Table) -> None:
17
- t = img_tbl
18
- _ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
19
-
20
- def test_eval_detections(self, test_client: pxt.Client) -> None:
21
- skip_test_if_not_installed('nos')
22
- cl = test_client
23
- video_t = cl.create_table('video_tbl', {'video': VideoType()})
24
- # create frame view
25
- args = {'video': video_t.video, 'fps': 1}
26
- v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
27
-
28
- files = get_video_files()
29
- video_t.insert(video=files[-1])
30
- v.add_column(frame_s=v.frame.resize([640, 480]))
31
- from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
32
- v.add_column(detections_a=yolox_nano(v.frame_s))
33
- v.add_column(detections_b=yolox_small(v.frame_s))
34
- v.add_column(gt=yolox_large(v.frame_s))
35
- from pixeltable.functions.eval import eval_detections, mean_ap
36
- res = v.select(
37
- eval_detections(
38
- v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
39
- )).show()
40
- v.add_column(
41
- eval_a=eval_detections(
42
- v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
43
- v.add_column(
44
- eval_b=eval_detections(
45
- v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
46
- ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
47
- ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
48
- common_classes = set(ap_a.keys()) & set(ap_b.keys())
49
-
50
- ## TODO: following assertion is failing on CI,
51
- # It is not necessarily a bug, as assert codition is not expected to be always true
52
- # for k in common_classes:
53
- # assert ap_a[k] <= ap_b[k]
54
-
55
- def test_str(self, test_client: pxt.Client) -> None:
56
- cl = test_client
57
- t = cl.create_table('test_tbl', {'input': StringType()})
58
- from pixeltable.functions.string import str_format
59
- t.add_column(s1=str_format('ABC {0}', t.input))
60
- t.add_column(s2=str_format('DEF {this}', this=t.input))
61
- t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
62
- status = t.insert(input='MNO')
63
- assert status.num_rows == 1
64
- assert status.num_excs == 0
65
- row = t.head()[0]
66
- assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
67
-
68
- def test_openai(self, test_client: pxt.Client) -> None:
69
- skip_test_if_not_installed('openai')
70
- TestFunctions.skip_test_if_no_openai_client()
71
- cl = test_client
72
- t = cl.create_table('test_tbl', {'input': StringType()})
73
- from pixeltable.functions.openai import chat_completions, embeddings, moderations
74
- msgs = [
75
- {"role": "system", "content": "You are a helpful assistant."},
76
- {"role": "user", "content": t.input}
77
- ]
78
- t.add_column(input_msgs=msgs)
79
- t.add_column(chat_output=chat_completions(model='gpt-3.5-turbo', messages=t.input_msgs))
80
- # with inlined messages
81
- t.add_column(chat_output2=chat_completions(model='gpt-3.5-turbo', messages=msgs))
82
- t.add_column(ada_embed=embeddings(model='text-embedding-ada-002', input=t.input))
83
- t.add_column(text_3=embeddings(model='text-embedding-3-small', input=t.input))
84
- t.add_column(moderation=moderations(input=t.input))
85
- t.insert(input='I find you really annoying')
86
- _ = t.head()
87
-
88
- def test_gpt_4_vision(self, test_client: pxt.Client) -> None:
89
- skip_test_if_not_installed('openai')
90
- TestFunctions.skip_test_if_no_openai_client()
91
- cl = test_client
92
- t = cl.create_table('test_tbl', {'prompt': StringType(), 'img': ImageType()})
93
- from pixeltable.functions.openai import chat_completions
94
- from pixeltable.functions.string import str_format
95
- msgs = [
96
- {'role': 'user',
97
- 'content': [
98
- {'type': 'text', 'text': t.prompt},
99
- {'type': 'image_url', 'image_url': {
100
- 'url': str_format('data:image/png;base64,{0}', t.img.b64_encode())
101
- }}
102
- ]}
103
- ]
104
- t.add_column(response=chat_completions(model='gpt-4-vision-preview', messages=msgs, max_tokens=300))
105
- t.add_column(response_content=t.response.choices[0].message.content)
106
- t.insert(prompt="What's in this image?", img=_sample_image_url)
107
- result = t.collect()['response_content'][0]
108
- assert len(result) > 0
109
-
110
- @staticmethod
111
- def skip_test_if_no_openai_client() -> None:
112
- try:
113
- _ = Env.get().openai_client
114
- except excs.Error as exc:
115
- pytest.skip(str(exc))
116
-
117
- def test_together(self, test_client: pxt.Client) -> None:
118
- skip_test_if_not_installed('together')
119
- if not Env.get().has_together_client:
120
- pytest.skip(f'Together client does not exist (missing API key?)')
121
- cl = test_client
122
- t = cl.create_table('test_tbl', {'input': StringType()})
123
- from pixeltable.functions.together import completions
124
- t.add_column(output=completions(prompt=t.input, model='mistralai/Mixtral-8x7B-v0.1', stop=['\n']))
125
- t.add_column(output_text=t.output.output.choices[0].text)
126
- t.insert(input='I am going to the ')
127
- result = t.select(t.output_text).collect()['output_text'][0]
128
- assert len(result) > 0
129
-
130
- def test_fireworks(self, test_client: pxt.Client) -> None:
131
- skip_test_if_not_installed('fireworks')
132
- try:
133
- from pixeltable.functions.fireworks import initialize
134
- initialize()
135
- except:
136
- pytest.skip(f'Fireworks client does not exist (missing API key?)')
137
- cl = test_client
138
- t = cl.create_table('test_tbl', {'input': StringType()})
139
- from pixeltable.functions.fireworks import chat_completions
140
- t['output'] = chat_completions(prompt=t.input, model='accounts/fireworks/models/llama-v2-7b-chat', max_tokens=256).choices[0].text
141
- t.insert(input='I am going to the ')
142
- result = t.select(t.output).collect()['output'][0]
143
- assert len(result) > 0
11
+ class TestHuggingface:
144
12
 
145
13
  def test_hf_function(self, test_client: pxt.Client) -> None:
146
14
  skip_test_if_not_installed('sentence_transformers')
@@ -281,14 +149,10 @@ class TestFunctions:
281
149
  t = cl.create_table('test_tbl', {'img': ImageType()})
282
150
  from pixeltable.functions.huggingface import detr_for_object_detection
283
151
  t['detect'] = detr_for_object_detection(t.img, model_id='facebook/detr-resnet-50', threshold=0.8)
284
- status = t.insert(img=_sample_image_url)
152
+ status = t.insert(img=SAMPLE_IMAGE_URL)
285
153
  assert status.num_rows == 1
286
154
  assert status.num_excs == 0
287
155
  result = t.select(t.detect).collect()[0]['detect']
288
156
  assert 'orange' in result['label_text']
289
157
  assert 'bowl' in result['label_text']
290
158
  assert 'broccoli' in result['label_text']
291
-
292
-
293
- _sample_image_url = \
294
- 'https://raw.githubusercontent.com/pixeltable/pixeltable/master/docs/source/data/images/000000000009.jpg'
@@ -0,0 +1,152 @@
1
+ import pytest
2
+
3
+ import pixeltable as pxt
4
+ import pixeltable.exceptions as excs
5
+ from pixeltable.tests.utils import SAMPLE_IMAGE_URL, skip_test_if_not_installed, validate_update_status
6
+ from pixeltable.type_system import StringType, ImageType
7
+
8
+
9
+ class TestOpenai:
10
+
11
+ def test_audio(self, test_client: pxt.Client) -> None:
12
+ skip_test_if_not_installed('openai')
13
+ TestOpenai.skip_test_if_no_openai_client()
14
+ cl = test_client
15
+ t = cl.create_table('test_tbl', {'input': StringType()})
16
+ from pixeltable.functions.openai import speech, transcriptions, translations
17
+ t.add_column(speech=speech(t.input, model='tts-1', voice='onyx'))
18
+ t.add_column(speech_2=speech(t.input, model='tts-1', voice='onyx', response_format='flac', speed=1.05))
19
+ t.add_column(transcription=transcriptions(t.speech, model='whisper-1'))
20
+ t.add_column(transcription_2=transcriptions(
21
+ t.speech, model='whisper-1', language='en', prompt='Transcribe the contents of this recording.'
22
+ ))
23
+ t.add_column(translation=translations(t.speech, model='whisper-1'))
24
+ t.add_column(translation_2=translations(
25
+ t.speech, model='whisper-1', prompt='Translate the recording from Spanish into English.', temperature=0.7
26
+ ))
27
+ validate_update_status(t.insert([
28
+ {'input': 'I am a banana.'},
29
+ {'input': 'Es fácil traducir del español al inglés.'}
30
+ ]), expected_rows=2)
31
+ # The audio generation -> transcription loop on these examples should be simple and clear enough
32
+ # that the unit test can reliably expect the output closely enough to pass these checks.
33
+ results = t.collect()
34
+ assert results[0]['transcription']['text'] in ['I am a banana.', "I'm a banana."]
35
+ assert results[0]['transcription_2']['text'] in ['I am a banana.', "I'm a banana."]
36
+ assert 'easy to translate from Spanish' in results[1]['translation']['text']
37
+ assert 'easy to translate from Spanish' in results[1]['translation_2']['text']
38
+
39
+ def test_chat_completions(self, test_client: pxt.Client) -> None:
40
+ skip_test_if_not_installed('openai')
41
+ TestOpenai.skip_test_if_no_openai_client()
42
+ cl = test_client
43
+ t = cl.create_table('test_tbl', {'input': StringType()})
44
+ from pixeltable.functions.openai import chat_completions
45
+ msgs = [
46
+ {"role": "system", "content": "You are a helpful assistant."},
47
+ {"role": "user", "content": t.input}
48
+ ]
49
+ t.add_column(input_msgs=msgs)
50
+ t.add_column(chat_output=chat_completions(model='gpt-3.5-turbo', messages=t.input_msgs))
51
+ # with inlined messages
52
+ t.add_column(chat_output_2=chat_completions(model='gpt-3.5-turbo', messages=msgs))
53
+ # test a bunch of the parameters
54
+ t.add_column(chat_output_3=chat_completions(
55
+ model='gpt-3.5-turbo', messages=msgs, frequency_penalty=0.1, logprobs=True, top_logprobs=3,
56
+ max_tokens=500, n=3, presence_penalty=0.1, seed=4171780, stop=['\n'], temperature=0.7, top_p=0.8,
57
+ user='pixeltable'
58
+ ))
59
+ # test with JSON output enforced
60
+ t.add_column(chat_output_4=chat_completions(
61
+ model='gpt-3.5-turbo', messages=msgs, response_format={'type': 'json_object'}
62
+ ))
63
+ # TODO Also test the `tools` and `tool_choice` parameters.
64
+ validate_update_status(t.insert(input='Give me an example of a typical JSON structure.'), 1)
65
+ result = t.collect()
66
+ assert len(result['chat_output'][0]['choices'][0]['message']['content']) > 0
67
+ assert len(result['chat_output_2'][0]['choices'][0]['message']['content']) > 0
68
+ assert len(result['chat_output_3'][0]['choices'][0]['message']['content']) > 0
69
+ assert len(result['chat_output_4'][0]['choices'][0]['message']['content']) > 0
70
+
71
+ # When OpenAI gets a request with `response_format` equal to `json_object`, but the prompt does not
72
+ # contain the string "json", it refuses the request.
73
+ # TODO This should probably not be throwing an exception, but rather logging the error in
74
+ # `t.chat_output_4.errormsg` etc.
75
+ with pytest.raises(excs.ExprEvalError) as exc_info:
76
+ t.insert(input='Say something interesting.')
77
+ assert "\\'messages\\' must contain the word \\'json\\'" in str(exc_info.value)
78
+
79
+ def test_gpt_4_vision(self, test_client: pxt.Client) -> None:
80
+ skip_test_if_not_installed('openai')
81
+ TestOpenai.skip_test_if_no_openai_client()
82
+ cl = test_client
83
+ t = cl.create_table('test_tbl', {'prompt': StringType(), 'img': ImageType()})
84
+ from pixeltable.functions.openai import chat_completions, vision
85
+ from pixeltable.functions.string import str_format
86
+ t.add_column(response=vision(prompt="What's in this image?", image=t.img))
87
+ # Also get the response the low-level way, by calling chat_completions
88
+ msgs = [
89
+ {'role': 'user',
90
+ 'content': [
91
+ {'type': 'text', 'text': t.prompt},
92
+ {'type': 'image_url', 'image_url': {
93
+ 'url': str_format('data:image/png;base64,{0}', t.img.b64_encode())
94
+ }}
95
+ ]}
96
+ ]
97
+ t.add_column(response_2=chat_completions(model='gpt-4-vision-preview', messages=msgs, max_tokens=300).choices[0].message.content)
98
+ validate_update_status(t.insert(prompt="What's in this image?", img=SAMPLE_IMAGE_URL), 1)
99
+ result = t.collect()['response_2'][0]
100
+ assert len(result) > 0
101
+
102
+ def test_embeddings(self, test_client: pxt.Client) -> None:
103
+ skip_test_if_not_installed('openai')
104
+ TestOpenai.skip_test_if_no_openai_client()
105
+ cl = test_client
106
+ from pixeltable.functions.openai import embeddings
107
+ t = cl.create_table('test_tbl', {'input': StringType()})
108
+ t.add_column(ada_embed=embeddings(model='text-embedding-ada-002', input=t.input))
109
+ t.add_column(text_3=embeddings(model='text-embedding-3-small', input=t.input, user='pixeltable'))
110
+ validate_update_status(t.insert(input='Say something interesting.'), 1)
111
+ _ = t.head()
112
+
113
+ def test_moderations(self, test_client: pxt.Client) -> None:
114
+ skip_test_if_not_installed('openai')
115
+ TestOpenai.skip_test_if_no_openai_client()
116
+ cl = test_client
117
+ t = cl.create_table('test_tbl', {'input': StringType()})
118
+ from pixeltable.functions.openai import moderations
119
+ t.add_column(moderation=moderations(input=t.input))
120
+ t.add_column(moderation_2=moderations(input=t.input, model='text-moderation-stable'))
121
+ validate_update_status(t.insert(input='Say something interesting.'), 1)
122
+ _ = t.head()
123
+
124
+ def test_image_generations(self, test_client: pxt.Client) -> None:
125
+ skip_test_if_not_installed('openai')
126
+ TestOpenai.skip_test_if_no_openai_client()
127
+ cl = test_client
128
+ t = cl.create_table('test_tbl', {'input': StringType()})
129
+ from pixeltable.functions.openai import image_generations
130
+ t.add_column(img=image_generations(t.input))
131
+ # Test dall-e-2 options
132
+ t.add_column(img_2=image_generations(
133
+ t.input, model='dall-e-2', size='512x512', user='pixeltable'
134
+ ))
135
+ # Test dall-e-3 options
136
+ t.add_column(img_3=image_generations(
137
+ t.input, model='dall-e-3', quality='hd', size='1792x1024', style='natural', user='pixeltable'
138
+ ))
139
+ validate_update_status(t.insert(input='A friendly dinosaur playing tennis in a cornfield'), 1)
140
+ assert t.collect()['img'][0].size == (1024, 1024)
141
+ assert t.collect()['img_2'][0].size == (512, 512)
142
+ assert t.collect()['img_3'][0].size == (1792, 1024)
143
+
144
+ # This ensures that the test will be skipped, rather than returning an error, when no API key is
145
+ # available (for example, when a PR runs in CI).
146
+ @staticmethod
147
+ def skip_test_if_no_openai_client() -> None:
148
+ try:
149
+ import pixeltable.functions.openai
150
+ _ = pixeltable.functions.openai.openai_client()
151
+ except excs.Error as exc:
152
+ pytest.skip(str(exc))