pixeltable 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__init__.py +2 -2
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +1 -1
- pixeltable/catalog/column.py +41 -29
- pixeltable/catalog/globals.py +18 -0
- pixeltable/catalog/insertable_table.py +30 -10
- pixeltable/catalog/table.py +198 -86
- pixeltable/catalog/table_version.py +47 -53
- pixeltable/catalog/table_version_path.py +2 -2
- pixeltable/catalog/view.py +17 -18
- pixeltable/dataframe.py +27 -36
- pixeltable/env.py +7 -0
- pixeltable/exec/__init__.py +0 -1
- pixeltable/exec/aggregation_node.py +6 -3
- pixeltable/exec/cache_prefetch_node.py +189 -43
- pixeltable/exec/data_row_batch.py +5 -22
- pixeltable/exec/exec_context.py +2 -2
- pixeltable/exec/exec_node.py +3 -2
- pixeltable/exec/expr_eval_node.py +23 -16
- pixeltable/exec/in_memory_data_node.py +6 -3
- pixeltable/exec/sql_node.py +24 -25
- pixeltable/exprs/arithmetic_expr.py +12 -5
- pixeltable/exprs/array_slice.py +7 -7
- pixeltable/exprs/column_property_ref.py +37 -10
- pixeltable/exprs/column_ref.py +97 -14
- pixeltable/exprs/comparison.py +10 -5
- pixeltable/exprs/compound_predicate.py +8 -7
- pixeltable/exprs/data_row.py +27 -18
- pixeltable/exprs/expr.py +53 -52
- pixeltable/exprs/expr_set.py +5 -0
- pixeltable/exprs/function_call.py +32 -16
- pixeltable/exprs/globals.py +4 -1
- pixeltable/exprs/in_predicate.py +8 -7
- pixeltable/exprs/inline_expr.py +4 -4
- pixeltable/exprs/is_null.py +4 -4
- pixeltable/exprs/json_mapper.py +11 -12
- pixeltable/exprs/json_path.py +6 -11
- pixeltable/exprs/literal.py +5 -5
- pixeltable/exprs/method_ref.py +5 -4
- pixeltable/exprs/object_ref.py +2 -1
- pixeltable/exprs/row_builder.py +88 -36
- pixeltable/exprs/rowid_ref.py +12 -11
- pixeltable/exprs/similarity_expr.py +12 -7
- pixeltable/exprs/sql_element_cache.py +7 -5
- pixeltable/exprs/type_cast.py +8 -6
- pixeltable/exprs/variable.py +5 -4
- pixeltable/func/aggregate_function.py +9 -9
- pixeltable/func/expr_template_function.py +6 -5
- pixeltable/func/function.py +11 -10
- pixeltable/func/udf.py +6 -11
- pixeltable/functions/__init__.py +2 -2
- pixeltable/functions/globals.py +5 -7
- pixeltable/functions/huggingface.py +155 -45
- pixeltable/functions/llama_cpp.py +107 -0
- pixeltable/functions/mistralai.py +1 -1
- pixeltable/functions/ollama.py +147 -0
- pixeltable/functions/openai.py +1 -1
- pixeltable/functions/replicate.py +72 -0
- pixeltable/functions/string.py +9 -0
- pixeltable/functions/together.py +1 -1
- pixeltable/functions/util.py +5 -2
- pixeltable/globals.py +67 -26
- pixeltable/index/btree.py +16 -3
- pixeltable/index/embedding_index.py +4 -4
- pixeltable/io/__init__.py +1 -2
- pixeltable/io/fiftyone.py +178 -0
- pixeltable/io/globals.py +96 -2
- pixeltable/iterators/base.py +3 -2
- pixeltable/iterators/document.py +1 -1
- pixeltable/iterators/video.py +120 -63
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_21.py +34 -0
- pixeltable/metadata/converters/util.py +45 -4
- pixeltable/metadata/notes.py +1 -0
- pixeltable/metadata/schema.py +8 -0
- pixeltable/plan.py +17 -15
- pixeltable/py.typed +0 -0
- pixeltable/store.py +7 -2
- pixeltable/tool/create_test_db_dump.py +1 -1
- pixeltable/tool/create_test_video.py +1 -1
- pixeltable/tool/embed_udf.py +1 -1
- pixeltable/tool/mypy_plugin.py +28 -5
- pixeltable/type_system.py +100 -36
- pixeltable/utils/coco.py +5 -5
- pixeltable/utils/documents.py +15 -1
- pixeltable/utils/formatter.py +12 -13
- pixeltable/utils/s3.py +6 -3
- {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/METADATA +158 -49
- pixeltable-0.2.23.dist-info/RECORD +153 -0
- pixeltable/exec/media_validation_node.py +0 -43
- pixeltable-0.2.21.dist-info/RECORD +0 -148
- {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/entry_points.txt +0 -0
|
@@ -1,87 +1,226 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import dataclasses
|
|
4
|
+
import itertools
|
|
4
5
|
import logging
|
|
5
6
|
import threading
|
|
6
7
|
import urllib.parse
|
|
7
8
|
import urllib.request
|
|
8
|
-
from collections import
|
|
9
|
+
from collections import deque
|
|
10
|
+
from concurrent import futures
|
|
9
11
|
from pathlib import Path
|
|
10
|
-
from typing import
|
|
12
|
+
from typing import Optional, Any, Iterator
|
|
11
13
|
from uuid import UUID
|
|
12
14
|
|
|
13
15
|
import pixeltable.env as env
|
|
14
16
|
import pixeltable.exceptions as excs
|
|
15
17
|
import pixeltable.exprs as exprs
|
|
18
|
+
from pixeltable import catalog
|
|
16
19
|
from pixeltable.utils.filecache import FileCache
|
|
20
|
+
|
|
17
21
|
from .data_row_batch import DataRowBatch
|
|
18
22
|
from .exec_node import ExecNode
|
|
19
23
|
|
|
20
24
|
_logger = logging.getLogger('pixeltable')
|
|
21
25
|
|
|
26
|
+
|
|
22
27
|
class CachePrefetchNode(ExecNode):
|
|
23
28
|
"""Brings files with external URLs into the cache
|
|
24
29
|
|
|
25
30
|
TODO:
|
|
26
|
-
- maintain a queue of row batches, in order to overlap download and evaluation
|
|
27
31
|
- adapting the number of download threads at runtime to maximize throughput
|
|
28
32
|
"""
|
|
29
|
-
|
|
30
|
-
|
|
33
|
+
BATCH_SIZE = 16
|
|
34
|
+
NUM_EXECUTOR_THREADS = 16
|
|
35
|
+
|
|
36
|
+
retain_input_order: bool # if True, return rows in the exact order they were received
|
|
37
|
+
file_col_info: list[exprs.ColumnSlotIdx]
|
|
38
|
+
boto_client: Optional[Any]
|
|
39
|
+
boto_client_lock: threading.Lock
|
|
40
|
+
|
|
41
|
+
# execution state
|
|
42
|
+
batch_tbl_version: Optional[catalog.TableVersion] # needed to construct output batches
|
|
43
|
+
num_returned_rows: int
|
|
44
|
+
|
|
45
|
+
# ready_rows: rows that are ready to be returned, ordered by row idx;
|
|
46
|
+
# the implied row idx of ready_rows[0] is num_returned_rows
|
|
47
|
+
ready_rows: deque[Optional[exprs.DataRow]]
|
|
48
|
+
|
|
49
|
+
in_flight_rows: dict[int, CachePrefetchNode.RowState] # rows with in-flight urls; id(row) -> RowState
|
|
50
|
+
in_flight_requests: dict[futures.Future, str] # in-flight requests for urls; future -> URL
|
|
51
|
+
in_flight_urls: dict[str, list[tuple[exprs.DataRow, exprs.ColumnSlotIdx]]] # URL -> [(row, info)]
|
|
52
|
+
input_finished: bool
|
|
53
|
+
row_idx: Iterator[Optional[int]]
|
|
54
|
+
|
|
55
|
+
@dataclasses.dataclass
|
|
56
|
+
class RowState:
|
|
57
|
+
row: exprs.DataRow
|
|
58
|
+
idx: Optional[int] # position in input stream; None if we don't retain input order
|
|
59
|
+
num_missing: int # number of missing URLs in this row
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self, tbl_id: UUID, file_col_info: list[exprs.ColumnSlotIdx], input: ExecNode,
|
|
63
|
+
retain_input_order: bool = True):
|
|
64
|
+
# input_/output_exprs=[]: we don't have anything to evaluate
|
|
31
65
|
super().__init__(input.row_builder, [], [], input)
|
|
32
|
-
self.
|
|
66
|
+
self.retain_input_order = retain_input_order
|
|
33
67
|
self.file_col_info = file_col_info
|
|
34
68
|
|
|
35
69
|
# clients for specific services are constructed as needed, because it's time-consuming
|
|
36
|
-
self.boto_client
|
|
70
|
+
self.boto_client = None
|
|
37
71
|
self.boto_client_lock = threading.Lock()
|
|
38
72
|
|
|
39
|
-
|
|
40
|
-
|
|
73
|
+
self.batch_tbl_version = None
|
|
74
|
+
self.num_returned_rows = 0
|
|
75
|
+
self.ready_rows = deque()
|
|
76
|
+
self.in_flight_rows = {}
|
|
77
|
+
self.in_flight_requests = {}
|
|
78
|
+
self.in_flight_urls = {}
|
|
79
|
+
self.input_finished = False
|
|
80
|
+
self.row_idx = itertools.count() if retain_input_order else itertools.repeat(None)
|
|
81
|
+
|
|
82
|
+
def __iter__(self) -> Iterator[DataRowBatch]:
|
|
83
|
+
input_iter = iter(self.input)
|
|
84
|
+
with futures.ThreadPoolExecutor(max_workers=self.NUM_EXECUTOR_THREADS) as executor:
|
|
85
|
+
# we create enough in-flight requests to fill the first batch
|
|
86
|
+
while not self.input_finished and self.__num_pending_rows() < self.BATCH_SIZE:
|
|
87
|
+
self.__submit_input_batch(input_iter, executor)
|
|
88
|
+
|
|
89
|
+
while True:
|
|
90
|
+
# try to assemble a full batch of output rows
|
|
91
|
+
if not self.__has_ready_batch() and len(self.in_flight_requests) > 0:
|
|
92
|
+
self.__wait_for_requests()
|
|
93
|
+
|
|
94
|
+
# try to create enough in-flight requests to fill the next batch
|
|
95
|
+
while not self.input_finished and self.__num_pending_rows() < self.BATCH_SIZE:
|
|
96
|
+
self.__submit_input_batch(input_iter, executor)
|
|
97
|
+
|
|
98
|
+
if len(self.ready_rows) > 0:
|
|
99
|
+
# create DataRowBatch from the first BATCH_SIZE ready rows
|
|
100
|
+
batch = DataRowBatch(self.batch_tbl_version, self.row_builder)
|
|
101
|
+
rows = [self.ready_rows.popleft() for _ in range(min(self.BATCH_SIZE, len(self.ready_rows)))]
|
|
102
|
+
for row in rows:
|
|
103
|
+
assert row is not None
|
|
104
|
+
batch.add_row(row)
|
|
105
|
+
self.num_returned_rows += len(rows)
|
|
106
|
+
_logger.debug(f'returning {len(rows)} rows')
|
|
107
|
+
yield batch
|
|
108
|
+
|
|
109
|
+
if self.input_finished and self.__num_pending_rows() == 0:
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
def __num_pending_rows(self) -> int:
|
|
113
|
+
return len(self.in_flight_rows) + len(self.ready_rows)
|
|
114
|
+
|
|
115
|
+
def __has_ready_batch(self) -> bool:
|
|
116
|
+
"""True if there are >= BATCH_SIZES entries in ready_rows and the first BATCH_SIZE ones are all non-None"""
|
|
117
|
+
return (
|
|
118
|
+
sum(int(row is not None) for row in itertools.islice(self.ready_rows, self.BATCH_SIZE)) == self.BATCH_SIZE
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def __ready_prefix_len(self) -> int:
|
|
122
|
+
"""Length of the non-None prefix of ready_rows (= what we can return right now)"""
|
|
123
|
+
return sum(1 for _ in itertools.takewhile(lambda x: x is not None, self.ready_rows))
|
|
124
|
+
|
|
125
|
+
def __add_ready_row(self, row: exprs.DataRow, row_idx: Optional[int]) -> None:
|
|
126
|
+
if row_idx is None:
|
|
127
|
+
self.ready_rows.append(row)
|
|
128
|
+
else:
|
|
129
|
+
# extend ready_rows to accommodate row_idx
|
|
130
|
+
idx = row_idx - self.num_returned_rows
|
|
131
|
+
if idx >= len(self.ready_rows):
|
|
132
|
+
self.ready_rows.extend([None] * (idx - len(self.ready_rows) + 1))
|
|
133
|
+
self.ready_rows[idx] = row
|
|
134
|
+
|
|
135
|
+
def __wait_for_requests(self) -> None:
|
|
136
|
+
"""Wait for in-flight requests to complete until we have a full batch of rows"""
|
|
137
|
+
file_cache = FileCache.get()
|
|
138
|
+
_logger.debug(f'waiting for requests; ready_batch_size={self.__ready_prefix_len()}')
|
|
139
|
+
while not self.__has_ready_batch() and len(self.in_flight_requests) > 0:
|
|
140
|
+
done, _ = futures.wait(self.in_flight_requests, return_when=futures.FIRST_COMPLETED)
|
|
141
|
+
for f in done:
|
|
142
|
+
url = self.in_flight_requests.pop(f)
|
|
143
|
+
tmp_path, exc = f.result()
|
|
144
|
+
local_path: Optional[Path] = None
|
|
145
|
+
if tmp_path is not None:
|
|
146
|
+
# register the file with the cache for the first column in which it's missing
|
|
147
|
+
assert url in self.in_flight_urls
|
|
148
|
+
_, info = self.in_flight_urls[url][0]
|
|
149
|
+
local_path = file_cache.add(info.col.tbl.id, info.col.id, url, tmp_path)
|
|
150
|
+
_logger.debug(f'cached {url} as {local_path}')
|
|
151
|
+
|
|
152
|
+
# add the local path/exception to the slots that reference the url
|
|
153
|
+
for row, info in self.in_flight_urls.pop(url):
|
|
154
|
+
if exc is not None:
|
|
155
|
+
self.row_builder.set_exc(row, info.slot_idx, exc)
|
|
156
|
+
else:
|
|
157
|
+
assert local_path is not None
|
|
158
|
+
row.set_file_path(info.slot_idx, str(local_path))
|
|
159
|
+
state = self.in_flight_rows[id(row)]
|
|
160
|
+
state.num_missing -= 1
|
|
161
|
+
if state.num_missing == 0:
|
|
162
|
+
del self.in_flight_rows[id(row)]
|
|
163
|
+
self.__add_ready_row(row, state.idx)
|
|
164
|
+
_logger.debug(f'row {state.idx} is ready (ready_batch_size={self.__ready_prefix_len()})')
|
|
165
|
+
|
|
166
|
+
def __submit_input_batch(self, input: Iterator[DataRowBatch], executor: futures.ThreadPoolExecutor) -> None:
|
|
167
|
+
assert not self.input_finished
|
|
168
|
+
input_batch = next(input, None)
|
|
169
|
+
if input_batch is None:
|
|
170
|
+
self.input_finished = True
|
|
171
|
+
return
|
|
172
|
+
if self.batch_tbl_version is None:
|
|
173
|
+
self.batch_tbl_version = input_batch.tbl
|
|
41
174
|
|
|
42
|
-
# collect external URLs that aren't already cached, and set DataRow.file_paths for those that are
|
|
43
175
|
file_cache = FileCache.get()
|
|
44
|
-
|
|
45
|
-
|
|
176
|
+
|
|
177
|
+
# URLs from this input batch that aren't already in the file cache;
|
|
178
|
+
# we use a list to make sure we submit urls in the order in which they appear in the output, which minimizes
|
|
179
|
+
# the time it takes to get the next batch together
|
|
180
|
+
cache_misses: list[str] = []
|
|
181
|
+
|
|
182
|
+
url_pos: dict[str, int] = {} # url -> row_idx; used for logging
|
|
46
183
|
for row in input_batch:
|
|
184
|
+
# identify missing local files in input batch, or fill in their paths if they're already cached
|
|
185
|
+
num_missing = 0
|
|
186
|
+
row_idx = next(self.row_idx)
|
|
187
|
+
|
|
47
188
|
for info in self.file_col_info:
|
|
48
189
|
url = row.file_urls[info.slot_idx]
|
|
49
190
|
if url is None or row.file_paths[info.slot_idx] is not None:
|
|
50
191
|
# nothing to do
|
|
51
192
|
continue
|
|
52
|
-
|
|
53
|
-
|
|
193
|
+
locations = self.in_flight_urls.get(url)
|
|
194
|
+
if locations is not None:
|
|
195
|
+
# we've already seen this
|
|
196
|
+
locations.append((row, info))
|
|
197
|
+
num_missing += 1
|
|
54
198
|
continue
|
|
199
|
+
|
|
55
200
|
local_path = file_cache.lookup(url)
|
|
56
201
|
if local_path is None:
|
|
57
|
-
cache_misses.append(
|
|
58
|
-
|
|
202
|
+
cache_misses.append(url)
|
|
203
|
+
self.in_flight_urls[url] = [(row, info)]
|
|
204
|
+
num_missing += 1
|
|
205
|
+
if url not in url_pos:
|
|
206
|
+
url_pos[url] = row_idx
|
|
59
207
|
else:
|
|
60
208
|
row.set_file_path(info.slot_idx, str(local_path))
|
|
61
209
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
for row, info in cache_misses:
|
|
67
|
-
futures[executor.submit(self._fetch_url, row, info.slot_idx)] = (row, info)
|
|
68
|
-
for future in concurrent.futures.as_completed(futures):
|
|
69
|
-
# TODO: does this need to deal with recoverable errors (such as retry after throttling)?
|
|
70
|
-
tmp_path = future.result()
|
|
71
|
-
if tmp_path is None:
|
|
72
|
-
continue
|
|
73
|
-
row, info = futures[future]
|
|
74
|
-
url = row.file_urls[info.slot_idx]
|
|
75
|
-
local_path = file_cache.add(self.tbl_id, info.col.id, url, tmp_path)
|
|
76
|
-
_logger.debug(f'PrefetchNode: cached {url} as {local_path}')
|
|
77
|
-
for row in missing_url_rows[url]:
|
|
78
|
-
row.set_file_path(info.slot_idx, str(local_path))
|
|
210
|
+
if num_missing > 0:
|
|
211
|
+
self.in_flight_rows[id(row)] = self.RowState(row, row_idx, num_missing)
|
|
212
|
+
else:
|
|
213
|
+
self.__add_ready_row(row, row_idx)
|
|
79
214
|
|
|
80
|
-
|
|
215
|
+
_logger.debug(f'submitting {len(cache_misses)} urls')
|
|
216
|
+
for url in cache_misses:
|
|
217
|
+
f = executor.submit(self.__fetch_url, url)
|
|
218
|
+
_logger.debug(f'submitted {url} for idx {url_pos[url]}')
|
|
219
|
+
self.in_flight_requests[f] = url
|
|
81
220
|
|
|
82
|
-
def
|
|
221
|
+
def __fetch_url(self, url: str) -> tuple[Optional[Path], Optional[Exception]]:
|
|
83
222
|
"""Fetches a remote URL into Env.tmp_dir and returns its path"""
|
|
84
|
-
url =
|
|
223
|
+
_logger.debug(f'fetching url={url} thread_name={threading.current_thread().name}')
|
|
85
224
|
parsed = urllib.parse.urlparse(url)
|
|
86
225
|
# Use len(parsed.scheme) > 1 here to ensure we're not being passed
|
|
87
226
|
# a Windows filename
|
|
@@ -93,24 +232,31 @@ class CachePrefetchNode(ExecNode):
|
|
|
93
232
|
extension = p.suffix
|
|
94
233
|
tmp_path = env.Env.get().create_tmp_path(extension=extension)
|
|
95
234
|
try:
|
|
235
|
+
_logger.debug(f'Downloading {url} to {tmp_path}')
|
|
96
236
|
if parsed.scheme == 's3':
|
|
97
237
|
from pixeltable.utils.s3 import get_client
|
|
98
238
|
with self.boto_client_lock:
|
|
99
239
|
if self.boto_client is None:
|
|
100
|
-
|
|
101
|
-
|
|
240
|
+
config = {
|
|
241
|
+
'max_pool_connections': self.NUM_EXECUTOR_THREADS + 4, # +4: leave some headroom
|
|
242
|
+
'connect_timeout': 5,
|
|
243
|
+
'read_timeout': 30,
|
|
244
|
+
'retries': {'max_attempts': 3, 'mode': 'adaptive'},
|
|
245
|
+
}
|
|
246
|
+
self.boto_client = get_client(**config)
|
|
247
|
+
self.boto_client.download_file(parsed.netloc, parsed.path.lstrip('/'), str(tmp_path))
|
|
102
248
|
elif parsed.scheme == 'http' or parsed.scheme == 'https':
|
|
103
249
|
with urllib.request.urlopen(url) as resp, open(tmp_path, 'wb') as f:
|
|
104
250
|
data = resp.read()
|
|
105
251
|
f.write(data)
|
|
106
252
|
else:
|
|
107
253
|
assert False, f'Unsupported URL scheme: {parsed.scheme}'
|
|
108
|
-
|
|
254
|
+
_logger.debug(f'Downloaded {url} to {tmp_path}')
|
|
255
|
+
return tmp_path, None
|
|
109
256
|
except Exception as e:
|
|
110
257
|
# we want to add the file url to the exception message
|
|
111
258
|
exc = excs.Error(f'Failed to download {url}: {e}')
|
|
112
|
-
|
|
259
|
+
_logger.debug(f'Failed to download {url}: {e}', exc_info=e)
|
|
113
260
|
if not self.ctx.ignore_errors:
|
|
114
261
|
raise exc from None # suppress original exception
|
|
115
|
-
|
|
116
|
-
|
|
262
|
+
return None, exc
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Iterator, Optional
|
|
3
3
|
import logging
|
|
4
4
|
|
|
5
5
|
import pixeltable.exprs as exprs
|
|
@@ -49,12 +49,12 @@ class DataRowBatch:
|
|
|
49
49
|
def __len__(self) -> int:
|
|
50
50
|
return len(self.rows)
|
|
51
51
|
|
|
52
|
-
def __getitem__(self, index:
|
|
52
|
+
def __getitem__(self, index: int) -> exprs.DataRow:
|
|
53
53
|
return self.rows[index]
|
|
54
54
|
|
|
55
55
|
def flush_imgs(
|
|
56
|
-
self, idx_range: Optional[slice] = None, stored_img_info: Optional[
|
|
57
|
-
flushed_slot_idxs: Optional[
|
|
56
|
+
self, idx_range: Optional[slice] = None, stored_img_info: Optional[list[exprs.ColumnSlotIdx]] = None,
|
|
57
|
+
flushed_slot_idxs: Optional[list[int]] = None
|
|
58
58
|
) -> None:
|
|
59
59
|
"""Flushes images in the given range of rows."""
|
|
60
60
|
assert self.tbl is not None
|
|
@@ -74,21 +74,4 @@ class DataRowBatch:
|
|
|
74
74
|
row.flush_img(slot_idx)
|
|
75
75
|
|
|
76
76
|
def __iter__(self) -> Iterator[exprs.DataRow]:
|
|
77
|
-
return
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
class DataRowBatchIterator:
|
|
81
|
-
"""
|
|
82
|
-
Iterator over a DataRowBatch.
|
|
83
|
-
"""
|
|
84
|
-
def __init__(self, batch: DataRowBatch):
|
|
85
|
-
self.row_batch = batch
|
|
86
|
-
self.index = 0
|
|
87
|
-
|
|
88
|
-
def __next__(self) -> exprs.DataRow:
|
|
89
|
-
if self.index >= len(self.row_batch.rows):
|
|
90
|
-
raise StopIteration
|
|
91
|
-
row = self.row_batch.rows[self.index]
|
|
92
|
-
self.index += 1
|
|
93
|
-
return row
|
|
94
|
-
|
|
77
|
+
return iter(self.rows)
|
pixeltable/exec/exec_context.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Optional
|
|
2
2
|
|
|
3
3
|
import sqlalchemy as sql
|
|
4
4
|
|
|
@@ -8,7 +8,7 @@ class ExecContext:
|
|
|
8
8
|
"""Class for execution runtime constants"""
|
|
9
9
|
def __init__(
|
|
10
10
|
self, row_builder: exprs.RowBuilder, *, show_pbar: bool = False, batch_size: int = 0,
|
|
11
|
-
pk_clause: Optional[
|
|
11
|
+
pk_clause: Optional[list[sql.ClauseElement]] = None, num_computed_exprs: int = 0,
|
|
12
12
|
ignore_errors: bool = False
|
|
13
13
|
):
|
|
14
14
|
self.show_pbar = show_pbar
|
pixeltable/exec/exec_node.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import TYPE_CHECKING, Iterable, Iterator, Optional
|
|
5
5
|
|
|
6
6
|
import pixeltable.exprs as exprs
|
|
7
|
+
|
|
7
8
|
from .data_row_batch import DataRowBatch
|
|
8
9
|
from .exec_context import ExecContext
|
|
9
10
|
|
|
@@ -42,7 +43,7 @@ class ExecNode(abc.ABC):
|
|
|
42
43
|
if self.input is not None:
|
|
43
44
|
self.input.set_ctx(ctx)
|
|
44
45
|
|
|
45
|
-
def set_stored_img_cols(self, stored_img_cols:
|
|
46
|
+
def set_stored_img_cols(self, stored_img_cols: list[exprs.ColumnSlotIdx]) -> None:
|
|
46
47
|
self.stored_img_cols = stored_img_cols
|
|
47
48
|
# propagate batch size to the source
|
|
48
49
|
if self.input is not None:
|
|
@@ -3,11 +3,11 @@ import sys
|
|
|
3
3
|
import time
|
|
4
4
|
import warnings
|
|
5
5
|
from dataclasses import dataclass
|
|
6
|
-
from typing import Iterable,
|
|
6
|
+
from typing import Iterable, Optional
|
|
7
7
|
|
|
8
8
|
from tqdm import TqdmWarning, tqdm
|
|
9
9
|
|
|
10
|
-
|
|
10
|
+
from pixeltable import exprs
|
|
11
11
|
from pixeltable.func import CallableFunction
|
|
12
12
|
|
|
13
13
|
from .data_row_batch import DataRowBatch
|
|
@@ -22,10 +22,10 @@ class ExprEvalNode(ExecNode):
|
|
|
22
22
|
@dataclass
|
|
23
23
|
class Cohort:
|
|
24
24
|
"""List of exprs that form an evaluation context and contain calls to at most one external function"""
|
|
25
|
-
|
|
25
|
+
exprs_: list[exprs.Expr]
|
|
26
26
|
batched_fn: Optional[CallableFunction]
|
|
27
|
-
segment_ctxs:
|
|
28
|
-
target_slot_idxs:
|
|
27
|
+
segment_ctxs: list['exprs.RowBuilder.EvalCtx']
|
|
28
|
+
target_slot_idxs: list[int]
|
|
29
29
|
batch_size: int = 8
|
|
30
30
|
|
|
31
31
|
def __init__(
|
|
@@ -38,7 +38,7 @@ class ExprEvalNode(ExecNode):
|
|
|
38
38
|
# we're only materializing exprs that are not already in the input
|
|
39
39
|
self.target_exprs = [e for e in output_exprs if e.slot_idx not in input_slot_idxs]
|
|
40
40
|
self.pbar: Optional[tqdm] = None
|
|
41
|
-
self.cohorts:
|
|
41
|
+
self.cohorts: list[ExprEvalNode.Cohort] = []
|
|
42
42
|
self._create_cohorts()
|
|
43
43
|
|
|
44
44
|
def __next__(self) -> DataRowBatch:
|
|
@@ -83,11 +83,13 @@ class ExprEvalNode(ExecNode):
|
|
|
83
83
|
all_exprs = self.row_builder.get_dependencies(self.target_exprs)
|
|
84
84
|
# break up all_exprs into cohorts such that each cohort contains calls to at most one external function;
|
|
85
85
|
# seed the cohorts with only the ext fn calls
|
|
86
|
-
cohorts:
|
|
86
|
+
cohorts: list[list[exprs.Expr]] = []
|
|
87
87
|
current_batched_fn: Optional[CallableFunction] = None
|
|
88
88
|
for e in all_exprs:
|
|
89
89
|
if not self._is_batched_fn_call(e):
|
|
90
90
|
continue
|
|
91
|
+
assert isinstance(e, exprs.FunctionCall)
|
|
92
|
+
assert isinstance(e.fn, CallableFunction)
|
|
91
93
|
if current_batched_fn is None or current_batched_fn != e.fn:
|
|
92
94
|
# create a new cohort
|
|
93
95
|
cohorts.append([])
|
|
@@ -96,9 +98,9 @@ class ExprEvalNode(ExecNode):
|
|
|
96
98
|
|
|
97
99
|
# expand the cohorts to include all exprs that are in the same evaluation context as the external calls;
|
|
98
100
|
# cohorts are evaluated in order, so we can exclude the target slots from preceding cohorts and input slots
|
|
99
|
-
exclude = set(
|
|
100
|
-
all_target_slot_idxs = set(
|
|
101
|
-
target_slot_idxs:
|
|
101
|
+
exclude = set(e.slot_idx for e in self.input_exprs)
|
|
102
|
+
all_target_slot_idxs = set(e.slot_idx for e in self.target_exprs)
|
|
103
|
+
target_slot_idxs: list[list[int]] = [] # the ones materialized by each cohort
|
|
102
104
|
for i in range(len(cohorts)):
|
|
103
105
|
cohorts[i] = self.row_builder.get_dependencies(
|
|
104
106
|
cohorts[i], exclude=[self.row_builder.unique_exprs[slot_idx] for slot_idx in exclude])
|
|
@@ -106,7 +108,7 @@ class ExprEvalNode(ExecNode):
|
|
|
106
108
|
[e.slot_idx for e in cohorts[i] if e.slot_idx in all_target_slot_idxs])
|
|
107
109
|
exclude.update(target_slot_idxs[-1])
|
|
108
110
|
|
|
109
|
-
all_cohort_slot_idxs = set(
|
|
111
|
+
all_cohort_slot_idxs = set(e.slot_idx for cohort in cohorts for e in cohort)
|
|
110
112
|
remaining_slot_idxs = set(all_target_slot_idxs) - all_cohort_slot_idxs
|
|
111
113
|
if len(remaining_slot_idxs) > 0:
|
|
112
114
|
cohorts.append(self.row_builder.get_dependencies(
|
|
@@ -164,11 +166,12 @@ class ExprEvalNode(ExecNode):
|
|
|
164
166
|
rows[row_idx], segment_ctx, self.ctx.profile, ignore_errors=self.ctx.ignore_errors)
|
|
165
167
|
else:
|
|
166
168
|
fn_call = segment_ctx.exprs[0]
|
|
169
|
+
assert isinstance(fn_call, exprs.FunctionCall)
|
|
167
170
|
# make a batched external function call
|
|
168
|
-
arg_batches = [[] for _ in range(len(fn_call.args))]
|
|
169
|
-
kwarg_batches = {k: [] for k in fn_call.kwargs.keys()}
|
|
171
|
+
arg_batches: list[list[exprs.Expr]] = [[] for _ in range(len(fn_call.args))]
|
|
172
|
+
kwarg_batches: dict[str, list[exprs.Expr]] = {k: [] for k in fn_call.kwargs.keys()}
|
|
170
173
|
|
|
171
|
-
valid_batch_idxs:
|
|
174
|
+
valid_batch_idxs: list[int] = [] # rows with exceptions are not valid
|
|
172
175
|
for row_idx in range(batch_start_idx, batch_start_idx + num_batch_rows):
|
|
173
176
|
row = rows[row_idx]
|
|
174
177
|
if row.has_exc(fn_call.slot_idx):
|
|
@@ -176,12 +179,15 @@ class ExprEvalNode(ExecNode):
|
|
|
176
179
|
continue
|
|
177
180
|
valid_batch_idxs.append(row_idx)
|
|
178
181
|
args, kwargs = fn_call._make_args(row)
|
|
179
|
-
|
|
180
|
-
|
|
182
|
+
for i in range(len(args)):
|
|
183
|
+
arg_batches[i].append(args[i])
|
|
184
|
+
for k in kwargs.keys():
|
|
185
|
+
kwarg_batches[k].append(kwargs[k])
|
|
181
186
|
num_valid_batch_rows = len(valid_batch_idxs)
|
|
182
187
|
|
|
183
188
|
if ext_batch_size is None:
|
|
184
189
|
# we need to choose a batch size based on the args
|
|
190
|
+
assert isinstance(fn_call.fn, CallableFunction)
|
|
185
191
|
sample_args = [arg_batches[i][0] for i in range(len(arg_batches))]
|
|
186
192
|
ext_batch_size = fn_call.fn.get_batch_size(*sample_args)
|
|
187
193
|
|
|
@@ -201,6 +207,7 @@ class ExprEvalNode(ExecNode):
|
|
|
201
207
|
for k in kwarg_batches.keys()
|
|
202
208
|
}
|
|
203
209
|
start_ts = time.perf_counter()
|
|
210
|
+
assert isinstance(fn_call.fn, CallableFunction)
|
|
204
211
|
result_batch = fn_call.fn.exec_batch(*call_args, **call_kwargs)
|
|
205
212
|
self.ctx.profile.eval_time[fn_call.slot_idx] += time.perf_counter() - start_ts
|
|
206
213
|
self.ctx.profile.eval_count[fn_call.slot_idx] += num_ext_batch_rows
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Iterator, Optional
|
|
3
3
|
|
|
4
4
|
import pixeltable.catalog as catalog
|
|
5
5
|
import pixeltable.exprs as exprs
|
|
@@ -23,12 +23,15 @@ class InMemoryDataNode(ExecNode):
|
|
|
23
23
|
start_row_id: int
|
|
24
24
|
output_rows: Optional[DataRowBatch]
|
|
25
25
|
|
|
26
|
+
# output_exprs is declared in the superclass, but we redeclare it here with a more specific type
|
|
27
|
+
output_exprs: list[exprs.ColumnRef]
|
|
28
|
+
|
|
26
29
|
def __init__(
|
|
27
30
|
self, tbl: catalog.TableVersion, rows: list[dict[str, Any]],
|
|
28
31
|
row_builder: exprs.RowBuilder, start_row_id: int,
|
|
29
32
|
):
|
|
30
|
-
# we materialize
|
|
31
|
-
output_exprs =
|
|
33
|
+
# we materialize the input slots
|
|
34
|
+
output_exprs = list(row_builder.input_exprs)
|
|
32
35
|
super().__init__(row_builder, output_exprs, [], None)
|
|
33
36
|
assert tbl.is_insertable()
|
|
34
37
|
self.tbl = tbl
|
pixeltable/exec/sql_node.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import warnings
|
|
3
3
|
from decimal import Decimal
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Iterable, Iterator, NamedTuple, Optional
|
|
5
5
|
from uuid import UUID
|
|
6
6
|
|
|
7
7
|
import sqlalchemy as sql
|
|
8
8
|
|
|
9
9
|
import pixeltable.catalog as catalog
|
|
10
10
|
import pixeltable.exprs as exprs
|
|
11
|
+
|
|
11
12
|
from .data_row_batch import DataRowBatch
|
|
12
13
|
from .exec_node import ExecNode
|
|
13
14
|
|
|
@@ -100,7 +101,7 @@ class SqlNode(ExecNode):
|
|
|
100
101
|
# minimize the number of tables that need to be joined to the target table
|
|
101
102
|
self.retarget_rowid_refs(tbl, self.select_list)
|
|
102
103
|
|
|
103
|
-
assert self.sql_elements.
|
|
104
|
+
assert self.sql_elements.contains_all(self.select_list)
|
|
104
105
|
self.set_pk = set_pk
|
|
105
106
|
self.num_pk_cols = 0
|
|
106
107
|
if set_pk:
|
|
@@ -120,13 +121,13 @@ class SqlNode(ExecNode):
|
|
|
120
121
|
def _create_stmt(self) -> sql.Select:
|
|
121
122
|
"""Create Select from local state"""
|
|
122
123
|
|
|
123
|
-
assert self.sql_elements.
|
|
124
|
+
assert self.sql_elements.contains_all(self.select_list)
|
|
124
125
|
sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
|
|
125
126
|
if self.set_pk:
|
|
126
127
|
sql_select_list += self.tbl.tbl_version.store_tbl.pk_columns()
|
|
127
128
|
stmt = sql.select(*sql_select_list)
|
|
128
129
|
|
|
129
|
-
order_by_clause: list[sql.
|
|
130
|
+
order_by_clause: list[sql.ColumnElement] = []
|
|
130
131
|
for e, asc in self.order_by_clause:
|
|
131
132
|
if isinstance(e, exprs.SimilarityExpr):
|
|
132
133
|
order_by_clause.append(e.as_order_by_clause(asc))
|
|
@@ -141,7 +142,7 @@ class SqlNode(ExecNode):
|
|
|
141
142
|
return stmt
|
|
142
143
|
|
|
143
144
|
def _ordering_tbl_ids(self) -> set[UUID]:
|
|
144
|
-
return exprs.Expr.
|
|
145
|
+
return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
|
|
145
146
|
|
|
146
147
|
def to_cte(self) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
|
|
147
148
|
"""
|
|
@@ -182,9 +183,9 @@ class SqlNode(ExecNode):
|
|
|
182
183
|
"""
|
|
183
184
|
# we need to include at least the root
|
|
184
185
|
if refd_tbl_ids is None:
|
|
185
|
-
refd_tbl_ids =
|
|
186
|
+
refd_tbl_ids = set()
|
|
186
187
|
if exact_version_only is None:
|
|
187
|
-
exact_version_only =
|
|
188
|
+
exact_version_only = set()
|
|
188
189
|
candidates = tbl.get_tbl_versions()
|
|
189
190
|
assert len(candidates) > 0
|
|
190
191
|
joined_tbls: list[catalog.TableVersion] = [candidates[0]]
|
|
@@ -193,6 +194,7 @@ class SqlNode(ExecNode):
|
|
|
193
194
|
joined_tbls.append(tbl)
|
|
194
195
|
|
|
195
196
|
first = True
|
|
197
|
+
prev_tbl: catalog.TableVersion
|
|
196
198
|
for tbl in joined_tbls[::-1]:
|
|
197
199
|
if first:
|
|
198
200
|
stmt = stmt.select_from(tbl.store_tbl.sa_tbl)
|
|
@@ -239,22 +241,19 @@ class SqlNode(ExecNode):
|
|
|
239
241
|
def __iter__(self) -> Iterator[DataRowBatch]:
|
|
240
242
|
# run the query; do this here rather than in _open(), exceptions are only expected during iteration
|
|
241
243
|
assert self.ctx.conn is not None
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
pass
|
|
256
|
-
except Exception as e:
|
|
257
|
-
raise e
|
|
244
|
+
with warnings.catch_warnings(record=True) as w:
|
|
245
|
+
stmt = self._create_stmt()
|
|
246
|
+
try:
|
|
247
|
+
# log stmt, if possible
|
|
248
|
+
stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
249
|
+
_logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
|
|
250
|
+
except Exception:
|
|
251
|
+
pass
|
|
252
|
+
self._log_explain(stmt)
|
|
253
|
+
|
|
254
|
+
result_cursor = self.ctx.conn.execute(stmt)
|
|
255
|
+
for warning in w:
|
|
256
|
+
pass
|
|
258
257
|
|
|
259
258
|
tbl_version = self.tbl.tbl_version if self.tbl is not None else None
|
|
260
259
|
output_batch = DataRowBatch(tbl_version, self.row_builder)
|
|
@@ -350,7 +349,7 @@ class SqlScanNode(SqlNode):
|
|
|
350
349
|
def _create_stmt(self) -> sql.Select:
|
|
351
350
|
stmt = super()._create_stmt()
|
|
352
351
|
where_clause_tbl_ids = self.where_clause.tbl_ids() if self.where_clause is not None else set()
|
|
353
|
-
refd_tbl_ids = exprs.Expr.
|
|
352
|
+
refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
|
|
354
353
|
stmt = self.create_from_clause(
|
|
355
354
|
self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only})
|
|
356
355
|
|
|
@@ -386,7 +385,7 @@ class SqlLookupNode(SqlNode):
|
|
|
386
385
|
|
|
387
386
|
def _create_stmt(self) -> sql.Select:
|
|
388
387
|
stmt = super()._create_stmt()
|
|
389
|
-
refd_tbl_ids = exprs.Expr.
|
|
388
|
+
refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | self._ordering_tbl_ids()
|
|
390
389
|
stmt = self.create_from_clause(self.tbl, stmt, refd_tbl_ids)
|
|
391
390
|
stmt = stmt.where(self.where_clause)
|
|
392
391
|
return stmt
|