pixeltable 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/insertable_table.py +3 -3
- pixeltable/catalog/table.py +2 -2
- pixeltable/catalog/table_version.py +3 -2
- pixeltable/catalog/view.py +1 -1
- pixeltable/dataframe.py +52 -27
- pixeltable/env.py +109 -4
- pixeltable/exec/__init__.py +1 -1
- pixeltable/exec/aggregation_node.py +3 -3
- pixeltable/exec/cache_prefetch_node.py +13 -7
- pixeltable/exec/component_iteration_node.py +3 -9
- pixeltable/exec/data_row_batch.py +17 -5
- pixeltable/exec/exec_node.py +32 -12
- pixeltable/exec/expr_eval/__init__.py +1 -0
- pixeltable/exec/expr_eval/evaluators.py +240 -0
- pixeltable/exec/expr_eval/expr_eval_node.py +408 -0
- pixeltable/exec/expr_eval/globals.py +113 -0
- pixeltable/exec/expr_eval/row_buffer.py +76 -0
- pixeltable/exec/expr_eval/schedulers.py +240 -0
- pixeltable/exec/in_memory_data_node.py +2 -2
- pixeltable/exec/row_update_node.py +14 -14
- pixeltable/exec/sql_node.py +2 -2
- pixeltable/exprs/column_ref.py +5 -1
- pixeltable/exprs/data_row.py +50 -40
- pixeltable/exprs/expr.py +57 -12
- pixeltable/exprs/function_call.py +54 -19
- pixeltable/exprs/inline_expr.py +12 -21
- pixeltable/exprs/literal.py +25 -8
- pixeltable/exprs/row_builder.py +25 -2
- pixeltable/func/aggregate_function.py +4 -0
- pixeltable/func/callable_function.py +54 -4
- pixeltable/func/expr_template_function.py +5 -1
- pixeltable/func/function.py +48 -7
- pixeltable/func/query_template_function.py +16 -7
- pixeltable/func/udf.py +7 -1
- pixeltable/functions/__init__.py +1 -1
- pixeltable/functions/anthropic.py +97 -21
- pixeltable/functions/gemini.py +2 -6
- pixeltable/functions/openai.py +219 -28
- pixeltable/globals.py +2 -3
- pixeltable/io/hf_datasets.py +1 -1
- pixeltable/io/label_studio.py +5 -5
- pixeltable/io/parquet.py +1 -1
- pixeltable/metadata/__init__.py +2 -1
- pixeltable/plan.py +24 -9
- pixeltable/store.py +6 -0
- pixeltable/type_system.py +73 -36
- pixeltable/utils/arrow.py +3 -8
- pixeltable/utils/console_output.py +41 -0
- pixeltable/utils/filecache.py +1 -1
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/METADATA +4 -1
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/RECORD +55 -49
- pixeltable/exec/expr_eval_node.py +0 -232
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/LICENSE +0 -0
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/WHEEL +0 -0
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/entry_points.txt +0 -0
pixeltable/functions/openai.py
CHANGED
|
@@ -6,14 +6,18 @@ the [Working with OpenAI](https://pixeltable.readme.io/docs/working-with-openai)
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import base64
|
|
9
|
+
import datetime
|
|
9
10
|
import io
|
|
10
11
|
import json
|
|
12
|
+
import logging
|
|
11
13
|
import pathlib
|
|
14
|
+
import re
|
|
12
15
|
import uuid
|
|
13
|
-
from typing import TYPE_CHECKING,
|
|
16
|
+
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union, cast, Any, Type
|
|
14
17
|
|
|
15
|
-
import numpy as np
|
|
16
18
|
import PIL.Image
|
|
19
|
+
import httpx
|
|
20
|
+
import numpy as np
|
|
17
21
|
import tenacity
|
|
18
22
|
|
|
19
23
|
import pixeltable as pxt
|
|
@@ -24,15 +28,28 @@ from pixeltable.utils.code import local_public_names
|
|
|
24
28
|
if TYPE_CHECKING:
|
|
25
29
|
import openai
|
|
26
30
|
|
|
31
|
+
_logger = logging.getLogger('pixeltable')
|
|
32
|
+
|
|
27
33
|
|
|
28
34
|
@env.register_client('openai')
|
|
29
|
-
def _(api_key: str) -> 'openai.OpenAI':
|
|
35
|
+
def _(api_key: str) -> tuple['openai.OpenAI', 'openai.AsyncOpenAI']:
|
|
30
36
|
import openai
|
|
31
|
-
return
|
|
37
|
+
return (
|
|
38
|
+
openai.OpenAI(api_key=api_key),
|
|
39
|
+
openai.AsyncOpenAI(
|
|
40
|
+
api_key=api_key,
|
|
41
|
+
# recommended to increase limits for async client to avoid connection errors
|
|
42
|
+
http_client=httpx.AsyncClient(limits=httpx.Limits(max_keepalive_connections=100, max_connections=500)),
|
|
43
|
+
)
|
|
44
|
+
)
|
|
32
45
|
|
|
33
46
|
|
|
34
47
|
def _openai_client() -> 'openai.OpenAI':
|
|
35
|
-
return env.Env.get().get_client('openai')
|
|
48
|
+
return env.Env.get().get_client('openai')[0]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _async_openai_client() -> 'openai.AsyncOpenAI':
|
|
52
|
+
return env.Env.get().get_client('openai')[1]
|
|
36
53
|
|
|
37
54
|
|
|
38
55
|
# Exponential backoff decorator using tenacity.
|
|
@@ -47,13 +64,138 @@ def _retry(fn: Callable) -> Callable:
|
|
|
47
64
|
)(fn)
|
|
48
65
|
|
|
49
66
|
|
|
67
|
+
# models that share rate limits; see https://platform.openai.com/settings/organization/limits for details
|
|
68
|
+
_shared_rate_limits = {
|
|
69
|
+
'gpt-4-turbo': [
|
|
70
|
+
'gpt-4-turbo',
|
|
71
|
+
'gpt-4-turbo-latest',
|
|
72
|
+
'gpt-4-turbo-2024-04-09',
|
|
73
|
+
'gpt-4-turbo-preview',
|
|
74
|
+
'gpt-4-0125-preview',
|
|
75
|
+
'gpt-4-1106-preview'
|
|
76
|
+
],
|
|
77
|
+
'gpt-4o': [
|
|
78
|
+
'gpt-4o',
|
|
79
|
+
'gpt-4o-latest',
|
|
80
|
+
'gpt-4o-2024-05-13',
|
|
81
|
+
'gpt-4o-2024-08-06',
|
|
82
|
+
'gpt-4o-2024-11-20',
|
|
83
|
+
'gpt-4o-audio-preview',
|
|
84
|
+
'gpt-4o-audio-preview-2024-10-01',
|
|
85
|
+
'gpt-4o-audio-preview-2024-12-17'
|
|
86
|
+
],
|
|
87
|
+
'gpt-4o-mini': [
|
|
88
|
+
'gpt-4o-mini',
|
|
89
|
+
'gpt-4o-mini-latest',
|
|
90
|
+
'gpt-4o-mini-2024-07-18',
|
|
91
|
+
'gpt-4o-mini-audio-preview',
|
|
92
|
+
'gpt-4o-mini-audio-preview-2024-12-17'
|
|
93
|
+
],
|
|
94
|
+
'gpt-4o-mini-realtime-preview': [
|
|
95
|
+
'gpt-4o-mini-realtime-preview',
|
|
96
|
+
'gpt-4o-mini-realtime-preview-latest',
|
|
97
|
+
'gpt-4o-mini-realtime-preview-2024-12-17'
|
|
98
|
+
]
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _resource_pool(model: str) -> str:
|
|
103
|
+
for model_family, models in _shared_rate_limits.items():
|
|
104
|
+
if model in models:
|
|
105
|
+
return f'rate-limits:openai:{model_family}'
|
|
106
|
+
return f'rate-limits:openai:{model}'
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class OpenAIRateLimitsInfo(env.RateLimitsInfo):
|
|
110
|
+
retryable_errors: tuple[Type[Exception], ...]
|
|
111
|
+
|
|
112
|
+
def __init__(self, get_request_resources: Callable[..., dict[str, int]]):
|
|
113
|
+
super().__init__(get_request_resources)
|
|
114
|
+
import openai
|
|
115
|
+
self.retryable_errors = (
|
|
116
|
+
# ConnectionError: we occasionally see this error when the AsyncConnectionPool is trying to close
|
|
117
|
+
# expired connections
|
|
118
|
+
# (AsyncConnectionPool._close_expired_connections() fails with ConnectionError when executing
|
|
119
|
+
# 'await connection.aclose()', which is potentially a bug in AsyncConnectionPool)
|
|
120
|
+
openai.APIConnectionError,
|
|
121
|
+
|
|
122
|
+
# the following errors are retryable according to OpenAI's API documentation
|
|
123
|
+
openai.RateLimitError,
|
|
124
|
+
openai.APITimeoutError,
|
|
125
|
+
openai.UnprocessableEntityError,
|
|
126
|
+
openai.InternalServerError,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def get_retry_delay(self, exc: Exception) -> Optional[float]:
|
|
130
|
+
import openai
|
|
131
|
+
|
|
132
|
+
if not isinstance(exc, self.retryable_errors):
|
|
133
|
+
return None
|
|
134
|
+
assert isinstance(exc, openai.APIError)
|
|
135
|
+
return 1.0
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# RE pattern for duration in '*-reset' headers;
|
|
139
|
+
# examples: 1d2h3ms, 4m5.6s; # fractional seconds can be reported as 0.5s or 500ms
|
|
140
|
+
_header_duration_pattern = re.compile(r'(?:(\d+)d)?(?:(\d+)h)?(?:(\d+)ms)|(?:(\d+)m)?(?:([\d.]+)s)?')
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _parse_header_duration(duration_str):
|
|
144
|
+
match = _header_duration_pattern.match(duration_str)
|
|
145
|
+
if not match:
|
|
146
|
+
raise ValueError("Invalid duration format")
|
|
147
|
+
|
|
148
|
+
days = int(match.group(1) or 0)
|
|
149
|
+
hours = int(match.group(2) or 0)
|
|
150
|
+
milliseconds = int(match.group(3) or 0)
|
|
151
|
+
minutes = int(match.group(4) or 0)
|
|
152
|
+
seconds = float(match.group(5) or 0)
|
|
153
|
+
|
|
154
|
+
return datetime.timedelta(
|
|
155
|
+
days=days,
|
|
156
|
+
hours=hours,
|
|
157
|
+
minutes=minutes,
|
|
158
|
+
seconds=seconds,
|
|
159
|
+
milliseconds=milliseconds
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _get_header_info(
|
|
164
|
+
headers: httpx.Headers, *, requests: bool = True, tokens: bool = True
|
|
165
|
+
) -> tuple[Optional[tuple[int, int, datetime.datetime]], Optional[tuple[int, int, datetime.datetime]]]:
|
|
166
|
+
assert requests or tokens
|
|
167
|
+
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
|
168
|
+
|
|
169
|
+
requests_info: Optional[tuple[int, int, datetime.datetime]] = None
|
|
170
|
+
if requests:
|
|
171
|
+
requests_limit_str = headers.get('x-ratelimit-limit-requests')
|
|
172
|
+
requests_limit = int(requests_limit_str) if requests_limit_str is not None else None
|
|
173
|
+
requests_remaining_str = headers.get('x-ratelimit-remaining-requests')
|
|
174
|
+
requests_remaining = int(requests_remaining_str) if requests_remaining_str is not None else None
|
|
175
|
+
requests_reset_str = headers.get('x-ratelimit-reset-requests')
|
|
176
|
+
requests_reset_ts = now + _parse_header_duration(requests_reset_str)
|
|
177
|
+
requests_info = (requests_limit, requests_remaining, requests_reset_ts)
|
|
178
|
+
|
|
179
|
+
tokens_info: Optional[tuple[int, int, datetime.datetime]] = None
|
|
180
|
+
if tokens:
|
|
181
|
+
tokens_limit_str = headers.get('x-ratelimit-limit-tokens')
|
|
182
|
+
tokens_limit = int(tokens_limit_str) if tokens_limit_str is not None else None
|
|
183
|
+
tokens_remaining_str = headers.get('x-ratelimit-remaining-tokens')
|
|
184
|
+
tokens_remaining = int(tokens_remaining_str) if tokens_remaining_str is not None else None
|
|
185
|
+
tokens_reset_str = headers.get('x-ratelimit-reset-tokens')
|
|
186
|
+
tokens_reset_ts = now + _parse_header_duration(tokens_reset_str)
|
|
187
|
+
tokens_info = (tokens_limit, tokens_remaining, tokens_reset_ts)
|
|
188
|
+
|
|
189
|
+
return requests_info, tokens_info
|
|
190
|
+
|
|
191
|
+
|
|
50
192
|
#####################################
|
|
51
193
|
# Audio Endpoints
|
|
52
194
|
|
|
53
195
|
|
|
54
196
|
@pxt.udf
|
|
55
197
|
def speech(
|
|
56
|
-
|
|
198
|
+
input: str, *, model: str, voice: str, response_format: Optional[str] = None, speed: Optional[float] = None
|
|
57
199
|
) -> pxt.Audio:
|
|
58
200
|
"""
|
|
59
201
|
Generates audio from the input text.
|
|
@@ -176,8 +318,24 @@ def translations(
|
|
|
176
318
|
# Chat Endpoints
|
|
177
319
|
|
|
178
320
|
|
|
321
|
+
def _chat_completions_get_request_resources(
|
|
322
|
+
messages: list, max_tokens: Optional[int], n: Optional[int]
|
|
323
|
+
) -> dict[str, int]:
|
|
324
|
+
completion_tokens = n * max_tokens
|
|
325
|
+
|
|
326
|
+
num_tokens = 0.0
|
|
327
|
+
for message in messages:
|
|
328
|
+
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
|
329
|
+
for key, value in message.items():
|
|
330
|
+
num_tokens += len(value) / 4
|
|
331
|
+
if key == "name": # if there's a name, the role is omitted
|
|
332
|
+
num_tokens -= 1 # role is always required and always 1 token
|
|
333
|
+
num_tokens += 2 # every reply is primed with <im_start>assistant
|
|
334
|
+
return {'requests': 1, 'tokens': int(num_tokens) + completion_tokens}
|
|
335
|
+
|
|
336
|
+
|
|
179
337
|
@pxt.udf
|
|
180
|
-
def chat_completions(
|
|
338
|
+
async def chat_completions(
|
|
181
339
|
messages: list,
|
|
182
340
|
*,
|
|
183
341
|
model: str,
|
|
@@ -185,8 +343,8 @@ def chat_completions(
|
|
|
185
343
|
logit_bias: Optional[dict[str, int]] = None,
|
|
186
344
|
logprobs: Optional[bool] = None,
|
|
187
345
|
top_logprobs: Optional[int] = None,
|
|
188
|
-
max_tokens: Optional[int] =
|
|
189
|
-
n: Optional[int] =
|
|
346
|
+
max_tokens: Optional[int] = 1024,
|
|
347
|
+
n: Optional[int] = 1,
|
|
190
348
|
presence_penalty: Optional[float] = None,
|
|
191
349
|
response_format: Optional[dict] = None,
|
|
192
350
|
seed: Optional[int] = None,
|
|
@@ -226,7 +384,6 @@ def chat_completions(
|
|
|
226
384
|
]
|
|
227
385
|
tbl['response'] = chat_completions(messages, model='gpt-4o-mini')
|
|
228
386
|
"""
|
|
229
|
-
|
|
230
387
|
if tools is not None:
|
|
231
388
|
tools = [
|
|
232
389
|
{
|
|
@@ -253,7 +410,13 @@ def chat_completions(
|
|
|
253
410
|
if tool_choice is not None and not tool_choice['parallel_tool_calls']:
|
|
254
411
|
extra_body = {'parallel_tool_calls': False}
|
|
255
412
|
|
|
256
|
-
|
|
413
|
+
# make sure the pool info exists prior to making the request
|
|
414
|
+
resource_pool = _resource_pool(model)
|
|
415
|
+
rate_limits_info = env.Env.get().get_resource_pool_info(
|
|
416
|
+
resource_pool, lambda: OpenAIRateLimitsInfo(_chat_completions_get_request_resources))
|
|
417
|
+
|
|
418
|
+
# cast(Any, ...): avoid mypy errors
|
|
419
|
+
result = await _async_openai_client().chat.completions.with_raw_response.create(
|
|
257
420
|
messages=messages,
|
|
258
421
|
model=model,
|
|
259
422
|
frequency_penalty=_opt(frequency_penalty),
|
|
@@ -263,17 +426,22 @@ def chat_completions(
|
|
|
263
426
|
max_tokens=_opt(max_tokens),
|
|
264
427
|
n=_opt(n),
|
|
265
428
|
presence_penalty=_opt(presence_penalty),
|
|
266
|
-
response_format=_opt(response_format),
|
|
429
|
+
response_format=_opt(cast(Any, response_format)),
|
|
267
430
|
seed=_opt(seed),
|
|
268
431
|
stop=_opt(stop),
|
|
269
432
|
temperature=_opt(temperature),
|
|
270
433
|
top_p=_opt(top_p),
|
|
271
|
-
tools=_opt(tools),
|
|
272
|
-
tool_choice=_opt(tool_choice_),
|
|
434
|
+
tools=_opt(cast(Any, tools)),
|
|
435
|
+
tool_choice=_opt(cast(Any, tool_choice_)),
|
|
273
436
|
user=_opt(user),
|
|
437
|
+
timeout=10,
|
|
274
438
|
extra_body=extra_body,
|
|
275
439
|
)
|
|
276
|
-
|
|
440
|
+
|
|
441
|
+
requests_info, tokens_info = _get_header_info(result.headers)
|
|
442
|
+
rate_limits_info.record(requests=requests_info, tokens=tokens_info)
|
|
443
|
+
|
|
444
|
+
return json.loads(result.text)
|
|
277
445
|
|
|
278
446
|
|
|
279
447
|
@pxt.udf
|
|
@@ -330,8 +498,13 @@ _embedding_dimensions_cache: dict[str, int] = {
|
|
|
330
498
|
}
|
|
331
499
|
|
|
332
500
|
|
|
501
|
+
def _embeddings_get_request_resources(input: list[str]) -> dict[str, int]:
|
|
502
|
+
input_len = sum(len(s) for s in input)
|
|
503
|
+
return {'requests': 1, 'tokens': int(input_len / 4)}
|
|
504
|
+
|
|
505
|
+
|
|
333
506
|
@pxt.udf(batch_size=32)
|
|
334
|
-
def embeddings(
|
|
507
|
+
async def embeddings(
|
|
335
508
|
input: Batch[str], *, model: str, dimensions: Optional[int] = None, user: Optional[str] = None
|
|
336
509
|
) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
337
510
|
"""
|
|
@@ -361,10 +534,16 @@ def embeddings(
|
|
|
361
534
|
|
|
362
535
|
>>> tbl['embed'] = embeddings(tbl.text, model='text-embedding-3-small')
|
|
363
536
|
"""
|
|
364
|
-
|
|
537
|
+
_logger.debug(f'embeddings: batch_size={len(input)}')
|
|
538
|
+
resource_pool = _resource_pool(model)
|
|
539
|
+
rate_limits_info = env.Env.get().get_resource_pool_info(
|
|
540
|
+
resource_pool, lambda: OpenAIRateLimitsInfo(_embeddings_get_request_resources))
|
|
541
|
+
result = await _async_openai_client().embeddings.with_raw_response.create(
|
|
365
542
|
input=input, model=model, dimensions=_opt(dimensions), user=_opt(user), encoding_format='float'
|
|
366
543
|
)
|
|
367
|
-
|
|
544
|
+
requests_info, tokens_info = _get_header_info(result.headers)
|
|
545
|
+
rate_limits_info.record(requests=requests_info, tokens=tokens_info)
|
|
546
|
+
return [np.array(data['embedding'], dtype=np.float64) for data in json.loads(result.content)['data']]
|
|
368
547
|
|
|
369
548
|
|
|
370
549
|
@embeddings.conditional_return_type
|
|
@@ -385,7 +564,7 @@ def _(model: str, dimensions: Optional[int] = None) -> pxt.ArrayType:
|
|
|
385
564
|
def image_generations(
|
|
386
565
|
prompt: str,
|
|
387
566
|
*,
|
|
388
|
-
model:
|
|
567
|
+
model: str = 'dall-e-2',
|
|
389
568
|
quality: Optional[str] = None,
|
|
390
569
|
size: Optional[str] = None,
|
|
391
570
|
style: Optional[str] = None,
|
|
@@ -441,7 +620,7 @@ def _(size: Optional[str] = None) -> pxt.ImageType:
|
|
|
441
620
|
if x_pos == -1:
|
|
442
621
|
return pxt.ImageType()
|
|
443
622
|
try:
|
|
444
|
-
width, height = int(size[:x_pos]), int(size[x_pos + 1
|
|
623
|
+
width, height = int(size[:x_pos]), int(size[x_pos + 1:])
|
|
445
624
|
except ValueError:
|
|
446
625
|
return pxt.ImageType()
|
|
447
626
|
return pxt.ImageType(size=(width, height))
|
|
@@ -452,7 +631,7 @@ def _(size: Optional[str] = None) -> pxt.ImageType:
|
|
|
452
631
|
|
|
453
632
|
|
|
454
633
|
@pxt.udf
|
|
455
|
-
def moderations(input: str, *, model:
|
|
634
|
+
def moderations(input: str, *, model: str = 'omni-moderation-latest') -> dict:
|
|
456
635
|
"""
|
|
457
636
|
Classifies if text is potentially harmful.
|
|
458
637
|
|
|
@@ -482,6 +661,18 @@ def moderations(input: str, *, model: Optional[str] = None) -> dict:
|
|
|
482
661
|
return result.dict()
|
|
483
662
|
|
|
484
663
|
|
|
664
|
+
# @speech.resource_pool
|
|
665
|
+
# @transcriptions.resource_pool
|
|
666
|
+
# @translations.resource_pool
|
|
667
|
+
@chat_completions.resource_pool
|
|
668
|
+
# @vision.resource_pool
|
|
669
|
+
@embeddings.resource_pool
|
|
670
|
+
# @image_generations.resource_pool
|
|
671
|
+
# @moderations.resource_pool
|
|
672
|
+
def _(model: str) -> str:
|
|
673
|
+
return _resource_pool(model)
|
|
674
|
+
|
|
675
|
+
|
|
485
676
|
def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
|
|
486
677
|
"""Converts an OpenAI response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
|
|
487
678
|
return tools._invoke(_openai_response_to_pxt_tool_calls(response))
|
|
@@ -489,15 +680,15 @@ def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
|
|
|
489
680
|
|
|
490
681
|
@pxt.udf
|
|
491
682
|
def _openai_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
|
|
683
|
+
if 'tool_calls' not in response['choices'][0]['message'] or response['choices'][0]['message']['tool_calls'] is None:
|
|
684
|
+
return None
|
|
492
685
|
openai_tool_calls = response['choices'][0]['message']['tool_calls']
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
tool_call['function']['
|
|
496
|
-
'args': json.loads(tool_call['function']['arguments'])
|
|
497
|
-
}
|
|
498
|
-
for tool_call in openai_tool_calls
|
|
686
|
+
return {
|
|
687
|
+
tool_call['function']['name']: {
|
|
688
|
+
'args': json.loads(tool_call['function']['arguments'])
|
|
499
689
|
}
|
|
500
|
-
|
|
690
|
+
for tool_call in openai_tool_calls
|
|
691
|
+
}
|
|
501
692
|
|
|
502
693
|
|
|
503
694
|
_T = TypeVar('_T')
|
pixeltable/globals.py
CHANGED
|
@@ -606,8 +606,7 @@ def create_dir(path_str: str, if_exists: Literal['error', 'ignore', 'replace', '
|
|
|
606
606
|
dir = catalog.Dir(dir_record.id, parent._id, path.name)
|
|
607
607
|
cat.paths[path] = dir
|
|
608
608
|
session.commit()
|
|
609
|
-
|
|
610
|
-
print(f'Created directory `{path_str}`.')
|
|
609
|
+
Env.get().console_logger.info(f'Created directory `{path_str}`.')
|
|
611
610
|
return dir
|
|
612
611
|
|
|
613
612
|
def drop_dir(path_str: str, force: bool = False, if_not_exists: Literal['error', 'ignore'] = 'error') -> None:
|
|
@@ -817,4 +816,4 @@ def configure_logging(
|
|
|
817
816
|
|
|
818
817
|
|
|
819
818
|
def array(elements: Iterable) -> exprs.Expr:
|
|
820
|
-
return exprs.
|
|
819
|
+
return exprs.Expr.from_array(elements)
|
pixeltable/io/hf_datasets.py
CHANGED
|
@@ -13,7 +13,7 @@ from pixeltable import exceptions as excs
|
|
|
13
13
|
if typing.TYPE_CHECKING:
|
|
14
14
|
import datasets # type: ignore[import-untyped]
|
|
15
15
|
|
|
16
|
-
_logger = logging.getLogger(
|
|
16
|
+
_logger = logging.getLogger('pixeltable')
|
|
17
17
|
|
|
18
18
|
# use 100MB as the batch size limit for loading a huggingface dataset into pixeltable.
|
|
19
19
|
# The primary goal is to bound memory use, regardless of dataset size.
|
pixeltable/io/label_studio.py
CHANGED
|
@@ -230,7 +230,7 @@ class LabelStudioProject(Project):
|
|
|
230
230
|
self.project.create_predictions(predictions)
|
|
231
231
|
tasks_created += 1
|
|
232
232
|
|
|
233
|
-
|
|
233
|
+
env.Env.get().console_logger.info(f'Created {tasks_created} new task(s) in {self}.')
|
|
234
234
|
|
|
235
235
|
sync_status = SyncStatus(external_rows_created=tasks_created)
|
|
236
236
|
|
|
@@ -330,7 +330,7 @@ class LabelStudioProject(Project):
|
|
|
330
330
|
if len(page) > 0:
|
|
331
331
|
self.project.import_tasks(page)
|
|
332
332
|
|
|
333
|
-
|
|
333
|
+
env.Env.get().console_logger.info(f'Created {tasks_created} new task(s) and updated {tasks_updated} existing task(s) in {self}.')
|
|
334
334
|
|
|
335
335
|
sync_status = SyncStatus(external_rows_created=tasks_created, external_rows_updated=tasks_updated)
|
|
336
336
|
|
|
@@ -363,7 +363,7 @@ class LabelStudioProject(Project):
|
|
|
363
363
|
|
|
364
364
|
if len(tasks_to_delete) > 0:
|
|
365
365
|
self.project.delete_tasks(tasks_to_delete)
|
|
366
|
-
|
|
366
|
+
env.Env.get().console_logger.info(f'Deleted {len(tasks_to_delete)} tasks(s) in {self} that are no longer present in Pixeltable.')
|
|
367
367
|
|
|
368
368
|
# Remove them from the `existing_tasks` dict so that future updates are applied correctly
|
|
369
369
|
for rowid in deleted_rowids:
|
|
@@ -406,7 +406,7 @@ class LabelStudioProject(Project):
|
|
|
406
406
|
assert ancestor._base is not None
|
|
407
407
|
ancestor = ancestor._base
|
|
408
408
|
update_status = ancestor.batch_update(updates)
|
|
409
|
-
|
|
409
|
+
env.Env.get().console_logger.info(f'Updated annotation(s) from {len(updates)} task(s) in {self}.')
|
|
410
410
|
return SyncStatus(pxt_rows_updated=update_status.num_rows, num_excs=update_status.num_excs)
|
|
411
411
|
else:
|
|
412
412
|
return SyncStatus.empty()
|
|
@@ -529,7 +529,7 @@ class LabelStudioProject(Project):
|
|
|
529
529
|
"""
|
|
530
530
|
title = self.project_title
|
|
531
531
|
_label_studio_client().delete_project(self.project_id)
|
|
532
|
-
|
|
532
|
+
env.Env.get().console_logger.info(f'Deleted Label Studio project: {title}')
|
|
533
533
|
|
|
534
534
|
def __eq__(self, other) -> bool:
|
|
535
535
|
return isinstance(other, LabelStudioProject) and self.project_id == other.project_id
|
pixeltable/io/parquet.py
CHANGED
|
@@ -23,7 +23,7 @@ if typing.TYPE_CHECKING:
|
|
|
23
23
|
import pyarrow as pa
|
|
24
24
|
import pixeltable as pxt
|
|
25
25
|
|
|
26
|
-
_logger = logging.getLogger(
|
|
26
|
+
_logger = logging.getLogger('pixeltable')
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def _write_batch(value_batch: dict[str, deque], schema: pa.Schema, output_path: Path) -> None:
|
pixeltable/metadata/__init__.py
CHANGED
|
@@ -47,7 +47,8 @@ def upgrade_md(engine: sql.engine.Engine) -> None:
|
|
|
47
47
|
while md_version < VERSION:
|
|
48
48
|
if md_version not in converter_cbs:
|
|
49
49
|
raise RuntimeError(f'No metadata converter for version {md_version}')
|
|
50
|
-
|
|
50
|
+
from pixeltable.env import Env
|
|
51
|
+
Env.get().console_logger.info(f'Converting metadata from version {md_version} to {md_version + 1}')
|
|
51
52
|
converter_cbs[md_version](engine)
|
|
52
53
|
md_version += 1
|
|
53
54
|
# update system info
|
pixeltable/plan.py
CHANGED
|
@@ -5,6 +5,7 @@ import enum
|
|
|
5
5
|
from typing import Any, Iterable, Optional, Sequence, Literal
|
|
6
6
|
from uuid import UUID
|
|
7
7
|
|
|
8
|
+
|
|
8
9
|
import sqlalchemy as sql
|
|
9
10
|
|
|
10
11
|
import pixeltable as pxt
|
|
@@ -166,10 +167,13 @@ class Analyzer:
|
|
|
166
167
|
raise excs.Error(
|
|
167
168
|
f'Invalid non-aggregate expression in aggregate query: {self.select_list[is_agg_output.index(False)]}')
|
|
168
169
|
|
|
169
|
-
# check that filter doesn't contain aggregates
|
|
170
|
+
# check that Where clause and filter doesn't contain aggregates
|
|
171
|
+
if self.sql_where_clause is not None:
|
|
172
|
+
if any(_is_agg_fn_call(e) for e in self.sql_where_clause.subexprs(expr_class=exprs.FunctionCall)):
|
|
173
|
+
raise excs.Error(f'where() cannot contain aggregate functions: {self.sql_where_clause}')
|
|
170
174
|
if self.filter is not None:
|
|
171
175
|
if any(_is_agg_fn_call(e) for e in self.filter.subexprs(expr_class=exprs.FunctionCall)):
|
|
172
|
-
raise excs.Error(f'
|
|
176
|
+
raise excs.Error(f'where() cannot contain aggregate functions: {self.filter}')
|
|
173
177
|
|
|
174
178
|
# check that grouping exprs don't contain aggregates and can be expressed as SQL (we perform sort-based
|
|
175
179
|
# aggregation and rely on the SqlScanNode returning data in the correct order)
|
|
@@ -283,7 +287,8 @@ class Planner:
|
|
|
283
287
|
computed_exprs = row_builder.output_exprs - row_builder.input_exprs
|
|
284
288
|
if len(computed_exprs) > 0:
|
|
285
289
|
# add an ExprEvalNode when there are exprs to compute
|
|
286
|
-
plan = exec.ExprEvalNode(
|
|
290
|
+
plan = exec.ExprEvalNode(
|
|
291
|
+
row_builder, computed_exprs, plan.output_exprs, input=plan, maintain_input_order=False)
|
|
287
292
|
|
|
288
293
|
stored_col_info = row_builder.output_slot_idxs()
|
|
289
294
|
stored_img_col_info = [info for info in stored_col_info if info.col.col_type.is_image_type()]
|
|
@@ -548,7 +553,7 @@ class Planner:
|
|
|
548
553
|
plan = exec.ComponentIterationNode(target, plan)
|
|
549
554
|
if len(view_output_exprs) > 0:
|
|
550
555
|
plan = exec.ExprEvalNode(
|
|
551
|
-
row_builder, output_exprs=view_output_exprs, input_exprs=base_output_exprs,input=plan)
|
|
556
|
+
row_builder, output_exprs=view_output_exprs, input_exprs=base_output_exprs, input=plan)
|
|
552
557
|
|
|
553
558
|
stored_img_col_info = [info for info in row_builder.output_slot_idxs() if info.col.col_type.is_image_type()]
|
|
554
559
|
plan.set_stored_img_cols(stored_img_col_info)
|
|
@@ -750,10 +755,12 @@ class Planner:
|
|
|
750
755
|
ctx.batch_size = 16
|
|
751
756
|
|
|
752
757
|
# do aggregation in SQL if all agg exprs can be translated
|
|
753
|
-
if (
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
758
|
+
if (
|
|
759
|
+
sql_elements.contains_all(analyzer.select_list)
|
|
760
|
+
and sql_elements.contains_all(analyzer.grouping_exprs)
|
|
761
|
+
and isinstance(plan, exec.SqlNode)
|
|
762
|
+
and plan.to_cte() is not None
|
|
763
|
+
):
|
|
757
764
|
plan = exec.SqlAggregationNode(
|
|
758
765
|
row_builder, input=plan, select_list=analyzer.select_list, group_by_items=analyzer.group_by_clause)
|
|
759
766
|
else:
|
|
@@ -770,14 +777,22 @@ class Planner:
|
|
|
770
777
|
# we need an ExprEvalNode to evaluate the remaining output exprs
|
|
771
778
|
plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, sql_exprs, input=plan)
|
|
772
779
|
# we're returning everything to the user, so we might as well do it in a single batch
|
|
780
|
+
# TODO: return smaller batches in order to increase inter-ExecNode parallelism
|
|
773
781
|
ctx.batch_size = 0
|
|
774
782
|
|
|
783
|
+
sql_node = plan.get_node(exec.SqlNode)
|
|
775
784
|
if len(analyzer.order_by_clause) > 0:
|
|
776
785
|
# we have the last SqlNode we created produce the ordering
|
|
777
|
-
sql_node = plan.get_node(exec.SqlNode)
|
|
778
786
|
assert sql_node is not None
|
|
779
787
|
sql_node.set_order_by(analyzer.order_by_clause)
|
|
780
788
|
|
|
789
|
+
# if we don't need an ordered result, tell the ExprEvalNode not to maintain input order (which allows us to
|
|
790
|
+
# return batches earlier)
|
|
791
|
+
if sql_node is not None and len(sql_node.order_by_clause) == 0:
|
|
792
|
+
expr_eval_node = plan.get_node(exec.ExprEvalNode)
|
|
793
|
+
if expr_eval_node is not None:
|
|
794
|
+
expr_eval_node.set_input_order(False)
|
|
795
|
+
|
|
781
796
|
if limit is not None:
|
|
782
797
|
plan.set_limit(limit)
|
|
783
798
|
|
pixeltable/store.py
CHANGED
|
@@ -229,6 +229,7 @@ class StoreBase:
|
|
|
229
229
|
sql.exc.DBAPIError if there was a SQL error during execution
|
|
230
230
|
excs.Error if on_error='abort' and there was an exception during row evaluation
|
|
231
231
|
"""
|
|
232
|
+
assert col.tbl.id == self.tbl_version.id
|
|
232
233
|
num_excs = 0
|
|
233
234
|
num_rows = 0
|
|
234
235
|
|
|
@@ -249,6 +250,7 @@ class StoreBase:
|
|
|
249
250
|
|
|
250
251
|
try:
|
|
251
252
|
# insert rows from exec_plan into temp table
|
|
253
|
+
# TODO: unify the table row construction logic with RowBuilder.create_table_row()
|
|
252
254
|
for row_batch in exec_plan:
|
|
253
255
|
num_rows += len(row_batch)
|
|
254
256
|
tbl_rows: list[dict[str, Any]] = []
|
|
@@ -272,6 +274,10 @@ class StoreBase:
|
|
|
272
274
|
tbl_row[col.sa_errortype_col.name] = error_type
|
|
273
275
|
tbl_row[col.sa_errormsg_col.name] = error_msg
|
|
274
276
|
else:
|
|
277
|
+
if col.col_type.is_image_type() and result_row.file_urls[value_expr_slot_idx] is None:
|
|
278
|
+
# we have yet to store this image
|
|
279
|
+
filepath = str(MediaStore.prepare_media_path(col.tbl.id, col.id, col.tbl.version))
|
|
280
|
+
result_row.flush_img(value_expr_slot_idx, filepath)
|
|
275
281
|
val = result_row.get_stored_val(value_expr_slot_idx, col.sa_col.type)
|
|
276
282
|
if col.col_type.is_media_type():
|
|
277
283
|
val = self._move_tmp_media_file(val, col, result_row.pk[-1])
|