pixeltable 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (84) hide show
  1. pixeltable/__init__.py +7 -19
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +7 -7
  4. pixeltable/catalog/globals.py +3 -0
  5. pixeltable/catalog/table.py +208 -145
  6. pixeltable/catalog/table_version.py +36 -18
  7. pixeltable/catalog/table_version_path.py +0 -8
  8. pixeltable/catalog/view.py +3 -3
  9. pixeltable/dataframe.py +9 -24
  10. pixeltable/env.py +1 -1
  11. pixeltable/exec/__init__.py +1 -1
  12. pixeltable/exec/aggregation_node.py +22 -15
  13. pixeltable/exec/data_row_batch.py +7 -7
  14. pixeltable/exec/exec_node.py +35 -7
  15. pixeltable/exec/expr_eval_node.py +2 -1
  16. pixeltable/exec/in_memory_data_node.py +9 -9
  17. pixeltable/exec/sql_node.py +265 -136
  18. pixeltable/exprs/__init__.py +1 -0
  19. pixeltable/exprs/data_row.py +30 -19
  20. pixeltable/exprs/expr.py +15 -14
  21. pixeltable/exprs/expr_dict.py +55 -0
  22. pixeltable/exprs/expr_set.py +21 -15
  23. pixeltable/exprs/function_call.py +21 -8
  24. pixeltable/exprs/rowid_ref.py +2 -2
  25. pixeltable/exprs/sql_element_cache.py +5 -1
  26. pixeltable/ext/functions/whisperx.py +7 -2
  27. pixeltable/func/callable_function.py +2 -2
  28. pixeltable/func/function_registry.py +6 -7
  29. pixeltable/func/query_template_function.py +11 -12
  30. pixeltable/func/signature.py +17 -15
  31. pixeltable/func/udf.py +0 -4
  32. pixeltable/functions/__init__.py +1 -1
  33. pixeltable/functions/audio.py +4 -6
  34. pixeltable/functions/globals.py +86 -42
  35. pixeltable/functions/huggingface.py +12 -14
  36. pixeltable/functions/image.py +59 -45
  37. pixeltable/functions/json.py +0 -1
  38. pixeltable/functions/mistralai.py +2 -2
  39. pixeltable/functions/openai.py +22 -25
  40. pixeltable/functions/string.py +50 -50
  41. pixeltable/functions/timestamp.py +20 -20
  42. pixeltable/functions/together.py +2 -2
  43. pixeltable/functions/video.py +11 -20
  44. pixeltable/functions/whisper.py +2 -20
  45. pixeltable/globals.py +55 -56
  46. pixeltable/index/base.py +2 -2
  47. pixeltable/index/btree.py +7 -7
  48. pixeltable/index/embedding_index.py +8 -10
  49. pixeltable/io/external_store.py +11 -5
  50. pixeltable/io/globals.py +2 -0
  51. pixeltable/io/hf_datasets.py +1 -1
  52. pixeltable/io/label_studio.py +6 -6
  53. pixeltable/io/parquet.py +14 -13
  54. pixeltable/iterators/document.py +9 -7
  55. pixeltable/iterators/video.py +10 -1
  56. pixeltable/metadata/__init__.py +3 -2
  57. pixeltable/metadata/converters/convert_14.py +4 -2
  58. pixeltable/metadata/converters/convert_15.py +1 -1
  59. pixeltable/metadata/converters/convert_19.py +1 -0
  60. pixeltable/metadata/converters/convert_20.py +1 -1
  61. pixeltable/metadata/converters/util.py +9 -8
  62. pixeltable/metadata/schema.py +32 -21
  63. pixeltable/plan.py +136 -154
  64. pixeltable/store.py +51 -36
  65. pixeltable/tool/create_test_db_dump.py +6 -6
  66. pixeltable/tool/doc_plugins/griffe.py +3 -34
  67. pixeltable/tool/mypy_plugin.py +32 -0
  68. pixeltable/type_system.py +243 -60
  69. pixeltable/utils/arrow.py +10 -9
  70. pixeltable/utils/coco.py +4 -4
  71. pixeltable/utils/documents.py +1 -1
  72. pixeltable/utils/filecache.py +9 -9
  73. pixeltable/utils/formatter.py +1 -1
  74. pixeltable/utils/http_server.py +2 -5
  75. pixeltable/utils/media_store.py +6 -6
  76. pixeltable/utils/pytorch.py +10 -11
  77. pixeltable/utils/sql.py +2 -1
  78. {pixeltable-0.2.20.dist-info → pixeltable-0.2.21.dist-info}/METADATA +6 -5
  79. pixeltable-0.2.21.dist-info/RECORD +148 -0
  80. pixeltable/utils/help.py +0 -11
  81. pixeltable-0.2.20.dist-info/RECORD +0 -147
  82. {pixeltable-0.2.20.dist-info → pixeltable-0.2.21.dist-info}/LICENSE +0 -0
  83. {pixeltable-0.2.20.dist-info → pixeltable-0.2.21.dist-info}/WHEEL +0 -0
  84. {pixeltable-0.2.20.dist-info → pixeltable-0.2.21.dist-info}/entry_points.txt +0 -0
@@ -6,7 +6,7 @@ import inspect
6
6
  import logging
7
7
  import time
8
8
  import uuid
9
- from typing import TYPE_CHECKING, Any, Iterable, Optional
9
+ from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional
10
10
  from uuid import UUID
11
11
 
12
12
  import sqlalchemy as sql
@@ -453,7 +453,9 @@ class TableVersion:
453
453
  self.idxs_by_name[idx_name] = idx_info
454
454
 
455
455
  # add the columns and update the metadata
456
- status = self._add_columns([val_col, undo_col], conn)
456
+ # TODO support on_error='abort' for indices; it's tricky because of the way metadata changes are entangled
457
+ # with the database operations
458
+ status = self._add_columns([val_col, undo_col], conn, print_stats=False, on_error='ignore')
457
459
  # now create the index structure
458
460
  idx.create_index(self._store_idx_name(idx_id), val_col, conn)
459
461
 
@@ -478,7 +480,7 @@ class TableVersion:
478
480
  self._update_md(time.time(), conn, preceding_schema_version=preceding_schema_version)
479
481
  _logger.info(f'Dropped index {idx_md.name} on table {self.name}')
480
482
 
481
- def add_column(self, col: Column, print_stats: bool = False) -> UpdateStatus:
483
+ def add_column(self, col: Column, print_stats: bool, on_error: Literal['abort', 'ignore']) -> UpdateStatus:
482
484
  """Adds a column to the table.
483
485
  """
484
486
  assert not self.is_snapshot
@@ -498,9 +500,8 @@ class TableVersion:
498
500
  preceding_schema_version = self.schema_version
499
501
  self.schema_version = self.version
500
502
  with Env.get().engine.begin() as conn:
501
- status = self._add_columns([col], conn, print_stats=print_stats)
503
+ status = self._add_columns([col], conn, print_stats=print_stats, on_error=on_error)
502
504
  _ = self._add_default_index(col, conn)
503
- # TODO: what to do about errors?
504
505
  self._update_md(time.time(), conn, preceding_schema_version=preceding_schema_version)
505
506
  _logger.info(f'Added column {col.name} to table {self.name}, new version: {self.version}')
506
507
 
@@ -512,7 +513,13 @@ class TableVersion:
512
513
  _logger.info(f'Column {col.name}: {msg}')
513
514
  return status
514
515
 
515
- def _add_columns(self, cols: Iterable[Column], conn: sql.engine.Connection, print_stats: bool = False) -> UpdateStatus:
516
+ def _add_columns(
517
+ self,
518
+ cols: Iterable[Column],
519
+ conn: sql.engine.Connection,
520
+ print_stats: bool,
521
+ on_error: Literal['abort', 'ignore']
522
+ ) -> UpdateStatus:
516
523
  """Add and populate columns within the current transaction"""
517
524
  cols = list(cols)
518
525
  row_count = self.store_tbl.count(conn=conn)
@@ -550,10 +557,14 @@ class TableVersion:
550
557
  try:
551
558
  plan.ctx.set_conn(conn)
552
559
  plan.open()
553
- num_excs = self.store_tbl.load_column(col, plan, value_expr_slot_idx, conn)
560
+ try:
561
+ num_excs = self.store_tbl.load_column(col, plan, value_expr_slot_idx, conn, on_error)
562
+ except sql.exc.DBAPIError as exc:
563
+ # Wrap the DBAPIError in an excs.Error to unify processing in the subsequent except block
564
+ raise excs.Error(f'SQL error during execution of computed column `{col.name}`:\n{exc}') from exc
554
565
  if num_excs > 0:
555
566
  cols_with_excs.append(col)
556
- except sql.exc.DBAPIError as e:
567
+ except excs.Error as exc:
557
568
  self.cols.pop()
558
569
  for col in cols:
559
570
  # remove columns that we already added
@@ -564,7 +575,7 @@ class TableVersion:
564
575
  del self.cols_by_id[col.id]
565
576
  # we need to re-initialize the sqlalchemy schema
566
577
  self.store_tbl.create_sa_tbl()
567
- raise excs.Error(f'Error during SQL execution:\n{e}')
578
+ raise exc
568
579
  finally:
569
580
  plan.close()
570
581
 
@@ -689,21 +700,30 @@ class TableVersion:
689
700
  plan = Planner.create_insert_plan(self, rows, ignore_errors=not fail_on_exception)
690
701
  else:
691
702
  plan = Planner.create_df_insert_plan(self, df, ignore_errors=not fail_on_exception)
703
+
704
+ # this is a base table; we generate rowids during the insert
705
+ def rowids() -> Iterator[int]:
706
+ while True:
707
+ rowid = self.next_rowid
708
+ self.next_rowid += 1
709
+ yield rowid
710
+
692
711
  if conn is None:
693
712
  with Env.get().engine.begin() as conn:
694
- return self._insert(plan, conn, time.time(), print_stats)
713
+ return self._insert(plan, conn, time.time(), print_stats=print_stats, rowids=rowids())
695
714
  else:
696
- return self._insert(plan, conn, time.time(), print_stats)
715
+ return self._insert(plan, conn, time.time(), print_stats=print_stats, rowids=rowids())
697
716
 
698
717
  def _insert(
699
- self, exec_plan: 'exec.ExecNode', conn: sql.engine.Connection, timestamp: float, print_stats: bool = False,
718
+ self, exec_plan: 'exec.ExecNode', conn: sql.engine.Connection, timestamp: float, *,
719
+ rowids: Optional[Iterator[int]] = None, print_stats: bool = False,
700
720
  ) -> UpdateStatus:
701
721
  """Insert rows produced by exec_plan and propagate to views"""
702
722
  # we're creating a new version
703
723
  self.version += 1
704
724
  result = UpdateStatus()
705
- num_rows, num_excs, cols_with_excs = self.store_tbl.insert_rows(exec_plan, conn, v_min=self.version)
706
- self.next_rowid = num_rows
725
+ num_rows, num_excs, cols_with_excs = self.store_tbl.insert_rows(
726
+ exec_plan, conn, v_min=self.version, rowids=rowids)
707
727
  result.num_rows = num_rows
708
728
  result.num_excs = num_excs
709
729
  result.num_computed_values += exec_plan.ctx.num_computed_exprs * num_rows
@@ -714,7 +734,7 @@ class TableVersion:
714
734
  for view in self.mutable_views:
715
735
  from pixeltable.plan import Planner
716
736
  plan, _ = Planner.create_view_load_plan(view.path, propagates_insert=True)
717
- status = view._insert(plan, conn, timestamp, print_stats)
737
+ status = view._insert(plan, conn, timestamp, print_stats=print_stats)
718
738
  result.num_rows += status.num_rows
719
739
  result.num_excs += status.num_excs
720
740
  result.num_computed_values += status.num_computed_values
@@ -751,9 +771,7 @@ class TableVersion:
751
771
  raise excs.Error(f'Filter {analysis_info.filter} not expressible in SQL')
752
772
 
753
773
  with Env.get().engine.begin() as conn:
754
- plan, updated_cols, recomputed_cols = (
755
- Planner.create_update_plan(self.path, update_spec, [], where, cascade)
756
- )
774
+ plan, updated_cols, recomputed_cols = Planner.create_update_plan(self.path, update_spec, [], where, cascade)
757
775
  from pixeltable.exprs import SqlElementCache
758
776
  result = self.propagate_update(
759
777
  plan, where.sql_expr(SqlElementCache()) if where is not None else None, recomputed_cols,
@@ -91,14 +91,6 @@ class TableVersionPath:
91
91
  col = self.tbl_version.cols_by_name[col_name]
92
92
  return ColumnRef(col)
93
93
 
94
- def __getitem__(self, index: object) -> Union[exprs.ColumnRef, pxt.DataFrame]:
95
- """Return a ColumnRef for the given column name, or a DataFrame for the given slice.
96
- """
97
- if isinstance(index, str):
98
- # basically <tbl>.<colname>
99
- return self.__getattr__(index)
100
- return pxt.DataFrame(self).__getitem__(index)
101
-
102
94
  def columns(self) -> list[Column]:
103
95
  """Return all user columns visible in this tbl version path, including columns from bases"""
104
96
  result = list(self.tbl_version.cols_by_name.values())
@@ -52,11 +52,11 @@ class View(Table):
52
52
 
53
53
  @classmethod
54
54
  def _create(
55
- cls, dir_id: UUID, name: str, base: TableVersionPath, schema: Dict[str, Any],
56
- predicate: 'pxt.exprs.Expr', is_snapshot: bool, num_retained_versions: int, comment: str,
55
+ cls, dir_id: UUID, name: str, base: TableVersionPath, additional_columns: Dict[str, Any],
56
+ predicate: Optional['pxt.exprs.Expr'], is_snapshot: bool, num_retained_versions: int, comment: str,
57
57
  iterator_cls: Optional[Type[ComponentIterator]], iterator_args: Optional[Dict]
58
58
  ) -> View:
59
- columns = cls._create_columns(schema)
59
+ columns = cls._create_columns(additional_columns)
60
60
  cls._verify_schema(columns)
61
61
 
62
62
  # verify that filter can be evaluated in the context of the base
pixeltable/dataframe.py CHANGED
@@ -8,7 +8,7 @@ import logging
8
8
  import mimetypes
9
9
  import traceback
10
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterator, List, Optional, Set, Tuple
11
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterator, List, Optional, Sequence, Set, Tuple, Union
12
12
 
13
13
  import pandas as pd
14
14
  import pandas.io.formats.style
@@ -97,8 +97,8 @@ class DataFrameResultSet:
97
97
  return self._rows[index[0]][col_idx]
98
98
  raise excs.Error(f'Bad index: {index}')
99
99
 
100
- def __iter__(self) -> DataFrameResultSetIterator:
101
- return DataFrameResultSetIterator(self)
100
+ def __iter__(self) -> Iterator[dict[str, Any]]:
101
+ return (self._row_to_dict(i) for i in range(len(self)))
102
102
 
103
103
  def __eq__(self, other):
104
104
  if not isinstance(other, DataFrameResultSet):
@@ -106,19 +106,6 @@ class DataFrameResultSet:
106
106
  return self.to_pandas().equals(other.to_pandas())
107
107
 
108
108
 
109
- class DataFrameResultSetIterator:
110
- def __init__(self, result_set: DataFrameResultSet):
111
- self._result_set = result_set
112
- self._idx = 0
113
-
114
- def __next__(self) -> Dict[str, Any]:
115
- if self._idx >= len(self._result_set):
116
- raise StopIteration
117
- row = self._result_set._row_to_dict(self._idx)
118
- self._idx += 1
119
- return row
120
-
121
-
122
109
  # # TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
123
110
  # class AnalysisInfo:
124
111
  # def __init__(self, tbl: catalog.TableVersion):
@@ -296,7 +283,7 @@ class DataFrame:
296
283
 
297
284
  def _create_query_plan(self) -> exec.ExecNode:
298
285
  # construct a group-by clause if we're grouping by a table
299
- group_by_clause: List[exprs.Expr] = []
286
+ group_by_clause: Optional[list[exprs.Expr]] = None
300
287
  if self.grouping_tbl is not None:
301
288
  assert self.group_by_clause is None
302
289
  num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
@@ -315,8 +302,8 @@ class DataFrame:
315
302
  where_clause=self.where_clause,
316
303
  group_by_clause=group_by_clause,
317
304
  order_by_clause=self.order_by_clause if self.order_by_clause is not None else [],
318
- limit=self.limit_val if self.limit_val is not None else 0,
319
- ) # limit_val == 0: no limit_val
305
+ limit=self.limit_val
306
+ )
320
307
 
321
308
 
322
309
  def show(self, n: int = 20) -> DataFrameResultSet:
@@ -629,17 +616,15 @@ class DataFrame:
629
616
  if self.limit_val is not None:
630
617
  raise excs.Error(f'Cannot use `{op_name}` after `limit`')
631
618
 
632
- def __getitem__(self, index: object) -> DataFrame:
619
+ def __getitem__(self, index: Union[exprs.Expr, Sequence[exprs.Expr]]) -> DataFrame:
633
620
  """
634
621
  Allowed:
635
622
  - [List[Expr]]/[Tuple[Expr]]: setting the select list
636
623
  - [Expr]: setting a single-col select list
637
624
  """
638
- if isinstance(index, tuple):
639
- index = list(index)
640
625
  if isinstance(index, exprs.Expr):
641
- index = [index]
642
- if isinstance(index, list):
626
+ return self.select(index)
627
+ if isinstance(index, Sequence):
643
628
  return self.select(*index)
644
629
  raise TypeError(f'Invalid index type: {type(index)}')
645
630
 
pixeltable/env.py CHANGED
@@ -342,7 +342,7 @@ class Env:
342
342
 
343
343
  if create_db:
344
344
  from pixeltable.metadata import schema
345
- schema.Base.metadata.create_all(self._sa_engine)
345
+ schema.base_metadata.create_all(self._sa_engine)
346
346
  metadata.create_system_info(self._sa_engine)
347
347
 
348
348
  print(f'Connected to Pixeltable database at: {self.db_url}')
@@ -8,4 +8,4 @@ from .expr_eval_node import ExprEvalNode
8
8
  from .in_memory_data_node import InMemoryDataNode
9
9
  from .media_validation_node import MediaValidationNode
10
10
  from .row_update_node import RowUpdateNode
11
- from .sql_node import SqlLookupNode, SqlScanNode
11
+ from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import sys
5
- from typing import Iterable, List, Optional, Any
5
+ from typing import Iterable, Optional, Any, Iterator
6
6
 
7
7
  import pixeltable.catalog as catalog
8
8
  import pixeltable.exceptions as excs
@@ -13,17 +13,29 @@ from .exec_node import ExecNode
13
13
  _logger = logging.getLogger('pixeltable')
14
14
 
15
15
  class AggregationNode(ExecNode):
16
+ """
17
+ In-memory aggregation for UDAs.
18
+
19
+ At the moment, this returns all results in a single DataRowBatch.
20
+ """
21
+ group_by: Optional[list[exprs.Expr]]
22
+ input_exprs: list[exprs.Expr]
23
+ agg_fn_eval_ctx: exprs.RowBuilder.EvalCtx
24
+ agg_fn_calls: list[exprs.FunctionCall]
25
+ output_batch: DataRowBatch
26
+
16
27
  def __init__(
17
- self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, group_by: List[exprs.Expr],
18
- agg_fn_calls: List[exprs.FunctionCall], input_exprs: Iterable[exprs.Expr], input: ExecNode
28
+ self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, group_by: Optional[list[exprs.Expr]],
29
+ agg_fn_calls: list[exprs.FunctionCall], input_exprs: Iterable[exprs.Expr], input: ExecNode
19
30
  ):
20
31
  super().__init__(row_builder, group_by + agg_fn_calls, input_exprs, input)
21
32
  self.input = input
22
33
  self.group_by = group_by
23
34
  self.input_exprs = list(input_exprs)
24
- self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=input_exprs)
35
+ self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=self.input_exprs)
25
36
  # we need to make sure to refer to the same exprs that RowBuilder.eval() will use
26
37
  self.agg_fn_calls = self.agg_fn_eval_ctx.target_exprs
38
+ # create output_batch here, rather than in __iter__(), so we don't need to remember tbl and row_builder
27
39
  self.output_batch = DataRowBatch(tbl, row_builder, 0)
28
40
 
29
41
  def _reset_agg_state(self, row_num: int) -> None:
@@ -45,17 +57,14 @@ class AggregationNode(ExecNode):
45
57
  input_vals = [row[d.slot_idx] for d in fn_call.dependencies()]
46
58
  raise excs.ExprEvalError(fn_call, expr_msg, e, exc_tb, input_vals, row_num)
47
59
 
48
- def __next__(self) -> DataRowBatch:
49
- if self.output_batch is None:
50
- raise StopIteration
51
-
60
+ def __iter__(self) -> Iterator[DataRowBatch]:
52
61
  prev_row: Optional[exprs.DataRow] = None
53
- current_group: Optional[List[Any]] = None # the values of the group-by exprs
62
+ current_group: Optional[list[Any]] = None # the values of the group-by exprs
54
63
  num_input_rows = 0
55
64
  for row_batch in self.input:
56
65
  num_input_rows += len(row_batch)
57
66
  for row in row_batch:
58
- group = [row[e.slot_idx] for e in self.group_by]
67
+ group = [row[e.slot_idx] for e in self.group_by] if self.group_by is not None else None
59
68
  if current_group is None:
60
69
  current_group = group
61
70
  self._reset_agg_state(0)
@@ -71,9 +80,7 @@ class AggregationNode(ExecNode):
71
80
  self.row_builder.eval(prev_row, self.agg_fn_eval_ctx, profile=self.ctx.profile)
72
81
  self.output_batch.add_row(prev_row)
73
82
 
74
- result = self.output_batch
75
- result.flush_imgs(None, self.stored_img_cols, self.flushed_img_slots)
76
- self.output_batch = None
77
- _logger.debug(f'AggregateNode: consumed {num_input_rows} rows, returning {len(result.rows)} rows')
78
- return result
83
+ self.output_batch.flush_imgs(None, self.stored_img_cols, self.flushed_img_slots)
84
+ _logger.debug(f'AggregateNode: consumed {num_input_rows} rows, returning {len(self.output_batch.rows)} rows')
85
+ yield self.output_batch
79
86
 
@@ -14,6 +14,13 @@ class DataRowBatch:
14
14
 
15
15
  Contains the metadata needed to initialize DataRows.
16
16
  """
17
+ tbl: Optional[catalog.TableVersion]
18
+ row_builder: exprs.RowBuilder
19
+ img_slot_idxs: list[int]
20
+ media_slot_idxs: list[int] # non-image media slots
21
+ array_slot_idxs: list[int]
22
+ rows: list[exprs.DataRow]
23
+
17
24
  def __init__(self, tbl: Optional[catalog.TableVersion], row_builder: exprs.RowBuilder, len: int = 0):
18
25
  self.tbl = tbl
19
26
  self.row_builder = row_builder
@@ -39,13 +46,6 @@ class DataRowBatch:
39
46
  def pop_row(self) -> exprs.DataRow:
40
47
  return self.rows.pop()
41
48
 
42
- def set_row_ids(self, row_ids: List[int]) -> None:
43
- """Sets pks for rows in batch"""
44
- assert self.tbl is not None
45
- assert len(row_ids) == len(self.rows)
46
- for row, row_id in zip(self.rows, row_ids):
47
- row.set_pk((row_id, self.tbl))
48
-
49
49
  def __len__(self) -> int:
50
50
  return len(self.rows)
51
51
 
@@ -1,13 +1,25 @@
1
1
  from __future__ import annotations
2
- from typing import Iterable, Optional, List
2
+
3
3
  import abc
4
+ from typing import Iterable, Optional, List, TYPE_CHECKING, Iterator
4
5
 
6
+ import pixeltable.exprs as exprs
5
7
  from .data_row_batch import DataRowBatch
6
8
  from .exec_context import ExecContext
7
- import pixeltable.exprs as exprs
9
+
10
+ if TYPE_CHECKING:
11
+ from pixeltable import exec
8
12
 
9
13
  class ExecNode(abc.ABC):
10
14
  """Base class of all execution nodes"""
15
+ output_exprs: Iterable[exprs.Expr]
16
+ row_builder: exprs.RowBuilder
17
+ input: Optional[ExecNode]
18
+ flushed_img_slots: list[int] # idxs of image slots of our output_exprs dependencies
19
+ stored_img_cols: list[exprs.ColumnSlotIdx]
20
+ ctx: Optional[ExecContext]
21
+ __iter: Optional[Iterator[DataRowBatch]]
22
+
11
23
  def __init__(
12
24
  self, row_builder: exprs.RowBuilder, output_exprs: Iterable[exprs.Expr],
13
25
  input_exprs: Iterable[exprs.Expr], input: Optional[ExecNode] = None):
@@ -21,8 +33,9 @@ class ExecNode(abc.ABC):
21
33
  e.slot_idx for e in output_dependencies
22
34
  if e.col_type.is_image_type() and e.slot_idx not in output_slot_idxs
23
35
  ]
24
- self.stored_img_cols: List[exprs.ColumnSlotIdx] = []
25
- self.ctx: Optional[ExecContext] = None # all nodes of a tree share the same context
36
+ self.stored_img_cols = []
37
+ self.ctx = None # all nodes of a tree share the same context
38
+ self.__iter = None
26
39
 
27
40
  def set_ctx(self, ctx: ExecContext) -> None:
28
41
  self.ctx = ctx
@@ -35,12 +48,15 @@ class ExecNode(abc.ABC):
35
48
  if self.input is not None:
36
49
  self.input.set_stored_img_cols(stored_img_cols)
37
50
 
38
- def __iter__(self):
51
+ # TODO: make this an abstractmethod when __next__() is removed
52
+ def __iter__(self) -> Iterator[DataRowBatch]:
39
53
  return self
40
54
 
41
- @abc.abstractmethod
55
+ # TODO: remove this and switch every subclass over to implementing __iter__
42
56
  def __next__(self) -> DataRowBatch:
43
- pass
57
+ if self.__iter is None:
58
+ self.__iter = iter(self)
59
+ return next(self.__iter)
44
60
 
45
61
  def open(self) -> None:
46
62
  """Bottom-up initialization of nodes for execution. Must be called before __next__."""
@@ -60,3 +76,15 @@ class ExecNode(abc.ABC):
60
76
  def _close(self) -> None:
61
77
  pass
62
78
 
79
+ def get_sql_node(self) -> Optional['exec.SqlNode']:
80
+ from .sql_node import SqlNode
81
+ if isinstance(self, SqlNode):
82
+ return self
83
+ if self.input is not None:
84
+ return self.input.get_sql_node()
85
+ return None
86
+
87
+ def set_limit(self, limit: int) -> None:
88
+ """Default implementation propagates to input"""
89
+ if self.input is not None:
90
+ self.input.set_limit(limit)
@@ -5,10 +5,11 @@ import warnings
5
5
  from dataclasses import dataclass
6
6
  from typing import Iterable, List, Optional
7
7
 
8
- from tqdm import tqdm, TqdmWarning
8
+ from tqdm import TqdmWarning, tqdm
9
9
 
10
10
  import pixeltable.exprs as exprs
11
11
  from pixeltable.func import CallableFunction
12
+
12
13
  from .data_row_batch import DataRowBatch
13
14
  from .exec_node import ExecNode
14
15
 
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Any, Optional
2
+ from typing import Any, Optional, Iterator
3
3
 
4
4
  import pixeltable.catalog as catalog
5
5
  import pixeltable.exprs as exprs
@@ -18,6 +18,11 @@ class InMemoryDataNode(ExecNode):
18
18
  - with the values provided in the input rows
19
19
  - if an input row doesn't provide a value, sets the slot to the column default
20
20
  """
21
+ tbl: catalog.TableVersion
22
+ input_rows: list[dict[str, Any]]
23
+ start_row_id: int
24
+ output_rows: Optional[DataRowBatch]
25
+
21
26
  def __init__(
22
27
  self, tbl: catalog.TableVersion, rows: list[dict[str, Any]],
23
28
  row_builder: exprs.RowBuilder, start_row_id: int,
@@ -29,8 +34,7 @@ class InMemoryDataNode(ExecNode):
29
34
  self.tbl = tbl
30
35
  self.input_rows = rows
31
36
  self.start_row_id = start_row_id
32
- self.has_returned_data = False
33
- self.output_rows: Optional[DataRowBatch] = None
37
+ self.output_rows = None
34
38
 
35
39
  def _open(self) -> None:
36
40
  """Create row batch and populate with self.input_rows"""
@@ -67,12 +71,8 @@ class InMemoryDataNode(ExecNode):
67
71
  assert col_info is not None
68
72
  self.output_rows[row_idx][col_info.slot_idx] = None
69
73
 
70
- self.output_rows.set_row_ids([self.start_row_id + i for i in range(len(self.output_rows))])
71
74
  self.ctx.num_rows = len(self.output_rows)
72
75
 
73
- def __next__(self) -> DataRowBatch:
74
- if self.has_returned_data:
75
- raise StopIteration
76
- self.has_returned_data = True
76
+ def __iter__(self) -> Iterator[DataRowBatch]:
77
77
  _logger.debug(f'InMemoryDataNode: created row batch with {len(self.output_rows)} output_rows')
78
- return self.output_rows
78
+ yield self.output_rows