pixeltable 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (47) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/table_version.py +2 -1
  3. pixeltable/dataframe.py +52 -27
  4. pixeltable/env.py +92 -4
  5. pixeltable/exec/__init__.py +1 -1
  6. pixeltable/exec/aggregation_node.py +3 -3
  7. pixeltable/exec/cache_prefetch_node.py +13 -7
  8. pixeltable/exec/component_iteration_node.py +3 -9
  9. pixeltable/exec/data_row_batch.py +17 -5
  10. pixeltable/exec/exec_node.py +32 -12
  11. pixeltable/exec/expr_eval/__init__.py +1 -0
  12. pixeltable/exec/expr_eval/evaluators.py +245 -0
  13. pixeltable/exec/expr_eval/expr_eval_node.py +404 -0
  14. pixeltable/exec/expr_eval/globals.py +114 -0
  15. pixeltable/exec/expr_eval/row_buffer.py +76 -0
  16. pixeltable/exec/expr_eval/schedulers.py +232 -0
  17. pixeltable/exec/in_memory_data_node.py +2 -2
  18. pixeltable/exec/row_update_node.py +14 -14
  19. pixeltable/exec/sql_node.py +2 -2
  20. pixeltable/exprs/column_ref.py +5 -1
  21. pixeltable/exprs/data_row.py +50 -40
  22. pixeltable/exprs/expr.py +57 -12
  23. pixeltable/exprs/function_call.py +54 -19
  24. pixeltable/exprs/inline_expr.py +12 -21
  25. pixeltable/exprs/literal.py +25 -8
  26. pixeltable/exprs/row_builder.py +23 -0
  27. pixeltable/func/aggregate_function.py +4 -0
  28. pixeltable/func/callable_function.py +54 -4
  29. pixeltable/func/expr_template_function.py +5 -1
  30. pixeltable/func/function.py +48 -7
  31. pixeltable/func/query_template_function.py +16 -7
  32. pixeltable/func/udf.py +7 -1
  33. pixeltable/functions/__init__.py +1 -1
  34. pixeltable/functions/anthropic.py +95 -21
  35. pixeltable/functions/gemini.py +2 -6
  36. pixeltable/functions/openai.py +207 -28
  37. pixeltable/globals.py +1 -1
  38. pixeltable/plan.py +24 -9
  39. pixeltable/store.py +6 -0
  40. pixeltable/type_system.py +3 -3
  41. pixeltable/utils/arrow.py +3 -3
  42. {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/METADATA +3 -1
  43. {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/RECORD +46 -41
  44. pixeltable/exec/expr_eval_node.py +0 -232
  45. {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/LICENSE +0 -0
  46. {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/WHEEL +0 -0
  47. {pixeltable-0.3.0.dist-info → pixeltable-0.3.1.dist-info}/entry_points.txt +0 -0
pixeltable/__version__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  # These version placeholders will be replaced during build.
2
- __version__ = "0.3.0"
3
- __version_tuple__ = (0, 3, 0)
2
+ __version__ = "0.3.1"
3
+ __version_tuple__ = (0, 3, 1)
@@ -734,7 +734,8 @@ class TableVersion:
734
734
  if conn is None:
735
735
  with Env.get().engine.begin() as conn:
736
736
  return self._insert(
737
- plan, conn, time.time(), print_stats=print_stats, rowids=rowids(), abort_on_exc=fail_on_exception)
737
+ plan, conn, time.time(), print_stats=print_stats, rowids=rowids(),
738
+ abort_on_exc=fail_on_exception)
738
739
  else:
739
740
  return self._insert(
740
741
  plan, conn, time.time(), print_stats=print_stats, rowids=rowids(), abort_on_exc=fail_on_exception)
pixeltable/dataframe.py CHANGED
@@ -8,18 +8,15 @@ import json
8
8
  import logging
9
9
  import traceback
10
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, Optional, Sequence, Union, Literal
11
+ from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, Optional, Sequence, Union, AsyncIterator, NoReturn
12
12
 
13
+ import numpy as np
13
14
  import pandas as pd
14
- import pandas.io.formats.style
15
15
  import sqlalchemy as sql
16
16
 
17
- import pixeltable.catalog as catalog
18
17
  import pixeltable.exceptions as excs
19
- import pixeltable.exprs as exprs
20
18
  import pixeltable.type_system as ts
21
- from pixeltable import exec
22
- from pixeltable import plan
19
+ from pixeltable import catalog, exec, exprs, plan
23
20
  from pixeltable.catalog import is_valid_identifier
24
21
  from pixeltable.catalog.globals import UpdateStatus
25
22
  from pixeltable.env import Env
@@ -29,6 +26,7 @@ from pixeltable.utils.formatter import Formatter
29
26
 
30
27
  if TYPE_CHECKING:
31
28
  import torch
29
+ import torch.utils.data
32
30
 
33
31
  __all__ = ['DataFrame']
34
32
 
@@ -268,6 +266,20 @@ class DataFrame:
268
266
  else:
269
267
  yield from exec_plan(conn)
270
268
 
269
+ async def _aexec(self, conn: sql.engine.Connection) -> AsyncIterator[exprs.DataRow]:
270
+ """Run the query and return rows as a generator.
271
+ This function must not modify the state of the DataFrame, otherwise it breaks dataset caching.
272
+ """
273
+ plan = self._create_query_plan()
274
+ plan.ctx.set_conn(conn)
275
+ plan.open()
276
+ try:
277
+ async for row_batch in plan:
278
+ for row in row_batch:
279
+ yield row
280
+ finally:
281
+ plan.close()
282
+
271
283
  def _create_query_plan(self) -> exec.ExecNode:
272
284
  # construct a group-by clause if we're grouping by a table
273
285
  group_by_clause: Optional[list[exprs.Expr]] = None
@@ -392,26 +404,29 @@ class DataFrame:
392
404
  group_by_clause=group_by_clause, grouping_tbl=self.grouping_tbl,
393
405
  order_by_clause=order_by_clause, limit=self.limit_val)
394
406
 
407
+ def _raise_expr_eval_err(self, e: excs.ExprEvalError) -> NoReturn:
408
+ msg = f'In row {e.row_num} the {e.expr_msg} encountered exception ' f'{type(e.exc).__name__}:\n{str(e.exc)}'
409
+ if len(e.input_vals) > 0:
410
+ input_msgs = [
411
+ f"'{d}' = {d.col_type.print_value(e.input_vals[i])}" for i, d in enumerate(e.expr.dependencies())
412
+ ]
413
+ msg += f'\nwith {", ".join(input_msgs)}'
414
+ assert e.exc_tb is not None
415
+ stack_trace = traceback.format_tb(e.exc_tb)
416
+ if len(stack_trace) > 2:
417
+ # append a stack trace if the exception happened in user code
418
+ # (frame 0 is ExprEvaluator and frame 1 is some expr's eval()
419
+ nl = '\n'
420
+ # [-1:0:-1]: leave out entry 0 and reverse order, so that the most recent frame is at the top
421
+ msg += f'\nStack:\n{nl.join(stack_trace[-1:1:-1])}'
422
+ raise excs.Error(msg)
423
+
395
424
  def _output_row_iterator(self, conn: Optional[sql.engine.Connection] = None) -> Iterator[list]:
396
425
  try:
397
426
  for data_row in self._exec(conn):
398
427
  yield [data_row[e.slot_idx] for e in self._select_list_exprs]
399
428
  except excs.ExprEvalError as e:
400
- msg = f'In row {e.row_num} the {e.expr_msg} encountered exception ' f'{type(e.exc).__name__}:\n{str(e.exc)}'
401
- if len(e.input_vals) > 0:
402
- input_msgs = [
403
- f"'{d}' = {d.col_type.print_value(e.input_vals[i])}" for i, d in enumerate(e.expr.dependencies())
404
- ]
405
- msg += f'\nwith {", ".join(input_msgs)}'
406
- assert e.exc_tb is not None
407
- stack_trace = traceback.format_tb(e.exc_tb)
408
- if len(stack_trace) > 2:
409
- # append a stack trace if the exception happened in user code
410
- # (frame 0 is ExprEvaluator and frame 1 is some expr's eval()
411
- nl = '\n'
412
- # [-1:0:-1]: leave out entry 0 and reverse order, so that the most recent frame is at the top
413
- msg += f'\nStack:\n{nl.join(stack_trace[-1:1:-1])}'
414
- raise excs.Error(msg)
429
+ self._raise_expr_eval_err(e)
415
430
  except sql.exc.DBAPIError as e:
416
431
  raise excs.Error(f'Error during SQL execution:\n{e}')
417
432
 
@@ -421,6 +436,18 @@ class DataFrame:
421
436
  def _collect(self, conn: Optional[sql.engine.Connection] = None) -> DataFrameResultSet:
422
437
  return DataFrameResultSet(list(self._output_row_iterator(conn)), self.schema)
423
438
 
439
+ async def _acollect(self, conn: sql.engine.Connection) -> DataFrameResultSet:
440
+ try:
441
+ result = [
442
+ [row[e.slot_idx] for e in self._select_list_exprs]
443
+ async for row in self._aexec(conn)
444
+ ]
445
+ return DataFrameResultSet(result, self.schema)
446
+ except excs.ExprEvalError as e:
447
+ self._raise_expr_eval_err(e)
448
+ except sql.exc.DBAPIError as e:
449
+ raise excs.Error(f'Error during SQL execution:\n{e}')
450
+
424
451
  def count(self) -> int:
425
452
  """Return the number of rows in the DataFrame.
426
453
 
@@ -540,10 +567,10 @@ class DataFrame:
540
567
  for raw_expr, name in base_list:
541
568
  if isinstance(raw_expr, exprs.Expr):
542
569
  select_list.append((raw_expr, name))
543
- elif isinstance(raw_expr, dict):
544
- select_list.append((exprs.InlineDict(raw_expr), name))
545
- elif isinstance(raw_expr, list):
546
- select_list.append((exprs.InlineList(raw_expr), name))
570
+ elif isinstance(raw_expr, (dict, list, tuple)):
571
+ select_list.append((exprs.Expr.from_object(raw_expr), name))
572
+ elif isinstance(raw_expr, np.ndarray):
573
+ select_list.append((exprs.Expr.from_array(raw_expr), name))
547
574
  else:
548
575
  select_list.append((exprs.Literal(raw_expr), name))
549
576
  expr = select_list[-1][0]
@@ -1031,8 +1058,6 @@ class DataFrame:
1031
1058
  else:
1032
1059
  return write_coco_dataset(self, dest_path)
1033
1060
 
1034
- # TODO Factor this out into a separate module.
1035
- # The return type is unresolvable, but torch can't be imported since it's an optional dependency.
1036
1061
  def to_pytorch_dataset(self, image_format: str = 'pt') -> 'torch.utils.data.IterableDataset':
1037
1062
  """
1038
1063
  Convert the dataframe to a pytorch IterableDataset suitable for parallel loading
pixeltable/env.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from abc import abstractmethod
3
4
  import datetime
4
5
  import glob
5
6
  import http.server
@@ -15,9 +16,9 @@ import sys
15
16
  import threading
16
17
  import uuid
17
18
  import warnings
18
- from dataclasses import dataclass
19
+ from dataclasses import dataclass, field
19
20
  from pathlib import Path
20
- from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar
21
+ from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Type
21
22
  from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
22
23
 
23
24
  import pixeltable_pgserver
@@ -33,6 +34,11 @@ if TYPE_CHECKING:
33
34
  import spacy
34
35
 
35
36
 
37
+ _logger = logging.getLogger('pixeltable')
38
+
39
+ T = TypeVar('T')
40
+
41
+
36
42
  class Env:
37
43
  """
38
44
  Store for runtime globals.
@@ -70,6 +76,8 @@ class Env:
70
76
  _stdout_handler: logging.StreamHandler
71
77
  _initialized: bool
72
78
 
79
+ _resource_pool_info: dict[str, Any]
80
+
73
81
  @classmethod
74
82
  def get(cls) -> Env:
75
83
  if cls._instance is None:
@@ -121,6 +129,8 @@ class Env:
121
129
  self._stdout_handler.setFormatter(logging.Formatter(self._log_fmt_str))
122
130
  self._initialized = False
123
131
 
132
+ self._resource_pool_info = {}
133
+
124
134
  @property
125
135
  def config(self) -> Config:
126
136
  assert self._config is not None
@@ -609,6 +619,16 @@ class Env:
609
619
  def create_tmp_path(self, extension: str = '') -> Path:
610
620
  return self._tmp_dir / f'{uuid.uuid4()}{extension}'
611
621
 
622
+
623
+ #def get_resource_pool_info(self, pool_id: str, pool_info_cls: Optional[Type[T]]) -> T:
624
+ def get_resource_pool_info(self, pool_id: str, make_pool_info: Optional[Callable[[], T]] = None) -> T:
625
+ """Returns the info object for the given id, creating it if necessary."""
626
+ info = self._resource_pool_info.get(pool_id)
627
+ if info is None and make_pool_info is not None:
628
+ info = make_pool_info()
629
+ self._resource_pool_info[pool_id] = info
630
+ return info
631
+
612
632
  @property
613
633
  def home(self) -> Path:
614
634
  assert self._home is not None
@@ -686,8 +706,6 @@ class Config:
686
706
  """
687
707
  __config: dict[str, Any]
688
708
 
689
- T = TypeVar('T')
690
-
691
709
  @classmethod
692
710
  def from_file(cls, path: Path) -> Config:
693
711
  """
@@ -767,3 +785,73 @@ class PackageInfo:
767
785
  is_installed: bool
768
786
  library_name: str # pypi library name (may be different from package name)
769
787
  version: Optional[list[int]] = None # installed version, as a list of components (such as [3,0,2] for "3.0.2")
788
+
789
+
790
+ TIME_FORMAT = '%H:%M.%S %f'
791
+
792
+
793
+ @dataclass
794
+ class RateLimitsInfo:
795
+ """
796
+ Abstract base class for resource pools made up of rate limits for different resources.
797
+
798
+ Rate limits and currently remaining resources are periodically reported via record().
799
+
800
+ Subclasses provide operational customization via:
801
+ - get_retry_delay()
802
+ - get_request_resources(self, ...) -> dict[str, int]
803
+ with parameters that are a subset of those of the udf that creates the subclass's instance
804
+ """
805
+
806
+ # get_request_resources:
807
+ # - Returns estimated resources needed for a specific request (ie, a single udf call) as a dict (key: resource name)
808
+ # - parameters are a subset of those of the udf
809
+ # - this is not a class method because the signature depends on the instantiating udf
810
+ get_request_resources: Callable[..., dict[str, int]]
811
+
812
+ resource_limits: dict[str, RateLimitInfo] = field(default_factory=dict)
813
+
814
+ def is_initialized(self) -> bool:
815
+ return len(self.resource_limits) > 0
816
+
817
+ def reset(self) -> None:
818
+ self.resource_limits.clear()
819
+
820
+ def record(self, **kwargs) -> None:
821
+ now = datetime.datetime.now(tz=datetime.timezone.utc)
822
+ if len(self.resource_limits) == 0:
823
+ self.resource_limits = {k: RateLimitInfo(k, now, *v) for k, v in kwargs.items() if v is not None}
824
+ # TODO: remove
825
+ for info in self.resource_limits.values():
826
+ _logger.debug(f'Init {info.resource} rate limit: rem={info.remaining} reset={info.reset_at.strftime(TIME_FORMAT)} delta={(info.reset_at - now).total_seconds()}')
827
+ else:
828
+ for k, v in kwargs.items():
829
+ if v is not None:
830
+ self.resource_limits[k].update(now, *v)
831
+
832
+ @abstractmethod
833
+ def get_retry_delay(self, exc: Exception) -> Optional[float]:
834
+ """Returns number of seconds to wait before retry, or None if not retryable"""
835
+ pass
836
+
837
+
838
+ @dataclass
839
+ class RateLimitInfo:
840
+ """Container for rate limit-related information for a single resource."""
841
+ resource: str
842
+ recorded_at: datetime.datetime
843
+ limit: int
844
+ remaining: int
845
+ reset_at: datetime.datetime
846
+
847
+ def update(self, recorded_at: datetime.datetime, limit: int, remaining: int, reset_at: datetime.datetime) -> None:
848
+ # we always update everything, even though responses may come back out-of-order: we can't use reset_at to
849
+ # determine order, because it doesn't increase monotonically (the reeset duration shortens as output_tokens
850
+ # are freed up - going from max to actual)
851
+ self.recorded_at = recorded_at
852
+ self.limit = limit
853
+ self.remaining = remaining
854
+ reset_delta = reset_at - self.reset_at
855
+ self.reset_at = reset_at
856
+ # TODO: remove
857
+ _logger.debug(f'Update {self.resource} rate limit: rem={self.remaining} reset={self.reset_at.strftime(TIME_FORMAT)} reset_delta={reset_delta.total_seconds()} recorded_delta={(self.reset_at - recorded_at).total_seconds()}')
@@ -4,7 +4,7 @@ from .component_iteration_node import ComponentIterationNode
4
4
  from .data_row_batch import DataRowBatch
5
5
  from .exec_context import ExecContext
6
6
  from .exec_node import ExecNode
7
- from .expr_eval_node import ExprEvalNode
8
7
  from .in_memory_data_node import InMemoryDataNode
9
8
  from .row_update_node import RowUpdateNode
10
9
  from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode, SqlJoinNode
10
+ from .expr_eval import ExprEvalNode
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import sys
5
- from typing import Any, Iterable, Iterator, Optional, cast
5
+ from typing import Any, Iterable, Iterator, Optional, cast, AsyncIterator
6
6
 
7
7
  import pixeltable.catalog as catalog
8
8
  import pixeltable.exceptions as excs
@@ -60,11 +60,11 @@ class AggregationNode(ExecNode):
60
60
  input_vals = [row[d.slot_idx] for d in fn_call.dependencies()]
61
61
  raise excs.ExprEvalError(fn_call, expr_msg, e, exc_tb, input_vals, row_num)
62
62
 
63
- def __iter__(self) -> Iterator[DataRowBatch]:
63
+ async def __aiter__(self) -> AsyncIterator[DataRowBatch]:
64
64
  prev_row: Optional[exprs.DataRow] = None
65
65
  current_group: Optional[list[Any]] = None # the values of the group-by exprs
66
66
  num_input_rows = 0
67
- for row_batch in self.input:
67
+ async for row_batch in self.input:
68
68
  num_input_rows += len(row_batch)
69
69
  for row in row_batch:
70
70
  group = [row[e.slot_idx] for e in self.group_by] if self.group_by is not None else None
@@ -9,7 +9,7 @@ import urllib.request
9
9
  from collections import deque
10
10
  from concurrent import futures
11
11
  from pathlib import Path
12
- from typing import Optional, Any, Iterator
12
+ from typing import Optional, Any, Iterator, AsyncIterator
13
13
  from uuid import UUID
14
14
 
15
15
  import pixeltable.env as env
@@ -79,12 +79,12 @@ class CachePrefetchNode(ExecNode):
79
79
  self.input_finished = False
80
80
  self.row_idx = itertools.count() if retain_input_order else itertools.repeat(None)
81
81
 
82
- def __iter__(self) -> Iterator[DataRowBatch]:
83
- input_iter = iter(self.input)
82
+ async def __aiter__(self) -> AsyncIterator[DataRowBatch]:
83
+ input_iter = self.input.__aiter__()
84
84
  with futures.ThreadPoolExecutor(max_workers=self.NUM_EXECUTOR_THREADS) as executor:
85
85
  # we create enough in-flight requests to fill the first batch
86
86
  while not self.input_finished and self.__num_pending_rows() < self.BATCH_SIZE:
87
- self.__submit_input_batch(input_iter, executor)
87
+ await self.__submit_input_batch(input_iter, executor)
88
88
 
89
89
  while True:
90
90
  # try to assemble a full batch of output rows
@@ -93,7 +93,7 @@ class CachePrefetchNode(ExecNode):
93
93
 
94
94
  # try to create enough in-flight requests to fill the next batch
95
95
  while not self.input_finished and self.__num_pending_rows() < self.BATCH_SIZE:
96
- self.__submit_input_batch(input_iter, executor)
96
+ await self.__submit_input_batch(input_iter, executor)
97
97
 
98
98
  if len(self.ready_rows) > 0:
99
99
  # create DataRowBatch from the first BATCH_SIZE ready rows
@@ -163,9 +163,15 @@ class CachePrefetchNode(ExecNode):
163
163
  self.__add_ready_row(row, state.idx)
164
164
  _logger.debug(f'row {state.idx} is ready (ready_batch_size={self.__ready_prefix_len()})')
165
165
 
166
- def __submit_input_batch(self, input: Iterator[DataRowBatch], executor: futures.ThreadPoolExecutor) -> None:
166
+ async def __submit_input_batch(
167
+ self, input: AsyncIterator[DataRowBatch], executor: futures.ThreadPoolExecutor
168
+ ) -> None:
167
169
  assert not self.input_finished
168
- input_batch = next(input, None)
170
+ input_batch: Optional[DataRowBatch]
171
+ try:
172
+ input_batch = await input.__anext__()
173
+ except StopAsyncIteration:
174
+ input_batch = None
169
175
  if input_batch is None:
170
176
  self.input_finished = True
171
177
  return
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import Iterator, Optional
2
+ from typing import Iterator, Optional, AsyncIterator
3
3
 
4
4
  import pixeltable.catalog as catalog
5
5
  import pixeltable.exceptions as excs
@@ -37,11 +37,10 @@ class ComponentIterationNode(ExecNode):
37
37
  e.col.name: e.slot_idx for e in self.row_builder.unique_exprs
38
38
  if isinstance(e, exprs.ColumnRef) and e.col.name in self.iterator_output_fields
39
39
  }
40
- self.__output: Optional[Iterator[DataRowBatch]] = None
41
40
 
42
- def __output_batches(self) -> Iterator[DataRowBatch]:
41
+ async def __aiter__(self) -> AsyncIterator[DataRowBatch]:
43
42
  output_batch = DataRowBatch(self.view, self.row_builder)
44
- for input_batch in self.input:
43
+ async for input_batch in self.input:
45
44
  for input_row in input_batch:
46
45
  self.row_builder.eval(input_row, self.iterator_args_ctx)
47
46
  iterator_args = input_row[self.iterator_args.slot_idx]
@@ -93,8 +92,3 @@ class ComponentIterationNode(ExecNode):
93
92
  raise excs.Error(
94
93
  f'Invalid output of {self.view.iterator_cls.__name__}: '
95
94
  f'missing fields {", ".join(missing_fields)}')
96
-
97
- def __next__(self) -> DataRowBatch:
98
- if self.__output is None:
99
- self.__output = self.__output_batches()
100
- return next(self.__output)
@@ -21,7 +21,14 @@ class DataRowBatch:
21
21
  array_slot_idxs: list[int]
22
22
  rows: list[exprs.DataRow]
23
23
 
24
- def __init__(self, tbl: Optional[catalog.TableVersion], row_builder: exprs.RowBuilder, len: int = 0):
24
+ def __init__(
25
+ self, tbl: Optional[catalog.TableVersion], row_builder: exprs.RowBuilder, num_rows: Optional[int] = None,
26
+ rows: Optional[list[exprs.DataRow]] = None
27
+ ):
28
+ """
29
+ Requires either num_rows or rows to be specified, but not both.
30
+ """
31
+ assert num_rows is None or rows is None
25
32
  self.tbl = tbl
26
33
  self.row_builder = row_builder
27
34
  self.img_slot_idxs = [e.slot_idx for e in row_builder.unique_exprs if e.col_type.is_image_type()]
@@ -31,10 +38,15 @@ class DataRowBatch:
31
38
  if e.col_type.is_media_type() and not e.col_type.is_image_type()
32
39
  ]
33
40
  self.array_slot_idxs = [e.slot_idx for e in row_builder.unique_exprs if e.col_type.is_array_type()]
34
- self.rows = [
35
- exprs.DataRow(row_builder.num_materialized, self.img_slot_idxs, self.media_slot_idxs, self.array_slot_idxs)
36
- for _ in range(len)
37
- ]
41
+ if rows is not None:
42
+ self.rows = rows
43
+ else:
44
+ if num_rows is None:
45
+ num_rows = 0
46
+ self.rows = [
47
+ exprs.DataRow(row_builder.num_materialized, self.img_slot_idxs, self.media_slot_idxs, self.array_slot_idxs)
48
+ for _ in range(num_rows)
49
+ ]
38
50
 
39
51
  def add_row(self, row: Optional[exprs.DataRow] = None) -> exprs.DataRow:
40
52
  if row is None:
@@ -1,13 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
- from typing import TYPE_CHECKING, Iterable, Iterator, Optional, TypeVar
4
+ import asyncio
5
+ import logging
6
+ import sys
7
+ from typing import Iterable, Iterator, Optional, TypeVar, AsyncIterator
5
8
 
6
9
  import pixeltable.exprs as exprs
7
-
8
10
  from .data_row_batch import DataRowBatch
9
11
  from .exec_context import ExecContext
10
12
 
13
+ _logger = logging.getLogger('pixeltable')
11
14
 
12
15
  class ExecNode(abc.ABC):
13
16
  """Base class of all execution nodes"""
@@ -17,7 +20,6 @@ class ExecNode(abc.ABC):
17
20
  flushed_img_slots: list[int] # idxs of image slots of our output_exprs dependencies
18
21
  stored_img_cols: list[exprs.ColumnSlotIdx]
19
22
  ctx: Optional[ExecContext]
20
- __iter: Optional[Iterator[DataRowBatch]]
21
23
 
22
24
  def __init__(
23
25
  self, row_builder: exprs.RowBuilder, output_exprs: Iterable[exprs.Expr],
@@ -34,7 +36,6 @@ class ExecNode(abc.ABC):
34
36
  ]
35
37
  self.stored_img_cols = []
36
38
  self.ctx = None # all nodes of a tree share the same context
37
- self.__iter = None
38
39
 
39
40
  def set_ctx(self, ctx: ExecContext) -> None:
40
41
  self.ctx = ctx
@@ -47,15 +48,34 @@ class ExecNode(abc.ABC):
47
48
  if self.input is not None:
48
49
  self.input.set_stored_img_cols(stored_img_cols)
49
50
 
50
- # TODO: make this an abstractmethod when __next__() is removed
51
- def __iter__(self) -> Iterator[DataRowBatch]:
52
- return self
51
+ @abc.abstractmethod
52
+ def __aiter__(self) -> AsyncIterator[DataRowBatch]:
53
+ pass
53
54
 
54
- # TODO: remove this and switch every subclass over to implementing __iter__
55
- def __next__(self) -> DataRowBatch:
56
- if self.__iter is None:
57
- self.__iter = iter(self)
58
- return next(self.__iter)
55
+ def __iter__(self) -> Iterator[DataRowBatch]:
56
+ try:
57
+ # check if we are already in an event loop (eg, Jupyter's); if so, patch it to allow nested event loops
58
+ _ = asyncio.get_event_loop()
59
+ import nest_asyncio # type: ignore
60
+ nest_asyncio.apply()
61
+ except RuntimeError:
62
+ pass
63
+
64
+ loop = asyncio.new_event_loop()
65
+ asyncio.set_event_loop(loop)
66
+
67
+ if 'pytest' in sys.modules:
68
+ loop.set_debug(True)
69
+
70
+ aiter = self.__aiter__()
71
+ try:
72
+ while True:
73
+ batch: DataRowBatch = loop.run_until_complete(aiter.__anext__())
74
+ yield batch
75
+ except StopAsyncIteration:
76
+ pass
77
+ finally:
78
+ loop.close()
59
79
 
60
80
  def open(self) -> None:
61
81
  """Bottom-up initialization of nodes for execution. Must be called before __next__."""
@@ -0,0 +1 @@
1
+ from .expr_eval_node import ExprEvalNode