pixeltable 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/catalog/column.py +1 -1
- pixeltable/client.py +72 -2
- pixeltable/env.py +36 -52
- pixeltable/functions/__init__.py +1 -1
- pixeltable/functions/fireworks.py +10 -37
- pixeltable/functions/openai.py +192 -24
- pixeltable/functions/together.py +104 -9
- pixeltable/tests/conftest.py +4 -4
- pixeltable/tests/functions/test_fireworks.py +42 -0
- pixeltable/tests/functions/test_functions.py +60 -0
- pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +5 -141
- pixeltable/tests/functions/test_openai.py +152 -0
- pixeltable/tests/functions/test_together.py +111 -0
- pixeltable/tests/test_dataframe.py +4 -4
- pixeltable/tests/test_table.py +105 -2
- pixeltable/tests/utils.py +128 -5
- pixeltable/type_system.py +41 -84
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/hf_datasets.py +157 -0
- pixeltable/utils/parquet.py +68 -27
- pixeltable/utils/pytorch.py +16 -97
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/METADATA +33 -27
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/RECORD +25 -19
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.4.dist-info}/WHEEL +0 -0
pixeltable/catalog/column.py
CHANGED
|
@@ -61,7 +61,7 @@ class Column:
|
|
|
61
61
|
raise excs.Error(f"Invalid column name: '{name}'")
|
|
62
62
|
self.name = name
|
|
63
63
|
if col_type is None and computed_with is None:
|
|
64
|
-
raise excs.Error(f'Column {name}
|
|
64
|
+
raise excs.Error(f'Column `{name}`: col_type is required if computed_with is not specified')
|
|
65
65
|
|
|
66
66
|
self.value_expr: Optional['Expr'] = None
|
|
67
67
|
self.compute_func: Optional[Callable] = None
|
pixeltable/client.py
CHANGED
|
@@ -2,12 +2,11 @@ from typing import List, Optional, Dict, Type, Any, Union
|
|
|
2
2
|
import pandas as pd
|
|
3
3
|
import logging
|
|
4
4
|
import dataclasses
|
|
5
|
-
from uuid import UUID
|
|
6
|
-
from collections import defaultdict
|
|
7
5
|
|
|
8
6
|
import sqlalchemy as sql
|
|
9
7
|
import sqlalchemy.orm as orm
|
|
10
8
|
|
|
9
|
+
import pixeltable
|
|
11
10
|
from pixeltable.metadata import schema
|
|
12
11
|
from pixeltable.env import Env
|
|
13
12
|
import pixeltable.func as func
|
|
@@ -16,6 +15,10 @@ from pixeltable import exceptions as excs
|
|
|
16
15
|
from pixeltable.exprs import Predicate
|
|
17
16
|
from pixeltable.iterators import ComponentIterator
|
|
18
17
|
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
import datasets
|
|
21
|
+
|
|
19
22
|
__all__ = [
|
|
20
23
|
'Client',
|
|
21
24
|
]
|
|
@@ -155,6 +158,73 @@ class Client:
|
|
|
155
158
|
_logger.info(f'Created table `{path_str}`.')
|
|
156
159
|
return tbl
|
|
157
160
|
|
|
161
|
+
def import_parquet(
|
|
162
|
+
self,
|
|
163
|
+
table_path: str,
|
|
164
|
+
*,
|
|
165
|
+
parquet_path: str,
|
|
166
|
+
schema_override: Optional[Dict[str, Any]] = None,
|
|
167
|
+
**kwargs,
|
|
168
|
+
) -> catalog.InsertableTable:
|
|
169
|
+
"""Create a new `InsertableTable` from a Parquet file or set of files. Requires pyarrow to be installed.
|
|
170
|
+
Args:
|
|
171
|
+
path_str: Path to the table within pixeltable.
|
|
172
|
+
parquet_path: Path to an individual Parquet file or directory of Parquet files.
|
|
173
|
+
schema_override: Optional dictionary mapping column names to column type to override the default
|
|
174
|
+
schema inferred from the Parquet file. The column type should be a pixeltable ColumnType.
|
|
175
|
+
For example, {'col_vid': VideoType()}, rather than {'col_vid': StringType()}.
|
|
176
|
+
Any fields not provided explicitly will map to types with `pixeltable.utils.parquet.parquet_schema_to_pixeltable_schema`
|
|
177
|
+
kwargs: Additional arguments to pass to `Client.create_table`.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
The newly created table. The table will have loaded the data from the Parquet file(s).
|
|
181
|
+
"""
|
|
182
|
+
from pixeltable.utils import parquet
|
|
183
|
+
|
|
184
|
+
return parquet.import_parquet(
|
|
185
|
+
self,
|
|
186
|
+
table_path=table_path,
|
|
187
|
+
parquet_path=parquet_path,
|
|
188
|
+
schema_override=schema_override,
|
|
189
|
+
**kwargs,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def import_huggingface_dataset(
|
|
193
|
+
self,
|
|
194
|
+
table_path: str,
|
|
195
|
+
dataset: Union['datasets.Dataset', 'datasets.DatasetDict'],
|
|
196
|
+
*,
|
|
197
|
+
column_name_for_split: Optional[str] = 'split',
|
|
198
|
+
schema_override: Optional[Dict[str, Any]] = None,
|
|
199
|
+
**kwargs
|
|
200
|
+
) -> catalog.InsertableTable:
|
|
201
|
+
"""Create a new `InsertableTable` from a Huggingface dataset, or dataset dict with multiple splits.
|
|
202
|
+
Requires datasets library to be installed.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
path_str: Path to the table.
|
|
206
|
+
dataset: Huggingface datasts.Dataset or datasts.DatasetDict to insert into the table.
|
|
207
|
+
column_name_for_split: column name to use for split information. If None, no split information will be stored.
|
|
208
|
+
schema_override: Optional dictionary mapping column names to column type to override the corresponding defaults from
|
|
209
|
+
`pixeltable.utils.hf_datasets.huggingface_schema_to_pixeltable_schema`. The column type should be a pixeltable ColumnType.
|
|
210
|
+
For example, {'col_vid': VideoType()}, rather than {'col_vid': StringType()}.
|
|
211
|
+
|
|
212
|
+
kwargs: Additional arguments to pass to `create_table`.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
The newly created table. The table will have loaded the data from the dataset.
|
|
216
|
+
"""
|
|
217
|
+
from pixeltable.utils import hf_datasets
|
|
218
|
+
|
|
219
|
+
return hf_datasets.import_huggingface_dataset(
|
|
220
|
+
self,
|
|
221
|
+
table_path,
|
|
222
|
+
dataset,
|
|
223
|
+
column_name_for_split=column_name_for_split,
|
|
224
|
+
schema_override=schema_override,
|
|
225
|
+
**kwargs,
|
|
226
|
+
)
|
|
227
|
+
|
|
158
228
|
def create_view(
|
|
159
229
|
self, path_str: str, base: catalog.Table, *, schema: Optional[Dict[str, Any]] = None,
|
|
160
230
|
filter: Optional[Predicate] = None,
|
pixeltable/env.py
CHANGED
|
@@ -1,33 +1,28 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import datetime
|
|
3
|
-
import
|
|
4
|
-
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
import sqlalchemy as sql
|
|
7
|
-
import uuid
|
|
4
|
+
import glob
|
|
5
|
+
import http.server
|
|
8
6
|
import importlib
|
|
9
7
|
import importlib.util
|
|
10
|
-
|
|
11
|
-
import
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
12
10
|
import socketserver
|
|
11
|
+
import sys
|
|
13
12
|
import threading
|
|
14
13
|
import typing
|
|
15
14
|
import uuid
|
|
16
15
|
from pathlib import Path
|
|
17
|
-
from typing import Optional, Dict, Any, List
|
|
16
|
+
from typing import Callable, Optional, Dict, Any, List
|
|
18
17
|
|
|
18
|
+
import pgserver
|
|
19
|
+
import sqlalchemy as sql
|
|
19
20
|
import yaml
|
|
20
21
|
from sqlalchemy_utils.functions import database_exists, create_database, drop_database
|
|
21
|
-
import pgserver
|
|
22
|
-
import logging
|
|
23
|
-
import sys
|
|
24
|
-
import glob
|
|
25
22
|
|
|
26
|
-
from pixeltable import metadata
|
|
27
23
|
import pixeltable.exceptions as excs
|
|
24
|
+
from pixeltable import metadata
|
|
28
25
|
|
|
29
|
-
if typing.TYPE_CHECKING:
|
|
30
|
-
import openai
|
|
31
26
|
|
|
32
27
|
class Env:
|
|
33
28
|
"""
|
|
@@ -59,12 +54,12 @@ class Env:
|
|
|
59
54
|
# package name -> version; version == []: package is installed, but we haven't determined the version yet
|
|
60
55
|
self._installed_packages: Dict[str, Optional[List[int]]] = {}
|
|
61
56
|
self._nos_client: Optional[Any] = None
|
|
62
|
-
self._openai_client: Optional['openai.OpenAI'] = None
|
|
63
|
-
self._has_together_client: bool = False
|
|
64
57
|
self._spacy_nlp: Optional[Any] = None # spacy.Language
|
|
65
58
|
self._httpd: Optional[socketserver.TCPServer] = None
|
|
66
59
|
self._http_address: Optional[str] = None
|
|
67
60
|
|
|
61
|
+
self._registered_clients: dict[str, Any] = {}
|
|
62
|
+
|
|
68
63
|
# logging-related state
|
|
69
64
|
self._logger = logging.getLogger('pixeltable')
|
|
70
65
|
self._logger.setLevel(logging.DEBUG) # allow everything to pass, we filter in _log_filter()
|
|
@@ -256,31 +251,32 @@ class Env:
|
|
|
256
251
|
from pixeltable.functions.util import create_nos_modules
|
|
257
252
|
_ = create_nos_modules()
|
|
258
253
|
|
|
259
|
-
def
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
254
|
+
def get_client(self, name: str, init: Callable, environ: Optional[str] = None) -> Any:
|
|
255
|
+
"""
|
|
256
|
+
Gets the client with the specified name, using `init` to construct one if necessary.
|
|
257
|
+
|
|
258
|
+
- name: The name of the client
|
|
259
|
+
- init: A `Callable` with signature `fn(api_key: str) -> Any` that constructs a client object
|
|
260
|
+
- environ: The name of the environment variable to use for the API key, if no API key is found in config
|
|
261
|
+
(defaults to f'{name.upper()}_API_KEY')
|
|
262
|
+
"""
|
|
263
|
+
if name in self._registered_clients:
|
|
264
|
+
return self._registered_clients[name]
|
|
265
|
+
|
|
266
|
+
if environ is None:
|
|
267
|
+
environ = f'{name.upper()}_API_KEY'
|
|
271
268
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
api_key = self._config['together']['api_key']
|
|
269
|
+
if name in self._config and 'api_key' in self._config[name]:
|
|
270
|
+
api_key = self._config[name]['api_key']
|
|
275
271
|
else:
|
|
276
|
-
api_key = os.environ.get(
|
|
272
|
+
api_key = os.environ.get(environ)
|
|
277
273
|
if api_key is None or api_key == '':
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
self.
|
|
282
|
-
|
|
283
|
-
|
|
274
|
+
raise excs.Error(f'`{name}` client not initialized (no API key configured).')
|
|
275
|
+
|
|
276
|
+
client = init(api_key)
|
|
277
|
+
self._registered_clients[name] = client
|
|
278
|
+
self._logger.info(f'Initialized `{name}` client.')
|
|
279
|
+
return client
|
|
284
280
|
|
|
285
281
|
def _start_web_server(self) -> None:
|
|
286
282
|
"""
|
|
@@ -319,6 +315,7 @@ class Env:
|
|
|
319
315
|
else:
|
|
320
316
|
self._installed_packages[package] = None
|
|
321
317
|
|
|
318
|
+
check('datasets')
|
|
322
319
|
check('torch')
|
|
323
320
|
check('torchvision')
|
|
324
321
|
check('transformers')
|
|
@@ -332,8 +329,6 @@ class Env:
|
|
|
332
329
|
check('tiktoken')
|
|
333
330
|
check('openai')
|
|
334
331
|
check('together')
|
|
335
|
-
if self.is_installed_package('together'):
|
|
336
|
-
self._create_together_client()
|
|
337
332
|
check('fireworks')
|
|
338
333
|
check('nos')
|
|
339
334
|
if self.is_installed_package('nos'):
|
|
@@ -399,17 +394,6 @@ class Env:
|
|
|
399
394
|
def nos_client(self) -> Any:
|
|
400
395
|
return self._nos_client
|
|
401
396
|
|
|
402
|
-
@property
|
|
403
|
-
def openai_client(self) -> 'openai.OpenAI':
|
|
404
|
-
if self._openai_client is None:
|
|
405
|
-
self._create_openai_client()
|
|
406
|
-
assert self._openai_client is not None
|
|
407
|
-
return self._openai_client
|
|
408
|
-
|
|
409
|
-
@property
|
|
410
|
-
def has_together_client(self) -> bool:
|
|
411
|
-
return self._has_together_client
|
|
412
|
-
|
|
413
397
|
@property
|
|
414
398
|
def spacy_nlp(self) -> Any:
|
|
415
399
|
assert self._spacy_nlp is not None
|
pixeltable/functions/__init__.py
CHANGED
|
@@ -15,7 +15,7 @@ import pixeltable.functions.pil.image
|
|
|
15
15
|
from pixeltable import exprs
|
|
16
16
|
from pixeltable.type_system import IntType, ColumnType, FloatType, ImageType, VideoType
|
|
17
17
|
# automatically import all submodules so that the udfs get registered
|
|
18
|
-
from . import image, string, video,
|
|
18
|
+
from . import image, string, video, huggingface
|
|
19
19
|
|
|
20
20
|
# TODO: remove and replace calls with astype()
|
|
21
21
|
def cast(expr: exprs.Expr, target_type: ColumnType) -> exprs.Expr:
|
|
@@ -1,61 +1,34 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import os
|
|
3
1
|
from typing import Optional
|
|
4
2
|
|
|
3
|
+
import fireworks.client
|
|
4
|
+
|
|
5
5
|
import pixeltable as pxt
|
|
6
|
-
import pixeltable.exceptions as excs
|
|
7
6
|
from pixeltable import env
|
|
8
7
|
|
|
9
8
|
|
|
9
|
+
def fireworks_client() -> fireworks.client.Fireworks:
|
|
10
|
+
return env.Env.get().get_client('fireworks', lambda api_key: fireworks.client.Fireworks(api_key=api_key))
|
|
11
|
+
|
|
12
|
+
|
|
10
13
|
@pxt.udf
|
|
11
14
|
def chat_completions(
|
|
12
|
-
|
|
13
|
-
model: str,
|
|
15
|
+
messages: list[dict[str, str]],
|
|
14
16
|
*,
|
|
17
|
+
model: str,
|
|
15
18
|
max_tokens: Optional[int] = None,
|
|
16
|
-
repetition_penalty: Optional[float] = None,
|
|
17
19
|
top_k: Optional[int] = None,
|
|
18
20
|
top_p: Optional[float] = None,
|
|
19
21
|
temperature: Optional[float] = None
|
|
20
22
|
) -> dict:
|
|
21
|
-
initialize()
|
|
22
23
|
kwargs = {
|
|
23
24
|
'max_tokens': max_tokens,
|
|
24
|
-
'repetition_penalty': repetition_penalty,
|
|
25
25
|
'top_k': top_k,
|
|
26
26
|
'top_p': top_p,
|
|
27
27
|
'temperature': temperature
|
|
28
28
|
}
|
|
29
29
|
kwargs_not_none = dict(filter(lambda x: x[1] is not None, kwargs.items()))
|
|
30
|
-
|
|
31
|
-
return fireworks.client.Completion.create(
|
|
30
|
+
return fireworks_client().chat.completions.create(
|
|
32
31
|
model=model,
|
|
33
|
-
|
|
32
|
+
messages=messages,
|
|
34
33
|
**kwargs_not_none
|
|
35
34
|
).dict()
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def initialize():
|
|
39
|
-
global _is_fireworks_initialized
|
|
40
|
-
if _is_fireworks_initialized:
|
|
41
|
-
return
|
|
42
|
-
|
|
43
|
-
_logger.info('Initializing Fireworks client.')
|
|
44
|
-
|
|
45
|
-
config = pxt.env.Env.get().config
|
|
46
|
-
|
|
47
|
-
if 'fireworks' in config and 'api_key' in config['fireworks']:
|
|
48
|
-
api_key = config['fireworks']['api_key']
|
|
49
|
-
else:
|
|
50
|
-
api_key = os.environ.get('FIREWORKS_API_KEY')
|
|
51
|
-
if api_key is None or api_key == '':
|
|
52
|
-
raise excs.Error('Fireworks client not initialized (no API key configured).')
|
|
53
|
-
|
|
54
|
-
import fireworks.client
|
|
55
|
-
|
|
56
|
-
fireworks.client.api_key = api_key
|
|
57
|
-
_is_fireworks_initialized = True
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
_logger = logging.getLogger('pixeltable')
|
|
61
|
-
_is_fireworks_initialized = False
|
pixeltable/functions/openai.py
CHANGED
|
@@ -1,9 +1,14 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import io
|
|
3
|
-
|
|
3
|
+
import pathlib
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import Optional, TypeVar, Union, Callable
|
|
4
6
|
|
|
5
7
|
import PIL.Image
|
|
6
8
|
import numpy as np
|
|
9
|
+
import openai
|
|
10
|
+
import tenacity
|
|
11
|
+
from openai._types import NOT_GIVEN, NotGiven
|
|
7
12
|
|
|
8
13
|
import pixeltable as pxt
|
|
9
14
|
import pixeltable.type_system as ts
|
|
@@ -11,43 +16,148 @@ from pixeltable import env
|
|
|
11
16
|
from pixeltable.func import Batch
|
|
12
17
|
|
|
13
18
|
|
|
19
|
+
def openai_client() -> openai.OpenAI:
|
|
20
|
+
return env.Env.get().get_client('openai', lambda api_key: openai.OpenAI(api_key=api_key))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Exponential backoff decorator using tenacity.
|
|
24
|
+
# TODO(aaron-siegel): Right now this hardwires random exponential backoff with defaults suggested
|
|
25
|
+
# by OpenAI. Should we investigate making this more customizable in the future?
|
|
26
|
+
def _retry(fn: Callable) -> Callable:
|
|
27
|
+
return tenacity.retry(
|
|
28
|
+
retry=tenacity.retry_if_exception_type(openai.RateLimitError),
|
|
29
|
+
wait=tenacity.wait_random_exponential(min=1, max=60),
|
|
30
|
+
stop=tenacity.stop_after_attempt(6)
|
|
31
|
+
)(fn)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
#####################################
|
|
35
|
+
# Audio Endpoints
|
|
36
|
+
|
|
37
|
+
@pxt.udf(return_type=ts.AudioType())
|
|
38
|
+
@_retry
|
|
39
|
+
def speech(
|
|
40
|
+
input: str,
|
|
41
|
+
*,
|
|
42
|
+
model: str,
|
|
43
|
+
voice: str,
|
|
44
|
+
response_format: Optional[str] = None,
|
|
45
|
+
speed: Optional[float] = None
|
|
46
|
+
) -> str:
|
|
47
|
+
content = openai_client().audio.speech.create(
|
|
48
|
+
input=input,
|
|
49
|
+
model=model,
|
|
50
|
+
voice=voice,
|
|
51
|
+
response_format=_opt(response_format),
|
|
52
|
+
speed=_opt(speed)
|
|
53
|
+
)
|
|
54
|
+
ext = response_format or 'mp3'
|
|
55
|
+
output_filename = str(env.Env.get().tmp_dir / f"{uuid.uuid4()}.{ext}")
|
|
56
|
+
content.stream_to_file(output_filename, chunk_size=1 << 20)
|
|
57
|
+
return output_filename
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pxt.udf(
|
|
61
|
+
param_types=[ts.AudioType(), ts.StringType(), ts.StringType(nullable=True),
|
|
62
|
+
ts.StringType(nullable=True), ts.FloatType(nullable=True)]
|
|
63
|
+
)
|
|
64
|
+
@_retry
|
|
65
|
+
def transcriptions(
|
|
66
|
+
audio: str,
|
|
67
|
+
*,
|
|
68
|
+
model: str,
|
|
69
|
+
language: Optional[str] = None,
|
|
70
|
+
prompt: Optional[str] = None,
|
|
71
|
+
temperature: Optional[float] = None
|
|
72
|
+
) -> dict:
|
|
73
|
+
file = pathlib.Path(audio)
|
|
74
|
+
transcription = openai_client().audio.transcriptions.create(
|
|
75
|
+
file=file,
|
|
76
|
+
model=model,
|
|
77
|
+
language=_opt(language),
|
|
78
|
+
prompt=_opt(prompt),
|
|
79
|
+
temperature=_opt(temperature)
|
|
80
|
+
)
|
|
81
|
+
return transcription.dict()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@pxt.udf(
|
|
85
|
+
param_types=[ts.AudioType(), ts.StringType(), ts.StringType(nullable=True), ts.FloatType(nullable=True)]
|
|
86
|
+
)
|
|
87
|
+
@_retry
|
|
88
|
+
def translations(
|
|
89
|
+
audio: str,
|
|
90
|
+
*,
|
|
91
|
+
model: str,
|
|
92
|
+
prompt: Optional[str] = None,
|
|
93
|
+
temperature: Optional[float] = None
|
|
94
|
+
) -> dict:
|
|
95
|
+
file = pathlib.Path(audio)
|
|
96
|
+
translation = openai_client().audio.translations.create(
|
|
97
|
+
file=file,
|
|
98
|
+
model=model,
|
|
99
|
+
prompt=_opt(prompt),
|
|
100
|
+
temperature=_opt(temperature)
|
|
101
|
+
)
|
|
102
|
+
return translation.dict()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
#####################################
|
|
106
|
+
# Chat Endpoints
|
|
107
|
+
|
|
14
108
|
@pxt.udf
|
|
109
|
+
@_retry
|
|
15
110
|
def chat_completions(
|
|
16
111
|
messages: list,
|
|
112
|
+
*,
|
|
17
113
|
model: str,
|
|
18
114
|
frequency_penalty: Optional[float] = None,
|
|
19
|
-
logit_bias: Optional[dict] = None,
|
|
115
|
+
logit_bias: Optional[dict[str, int]] = None,
|
|
116
|
+
logprobs: Optional[bool] = None,
|
|
117
|
+
top_logprobs: Optional[int] = None,
|
|
20
118
|
max_tokens: Optional[int] = None,
|
|
21
119
|
n: Optional[int] = None,
|
|
22
120
|
presence_penalty: Optional[float] = None,
|
|
23
121
|
response_format: Optional[dict] = None,
|
|
24
122
|
seed: Optional[int] = None,
|
|
123
|
+
stop: Optional[list[str]] = None,
|
|
124
|
+
temperature: Optional[float] = None,
|
|
25
125
|
top_p: Optional[float] = None,
|
|
26
|
-
|
|
126
|
+
tools: Optional[list[dict]] = None,
|
|
127
|
+
tool_choice: Optional[dict] = None,
|
|
128
|
+
user: Optional[str] = None
|
|
27
129
|
) -> dict:
|
|
28
|
-
|
|
29
|
-
result = env.Env.get().openai_client.chat.completions.create(
|
|
130
|
+
result = openai_client().chat.completions.create(
|
|
30
131
|
messages=messages,
|
|
31
132
|
model=model,
|
|
32
|
-
frequency_penalty=frequency_penalty
|
|
33
|
-
logit_bias=logit_bias
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
133
|
+
frequency_penalty=_opt(frequency_penalty),
|
|
134
|
+
logit_bias=_opt(logit_bias),
|
|
135
|
+
logprobs=_opt(logprobs),
|
|
136
|
+
top_logprobs=_opt(top_logprobs),
|
|
137
|
+
max_tokens=_opt(max_tokens),
|
|
138
|
+
n=_opt(n),
|
|
139
|
+
presence_penalty=_opt(presence_penalty),
|
|
140
|
+
response_format=_opt(response_format),
|
|
141
|
+
seed=_opt(seed),
|
|
142
|
+
stop=_opt(stop),
|
|
143
|
+
temperature=_opt(temperature),
|
|
144
|
+
top_p=_opt(top_p),
|
|
145
|
+
tools=_opt(tools),
|
|
146
|
+
tool_choice=_opt(tool_choice),
|
|
147
|
+
user=_opt(user)
|
|
41
148
|
)
|
|
42
149
|
return result.dict()
|
|
43
150
|
|
|
44
151
|
|
|
45
152
|
@pxt.udf
|
|
153
|
+
@_retry
|
|
46
154
|
def vision(
|
|
47
155
|
prompt: str,
|
|
48
156
|
image: PIL.Image.Image,
|
|
157
|
+
*,
|
|
49
158
|
model: str = 'gpt-4-vision-preview'
|
|
50
159
|
) -> str:
|
|
160
|
+
# TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
|
|
51
161
|
bytes_arr = io.BytesIO()
|
|
52
162
|
image.save(bytes_arr, format='png')
|
|
53
163
|
b64_bytes = base64.b64encode(bytes_arr.getvalue())
|
|
@@ -61,28 +171,86 @@ def vision(
|
|
|
61
171
|
}}
|
|
62
172
|
]}
|
|
63
173
|
]
|
|
64
|
-
result =
|
|
174
|
+
result = openai_client().chat.completions.create(
|
|
65
175
|
messages=messages,
|
|
66
176
|
model=model
|
|
67
177
|
)
|
|
68
178
|
return result.choices[0].message.content
|
|
69
179
|
|
|
70
180
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
result = env.Env().get().openai_client.moderations.create(input=input, model=model)
|
|
74
|
-
return result.dict()
|
|
75
|
-
|
|
181
|
+
#####################################
|
|
182
|
+
# Embeddings Endpoints
|
|
76
183
|
|
|
77
184
|
@pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
|
|
78
|
-
|
|
79
|
-
|
|
185
|
+
@_retry
|
|
186
|
+
def embeddings(
|
|
187
|
+
input: Batch[str],
|
|
188
|
+
*,
|
|
189
|
+
model: str,
|
|
190
|
+
user: Optional[str] = None
|
|
191
|
+
) -> Batch[np.ndarray]:
|
|
192
|
+
result = openai_client().embeddings.create(
|
|
80
193
|
input=input,
|
|
81
194
|
model=model,
|
|
195
|
+
user=_opt(user),
|
|
82
196
|
encoding_format='float'
|
|
83
197
|
)
|
|
84
|
-
|
|
198
|
+
return [
|
|
85
199
|
np.array(data.embedding, dtype=np.float64)
|
|
86
200
|
for data in result.data
|
|
87
201
|
]
|
|
88
|
-
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
#####################################
|
|
205
|
+
# Images Endpoints
|
|
206
|
+
|
|
207
|
+
@pxt.udf
|
|
208
|
+
@_retry
|
|
209
|
+
def image_generations(
|
|
210
|
+
prompt: str,
|
|
211
|
+
*,
|
|
212
|
+
model: Optional[str] = None,
|
|
213
|
+
quality: Optional[str] = None,
|
|
214
|
+
size: Optional[str] = None,
|
|
215
|
+
style: Optional[str] = None,
|
|
216
|
+
user: Optional[str] = None
|
|
217
|
+
) -> PIL.Image.Image:
|
|
218
|
+
# TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
|
|
219
|
+
result = openai_client().images.generate(
|
|
220
|
+
prompt=prompt,
|
|
221
|
+
model=_opt(model),
|
|
222
|
+
quality=_opt(quality),
|
|
223
|
+
size=_opt(size),
|
|
224
|
+
style=_opt(style),
|
|
225
|
+
user=_opt(user),
|
|
226
|
+
response_format="b64_json"
|
|
227
|
+
)
|
|
228
|
+
b64_str = result.data[0].b64_json
|
|
229
|
+
b64_bytes = base64.b64decode(b64_str)
|
|
230
|
+
img = PIL.Image.open(io.BytesIO(b64_bytes))
|
|
231
|
+
img.load()
|
|
232
|
+
return img
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
#####################################
|
|
236
|
+
# Moderations Endpoints
|
|
237
|
+
|
|
238
|
+
@pxt.udf
|
|
239
|
+
@_retry
|
|
240
|
+
def moderations(
|
|
241
|
+
input: str,
|
|
242
|
+
*,
|
|
243
|
+
model: Optional[str] = None
|
|
244
|
+
) -> dict:
|
|
245
|
+
result = openai_client().moderations.create(
|
|
246
|
+
input=input,
|
|
247
|
+
model=_opt(model)
|
|
248
|
+
)
|
|
249
|
+
return result.dict()
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
_T = TypeVar('_T')
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _opt(arg: _T) -> Union[_T, NotGiven]:
|
|
256
|
+
return arg if arg is not None else NOT_GIVEN
|