pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. pixeltable/__init__.py +83 -19
  2. pixeltable/_query.py +1444 -0
  3. pixeltable/_version.py +1 -0
  4. pixeltable/catalog/__init__.py +7 -4
  5. pixeltable/catalog/catalog.py +2394 -119
  6. pixeltable/catalog/column.py +225 -104
  7. pixeltable/catalog/dir.py +38 -9
  8. pixeltable/catalog/globals.py +53 -34
  9. pixeltable/catalog/insertable_table.py +265 -115
  10. pixeltable/catalog/path.py +80 -17
  11. pixeltable/catalog/schema_object.py +28 -43
  12. pixeltable/catalog/table.py +1270 -677
  13. pixeltable/catalog/table_metadata.py +103 -0
  14. pixeltable/catalog/table_version.py +1270 -751
  15. pixeltable/catalog/table_version_handle.py +109 -0
  16. pixeltable/catalog/table_version_path.py +137 -42
  17. pixeltable/catalog/tbl_ops.py +53 -0
  18. pixeltable/catalog/update_status.py +191 -0
  19. pixeltable/catalog/view.py +251 -134
  20. pixeltable/config.py +215 -0
  21. pixeltable/env.py +736 -285
  22. pixeltable/exceptions.py +26 -2
  23. pixeltable/exec/__init__.py +7 -2
  24. pixeltable/exec/aggregation_node.py +39 -21
  25. pixeltable/exec/cache_prefetch_node.py +87 -109
  26. pixeltable/exec/cell_materialization_node.py +268 -0
  27. pixeltable/exec/cell_reconstruction_node.py +168 -0
  28. pixeltable/exec/component_iteration_node.py +25 -28
  29. pixeltable/exec/data_row_batch.py +11 -46
  30. pixeltable/exec/exec_context.py +26 -11
  31. pixeltable/exec/exec_node.py +35 -27
  32. pixeltable/exec/expr_eval/__init__.py +3 -0
  33. pixeltable/exec/expr_eval/evaluators.py +365 -0
  34. pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
  35. pixeltable/exec/expr_eval/globals.py +200 -0
  36. pixeltable/exec/expr_eval/row_buffer.py +74 -0
  37. pixeltable/exec/expr_eval/schedulers.py +413 -0
  38. pixeltable/exec/globals.py +35 -0
  39. pixeltable/exec/in_memory_data_node.py +35 -27
  40. pixeltable/exec/object_store_save_node.py +293 -0
  41. pixeltable/exec/row_update_node.py +44 -29
  42. pixeltable/exec/sql_node.py +414 -115
  43. pixeltable/exprs/__init__.py +8 -5
  44. pixeltable/exprs/arithmetic_expr.py +79 -45
  45. pixeltable/exprs/array_slice.py +5 -5
  46. pixeltable/exprs/column_property_ref.py +40 -26
  47. pixeltable/exprs/column_ref.py +254 -61
  48. pixeltable/exprs/comparison.py +14 -9
  49. pixeltable/exprs/compound_predicate.py +9 -10
  50. pixeltable/exprs/data_row.py +213 -72
  51. pixeltable/exprs/expr.py +270 -104
  52. pixeltable/exprs/expr_dict.py +6 -5
  53. pixeltable/exprs/expr_set.py +20 -11
  54. pixeltable/exprs/function_call.py +383 -284
  55. pixeltable/exprs/globals.py +18 -5
  56. pixeltable/exprs/in_predicate.py +7 -7
  57. pixeltable/exprs/inline_expr.py +37 -37
  58. pixeltable/exprs/is_null.py +8 -4
  59. pixeltable/exprs/json_mapper.py +120 -54
  60. pixeltable/exprs/json_path.py +90 -60
  61. pixeltable/exprs/literal.py +61 -16
  62. pixeltable/exprs/method_ref.py +7 -6
  63. pixeltable/exprs/object_ref.py +19 -8
  64. pixeltable/exprs/row_builder.py +238 -75
  65. pixeltable/exprs/rowid_ref.py +53 -15
  66. pixeltable/exprs/similarity_expr.py +65 -50
  67. pixeltable/exprs/sql_element_cache.py +5 -5
  68. pixeltable/exprs/string_op.py +107 -0
  69. pixeltable/exprs/type_cast.py +25 -13
  70. pixeltable/exprs/variable.py +2 -2
  71. pixeltable/func/__init__.py +9 -5
  72. pixeltable/func/aggregate_function.py +197 -92
  73. pixeltable/func/callable_function.py +119 -35
  74. pixeltable/func/expr_template_function.py +101 -48
  75. pixeltable/func/function.py +375 -62
  76. pixeltable/func/function_registry.py +20 -19
  77. pixeltable/func/globals.py +6 -5
  78. pixeltable/func/mcp.py +74 -0
  79. pixeltable/func/query_template_function.py +151 -35
  80. pixeltable/func/signature.py +178 -49
  81. pixeltable/func/tools.py +164 -0
  82. pixeltable/func/udf.py +176 -53
  83. pixeltable/functions/__init__.py +44 -4
  84. pixeltable/functions/anthropic.py +226 -47
  85. pixeltable/functions/audio.py +148 -11
  86. pixeltable/functions/bedrock.py +137 -0
  87. pixeltable/functions/date.py +188 -0
  88. pixeltable/functions/deepseek.py +113 -0
  89. pixeltable/functions/document.py +81 -0
  90. pixeltable/functions/fal.py +76 -0
  91. pixeltable/functions/fireworks.py +72 -20
  92. pixeltable/functions/gemini.py +249 -0
  93. pixeltable/functions/globals.py +208 -53
  94. pixeltable/functions/groq.py +108 -0
  95. pixeltable/functions/huggingface.py +1088 -95
  96. pixeltable/functions/image.py +155 -84
  97. pixeltable/functions/json.py +8 -11
  98. pixeltable/functions/llama_cpp.py +31 -19
  99. pixeltable/functions/math.py +169 -0
  100. pixeltable/functions/mistralai.py +50 -75
  101. pixeltable/functions/net.py +70 -0
  102. pixeltable/functions/ollama.py +29 -36
  103. pixeltable/functions/openai.py +548 -160
  104. pixeltable/functions/openrouter.py +143 -0
  105. pixeltable/functions/replicate.py +15 -14
  106. pixeltable/functions/reve.py +250 -0
  107. pixeltable/functions/string.py +310 -85
  108. pixeltable/functions/timestamp.py +37 -19
  109. pixeltable/functions/together.py +77 -120
  110. pixeltable/functions/twelvelabs.py +188 -0
  111. pixeltable/functions/util.py +7 -2
  112. pixeltable/functions/uuid.py +30 -0
  113. pixeltable/functions/video.py +1528 -117
  114. pixeltable/functions/vision.py +26 -26
  115. pixeltable/functions/voyageai.py +289 -0
  116. pixeltable/functions/whisper.py +19 -10
  117. pixeltable/functions/whisperx.py +179 -0
  118. pixeltable/functions/yolox.py +112 -0
  119. pixeltable/globals.py +716 -236
  120. pixeltable/index/__init__.py +3 -1
  121. pixeltable/index/base.py +17 -21
  122. pixeltable/index/btree.py +32 -22
  123. pixeltable/index/embedding_index.py +155 -92
  124. pixeltable/io/__init__.py +12 -7
  125. pixeltable/io/datarows.py +140 -0
  126. pixeltable/io/external_store.py +83 -125
  127. pixeltable/io/fiftyone.py +24 -33
  128. pixeltable/io/globals.py +47 -182
  129. pixeltable/io/hf_datasets.py +96 -127
  130. pixeltable/io/label_studio.py +171 -156
  131. pixeltable/io/lancedb.py +3 -0
  132. pixeltable/io/pandas.py +136 -115
  133. pixeltable/io/parquet.py +40 -153
  134. pixeltable/io/table_data_conduit.py +702 -0
  135. pixeltable/io/utils.py +100 -0
  136. pixeltable/iterators/__init__.py +8 -4
  137. pixeltable/iterators/audio.py +207 -0
  138. pixeltable/iterators/base.py +9 -3
  139. pixeltable/iterators/document.py +144 -87
  140. pixeltable/iterators/image.py +17 -38
  141. pixeltable/iterators/string.py +15 -12
  142. pixeltable/iterators/video.py +523 -127
  143. pixeltable/metadata/__init__.py +33 -8
  144. pixeltable/metadata/converters/convert_10.py +2 -3
  145. pixeltable/metadata/converters/convert_13.py +2 -2
  146. pixeltable/metadata/converters/convert_15.py +15 -11
  147. pixeltable/metadata/converters/convert_16.py +4 -5
  148. pixeltable/metadata/converters/convert_17.py +4 -5
  149. pixeltable/metadata/converters/convert_18.py +4 -6
  150. pixeltable/metadata/converters/convert_19.py +6 -9
  151. pixeltable/metadata/converters/convert_20.py +3 -6
  152. pixeltable/metadata/converters/convert_21.py +6 -8
  153. pixeltable/metadata/converters/convert_22.py +3 -2
  154. pixeltable/metadata/converters/convert_23.py +33 -0
  155. pixeltable/metadata/converters/convert_24.py +55 -0
  156. pixeltable/metadata/converters/convert_25.py +19 -0
  157. pixeltable/metadata/converters/convert_26.py +23 -0
  158. pixeltable/metadata/converters/convert_27.py +29 -0
  159. pixeltable/metadata/converters/convert_28.py +13 -0
  160. pixeltable/metadata/converters/convert_29.py +110 -0
  161. pixeltable/metadata/converters/convert_30.py +63 -0
  162. pixeltable/metadata/converters/convert_31.py +11 -0
  163. pixeltable/metadata/converters/convert_32.py +15 -0
  164. pixeltable/metadata/converters/convert_33.py +17 -0
  165. pixeltable/metadata/converters/convert_34.py +21 -0
  166. pixeltable/metadata/converters/convert_35.py +9 -0
  167. pixeltable/metadata/converters/convert_36.py +38 -0
  168. pixeltable/metadata/converters/convert_37.py +15 -0
  169. pixeltable/metadata/converters/convert_38.py +39 -0
  170. pixeltable/metadata/converters/convert_39.py +124 -0
  171. pixeltable/metadata/converters/convert_40.py +73 -0
  172. pixeltable/metadata/converters/convert_41.py +12 -0
  173. pixeltable/metadata/converters/convert_42.py +9 -0
  174. pixeltable/metadata/converters/convert_43.py +44 -0
  175. pixeltable/metadata/converters/util.py +44 -18
  176. pixeltable/metadata/notes.py +21 -0
  177. pixeltable/metadata/schema.py +185 -42
  178. pixeltable/metadata/utils.py +74 -0
  179. pixeltable/mypy/__init__.py +3 -0
  180. pixeltable/mypy/mypy_plugin.py +123 -0
  181. pixeltable/plan.py +616 -225
  182. pixeltable/share/__init__.py +3 -0
  183. pixeltable/share/packager.py +797 -0
  184. pixeltable/share/protocol/__init__.py +33 -0
  185. pixeltable/share/protocol/common.py +165 -0
  186. pixeltable/share/protocol/operation_types.py +33 -0
  187. pixeltable/share/protocol/replica.py +119 -0
  188. pixeltable/share/publish.py +349 -0
  189. pixeltable/store.py +398 -232
  190. pixeltable/type_system.py +730 -267
  191. pixeltable/utils/__init__.py +40 -0
  192. pixeltable/utils/arrow.py +201 -29
  193. pixeltable/utils/av.py +298 -0
  194. pixeltable/utils/azure_store.py +346 -0
  195. pixeltable/utils/coco.py +26 -27
  196. pixeltable/utils/code.py +4 -4
  197. pixeltable/utils/console_output.py +46 -0
  198. pixeltable/utils/coroutine.py +24 -0
  199. pixeltable/utils/dbms.py +92 -0
  200. pixeltable/utils/description_helper.py +11 -12
  201. pixeltable/utils/documents.py +60 -61
  202. pixeltable/utils/exception_handler.py +36 -0
  203. pixeltable/utils/filecache.py +38 -22
  204. pixeltable/utils/formatter.py +88 -51
  205. pixeltable/utils/gcs_store.py +295 -0
  206. pixeltable/utils/http.py +133 -0
  207. pixeltable/utils/http_server.py +14 -13
  208. pixeltable/utils/iceberg.py +13 -0
  209. pixeltable/utils/image.py +17 -0
  210. pixeltable/utils/lancedb.py +90 -0
  211. pixeltable/utils/local_store.py +322 -0
  212. pixeltable/utils/misc.py +5 -0
  213. pixeltable/utils/object_stores.py +573 -0
  214. pixeltable/utils/pydantic.py +60 -0
  215. pixeltable/utils/pytorch.py +20 -20
  216. pixeltable/utils/s3_store.py +527 -0
  217. pixeltable/utils/sql.py +32 -5
  218. pixeltable/utils/system.py +30 -0
  219. pixeltable/utils/transactional_directory.py +4 -3
  220. pixeltable-0.5.7.dist-info/METADATA +579 -0
  221. pixeltable-0.5.7.dist-info/RECORD +227 -0
  222. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
  223. pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
  224. pixeltable/__version__.py +0 -3
  225. pixeltable/catalog/named_function.py +0 -36
  226. pixeltable/catalog/path_dict.py +0 -141
  227. pixeltable/dataframe.py +0 -894
  228. pixeltable/exec/expr_eval_node.py +0 -232
  229. pixeltable/ext/__init__.py +0 -14
  230. pixeltable/ext/functions/__init__.py +0 -8
  231. pixeltable/ext/functions/whisperx.py +0 -77
  232. pixeltable/ext/functions/yolox.py +0 -157
  233. pixeltable/tool/create_test_db_dump.py +0 -311
  234. pixeltable/tool/create_test_video.py +0 -81
  235. pixeltable/tool/doc_plugins/griffe.py +0 -50
  236. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  237. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  238. pixeltable/tool/embed_udf.py +0 -9
  239. pixeltable/tool/mypy_plugin.py +0 -55
  240. pixeltable/utils/media_store.py +0 -76
  241. pixeltable/utils/s3.py +0 -16
  242. pixeltable-0.2.26.dist-info/METADATA +0 -400
  243. pixeltable-0.2.26.dist-info/RECORD +0 -156
  244. pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
  245. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
@@ -1,49 +1,253 @@
1
1
  """
2
- Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
2
+ Pixeltable UDFs
3
3
  that wrap various endpoints from the OpenAI API. In order to use them, you must
4
4
  first `pip install openai` and configure your OpenAI credentials, as described in
5
- the [Working with OpenAI](https://pixeltable.readme.io/docs/working-with-openai) tutorial.
5
+ the [Working with OpenAI](https://docs.pixeltable.com/notebooks/integrations/working-with-openai) tutorial.
6
6
  """
7
7
 
8
8
  import base64
9
+ import datetime
9
10
  import io
11
+ import json
12
+ import logging
13
+ import math
10
14
  import pathlib
11
- import uuid
12
- from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
15
+ import re
16
+ from typing import TYPE_CHECKING, Any, Callable, Type
13
17
 
18
+ import httpx
14
19
  import numpy as np
15
- import PIL.Image
16
- import tenacity
20
+ import PIL
17
21
 
18
22
  import pixeltable as pxt
19
- from pixeltable import env
20
- from pixeltable.func import Batch
23
+ from pixeltable import env, exprs, type_system as ts
24
+ from pixeltable.config import Config
25
+ from pixeltable.func import Batch, Tools
21
26
  from pixeltable.utils.code import local_public_names
27
+ from pixeltable.utils.local_store import TempStore
28
+ from pixeltable.utils.system import set_file_descriptor_limit
22
29
 
23
30
  if TYPE_CHECKING:
24
31
  import openai
25
32
 
33
+ _logger = logging.getLogger('pixeltable')
34
+
26
35
 
27
36
  @env.register_client('openai')
28
- def _(api_key: str) -> 'openai.OpenAI':
37
+ def _(api_key: str, base_url: str | None = None, api_version: str | None = None) -> 'openai.AsyncOpenAI':
29
38
  import openai
30
- return openai.OpenAI(api_key=api_key)
39
+
40
+ max_connections = Config.get().get_int_value('openai.max_connections') or 2000
41
+ max_keepalive_connections = Config.get().get_int_value('openai.max_keepalive_connections') or 100
42
+ set_file_descriptor_limit(max_connections * 2)
43
+ default_query = None if api_version is None else {'api-version': api_version}
44
+
45
+ # Pixeltable scheduler's retry logic takes into account the rate limit-related response headers, so in theory we can
46
+ # benefit from disabling retries in the OpenAI client (max_retries=0). However to do that, we need to get smarter
47
+ # about idempotency keys and possibly more.
48
+ return openai.AsyncOpenAI(
49
+ api_key=api_key,
50
+ base_url=base_url,
51
+ default_query=default_query,
52
+ # recommended to increase limits for async client to avoid connection errors
53
+ http_client=httpx.AsyncClient(
54
+ limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections, max_connections=max_connections),
55
+ # HTTP1 tends to perform better on this kind of workloads
56
+ http2=False,
57
+ http1=True,
58
+ ),
59
+ )
31
60
 
32
61
 
33
- def _openai_client() -> 'openai.OpenAI':
62
+ def _openai_client() -> 'openai.AsyncOpenAI':
34
63
  return env.Env.get().get_client('openai')
35
64
 
36
65
 
37
- # Exponential backoff decorator using tenacity.
38
- # TODO(aaron-siegel): Right now this hardwires random exponential backoff with defaults suggested
39
- # by OpenAI. Should we investigate making this more customizable in the future?
40
- def _retry(fn: Callable) -> Callable:
41
- import openai
42
- return tenacity.retry(
43
- retry=tenacity.retry_if_exception_type(openai.RateLimitError),
44
- wait=tenacity.wait_random_exponential(multiplier=1, max=60),
45
- stop=tenacity.stop_after_attempt(20),
46
- )(fn)
66
+ # models that share rate limits; see https://platform.openai.com/settings/organization/limits for details
67
+ _shared_rate_limits = {
68
+ 'gpt-4-turbo': [
69
+ 'gpt-4-turbo',
70
+ 'gpt-4-turbo-latest',
71
+ 'gpt-4-turbo-2024-04-09',
72
+ 'gpt-4-turbo-preview',
73
+ 'gpt-4-0125-preview',
74
+ 'gpt-4-1106-preview',
75
+ ],
76
+ 'gpt-4o': [
77
+ 'gpt-4o',
78
+ 'gpt-4o-latest',
79
+ 'gpt-4o-2024-05-13',
80
+ 'gpt-4o-2024-08-06',
81
+ 'gpt-4o-2024-11-20',
82
+ 'gpt-4o-audio-preview',
83
+ 'gpt-4o-audio-preview-2024-10-01',
84
+ 'gpt-4o-audio-preview-2024-12-17',
85
+ ],
86
+ 'gpt-4o-mini': [
87
+ 'gpt-4o-mini',
88
+ 'gpt-4o-mini-latest',
89
+ 'gpt-4o-mini-2024-07-18',
90
+ 'gpt-4o-mini-audio-preview',
91
+ 'gpt-4o-mini-audio-preview-2024-12-17',
92
+ ],
93
+ 'gpt-4o-mini-realtime-preview': [
94
+ 'gpt-4o-mini-realtime-preview',
95
+ 'gpt-4o-mini-realtime-preview-latest',
96
+ 'gpt-4o-mini-realtime-preview-2024-12-17',
97
+ ],
98
+ }
99
+
100
+
101
+ def _rate_limits_pool(model: str) -> str:
102
+ for model_family, models in _shared_rate_limits.items():
103
+ if model in models:
104
+ return f'rate-limits:openai:{model_family}'
105
+ return f'rate-limits:openai:{model}'
106
+
107
+
108
+ def _parse_header_duration(duration_str: str) -> float | None:
109
+ """Parses the value of x-ratelimit-reset-* header into seconds.
110
+
111
+ Returns None if the input cannot be parsed.
112
+
113
+ Real life examples of header values:
114
+ * '1m33.792s'
115
+ * '857ms'
116
+ * '0s'
117
+ * '47.874s'
118
+ * '156h58m48.601s'
119
+ """
120
+ if duration_str is None or duration_str.strip() == '':
121
+ return None
122
+ units = {
123
+ 86400: r'(\d+)d', # days
124
+ 3600: r'(\d+)h', # hours
125
+ 60: r'(\d+)m(?:[^s]|$)', # minutes
126
+ 1: r'([\d.]+)s', # seconds
127
+ 0.001: r'(\d+)ms', # millis
128
+ }
129
+ seconds = None
130
+ for unit_value, pattern in units.items():
131
+ match = re.search(pattern, duration_str)
132
+ if match:
133
+ seconds = seconds or 0.0
134
+ seconds += float(match.group(1)) * unit_value
135
+ _logger.debug(f'Parsed duration header value "{duration_str}" into {seconds} seconds')
136
+ return seconds
137
+
138
+
139
+ def _get_header_info(
140
+ headers: httpx.Headers,
141
+ ) -> tuple[tuple[int, int, datetime.datetime] | None, tuple[int, int, datetime.datetime] | None]:
142
+ """Parses rate limit related headers"""
143
+ # Requests and project-requests are two separate limits of requests per minute. project-requests headers will be
144
+ # present if an RPM limit is configured on the project limit.
145
+ requests_info = _get_resource_info(headers, 'requests')
146
+ requests_fraction_remaining = _fract_remaining(requests_info)
147
+ project_requests_info = _get_resource_info(headers, 'project-requests')
148
+ project_requests_fraction_remaining = _fract_remaining(project_requests_info)
149
+
150
+ # If both limit infos are present, pick the one with the least percentage remaining
151
+ best_requests_info = requests_info or project_requests_info
152
+ if (
153
+ requests_fraction_remaining is not None
154
+ and project_requests_fraction_remaining is not None
155
+ and project_requests_fraction_remaining < requests_fraction_remaining
156
+ ):
157
+ best_requests_info = project_requests_info
158
+
159
+ # Same story with tokens
160
+ tokens_info = _get_resource_info(headers, 'tokens')
161
+ tokens_fraction_remaining = _fract_remaining(tokens_info)
162
+ project_tokens_info = _get_resource_info(headers, 'project-tokens')
163
+ project_tokens_fraction_remaining = _fract_remaining(project_tokens_info)
164
+
165
+ best_tokens_info = tokens_info or project_tokens_info
166
+ if (
167
+ tokens_fraction_remaining is not None
168
+ and project_tokens_fraction_remaining is not None
169
+ and project_tokens_fraction_remaining < tokens_fraction_remaining
170
+ ):
171
+ best_tokens_info = project_tokens_info
172
+
173
+ if best_requests_info is None or best_tokens_info is None:
174
+ _logger.debug(f'get_header_info(): incomplete rate limit info: {headers}')
175
+
176
+ return best_requests_info, best_tokens_info
177
+
178
+
179
+ def _get_resource_info(headers: httpx.Headers, resource: str) -> tuple[int, int, datetime.datetime] | None:
180
+ remaining_str = headers.get(f'x-ratelimit-remaining-{resource}')
181
+ if remaining_str is None:
182
+ return None
183
+ remaining = int(remaining_str)
184
+ limit_str = headers.get(f'x-ratelimit-limit-{resource}')
185
+ limit = int(limit_str) if limit_str is not None else None
186
+ reset_str = headers.get(f'x-ratelimit-reset-{resource}')
187
+ reset_in_seconds = _parse_header_duration(reset_str) or 5.0 # Default to 5 seconds
188
+ reset_ts = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(seconds=reset_in_seconds)
189
+ return (limit, remaining, reset_ts)
190
+
191
+
192
+ def _fract_remaining(resource_info: tuple[int, int, datetime.datetime] | None) -> float | None:
193
+ if resource_info is None:
194
+ return None
195
+ limit, remaining, _ = resource_info
196
+ if limit is None or remaining is None:
197
+ return None
198
+ return remaining / limit
199
+
200
+
201
+ class OpenAIRateLimitsInfo(env.RateLimitsInfo):
202
+ retryable_errors: tuple[Type[Exception], ...]
203
+
204
+ def __init__(self, get_request_resources: Callable[..., dict[str, int]]):
205
+ super().__init__(get_request_resources)
206
+ import openai
207
+
208
+ self.retryable_errors = (
209
+ # ConnectionError: we occasionally see this error when the AsyncConnectionPool is trying to close
210
+ # expired connections
211
+ # (AsyncConnectionPool._close_expired_connections() fails with ConnectionError when executing
212
+ # 'await connection.aclose()', which is very likely a bug in AsyncConnectionPool)
213
+ openai.APIConnectionError,
214
+ # the following errors are retryable according to OpenAI's API documentation
215
+ openai.RateLimitError,
216
+ openai.APITimeoutError,
217
+ openai.UnprocessableEntityError,
218
+ openai.InternalServerError,
219
+ )
220
+
221
+ def record_exc(self, request_ts: datetime.datetime, exc: Exception) -> None:
222
+ import openai
223
+
224
+ _ = isinstance(exc, openai.APIError)
225
+ if not isinstance(exc, openai.APIError) or not hasattr(exc, 'response') or not hasattr(exc.response, 'headers'):
226
+ return
227
+
228
+ requests_info, tokens_info = _get_header_info(exc.response.headers)
229
+ _logger.debug(
230
+ f'record_exc(): request_ts: {request_ts}, requests_info={requests_info} tokens_info={tokens_info}'
231
+ )
232
+ self.record(request_ts=request_ts, requests=requests_info, tokens=tokens_info)
233
+ self.has_exc = True
234
+
235
+ def _retry_delay_from_exception(self, exc: Exception) -> float | None:
236
+ try:
237
+ retry_after_str = exc.response.headers.get('retry-after') # type: ignore
238
+ except AttributeError:
239
+ return None
240
+ if retry_after_str is not None and re.fullmatch(r'\d{1,4}', retry_after_str):
241
+ return float(retry_after_str)
242
+ return None
243
+
244
+ def get_retry_delay(self, exc: Exception, attempt: int) -> float | None:
245
+ import openai
246
+
247
+ if not isinstance(exc, self.retryable_errors):
248
+ return None
249
+ assert isinstance(exc, openai.APIError)
250
+ return self._retry_delay_from_exception(exc) or super().get_retry_delay(exc, attempt)
47
251
 
48
252
 
49
253
  #####################################
@@ -51,14 +255,16 @@ def _retry(fn: Callable) -> Callable:
51
255
 
52
256
 
53
257
  @pxt.udf
54
- def speech(
55
- input: str, *, model: str, voice: str, response_format: Optional[str] = None, speed: Optional[float] = None
56
- ) -> pxt.Audio:
258
+ async def speech(input: str, *, model: str, voice: str, model_kwargs: dict[str, Any] | None = None) -> pxt.Audio:
57
259
  """
58
260
  Generates audio from the input text.
59
261
 
60
262
  Equivalent to the OpenAI `audio/speech` API endpoint.
61
- For additional details, see: [https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
263
+ For additional details, see: <https://platform.openai.com/docs/guides/text-to-speech>
264
+
265
+ Request throttling:
266
+ Applies the rate limit set in the config (section `openai.rate_limits`; use the model id as the key). If no rate
267
+ limit is configured, uses a default of 600 RPM.
62
268
 
63
269
  __Requirements:__
64
270
 
@@ -69,8 +275,8 @@ def speech(
69
275
  model: The model to use for speech synthesis.
70
276
  voice: The voice profile to use for speech synthesis. Supported options include:
71
277
  `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`.
72
-
73
- For details on the other parameters, see: [https://platform.openai.com/docs/api-reference/audio/createSpeech](https://platform.openai.com/docs/api-reference/audio/createSpeech)
278
+ model_kwargs: Additional keyword args for the OpenAI `audio/speech` API. For details on the available
279
+ parameters, see: <https://platform.openai.com/docs/api-reference/audio/createSpeech>
74
280
 
75
281
  Returns:
76
282
  An audio file containing the synthesized speech.
@@ -79,31 +285,29 @@ def speech(
79
285
  Add a computed column that applies the model `tts-1` to an existing Pixeltable column `tbl.text`
80
286
  of the table `tbl`:
81
287
 
82
- >>> tbl['audio'] = speech(tbl.text, model='tts-1', voice='nova')
288
+ >>> tbl.add_computed_column(audio=speech(tbl.text, model='tts-1', voice='nova'))
83
289
  """
84
- content = _retry(_openai_client().audio.speech.create)(
85
- input=input, model=model, voice=voice, response_format=_opt(response_format), speed=_opt(speed)
86
- )
87
- ext = response_format or 'mp3'
88
- output_filename = str(env.Env.get().tmp_dir / f'{uuid.uuid4()}.{ext}')
290
+ if model_kwargs is None:
291
+ model_kwargs = {}
292
+
293
+ content = await _openai_client().audio.speech.create(input=input, model=model, voice=voice, **model_kwargs)
294
+ ext = model_kwargs.get('response_format', 'mp3')
295
+ output_filename = str(TempStore.create_path(extension=f'.{ext}'))
89
296
  content.write_to_file(output_filename)
90
297
  return output_filename
91
298
 
92
299
 
93
300
  @pxt.udf
94
- def transcriptions(
95
- audio: pxt.Audio,
96
- *,
97
- model: str,
98
- language: Optional[str] = None,
99
- prompt: Optional[str] = None,
100
- temperature: Optional[float] = None,
101
- ) -> dict:
301
+ async def transcriptions(audio: pxt.Audio, *, model: str, model_kwargs: dict[str, Any] | None = None) -> dict:
102
302
  """
103
303
  Transcribes audio into the input language.
104
304
 
105
305
  Equivalent to the OpenAI `audio/transcriptions` API endpoint.
106
- For additional details, see: [https://platform.openai.com/docs/guides/speech-to-text](https://platform.openai.com/docs/guides/speech-to-text)
306
+ For additional details, see: <https://platform.openai.com/docs/guides/speech-to-text>
307
+
308
+ Request throttling:
309
+ Applies the rate limit set in the config (section `openai.rate_limits`; use the model id as the key). If no rate
310
+ limit is configured, uses a default of 600 RPM.
107
311
 
108
312
  __Requirements:__
109
313
 
@@ -112,8 +316,8 @@ def transcriptions(
112
316
  Args:
113
317
  audio: The audio to transcribe.
114
318
  model: The model to use for speech transcription.
115
-
116
- For details on the other parameters, see: [https://platform.openai.com/docs/api-reference/audio/createTranscription](https://platform.openai.com/docs/api-reference/audio/createTranscription)
319
+ model_kwargs: Additional keyword args for the OpenAI `audio/transcriptions` API. For details on the available
320
+ parameters, see: <https://platform.openai.com/docs/api-reference/audio/createTranscription>
117
321
 
118
322
  Returns:
119
323
  A dictionary containing the transcription and other metadata.
@@ -122,28 +326,27 @@ def transcriptions(
122
326
  Add a computed column that applies the model `whisper-1` to an existing Pixeltable column `tbl.audio`
123
327
  of the table `tbl`:
124
328
 
125
- >>> tbl['transcription'] = transcriptions(tbl.audio, model='whisper-1', language='en')
329
+ >>> tbl.add_computed_column(transcription=transcriptions(tbl.audio, model='whisper-1', language='en'))
126
330
  """
331
+ if model_kwargs is None:
332
+ model_kwargs = {}
333
+
127
334
  file = pathlib.Path(audio)
128
- transcription = _retry(_openai_client().audio.transcriptions.create)(
129
- file=file, model=model, language=_opt(language), prompt=_opt(prompt), temperature=_opt(temperature)
130
- )
335
+ transcription = await _openai_client().audio.transcriptions.create(file=file, model=model, **model_kwargs)
131
336
  return transcription.dict()
132
337
 
133
338
 
134
339
  @pxt.udf
135
- def translations(
136
- audio: pxt.Audio,
137
- *,
138
- model: str,
139
- prompt: Optional[str] = None,
140
- temperature: Optional[float] = None
141
- ) -> dict:
340
+ async def translations(audio: pxt.Audio, *, model: str, model_kwargs: dict[str, Any] | None = None) -> dict:
142
341
  """
143
342
  Translates audio into English.
144
343
 
145
344
  Equivalent to the OpenAI `audio/translations` API endpoint.
146
- For additional details, see: [https://platform.openai.com/docs/guides/speech-to-text](https://platform.openai.com/docs/guides/speech-to-text)
345
+ For additional details, see: <https://platform.openai.com/docs/guides/speech-to-text>
346
+
347
+ Request throttling:
348
+ Applies the rate limit set in the config (section `openai.rate_limits`; use the model id as the key). If no rate
349
+ limit is configured, uses a default of 600 RPM.
147
350
 
148
351
  __Requirements:__
149
352
 
@@ -152,8 +355,8 @@ def translations(
152
355
  Args:
153
356
  audio: The audio to translate.
154
357
  model: The model to use for speech transcription and translation.
155
-
156
- For details on the other parameters, see: [https://platform.openai.com/docs/api-reference/audio/createTranslation](https://platform.openai.com/docs/api-reference/audio/createTranslation)
358
+ model_kwargs: Additional keyword args for the OpenAI `audio/translations` API. For details on the available
359
+ parameters, see: <https://platform.openai.com/docs/api-reference/audio/createTranslation>
157
360
 
158
361
  Returns:
159
362
  A dictionary containing the translation and other metadata.
@@ -162,12 +365,13 @@ def translations(
162
365
  Add a computed column that applies the model `whisper-1` to an existing Pixeltable column `tbl.audio`
163
366
  of the table `tbl`:
164
367
 
165
- >>> tbl['translation'] = translations(tbl.audio, model='whisper-1', language='en')
368
+ >>> tbl.add_computed_column(translation=translations(tbl.audio, model='whisper-1', language='en'))
166
369
  """
370
+ if model_kwargs is None:
371
+ model_kwargs = {}
372
+
167
373
  file = pathlib.Path(audio)
168
- translation = _retry(_openai_client().audio.translations.create)(
169
- file=file, model=model, prompt=_opt(prompt), temperature=_opt(temperature)
170
- )
374
+ translation = await _openai_client().audio.translations.create(file=file, model=model, **model_kwargs)
171
375
  return translation.dict()
172
376
 
173
377
 
@@ -175,32 +379,75 @@ def translations(
175
379
  # Chat Endpoints
176
380
 
177
381
 
382
+ def _default_max_tokens(model: str) -> int:
383
+ if (
384
+ _is_model_family(model, 'gpt-4o-realtime')
385
+ or _is_model_family(model, 'gpt-4o-mini-realtime')
386
+ or _is_model_family(model, 'gpt-4-turbo')
387
+ or _is_model_family(model, 'gpt-3.5-turbo')
388
+ ):
389
+ return 4096
390
+ if _is_model_family(model, 'gpt-4'):
391
+ return 8192 # All other gpt-4 models (will not match on gpt-4o models)
392
+ if _is_model_family(model, 'gpt-4o') or _is_model_family(model, 'gpt-4.5-preview'):
393
+ return 16384 # All other gpt-4o / gpt-4.5 models
394
+ if _is_model_family(model, 'o1-preview'):
395
+ return 32768
396
+ if _is_model_family(model, 'o1-mini'):
397
+ return 65536
398
+ if _is_model_family(model, 'o1') or _is_model_family(model, 'o3'):
399
+ return 100000 # All other o1 / o3 models
400
+ return 100000 # global default
401
+
402
+
403
+ def _is_model_family(model: str, family: str) -> bool:
404
+ # `model.startswith(family)` would be a simpler match, but increases the risk of false positives.
405
+ # We use a slightly more complicated criterion to make things a little less error prone.
406
+ return model == family or model.startswith(f'{family}-')
407
+
408
+
409
+ def _chat_completions_get_request_resources(
410
+ messages: list, model: str, model_kwargs: dict[str, Any] | None
411
+ ) -> dict[str, int]:
412
+ if model_kwargs is None:
413
+ model_kwargs = {}
414
+
415
+ max_completion_tokens = model_kwargs.get('max_completion_tokens')
416
+ max_tokens = model_kwargs.get('max_tokens')
417
+ n = model_kwargs.get('n')
418
+
419
+ completion_tokens = (n or 1) * (max_completion_tokens or max_tokens or _default_max_tokens(model))
420
+
421
+ num_tokens = 0.0
422
+ for message in messages:
423
+ num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
424
+ for key, value in message.items():
425
+ num_tokens += len(value) / 4
426
+ if key == 'name': # if there's a name, the role is omitted
427
+ num_tokens -= 1 # role is always required and always 1 token
428
+ num_tokens += 2 # every reply is primed with <im_start>assistant
429
+ return {'requests': 1, 'tokens': int(num_tokens) + completion_tokens}
430
+
431
+
178
432
  @pxt.udf
179
- def chat_completions(
433
+ async def chat_completions(
180
434
  messages: list,
181
435
  *,
182
436
  model: str,
183
- frequency_penalty: Optional[float] = None,
184
- logit_bias: Optional[dict[str, int]] = None,
185
- logprobs: Optional[bool] = None,
186
- top_logprobs: Optional[int] = None,
187
- max_tokens: Optional[int] = None,
188
- n: Optional[int] = None,
189
- presence_penalty: Optional[float] = None,
190
- response_format: Optional[dict] = None,
191
- seed: Optional[int] = None,
192
- stop: Optional[list[str]] = None,
193
- temperature: Optional[float] = None,
194
- top_p: Optional[float] = None,
195
- tools: Optional[list[dict]] = None,
196
- tool_choice: Optional[dict] = None,
197
- user: Optional[str] = None,
437
+ model_kwargs: dict[str, Any] | None = None,
438
+ tools: list[dict[str, Any]] | None = None,
439
+ tool_choice: dict[str, Any] | None = None,
440
+ _runtime_ctx: env.RuntimeCtx | None = None,
198
441
  ) -> dict:
199
442
  """
200
443
  Creates a model response for the given chat conversation.
201
444
 
202
445
  Equivalent to the OpenAI `chat/completions` API endpoint.
203
- For additional details, see: [https://platform.openai.com/docs/guides/chat-completions](https://platform.openai.com/docs/guides/chat-completions)
446
+ For additional details, see: <https://platform.openai.com/docs/guides/chat-completions>
447
+
448
+ Request throttling:
449
+ Uses the rate limit-related headers returned by the API to throttle requests adaptively, based on available
450
+ request and token capacity. No configuration is necessary.
204
451
 
205
452
  __Requirements:__
206
453
 
@@ -209,8 +456,8 @@ def chat_completions(
209
456
  Args:
210
457
  messages: A list of messages to use for chat completion, as described in the OpenAI API documentation.
211
458
  model: The model to use for chat completion.
212
-
213
- For details on the other parameters, see: [https://platform.openai.com/docs/api-reference/chat](https://platform.openai.com/docs/api-reference/chat)
459
+ model_kwargs: Additional keyword args for the OpenAI `chat/completions` API. For details on the available
460
+ parameters, see: <https://platform.openai.com/docs/api-reference/chat/create>
214
461
 
215
462
  Returns:
216
463
  A dictionary containing the response and other metadata.
@@ -220,40 +467,101 @@ def chat_completions(
220
467
  of the table `tbl`:
221
468
 
222
469
  >>> messages = [
223
- {'role': 'system', 'content': 'You are a helpful assistant.'},
224
- {'role': 'user', 'content': tbl.prompt}
225
- ]
226
- tbl['response'] = chat_completions(messages, model='gpt-4o-mini')
470
+ ... {'role': 'system', 'content': 'You are a helpful assistant.'},
471
+ ... {'role': 'user', 'content': tbl.prompt}
472
+ ... ]
473
+ >>> tbl.add_computed_column(response=chat_completions(messages, model='gpt-4o-mini'))
227
474
  """
228
- result = _retry(_openai_client().chat.completions.create)(
229
- messages=messages,
230
- model=model,
231
- frequency_penalty=_opt(frequency_penalty),
232
- logit_bias=_opt(logit_bias),
233
- logprobs=_opt(logprobs),
234
- top_logprobs=_opt(top_logprobs),
235
- max_tokens=_opt(max_tokens),
236
- n=_opt(n),
237
- presence_penalty=_opt(presence_penalty),
238
- response_format=_opt(response_format),
239
- seed=_opt(seed),
240
- stop=_opt(stop),
241
- temperature=_opt(temperature),
242
- top_p=_opt(top_p),
243
- tools=_opt(tools),
244
- tool_choice=_opt(tool_choice),
245
- user=_opt(user),
475
+ if model_kwargs is None:
476
+ model_kwargs = {}
477
+
478
+ if tools is not None:
479
+ model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
480
+
481
+ if tool_choice is not None:
482
+ if tool_choice['auto']:
483
+ model_kwargs['tool_choice'] = 'auto'
484
+ elif tool_choice['required']:
485
+ model_kwargs['tool_choice'] = 'required'
486
+ else:
487
+ assert tool_choice['tool'] is not None
488
+ model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
489
+
490
+ if tool_choice is not None and not tool_choice['parallel_tool_calls']:
491
+ model_kwargs['parallel_tool_calls'] = False
492
+
493
+ # make sure the pool info exists prior to making the request
494
+ resource_pool = _rate_limits_pool(model)
495
+ rate_limits_info = env.Env.get().get_resource_pool_info(
496
+ resource_pool, lambda: OpenAIRateLimitsInfo(_chat_completions_get_request_resources)
246
497
  )
247
- return result.dict()
498
+
499
+ request_ts = datetime.datetime.now(tz=datetime.timezone.utc)
500
+ result = await _openai_client().chat.completions.with_raw_response.create(
501
+ messages=messages, model=model, **model_kwargs
502
+ )
503
+
504
+ requests_info, tokens_info = _get_header_info(result.headers)
505
+ is_retry = _runtime_ctx is not None and _runtime_ctx.is_retry
506
+ rate_limits_info.record(request_ts=request_ts, requests=requests_info, tokens=tokens_info, reset_exc=is_retry)
507
+
508
+ return json.loads(result.text)
509
+
510
+
511
+ def _vision_get_request_resources(
512
+ prompt: str, image: PIL.Image.Image, model: str, model_kwargs: dict[str, Any] | None = None
513
+ ) -> dict[str, int]:
514
+ if model_kwargs is None:
515
+ model_kwargs = {}
516
+
517
+ max_completion_tokens = model_kwargs.get('max_completion_tokens')
518
+ max_tokens = model_kwargs.get('max_tokens')
519
+ n = model_kwargs.get('n')
520
+
521
+ completion_tokens = (n or 1) * (max_completion_tokens or max_tokens or _default_max_tokens(model))
522
+ prompt_tokens = len(prompt) / 4
523
+
524
+ # calculate image tokens based on
525
+ # https://platform.openai.com/docs/guides/vision/calculating-costs#calculating-costs
526
+ # assuming detail='high' (which appears to be the default, according to community forum posts)
527
+
528
+ # number of 512x512 crops; ceil(): partial crops still count as full crops
529
+ crops_width = math.ceil(image.width / 512)
530
+ crops_height = math.ceil(image.height / 512)
531
+ total_crops = crops_width * crops_height
532
+
533
+ base_tokens = 85 # base cost for the initial 512x512 overview
534
+ crop_tokens = 170 # cost per additional 512x512 crop
535
+ img_tokens = base_tokens + (crop_tokens * total_crops)
536
+
537
+ total_tokens = (
538
+ prompt_tokens
539
+ + img_tokens
540
+ + completion_tokens
541
+ + 4 # for <im_start>{role/name}\n{content}<im_end>\n
542
+ + 2 # for reply's <im_start>assistant
543
+ )
544
+ return {'requests': 1, 'tokens': int(total_tokens)}
248
545
 
249
546
 
250
547
  @pxt.udf
251
- def vision(prompt: str, image: PIL.Image.Image, *, model: str) -> str:
548
+ async def vision(
549
+ prompt: str,
550
+ image: PIL.Image.Image,
551
+ *,
552
+ model: str,
553
+ model_kwargs: dict[str, Any] | None = None,
554
+ _runtime_ctx: env.RuntimeCtx | None = None,
555
+ ) -> str:
252
556
  """
253
557
  Analyzes an image with the OpenAI vision capability. This is a convenience function that takes an image and
254
558
  prompt, and constructs a chat completion request that utilizes OpenAI vision.
255
559
 
256
- For additional details, see: [https://platform.openai.com/docs/guides/vision](https://platform.openai.com/docs/guides/vision)
560
+ For additional details, see: <https://platform.openai.com/docs/guides/vision>
561
+
562
+ Request throttling:
563
+ Uses the rate limit-related headers returned by the API to throttle requests adaptively, based on available
564
+ request and token capacity. No configuration is necessary.
257
565
 
258
566
  __Requirements:__
259
567
 
@@ -271,8 +579,11 @@ def vision(prompt: str, image: PIL.Image.Image, *, model: str) -> str:
271
579
  Add a computed column that applies the model `gpt-4o-mini` to an existing Pixeltable column `tbl.image`
272
580
  of the table `tbl`:
273
581
 
274
- >>> tbl['response'] = vision("What's in this image?", tbl.image, model='gpt-4o-mini')
582
+ >>> tbl.add_computed_column(response=vision("What's in this image?", tbl.image, model='gpt-4o-mini'))
275
583
  """
584
+ if model_kwargs is None:
585
+ model_kwargs = {}
586
+
276
587
  # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
277
588
  bytes_arr = io.BytesIO()
278
589
  image.save(bytes_arr, format='png')
@@ -287,8 +598,27 @@ def vision(prompt: str, image: PIL.Image.Image, *, model: str) -> str:
287
598
  ],
288
599
  }
289
600
  ]
290
- result = _retry(_openai_client().chat.completions.create)(messages=messages, model=model)
291
- return result.choices[0].message.content
601
+
602
+ # make sure the pool info exists prior to making the request
603
+ resource_pool = _rate_limits_pool(model)
604
+ rate_limits_info = env.Env.get().get_resource_pool_info(
605
+ resource_pool, lambda: OpenAIRateLimitsInfo(_vision_get_request_resources)
606
+ )
607
+
608
+ request_ts = datetime.datetime.now(tz=datetime.timezone.utc)
609
+ result = await _openai_client().chat.completions.with_raw_response.create(
610
+ messages=messages, # type: ignore
611
+ model=model,
612
+ **model_kwargs,
613
+ )
614
+
615
+ # _logger.debug(f'vision(): headers={result.headers}')
616
+ requests_info, tokens_info = _get_header_info(result.headers)
617
+ is_retry = _runtime_ctx is not None and _runtime_ctx.is_retry
618
+ rate_limits_info.record(request_ts=request_ts, requests=requests_info, tokens=tokens_info, reset_exc=is_retry)
619
+
620
+ result = json.loads(result.text)
621
+ return result['choices'][0]['message']['content']
292
622
 
293
623
 
294
624
  #####################################
@@ -301,15 +631,28 @@ _embedding_dimensions_cache: dict[str, int] = {
301
631
  }
302
632
 
303
633
 
634
+ def _embeddings_get_request_resources(input: list[str]) -> dict[str, int]:
635
+ input_len = sum(len(s) for s in input)
636
+ return {'requests': 1, 'tokens': int(input_len / 4)}
637
+
638
+
304
639
  @pxt.udf(batch_size=32)
305
- def embeddings(
306
- input: Batch[str], *, model: str, dimensions: Optional[int] = None, user: Optional[str] = None
640
+ async def embeddings(
641
+ input: Batch[str],
642
+ *,
643
+ model: str,
644
+ model_kwargs: dict[str, Any] | None = None,
645
+ _runtime_ctx: env.RuntimeCtx | None = None,
307
646
  ) -> Batch[pxt.Array[(None,), pxt.Float]]:
308
647
  """
309
648
  Creates an embedding vector representing the input text.
310
649
 
311
650
  Equivalent to the OpenAI `embeddings` API endpoint.
312
- For additional details, see: [https://platform.openai.com/docs/guides/embeddings](https://platform.openai.com/docs/guides/embeddings)
651
+ For additional details, see: <https://platform.openai.com/docs/guides/embeddings>
652
+
653
+ Request throttling:
654
+ Uses the rate limit-related headers returned by the API to throttle requests adaptively, based on available
655
+ request and token capacity. No configuration is necessary.
313
656
 
314
657
  __Requirements:__
315
658
 
@@ -318,10 +661,8 @@ def embeddings(
318
661
  Args:
319
662
  input: The text to embed.
320
663
  model: The model to use for the embedding.
321
- dimensions: The vector length of the embedding. If not specified, Pixeltable will use
322
- a default value based on the model.
323
-
324
- For details on the other parameters, see: [https://platform.openai.com/docs/api-reference/embeddings](https://platform.openai.com/docs/api-reference/embeddings)
664
+ model_kwargs: Additional keyword args for the OpenAI `embeddings` API. For details on the available
665
+ parameters, see: <https://platform.openai.com/docs/api-reference/embeddings>
325
666
 
326
667
  Returns:
327
668
  An array representing the application of the given embedding to `input`.
@@ -330,22 +671,41 @@ def embeddings(
330
671
  Add a computed column that applies the model `text-embedding-3-small` to an existing
331
672
  Pixeltable column `tbl.text` of the table `tbl`:
332
673
 
333
- >>> tbl['embed'] = embeddings(tbl.text, model='text-embedding-3-small')
674
+ >>> tbl.add_computed_column(embed=embeddings(tbl.text, model='text-embedding-3-small'))
675
+
676
+ Add an embedding index to an existing column `text`, using the model `text-embedding-3-small`:
677
+
678
+ >>> tbl.add_embedding_index(embedding=embeddings.using(model='text-embedding-3-small'))
334
679
  """
335
- result = _retry(_openai_client().embeddings.create)(
336
- input=input, model=model, dimensions=_opt(dimensions), user=_opt(user), encoding_format='float'
680
+ if model_kwargs is None:
681
+ model_kwargs = {}
682
+
683
+ _logger.debug(f'embeddings: batch_size={len(input)}')
684
+ resource_pool = _rate_limits_pool(model)
685
+ rate_limits_info = env.Env.get().get_resource_pool_info(
686
+ resource_pool, lambda: OpenAIRateLimitsInfo(_embeddings_get_request_resources)
687
+ )
688
+ request_ts = datetime.datetime.now(tz=datetime.timezone.utc)
689
+ result = await _openai_client().embeddings.with_raw_response.create(
690
+ input=input, model=model, encoding_format='float', **model_kwargs
337
691
  )
338
- return [np.array(data.embedding, dtype=np.float64) for data in result.data]
692
+ requests_info, tokens_info = _get_header_info(result.headers)
693
+ is_retry = _runtime_ctx is not None and _runtime_ctx.is_retry
694
+ rate_limits_info.record(request_ts=request_ts, requests=requests_info, tokens=tokens_info, reset_exc=is_retry)
695
+ return [np.array(data['embedding'], dtype=np.float64) for data in json.loads(result.content)['data']]
339
696
 
340
697
 
341
698
  @embeddings.conditional_return_type
342
- def _(model: str, dimensions: Optional[int] = None) -> pxt.ArrayType:
699
+ def _(model: str, model_kwargs: dict[str, Any] | None = None) -> ts.ArrayType:
700
+ dimensions: int | None = None
701
+ if model_kwargs is not None:
702
+ dimensions = model_kwargs.get('dimensions')
343
703
  if dimensions is None:
344
704
  if model not in _embedding_dimensions_cache:
345
705
  # TODO: find some other way to retrieve a sample
346
- return pxt.ArrayType((None,), dtype=pxt.FloatType(), nullable=False)
347
- dimensions = _embedding_dimensions_cache.get(model, None)
348
- return pxt.ArrayType((dimensions,), dtype=pxt.FloatType(), nullable=False)
706
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
707
+ dimensions = _embedding_dimensions_cache.get(model)
708
+ return ts.ArrayType((dimensions,), dtype=ts.FloatType(), nullable=False)
349
709
 
350
710
 
351
711
  #####################################
@@ -353,20 +713,18 @@ def _(model: str, dimensions: Optional[int] = None) -> pxt.ArrayType:
353
713
 
354
714
 
355
715
  @pxt.udf
356
- def image_generations(
357
- prompt: str,
358
- *,
359
- model: Optional[str] = None,
360
- quality: Optional[str] = None,
361
- size: Optional[str] = None,
362
- style: Optional[str] = None,
363
- user: Optional[str] = None,
716
+ async def image_generations(
717
+ prompt: str, *, model: str = 'dall-e-2', model_kwargs: dict[str, Any] | None = None
364
718
  ) -> PIL.Image.Image:
365
719
  """
366
720
  Creates an image given a prompt.
367
721
 
368
722
  Equivalent to the OpenAI `images/generations` API endpoint.
369
- For additional details, see: [https://platform.openai.com/docs/guides/images](https://platform.openai.com/docs/guides/images)
723
+ For additional details, see: <https://platform.openai.com/docs/guides/images>
724
+
725
+ Request throttling:
726
+ Applies the rate limit set in the config (section `openai.rate_limits`; use the model id as the key). If no rate
727
+ limit is configured, uses a default of 600 RPM.
370
728
 
371
729
  __Requirements:__
372
730
 
@@ -375,8 +733,8 @@ def image_generations(
375
733
  Args:
376
734
  prompt: Prompt for the image.
377
735
  model: The model to use for the generations.
378
-
379
- For details on the other parameters, see: [https://platform.openai.com/docs/api-reference/images/create](https://platform.openai.com/docs/api-reference/images/create)
736
+ model_kwargs: Additional keyword args for the OpenAI `images/generations` API. For details on the available
737
+ parameters, see: <https://platform.openai.com/docs/api-reference/images/create>
380
738
 
381
739
  Returns:
382
740
  The generated image.
@@ -385,17 +743,14 @@ def image_generations(
385
743
  Add a computed column that applies the model `dall-e-2` to an existing
386
744
  Pixeltable column `tbl.text` of the table `tbl`:
387
745
 
388
- >>> tbl['gen_image'] = image_generations(tbl.text, model='dall-e-2')
746
+ >>> tbl.add_computed_column(gen_image=image_generations(tbl.text, model='dall-e-2'))
389
747
  """
748
+ if model_kwargs is None:
749
+ model_kwargs = {}
750
+
390
751
  # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
391
- result = _retry(_openai_client().images.generate)(
392
- prompt=prompt,
393
- model=_opt(model),
394
- quality=_opt(quality),
395
- size=_opt(size),
396
- style=_opt(style),
397
- user=_opt(user),
398
- response_format='b64_json',
752
+ result = await _openai_client().images.generate(
753
+ prompt=prompt, model=model, response_format='b64_json', **model_kwargs
399
754
  )
400
755
  b64_str = result.data[0].b64_json
401
756
  b64_bytes = base64.b64decode(b64_str)
@@ -405,17 +760,19 @@ def image_generations(
405
760
 
406
761
 
407
762
  @image_generations.conditional_return_type
408
- def _(size: Optional[str] = None) -> pxt.ImageType:
409
- if size is None:
410
- return pxt.ImageType(size=(1024, 1024))
763
+ def _(model_kwargs: dict[str, Any] | None = None) -> ts.ImageType:
764
+ if model_kwargs is None or 'size' not in model_kwargs:
765
+ # default size is 1024x1024
766
+ return ts.ImageType(size=(1024, 1024))
767
+ size = model_kwargs['size']
411
768
  x_pos = size.find('x')
412
769
  if x_pos == -1:
413
- return pxt.ImageType()
770
+ return ts.ImageType()
414
771
  try:
415
772
  width, height = int(size[:x_pos]), int(size[x_pos + 1 :])
416
773
  except ValueError:
417
- return pxt.ImageType()
418
- return pxt.ImageType(size=(width, height))
774
+ return ts.ImageType()
775
+ return ts.ImageType(size=(width, height))
419
776
 
420
777
 
421
778
  #####################################
@@ -423,12 +780,16 @@ def _(size: Optional[str] = None) -> pxt.ImageType:
423
780
 
424
781
 
425
782
  @pxt.udf
426
- def moderations(input: str, *, model: Optional[str] = None) -> dict:
783
+ async def moderations(input: str, *, model: str = 'omni-moderation-latest') -> dict:
427
784
  """
428
785
  Classifies if text is potentially harmful.
429
786
 
430
787
  Equivalent to the OpenAI `moderations` API endpoint.
431
- For additional details, see: [https://platform.openai.com/docs/guides/moderation](https://platform.openai.com/docs/guides/moderation)
788
+ For additional details, see: <https://platform.openai.com/docs/guides/moderation>
789
+
790
+ Request throttling:
791
+ Applies the rate limit set in the config (section `openai.rate_limits`; use the model id as the key). If no rate
792
+ limit is configured, uses a default of 600 RPM.
432
793
 
433
794
  __Requirements:__
434
795
 
@@ -438,7 +799,7 @@ def moderations(input: str, *, model: Optional[str] = None) -> dict:
438
799
  input: Text to analyze with the moderations model.
439
800
  model: The model to use for moderations.
440
801
 
441
- For details on the other parameters, see: [https://platform.openai.com/docs/api-reference/moderations](https://platform.openai.com/docs/api-reference/moderations)
802
+ For details on the other parameters, see: <https://platform.openai.com/docs/api-reference/moderations>
442
803
 
443
804
  Returns:
444
805
  Details of the moderations results.
@@ -447,22 +808,49 @@ def moderations(input: str, *, model: Optional[str] = None) -> dict:
447
808
  Add a computed column that applies the model `text-moderation-stable` to an existing
448
809
  Pixeltable column `tbl.input` of the table `tbl`:
449
810
 
450
- >>> tbl['moderations'] = moderations(tbl.text, model='text-moderation-stable')
811
+ >>> tbl.add_computed_column(moderations=moderations(tbl.text, model='text-moderation-stable'))
451
812
  """
452
- result = _retry(_openai_client().moderations.create)(input=input, model=_opt(model))
813
+ result = await _openai_client().moderations.create(input=input, model=model)
453
814
  return result.dict()
454
815
 
455
816
 
456
- _T = TypeVar('_T')
817
+ @speech.resource_pool
818
+ @transcriptions.resource_pool
819
+ @translations.resource_pool
820
+ @image_generations.resource_pool
821
+ @moderations.resource_pool
822
+ def _(model: str) -> str:
823
+ return f'request-rate:openai:{model}'
457
824
 
458
825
 
459
- def _opt(arg: _T) -> Union[_T, 'openai.NotGiven']:
460
- import openai
461
- return arg if arg is not None else openai.NOT_GIVEN
826
+ @chat_completions.resource_pool
827
+ @vision.resource_pool
828
+ @embeddings.resource_pool
829
+ def _(model: str) -> str:
830
+ return _rate_limits_pool(model)
831
+
832
+
833
+ def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
834
+ """Converts an OpenAI response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
835
+ return tools._invoke(_openai_response_to_pxt_tool_calls(response))
836
+
837
+
838
+ @pxt.udf
839
+ def _openai_response_to_pxt_tool_calls(response: dict) -> dict | None:
840
+ if 'tool_calls' not in response['choices'][0]['message'] or response['choices'][0]['message']['tool_calls'] is None:
841
+ return None
842
+ openai_tool_calls = response['choices'][0]['message']['tool_calls']
843
+ pxt_tool_calls: dict[str, list[dict[str, Any]]] = {}
844
+ for tool_call in openai_tool_calls:
845
+ tool_name = tool_call['function']['name']
846
+ if tool_name not in pxt_tool_calls:
847
+ pxt_tool_calls[tool_name] = []
848
+ pxt_tool_calls[tool_name].append({'args': json.loads(tool_call['function']['arguments'])})
849
+ return pxt_tool_calls
462
850
 
463
851
 
464
852
  __all__ = local_public_names(__name__)
465
853
 
466
854
 
467
- def __dir__():
855
+ def __dir__() -> list[str]:
468
856
  return __all__