pixeltable 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (56) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/insertable_table.py +3 -3
  3. pixeltable/catalog/table.py +2 -2
  4. pixeltable/catalog/table_version.py +3 -2
  5. pixeltable/catalog/view.py +1 -1
  6. pixeltable/dataframe.py +52 -27
  7. pixeltable/env.py +109 -4
  8. pixeltable/exec/__init__.py +1 -1
  9. pixeltable/exec/aggregation_node.py +3 -3
  10. pixeltable/exec/cache_prefetch_node.py +13 -7
  11. pixeltable/exec/component_iteration_node.py +3 -9
  12. pixeltable/exec/data_row_batch.py +17 -5
  13. pixeltable/exec/exec_node.py +32 -12
  14. pixeltable/exec/expr_eval/__init__.py +1 -0
  15. pixeltable/exec/expr_eval/evaluators.py +240 -0
  16. pixeltable/exec/expr_eval/expr_eval_node.py +408 -0
  17. pixeltable/exec/expr_eval/globals.py +113 -0
  18. pixeltable/exec/expr_eval/row_buffer.py +76 -0
  19. pixeltable/exec/expr_eval/schedulers.py +240 -0
  20. pixeltable/exec/in_memory_data_node.py +2 -2
  21. pixeltable/exec/row_update_node.py +14 -14
  22. pixeltable/exec/sql_node.py +2 -2
  23. pixeltable/exprs/column_ref.py +5 -1
  24. pixeltable/exprs/data_row.py +50 -40
  25. pixeltable/exprs/expr.py +57 -12
  26. pixeltable/exprs/function_call.py +54 -19
  27. pixeltable/exprs/inline_expr.py +12 -21
  28. pixeltable/exprs/literal.py +25 -8
  29. pixeltable/exprs/row_builder.py +25 -2
  30. pixeltable/func/aggregate_function.py +4 -0
  31. pixeltable/func/callable_function.py +54 -4
  32. pixeltable/func/expr_template_function.py +5 -1
  33. pixeltable/func/function.py +48 -7
  34. pixeltable/func/query_template_function.py +16 -7
  35. pixeltable/func/udf.py +7 -1
  36. pixeltable/functions/__init__.py +1 -1
  37. pixeltable/functions/anthropic.py +97 -21
  38. pixeltable/functions/gemini.py +2 -6
  39. pixeltable/functions/openai.py +219 -28
  40. pixeltable/globals.py +2 -3
  41. pixeltable/io/hf_datasets.py +1 -1
  42. pixeltable/io/label_studio.py +5 -5
  43. pixeltable/io/parquet.py +1 -1
  44. pixeltable/metadata/__init__.py +2 -1
  45. pixeltable/plan.py +24 -9
  46. pixeltable/store.py +6 -0
  47. pixeltable/type_system.py +73 -36
  48. pixeltable/utils/arrow.py +3 -8
  49. pixeltable/utils/console_output.py +41 -0
  50. pixeltable/utils/filecache.py +1 -1
  51. {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/METADATA +4 -1
  52. {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/RECORD +55 -49
  53. pixeltable/exec/expr_eval_node.py +0 -232
  54. {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/LICENSE +0 -0
  55. {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/WHEEL +0 -0
  56. {pixeltable-0.3.0.dist-info → pixeltable-0.3.2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,240 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import datetime
5
+ import itertools
6
+ import logging
7
+ import sys
8
+ from typing import Iterator, Any, Optional, Callable, cast
9
+
10
+ from pixeltable import exprs
11
+ from pixeltable import func
12
+ from .globals import Dispatcher, Evaluator, FnCallArgs
13
+
14
+ _logger = logging.getLogger('pixeltable')
15
+
16
+
17
+ class DefaultExprEvaluator(Evaluator):
18
+ """
19
+ Standard expression evaluation using Expr.eval().
20
+
21
+ Creates one task per set of rows handed to schedule().
22
+
23
+ TODO:
24
+ - parallelize via Ray
25
+ """
26
+ e: exprs.Expr
27
+
28
+ def __init__(self, e: exprs.Expr, dispatcher: Dispatcher):
29
+ super().__init__(dispatcher)
30
+ self.e = e
31
+
32
+ def schedule(self, rows: list[exprs.DataRow], slot_idx: int) -> None:
33
+ assert self.e.slot_idx >= 0
34
+ task = asyncio.create_task(self.eval(rows))
35
+ self.dispatcher.register_task(task)
36
+
37
+ async def eval(self, rows: list[exprs.DataRow]) -> None:
38
+ rows_with_excs: set[int] = set() # records idxs into rows
39
+ for idx, row in enumerate(rows):
40
+ assert not row.has_val[self.e.slot_idx] and not row.has_exc(self.e.slot_idx)
41
+ if asyncio.current_task().cancelled() or self.dispatcher.exc_event.is_set():
42
+ return
43
+ try:
44
+ self.e.eval(row, self.dispatcher.row_builder)
45
+ except Exception as exc:
46
+ _, _, exc_tb = sys.exc_info()
47
+ row.set_exc(self.e.slot_idx, exc)
48
+ rows_with_excs.add(idx)
49
+ self.dispatcher.dispatch_exc([row], self.e.slot_idx, exc_tb)
50
+ self.dispatcher.dispatch([rows[i] for i in range(len(rows)) if i not in rows_with_excs])
51
+
52
+
53
+ class FnCallEvaluator(Evaluator):
54
+ """
55
+ Evaluates function calls:
56
+ - batched functions (sync and async): one task per batch
57
+ - async functions: one task per row
58
+ - the rest: one task per set of rows handed to schedule()
59
+
60
+ TODO:
61
+ - adaptive batching: finding the optimal batch size based on observed execution times
62
+ """
63
+ fn_call: exprs.FunctionCall
64
+ fn: func.CallableFunction
65
+ scalar_py_fn: Optional[Callable] # only set for non-batching CallableFunctions
66
+
67
+ # only set if fn.is_batched
68
+ call_args_queue: Optional[asyncio.Queue[FnCallArgs]] # FnCallArgs waiting for execution
69
+ batch_size: Optional[int]
70
+
71
+ def __init__(self, fn_call: exprs.FunctionCall, dispatcher: Dispatcher):
72
+ super().__init__(dispatcher)
73
+ self.fn_call = fn_call
74
+ self.fn = cast(func.CallableFunction, fn_call.fn)
75
+ if isinstance(self.fn, func.CallableFunction) and self.fn.is_batched:
76
+ self.call_args_queue = asyncio.Queue[FnCallArgs]()
77
+ # we're not supplying sample arguments there, they're ignored anyway
78
+ self.batch_size = self.fn.get_batch_size()
79
+ self.scalar_py_fn = None
80
+ else:
81
+ self.call_args_queue = None
82
+ self.batch_size = None
83
+ if isinstance(self.fn, func.CallableFunction):
84
+ self.scalar_py_fn = self.fn.py_fn
85
+ else:
86
+ self.scalar_py_fn = None
87
+
88
+ def schedule(self, rows: list[exprs.DataRow], slot_idx: int) -> None:
89
+ assert self.fn_call.slot_idx >= 0
90
+
91
+ # create FnCallArgs for incoming rows
92
+ skip_rows: list[exprs.DataRow] = [] # skip rows with Nones in non-nullable parameters
93
+ rows_call_args: list[FnCallArgs] = []
94
+ for row in rows:
95
+ args_kwargs = self.fn_call.make_args(row)
96
+ if args_kwargs is None:
97
+ # nothing to do here
98
+ row[self.fn_call.slot_idx] = None
99
+ skip_rows.append(row)
100
+ else:
101
+ args, kwargs = args_kwargs
102
+ rows_call_args.append(FnCallArgs(self.fn_call, [row], args=args, kwargs=kwargs))
103
+
104
+ if len(skip_rows) > 0:
105
+ self.dispatcher.dispatch(skip_rows)
106
+
107
+ if self.batch_size is not None:
108
+ if not self.is_closed and (len(rows_call_args) + self.call_args_queue.qsize() < self.batch_size):
109
+ # we don't have enough FnCallArgs for a batch, so add them to the queue
110
+ for item in rows_call_args:
111
+ self.call_args_queue.put_nowait(item)
112
+ return
113
+
114
+ # create one task per batch
115
+ combined_call_args = itertools.chain(self._queued_call_args_iter(), rows_call_args)
116
+ while True:
117
+ call_args_batch = list(itertools.islice(combined_call_args, self.batch_size))
118
+ if len(call_args_batch) == 0:
119
+ break
120
+ if len(call_args_batch) < self.batch_size and not self.is_closed:
121
+ # we don't have a full batch left: return the rest to the queue
122
+ assert self.call_args_queue.empty() # we saw all queued items
123
+ for item in call_args_batch:
124
+ self.call_args_queue.put_nowait(item)
125
+ return
126
+
127
+ # turn call_args_batch into a single batched FnCallArgs
128
+ _logger.debug(f'Creating batch of size {len(call_args_batch)} for slot {slot_idx}')
129
+ batched_call_args = self._create_batch_call_args(call_args_batch)
130
+ if self.fn_call.resource_pool is not None:
131
+ # hand the call off to the resource pool's scheduler
132
+ scheduler = self.dispatcher.schedulers[self.fn_call.resource_pool]
133
+ scheduler.submit(batched_call_args)
134
+ else:
135
+ task = asyncio.create_task(self.eval_batch(batched_call_args))
136
+ self.dispatcher.register_task(task)
137
+
138
+ elif self.fn.is_async:
139
+ if self.fn_call.resource_pool is not None:
140
+ # hand the call off to the resource pool's scheduler
141
+ scheduler = self.dispatcher.schedulers[self.fn_call.resource_pool]
142
+ for item in rows_call_args:
143
+ scheduler.submit(item)
144
+ else:
145
+ # create one task per call
146
+ for item in rows_call_args:
147
+ task = asyncio.create_task(self.eval_async(item))
148
+ self.dispatcher.register_task(task)
149
+
150
+ else:
151
+ # create a single task for all rows
152
+ task = asyncio.create_task(self.eval(rows_call_args))
153
+ self.dispatcher.register_task(task)
154
+
155
+ def _queued_call_args_iter(self) -> Iterator[FnCallArgs]:
156
+ while not self.call_args_queue.empty():
157
+ yield self.call_args_queue.get_nowait()
158
+
159
+ def _create_batch_call_args(self, call_args: list[FnCallArgs]) -> FnCallArgs:
160
+ """Roll call_args into a single batched FnCallArgs"""
161
+ batch_args: list[list[Optional[Any]]] = [[None] * len(call_args) for _ in range(len(self.fn_call.args))]
162
+ batch_kwargs: dict[str, list[Optional[Any]]] = {k: [None] * len(call_args) for k in self.fn_call.kwargs.keys()}
163
+ assert isinstance(self.fn, func.CallableFunction)
164
+ for i, item in enumerate(call_args):
165
+ for j in range(len(item.args)):
166
+ batch_args[j][i] = item.args[j]
167
+ for k in item.kwargs.keys():
168
+ batch_kwargs[k][i] = item.kwargs[k]
169
+ return FnCallArgs(
170
+ self.fn_call, [item.row for item in call_args], batch_args=batch_args, batch_kwargs=batch_kwargs)
171
+
172
+ async def eval_batch(self, batched_call_args: FnCallArgs) -> None:
173
+ result_batch: list[Any]
174
+ try:
175
+ if self.fn.is_async:
176
+ result_batch = await self.fn.aexec_batch(
177
+ *batched_call_args.batch_args, **batched_call_args.batch_kwargs)
178
+ else:
179
+ # check for cancellation before starting something potentially long-running
180
+ if asyncio.current_task().cancelled() or self.dispatcher.exc_event.is_set():
181
+ return
182
+ result_batch = self.fn.exec_batch(batched_call_args.batch_args, batched_call_args.batch_kwargs)
183
+ except Exception as exc:
184
+ _, _, exc_tb = sys.exc_info()
185
+ for row in batched_call_args.rows:
186
+ row.set_exc(self.fn_call.slot_idx, exc)
187
+ self.dispatcher.dispatch_exc(batched_call_args.rows, self.fn_call.slot_idx, exc_tb)
188
+ return
189
+
190
+ for i, row in enumerate(batched_call_args.rows):
191
+ row[self.fn_call.slot_idx] = result_batch[i]
192
+ self.dispatcher.dispatch(batched_call_args.rows)
193
+
194
+ async def eval_async(self, call_args: FnCallArgs) -> None:
195
+ assert len(call_args.rows) == 1
196
+ assert not call_args.row.has_val[self.fn_call.slot_idx]
197
+ assert not call_args.row.has_exc(self.fn_call.slot_idx)
198
+
199
+ try:
200
+ start_ts = datetime.datetime.now()
201
+ _logger.debug(f'Start evaluating slot {self.fn_call.slot_idx}')
202
+ call_args.row[self.fn_call.slot_idx] = await self.fn.aexec(*call_args.args, **call_args.kwargs)
203
+ end_ts = datetime.datetime.now()
204
+ _logger.debug(f'Evaluated slot {self.fn_call.slot_idx} in {end_ts - start_ts}')
205
+ self.dispatcher.dispatch([call_args.row])
206
+ except Exception as exc:
207
+ import anthropic
208
+ if isinstance(exc, anthropic.RateLimitError):
209
+ _logger.debug(f'RateLimitError: {exc}')
210
+ _, _, exc_tb = sys.exc_info()
211
+ call_args.row.set_exc(self.fn_call.slot_idx, exc)
212
+ self.dispatcher.dispatch_exc(call_args.rows, self.fn_call.slot_idx, exc_tb)
213
+
214
+ async def eval(self, call_args_batch: list[FnCallArgs]) -> None:
215
+ rows_with_excs: set[int] = set() # records idxs into 'rows'
216
+ for idx, item in enumerate(call_args_batch):
217
+ assert len(item.rows) == 1
218
+ assert not item.row.has_val[self.fn_call.slot_idx]
219
+ assert not item.row.has_exc(self.fn_call.slot_idx)
220
+ # check for cancellation before starting something potentially long-running
221
+ if asyncio.current_task().cancelled() or self.dispatcher.exc_event.is_set():
222
+ return
223
+ try:
224
+ item.row[self.fn_call.slot_idx] = self.scalar_py_fn(*item.args, **item.kwargs)
225
+ except Exception as exc:
226
+ _, _, exc_tb = sys.exc_info()
227
+ item.row.set_exc(self.fn_call.slot_idx, exc)
228
+ rows_with_excs.add(idx)
229
+ self.dispatcher.dispatch_exc(item.rows, self.fn_call.slot_idx, exc_tb)
230
+ self.dispatcher.dispatch(
231
+ [call_args_batch[i].row for i in range(len(call_args_batch)) if i not in rows_with_excs])
232
+
233
+ def _close(self) -> None:
234
+ """Create a task for the incomplete batch of queued FnCallArgs, if any"""
235
+ _logger.debug(f'FnCallEvaluator.close(): slot_idx={self.fn_call.slot_idx}')
236
+ if self.call_args_queue is None or self.call_args_queue.empty():
237
+ return
238
+ batched_call_args = self._create_batch_call_args(list(self._queued_call_args_iter()))
239
+ task = asyncio.create_task(self.eval_batch(batched_call_args))
240
+ self.dispatcher.register_task(task)
@@ -0,0 +1,408 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import traceback
6
+ from types import TracebackType
7
+ from typing import Iterable, AsyncIterator, Optional, Union
8
+
9
+ import numpy as np
10
+
11
+ import pixeltable.exceptions as excs
12
+ from pixeltable import exprs
13
+ from pixeltable import func
14
+ from .evaluators import DefaultExprEvaluator, FnCallEvaluator
15
+ from .globals import Evaluator, Scheduler
16
+ from .row_buffer import RowBuffer
17
+ from .schedulers import SCHEDULERS
18
+ from ..data_row_batch import DataRowBatch
19
+ from ..exec_node import ExecNode
20
+
21
+ _logger = logging.getLogger('pixeltable')
22
+
23
+
24
+
25
+ class ExprEvalNode(ExecNode):
26
+ """
27
+ Expression evaluation
28
+
29
+ Resource management:
30
+ - the execution system tries to limit total memory consumption by limiting the number of rows that are in
31
+ circulation
32
+ - during execution, slots that aren't part of the output are garbage collected as soon as their direct dependents
33
+ are materialized
34
+
35
+ TODO:
36
+ - Literal handling: currently, Literal values are copied into slots via the normal evaluation mechanism, which is
37
+ needless overhead; instead: pre-populate Literal slots in _init_row()
38
+ - local model inference on gpu: currently, no attempt is made to ensure that models can fit onto the gpu
39
+ simultaneously, which will cause errors; instead, the execution should be divided into sequential phases, each
40
+ of which only contains a subset of the models which is known to fit onto the gpu simultaneously
41
+ """
42
+ maintain_input_order: bool # True if we're returning rows in the order we received them from our input
43
+ num_dependencies: np.ndarray # number of dependencies for our output slots; indexed by slot idx
44
+ outputs: np.ndarray # bool per slot; True if this slot is part of our output
45
+ slot_evaluators: dict[int, Evaluator] # key: slot idx
46
+ schedulers: dict[str, Scheduler] # key: resource pool name
47
+ gc_targets: np.ndarray # bool per slot; True if this is an intermediate expr (ie, not part of our output)
48
+ eval_ctx: np.ndarray # bool per slot; EvalCtx.slot_idxs as a mask
49
+
50
+ # execution state
51
+ tasks: set[asyncio.Task] # collects all running tasks to prevent them from getting gc'd
52
+ exc_event: asyncio.Event # set if an exception needs to be propagated
53
+ error: Optional[Union[excs.Error, excs.ExprEvalError]] # exception that needs to be propagated
54
+ completed_rows: asyncio.Queue[exprs.DataRow] # rows that have completed evaluation
55
+ completed_event: asyncio.Event # set when completed_rows is non-empty
56
+ input_iter: AsyncIterator[DataRowBatch]
57
+ current_input_batch: Optional[DataRowBatch] # batch from which we're currently consuming rows
58
+ input_row_idx: int # next row to consume from current_input_batch
59
+ next_input_batch: Optional[DataRowBatch] # read-ahead input batch
60
+ avail_input_rows: int # total number across both current_/next_input_batch
61
+ input_complete: bool # True if we've received all input batches
62
+ num_in_flight: int # number of dispatched rows that haven't completed
63
+ row_pos_map: Optional[dict[int, int]] # id(row) -> position of row in input; only set if maintain_input_order
64
+ output_buffer: RowBuffer # holds rows that are ready to be returned, in order
65
+
66
+ # debugging
67
+ num_input_rows: int
68
+ num_output_rows: int
69
+
70
+ BATCH_SIZE = 64
71
+ MAX_BUFFERED_ROWS = 512 # maximum number of rows that have been dispatched but not yet returned
72
+
73
+ def __init__(
74
+ self, row_builder: exprs.RowBuilder, output_exprs: Iterable[exprs.Expr], input_exprs: Iterable[exprs.Expr],
75
+ input: ExecNode, maintain_input_order: bool = True
76
+ ):
77
+ super().__init__(row_builder, output_exprs, input_exprs, input)
78
+ self.maintain_input_order = maintain_input_order
79
+ self.num_dependencies = np.sum(row_builder.dependencies, axis=1)
80
+ self.outputs = np.zeros(row_builder.num_materialized, dtype=bool)
81
+ output_slot_idxs = [e.slot_idx for e in output_exprs]
82
+ self.outputs[output_slot_idxs] = True
83
+ self.tasks = set()
84
+
85
+ self.gc_targets = np.ones(row_builder.num_materialized, dtype=bool)
86
+ # we need to retain all slots that are part of the output
87
+ self.gc_targets[[e.slot_idx for e in row_builder.output_exprs]] = False
88
+
89
+ output_ctx = self.row_builder.create_eval_ctx(output_exprs, exclude=input_exprs)
90
+ self.eval_ctx = np.zeros(row_builder.num_materialized, dtype=bool)
91
+ self.eval_ctx[output_ctx.slot_idxs] = True
92
+ self.error = None
93
+
94
+ self.input_iter = self.input.__aiter__()
95
+ self.current_input_batch = None
96
+ self.next_input_batch = None
97
+ self.input_row_idx = 0
98
+ self.avail_input_rows = 0
99
+ self.input_complete = False
100
+ self.num_in_flight = 0
101
+ self.row_pos_map = None
102
+ self.output_buffer = RowBuffer(self.MAX_BUFFERED_ROWS)
103
+
104
+ self.num_input_rows = 0
105
+ self.num_output_rows = 0
106
+
107
+ self.slot_evaluators = {}
108
+ self.schedulers = {}
109
+ self._init_slot_evaluators()
110
+
111
+ def set_input_order(self, maintain_input_order: bool) -> None:
112
+ self.maintain_input_order = maintain_input_order
113
+
114
+ def _init_slot_evaluators(self) -> None:
115
+ """Create slot evaluators and resource pool schedulers"""
116
+ resource_pools: set[str] = set()
117
+ for slot_idx in range(self.row_builder.num_materialized):
118
+ expr = self.row_builder.unique_exprs[slot_idx]
119
+ if (
120
+ isinstance(expr, exprs.FunctionCall)
121
+ # ExprTemplateFunction and AggregateFunction calls are best handled by FunctionCall.eval()
122
+ and not isinstance(expr.fn, func.ExprTemplateFunction)
123
+ and not isinstance(expr.fn, func.AggregateFunction)
124
+ ):
125
+ if expr.resource_pool is not None:
126
+ resource_pools.add(expr.resource_pool)
127
+ self.slot_evaluators[slot_idx] = FnCallEvaluator(expr, self)
128
+ else:
129
+ self.slot_evaluators[slot_idx] = DefaultExprEvaluator(expr, self)
130
+
131
+ async def _fetch_input_batch(self) -> None:
132
+ """
133
+ Fetches another batch from our input or sets input_complete to True if there are no more batches.
134
+
135
+ - stores the batch in current_input_batch, if not already set, or next_input_batch
136
+ - updates row_pos_map, if needed
137
+ """
138
+ assert not self.input_complete
139
+ try:
140
+ batch = await self.input_iter.__anext__()
141
+ assert self.next_input_batch is None
142
+ if self.current_input_batch is None:
143
+ self.current_input_batch = batch
144
+ else:
145
+ self.next_input_batch = batch
146
+ if self.maintain_input_order:
147
+ for idx, row in enumerate(batch.rows):
148
+ self.row_pos_map[id(row)] = self.num_input_rows + idx
149
+ self.num_input_rows += len(batch)
150
+ self.avail_input_rows += len(batch)
151
+ _logger.debug(f'adding input: batch_size={len(batch)} #input_rows={self.num_input_rows} #avail={self.avail_input_rows}')
152
+ except StopAsyncIteration:
153
+ self.input_complete = True
154
+ _logger.debug(f'finished input: #input_rows={self.num_input_rows}, #avail={self.avail_input_rows}')
155
+ except excs.Error as err:
156
+ self.error = err
157
+ self.exc_event.set()
158
+ # TODO: should we also handle Exception here and create an excs.Error from it?
159
+
160
+ @property
161
+ def total_buffered(self) -> int:
162
+ return self.num_in_flight + self.completed_rows.qsize() + self.output_buffer.num_rows
163
+
164
+ def _dispatch_input_rows(self) -> None:
165
+ """Dispatch the maximum number of input rows, given total_buffered; does not block"""
166
+ if self.avail_input_rows == 0:
167
+ return
168
+ num_rows = min(self.MAX_BUFFERED_ROWS - self.total_buffered, self.avail_input_rows)
169
+ assert num_rows >= 0
170
+ if num_rows == 0:
171
+ return
172
+ assert self.current_input_batch is not None
173
+ avail_current_batch_rows = len(self.current_input_batch) - self.input_row_idx
174
+
175
+ rows: list[exprs.DataRow]
176
+ if avail_current_batch_rows > num_rows:
177
+ # we only need rows from current_input_batch
178
+ rows = self.current_input_batch.rows[self.input_row_idx:self.input_row_idx + num_rows]
179
+ self.input_row_idx += num_rows
180
+ else:
181
+ # we need rows from both current_/next_input_batch
182
+ rows = self.current_input_batch.rows[self.input_row_idx:]
183
+ self.current_input_batch = self.next_input_batch
184
+ self.next_input_batch = None
185
+ self.input_row_idx = 0
186
+ num_remaining = num_rows - len(rows)
187
+ if num_remaining > 0:
188
+ rows.extend(self.current_input_batch.rows[:num_remaining])
189
+ self.input_row_idx = num_remaining
190
+ self.avail_input_rows -= num_rows
191
+ self.num_in_flight += num_rows
192
+ self._log_state(f'dispatch input ({num_rows})')
193
+
194
+ self._init_input_rows(rows)
195
+ self.dispatch(rows)
196
+
197
+ def _log_state(self, prefix: str) -> None:
198
+ _logger.debug(
199
+ f'{prefix}: #in-flight={self.num_in_flight} #complete={self.completed_rows.qsize()} '
200
+ f'#output-buffer={self.output_buffer.num_rows} #ready={self.output_buffer.num_ready} '
201
+ f'total-buffered={self.total_buffered} #avail={self.avail_input_rows} '
202
+ f'#input={self.num_input_rows} #output={self.num_output_rows}'
203
+ )
204
+
205
+ def _init_schedulers(self) -> None:
206
+ resource_pools = {
207
+ eval.fn_call.resource_pool for eval in self.slot_evaluators.values() if isinstance(eval, FnCallEvaluator)
208
+ }
209
+ resource_pools = {pool for pool in resource_pools if pool is not None}
210
+ for pool_name in resource_pools:
211
+ for scheduler in SCHEDULERS:
212
+ if scheduler.matches(pool_name):
213
+ self.schedulers[pool_name] = scheduler(pool_name, self)
214
+ break
215
+ if pool_name not in self.schedulers:
216
+ raise RuntimeError(f'No scheduler found for resource pool {pool_name}')
217
+
218
+ async def __aiter__(self) -> AsyncIterator[DataRowBatch]:
219
+ """
220
+ Main event loop
221
+
222
+ Goals:
223
+ - return completed DataRowBatches as soon as they become available
224
+ - maximize the number of rows in flight in order to maximize parallelism, up to the given limit
225
+ """
226
+ # initialize completed_rows and events, now that we have the correct event loop
227
+ self.completed_rows = asyncio.Queue[exprs.DataRow]()
228
+ self.exc_event = asyncio.Event()
229
+ self.completed_event = asyncio.Event()
230
+ self._init_schedulers()
231
+ if self.maintain_input_order:
232
+ self.row_pos_map = {}
233
+ self.output_buffer.set_row_pos_map(self.row_pos_map)
234
+
235
+ row: exprs.DataRow
236
+ exc_event_aw = asyncio.create_task(self.exc_event.wait(), name='exc_event.wait()')
237
+ input_batch_aw: Optional[asyncio.Task] = None
238
+ completed_aw: Optional[asyncio.Task] = None
239
+
240
+ try:
241
+ while True:
242
+ # process completed rows before doing anything else
243
+ while not self.completed_rows.empty():
244
+ # move completed rows to output buffer
245
+ while not self.completed_rows.empty():
246
+ row = self.completed_rows.get_nowait()
247
+ self.output_buffer.add_row(row)
248
+ if self.row_pos_map is not None:
249
+ self.row_pos_map.pop(id(row))
250
+
251
+ self._log_state('processed completed')
252
+ # return as many batches as we have available
253
+ while self.output_buffer.num_ready >= self.BATCH_SIZE:
254
+ batch_rows = self.output_buffer.get_rows(self.BATCH_SIZE)
255
+ self.num_output_rows += len(batch_rows)
256
+ # make sure we top up our in-flight rows before yielding
257
+ self._dispatch_input_rows()
258
+ self._log_state(f'yielding {len(batch_rows)} rows')
259
+ yield DataRowBatch(tbl=None, row_builder=self.row_builder, rows=batch_rows)
260
+ # at this point, we may have more completed rows
261
+
262
+ assert self.completed_rows.empty() # all completed rows should be sitting in output_buffer
263
+ self.completed_event.clear()
264
+ if self.input_complete and self.num_in_flight == 0:
265
+ # there is no more input and nothing left to wait for
266
+ assert self.avail_input_rows == 0
267
+ if self.output_buffer.num_ready > 0:
268
+ assert self.output_buffer.num_rows == self.output_buffer.num_ready
269
+ # yield the leftover rows
270
+ batch_rows = self.output_buffer.get_rows(self.output_buffer.num_ready)
271
+ self.num_output_rows += len(batch_rows)
272
+ self._log_state(f'yielding {len(batch_rows)} rows')
273
+ yield DataRowBatch(tbl=None, row_builder=self.row_builder, rows=batch_rows)
274
+
275
+ assert self.output_buffer.num_rows == 0
276
+ return
277
+
278
+ if self.input_complete and self.avail_input_rows == 0:
279
+ # no more input rows to dispatch, but we're still waiting for rows to finish:
280
+ # close all slot evaluators to flush queued rows
281
+ for evaluator in self.slot_evaluators.values():
282
+ evaluator.close()
283
+
284
+ # we don't have a full batch of rows at this point and need to wait
285
+ aws = {exc_event_aw} # always wait for an exception
286
+ if self.next_input_batch is None and not self.input_complete:
287
+ # also wait for another batch if we don't have a read-ahead batch yet
288
+ if input_batch_aw is None:
289
+ input_batch_aw = asyncio.create_task(self._fetch_input_batch(), name='_fetch_input_batch()')
290
+ aws.add(input_batch_aw)
291
+ if self.num_in_flight > 0:
292
+ # also wait for more rows to complete
293
+ if completed_aw is None:
294
+ completed_aw = asyncio.create_task(self.completed_event.wait(), name='completed.wait()')
295
+ aws.add(completed_aw)
296
+ done, pending = await asyncio.wait(aws, return_when=asyncio.FIRST_COMPLETED)
297
+
298
+ if self.exc_event.is_set():
299
+ # we got an exception that we need to propagate through __iter__()
300
+ _logger.debug(f'Propagating exception {self.error}')
301
+ raise self.error
302
+ if completed_aw in done:
303
+ self._log_state('completed_aw done')
304
+ completed_aw = None
305
+ if input_batch_aw in done:
306
+ self._dispatch_input_rows()
307
+ input_batch_aw = None
308
+
309
+ finally:
310
+ # task cleanup
311
+ active_tasks = {exc_event_aw}
312
+ if input_batch_aw is not None:
313
+ active_tasks.add(input_batch_aw)
314
+ if completed_aw is not None:
315
+ active_tasks.add(completed_aw)
316
+ active_tasks.update(self.tasks)
317
+ for task in active_tasks:
318
+ if not task.done():
319
+ task.cancel()
320
+ _ = await asyncio.gather(*active_tasks, return_exceptions=True)
321
+
322
+ def _init_input_rows(self, rows: list[exprs.DataRow]) -> None:
323
+ """Set execution state in DataRow"""
324
+ for row in rows:
325
+ row.missing_dependents = np.sum(self.row_builder.dependencies[row.has_val == False], axis=0)
326
+ row.missing_slots = self.eval_ctx & (row.has_val == False)
327
+
328
+ def dispatch_exc(self, rows: list[exprs.DataRow], slot_with_exc: int, exc_tb: TracebackType) -> None:
329
+ """Propagate exception to main event loop or to dependent slots, depending on ignore_errors"""
330
+ if len(rows) == 0 or self.exc_event.is_set():
331
+ return
332
+
333
+ if not self.ctx.ignore_errors:
334
+ dependency_idxs = [e.slot_idx for e in self.row_builder.unique_exprs[slot_with_exc].dependencies()]
335
+ first_row = rows[0]
336
+ input_vals = [first_row[idx] for idx in dependency_idxs]
337
+ e = self.row_builder.unique_exprs[slot_with_exc]
338
+ self.error = excs.ExprEvalError(
339
+ e, f'expression {e}', first_row.get_exc(e.slot_idx), exc_tb, input_vals, 0)
340
+ self.exc_event.set()
341
+ return
342
+
343
+ for row in rows:
344
+ assert row.has_exc(slot_with_exc)
345
+ exc = row.get_exc(slot_with_exc)
346
+ # propagate exception
347
+ for slot_idx in np.nonzero(self.row_builder.transitive_dependents[slot_with_exc])[0].tolist():
348
+ row.set_exc(slot_idx, exc)
349
+ self.dispatch(rows)
350
+
351
+ def dispatch(self, rows: list[exprs.DataRow]) -> None:
352
+ """Dispatch rows to slot evaluators, based on materialized dependencies"""
353
+ if len(rows) == 0 or self.exc_event.is_set():
354
+ return
355
+
356
+ # slots ready for evaluation; rows x slots
357
+ ready_slots = np.zeros((len(rows), self.row_builder.num_materialized), dtype=bool)
358
+ completed_rows = np.zeros(len(rows), dtype=bool)
359
+ for i, row in enumerate(rows):
360
+ row.missing_slots &= row.has_val == False
361
+ if row.missing_slots.sum() == 0:
362
+ # all output slots have been materialized
363
+ completed_rows[i] = True
364
+ else:
365
+ # dependencies of missing slots
366
+ missing_dependencies = self.num_dependencies * row.missing_slots
367
+ # determine ready slots that are not yet materialized and not yet scheduled
368
+ num_mat_dependencies = np.sum(self.row_builder.dependencies * row.has_val, axis=1)
369
+ num_missing = missing_dependencies - num_mat_dependencies
370
+ ready_slots[i] = (num_missing == 0) & (row.is_scheduled == False) & row.missing_slots
371
+ row.is_scheduled = row.is_scheduled | ready_slots[i]
372
+
373
+ # clear intermediate values that are no longer needed (ie, all dependents are materialized)
374
+ missing_dependents = np.sum(self.row_builder.dependencies[row.has_val == False], axis=0)
375
+ gc_targets = (missing_dependents == 0) & (row.missing_dependents > 0) & self.gc_targets
376
+ row.clear(gc_targets)
377
+ row.missing_dependents = missing_dependents
378
+
379
+ if np.any(completed_rows):
380
+ completed_idxs = list(completed_rows.nonzero()[0])
381
+ for i in completed_idxs:
382
+ self.completed_rows.put_nowait(rows[i])
383
+ self.completed_event.set()
384
+ self.num_in_flight -= len(completed_idxs)
385
+
386
+ # schedule all ready slots
387
+ for slot_idx in np.sum(ready_slots, axis=0).nonzero()[0]:
388
+ ready_rows_v = ready_slots[:, slot_idx].flatten()
389
+ _ = ready_rows_v.nonzero()
390
+ ready_rows = [rows[i] for i in ready_rows_v.nonzero()[0]]
391
+ _logger.debug(f'Scheduling {len(ready_rows)} rows for slot {slot_idx}')
392
+ self.slot_evaluators[slot_idx].schedule(ready_rows, slot_idx)
393
+
394
+ def register_task(self, t: asyncio.Task) -> None:
395
+ self.tasks.add(t)
396
+ t.add_done_callback(self._done_cb)
397
+
398
+ def _done_cb(self, t: asyncio.Task) -> None:
399
+ self.tasks.discard(t)
400
+ # end the main loop if we had an unhandled exception
401
+ try:
402
+ t.result()
403
+ except asyncio.CancelledError:
404
+ pass
405
+ except Exception as exc:
406
+ stack_trace = traceback.format_exc()
407
+ self.error = excs.Error(f'Exception in task: {exc}\n{stack_trace}')
408
+ self.exc_event.set()