pixeltable 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/table_version.py +2 -1
- pixeltable/dataframe.py +52 -27
- pixeltable/env.py +92 -4
- pixeltable/exec/__init__.py +1 -1
- pixeltable/exec/aggregation_node.py +3 -3
- pixeltable/exec/cache_prefetch_node.py +13 -7
- pixeltable/exec/component_iteration_node.py +3 -9
- pixeltable/exec/data_row_batch.py +17 -5
- pixeltable/exec/exec_node.py +32 -12
- pixeltable/exec/expr_eval/__init__.py +1 -0
- pixeltable/exec/expr_eval/evaluators.py +245 -0
- pixeltable/exec/expr_eval/expr_eval_node.py +404 -0
- pixeltable/exec/expr_eval/globals.py +114 -0
- pixeltable/exec/expr_eval/row_buffer.py +76 -0
- pixeltable/exec/expr_eval/schedulers.py +232 -0
- pixeltable/exec/in_memory_data_node.py +2 -2
- pixeltable/exec/row_update_node.py +14 -14
- pixeltable/exec/sql_node.py +2 -2
- pixeltable/exprs/column_ref.py +5 -1
- pixeltable/exprs/data_row.py +50 -40
- pixeltable/exprs/expr.py +57 -12
- pixeltable/exprs/function_call.py +54 -19
- pixeltable/exprs/inline_expr.py +12 -21
- pixeltable/exprs/literal.py +25 -8
- pixeltable/exprs/row_builder.py +23 -0
- pixeltable/func/aggregate_function.py +4 -0
- pixeltable/func/callable_function.py +54 -4
- pixeltable/func/expr_template_function.py +5 -1
- pixeltable/func/function.py +48 -7
- pixeltable/func/query_template_function.py +16 -7
- pixeltable/func/udf.py +7 -1
- pixeltable/functions/__init__.py +1 -1
- pixeltable/functions/anthropic.py +95 -21
- pixeltable/functions/gemini.py +2 -6
- pixeltable/functions/openai.py +207 -28
- pixeltable/globals.py +1 -1
- pixeltable/plan.py +24 -9
- pixeltable/store.py +6 -0
- pixeltable/type_system.py +3 -3
- pixeltable/utils/arrow.py +3 -3
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/METADATA +3 -1
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/RECORD +46 -41
- pixeltable/exec/expr_eval_node.py +0 -232
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/LICENSE +0 -0
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/WHEEL +0 -0
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/entry_points.txt +0 -0
pixeltable/functions/openai.py
CHANGED
|
@@ -6,14 +6,18 @@ the [Working with OpenAI](https://pixeltable.readme.io/docs/working-with-openai)
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import base64
|
|
9
|
+
import datetime
|
|
9
10
|
import io
|
|
10
11
|
import json
|
|
12
|
+
import logging
|
|
11
13
|
import pathlib
|
|
14
|
+
import re
|
|
12
15
|
import uuid
|
|
13
|
-
from typing import TYPE_CHECKING,
|
|
16
|
+
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union, cast, Any, Type
|
|
14
17
|
|
|
15
|
-
import numpy as np
|
|
16
18
|
import PIL.Image
|
|
19
|
+
import httpx
|
|
20
|
+
import numpy as np
|
|
17
21
|
import tenacity
|
|
18
22
|
|
|
19
23
|
import pixeltable as pxt
|
|
@@ -24,15 +28,28 @@ from pixeltable.utils.code import local_public_names
|
|
|
24
28
|
if TYPE_CHECKING:
|
|
25
29
|
import openai
|
|
26
30
|
|
|
31
|
+
_logger = logging.getLogger('pixeltable')
|
|
32
|
+
|
|
27
33
|
|
|
28
34
|
@env.register_client('openai')
|
|
29
|
-
def _(api_key: str) -> 'openai.OpenAI':
|
|
35
|
+
def _(api_key: str) -> tuple['openai.OpenAI', 'openai.AsyncOpenAI']:
|
|
30
36
|
import openai
|
|
31
|
-
return
|
|
37
|
+
return (
|
|
38
|
+
openai.OpenAI(api_key=api_key),
|
|
39
|
+
openai.AsyncOpenAI(
|
|
40
|
+
api_key=api_key,
|
|
41
|
+
# recommended to increase limits for async client to avoid connection errors
|
|
42
|
+
http_client=httpx.AsyncClient(limits=httpx.Limits(max_keepalive_connections=100, max_connections=500)),
|
|
43
|
+
)
|
|
44
|
+
)
|
|
32
45
|
|
|
33
46
|
|
|
34
47
|
def _openai_client() -> 'openai.OpenAI':
|
|
35
|
-
return env.Env.get().get_client('openai')
|
|
48
|
+
return env.Env.get().get_client('openai')[0]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _async_openai_client() -> 'openai.AsyncOpenAI':
|
|
52
|
+
return env.Env.get().get_client('openai')[1]
|
|
36
53
|
|
|
37
54
|
|
|
38
55
|
# Exponential backoff decorator using tenacity.
|
|
@@ -47,13 +64,128 @@ def _retry(fn: Callable) -> Callable:
|
|
|
47
64
|
)(fn)
|
|
48
65
|
|
|
49
66
|
|
|
67
|
+
# models that share rate limits; see https://platform.openai.com/settings/organization/limits for details
|
|
68
|
+
_shared_rate_limits = {
|
|
69
|
+
'gpt-4-turbo': [
|
|
70
|
+
'gpt-4-turbo',
|
|
71
|
+
'gpt-4-turbo-latest',
|
|
72
|
+
'gpt-4-turbo-2024-04-09',
|
|
73
|
+
'gpt-4-turbo-preview',
|
|
74
|
+
'gpt-4-0125-preview',
|
|
75
|
+
'gpt-4-1106-preview'
|
|
76
|
+
],
|
|
77
|
+
'gpt-4o': [
|
|
78
|
+
'gpt-4o',
|
|
79
|
+
'gpt-4o-latest',
|
|
80
|
+
'gpt-4o-2024-05-13',
|
|
81
|
+
'gpt-4o-2024-08-06',
|
|
82
|
+
'gpt-4o-2024-11-20',
|
|
83
|
+
'gpt-4o-audio-preview',
|
|
84
|
+
'gpt-4o-audio-preview-2024-10-01',
|
|
85
|
+
'gpt-4o-audio-preview-2024-12-17'
|
|
86
|
+
],
|
|
87
|
+
'gpt-4o-mini': [
|
|
88
|
+
'gpt-4o-mini',
|
|
89
|
+
'gpt-4o-mini-latest',
|
|
90
|
+
'gpt-4o-mini-2024-07-18',
|
|
91
|
+
'gpt-4o-mini-audio-preview',
|
|
92
|
+
'gpt-4o-mini-audio-preview-2024-12-17'
|
|
93
|
+
],
|
|
94
|
+
'gpt-4o-mini-realtime-preview': [
|
|
95
|
+
'gpt-4o-mini-realtime-preview',
|
|
96
|
+
'gpt-4o-mini-realtime-preview-latest',
|
|
97
|
+
'gpt-4o-mini-realtime-preview-2024-12-17'
|
|
98
|
+
]
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _resource_pool(model: str) -> str:
|
|
103
|
+
for model_family, models in _shared_rate_limits.items():
|
|
104
|
+
if model in models:
|
|
105
|
+
return f'rate-limits:openai:{model_family}'
|
|
106
|
+
return f'rate-limits:openai:{model}'
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class OpenAIRateLimitsInfo(env.RateLimitsInfo):
|
|
110
|
+
retryable_errors: tuple[Type[Exception], ...]
|
|
111
|
+
|
|
112
|
+
def __init__(self, get_request_resources: Callable[..., dict[str, int]]):
|
|
113
|
+
super().__init__(get_request_resources)
|
|
114
|
+
import openai
|
|
115
|
+
self.retryable_errors = (
|
|
116
|
+
openai.RateLimitError, openai.APITimeoutError, openai.UnprocessableEntityError, openai.InternalServerError
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def get_retry_delay(self, exc: Exception) -> Optional[float]:
|
|
120
|
+
import openai
|
|
121
|
+
|
|
122
|
+
if not isinstance(exc, self.retryable_errors):
|
|
123
|
+
return None
|
|
124
|
+
assert isinstance(exc, openai.APIError)
|
|
125
|
+
return 1.0
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# RE pattern for duration in '*-reset' headers;
|
|
129
|
+
# examples: 1d2h3ms, 4m5.6s; # fractional seconds can be reported as 0.5s or 500ms
|
|
130
|
+
_header_duration_pattern = re.compile(r'(?:(\d+)d)?(?:(\d+)h)?(?:(\d+)ms)|(?:(\d+)m)?(?:([\d.]+)s)?')
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _parse_header_duration(duration_str):
|
|
134
|
+
match = _header_duration_pattern.match(duration_str)
|
|
135
|
+
if not match:
|
|
136
|
+
raise ValueError("Invalid duration format")
|
|
137
|
+
|
|
138
|
+
days = int(match.group(1) or 0)
|
|
139
|
+
hours = int(match.group(2) or 0)
|
|
140
|
+
milliseconds = int(match.group(3) or 0)
|
|
141
|
+
minutes = int(match.group(4) or 0)
|
|
142
|
+
seconds = float(match.group(5) or 0)
|
|
143
|
+
|
|
144
|
+
return datetime.timedelta(
|
|
145
|
+
days=days,
|
|
146
|
+
hours=hours,
|
|
147
|
+
minutes=minutes,
|
|
148
|
+
seconds=seconds,
|
|
149
|
+
milliseconds=milliseconds
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _get_header_info(
|
|
154
|
+
headers: httpx.Headers, *, requests: bool = True, tokens: bool = True
|
|
155
|
+
) -> tuple[Optional[tuple[int, int, datetime.datetime]], Optional[tuple[int, int, datetime.datetime]]]:
|
|
156
|
+
assert requests or tokens
|
|
157
|
+
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
|
158
|
+
|
|
159
|
+
requests_info: Optional[tuple[int, int, datetime.datetime]] = None
|
|
160
|
+
if requests:
|
|
161
|
+
requests_limit_str = headers.get('x-ratelimit-limit-requests')
|
|
162
|
+
requests_limit = int(requests_limit_str) if requests_limit_str is not None else None
|
|
163
|
+
requests_remaining_str = headers.get('x-ratelimit-remaining-requests')
|
|
164
|
+
requests_remaining = int(requests_remaining_str) if requests_remaining_str is not None else None
|
|
165
|
+
requests_reset_str = headers.get('x-ratelimit-reset-requests')
|
|
166
|
+
requests_reset_ts = now + _parse_header_duration(requests_reset_str)
|
|
167
|
+
requests_info = (requests_limit, requests_remaining, requests_reset_ts)
|
|
168
|
+
|
|
169
|
+
tokens_info: Optional[tuple[int, int, datetime.datetime]] = None
|
|
170
|
+
if tokens:
|
|
171
|
+
tokens_limit_str = headers.get('x-ratelimit-limit-tokens')
|
|
172
|
+
tokens_limit = int(tokens_limit_str) if tokens_limit_str is not None else None
|
|
173
|
+
tokens_remaining_str = headers.get('x-ratelimit-remaining-tokens')
|
|
174
|
+
tokens_remaining = int(tokens_remaining_str) if tokens_remaining_str is not None else None
|
|
175
|
+
tokens_reset_str = headers.get('x-ratelimit-reset-tokens')
|
|
176
|
+
tokens_reset_ts = now + _parse_header_duration(tokens_reset_str)
|
|
177
|
+
tokens_info = (tokens_limit, tokens_remaining, tokens_reset_ts)
|
|
178
|
+
|
|
179
|
+
return requests_info, tokens_info
|
|
180
|
+
|
|
181
|
+
|
|
50
182
|
#####################################
|
|
51
183
|
# Audio Endpoints
|
|
52
184
|
|
|
53
185
|
|
|
54
186
|
@pxt.udf
|
|
55
187
|
def speech(
|
|
56
|
-
|
|
188
|
+
input: str, *, model: str, voice: str, response_format: Optional[str] = None, speed: Optional[float] = None
|
|
57
189
|
) -> pxt.Audio:
|
|
58
190
|
"""
|
|
59
191
|
Generates audio from the input text.
|
|
@@ -176,8 +308,24 @@ def translations(
|
|
|
176
308
|
# Chat Endpoints
|
|
177
309
|
|
|
178
310
|
|
|
311
|
+
def _chat_completions_get_request_resources(
|
|
312
|
+
messages: list, max_tokens: Optional[int], n: Optional[int]
|
|
313
|
+
) -> dict[str, int]:
|
|
314
|
+
completion_tokens = n * max_tokens
|
|
315
|
+
|
|
316
|
+
num_tokens = 0.0
|
|
317
|
+
for message in messages:
|
|
318
|
+
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
|
319
|
+
for key, value in message.items():
|
|
320
|
+
num_tokens += len(value) / 4
|
|
321
|
+
if key == "name": # if there's a name, the role is omitted
|
|
322
|
+
num_tokens -= 1 # role is always required and always 1 token
|
|
323
|
+
num_tokens += 2 # every reply is primed with <im_start>assistant
|
|
324
|
+
return {'requests': 1, 'tokens': int(num_tokens) + completion_tokens}
|
|
325
|
+
|
|
326
|
+
|
|
179
327
|
@pxt.udf
|
|
180
|
-
def chat_completions(
|
|
328
|
+
async def chat_completions(
|
|
181
329
|
messages: list,
|
|
182
330
|
*,
|
|
183
331
|
model: str,
|
|
@@ -185,8 +333,8 @@ def chat_completions(
|
|
|
185
333
|
logit_bias: Optional[dict[str, int]] = None,
|
|
186
334
|
logprobs: Optional[bool] = None,
|
|
187
335
|
top_logprobs: Optional[int] = None,
|
|
188
|
-
max_tokens: Optional[int] =
|
|
189
|
-
n: Optional[int] =
|
|
336
|
+
max_tokens: Optional[int] = 1024,
|
|
337
|
+
n: Optional[int] = 1,
|
|
190
338
|
presence_penalty: Optional[float] = None,
|
|
191
339
|
response_format: Optional[dict] = None,
|
|
192
340
|
seed: Optional[int] = None,
|
|
@@ -226,7 +374,6 @@ def chat_completions(
|
|
|
226
374
|
]
|
|
227
375
|
tbl['response'] = chat_completions(messages, model='gpt-4o-mini')
|
|
228
376
|
"""
|
|
229
|
-
|
|
230
377
|
if tools is not None:
|
|
231
378
|
tools = [
|
|
232
379
|
{
|
|
@@ -253,7 +400,8 @@ def chat_completions(
|
|
|
253
400
|
if tool_choice is not None and not tool_choice['parallel_tool_calls']:
|
|
254
401
|
extra_body = {'parallel_tool_calls': False}
|
|
255
402
|
|
|
256
|
-
|
|
403
|
+
# cast(Any, ...): avoid mypy errors
|
|
404
|
+
result = await _async_openai_client().chat.completions.with_raw_response.create(
|
|
257
405
|
messages=messages,
|
|
258
406
|
model=model,
|
|
259
407
|
frequency_penalty=_opt(frequency_penalty),
|
|
@@ -263,17 +411,25 @@ def chat_completions(
|
|
|
263
411
|
max_tokens=_opt(max_tokens),
|
|
264
412
|
n=_opt(n),
|
|
265
413
|
presence_penalty=_opt(presence_penalty),
|
|
266
|
-
response_format=_opt(response_format),
|
|
414
|
+
response_format=_opt(cast(Any, response_format)),
|
|
267
415
|
seed=_opt(seed),
|
|
268
416
|
stop=_opt(stop),
|
|
269
417
|
temperature=_opt(temperature),
|
|
270
418
|
top_p=_opt(top_p),
|
|
271
|
-
tools=_opt(tools),
|
|
272
|
-
tool_choice=_opt(tool_choice_),
|
|
419
|
+
tools=_opt(cast(Any, tools)),
|
|
420
|
+
tool_choice=_opt(cast(Any, tool_choice_)),
|
|
273
421
|
user=_opt(user),
|
|
422
|
+
timeout=10,
|
|
274
423
|
extra_body=extra_body,
|
|
275
424
|
)
|
|
276
|
-
|
|
425
|
+
|
|
426
|
+
resource_pool = _resource_pool(model)
|
|
427
|
+
requests_info, tokens_info = _get_header_info(result.headers)
|
|
428
|
+
rate_limits_info = env.Env.get().get_resource_pool_info(resource_pool, lambda: OpenAIRateLimitsInfo(
|
|
429
|
+
_chat_completions_get_request_resources))
|
|
430
|
+
rate_limits_info.record(requests=requests_info, tokens=tokens_info)
|
|
431
|
+
|
|
432
|
+
return json.loads(result.text)
|
|
277
433
|
|
|
278
434
|
|
|
279
435
|
@pxt.udf
|
|
@@ -330,8 +486,13 @@ _embedding_dimensions_cache: dict[str, int] = {
|
|
|
330
486
|
}
|
|
331
487
|
|
|
332
488
|
|
|
489
|
+
def _embeddings_get_request_resources(input: list[str]) -> dict[str, int]:
|
|
490
|
+
input_len = sum(len(s) for s in input)
|
|
491
|
+
return {'requests': 1, 'tokens': int(input_len / 4)}
|
|
492
|
+
|
|
493
|
+
|
|
333
494
|
@pxt.udf(batch_size=32)
|
|
334
|
-
def embeddings(
|
|
495
|
+
async def embeddings(
|
|
335
496
|
input: Batch[str], *, model: str, dimensions: Optional[int] = None, user: Optional[str] = None
|
|
336
497
|
) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
337
498
|
"""
|
|
@@ -361,10 +522,16 @@ def embeddings(
|
|
|
361
522
|
|
|
362
523
|
>>> tbl['embed'] = embeddings(tbl.text, model='text-embedding-3-small')
|
|
363
524
|
"""
|
|
364
|
-
|
|
525
|
+
_logger.debug(f'embeddings: batch_size={len(input)}')
|
|
526
|
+
result = await _async_openai_client().embeddings.with_raw_response.create(
|
|
365
527
|
input=input, model=model, dimensions=_opt(dimensions), user=_opt(user), encoding_format='float'
|
|
366
528
|
)
|
|
367
|
-
|
|
529
|
+
resource_pool = _resource_pool(model)
|
|
530
|
+
requests_info, tokens_info = _get_header_info(result.headers)
|
|
531
|
+
rate_limits_info = env.Env.get().get_resource_pool_info(
|
|
532
|
+
resource_pool, lambda: OpenAIRateLimitsInfo(_embeddings_get_request_resources))
|
|
533
|
+
rate_limits_info.record(requests=requests_info, tokens=tokens_info)
|
|
534
|
+
return [np.array(data['embedding'], dtype=np.float64) for data in json.loads(result.content)['data']]
|
|
368
535
|
|
|
369
536
|
|
|
370
537
|
@embeddings.conditional_return_type
|
|
@@ -385,7 +552,7 @@ def _(model: str, dimensions: Optional[int] = None) -> pxt.ArrayType:
|
|
|
385
552
|
def image_generations(
|
|
386
553
|
prompt: str,
|
|
387
554
|
*,
|
|
388
|
-
model:
|
|
555
|
+
model: str = 'dall-e-2',
|
|
389
556
|
quality: Optional[str] = None,
|
|
390
557
|
size: Optional[str] = None,
|
|
391
558
|
style: Optional[str] = None,
|
|
@@ -441,7 +608,7 @@ def _(size: Optional[str] = None) -> pxt.ImageType:
|
|
|
441
608
|
if x_pos == -1:
|
|
442
609
|
return pxt.ImageType()
|
|
443
610
|
try:
|
|
444
|
-
width, height = int(size[:x_pos]), int(size[x_pos + 1
|
|
611
|
+
width, height = int(size[:x_pos]), int(size[x_pos + 1:])
|
|
445
612
|
except ValueError:
|
|
446
613
|
return pxt.ImageType()
|
|
447
614
|
return pxt.ImageType(size=(width, height))
|
|
@@ -452,7 +619,7 @@ def _(size: Optional[str] = None) -> pxt.ImageType:
|
|
|
452
619
|
|
|
453
620
|
|
|
454
621
|
@pxt.udf
|
|
455
|
-
def moderations(input: str, *, model:
|
|
622
|
+
def moderations(input: str, *, model: str = 'omni-moderation-latest') -> dict:
|
|
456
623
|
"""
|
|
457
624
|
Classifies if text is potentially harmful.
|
|
458
625
|
|
|
@@ -482,6 +649,18 @@ def moderations(input: str, *, model: Optional[str] = None) -> dict:
|
|
|
482
649
|
return result.dict()
|
|
483
650
|
|
|
484
651
|
|
|
652
|
+
# @speech.resource_pool
|
|
653
|
+
# @transcriptions.resource_pool
|
|
654
|
+
# @translations.resource_pool
|
|
655
|
+
@chat_completions.resource_pool
|
|
656
|
+
# @vision.resource_pool
|
|
657
|
+
@embeddings.resource_pool
|
|
658
|
+
# @image_generations.resource_pool
|
|
659
|
+
# @moderations.resource_pool
|
|
660
|
+
def _(model: str) -> str:
|
|
661
|
+
return _resource_pool(model)
|
|
662
|
+
|
|
663
|
+
|
|
485
664
|
def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
|
|
486
665
|
"""Converts an OpenAI response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
|
|
487
666
|
return tools._invoke(_openai_response_to_pxt_tool_calls(response))
|
|
@@ -489,15 +668,15 @@ def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
|
|
|
489
668
|
|
|
490
669
|
@pxt.udf
|
|
491
670
|
def _openai_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
|
|
671
|
+
if 'tool_calls' not in response['choices'][0]['message'] or response['choices'][0]['message']['tool_calls'] is None:
|
|
672
|
+
return None
|
|
492
673
|
openai_tool_calls = response['choices'][0]['message']['tool_calls']
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
tool_call['function']['
|
|
496
|
-
'args': json.loads(tool_call['function']['arguments'])
|
|
497
|
-
}
|
|
498
|
-
for tool_call in openai_tool_calls
|
|
674
|
+
return {
|
|
675
|
+
tool_call['function']['name']: {
|
|
676
|
+
'args': json.loads(tool_call['function']['arguments'])
|
|
499
677
|
}
|
|
500
|
-
|
|
678
|
+
for tool_call in openai_tool_calls
|
|
679
|
+
}
|
|
501
680
|
|
|
502
681
|
|
|
503
682
|
_T = TypeVar('_T')
|
pixeltable/globals.py
CHANGED
pixeltable/plan.py
CHANGED
|
@@ -5,6 +5,7 @@ import enum
|
|
|
5
5
|
from typing import Any, Iterable, Optional, Sequence, Literal
|
|
6
6
|
from uuid import UUID
|
|
7
7
|
|
|
8
|
+
|
|
8
9
|
import sqlalchemy as sql
|
|
9
10
|
|
|
10
11
|
import pixeltable as pxt
|
|
@@ -166,10 +167,13 @@ class Analyzer:
|
|
|
166
167
|
raise excs.Error(
|
|
167
168
|
f'Invalid non-aggregate expression in aggregate query: {self.select_list[is_agg_output.index(False)]}')
|
|
168
169
|
|
|
169
|
-
# check that filter doesn't contain aggregates
|
|
170
|
+
# check that Where clause and filter doesn't contain aggregates
|
|
171
|
+
if self.sql_where_clause is not None:
|
|
172
|
+
if any(_is_agg_fn_call(e) for e in self.sql_where_clause.subexprs(expr_class=exprs.FunctionCall)):
|
|
173
|
+
raise excs.Error(f'where() cannot contain aggregate functions: {self.sql_where_clause}')
|
|
170
174
|
if self.filter is not None:
|
|
171
175
|
if any(_is_agg_fn_call(e) for e in self.filter.subexprs(expr_class=exprs.FunctionCall)):
|
|
172
|
-
raise excs.Error(f'
|
|
176
|
+
raise excs.Error(f'where() cannot contain aggregate functions: {self.filter}')
|
|
173
177
|
|
|
174
178
|
# check that grouping exprs don't contain aggregates and can be expressed as SQL (we perform sort-based
|
|
175
179
|
# aggregation and rely on the SqlScanNode returning data in the correct order)
|
|
@@ -283,7 +287,8 @@ class Planner:
|
|
|
283
287
|
computed_exprs = row_builder.output_exprs - row_builder.input_exprs
|
|
284
288
|
if len(computed_exprs) > 0:
|
|
285
289
|
# add an ExprEvalNode when there are exprs to compute
|
|
286
|
-
plan = exec.ExprEvalNode(
|
|
290
|
+
plan = exec.ExprEvalNode(
|
|
291
|
+
row_builder, computed_exprs, plan.output_exprs, input=plan, maintain_input_order=False)
|
|
287
292
|
|
|
288
293
|
stored_col_info = row_builder.output_slot_idxs()
|
|
289
294
|
stored_img_col_info = [info for info in stored_col_info if info.col.col_type.is_image_type()]
|
|
@@ -548,7 +553,7 @@ class Planner:
|
|
|
548
553
|
plan = exec.ComponentIterationNode(target, plan)
|
|
549
554
|
if len(view_output_exprs) > 0:
|
|
550
555
|
plan = exec.ExprEvalNode(
|
|
551
|
-
row_builder, output_exprs=view_output_exprs, input_exprs=base_output_exprs,input=plan)
|
|
556
|
+
row_builder, output_exprs=view_output_exprs, input_exprs=base_output_exprs, input=plan)
|
|
552
557
|
|
|
553
558
|
stored_img_col_info = [info for info in row_builder.output_slot_idxs() if info.col.col_type.is_image_type()]
|
|
554
559
|
plan.set_stored_img_cols(stored_img_col_info)
|
|
@@ -750,10 +755,12 @@ class Planner:
|
|
|
750
755
|
ctx.batch_size = 16
|
|
751
756
|
|
|
752
757
|
# do aggregation in SQL if all agg exprs can be translated
|
|
753
|
-
if (
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
758
|
+
if (
|
|
759
|
+
sql_elements.contains_all(analyzer.select_list)
|
|
760
|
+
and sql_elements.contains_all(analyzer.grouping_exprs)
|
|
761
|
+
and isinstance(plan, exec.SqlNode)
|
|
762
|
+
and plan.to_cte() is not None
|
|
763
|
+
):
|
|
757
764
|
plan = exec.SqlAggregationNode(
|
|
758
765
|
row_builder, input=plan, select_list=analyzer.select_list, group_by_items=analyzer.group_by_clause)
|
|
759
766
|
else:
|
|
@@ -770,14 +777,22 @@ class Planner:
|
|
|
770
777
|
# we need an ExprEvalNode to evaluate the remaining output exprs
|
|
771
778
|
plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, sql_exprs, input=plan)
|
|
772
779
|
# we're returning everything to the user, so we might as well do it in a single batch
|
|
780
|
+
# TODO: return smaller batches in order to increase inter-ExecNode parallelism
|
|
773
781
|
ctx.batch_size = 0
|
|
774
782
|
|
|
783
|
+
sql_node = plan.get_node(exec.SqlNode)
|
|
775
784
|
if len(analyzer.order_by_clause) > 0:
|
|
776
785
|
# we have the last SqlNode we created produce the ordering
|
|
777
|
-
sql_node = plan.get_node(exec.SqlNode)
|
|
778
786
|
assert sql_node is not None
|
|
779
787
|
sql_node.set_order_by(analyzer.order_by_clause)
|
|
780
788
|
|
|
789
|
+
# if we don't need an ordered result, tell the ExprEvalNode not to maintain input order (which allows us to
|
|
790
|
+
# return batches earlier)
|
|
791
|
+
if sql_node is not None and len(sql_node.order_by_clause) == 0:
|
|
792
|
+
expr_eval_node = plan.get_node(exec.ExprEvalNode)
|
|
793
|
+
if expr_eval_node is not None:
|
|
794
|
+
expr_eval_node.set_input_order(False)
|
|
795
|
+
|
|
781
796
|
if limit is not None:
|
|
782
797
|
plan.set_limit(limit)
|
|
783
798
|
|
pixeltable/store.py
CHANGED
|
@@ -229,6 +229,7 @@ class StoreBase:
|
|
|
229
229
|
sql.exc.DBAPIError if there was a SQL error during execution
|
|
230
230
|
excs.Error if on_error='abort' and there was an exception during row evaluation
|
|
231
231
|
"""
|
|
232
|
+
assert col.tbl.id == self.tbl_version.id
|
|
232
233
|
num_excs = 0
|
|
233
234
|
num_rows = 0
|
|
234
235
|
|
|
@@ -249,6 +250,7 @@ class StoreBase:
|
|
|
249
250
|
|
|
250
251
|
try:
|
|
251
252
|
# insert rows from exec_plan into temp table
|
|
253
|
+
# TODO: unify the table row construction logic with RowBuilder.create_table_row()
|
|
252
254
|
for row_batch in exec_plan:
|
|
253
255
|
num_rows += len(row_batch)
|
|
254
256
|
tbl_rows: list[dict[str, Any]] = []
|
|
@@ -272,6 +274,10 @@ class StoreBase:
|
|
|
272
274
|
tbl_row[col.sa_errortype_col.name] = error_type
|
|
273
275
|
tbl_row[col.sa_errormsg_col.name] = error_msg
|
|
274
276
|
else:
|
|
277
|
+
if col.col_type.is_image_type() and result_row.file_urls[value_expr_slot_idx] is None:
|
|
278
|
+
# we have yet to store this image
|
|
279
|
+
filepath = str(MediaStore.prepare_media_path(col.tbl.id, col.id, col.tbl.version))
|
|
280
|
+
result_row.flush_img(value_expr_slot_idx, filepath)
|
|
275
281
|
val = result_row.get_stored_val(value_expr_slot_idx, col.sa_col.type)
|
|
276
282
|
if col.col_type.is_media_type():
|
|
277
283
|
val = self._move_tmp_media_file(val, col, result_row.pk[-1])
|
pixeltable/type_system.py
CHANGED
|
@@ -246,8 +246,8 @@ class ColumnType:
|
|
|
246
246
|
col_type = ArrayType.from_literal(val, nullable=nullable)
|
|
247
247
|
if col_type is not None:
|
|
248
248
|
return col_type
|
|
249
|
-
|
|
250
|
-
if isinstance(val,
|
|
249
|
+
# this could still be json-serializable
|
|
250
|
+
if isinstance(val, (list, tuple, dict, np.ndarray, pydantic.BaseModel)):
|
|
251
251
|
try:
|
|
252
252
|
JsonType().validate_literal(val)
|
|
253
253
|
return JsonType(nullable=nullable)
|
|
@@ -866,7 +866,7 @@ class ArrayType(ColumnType):
|
|
|
866
866
|
continue
|
|
867
867
|
if n1 != n2:
|
|
868
868
|
return False
|
|
869
|
-
return val.dtype
|
|
869
|
+
return np.issubdtype(val.dtype, self.numpy_dtype())
|
|
870
870
|
|
|
871
871
|
def _to_json_schema(self) -> dict[str, Any]:
|
|
872
872
|
return {
|
pixeltable/utils/arrow.py
CHANGED
|
@@ -75,7 +75,7 @@ def to_arrow_schema(pixeltable_schema: dict[str, Any]) -> pa.Schema:
|
|
|
75
75
|
return pa.schema((name, to_arrow_type(typ)) for name, typ in pixeltable_schema.items()) # type: ignore[misc]
|
|
76
76
|
|
|
77
77
|
|
|
78
|
-
def to_pydict(batch: pa.RecordBatch) -> dict[str, Union[list, np.ndarray]]:
|
|
78
|
+
def to_pydict(batch: Union[pa.Table, pa.RecordBatch]) -> dict[str, Union[list, np.ndarray]]:
|
|
79
79
|
"""Convert a RecordBatch to a dictionary of lists, unlike pa.lib.RecordBatch.to_pydict,
|
|
80
80
|
this function will not convert numpy arrays to lists, and will preserve the original numpy dtype.
|
|
81
81
|
"""
|
|
@@ -84,7 +84,7 @@ def to_pydict(batch: pa.RecordBatch) -> dict[str, Union[list, np.ndarray]]:
|
|
|
84
84
|
col = batch.column(k)
|
|
85
85
|
if isinstance(col.type, pa.FixedShapeTensorType):
|
|
86
86
|
# treat array columns as numpy arrays to easily preserve numpy type
|
|
87
|
-
out[name] = col.to_numpy(zero_copy_only=False)
|
|
87
|
+
out[name] = col.to_numpy(zero_copy_only=False) # type: ignore[call-arg]
|
|
88
88
|
else:
|
|
89
89
|
# for the rest, use pydict to preserve python types
|
|
90
90
|
out[name] = col.to_pylist()
|
|
@@ -92,7 +92,7 @@ def to_pydict(batch: pa.RecordBatch) -> dict[str, Union[list, np.ndarray]]:
|
|
|
92
92
|
return out
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
def iter_tuples(batch: pa.RecordBatch) -> Iterator[dict[str, Any]]:
|
|
95
|
+
def iter_tuples(batch: Union[pa.Table, pa.RecordBatch]) -> Iterator[dict[str, Any]]:
|
|
96
96
|
"""Convert a RecordBatch to an iterator of dictionaries. also works with pa.Table and pa.RowGroup"""
|
|
97
97
|
pydict = to_pydict(batch)
|
|
98
98
|
assert len(pydict) > 0, 'empty record batch'
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: pixeltable
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: AI Data Infrastructure: Declarative, Multimodal, and Incremental
|
|
5
5
|
Home-page: https://pixeltable.com/
|
|
6
6
|
License: Apache-2.0
|
|
@@ -27,11 +27,13 @@ Requires-Dist: av (>=10.0.0)
|
|
|
27
27
|
Requires-Dist: beautifulsoup4 (>=4.0.0,<5.0.0)
|
|
28
28
|
Requires-Dist: cloudpickle (>=2.2.1,<3.0.0)
|
|
29
29
|
Requires-Dist: ftfy (>=6.2.0,<7.0.0)
|
|
30
|
+
Requires-Dist: httpx (>=0.27)
|
|
30
31
|
Requires-Dist: jinja2 (>=3.1.3,<4.0.0)
|
|
31
32
|
Requires-Dist: jmespath (>=1.0.1,<2.0.0)
|
|
32
33
|
Requires-Dist: jsonschema (>=4.1.0)
|
|
33
34
|
Requires-Dist: lxml (>=5.0)
|
|
34
35
|
Requires-Dist: more-itertools (>=10.2,<11.0)
|
|
36
|
+
Requires-Dist: nest_asyncio (>=1.5)
|
|
35
37
|
Requires-Dist: numpy (>=1.25,<2.0)
|
|
36
38
|
Requires-Dist: pandas (>=2.0,<3.0)
|
|
37
39
|
Requires-Dist: pgvector (>=0.2.1,<0.3.0)
|