pixeltable 0.3.14__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__init__.py +42 -8
- pixeltable/{dataframe.py → _query.py} +470 -206
- pixeltable/_version.py +1 -0
- pixeltable/catalog/__init__.py +5 -4
- pixeltable/catalog/catalog.py +1785 -432
- pixeltable/catalog/column.py +190 -113
- pixeltable/catalog/dir.py +2 -4
- pixeltable/catalog/globals.py +19 -46
- pixeltable/catalog/insertable_table.py +191 -98
- pixeltable/catalog/path.py +63 -23
- pixeltable/catalog/schema_object.py +11 -15
- pixeltable/catalog/table.py +843 -436
- pixeltable/catalog/table_metadata.py +103 -0
- pixeltable/catalog/table_version.py +978 -657
- pixeltable/catalog/table_version_handle.py +72 -16
- pixeltable/catalog/table_version_path.py +112 -43
- pixeltable/catalog/tbl_ops.py +53 -0
- pixeltable/catalog/update_status.py +191 -0
- pixeltable/catalog/view.py +134 -90
- pixeltable/config.py +134 -22
- pixeltable/env.py +471 -157
- pixeltable/exceptions.py +6 -0
- pixeltable/exec/__init__.py +4 -1
- pixeltable/exec/aggregation_node.py +7 -8
- pixeltable/exec/cache_prefetch_node.py +83 -110
- pixeltable/exec/cell_materialization_node.py +268 -0
- pixeltable/exec/cell_reconstruction_node.py +168 -0
- pixeltable/exec/component_iteration_node.py +4 -3
- pixeltable/exec/data_row_batch.py +8 -65
- pixeltable/exec/exec_context.py +16 -4
- pixeltable/exec/exec_node.py +13 -36
- pixeltable/exec/expr_eval/evaluators.py +11 -7
- pixeltable/exec/expr_eval/expr_eval_node.py +27 -12
- pixeltable/exec/expr_eval/globals.py +8 -5
- pixeltable/exec/expr_eval/row_buffer.py +1 -2
- pixeltable/exec/expr_eval/schedulers.py +106 -56
- pixeltable/exec/globals.py +35 -0
- pixeltable/exec/in_memory_data_node.py +19 -19
- pixeltable/exec/object_store_save_node.py +293 -0
- pixeltable/exec/row_update_node.py +16 -9
- pixeltable/exec/sql_node.py +351 -84
- pixeltable/exprs/__init__.py +1 -1
- pixeltable/exprs/arithmetic_expr.py +27 -22
- pixeltable/exprs/array_slice.py +3 -3
- pixeltable/exprs/column_property_ref.py +36 -23
- pixeltable/exprs/column_ref.py +213 -89
- pixeltable/exprs/comparison.py +5 -5
- pixeltable/exprs/compound_predicate.py +5 -4
- pixeltable/exprs/data_row.py +164 -54
- pixeltable/exprs/expr.py +70 -44
- pixeltable/exprs/expr_dict.py +3 -3
- pixeltable/exprs/expr_set.py +17 -10
- pixeltable/exprs/function_call.py +100 -40
- pixeltable/exprs/globals.py +2 -2
- pixeltable/exprs/in_predicate.py +4 -4
- pixeltable/exprs/inline_expr.py +18 -32
- pixeltable/exprs/is_null.py +7 -3
- pixeltable/exprs/json_mapper.py +8 -8
- pixeltable/exprs/json_path.py +56 -22
- pixeltable/exprs/literal.py +27 -5
- pixeltable/exprs/method_ref.py +2 -2
- pixeltable/exprs/object_ref.py +2 -2
- pixeltable/exprs/row_builder.py +167 -67
- pixeltable/exprs/rowid_ref.py +25 -10
- pixeltable/exprs/similarity_expr.py +58 -40
- pixeltable/exprs/sql_element_cache.py +4 -4
- pixeltable/exprs/string_op.py +5 -5
- pixeltable/exprs/type_cast.py +3 -5
- pixeltable/func/__init__.py +1 -0
- pixeltable/func/aggregate_function.py +8 -8
- pixeltable/func/callable_function.py +9 -9
- pixeltable/func/expr_template_function.py +17 -11
- pixeltable/func/function.py +18 -20
- pixeltable/func/function_registry.py +6 -7
- pixeltable/func/globals.py +2 -3
- pixeltable/func/mcp.py +74 -0
- pixeltable/func/query_template_function.py +29 -27
- pixeltable/func/signature.py +46 -19
- pixeltable/func/tools.py +31 -13
- pixeltable/func/udf.py +18 -20
- pixeltable/functions/__init__.py +16 -0
- pixeltable/functions/anthropic.py +123 -77
- pixeltable/functions/audio.py +147 -10
- pixeltable/functions/bedrock.py +13 -6
- pixeltable/functions/date.py +7 -4
- pixeltable/functions/deepseek.py +35 -43
- pixeltable/functions/document.py +81 -0
- pixeltable/functions/fal.py +76 -0
- pixeltable/functions/fireworks.py +11 -20
- pixeltable/functions/gemini.py +195 -39
- pixeltable/functions/globals.py +142 -14
- pixeltable/functions/groq.py +108 -0
- pixeltable/functions/huggingface.py +1056 -24
- pixeltable/functions/image.py +115 -57
- pixeltable/functions/json.py +1 -1
- pixeltable/functions/llama_cpp.py +28 -13
- pixeltable/functions/math.py +67 -5
- pixeltable/functions/mistralai.py +18 -55
- pixeltable/functions/net.py +70 -0
- pixeltable/functions/ollama.py +20 -13
- pixeltable/functions/openai.py +240 -226
- pixeltable/functions/openrouter.py +143 -0
- pixeltable/functions/replicate.py +4 -4
- pixeltable/functions/reve.py +250 -0
- pixeltable/functions/string.py +239 -69
- pixeltable/functions/timestamp.py +16 -16
- pixeltable/functions/together.py +24 -84
- pixeltable/functions/twelvelabs.py +188 -0
- pixeltable/functions/util.py +6 -1
- pixeltable/functions/uuid.py +30 -0
- pixeltable/functions/video.py +1515 -107
- pixeltable/functions/vision.py +8 -8
- pixeltable/functions/voyageai.py +289 -0
- pixeltable/functions/whisper.py +16 -8
- pixeltable/functions/whisperx.py +179 -0
- pixeltable/{ext/functions → functions}/yolox.py +2 -4
- pixeltable/globals.py +362 -115
- pixeltable/index/base.py +17 -21
- pixeltable/index/btree.py +28 -22
- pixeltable/index/embedding_index.py +100 -118
- pixeltable/io/__init__.py +4 -2
- pixeltable/io/datarows.py +8 -7
- pixeltable/io/external_store.py +56 -105
- pixeltable/io/fiftyone.py +13 -13
- pixeltable/io/globals.py +31 -30
- pixeltable/io/hf_datasets.py +61 -16
- pixeltable/io/label_studio.py +74 -70
- pixeltable/io/lancedb.py +3 -0
- pixeltable/io/pandas.py +21 -12
- pixeltable/io/parquet.py +25 -105
- pixeltable/io/table_data_conduit.py +250 -123
- pixeltable/io/utils.py +4 -4
- pixeltable/iterators/__init__.py +2 -1
- pixeltable/iterators/audio.py +26 -25
- pixeltable/iterators/base.py +9 -3
- pixeltable/iterators/document.py +112 -78
- pixeltable/iterators/image.py +12 -15
- pixeltable/iterators/string.py +11 -4
- pixeltable/iterators/video.py +523 -120
- pixeltable/metadata/__init__.py +14 -3
- pixeltable/metadata/converters/convert_13.py +2 -2
- pixeltable/metadata/converters/convert_18.py +2 -2
- pixeltable/metadata/converters/convert_19.py +2 -2
- pixeltable/metadata/converters/convert_20.py +2 -2
- pixeltable/metadata/converters/convert_21.py +2 -2
- pixeltable/metadata/converters/convert_22.py +2 -2
- pixeltable/metadata/converters/convert_24.py +2 -2
- pixeltable/metadata/converters/convert_25.py +2 -2
- pixeltable/metadata/converters/convert_26.py +2 -2
- pixeltable/metadata/converters/convert_29.py +4 -4
- pixeltable/metadata/converters/convert_30.py +34 -21
- pixeltable/metadata/converters/convert_34.py +2 -2
- pixeltable/metadata/converters/convert_35.py +9 -0
- pixeltable/metadata/converters/convert_36.py +38 -0
- pixeltable/metadata/converters/convert_37.py +15 -0
- pixeltable/metadata/converters/convert_38.py +39 -0
- pixeltable/metadata/converters/convert_39.py +124 -0
- pixeltable/metadata/converters/convert_40.py +73 -0
- pixeltable/metadata/converters/convert_41.py +12 -0
- pixeltable/metadata/converters/convert_42.py +9 -0
- pixeltable/metadata/converters/convert_43.py +44 -0
- pixeltable/metadata/converters/util.py +20 -31
- pixeltable/metadata/notes.py +9 -0
- pixeltable/metadata/schema.py +140 -53
- pixeltable/metadata/utils.py +74 -0
- pixeltable/mypy/__init__.py +3 -0
- pixeltable/mypy/mypy_plugin.py +123 -0
- pixeltable/plan.py +382 -115
- pixeltable/share/__init__.py +1 -1
- pixeltable/share/packager.py +547 -83
- pixeltable/share/protocol/__init__.py +33 -0
- pixeltable/share/protocol/common.py +165 -0
- pixeltable/share/protocol/operation_types.py +33 -0
- pixeltable/share/protocol/replica.py +119 -0
- pixeltable/share/publish.py +257 -59
- pixeltable/store.py +311 -194
- pixeltable/type_system.py +373 -211
- pixeltable/utils/__init__.py +2 -3
- pixeltable/utils/arrow.py +131 -17
- pixeltable/utils/av.py +298 -0
- pixeltable/utils/azure_store.py +346 -0
- pixeltable/utils/coco.py +6 -6
- pixeltable/utils/code.py +3 -3
- pixeltable/utils/console_output.py +4 -1
- pixeltable/utils/coroutine.py +6 -23
- pixeltable/utils/dbms.py +32 -6
- pixeltable/utils/description_helper.py +4 -5
- pixeltable/utils/documents.py +7 -18
- pixeltable/utils/exception_handler.py +7 -30
- pixeltable/utils/filecache.py +6 -6
- pixeltable/utils/formatter.py +86 -48
- pixeltable/utils/gcs_store.py +295 -0
- pixeltable/utils/http.py +133 -0
- pixeltable/utils/http_server.py +2 -3
- pixeltable/utils/iceberg.py +1 -2
- pixeltable/utils/image.py +17 -0
- pixeltable/utils/lancedb.py +90 -0
- pixeltable/utils/local_store.py +322 -0
- pixeltable/utils/misc.py +5 -0
- pixeltable/utils/object_stores.py +573 -0
- pixeltable/utils/pydantic.py +60 -0
- pixeltable/utils/pytorch.py +5 -6
- pixeltable/utils/s3_store.py +527 -0
- pixeltable/utils/sql.py +26 -0
- pixeltable/utils/system.py +30 -0
- pixeltable-0.5.7.dist-info/METADATA +579 -0
- pixeltable-0.5.7.dist-info/RECORD +227 -0
- {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
- pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
- pixeltable/__version__.py +0 -3
- pixeltable/catalog/named_function.py +0 -40
- pixeltable/ext/__init__.py +0 -17
- pixeltable/ext/functions/__init__.py +0 -11
- pixeltable/ext/functions/whisperx.py +0 -77
- pixeltable/utils/media_store.py +0 -77
- pixeltable/utils/s3.py +0 -17
- pixeltable-0.3.14.dist-info/METADATA +0 -434
- pixeltable-0.3.14.dist-info/RECORD +0 -186
- pixeltable-0.3.14.dist-info/entry_points.txt +0 -3
- {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
pixeltable/functions/openai.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Pixeltable
|
|
2
|
+
Pixeltable UDFs
|
|
3
3
|
that wrap various endpoints from the OpenAI API. In order to use them, you must
|
|
4
4
|
first `pip install openai` and configure your OpenAI credentials, as described in
|
|
5
|
-
the [Working with OpenAI](https://pixeltable.
|
|
5
|
+
the [Working with OpenAI](https://docs.pixeltable.com/notebooks/integrations/working-with-openai) tutorial.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import base64
|
|
@@ -13,18 +13,19 @@ import logging
|
|
|
13
13
|
import math
|
|
14
14
|
import pathlib
|
|
15
15
|
import re
|
|
16
|
-
import
|
|
17
|
-
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Type, TypeVar, Union, cast
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable, Type
|
|
18
17
|
|
|
19
18
|
import httpx
|
|
20
19
|
import numpy as np
|
|
21
20
|
import PIL
|
|
22
21
|
|
|
23
22
|
import pixeltable as pxt
|
|
24
|
-
import
|
|
25
|
-
from pixeltable import
|
|
23
|
+
from pixeltable import env, exprs, type_system as ts
|
|
24
|
+
from pixeltable.config import Config
|
|
26
25
|
from pixeltable.func import Batch, Tools
|
|
27
26
|
from pixeltable.utils.code import local_public_names
|
|
27
|
+
from pixeltable.utils.local_store import TempStore
|
|
28
|
+
from pixeltable.utils.system import set_file_descriptor_limit
|
|
28
29
|
|
|
29
30
|
if TYPE_CHECKING:
|
|
30
31
|
import openai
|
|
@@ -33,13 +34,28 @@ _logger = logging.getLogger('pixeltable')
|
|
|
33
34
|
|
|
34
35
|
|
|
35
36
|
@env.register_client('openai')
|
|
36
|
-
def _(api_key: str) -> 'openai.AsyncOpenAI':
|
|
37
|
+
def _(api_key: str, base_url: str | None = None, api_version: str | None = None) -> 'openai.AsyncOpenAI':
|
|
37
38
|
import openai
|
|
38
39
|
|
|
40
|
+
max_connections = Config.get().get_int_value('openai.max_connections') or 2000
|
|
41
|
+
max_keepalive_connections = Config.get().get_int_value('openai.max_keepalive_connections') or 100
|
|
42
|
+
set_file_descriptor_limit(max_connections * 2)
|
|
43
|
+
default_query = None if api_version is None else {'api-version': api_version}
|
|
44
|
+
|
|
45
|
+
# Pixeltable scheduler's retry logic takes into account the rate limit-related response headers, so in theory we can
|
|
46
|
+
# benefit from disabling retries in the OpenAI client (max_retries=0). However to do that, we need to get smarter
|
|
47
|
+
# about idempotency keys and possibly more.
|
|
39
48
|
return openai.AsyncOpenAI(
|
|
40
49
|
api_key=api_key,
|
|
50
|
+
base_url=base_url,
|
|
51
|
+
default_query=default_query,
|
|
41
52
|
# recommended to increase limits for async client to avoid connection errors
|
|
42
|
-
http_client=httpx.AsyncClient(
|
|
53
|
+
http_client=httpx.AsyncClient(
|
|
54
|
+
limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections, max_connections=max_connections),
|
|
55
|
+
# HTTP1 tends to perform better on this kind of workloads
|
|
56
|
+
http2=False,
|
|
57
|
+
http1=True,
|
|
58
|
+
),
|
|
43
59
|
)
|
|
44
60
|
|
|
45
61
|
|
|
@@ -89,6 +105,99 @@ def _rate_limits_pool(model: str) -> str:
|
|
|
89
105
|
return f'rate-limits:openai:{model}'
|
|
90
106
|
|
|
91
107
|
|
|
108
|
+
def _parse_header_duration(duration_str: str) -> float | None:
|
|
109
|
+
"""Parses the value of x-ratelimit-reset-* header into seconds.
|
|
110
|
+
|
|
111
|
+
Returns None if the input cannot be parsed.
|
|
112
|
+
|
|
113
|
+
Real life examples of header values:
|
|
114
|
+
* '1m33.792s'
|
|
115
|
+
* '857ms'
|
|
116
|
+
* '0s'
|
|
117
|
+
* '47.874s'
|
|
118
|
+
* '156h58m48.601s'
|
|
119
|
+
"""
|
|
120
|
+
if duration_str is None or duration_str.strip() == '':
|
|
121
|
+
return None
|
|
122
|
+
units = {
|
|
123
|
+
86400: r'(\d+)d', # days
|
|
124
|
+
3600: r'(\d+)h', # hours
|
|
125
|
+
60: r'(\d+)m(?:[^s]|$)', # minutes
|
|
126
|
+
1: r'([\d.]+)s', # seconds
|
|
127
|
+
0.001: r'(\d+)ms', # millis
|
|
128
|
+
}
|
|
129
|
+
seconds = None
|
|
130
|
+
for unit_value, pattern in units.items():
|
|
131
|
+
match = re.search(pattern, duration_str)
|
|
132
|
+
if match:
|
|
133
|
+
seconds = seconds or 0.0
|
|
134
|
+
seconds += float(match.group(1)) * unit_value
|
|
135
|
+
_logger.debug(f'Parsed duration header value "{duration_str}" into {seconds} seconds')
|
|
136
|
+
return seconds
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _get_header_info(
|
|
140
|
+
headers: httpx.Headers,
|
|
141
|
+
) -> tuple[tuple[int, int, datetime.datetime] | None, tuple[int, int, datetime.datetime] | None]:
|
|
142
|
+
"""Parses rate limit related headers"""
|
|
143
|
+
# Requests and project-requests are two separate limits of requests per minute. project-requests headers will be
|
|
144
|
+
# present if an RPM limit is configured on the project limit.
|
|
145
|
+
requests_info = _get_resource_info(headers, 'requests')
|
|
146
|
+
requests_fraction_remaining = _fract_remaining(requests_info)
|
|
147
|
+
project_requests_info = _get_resource_info(headers, 'project-requests')
|
|
148
|
+
project_requests_fraction_remaining = _fract_remaining(project_requests_info)
|
|
149
|
+
|
|
150
|
+
# If both limit infos are present, pick the one with the least percentage remaining
|
|
151
|
+
best_requests_info = requests_info or project_requests_info
|
|
152
|
+
if (
|
|
153
|
+
requests_fraction_remaining is not None
|
|
154
|
+
and project_requests_fraction_remaining is not None
|
|
155
|
+
and project_requests_fraction_remaining < requests_fraction_remaining
|
|
156
|
+
):
|
|
157
|
+
best_requests_info = project_requests_info
|
|
158
|
+
|
|
159
|
+
# Same story with tokens
|
|
160
|
+
tokens_info = _get_resource_info(headers, 'tokens')
|
|
161
|
+
tokens_fraction_remaining = _fract_remaining(tokens_info)
|
|
162
|
+
project_tokens_info = _get_resource_info(headers, 'project-tokens')
|
|
163
|
+
project_tokens_fraction_remaining = _fract_remaining(project_tokens_info)
|
|
164
|
+
|
|
165
|
+
best_tokens_info = tokens_info or project_tokens_info
|
|
166
|
+
if (
|
|
167
|
+
tokens_fraction_remaining is not None
|
|
168
|
+
and project_tokens_fraction_remaining is not None
|
|
169
|
+
and project_tokens_fraction_remaining < tokens_fraction_remaining
|
|
170
|
+
):
|
|
171
|
+
best_tokens_info = project_tokens_info
|
|
172
|
+
|
|
173
|
+
if best_requests_info is None or best_tokens_info is None:
|
|
174
|
+
_logger.debug(f'get_header_info(): incomplete rate limit info: {headers}')
|
|
175
|
+
|
|
176
|
+
return best_requests_info, best_tokens_info
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _get_resource_info(headers: httpx.Headers, resource: str) -> tuple[int, int, datetime.datetime] | None:
|
|
180
|
+
remaining_str = headers.get(f'x-ratelimit-remaining-{resource}')
|
|
181
|
+
if remaining_str is None:
|
|
182
|
+
return None
|
|
183
|
+
remaining = int(remaining_str)
|
|
184
|
+
limit_str = headers.get(f'x-ratelimit-limit-{resource}')
|
|
185
|
+
limit = int(limit_str) if limit_str is not None else None
|
|
186
|
+
reset_str = headers.get(f'x-ratelimit-reset-{resource}')
|
|
187
|
+
reset_in_seconds = _parse_header_duration(reset_str) or 5.0 # Default to 5 seconds
|
|
188
|
+
reset_ts = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(seconds=reset_in_seconds)
|
|
189
|
+
return (limit, remaining, reset_ts)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _fract_remaining(resource_info: tuple[int, int, datetime.datetime] | None) -> float | None:
|
|
193
|
+
if resource_info is None:
|
|
194
|
+
return None
|
|
195
|
+
limit, remaining, _ = resource_info
|
|
196
|
+
if limit is None or remaining is None:
|
|
197
|
+
return None
|
|
198
|
+
return remaining / limit
|
|
199
|
+
|
|
200
|
+
|
|
92
201
|
class OpenAIRateLimitsInfo(env.RateLimitsInfo):
|
|
93
202
|
retryable_errors: tuple[Type[Exception], ...]
|
|
94
203
|
|
|
@@ -109,61 +218,36 @@ class OpenAIRateLimitsInfo(env.RateLimitsInfo):
|
|
|
109
218
|
openai.InternalServerError,
|
|
110
219
|
)
|
|
111
220
|
|
|
112
|
-
def
|
|
221
|
+
def record_exc(self, request_ts: datetime.datetime, exc: Exception) -> None:
|
|
113
222
|
import openai
|
|
114
223
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
return 1.0
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
# RE pattern for duration in '*-reset' headers;
|
|
122
|
-
# examples: 1d2h3ms, 4m5.6s; # fractional seconds can be reported as 0.5s or 500ms
|
|
123
|
-
_header_duration_pattern = re.compile(r'(?:(\d+)d)?(?:(\d+)h)?(?:(\d+)ms)|(?:(\d+)m)?(?:([\d.]+)s)?')
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
def _parse_header_duration(duration_str: str) -> datetime.timedelta:
|
|
127
|
-
match = _header_duration_pattern.match(duration_str)
|
|
128
|
-
if not match:
|
|
129
|
-
raise ValueError('Invalid duration format')
|
|
224
|
+
_ = isinstance(exc, openai.APIError)
|
|
225
|
+
if not isinstance(exc, openai.APIError) or not hasattr(exc, 'response') or not hasattr(exc.response, 'headers'):
|
|
226
|
+
return
|
|
130
227
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
228
|
+
requests_info, tokens_info = _get_header_info(exc.response.headers)
|
|
229
|
+
_logger.debug(
|
|
230
|
+
f'record_exc(): request_ts: {request_ts}, requests_info={requests_info} tokens_info={tokens_info}'
|
|
231
|
+
)
|
|
232
|
+
self.record(request_ts=request_ts, requests=requests_info, tokens=tokens_info)
|
|
233
|
+
self.has_exc = True
|
|
136
234
|
|
|
137
|
-
|
|
235
|
+
def _retry_delay_from_exception(self, exc: Exception) -> float | None:
|
|
236
|
+
try:
|
|
237
|
+
retry_after_str = exc.response.headers.get('retry-after') # type: ignore
|
|
238
|
+
except AttributeError:
|
|
239
|
+
return None
|
|
240
|
+
if retry_after_str is not None and re.fullmatch(r'\d{1,4}', retry_after_str):
|
|
241
|
+
return float(retry_after_str)
|
|
242
|
+
return None
|
|
138
243
|
|
|
244
|
+
def get_retry_delay(self, exc: Exception, attempt: int) -> float | None:
|
|
245
|
+
import openai
|
|
139
246
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
|
145
|
-
|
|
146
|
-
requests_info: Optional[tuple[int, int, datetime.datetime]] = None
|
|
147
|
-
if requests:
|
|
148
|
-
requests_limit_str = headers.get('x-ratelimit-limit-requests')
|
|
149
|
-
requests_limit = int(requests_limit_str) if requests_limit_str is not None else None
|
|
150
|
-
requests_remaining_str = headers.get('x-ratelimit-remaining-requests')
|
|
151
|
-
requests_remaining = int(requests_remaining_str) if requests_remaining_str is not None else None
|
|
152
|
-
requests_reset_str = headers.get('x-ratelimit-reset-requests')
|
|
153
|
-
requests_reset_ts = now + _parse_header_duration(requests_reset_str)
|
|
154
|
-
requests_info = (requests_limit, requests_remaining, requests_reset_ts)
|
|
155
|
-
|
|
156
|
-
tokens_info: Optional[tuple[int, int, datetime.datetime]] = None
|
|
157
|
-
if tokens:
|
|
158
|
-
tokens_limit_str = headers.get('x-ratelimit-limit-tokens')
|
|
159
|
-
tokens_limit = int(tokens_limit_str) if tokens_limit_str is not None else None
|
|
160
|
-
tokens_remaining_str = headers.get('x-ratelimit-remaining-tokens')
|
|
161
|
-
tokens_remaining = int(tokens_remaining_str) if tokens_remaining_str is not None else None
|
|
162
|
-
tokens_reset_str = headers.get('x-ratelimit-reset-tokens')
|
|
163
|
-
tokens_reset_ts = now + _parse_header_duration(tokens_reset_str)
|
|
164
|
-
tokens_info = (tokens_limit, tokens_remaining, tokens_reset_ts)
|
|
165
|
-
|
|
166
|
-
return requests_info, tokens_info
|
|
247
|
+
if not isinstance(exc, self.retryable_errors):
|
|
248
|
+
return None
|
|
249
|
+
assert isinstance(exc, openai.APIError)
|
|
250
|
+
return self._retry_delay_from_exception(exc) or super().get_retry_delay(exc, attempt)
|
|
167
251
|
|
|
168
252
|
|
|
169
253
|
#####################################
|
|
@@ -171,15 +255,7 @@ def _get_header_info(
|
|
|
171
255
|
|
|
172
256
|
|
|
173
257
|
@pxt.udf
|
|
174
|
-
async def speech(
|
|
175
|
-
input: str,
|
|
176
|
-
*,
|
|
177
|
-
model: str,
|
|
178
|
-
voice: str,
|
|
179
|
-
response_format: Optional[str] = None,
|
|
180
|
-
speed: Optional[float] = None,
|
|
181
|
-
timeout: Optional[float] = None,
|
|
182
|
-
) -> pxt.Audio:
|
|
258
|
+
async def speech(input: str, *, model: str, voice: str, model_kwargs: dict[str, Any] | None = None) -> pxt.Audio:
|
|
183
259
|
"""
|
|
184
260
|
Generates audio from the input text.
|
|
185
261
|
|
|
@@ -199,8 +275,8 @@ async def speech(
|
|
|
199
275
|
model: The model to use for speech synthesis.
|
|
200
276
|
voice: The voice profile to use for speech synthesis. Supported options include:
|
|
201
277
|
`alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`.
|
|
202
|
-
|
|
203
|
-
|
|
278
|
+
model_kwargs: Additional keyword args for the OpenAI `audio/speech` API. For details on the available
|
|
279
|
+
parameters, see: <https://platform.openai.com/docs/api-reference/audio/createSpeech>
|
|
204
280
|
|
|
205
281
|
Returns:
|
|
206
282
|
An audio file containing the synthesized speech.
|
|
@@ -211,30 +287,18 @@ async def speech(
|
|
|
211
287
|
|
|
212
288
|
>>> tbl.add_computed_column(audio=speech(tbl.text, model='tts-1', voice='nova'))
|
|
213
289
|
"""
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
timeout=_opt(timeout),
|
|
221
|
-
)
|
|
222
|
-
ext = response_format or 'mp3'
|
|
223
|
-
output_filename = str(env.Env.get().tmp_dir / f'{uuid.uuid4()}.{ext}')
|
|
290
|
+
if model_kwargs is None:
|
|
291
|
+
model_kwargs = {}
|
|
292
|
+
|
|
293
|
+
content = await _openai_client().audio.speech.create(input=input, model=model, voice=voice, **model_kwargs)
|
|
294
|
+
ext = model_kwargs.get('response_format', 'mp3')
|
|
295
|
+
output_filename = str(TempStore.create_path(extension=f'.{ext}'))
|
|
224
296
|
content.write_to_file(output_filename)
|
|
225
297
|
return output_filename
|
|
226
298
|
|
|
227
299
|
|
|
228
300
|
@pxt.udf
|
|
229
|
-
async def transcriptions(
|
|
230
|
-
audio: pxt.Audio,
|
|
231
|
-
*,
|
|
232
|
-
model: str,
|
|
233
|
-
language: Optional[str] = None,
|
|
234
|
-
prompt: Optional[str] = None,
|
|
235
|
-
temperature: Optional[float] = None,
|
|
236
|
-
timeout: Optional[float] = None,
|
|
237
|
-
) -> dict:
|
|
301
|
+
async def transcriptions(audio: pxt.Audio, *, model: str, model_kwargs: dict[str, Any] | None = None) -> dict:
|
|
238
302
|
"""
|
|
239
303
|
Transcribes audio into the input language.
|
|
240
304
|
|
|
@@ -252,8 +316,8 @@ async def transcriptions(
|
|
|
252
316
|
Args:
|
|
253
317
|
audio: The audio to transcribe.
|
|
254
318
|
model: The model to use for speech transcription.
|
|
255
|
-
|
|
256
|
-
|
|
319
|
+
model_kwargs: Additional keyword args for the OpenAI `audio/transcriptions` API. For details on the available
|
|
320
|
+
parameters, see: <https://platform.openai.com/docs/api-reference/audio/createTranscription>
|
|
257
321
|
|
|
258
322
|
Returns:
|
|
259
323
|
A dictionary containing the transcription and other metadata.
|
|
@@ -264,27 +328,16 @@ async def transcriptions(
|
|
|
264
328
|
|
|
265
329
|
>>> tbl.add_computed_column(transcription=transcriptions(tbl.audio, model='whisper-1', language='en'))
|
|
266
330
|
"""
|
|
331
|
+
if model_kwargs is None:
|
|
332
|
+
model_kwargs = {}
|
|
333
|
+
|
|
267
334
|
file = pathlib.Path(audio)
|
|
268
|
-
transcription = await _openai_client().audio.transcriptions.create(
|
|
269
|
-
file=file,
|
|
270
|
-
model=model,
|
|
271
|
-
language=_opt(language),
|
|
272
|
-
prompt=_opt(prompt),
|
|
273
|
-
temperature=_opt(temperature),
|
|
274
|
-
timeout=_opt(timeout),
|
|
275
|
-
)
|
|
335
|
+
transcription = await _openai_client().audio.transcriptions.create(file=file, model=model, **model_kwargs)
|
|
276
336
|
return transcription.dict()
|
|
277
337
|
|
|
278
338
|
|
|
279
339
|
@pxt.udf
|
|
280
|
-
async def translations(
|
|
281
|
-
audio: pxt.Audio,
|
|
282
|
-
*,
|
|
283
|
-
model: str,
|
|
284
|
-
prompt: Optional[str] = None,
|
|
285
|
-
temperature: Optional[float] = None,
|
|
286
|
-
timeout: Optional[float] = None,
|
|
287
|
-
) -> dict:
|
|
340
|
+
async def translations(audio: pxt.Audio, *, model: str, model_kwargs: dict[str, Any] | None = None) -> dict:
|
|
288
341
|
"""
|
|
289
342
|
Translates audio into English.
|
|
290
343
|
|
|
@@ -302,8 +355,8 @@ async def translations(
|
|
|
302
355
|
Args:
|
|
303
356
|
audio: The audio to translate.
|
|
304
357
|
model: The model to use for speech transcription and translation.
|
|
305
|
-
|
|
306
|
-
|
|
358
|
+
model_kwargs: Additional keyword args for the OpenAI `audio/translations` API. For details on the available
|
|
359
|
+
parameters, see: <https://platform.openai.com/docs/api-reference/audio/createTranslation>
|
|
307
360
|
|
|
308
361
|
Returns:
|
|
309
362
|
A dictionary containing the translation and other metadata.
|
|
@@ -314,10 +367,11 @@ async def translations(
|
|
|
314
367
|
|
|
315
368
|
>>> tbl.add_computed_column(translation=translations(tbl.audio, model='whisper-1', language='en'))
|
|
316
369
|
"""
|
|
370
|
+
if model_kwargs is None:
|
|
371
|
+
model_kwargs = {}
|
|
372
|
+
|
|
317
373
|
file = pathlib.Path(audio)
|
|
318
|
-
translation = await _openai_client().audio.translations.create(
|
|
319
|
-
file=file, model=model, prompt=_opt(prompt), temperature=_opt(temperature), timeout=_opt(timeout)
|
|
320
|
-
)
|
|
374
|
+
translation = await _openai_client().audio.translations.create(file=file, model=model, **model_kwargs)
|
|
321
375
|
return translation.dict()
|
|
322
376
|
|
|
323
377
|
|
|
@@ -353,8 +407,15 @@ def _is_model_family(model: str, family: str) -> bool:
|
|
|
353
407
|
|
|
354
408
|
|
|
355
409
|
def _chat_completions_get_request_resources(
|
|
356
|
-
messages: list, model: str,
|
|
410
|
+
messages: list, model: str, model_kwargs: dict[str, Any] | None
|
|
357
411
|
) -> dict[str, int]:
|
|
412
|
+
if model_kwargs is None:
|
|
413
|
+
model_kwargs = {}
|
|
414
|
+
|
|
415
|
+
max_completion_tokens = model_kwargs.get('max_completion_tokens')
|
|
416
|
+
max_tokens = model_kwargs.get('max_tokens')
|
|
417
|
+
n = model_kwargs.get('n')
|
|
418
|
+
|
|
358
419
|
completion_tokens = (n or 1) * (max_completion_tokens or max_tokens or _default_max_tokens(model))
|
|
359
420
|
|
|
360
421
|
num_tokens = 0.0
|
|
@@ -373,24 +434,10 @@ async def chat_completions(
|
|
|
373
434
|
messages: list,
|
|
374
435
|
*,
|
|
375
436
|
model: str,
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
max_completion_tokens: Optional[int] = None,
|
|
381
|
-
max_tokens: Optional[int] = None,
|
|
382
|
-
n: Optional[int] = None,
|
|
383
|
-
presence_penalty: Optional[float] = None,
|
|
384
|
-
reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None,
|
|
385
|
-
response_format: Optional[dict] = None,
|
|
386
|
-
seed: Optional[int] = None,
|
|
387
|
-
stop: Optional[list[str]] = None,
|
|
388
|
-
temperature: Optional[float] = None,
|
|
389
|
-
tools: Optional[list[dict]] = None,
|
|
390
|
-
tool_choice: Optional[dict] = None,
|
|
391
|
-
top_p: Optional[float] = None,
|
|
392
|
-
user: Optional[str] = None,
|
|
393
|
-
timeout: Optional[float] = None,
|
|
437
|
+
model_kwargs: dict[str, Any] | None = None,
|
|
438
|
+
tools: list[dict[str, Any]] | None = None,
|
|
439
|
+
tool_choice: dict[str, Any] | None = None,
|
|
440
|
+
_runtime_ctx: env.RuntimeCtx | None = None,
|
|
394
441
|
) -> dict:
|
|
395
442
|
"""
|
|
396
443
|
Creates a model response for the given chat conversation.
|
|
@@ -409,8 +456,8 @@ async def chat_completions(
|
|
|
409
456
|
Args:
|
|
410
457
|
messages: A list of messages to use for chat completion, as described in the OpenAI API documentation.
|
|
411
458
|
model: The model to use for chat completion.
|
|
412
|
-
|
|
413
|
-
|
|
459
|
+
model_kwargs: Additional keyword args for the OpenAI `chat/completions` API. For details on the available
|
|
460
|
+
parameters, see: <https://platform.openai.com/docs/api-reference/chat/create>
|
|
414
461
|
|
|
415
462
|
Returns:
|
|
416
463
|
A dictionary containing the response and other metadata.
|
|
@@ -420,27 +467,28 @@ async def chat_completions(
|
|
|
420
467
|
of the table `tbl`:
|
|
421
468
|
|
|
422
469
|
>>> messages = [
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
470
|
+
... {'role': 'system', 'content': 'You are a helpful assistant.'},
|
|
471
|
+
... {'role': 'user', 'content': tbl.prompt}
|
|
472
|
+
... ]
|
|
473
|
+
>>> tbl.add_computed_column(response=chat_completions(messages, model='gpt-4o-mini'))
|
|
427
474
|
"""
|
|
475
|
+
if model_kwargs is None:
|
|
476
|
+
model_kwargs = {}
|
|
477
|
+
|
|
428
478
|
if tools is not None:
|
|
429
|
-
tools = [{'type': 'function', 'function': tool} for tool in tools]
|
|
479
|
+
model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
|
|
430
480
|
|
|
431
|
-
tool_choice_: Union[str, dict, None] = None
|
|
432
481
|
if tool_choice is not None:
|
|
433
482
|
if tool_choice['auto']:
|
|
434
|
-
|
|
483
|
+
model_kwargs['tool_choice'] = 'auto'
|
|
435
484
|
elif tool_choice['required']:
|
|
436
|
-
|
|
485
|
+
model_kwargs['tool_choice'] = 'required'
|
|
437
486
|
else:
|
|
438
487
|
assert tool_choice['tool'] is not None
|
|
439
|
-
|
|
488
|
+
model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
|
|
440
489
|
|
|
441
|
-
extra_body: Optional[dict[str, Any]] = None
|
|
442
490
|
if tool_choice is not None and not tool_choice['parallel_tool_calls']:
|
|
443
|
-
|
|
491
|
+
model_kwargs['parallel_tool_calls'] = False
|
|
444
492
|
|
|
445
493
|
# make sure the pool info exists prior to making the request
|
|
446
494
|
resource_pool = _rate_limits_pool(model)
|
|
@@ -448,45 +496,28 @@ async def chat_completions(
|
|
|
448
496
|
resource_pool, lambda: OpenAIRateLimitsInfo(_chat_completions_get_request_resources)
|
|
449
497
|
)
|
|
450
498
|
|
|
451
|
-
|
|
499
|
+
request_ts = datetime.datetime.now(tz=datetime.timezone.utc)
|
|
452
500
|
result = await _openai_client().chat.completions.with_raw_response.create(
|
|
453
|
-
messages=messages,
|
|
454
|
-
model=model,
|
|
455
|
-
frequency_penalty=_opt(frequency_penalty),
|
|
456
|
-
logit_bias=_opt(logit_bias),
|
|
457
|
-
logprobs=_opt(logprobs),
|
|
458
|
-
top_logprobs=_opt(top_logprobs),
|
|
459
|
-
max_completion_tokens=_opt(max_completion_tokens),
|
|
460
|
-
max_tokens=_opt(max_tokens),
|
|
461
|
-
n=_opt(n),
|
|
462
|
-
presence_penalty=_opt(presence_penalty),
|
|
463
|
-
reasoning_effort=_opt(reasoning_effort),
|
|
464
|
-
response_format=_opt(cast(Any, response_format)),
|
|
465
|
-
seed=_opt(seed),
|
|
466
|
-
stop=_opt(stop),
|
|
467
|
-
temperature=_opt(temperature),
|
|
468
|
-
tools=_opt(cast(Any, tools)),
|
|
469
|
-
tool_choice=_opt(cast(Any, tool_choice_)),
|
|
470
|
-
top_p=_opt(top_p),
|
|
471
|
-
user=_opt(user),
|
|
472
|
-
timeout=_opt(timeout),
|
|
473
|
-
extra_body=extra_body,
|
|
501
|
+
messages=messages, model=model, **model_kwargs
|
|
474
502
|
)
|
|
475
503
|
|
|
476
504
|
requests_info, tokens_info = _get_header_info(result.headers)
|
|
477
|
-
|
|
505
|
+
is_retry = _runtime_ctx is not None and _runtime_ctx.is_retry
|
|
506
|
+
rate_limits_info.record(request_ts=request_ts, requests=requests_info, tokens=tokens_info, reset_exc=is_retry)
|
|
478
507
|
|
|
479
508
|
return json.loads(result.text)
|
|
480
509
|
|
|
481
510
|
|
|
482
511
|
def _vision_get_request_resources(
|
|
483
|
-
prompt: str,
|
|
484
|
-
image: PIL.Image.Image,
|
|
485
|
-
model: str,
|
|
486
|
-
max_completion_tokens: Optional[int],
|
|
487
|
-
max_tokens: Optional[int],
|
|
488
|
-
n: Optional[int],
|
|
512
|
+
prompt: str, image: PIL.Image.Image, model: str, model_kwargs: dict[str, Any] | None = None
|
|
489
513
|
) -> dict[str, int]:
|
|
514
|
+
if model_kwargs is None:
|
|
515
|
+
model_kwargs = {}
|
|
516
|
+
|
|
517
|
+
max_completion_tokens = model_kwargs.get('max_completion_tokens')
|
|
518
|
+
max_tokens = model_kwargs.get('max_tokens')
|
|
519
|
+
n = model_kwargs.get('n')
|
|
520
|
+
|
|
490
521
|
completion_tokens = (n or 1) * (max_completion_tokens or max_tokens or _default_max_tokens(model))
|
|
491
522
|
prompt_tokens = len(prompt) / 4
|
|
492
523
|
|
|
@@ -519,10 +550,8 @@ async def vision(
|
|
|
519
550
|
image: PIL.Image.Image,
|
|
520
551
|
*,
|
|
521
552
|
model: str,
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
n: Optional[int] = 1,
|
|
525
|
-
timeout: Optional[float] = None,
|
|
553
|
+
model_kwargs: dict[str, Any] | None = None,
|
|
554
|
+
_runtime_ctx: env.RuntimeCtx | None = None,
|
|
526
555
|
) -> str:
|
|
527
556
|
"""
|
|
528
557
|
Analyzes an image with the OpenAI vision capability. This is a convenience function that takes an image and
|
|
@@ -552,6 +581,9 @@ async def vision(
|
|
|
552
581
|
|
|
553
582
|
>>> tbl.add_computed_column(response=vision("What's in this image?", tbl.image, model='gpt-4o-mini'))
|
|
554
583
|
"""
|
|
584
|
+
if model_kwargs is None:
|
|
585
|
+
model_kwargs = {}
|
|
586
|
+
|
|
555
587
|
# TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
|
|
556
588
|
bytes_arr = io.BytesIO()
|
|
557
589
|
image.save(bytes_arr, format='png')
|
|
@@ -573,17 +605,17 @@ async def vision(
|
|
|
573
605
|
resource_pool, lambda: OpenAIRateLimitsInfo(_vision_get_request_resources)
|
|
574
606
|
)
|
|
575
607
|
|
|
608
|
+
request_ts = datetime.datetime.now(tz=datetime.timezone.utc)
|
|
576
609
|
result = await _openai_client().chat.completions.with_raw_response.create(
|
|
577
610
|
messages=messages, # type: ignore
|
|
578
611
|
model=model,
|
|
579
|
-
|
|
580
|
-
max_tokens=_opt(max_tokens),
|
|
581
|
-
n=_opt(n),
|
|
582
|
-
timeout=_opt(timeout),
|
|
612
|
+
**model_kwargs,
|
|
583
613
|
)
|
|
584
614
|
|
|
615
|
+
# _logger.debug(f'vision(): headers={result.headers}')
|
|
585
616
|
requests_info, tokens_info = _get_header_info(result.headers)
|
|
586
|
-
|
|
617
|
+
is_retry = _runtime_ctx is not None and _runtime_ctx.is_retry
|
|
618
|
+
rate_limits_info.record(request_ts=request_ts, requests=requests_info, tokens=tokens_info, reset_exc=is_retry)
|
|
587
619
|
|
|
588
620
|
result = json.loads(result.text)
|
|
589
621
|
return result['choices'][0]['message']['content']
|
|
@@ -609,9 +641,8 @@ async def embeddings(
|
|
|
609
641
|
input: Batch[str],
|
|
610
642
|
*,
|
|
611
643
|
model: str,
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
timeout: Optional[float] = None,
|
|
644
|
+
model_kwargs: dict[str, Any] | None = None,
|
|
645
|
+
_runtime_ctx: env.RuntimeCtx | None = None,
|
|
615
646
|
) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
616
647
|
"""
|
|
617
648
|
Creates an embedding vector representing the input text.
|
|
@@ -630,10 +661,8 @@ async def embeddings(
|
|
|
630
661
|
Args:
|
|
631
662
|
input: The text to embed.
|
|
632
663
|
model: The model to use for the embedding.
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/embeddings>
|
|
664
|
+
model_kwargs: Additional keyword args for the OpenAI `embeddings` API. For details on the available
|
|
665
|
+
parameters, see: <https://platform.openai.com/docs/api-reference/embeddings>
|
|
637
666
|
|
|
638
667
|
Returns:
|
|
639
668
|
An array representing the application of the given embedding to `input`.
|
|
@@ -648,26 +677,29 @@ async def embeddings(
|
|
|
648
677
|
|
|
649
678
|
>>> tbl.add_embedding_index(embedding=embeddings.using(model='text-embedding-3-small'))
|
|
650
679
|
"""
|
|
680
|
+
if model_kwargs is None:
|
|
681
|
+
model_kwargs = {}
|
|
682
|
+
|
|
651
683
|
_logger.debug(f'embeddings: batch_size={len(input)}')
|
|
652
684
|
resource_pool = _rate_limits_pool(model)
|
|
653
685
|
rate_limits_info = env.Env.get().get_resource_pool_info(
|
|
654
686
|
resource_pool, lambda: OpenAIRateLimitsInfo(_embeddings_get_request_resources)
|
|
655
687
|
)
|
|
688
|
+
request_ts = datetime.datetime.now(tz=datetime.timezone.utc)
|
|
656
689
|
result = await _openai_client().embeddings.with_raw_response.create(
|
|
657
|
-
input=input,
|
|
658
|
-
model=model,
|
|
659
|
-
dimensions=_opt(dimensions),
|
|
660
|
-
user=_opt(user),
|
|
661
|
-
encoding_format='float',
|
|
662
|
-
timeout=_opt(timeout),
|
|
690
|
+
input=input, model=model, encoding_format='float', **model_kwargs
|
|
663
691
|
)
|
|
664
692
|
requests_info, tokens_info = _get_header_info(result.headers)
|
|
665
|
-
|
|
693
|
+
is_retry = _runtime_ctx is not None and _runtime_ctx.is_retry
|
|
694
|
+
rate_limits_info.record(request_ts=request_ts, requests=requests_info, tokens=tokens_info, reset_exc=is_retry)
|
|
666
695
|
return [np.array(data['embedding'], dtype=np.float64) for data in json.loads(result.content)['data']]
|
|
667
696
|
|
|
668
697
|
|
|
669
698
|
@embeddings.conditional_return_type
|
|
670
|
-
def _(model: str,
|
|
699
|
+
def _(model: str, model_kwargs: dict[str, Any] | None = None) -> ts.ArrayType:
|
|
700
|
+
dimensions: int | None = None
|
|
701
|
+
if model_kwargs is not None:
|
|
702
|
+
dimensions = model_kwargs.get('dimensions')
|
|
671
703
|
if dimensions is None:
|
|
672
704
|
if model not in _embedding_dimensions_cache:
|
|
673
705
|
# TODO: find some other way to retrieve a sample
|
|
@@ -682,14 +714,7 @@ def _(model: str, dimensions: Optional[int] = None) -> ts.ArrayType:
|
|
|
682
714
|
|
|
683
715
|
@pxt.udf
|
|
684
716
|
async def image_generations(
|
|
685
|
-
prompt: str,
|
|
686
|
-
*,
|
|
687
|
-
model: str = 'dall-e-2',
|
|
688
|
-
quality: Optional[str] = None,
|
|
689
|
-
size: Optional[str] = None,
|
|
690
|
-
style: Optional[str] = None,
|
|
691
|
-
user: Optional[str] = None,
|
|
692
|
-
timeout: Optional[float] = None,
|
|
717
|
+
prompt: str, *, model: str = 'dall-e-2', model_kwargs: dict[str, Any] | None = None
|
|
693
718
|
) -> PIL.Image.Image:
|
|
694
719
|
"""
|
|
695
720
|
Creates an image given a prompt.
|
|
@@ -708,8 +733,8 @@ async def image_generations(
|
|
|
708
733
|
Args:
|
|
709
734
|
prompt: Prompt for the image.
|
|
710
735
|
model: The model to use for the generations.
|
|
711
|
-
|
|
712
|
-
|
|
736
|
+
model_kwargs: Additional keyword args for the OpenAI `images/generations` API. For details on the available
|
|
737
|
+
parameters, see: <https://platform.openai.com/docs/api-reference/images/create>
|
|
713
738
|
|
|
714
739
|
Returns:
|
|
715
740
|
The generated image.
|
|
@@ -720,16 +745,12 @@ async def image_generations(
|
|
|
720
745
|
|
|
721
746
|
>>> tbl.add_computed_column(gen_image=image_generations(tbl.text, model='dall-e-2'))
|
|
722
747
|
"""
|
|
748
|
+
if model_kwargs is None:
|
|
749
|
+
model_kwargs = {}
|
|
750
|
+
|
|
723
751
|
# TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
|
|
724
752
|
result = await _openai_client().images.generate(
|
|
725
|
-
prompt=prompt,
|
|
726
|
-
model=_opt(model),
|
|
727
|
-
quality=_opt(quality), # type: ignore
|
|
728
|
-
size=_opt(size), # type: ignore
|
|
729
|
-
style=_opt(style), # type: ignore
|
|
730
|
-
user=_opt(user),
|
|
731
|
-
response_format='b64_json',
|
|
732
|
-
timeout=_opt(timeout),
|
|
753
|
+
prompt=prompt, model=model, response_format='b64_json', **model_kwargs
|
|
733
754
|
)
|
|
734
755
|
b64_str = result.data[0].b64_json
|
|
735
756
|
b64_bytes = base64.b64decode(b64_str)
|
|
@@ -739,9 +760,11 @@ async def image_generations(
|
|
|
739
760
|
|
|
740
761
|
|
|
741
762
|
@image_generations.conditional_return_type
|
|
742
|
-
def _(
|
|
743
|
-
if
|
|
763
|
+
def _(model_kwargs: dict[str, Any] | None = None) -> ts.ImageType:
|
|
764
|
+
if model_kwargs is None or 'size' not in model_kwargs:
|
|
765
|
+
# default size is 1024x1024
|
|
744
766
|
return ts.ImageType(size=(1024, 1024))
|
|
767
|
+
size = model_kwargs['size']
|
|
745
768
|
x_pos = size.find('x')
|
|
746
769
|
if x_pos == -1:
|
|
747
770
|
return ts.ImageType()
|
|
@@ -787,7 +810,7 @@ async def moderations(input: str, *, model: str = 'omni-moderation-latest') -> d
|
|
|
787
810
|
|
|
788
811
|
>>> tbl.add_computed_column(moderations=moderations(tbl.text, model='text-moderation-stable'))
|
|
789
812
|
"""
|
|
790
|
-
result = await _openai_client().moderations.create(input=input, model=
|
|
813
|
+
result = await _openai_client().moderations.create(input=input, model=model)
|
|
791
814
|
return result.dict()
|
|
792
815
|
|
|
793
816
|
|
|
@@ -813,7 +836,7 @@ def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
|
|
|
813
836
|
|
|
814
837
|
|
|
815
838
|
@pxt.udf
|
|
816
|
-
def _openai_response_to_pxt_tool_calls(response: dict) ->
|
|
839
|
+
def _openai_response_to_pxt_tool_calls(response: dict) -> dict | None:
|
|
817
840
|
if 'tool_calls' not in response['choices'][0]['message'] or response['choices'][0]['message']['tool_calls'] is None:
|
|
818
841
|
return None
|
|
819
842
|
openai_tool_calls = response['choices'][0]['message']['tool_calls']
|
|
@@ -826,15 +849,6 @@ def _openai_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
|
|
|
826
849
|
return pxt_tool_calls
|
|
827
850
|
|
|
828
851
|
|
|
829
|
-
_T = TypeVar('_T')
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
def _opt(arg: _T) -> Union[_T, 'openai.NotGiven']:
|
|
833
|
-
import openai
|
|
834
|
-
|
|
835
|
-
return arg if arg is not None else openai.NOT_GIVEN
|
|
836
|
-
|
|
837
|
-
|
|
838
852
|
__all__ = local_public_names(__name__)
|
|
839
853
|
|
|
840
854
|
|