pixeltable 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +7 -19
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +7 -7
- pixeltable/catalog/globals.py +3 -0
- pixeltable/catalog/table.py +208 -145
- pixeltable/catalog/table_version.py +36 -18
- pixeltable/catalog/table_version_path.py +0 -8
- pixeltable/catalog/view.py +3 -3
- pixeltable/dataframe.py +9 -24
- pixeltable/env.py +1 -1
- pixeltable/exec/__init__.py +1 -1
- pixeltable/exec/aggregation_node.py +22 -15
- pixeltable/exec/data_row_batch.py +7 -7
- pixeltable/exec/exec_node.py +35 -7
- pixeltable/exec/expr_eval_node.py +2 -1
- pixeltable/exec/in_memory_data_node.py +9 -9
- pixeltable/exec/sql_node.py +265 -136
- pixeltable/exprs/__init__.py +1 -0
- pixeltable/exprs/data_row.py +30 -19
- pixeltable/exprs/expr.py +15 -14
- pixeltable/exprs/expr_dict.py +55 -0
- pixeltable/exprs/expr_set.py +21 -15
- pixeltable/exprs/function_call.py +21 -8
- pixeltable/exprs/rowid_ref.py +2 -2
- pixeltable/exprs/sql_element_cache.py +5 -1
- pixeltable/ext/functions/whisperx.py +7 -2
- pixeltable/func/callable_function.py +2 -2
- pixeltable/func/function_registry.py +6 -7
- pixeltable/func/query_template_function.py +11 -12
- pixeltable/func/signature.py +17 -15
- pixeltable/func/udf.py +0 -4
- pixeltable/functions/__init__.py +1 -1
- pixeltable/functions/audio.py +4 -6
- pixeltable/functions/globals.py +86 -42
- pixeltable/functions/huggingface.py +12 -14
- pixeltable/functions/image.py +59 -45
- pixeltable/functions/json.py +0 -1
- pixeltable/functions/mistralai.py +2 -2
- pixeltable/functions/openai.py +22 -25
- pixeltable/functions/string.py +50 -50
- pixeltable/functions/timestamp.py +20 -20
- pixeltable/functions/together.py +2 -2
- pixeltable/functions/video.py +11 -20
- pixeltable/functions/whisper.py +2 -20
- pixeltable/globals.py +55 -56
- pixeltable/index/base.py +2 -2
- pixeltable/index/btree.py +7 -7
- pixeltable/index/embedding_index.py +8 -10
- pixeltable/io/external_store.py +11 -5
- pixeltable/io/globals.py +2 -0
- pixeltable/io/hf_datasets.py +1 -1
- pixeltable/io/label_studio.py +6 -6
- pixeltable/io/parquet.py +14 -13
- pixeltable/iterators/document.py +9 -7
- pixeltable/iterators/video.py +10 -1
- pixeltable/metadata/__init__.py +3 -2
- pixeltable/metadata/converters/convert_14.py +4 -2
- pixeltable/metadata/converters/convert_15.py +1 -1
- pixeltable/metadata/converters/convert_19.py +1 -0
- pixeltable/metadata/converters/convert_20.py +1 -1
- pixeltable/metadata/converters/util.py +9 -8
- pixeltable/metadata/schema.py +32 -21
- pixeltable/plan.py +136 -154
- pixeltable/store.py +51 -36
- pixeltable/tool/create_test_db_dump.py +6 -6
- pixeltable/tool/doc_plugins/griffe.py +3 -34
- pixeltable/tool/mypy_plugin.py +32 -0
- pixeltable/type_system.py +243 -60
- pixeltable/utils/arrow.py +10 -9
- pixeltable/utils/coco.py +4 -4
- pixeltable/utils/documents.py +1 -1
- pixeltable/utils/filecache.py +9 -9
- pixeltable/utils/formatter.py +1 -1
- pixeltable/utils/http_server.py +2 -5
- pixeltable/utils/media_store.py +6 -6
- pixeltable/utils/pytorch.py +10 -11
- pixeltable/utils/sql.py +2 -1
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.21.dist-info}/METADATA +6 -5
- pixeltable-0.2.21.dist-info/RECORD +148 -0
- pixeltable/utils/help.py +0 -11
- pixeltable-0.2.20.dist-info/RECORD +0 -147
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.21.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.21.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.21.dist-info}/entry_points.txt +0 -0
|
@@ -6,7 +6,7 @@ import inspect
|
|
|
6
6
|
import logging
|
|
7
7
|
import time
|
|
8
8
|
import uuid
|
|
9
|
-
from typing import TYPE_CHECKING, Any, Iterable, Optional
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional
|
|
10
10
|
from uuid import UUID
|
|
11
11
|
|
|
12
12
|
import sqlalchemy as sql
|
|
@@ -453,7 +453,9 @@ class TableVersion:
|
|
|
453
453
|
self.idxs_by_name[idx_name] = idx_info
|
|
454
454
|
|
|
455
455
|
# add the columns and update the metadata
|
|
456
|
-
|
|
456
|
+
# TODO support on_error='abort' for indices; it's tricky because of the way metadata changes are entangled
|
|
457
|
+
# with the database operations
|
|
458
|
+
status = self._add_columns([val_col, undo_col], conn, print_stats=False, on_error='ignore')
|
|
457
459
|
# now create the index structure
|
|
458
460
|
idx.create_index(self._store_idx_name(idx_id), val_col, conn)
|
|
459
461
|
|
|
@@ -478,7 +480,7 @@ class TableVersion:
|
|
|
478
480
|
self._update_md(time.time(), conn, preceding_schema_version=preceding_schema_version)
|
|
479
481
|
_logger.info(f'Dropped index {idx_md.name} on table {self.name}')
|
|
480
482
|
|
|
481
|
-
def add_column(self, col: Column, print_stats: bool
|
|
483
|
+
def add_column(self, col: Column, print_stats: bool, on_error: Literal['abort', 'ignore']) -> UpdateStatus:
|
|
482
484
|
"""Adds a column to the table.
|
|
483
485
|
"""
|
|
484
486
|
assert not self.is_snapshot
|
|
@@ -498,9 +500,8 @@ class TableVersion:
|
|
|
498
500
|
preceding_schema_version = self.schema_version
|
|
499
501
|
self.schema_version = self.version
|
|
500
502
|
with Env.get().engine.begin() as conn:
|
|
501
|
-
status = self._add_columns([col], conn, print_stats=print_stats)
|
|
503
|
+
status = self._add_columns([col], conn, print_stats=print_stats, on_error=on_error)
|
|
502
504
|
_ = self._add_default_index(col, conn)
|
|
503
|
-
# TODO: what to do about errors?
|
|
504
505
|
self._update_md(time.time(), conn, preceding_schema_version=preceding_schema_version)
|
|
505
506
|
_logger.info(f'Added column {col.name} to table {self.name}, new version: {self.version}')
|
|
506
507
|
|
|
@@ -512,7 +513,13 @@ class TableVersion:
|
|
|
512
513
|
_logger.info(f'Column {col.name}: {msg}')
|
|
513
514
|
return status
|
|
514
515
|
|
|
515
|
-
def _add_columns(
|
|
516
|
+
def _add_columns(
|
|
517
|
+
self,
|
|
518
|
+
cols: Iterable[Column],
|
|
519
|
+
conn: sql.engine.Connection,
|
|
520
|
+
print_stats: bool,
|
|
521
|
+
on_error: Literal['abort', 'ignore']
|
|
522
|
+
) -> UpdateStatus:
|
|
516
523
|
"""Add and populate columns within the current transaction"""
|
|
517
524
|
cols = list(cols)
|
|
518
525
|
row_count = self.store_tbl.count(conn=conn)
|
|
@@ -550,10 +557,14 @@ class TableVersion:
|
|
|
550
557
|
try:
|
|
551
558
|
plan.ctx.set_conn(conn)
|
|
552
559
|
plan.open()
|
|
553
|
-
|
|
560
|
+
try:
|
|
561
|
+
num_excs = self.store_tbl.load_column(col, plan, value_expr_slot_idx, conn, on_error)
|
|
562
|
+
except sql.exc.DBAPIError as exc:
|
|
563
|
+
# Wrap the DBAPIError in an excs.Error to unify processing in the subsequent except block
|
|
564
|
+
raise excs.Error(f'SQL error during execution of computed column `{col.name}`:\n{exc}') from exc
|
|
554
565
|
if num_excs > 0:
|
|
555
566
|
cols_with_excs.append(col)
|
|
556
|
-
except
|
|
567
|
+
except excs.Error as exc:
|
|
557
568
|
self.cols.pop()
|
|
558
569
|
for col in cols:
|
|
559
570
|
# remove columns that we already added
|
|
@@ -564,7 +575,7 @@ class TableVersion:
|
|
|
564
575
|
del self.cols_by_id[col.id]
|
|
565
576
|
# we need to re-initialize the sqlalchemy schema
|
|
566
577
|
self.store_tbl.create_sa_tbl()
|
|
567
|
-
raise
|
|
578
|
+
raise exc
|
|
568
579
|
finally:
|
|
569
580
|
plan.close()
|
|
570
581
|
|
|
@@ -689,21 +700,30 @@ class TableVersion:
|
|
|
689
700
|
plan = Planner.create_insert_plan(self, rows, ignore_errors=not fail_on_exception)
|
|
690
701
|
else:
|
|
691
702
|
plan = Planner.create_df_insert_plan(self, df, ignore_errors=not fail_on_exception)
|
|
703
|
+
|
|
704
|
+
# this is a base table; we generate rowids during the insert
|
|
705
|
+
def rowids() -> Iterator[int]:
|
|
706
|
+
while True:
|
|
707
|
+
rowid = self.next_rowid
|
|
708
|
+
self.next_rowid += 1
|
|
709
|
+
yield rowid
|
|
710
|
+
|
|
692
711
|
if conn is None:
|
|
693
712
|
with Env.get().engine.begin() as conn:
|
|
694
|
-
return self._insert(plan, conn, time.time(), print_stats)
|
|
713
|
+
return self._insert(plan, conn, time.time(), print_stats=print_stats, rowids=rowids())
|
|
695
714
|
else:
|
|
696
|
-
return self._insert(plan, conn, time.time(), print_stats)
|
|
715
|
+
return self._insert(plan, conn, time.time(), print_stats=print_stats, rowids=rowids())
|
|
697
716
|
|
|
698
717
|
def _insert(
|
|
699
|
-
self, exec_plan: 'exec.ExecNode', conn: sql.engine.Connection, timestamp: float,
|
|
718
|
+
self, exec_plan: 'exec.ExecNode', conn: sql.engine.Connection, timestamp: float, *,
|
|
719
|
+
rowids: Optional[Iterator[int]] = None, print_stats: bool = False,
|
|
700
720
|
) -> UpdateStatus:
|
|
701
721
|
"""Insert rows produced by exec_plan and propagate to views"""
|
|
702
722
|
# we're creating a new version
|
|
703
723
|
self.version += 1
|
|
704
724
|
result = UpdateStatus()
|
|
705
|
-
num_rows, num_excs, cols_with_excs = self.store_tbl.insert_rows(
|
|
706
|
-
|
|
725
|
+
num_rows, num_excs, cols_with_excs = self.store_tbl.insert_rows(
|
|
726
|
+
exec_plan, conn, v_min=self.version, rowids=rowids)
|
|
707
727
|
result.num_rows = num_rows
|
|
708
728
|
result.num_excs = num_excs
|
|
709
729
|
result.num_computed_values += exec_plan.ctx.num_computed_exprs * num_rows
|
|
@@ -714,7 +734,7 @@ class TableVersion:
|
|
|
714
734
|
for view in self.mutable_views:
|
|
715
735
|
from pixeltable.plan import Planner
|
|
716
736
|
plan, _ = Planner.create_view_load_plan(view.path, propagates_insert=True)
|
|
717
|
-
status = view._insert(plan, conn, timestamp, print_stats)
|
|
737
|
+
status = view._insert(plan, conn, timestamp, print_stats=print_stats)
|
|
718
738
|
result.num_rows += status.num_rows
|
|
719
739
|
result.num_excs += status.num_excs
|
|
720
740
|
result.num_computed_values += status.num_computed_values
|
|
@@ -751,9 +771,7 @@ class TableVersion:
|
|
|
751
771
|
raise excs.Error(f'Filter {analysis_info.filter} not expressible in SQL')
|
|
752
772
|
|
|
753
773
|
with Env.get().engine.begin() as conn:
|
|
754
|
-
plan, updated_cols, recomputed_cols = (
|
|
755
|
-
Planner.create_update_plan(self.path, update_spec, [], where, cascade)
|
|
756
|
-
)
|
|
774
|
+
plan, updated_cols, recomputed_cols = Planner.create_update_plan(self.path, update_spec, [], where, cascade)
|
|
757
775
|
from pixeltable.exprs import SqlElementCache
|
|
758
776
|
result = self.propagate_update(
|
|
759
777
|
plan, where.sql_expr(SqlElementCache()) if where is not None else None, recomputed_cols,
|
|
@@ -91,14 +91,6 @@ class TableVersionPath:
|
|
|
91
91
|
col = self.tbl_version.cols_by_name[col_name]
|
|
92
92
|
return ColumnRef(col)
|
|
93
93
|
|
|
94
|
-
def __getitem__(self, index: object) -> Union[exprs.ColumnRef, pxt.DataFrame]:
|
|
95
|
-
"""Return a ColumnRef for the given column name, or a DataFrame for the given slice.
|
|
96
|
-
"""
|
|
97
|
-
if isinstance(index, str):
|
|
98
|
-
# basically <tbl>.<colname>
|
|
99
|
-
return self.__getattr__(index)
|
|
100
|
-
return pxt.DataFrame(self).__getitem__(index)
|
|
101
|
-
|
|
102
94
|
def columns(self) -> list[Column]:
|
|
103
95
|
"""Return all user columns visible in this tbl version path, including columns from bases"""
|
|
104
96
|
result = list(self.tbl_version.cols_by_name.values())
|
pixeltable/catalog/view.py
CHANGED
|
@@ -52,11 +52,11 @@ class View(Table):
|
|
|
52
52
|
|
|
53
53
|
@classmethod
|
|
54
54
|
def _create(
|
|
55
|
-
cls, dir_id: UUID, name: str, base: TableVersionPath,
|
|
56
|
-
predicate: 'pxt.exprs.Expr', is_snapshot: bool, num_retained_versions: int, comment: str,
|
|
55
|
+
cls, dir_id: UUID, name: str, base: TableVersionPath, additional_columns: Dict[str, Any],
|
|
56
|
+
predicate: Optional['pxt.exprs.Expr'], is_snapshot: bool, num_retained_versions: int, comment: str,
|
|
57
57
|
iterator_cls: Optional[Type[ComponentIterator]], iterator_args: Optional[Dict]
|
|
58
58
|
) -> View:
|
|
59
|
-
columns = cls._create_columns(
|
|
59
|
+
columns = cls._create_columns(additional_columns)
|
|
60
60
|
cls._verify_schema(columns)
|
|
61
61
|
|
|
62
62
|
# verify that filter can be evaluated in the context of the base
|
pixeltable/dataframe.py
CHANGED
|
@@ -8,7 +8,7 @@ import logging
|
|
|
8
8
|
import mimetypes
|
|
9
9
|
import traceback
|
|
10
10
|
from pathlib import Path
|
|
11
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterator, List, Optional, Set, Tuple
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
|
12
12
|
|
|
13
13
|
import pandas as pd
|
|
14
14
|
import pandas.io.formats.style
|
|
@@ -97,8 +97,8 @@ class DataFrameResultSet:
|
|
|
97
97
|
return self._rows[index[0]][col_idx]
|
|
98
98
|
raise excs.Error(f'Bad index: {index}')
|
|
99
99
|
|
|
100
|
-
def __iter__(self) ->
|
|
101
|
-
return
|
|
100
|
+
def __iter__(self) -> Iterator[dict[str, Any]]:
|
|
101
|
+
return (self._row_to_dict(i) for i in range(len(self)))
|
|
102
102
|
|
|
103
103
|
def __eq__(self, other):
|
|
104
104
|
if not isinstance(other, DataFrameResultSet):
|
|
@@ -106,19 +106,6 @@ class DataFrameResultSet:
|
|
|
106
106
|
return self.to_pandas().equals(other.to_pandas())
|
|
107
107
|
|
|
108
108
|
|
|
109
|
-
class DataFrameResultSetIterator:
|
|
110
|
-
def __init__(self, result_set: DataFrameResultSet):
|
|
111
|
-
self._result_set = result_set
|
|
112
|
-
self._idx = 0
|
|
113
|
-
|
|
114
|
-
def __next__(self) -> Dict[str, Any]:
|
|
115
|
-
if self._idx >= len(self._result_set):
|
|
116
|
-
raise StopIteration
|
|
117
|
-
row = self._result_set._row_to_dict(self._idx)
|
|
118
|
-
self._idx += 1
|
|
119
|
-
return row
|
|
120
|
-
|
|
121
|
-
|
|
122
109
|
# # TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
|
|
123
110
|
# class AnalysisInfo:
|
|
124
111
|
# def __init__(self, tbl: catalog.TableVersion):
|
|
@@ -296,7 +283,7 @@ class DataFrame:
|
|
|
296
283
|
|
|
297
284
|
def _create_query_plan(self) -> exec.ExecNode:
|
|
298
285
|
# construct a group-by clause if we're grouping by a table
|
|
299
|
-
group_by_clause:
|
|
286
|
+
group_by_clause: Optional[list[exprs.Expr]] = None
|
|
300
287
|
if self.grouping_tbl is not None:
|
|
301
288
|
assert self.group_by_clause is None
|
|
302
289
|
num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
|
|
@@ -315,8 +302,8 @@ class DataFrame:
|
|
|
315
302
|
where_clause=self.where_clause,
|
|
316
303
|
group_by_clause=group_by_clause,
|
|
317
304
|
order_by_clause=self.order_by_clause if self.order_by_clause is not None else [],
|
|
318
|
-
limit=self.limit_val
|
|
319
|
-
)
|
|
305
|
+
limit=self.limit_val
|
|
306
|
+
)
|
|
320
307
|
|
|
321
308
|
|
|
322
309
|
def show(self, n: int = 20) -> DataFrameResultSet:
|
|
@@ -629,17 +616,15 @@ class DataFrame:
|
|
|
629
616
|
if self.limit_val is not None:
|
|
630
617
|
raise excs.Error(f'Cannot use `{op_name}` after `limit`')
|
|
631
618
|
|
|
632
|
-
def __getitem__(self, index:
|
|
619
|
+
def __getitem__(self, index: Union[exprs.Expr, Sequence[exprs.Expr]]) -> DataFrame:
|
|
633
620
|
"""
|
|
634
621
|
Allowed:
|
|
635
622
|
- [List[Expr]]/[Tuple[Expr]]: setting the select list
|
|
636
623
|
- [Expr]: setting a single-col select list
|
|
637
624
|
"""
|
|
638
|
-
if isinstance(index, tuple):
|
|
639
|
-
index = list(index)
|
|
640
625
|
if isinstance(index, exprs.Expr):
|
|
641
|
-
|
|
642
|
-
if isinstance(index,
|
|
626
|
+
return self.select(index)
|
|
627
|
+
if isinstance(index, Sequence):
|
|
643
628
|
return self.select(*index)
|
|
644
629
|
raise TypeError(f'Invalid index type: {type(index)}')
|
|
645
630
|
|
pixeltable/env.py
CHANGED
|
@@ -342,7 +342,7 @@ class Env:
|
|
|
342
342
|
|
|
343
343
|
if create_db:
|
|
344
344
|
from pixeltable.metadata import schema
|
|
345
|
-
schema.
|
|
345
|
+
schema.base_metadata.create_all(self._sa_engine)
|
|
346
346
|
metadata.create_system_info(self._sa_engine)
|
|
347
347
|
|
|
348
348
|
print(f'Connected to Pixeltable database at: {self.db_url}')
|
pixeltable/exec/__init__.py
CHANGED
|
@@ -8,4 +8,4 @@ from .expr_eval_node import ExprEvalNode
|
|
|
8
8
|
from .in_memory_data_node import InMemoryDataNode
|
|
9
9
|
from .media_validation_node import MediaValidationNode
|
|
10
10
|
from .row_update_node import RowUpdateNode
|
|
11
|
-
from .sql_node import SqlLookupNode, SqlScanNode
|
|
11
|
+
from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import sys
|
|
5
|
-
from typing import Iterable,
|
|
5
|
+
from typing import Iterable, Optional, Any, Iterator
|
|
6
6
|
|
|
7
7
|
import pixeltable.catalog as catalog
|
|
8
8
|
import pixeltable.exceptions as excs
|
|
@@ -13,17 +13,29 @@ from .exec_node import ExecNode
|
|
|
13
13
|
_logger = logging.getLogger('pixeltable')
|
|
14
14
|
|
|
15
15
|
class AggregationNode(ExecNode):
|
|
16
|
+
"""
|
|
17
|
+
In-memory aggregation for UDAs.
|
|
18
|
+
|
|
19
|
+
At the moment, this returns all results in a single DataRowBatch.
|
|
20
|
+
"""
|
|
21
|
+
group_by: Optional[list[exprs.Expr]]
|
|
22
|
+
input_exprs: list[exprs.Expr]
|
|
23
|
+
agg_fn_eval_ctx: exprs.RowBuilder.EvalCtx
|
|
24
|
+
agg_fn_calls: list[exprs.FunctionCall]
|
|
25
|
+
output_batch: DataRowBatch
|
|
26
|
+
|
|
16
27
|
def __init__(
|
|
17
|
-
self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, group_by:
|
|
18
|
-
agg_fn_calls:
|
|
28
|
+
self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, group_by: Optional[list[exprs.Expr]],
|
|
29
|
+
agg_fn_calls: list[exprs.FunctionCall], input_exprs: Iterable[exprs.Expr], input: ExecNode
|
|
19
30
|
):
|
|
20
31
|
super().__init__(row_builder, group_by + agg_fn_calls, input_exprs, input)
|
|
21
32
|
self.input = input
|
|
22
33
|
self.group_by = group_by
|
|
23
34
|
self.input_exprs = list(input_exprs)
|
|
24
|
-
self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=input_exprs)
|
|
35
|
+
self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=self.input_exprs)
|
|
25
36
|
# we need to make sure to refer to the same exprs that RowBuilder.eval() will use
|
|
26
37
|
self.agg_fn_calls = self.agg_fn_eval_ctx.target_exprs
|
|
38
|
+
# create output_batch here, rather than in __iter__(), so we don't need to remember tbl and row_builder
|
|
27
39
|
self.output_batch = DataRowBatch(tbl, row_builder, 0)
|
|
28
40
|
|
|
29
41
|
def _reset_agg_state(self, row_num: int) -> None:
|
|
@@ -45,17 +57,14 @@ class AggregationNode(ExecNode):
|
|
|
45
57
|
input_vals = [row[d.slot_idx] for d in fn_call.dependencies()]
|
|
46
58
|
raise excs.ExprEvalError(fn_call, expr_msg, e, exc_tb, input_vals, row_num)
|
|
47
59
|
|
|
48
|
-
def
|
|
49
|
-
if self.output_batch is None:
|
|
50
|
-
raise StopIteration
|
|
51
|
-
|
|
60
|
+
def __iter__(self) -> Iterator[DataRowBatch]:
|
|
52
61
|
prev_row: Optional[exprs.DataRow] = None
|
|
53
|
-
current_group: Optional[
|
|
62
|
+
current_group: Optional[list[Any]] = None # the values of the group-by exprs
|
|
54
63
|
num_input_rows = 0
|
|
55
64
|
for row_batch in self.input:
|
|
56
65
|
num_input_rows += len(row_batch)
|
|
57
66
|
for row in row_batch:
|
|
58
|
-
group = [row[e.slot_idx] for e in self.group_by]
|
|
67
|
+
group = [row[e.slot_idx] for e in self.group_by] if self.group_by is not None else None
|
|
59
68
|
if current_group is None:
|
|
60
69
|
current_group = group
|
|
61
70
|
self._reset_agg_state(0)
|
|
@@ -71,9 +80,7 @@ class AggregationNode(ExecNode):
|
|
|
71
80
|
self.row_builder.eval(prev_row, self.agg_fn_eval_ctx, profile=self.ctx.profile)
|
|
72
81
|
self.output_batch.add_row(prev_row)
|
|
73
82
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
self.output_batch
|
|
77
|
-
_logger.debug(f'AggregateNode: consumed {num_input_rows} rows, returning {len(result.rows)} rows')
|
|
78
|
-
return result
|
|
83
|
+
self.output_batch.flush_imgs(None, self.stored_img_cols, self.flushed_img_slots)
|
|
84
|
+
_logger.debug(f'AggregateNode: consumed {num_input_rows} rows, returning {len(self.output_batch.rows)} rows')
|
|
85
|
+
yield self.output_batch
|
|
79
86
|
|
|
@@ -14,6 +14,13 @@ class DataRowBatch:
|
|
|
14
14
|
|
|
15
15
|
Contains the metadata needed to initialize DataRows.
|
|
16
16
|
"""
|
|
17
|
+
tbl: Optional[catalog.TableVersion]
|
|
18
|
+
row_builder: exprs.RowBuilder
|
|
19
|
+
img_slot_idxs: list[int]
|
|
20
|
+
media_slot_idxs: list[int] # non-image media slots
|
|
21
|
+
array_slot_idxs: list[int]
|
|
22
|
+
rows: list[exprs.DataRow]
|
|
23
|
+
|
|
17
24
|
def __init__(self, tbl: Optional[catalog.TableVersion], row_builder: exprs.RowBuilder, len: int = 0):
|
|
18
25
|
self.tbl = tbl
|
|
19
26
|
self.row_builder = row_builder
|
|
@@ -39,13 +46,6 @@ class DataRowBatch:
|
|
|
39
46
|
def pop_row(self) -> exprs.DataRow:
|
|
40
47
|
return self.rows.pop()
|
|
41
48
|
|
|
42
|
-
def set_row_ids(self, row_ids: List[int]) -> None:
|
|
43
|
-
"""Sets pks for rows in batch"""
|
|
44
|
-
assert self.tbl is not None
|
|
45
|
-
assert len(row_ids) == len(self.rows)
|
|
46
|
-
for row, row_id in zip(self.rows, row_ids):
|
|
47
|
-
row.set_pk((row_id, self.tbl))
|
|
48
|
-
|
|
49
49
|
def __len__(self) -> int:
|
|
50
50
|
return len(self.rows)
|
|
51
51
|
|
pixeltable/exec/exec_node.py
CHANGED
|
@@ -1,13 +1,25 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import abc
|
|
4
|
+
from typing import Iterable, Optional, List, TYPE_CHECKING, Iterator
|
|
4
5
|
|
|
6
|
+
import pixeltable.exprs as exprs
|
|
5
7
|
from .data_row_batch import DataRowBatch
|
|
6
8
|
from .exec_context import ExecContext
|
|
7
|
-
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from pixeltable import exec
|
|
8
12
|
|
|
9
13
|
class ExecNode(abc.ABC):
|
|
10
14
|
"""Base class of all execution nodes"""
|
|
15
|
+
output_exprs: Iterable[exprs.Expr]
|
|
16
|
+
row_builder: exprs.RowBuilder
|
|
17
|
+
input: Optional[ExecNode]
|
|
18
|
+
flushed_img_slots: list[int] # idxs of image slots of our output_exprs dependencies
|
|
19
|
+
stored_img_cols: list[exprs.ColumnSlotIdx]
|
|
20
|
+
ctx: Optional[ExecContext]
|
|
21
|
+
__iter: Optional[Iterator[DataRowBatch]]
|
|
22
|
+
|
|
11
23
|
def __init__(
|
|
12
24
|
self, row_builder: exprs.RowBuilder, output_exprs: Iterable[exprs.Expr],
|
|
13
25
|
input_exprs: Iterable[exprs.Expr], input: Optional[ExecNode] = None):
|
|
@@ -21,8 +33,9 @@ class ExecNode(abc.ABC):
|
|
|
21
33
|
e.slot_idx for e in output_dependencies
|
|
22
34
|
if e.col_type.is_image_type() and e.slot_idx not in output_slot_idxs
|
|
23
35
|
]
|
|
24
|
-
self.stored_img_cols
|
|
25
|
-
self.ctx
|
|
36
|
+
self.stored_img_cols = []
|
|
37
|
+
self.ctx = None # all nodes of a tree share the same context
|
|
38
|
+
self.__iter = None
|
|
26
39
|
|
|
27
40
|
def set_ctx(self, ctx: ExecContext) -> None:
|
|
28
41
|
self.ctx = ctx
|
|
@@ -35,12 +48,15 @@ class ExecNode(abc.ABC):
|
|
|
35
48
|
if self.input is not None:
|
|
36
49
|
self.input.set_stored_img_cols(stored_img_cols)
|
|
37
50
|
|
|
38
|
-
|
|
51
|
+
# TODO: make this an abstractmethod when __next__() is removed
|
|
52
|
+
def __iter__(self) -> Iterator[DataRowBatch]:
|
|
39
53
|
return self
|
|
40
54
|
|
|
41
|
-
|
|
55
|
+
# TODO: remove this and switch every subclass over to implementing __iter__
|
|
42
56
|
def __next__(self) -> DataRowBatch:
|
|
43
|
-
|
|
57
|
+
if self.__iter is None:
|
|
58
|
+
self.__iter = iter(self)
|
|
59
|
+
return next(self.__iter)
|
|
44
60
|
|
|
45
61
|
def open(self) -> None:
|
|
46
62
|
"""Bottom-up initialization of nodes for execution. Must be called before __next__."""
|
|
@@ -60,3 +76,15 @@ class ExecNode(abc.ABC):
|
|
|
60
76
|
def _close(self) -> None:
|
|
61
77
|
pass
|
|
62
78
|
|
|
79
|
+
def get_sql_node(self) -> Optional['exec.SqlNode']:
|
|
80
|
+
from .sql_node import SqlNode
|
|
81
|
+
if isinstance(self, SqlNode):
|
|
82
|
+
return self
|
|
83
|
+
if self.input is not None:
|
|
84
|
+
return self.input.get_sql_node()
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
def set_limit(self, limit: int) -> None:
|
|
88
|
+
"""Default implementation propagates to input"""
|
|
89
|
+
if self.input is not None:
|
|
90
|
+
self.input.set_limit(limit)
|
|
@@ -5,10 +5,11 @@ import warnings
|
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from typing import Iterable, List, Optional
|
|
7
7
|
|
|
8
|
-
from tqdm import
|
|
8
|
+
from tqdm import TqdmWarning, tqdm
|
|
9
9
|
|
|
10
10
|
import pixeltable.exprs as exprs
|
|
11
11
|
from pixeltable.func import CallableFunction
|
|
12
|
+
|
|
12
13
|
from .data_row_batch import DataRowBatch
|
|
13
14
|
from .exec_node import ExecNode
|
|
14
15
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Any, Optional
|
|
2
|
+
from typing import Any, Optional, Iterator
|
|
3
3
|
|
|
4
4
|
import pixeltable.catalog as catalog
|
|
5
5
|
import pixeltable.exprs as exprs
|
|
@@ -18,6 +18,11 @@ class InMemoryDataNode(ExecNode):
|
|
|
18
18
|
- with the values provided in the input rows
|
|
19
19
|
- if an input row doesn't provide a value, sets the slot to the column default
|
|
20
20
|
"""
|
|
21
|
+
tbl: catalog.TableVersion
|
|
22
|
+
input_rows: list[dict[str, Any]]
|
|
23
|
+
start_row_id: int
|
|
24
|
+
output_rows: Optional[DataRowBatch]
|
|
25
|
+
|
|
21
26
|
def __init__(
|
|
22
27
|
self, tbl: catalog.TableVersion, rows: list[dict[str, Any]],
|
|
23
28
|
row_builder: exprs.RowBuilder, start_row_id: int,
|
|
@@ -29,8 +34,7 @@ class InMemoryDataNode(ExecNode):
|
|
|
29
34
|
self.tbl = tbl
|
|
30
35
|
self.input_rows = rows
|
|
31
36
|
self.start_row_id = start_row_id
|
|
32
|
-
self.
|
|
33
|
-
self.output_rows: Optional[DataRowBatch] = None
|
|
37
|
+
self.output_rows = None
|
|
34
38
|
|
|
35
39
|
def _open(self) -> None:
|
|
36
40
|
"""Create row batch and populate with self.input_rows"""
|
|
@@ -67,12 +71,8 @@ class InMemoryDataNode(ExecNode):
|
|
|
67
71
|
assert col_info is not None
|
|
68
72
|
self.output_rows[row_idx][col_info.slot_idx] = None
|
|
69
73
|
|
|
70
|
-
self.output_rows.set_row_ids([self.start_row_id + i for i in range(len(self.output_rows))])
|
|
71
74
|
self.ctx.num_rows = len(self.output_rows)
|
|
72
75
|
|
|
73
|
-
def
|
|
74
|
-
if self.has_returned_data:
|
|
75
|
-
raise StopIteration
|
|
76
|
-
self.has_returned_data = True
|
|
76
|
+
def __iter__(self) -> Iterator[DataRowBatch]:
|
|
77
77
|
_logger.debug(f'InMemoryDataNode: created row batch with {len(self.output_rows)} output_rows')
|
|
78
|
-
|
|
78
|
+
yield self.output_rows
|