pixeltable 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/insertable_table.py +3 -3
- pixeltable/catalog/table.py +2 -2
- pixeltable/catalog/table_version.py +3 -2
- pixeltable/catalog/view.py +1 -1
- pixeltable/dataframe.py +52 -27
- pixeltable/env.py +109 -4
- pixeltable/exec/__init__.py +1 -1
- pixeltable/exec/aggregation_node.py +3 -3
- pixeltable/exec/cache_prefetch_node.py +13 -7
- pixeltable/exec/component_iteration_node.py +3 -9
- pixeltable/exec/data_row_batch.py +17 -5
- pixeltable/exec/exec_node.py +32 -12
- pixeltable/exec/expr_eval/__init__.py +1 -0
- pixeltable/exec/expr_eval/evaluators.py +240 -0
- pixeltable/exec/expr_eval/expr_eval_node.py +408 -0
- pixeltable/exec/expr_eval/globals.py +113 -0
- pixeltable/exec/expr_eval/row_buffer.py +76 -0
- pixeltable/exec/expr_eval/schedulers.py +240 -0
- pixeltable/exec/in_memory_data_node.py +2 -2
- pixeltable/exec/row_update_node.py +14 -14
- pixeltable/exec/sql_node.py +2 -2
- pixeltable/exprs/column_ref.py +5 -1
- pixeltable/exprs/data_row.py +50 -40
- pixeltable/exprs/expr.py +57 -12
- pixeltable/exprs/function_call.py +54 -19
- pixeltable/exprs/inline_expr.py +12 -21
- pixeltable/exprs/literal.py +25 -8
- pixeltable/exprs/row_builder.py +25 -2
- pixeltable/func/aggregate_function.py +4 -0
- pixeltable/func/callable_function.py +54 -4
- pixeltable/func/expr_template_function.py +5 -1
- pixeltable/func/function.py +48 -7
- pixeltable/func/query_template_function.py +16 -7
- pixeltable/func/udf.py +7 -1
- pixeltable/functions/__init__.py +1 -1
- pixeltable/functions/anthropic.py +97 -21
- pixeltable/functions/gemini.py +2 -6
- pixeltable/functions/openai.py +219 -28
- pixeltable/globals.py +2 -3
- pixeltable/io/hf_datasets.py +1 -1
- pixeltable/io/label_studio.py +5 -5
- pixeltable/io/parquet.py +1 -1
- pixeltable/metadata/__init__.py +2 -1
- pixeltable/plan.py +24 -9
- pixeltable/store.py +6 -0
- pixeltable/type_system.py +73 -36
- pixeltable/utils/arrow.py +3 -8
- pixeltable/utils/console_output.py +41 -0
- pixeltable/utils/filecache.py +1 -1
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/METADATA +4 -1
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/RECORD +55 -49
- pixeltable/exec/expr_eval_node.py +0 -232
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/LICENSE +0 -0
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/WHEEL +0 -0
- {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import datetime
|
|
5
|
+
import itertools
|
|
6
|
+
import logging
|
|
7
|
+
import sys
|
|
8
|
+
from typing import Iterator, Any, Optional, Callable, cast
|
|
9
|
+
|
|
10
|
+
from pixeltable import exprs
|
|
11
|
+
from pixeltable import func
|
|
12
|
+
from .globals import Dispatcher, Evaluator, FnCallArgs
|
|
13
|
+
|
|
14
|
+
_logger = logging.getLogger('pixeltable')
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DefaultExprEvaluator(Evaluator):
|
|
18
|
+
"""
|
|
19
|
+
Standard expression evaluation using Expr.eval().
|
|
20
|
+
|
|
21
|
+
Creates one task per set of rows handed to schedule().
|
|
22
|
+
|
|
23
|
+
TODO:
|
|
24
|
+
- parallelize via Ray
|
|
25
|
+
"""
|
|
26
|
+
e: exprs.Expr
|
|
27
|
+
|
|
28
|
+
def __init__(self, e: exprs.Expr, dispatcher: Dispatcher):
|
|
29
|
+
super().__init__(dispatcher)
|
|
30
|
+
self.e = e
|
|
31
|
+
|
|
32
|
+
def schedule(self, rows: list[exprs.DataRow], slot_idx: int) -> None:
|
|
33
|
+
assert self.e.slot_idx >= 0
|
|
34
|
+
task = asyncio.create_task(self.eval(rows))
|
|
35
|
+
self.dispatcher.register_task(task)
|
|
36
|
+
|
|
37
|
+
async def eval(self, rows: list[exprs.DataRow]) -> None:
|
|
38
|
+
rows_with_excs: set[int] = set() # records idxs into rows
|
|
39
|
+
for idx, row in enumerate(rows):
|
|
40
|
+
assert not row.has_val[self.e.slot_idx] and not row.has_exc(self.e.slot_idx)
|
|
41
|
+
if asyncio.current_task().cancelled() or self.dispatcher.exc_event.is_set():
|
|
42
|
+
return
|
|
43
|
+
try:
|
|
44
|
+
self.e.eval(row, self.dispatcher.row_builder)
|
|
45
|
+
except Exception as exc:
|
|
46
|
+
_, _, exc_tb = sys.exc_info()
|
|
47
|
+
row.set_exc(self.e.slot_idx, exc)
|
|
48
|
+
rows_with_excs.add(idx)
|
|
49
|
+
self.dispatcher.dispatch_exc([row], self.e.slot_idx, exc_tb)
|
|
50
|
+
self.dispatcher.dispatch([rows[i] for i in range(len(rows)) if i not in rows_with_excs])
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class FnCallEvaluator(Evaluator):
|
|
54
|
+
"""
|
|
55
|
+
Evaluates function calls:
|
|
56
|
+
- batched functions (sync and async): one task per batch
|
|
57
|
+
- async functions: one task per row
|
|
58
|
+
- the rest: one task per set of rows handed to schedule()
|
|
59
|
+
|
|
60
|
+
TODO:
|
|
61
|
+
- adaptive batching: finding the optimal batch size based on observed execution times
|
|
62
|
+
"""
|
|
63
|
+
fn_call: exprs.FunctionCall
|
|
64
|
+
fn: func.CallableFunction
|
|
65
|
+
scalar_py_fn: Optional[Callable] # only set for non-batching CallableFunctions
|
|
66
|
+
|
|
67
|
+
# only set if fn.is_batched
|
|
68
|
+
call_args_queue: Optional[asyncio.Queue[FnCallArgs]] # FnCallArgs waiting for execution
|
|
69
|
+
batch_size: Optional[int]
|
|
70
|
+
|
|
71
|
+
def __init__(self, fn_call: exprs.FunctionCall, dispatcher: Dispatcher):
|
|
72
|
+
super().__init__(dispatcher)
|
|
73
|
+
self.fn_call = fn_call
|
|
74
|
+
self.fn = cast(func.CallableFunction, fn_call.fn)
|
|
75
|
+
if isinstance(self.fn, func.CallableFunction) and self.fn.is_batched:
|
|
76
|
+
self.call_args_queue = asyncio.Queue[FnCallArgs]()
|
|
77
|
+
# we're not supplying sample arguments there, they're ignored anyway
|
|
78
|
+
self.batch_size = self.fn.get_batch_size()
|
|
79
|
+
self.scalar_py_fn = None
|
|
80
|
+
else:
|
|
81
|
+
self.call_args_queue = None
|
|
82
|
+
self.batch_size = None
|
|
83
|
+
if isinstance(self.fn, func.CallableFunction):
|
|
84
|
+
self.scalar_py_fn = self.fn.py_fn
|
|
85
|
+
else:
|
|
86
|
+
self.scalar_py_fn = None
|
|
87
|
+
|
|
88
|
+
def schedule(self, rows: list[exprs.DataRow], slot_idx: int) -> None:
|
|
89
|
+
assert self.fn_call.slot_idx >= 0
|
|
90
|
+
|
|
91
|
+
# create FnCallArgs for incoming rows
|
|
92
|
+
skip_rows: list[exprs.DataRow] = [] # skip rows with Nones in non-nullable parameters
|
|
93
|
+
rows_call_args: list[FnCallArgs] = []
|
|
94
|
+
for row in rows:
|
|
95
|
+
args_kwargs = self.fn_call.make_args(row)
|
|
96
|
+
if args_kwargs is None:
|
|
97
|
+
# nothing to do here
|
|
98
|
+
row[self.fn_call.slot_idx] = None
|
|
99
|
+
skip_rows.append(row)
|
|
100
|
+
else:
|
|
101
|
+
args, kwargs = args_kwargs
|
|
102
|
+
rows_call_args.append(FnCallArgs(self.fn_call, [row], args=args, kwargs=kwargs))
|
|
103
|
+
|
|
104
|
+
if len(skip_rows) > 0:
|
|
105
|
+
self.dispatcher.dispatch(skip_rows)
|
|
106
|
+
|
|
107
|
+
if self.batch_size is not None:
|
|
108
|
+
if not self.is_closed and (len(rows_call_args) + self.call_args_queue.qsize() < self.batch_size):
|
|
109
|
+
# we don't have enough FnCallArgs for a batch, so add them to the queue
|
|
110
|
+
for item in rows_call_args:
|
|
111
|
+
self.call_args_queue.put_nowait(item)
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
# create one task per batch
|
|
115
|
+
combined_call_args = itertools.chain(self._queued_call_args_iter(), rows_call_args)
|
|
116
|
+
while True:
|
|
117
|
+
call_args_batch = list(itertools.islice(combined_call_args, self.batch_size))
|
|
118
|
+
if len(call_args_batch) == 0:
|
|
119
|
+
break
|
|
120
|
+
if len(call_args_batch) < self.batch_size and not self.is_closed:
|
|
121
|
+
# we don't have a full batch left: return the rest to the queue
|
|
122
|
+
assert self.call_args_queue.empty() # we saw all queued items
|
|
123
|
+
for item in call_args_batch:
|
|
124
|
+
self.call_args_queue.put_nowait(item)
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
# turn call_args_batch into a single batched FnCallArgs
|
|
128
|
+
_logger.debug(f'Creating batch of size {len(call_args_batch)} for slot {slot_idx}')
|
|
129
|
+
batched_call_args = self._create_batch_call_args(call_args_batch)
|
|
130
|
+
if self.fn_call.resource_pool is not None:
|
|
131
|
+
# hand the call off to the resource pool's scheduler
|
|
132
|
+
scheduler = self.dispatcher.schedulers[self.fn_call.resource_pool]
|
|
133
|
+
scheduler.submit(batched_call_args)
|
|
134
|
+
else:
|
|
135
|
+
task = asyncio.create_task(self.eval_batch(batched_call_args))
|
|
136
|
+
self.dispatcher.register_task(task)
|
|
137
|
+
|
|
138
|
+
elif self.fn.is_async:
|
|
139
|
+
if self.fn_call.resource_pool is not None:
|
|
140
|
+
# hand the call off to the resource pool's scheduler
|
|
141
|
+
scheduler = self.dispatcher.schedulers[self.fn_call.resource_pool]
|
|
142
|
+
for item in rows_call_args:
|
|
143
|
+
scheduler.submit(item)
|
|
144
|
+
else:
|
|
145
|
+
# create one task per call
|
|
146
|
+
for item in rows_call_args:
|
|
147
|
+
task = asyncio.create_task(self.eval_async(item))
|
|
148
|
+
self.dispatcher.register_task(task)
|
|
149
|
+
|
|
150
|
+
else:
|
|
151
|
+
# create a single task for all rows
|
|
152
|
+
task = asyncio.create_task(self.eval(rows_call_args))
|
|
153
|
+
self.dispatcher.register_task(task)
|
|
154
|
+
|
|
155
|
+
def _queued_call_args_iter(self) -> Iterator[FnCallArgs]:
|
|
156
|
+
while not self.call_args_queue.empty():
|
|
157
|
+
yield self.call_args_queue.get_nowait()
|
|
158
|
+
|
|
159
|
+
def _create_batch_call_args(self, call_args: list[FnCallArgs]) -> FnCallArgs:
|
|
160
|
+
"""Roll call_args into a single batched FnCallArgs"""
|
|
161
|
+
batch_args: list[list[Optional[Any]]] = [[None] * len(call_args) for _ in range(len(self.fn_call.args))]
|
|
162
|
+
batch_kwargs: dict[str, list[Optional[Any]]] = {k: [None] * len(call_args) for k in self.fn_call.kwargs.keys()}
|
|
163
|
+
assert isinstance(self.fn, func.CallableFunction)
|
|
164
|
+
for i, item in enumerate(call_args):
|
|
165
|
+
for j in range(len(item.args)):
|
|
166
|
+
batch_args[j][i] = item.args[j]
|
|
167
|
+
for k in item.kwargs.keys():
|
|
168
|
+
batch_kwargs[k][i] = item.kwargs[k]
|
|
169
|
+
return FnCallArgs(
|
|
170
|
+
self.fn_call, [item.row for item in call_args], batch_args=batch_args, batch_kwargs=batch_kwargs)
|
|
171
|
+
|
|
172
|
+
async def eval_batch(self, batched_call_args: FnCallArgs) -> None:
|
|
173
|
+
result_batch: list[Any]
|
|
174
|
+
try:
|
|
175
|
+
if self.fn.is_async:
|
|
176
|
+
result_batch = await self.fn.aexec_batch(
|
|
177
|
+
*batched_call_args.batch_args, **batched_call_args.batch_kwargs)
|
|
178
|
+
else:
|
|
179
|
+
# check for cancellation before starting something potentially long-running
|
|
180
|
+
if asyncio.current_task().cancelled() or self.dispatcher.exc_event.is_set():
|
|
181
|
+
return
|
|
182
|
+
result_batch = self.fn.exec_batch(batched_call_args.batch_args, batched_call_args.batch_kwargs)
|
|
183
|
+
except Exception as exc:
|
|
184
|
+
_, _, exc_tb = sys.exc_info()
|
|
185
|
+
for row in batched_call_args.rows:
|
|
186
|
+
row.set_exc(self.fn_call.slot_idx, exc)
|
|
187
|
+
self.dispatcher.dispatch_exc(batched_call_args.rows, self.fn_call.slot_idx, exc_tb)
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
for i, row in enumerate(batched_call_args.rows):
|
|
191
|
+
row[self.fn_call.slot_idx] = result_batch[i]
|
|
192
|
+
self.dispatcher.dispatch(batched_call_args.rows)
|
|
193
|
+
|
|
194
|
+
async def eval_async(self, call_args: FnCallArgs) -> None:
|
|
195
|
+
assert len(call_args.rows) == 1
|
|
196
|
+
assert not call_args.row.has_val[self.fn_call.slot_idx]
|
|
197
|
+
assert not call_args.row.has_exc(self.fn_call.slot_idx)
|
|
198
|
+
|
|
199
|
+
try:
|
|
200
|
+
start_ts = datetime.datetime.now()
|
|
201
|
+
_logger.debug(f'Start evaluating slot {self.fn_call.slot_idx}')
|
|
202
|
+
call_args.row[self.fn_call.slot_idx] = await self.fn.aexec(*call_args.args, **call_args.kwargs)
|
|
203
|
+
end_ts = datetime.datetime.now()
|
|
204
|
+
_logger.debug(f'Evaluated slot {self.fn_call.slot_idx} in {end_ts - start_ts}')
|
|
205
|
+
self.dispatcher.dispatch([call_args.row])
|
|
206
|
+
except Exception as exc:
|
|
207
|
+
import anthropic
|
|
208
|
+
if isinstance(exc, anthropic.RateLimitError):
|
|
209
|
+
_logger.debug(f'RateLimitError: {exc}')
|
|
210
|
+
_, _, exc_tb = sys.exc_info()
|
|
211
|
+
call_args.row.set_exc(self.fn_call.slot_idx, exc)
|
|
212
|
+
self.dispatcher.dispatch_exc(call_args.rows, self.fn_call.slot_idx, exc_tb)
|
|
213
|
+
|
|
214
|
+
async def eval(self, call_args_batch: list[FnCallArgs]) -> None:
|
|
215
|
+
rows_with_excs: set[int] = set() # records idxs into 'rows'
|
|
216
|
+
for idx, item in enumerate(call_args_batch):
|
|
217
|
+
assert len(item.rows) == 1
|
|
218
|
+
assert not item.row.has_val[self.fn_call.slot_idx]
|
|
219
|
+
assert not item.row.has_exc(self.fn_call.slot_idx)
|
|
220
|
+
# check for cancellation before starting something potentially long-running
|
|
221
|
+
if asyncio.current_task().cancelled() or self.dispatcher.exc_event.is_set():
|
|
222
|
+
return
|
|
223
|
+
try:
|
|
224
|
+
item.row[self.fn_call.slot_idx] = self.scalar_py_fn(*item.args, **item.kwargs)
|
|
225
|
+
except Exception as exc:
|
|
226
|
+
_, _, exc_tb = sys.exc_info()
|
|
227
|
+
item.row.set_exc(self.fn_call.slot_idx, exc)
|
|
228
|
+
rows_with_excs.add(idx)
|
|
229
|
+
self.dispatcher.dispatch_exc(item.rows, self.fn_call.slot_idx, exc_tb)
|
|
230
|
+
self.dispatcher.dispatch(
|
|
231
|
+
[call_args_batch[i].row for i in range(len(call_args_batch)) if i not in rows_with_excs])
|
|
232
|
+
|
|
233
|
+
def _close(self) -> None:
|
|
234
|
+
"""Create a task for the incomplete batch of queued FnCallArgs, if any"""
|
|
235
|
+
_logger.debug(f'FnCallEvaluator.close(): slot_idx={self.fn_call.slot_idx}')
|
|
236
|
+
if self.call_args_queue is None or self.call_args_queue.empty():
|
|
237
|
+
return
|
|
238
|
+
batched_call_args = self._create_batch_call_args(list(self._queued_call_args_iter()))
|
|
239
|
+
task = asyncio.create_task(self.eval_batch(batched_call_args))
|
|
240
|
+
self.dispatcher.register_task(task)
|
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import traceback
|
|
6
|
+
from types import TracebackType
|
|
7
|
+
from typing import Iterable, AsyncIterator, Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
import pixeltable.exceptions as excs
|
|
12
|
+
from pixeltable import exprs
|
|
13
|
+
from pixeltable import func
|
|
14
|
+
from .evaluators import DefaultExprEvaluator, FnCallEvaluator
|
|
15
|
+
from .globals import Evaluator, Scheduler
|
|
16
|
+
from .row_buffer import RowBuffer
|
|
17
|
+
from .schedulers import SCHEDULERS
|
|
18
|
+
from ..data_row_batch import DataRowBatch
|
|
19
|
+
from ..exec_node import ExecNode
|
|
20
|
+
|
|
21
|
+
_logger = logging.getLogger('pixeltable')
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ExprEvalNode(ExecNode):
|
|
26
|
+
"""
|
|
27
|
+
Expression evaluation
|
|
28
|
+
|
|
29
|
+
Resource management:
|
|
30
|
+
- the execution system tries to limit total memory consumption by limiting the number of rows that are in
|
|
31
|
+
circulation
|
|
32
|
+
- during execution, slots that aren't part of the output are garbage collected as soon as their direct dependents
|
|
33
|
+
are materialized
|
|
34
|
+
|
|
35
|
+
TODO:
|
|
36
|
+
- Literal handling: currently, Literal values are copied into slots via the normal evaluation mechanism, which is
|
|
37
|
+
needless overhead; instead: pre-populate Literal slots in _init_row()
|
|
38
|
+
- local model inference on gpu: currently, no attempt is made to ensure that models can fit onto the gpu
|
|
39
|
+
simultaneously, which will cause errors; instead, the execution should be divided into sequential phases, each
|
|
40
|
+
of which only contains a subset of the models which is known to fit onto the gpu simultaneously
|
|
41
|
+
"""
|
|
42
|
+
maintain_input_order: bool # True if we're returning rows in the order we received them from our input
|
|
43
|
+
num_dependencies: np.ndarray # number of dependencies for our output slots; indexed by slot idx
|
|
44
|
+
outputs: np.ndarray # bool per slot; True if this slot is part of our output
|
|
45
|
+
slot_evaluators: dict[int, Evaluator] # key: slot idx
|
|
46
|
+
schedulers: dict[str, Scheduler] # key: resource pool name
|
|
47
|
+
gc_targets: np.ndarray # bool per slot; True if this is an intermediate expr (ie, not part of our output)
|
|
48
|
+
eval_ctx: np.ndarray # bool per slot; EvalCtx.slot_idxs as a mask
|
|
49
|
+
|
|
50
|
+
# execution state
|
|
51
|
+
tasks: set[asyncio.Task] # collects all running tasks to prevent them from getting gc'd
|
|
52
|
+
exc_event: asyncio.Event # set if an exception needs to be propagated
|
|
53
|
+
error: Optional[Union[excs.Error, excs.ExprEvalError]] # exception that needs to be propagated
|
|
54
|
+
completed_rows: asyncio.Queue[exprs.DataRow] # rows that have completed evaluation
|
|
55
|
+
completed_event: asyncio.Event # set when completed_rows is non-empty
|
|
56
|
+
input_iter: AsyncIterator[DataRowBatch]
|
|
57
|
+
current_input_batch: Optional[DataRowBatch] # batch from which we're currently consuming rows
|
|
58
|
+
input_row_idx: int # next row to consume from current_input_batch
|
|
59
|
+
next_input_batch: Optional[DataRowBatch] # read-ahead input batch
|
|
60
|
+
avail_input_rows: int # total number across both current_/next_input_batch
|
|
61
|
+
input_complete: bool # True if we've received all input batches
|
|
62
|
+
num_in_flight: int # number of dispatched rows that haven't completed
|
|
63
|
+
row_pos_map: Optional[dict[int, int]] # id(row) -> position of row in input; only set if maintain_input_order
|
|
64
|
+
output_buffer: RowBuffer # holds rows that are ready to be returned, in order
|
|
65
|
+
|
|
66
|
+
# debugging
|
|
67
|
+
num_input_rows: int
|
|
68
|
+
num_output_rows: int
|
|
69
|
+
|
|
70
|
+
BATCH_SIZE = 64
|
|
71
|
+
MAX_BUFFERED_ROWS = 512 # maximum number of rows that have been dispatched but not yet returned
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self, row_builder: exprs.RowBuilder, output_exprs: Iterable[exprs.Expr], input_exprs: Iterable[exprs.Expr],
|
|
75
|
+
input: ExecNode, maintain_input_order: bool = True
|
|
76
|
+
):
|
|
77
|
+
super().__init__(row_builder, output_exprs, input_exprs, input)
|
|
78
|
+
self.maintain_input_order = maintain_input_order
|
|
79
|
+
self.num_dependencies = np.sum(row_builder.dependencies, axis=1)
|
|
80
|
+
self.outputs = np.zeros(row_builder.num_materialized, dtype=bool)
|
|
81
|
+
output_slot_idxs = [e.slot_idx for e in output_exprs]
|
|
82
|
+
self.outputs[output_slot_idxs] = True
|
|
83
|
+
self.tasks = set()
|
|
84
|
+
|
|
85
|
+
self.gc_targets = np.ones(row_builder.num_materialized, dtype=bool)
|
|
86
|
+
# we need to retain all slots that are part of the output
|
|
87
|
+
self.gc_targets[[e.slot_idx for e in row_builder.output_exprs]] = False
|
|
88
|
+
|
|
89
|
+
output_ctx = self.row_builder.create_eval_ctx(output_exprs, exclude=input_exprs)
|
|
90
|
+
self.eval_ctx = np.zeros(row_builder.num_materialized, dtype=bool)
|
|
91
|
+
self.eval_ctx[output_ctx.slot_idxs] = True
|
|
92
|
+
self.error = None
|
|
93
|
+
|
|
94
|
+
self.input_iter = self.input.__aiter__()
|
|
95
|
+
self.current_input_batch = None
|
|
96
|
+
self.next_input_batch = None
|
|
97
|
+
self.input_row_idx = 0
|
|
98
|
+
self.avail_input_rows = 0
|
|
99
|
+
self.input_complete = False
|
|
100
|
+
self.num_in_flight = 0
|
|
101
|
+
self.row_pos_map = None
|
|
102
|
+
self.output_buffer = RowBuffer(self.MAX_BUFFERED_ROWS)
|
|
103
|
+
|
|
104
|
+
self.num_input_rows = 0
|
|
105
|
+
self.num_output_rows = 0
|
|
106
|
+
|
|
107
|
+
self.slot_evaluators = {}
|
|
108
|
+
self.schedulers = {}
|
|
109
|
+
self._init_slot_evaluators()
|
|
110
|
+
|
|
111
|
+
def set_input_order(self, maintain_input_order: bool) -> None:
|
|
112
|
+
self.maintain_input_order = maintain_input_order
|
|
113
|
+
|
|
114
|
+
def _init_slot_evaluators(self) -> None:
|
|
115
|
+
"""Create slot evaluators and resource pool schedulers"""
|
|
116
|
+
resource_pools: set[str] = set()
|
|
117
|
+
for slot_idx in range(self.row_builder.num_materialized):
|
|
118
|
+
expr = self.row_builder.unique_exprs[slot_idx]
|
|
119
|
+
if (
|
|
120
|
+
isinstance(expr, exprs.FunctionCall)
|
|
121
|
+
# ExprTemplateFunction and AggregateFunction calls are best handled by FunctionCall.eval()
|
|
122
|
+
and not isinstance(expr.fn, func.ExprTemplateFunction)
|
|
123
|
+
and not isinstance(expr.fn, func.AggregateFunction)
|
|
124
|
+
):
|
|
125
|
+
if expr.resource_pool is not None:
|
|
126
|
+
resource_pools.add(expr.resource_pool)
|
|
127
|
+
self.slot_evaluators[slot_idx] = FnCallEvaluator(expr, self)
|
|
128
|
+
else:
|
|
129
|
+
self.slot_evaluators[slot_idx] = DefaultExprEvaluator(expr, self)
|
|
130
|
+
|
|
131
|
+
async def _fetch_input_batch(self) -> None:
|
|
132
|
+
"""
|
|
133
|
+
Fetches another batch from our input or sets input_complete to True if there are no more batches.
|
|
134
|
+
|
|
135
|
+
- stores the batch in current_input_batch, if not already set, or next_input_batch
|
|
136
|
+
- updates row_pos_map, if needed
|
|
137
|
+
"""
|
|
138
|
+
assert not self.input_complete
|
|
139
|
+
try:
|
|
140
|
+
batch = await self.input_iter.__anext__()
|
|
141
|
+
assert self.next_input_batch is None
|
|
142
|
+
if self.current_input_batch is None:
|
|
143
|
+
self.current_input_batch = batch
|
|
144
|
+
else:
|
|
145
|
+
self.next_input_batch = batch
|
|
146
|
+
if self.maintain_input_order:
|
|
147
|
+
for idx, row in enumerate(batch.rows):
|
|
148
|
+
self.row_pos_map[id(row)] = self.num_input_rows + idx
|
|
149
|
+
self.num_input_rows += len(batch)
|
|
150
|
+
self.avail_input_rows += len(batch)
|
|
151
|
+
_logger.debug(f'adding input: batch_size={len(batch)} #input_rows={self.num_input_rows} #avail={self.avail_input_rows}')
|
|
152
|
+
except StopAsyncIteration:
|
|
153
|
+
self.input_complete = True
|
|
154
|
+
_logger.debug(f'finished input: #input_rows={self.num_input_rows}, #avail={self.avail_input_rows}')
|
|
155
|
+
except excs.Error as err:
|
|
156
|
+
self.error = err
|
|
157
|
+
self.exc_event.set()
|
|
158
|
+
# TODO: should we also handle Exception here and create an excs.Error from it?
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def total_buffered(self) -> int:
|
|
162
|
+
return self.num_in_flight + self.completed_rows.qsize() + self.output_buffer.num_rows
|
|
163
|
+
|
|
164
|
+
def _dispatch_input_rows(self) -> None:
|
|
165
|
+
"""Dispatch the maximum number of input rows, given total_buffered; does not block"""
|
|
166
|
+
if self.avail_input_rows == 0:
|
|
167
|
+
return
|
|
168
|
+
num_rows = min(self.MAX_BUFFERED_ROWS - self.total_buffered, self.avail_input_rows)
|
|
169
|
+
assert num_rows >= 0
|
|
170
|
+
if num_rows == 0:
|
|
171
|
+
return
|
|
172
|
+
assert self.current_input_batch is not None
|
|
173
|
+
avail_current_batch_rows = len(self.current_input_batch) - self.input_row_idx
|
|
174
|
+
|
|
175
|
+
rows: list[exprs.DataRow]
|
|
176
|
+
if avail_current_batch_rows > num_rows:
|
|
177
|
+
# we only need rows from current_input_batch
|
|
178
|
+
rows = self.current_input_batch.rows[self.input_row_idx:self.input_row_idx + num_rows]
|
|
179
|
+
self.input_row_idx += num_rows
|
|
180
|
+
else:
|
|
181
|
+
# we need rows from both current_/next_input_batch
|
|
182
|
+
rows = self.current_input_batch.rows[self.input_row_idx:]
|
|
183
|
+
self.current_input_batch = self.next_input_batch
|
|
184
|
+
self.next_input_batch = None
|
|
185
|
+
self.input_row_idx = 0
|
|
186
|
+
num_remaining = num_rows - len(rows)
|
|
187
|
+
if num_remaining > 0:
|
|
188
|
+
rows.extend(self.current_input_batch.rows[:num_remaining])
|
|
189
|
+
self.input_row_idx = num_remaining
|
|
190
|
+
self.avail_input_rows -= num_rows
|
|
191
|
+
self.num_in_flight += num_rows
|
|
192
|
+
self._log_state(f'dispatch input ({num_rows})')
|
|
193
|
+
|
|
194
|
+
self._init_input_rows(rows)
|
|
195
|
+
self.dispatch(rows)
|
|
196
|
+
|
|
197
|
+
def _log_state(self, prefix: str) -> None:
|
|
198
|
+
_logger.debug(
|
|
199
|
+
f'{prefix}: #in-flight={self.num_in_flight} #complete={self.completed_rows.qsize()} '
|
|
200
|
+
f'#output-buffer={self.output_buffer.num_rows} #ready={self.output_buffer.num_ready} '
|
|
201
|
+
f'total-buffered={self.total_buffered} #avail={self.avail_input_rows} '
|
|
202
|
+
f'#input={self.num_input_rows} #output={self.num_output_rows}'
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def _init_schedulers(self) -> None:
|
|
206
|
+
resource_pools = {
|
|
207
|
+
eval.fn_call.resource_pool for eval in self.slot_evaluators.values() if isinstance(eval, FnCallEvaluator)
|
|
208
|
+
}
|
|
209
|
+
resource_pools = {pool for pool in resource_pools if pool is not None}
|
|
210
|
+
for pool_name in resource_pools:
|
|
211
|
+
for scheduler in SCHEDULERS:
|
|
212
|
+
if scheduler.matches(pool_name):
|
|
213
|
+
self.schedulers[pool_name] = scheduler(pool_name, self)
|
|
214
|
+
break
|
|
215
|
+
if pool_name not in self.schedulers:
|
|
216
|
+
raise RuntimeError(f'No scheduler found for resource pool {pool_name}')
|
|
217
|
+
|
|
218
|
+
async def __aiter__(self) -> AsyncIterator[DataRowBatch]:
|
|
219
|
+
"""
|
|
220
|
+
Main event loop
|
|
221
|
+
|
|
222
|
+
Goals:
|
|
223
|
+
- return completed DataRowBatches as soon as they become available
|
|
224
|
+
- maximize the number of rows in flight in order to maximize parallelism, up to the given limit
|
|
225
|
+
"""
|
|
226
|
+
# initialize completed_rows and events, now that we have the correct event loop
|
|
227
|
+
self.completed_rows = asyncio.Queue[exprs.DataRow]()
|
|
228
|
+
self.exc_event = asyncio.Event()
|
|
229
|
+
self.completed_event = asyncio.Event()
|
|
230
|
+
self._init_schedulers()
|
|
231
|
+
if self.maintain_input_order:
|
|
232
|
+
self.row_pos_map = {}
|
|
233
|
+
self.output_buffer.set_row_pos_map(self.row_pos_map)
|
|
234
|
+
|
|
235
|
+
row: exprs.DataRow
|
|
236
|
+
exc_event_aw = asyncio.create_task(self.exc_event.wait(), name='exc_event.wait()')
|
|
237
|
+
input_batch_aw: Optional[asyncio.Task] = None
|
|
238
|
+
completed_aw: Optional[asyncio.Task] = None
|
|
239
|
+
|
|
240
|
+
try:
|
|
241
|
+
while True:
|
|
242
|
+
# process completed rows before doing anything else
|
|
243
|
+
while not self.completed_rows.empty():
|
|
244
|
+
# move completed rows to output buffer
|
|
245
|
+
while not self.completed_rows.empty():
|
|
246
|
+
row = self.completed_rows.get_nowait()
|
|
247
|
+
self.output_buffer.add_row(row)
|
|
248
|
+
if self.row_pos_map is not None:
|
|
249
|
+
self.row_pos_map.pop(id(row))
|
|
250
|
+
|
|
251
|
+
self._log_state('processed completed')
|
|
252
|
+
# return as many batches as we have available
|
|
253
|
+
while self.output_buffer.num_ready >= self.BATCH_SIZE:
|
|
254
|
+
batch_rows = self.output_buffer.get_rows(self.BATCH_SIZE)
|
|
255
|
+
self.num_output_rows += len(batch_rows)
|
|
256
|
+
# make sure we top up our in-flight rows before yielding
|
|
257
|
+
self._dispatch_input_rows()
|
|
258
|
+
self._log_state(f'yielding {len(batch_rows)} rows')
|
|
259
|
+
yield DataRowBatch(tbl=None, row_builder=self.row_builder, rows=batch_rows)
|
|
260
|
+
# at this point, we may have more completed rows
|
|
261
|
+
|
|
262
|
+
assert self.completed_rows.empty() # all completed rows should be sitting in output_buffer
|
|
263
|
+
self.completed_event.clear()
|
|
264
|
+
if self.input_complete and self.num_in_flight == 0:
|
|
265
|
+
# there is no more input and nothing left to wait for
|
|
266
|
+
assert self.avail_input_rows == 0
|
|
267
|
+
if self.output_buffer.num_ready > 0:
|
|
268
|
+
assert self.output_buffer.num_rows == self.output_buffer.num_ready
|
|
269
|
+
# yield the leftover rows
|
|
270
|
+
batch_rows = self.output_buffer.get_rows(self.output_buffer.num_ready)
|
|
271
|
+
self.num_output_rows += len(batch_rows)
|
|
272
|
+
self._log_state(f'yielding {len(batch_rows)} rows')
|
|
273
|
+
yield DataRowBatch(tbl=None, row_builder=self.row_builder, rows=batch_rows)
|
|
274
|
+
|
|
275
|
+
assert self.output_buffer.num_rows == 0
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
if self.input_complete and self.avail_input_rows == 0:
|
|
279
|
+
# no more input rows to dispatch, but we're still waiting for rows to finish:
|
|
280
|
+
# close all slot evaluators to flush queued rows
|
|
281
|
+
for evaluator in self.slot_evaluators.values():
|
|
282
|
+
evaluator.close()
|
|
283
|
+
|
|
284
|
+
# we don't have a full batch of rows at this point and need to wait
|
|
285
|
+
aws = {exc_event_aw} # always wait for an exception
|
|
286
|
+
if self.next_input_batch is None and not self.input_complete:
|
|
287
|
+
# also wait for another batch if we don't have a read-ahead batch yet
|
|
288
|
+
if input_batch_aw is None:
|
|
289
|
+
input_batch_aw = asyncio.create_task(self._fetch_input_batch(), name='_fetch_input_batch()')
|
|
290
|
+
aws.add(input_batch_aw)
|
|
291
|
+
if self.num_in_flight > 0:
|
|
292
|
+
# also wait for more rows to complete
|
|
293
|
+
if completed_aw is None:
|
|
294
|
+
completed_aw = asyncio.create_task(self.completed_event.wait(), name='completed.wait()')
|
|
295
|
+
aws.add(completed_aw)
|
|
296
|
+
done, pending = await asyncio.wait(aws, return_when=asyncio.FIRST_COMPLETED)
|
|
297
|
+
|
|
298
|
+
if self.exc_event.is_set():
|
|
299
|
+
# we got an exception that we need to propagate through __iter__()
|
|
300
|
+
_logger.debug(f'Propagating exception {self.error}')
|
|
301
|
+
raise self.error
|
|
302
|
+
if completed_aw in done:
|
|
303
|
+
self._log_state('completed_aw done')
|
|
304
|
+
completed_aw = None
|
|
305
|
+
if input_batch_aw in done:
|
|
306
|
+
self._dispatch_input_rows()
|
|
307
|
+
input_batch_aw = None
|
|
308
|
+
|
|
309
|
+
finally:
|
|
310
|
+
# task cleanup
|
|
311
|
+
active_tasks = {exc_event_aw}
|
|
312
|
+
if input_batch_aw is not None:
|
|
313
|
+
active_tasks.add(input_batch_aw)
|
|
314
|
+
if completed_aw is not None:
|
|
315
|
+
active_tasks.add(completed_aw)
|
|
316
|
+
active_tasks.update(self.tasks)
|
|
317
|
+
for task in active_tasks:
|
|
318
|
+
if not task.done():
|
|
319
|
+
task.cancel()
|
|
320
|
+
_ = await asyncio.gather(*active_tasks, return_exceptions=True)
|
|
321
|
+
|
|
322
|
+
def _init_input_rows(self, rows: list[exprs.DataRow]) -> None:
|
|
323
|
+
"""Set execution state in DataRow"""
|
|
324
|
+
for row in rows:
|
|
325
|
+
row.missing_dependents = np.sum(self.row_builder.dependencies[row.has_val == False], axis=0)
|
|
326
|
+
row.missing_slots = self.eval_ctx & (row.has_val == False)
|
|
327
|
+
|
|
328
|
+
def dispatch_exc(self, rows: list[exprs.DataRow], slot_with_exc: int, exc_tb: TracebackType) -> None:
|
|
329
|
+
"""Propagate exception to main event loop or to dependent slots, depending on ignore_errors"""
|
|
330
|
+
if len(rows) == 0 or self.exc_event.is_set():
|
|
331
|
+
return
|
|
332
|
+
|
|
333
|
+
if not self.ctx.ignore_errors:
|
|
334
|
+
dependency_idxs = [e.slot_idx for e in self.row_builder.unique_exprs[slot_with_exc].dependencies()]
|
|
335
|
+
first_row = rows[0]
|
|
336
|
+
input_vals = [first_row[idx] for idx in dependency_idxs]
|
|
337
|
+
e = self.row_builder.unique_exprs[slot_with_exc]
|
|
338
|
+
self.error = excs.ExprEvalError(
|
|
339
|
+
e, f'expression {e}', first_row.get_exc(e.slot_idx), exc_tb, input_vals, 0)
|
|
340
|
+
self.exc_event.set()
|
|
341
|
+
return
|
|
342
|
+
|
|
343
|
+
for row in rows:
|
|
344
|
+
assert row.has_exc(slot_with_exc)
|
|
345
|
+
exc = row.get_exc(slot_with_exc)
|
|
346
|
+
# propagate exception
|
|
347
|
+
for slot_idx in np.nonzero(self.row_builder.transitive_dependents[slot_with_exc])[0].tolist():
|
|
348
|
+
row.set_exc(slot_idx, exc)
|
|
349
|
+
self.dispatch(rows)
|
|
350
|
+
|
|
351
|
+
def dispatch(self, rows: list[exprs.DataRow]) -> None:
|
|
352
|
+
"""Dispatch rows to slot evaluators, based on materialized dependencies"""
|
|
353
|
+
if len(rows) == 0 or self.exc_event.is_set():
|
|
354
|
+
return
|
|
355
|
+
|
|
356
|
+
# slots ready for evaluation; rows x slots
|
|
357
|
+
ready_slots = np.zeros((len(rows), self.row_builder.num_materialized), dtype=bool)
|
|
358
|
+
completed_rows = np.zeros(len(rows), dtype=bool)
|
|
359
|
+
for i, row in enumerate(rows):
|
|
360
|
+
row.missing_slots &= row.has_val == False
|
|
361
|
+
if row.missing_slots.sum() == 0:
|
|
362
|
+
# all output slots have been materialized
|
|
363
|
+
completed_rows[i] = True
|
|
364
|
+
else:
|
|
365
|
+
# dependencies of missing slots
|
|
366
|
+
missing_dependencies = self.num_dependencies * row.missing_slots
|
|
367
|
+
# determine ready slots that are not yet materialized and not yet scheduled
|
|
368
|
+
num_mat_dependencies = np.sum(self.row_builder.dependencies * row.has_val, axis=1)
|
|
369
|
+
num_missing = missing_dependencies - num_mat_dependencies
|
|
370
|
+
ready_slots[i] = (num_missing == 0) & (row.is_scheduled == False) & row.missing_slots
|
|
371
|
+
row.is_scheduled = row.is_scheduled | ready_slots[i]
|
|
372
|
+
|
|
373
|
+
# clear intermediate values that are no longer needed (ie, all dependents are materialized)
|
|
374
|
+
missing_dependents = np.sum(self.row_builder.dependencies[row.has_val == False], axis=0)
|
|
375
|
+
gc_targets = (missing_dependents == 0) & (row.missing_dependents > 0) & self.gc_targets
|
|
376
|
+
row.clear(gc_targets)
|
|
377
|
+
row.missing_dependents = missing_dependents
|
|
378
|
+
|
|
379
|
+
if np.any(completed_rows):
|
|
380
|
+
completed_idxs = list(completed_rows.nonzero()[0])
|
|
381
|
+
for i in completed_idxs:
|
|
382
|
+
self.completed_rows.put_nowait(rows[i])
|
|
383
|
+
self.completed_event.set()
|
|
384
|
+
self.num_in_flight -= len(completed_idxs)
|
|
385
|
+
|
|
386
|
+
# schedule all ready slots
|
|
387
|
+
for slot_idx in np.sum(ready_slots, axis=0).nonzero()[0]:
|
|
388
|
+
ready_rows_v = ready_slots[:, slot_idx].flatten()
|
|
389
|
+
_ = ready_rows_v.nonzero()
|
|
390
|
+
ready_rows = [rows[i] for i in ready_rows_v.nonzero()[0]]
|
|
391
|
+
_logger.debug(f'Scheduling {len(ready_rows)} rows for slot {slot_idx}')
|
|
392
|
+
self.slot_evaluators[slot_idx].schedule(ready_rows, slot_idx)
|
|
393
|
+
|
|
394
|
+
def register_task(self, t: asyncio.Task) -> None:
|
|
395
|
+
self.tasks.add(t)
|
|
396
|
+
t.add_done_callback(self._done_cb)
|
|
397
|
+
|
|
398
|
+
def _done_cb(self, t: asyncio.Task) -> None:
|
|
399
|
+
self.tasks.discard(t)
|
|
400
|
+
# end the main loop if we had an unhandled exception
|
|
401
|
+
try:
|
|
402
|
+
t.result()
|
|
403
|
+
except asyncio.CancelledError:
|
|
404
|
+
pass
|
|
405
|
+
except Exception as exc:
|
|
406
|
+
stack_trace = traceback.format_exc()
|
|
407
|
+
self.error = excs.Error(f'Exception in task: {exc}\n{stack_trace}')
|
|
408
|
+
self.exc_event.set()
|