pixeltable 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +7 -19
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +7 -7
- pixeltable/catalog/column.py +37 -11
- pixeltable/catalog/globals.py +21 -0
- pixeltable/catalog/insertable_table.py +6 -4
- pixeltable/catalog/table.py +227 -148
- pixeltable/catalog/table_version.py +66 -28
- pixeltable/catalog/table_version_path.py +0 -8
- pixeltable/catalog/view.py +18 -19
- pixeltable/dataframe.py +16 -32
- pixeltable/env.py +6 -1
- pixeltable/exec/__init__.py +1 -2
- pixeltable/exec/aggregation_node.py +27 -17
- pixeltable/exec/cache_prefetch_node.py +1 -1
- pixeltable/exec/data_row_batch.py +9 -26
- pixeltable/exec/exec_node.py +36 -7
- pixeltable/exec/expr_eval_node.py +19 -11
- pixeltable/exec/in_memory_data_node.py +14 -11
- pixeltable/exec/sql_node.py +266 -138
- pixeltable/exprs/__init__.py +1 -0
- pixeltable/exprs/arithmetic_expr.py +3 -1
- pixeltable/exprs/array_slice.py +7 -7
- pixeltable/exprs/column_property_ref.py +37 -10
- pixeltable/exprs/column_ref.py +93 -14
- pixeltable/exprs/comparison.py +5 -5
- pixeltable/exprs/compound_predicate.py +8 -7
- pixeltable/exprs/data_row.py +56 -36
- pixeltable/exprs/expr.py +65 -63
- pixeltable/exprs/expr_dict.py +55 -0
- pixeltable/exprs/expr_set.py +26 -15
- pixeltable/exprs/function_call.py +53 -24
- pixeltable/exprs/globals.py +4 -1
- pixeltable/exprs/in_predicate.py +8 -7
- pixeltable/exprs/inline_expr.py +4 -4
- pixeltable/exprs/is_null.py +4 -4
- pixeltable/exprs/json_mapper.py +11 -12
- pixeltable/exprs/json_path.py +5 -10
- pixeltable/exprs/literal.py +5 -5
- pixeltable/exprs/method_ref.py +5 -4
- pixeltable/exprs/object_ref.py +2 -1
- pixeltable/exprs/row_builder.py +88 -36
- pixeltable/exprs/rowid_ref.py +14 -13
- pixeltable/exprs/similarity_expr.py +12 -7
- pixeltable/exprs/sql_element_cache.py +12 -6
- pixeltable/exprs/type_cast.py +8 -6
- pixeltable/exprs/variable.py +5 -4
- pixeltable/ext/functions/whisperx.py +7 -2
- pixeltable/func/aggregate_function.py +1 -1
- pixeltable/func/callable_function.py +2 -2
- pixeltable/func/function.py +11 -10
- pixeltable/func/function_registry.py +6 -7
- pixeltable/func/query_template_function.py +11 -12
- pixeltable/func/signature.py +17 -15
- pixeltable/func/udf.py +0 -4
- pixeltable/functions/__init__.py +2 -2
- pixeltable/functions/audio.py +4 -6
- pixeltable/functions/globals.py +84 -42
- pixeltable/functions/huggingface.py +31 -34
- pixeltable/functions/image.py +59 -45
- pixeltable/functions/json.py +0 -1
- pixeltable/functions/llama_cpp.py +106 -0
- pixeltable/functions/mistralai.py +2 -2
- pixeltable/functions/ollama.py +147 -0
- pixeltable/functions/openai.py +22 -25
- pixeltable/functions/replicate.py +72 -0
- pixeltable/functions/string.py +59 -50
- pixeltable/functions/timestamp.py +20 -20
- pixeltable/functions/together.py +2 -2
- pixeltable/functions/video.py +11 -20
- pixeltable/functions/whisper.py +2 -20
- pixeltable/globals.py +65 -74
- pixeltable/index/base.py +2 -2
- pixeltable/index/btree.py +20 -7
- pixeltable/index/embedding_index.py +12 -14
- pixeltable/io/__init__.py +1 -2
- pixeltable/io/external_store.py +11 -5
- pixeltable/io/fiftyone.py +178 -0
- pixeltable/io/globals.py +98 -2
- pixeltable/io/hf_datasets.py +1 -1
- pixeltable/io/label_studio.py +6 -6
- pixeltable/io/parquet.py +14 -13
- pixeltable/iterators/base.py +3 -2
- pixeltable/iterators/document.py +10 -8
- pixeltable/iterators/video.py +126 -60
- pixeltable/metadata/__init__.py +4 -3
- pixeltable/metadata/converters/convert_14.py +4 -2
- pixeltable/metadata/converters/convert_15.py +1 -1
- pixeltable/metadata/converters/convert_19.py +1 -0
- pixeltable/metadata/converters/convert_20.py +1 -1
- pixeltable/metadata/converters/convert_21.py +34 -0
- pixeltable/metadata/converters/util.py +54 -12
- pixeltable/metadata/notes.py +1 -0
- pixeltable/metadata/schema.py +40 -21
- pixeltable/plan.py +149 -165
- pixeltable/py.typed +0 -0
- pixeltable/store.py +57 -37
- pixeltable/tool/create_test_db_dump.py +6 -6
- pixeltable/tool/create_test_video.py +1 -1
- pixeltable/tool/doc_plugins/griffe.py +3 -34
- pixeltable/tool/embed_udf.py +1 -1
- pixeltable/tool/mypy_plugin.py +55 -0
- pixeltable/type_system.py +260 -61
- pixeltable/utils/arrow.py +10 -9
- pixeltable/utils/coco.py +4 -4
- pixeltable/utils/documents.py +16 -2
- pixeltable/utils/filecache.py +9 -9
- pixeltable/utils/formatter.py +10 -11
- pixeltable/utils/http_server.py +2 -5
- pixeltable/utils/media_store.py +6 -6
- pixeltable/utils/pytorch.py +10 -11
- pixeltable/utils/sql.py +2 -1
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/METADATA +50 -13
- pixeltable-0.2.22.dist-info/RECORD +153 -0
- pixeltable/exec/media_validation_node.py +0 -43
- pixeltable/utils/help.py +0 -11
- pixeltable-0.2.20.dist-info/RECORD +0 -147
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/entry_points.txt +0 -0
|
@@ -6,7 +6,7 @@ import inspect
|
|
|
6
6
|
import logging
|
|
7
7
|
import time
|
|
8
8
|
import uuid
|
|
9
|
-
from typing import TYPE_CHECKING, Any, Iterable, Optional
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional
|
|
10
10
|
from uuid import UUID
|
|
11
11
|
|
|
12
12
|
import sqlalchemy as sql
|
|
@@ -26,7 +26,7 @@ from pixeltable.utils.media_store import MediaStore
|
|
|
26
26
|
|
|
27
27
|
from ..func.globals import resolve_symbol
|
|
28
28
|
from .column import Column
|
|
29
|
-
from .globals import _POS_COLUMN_NAME, _ROWID_COLUMN_NAME, UpdateStatus, is_valid_identifier
|
|
29
|
+
from .globals import _POS_COLUMN_NAME, _ROWID_COLUMN_NAME, UpdateStatus, is_valid_identifier, MediaValidation
|
|
30
30
|
|
|
31
31
|
if TYPE_CHECKING:
|
|
32
32
|
from pixeltable import exec, store
|
|
@@ -53,6 +53,7 @@ class TableVersion:
|
|
|
53
53
|
name: str
|
|
54
54
|
version: int
|
|
55
55
|
comment: str
|
|
56
|
+
media_validation: MediaValidation
|
|
56
57
|
num_retained_versions: int
|
|
57
58
|
schema_version: int
|
|
58
59
|
view_md: Optional[schema.ViewMd]
|
|
@@ -109,6 +110,7 @@ class TableVersion:
|
|
|
109
110
|
self.view_md = tbl_md.view_md # save this as-is, it's needed for _create_md()
|
|
110
111
|
is_view = tbl_md.view_md is not None
|
|
111
112
|
self.is_snapshot = (is_view and tbl_md.view_md.is_snapshot) or bool(is_snapshot)
|
|
113
|
+
self.media_validation = MediaValidation[schema_version_md.media_validation.upper()]
|
|
112
114
|
# a mutable TableVersion doesn't have a static version
|
|
113
115
|
self.effective_version = self.version if self.is_snapshot else None
|
|
114
116
|
|
|
@@ -182,7 +184,7 @@ class TableVersion:
|
|
|
182
184
|
@classmethod
|
|
183
185
|
def create(
|
|
184
186
|
cls, session: orm.Session, dir_id: UUID, name: str, cols: list[Column], num_retained_versions: int,
|
|
185
|
-
comment: str, base_path: Optional[pxt.catalog.TableVersionPath] = None,
|
|
187
|
+
comment: str, media_validation: MediaValidation, base_path: Optional[pxt.catalog.TableVersionPath] = None,
|
|
186
188
|
view_md: Optional[schema.ViewMd] = None
|
|
187
189
|
) -> tuple[UUID, Optional[TableVersion]]:
|
|
188
190
|
# assign ids
|
|
@@ -214,11 +216,17 @@ class TableVersion:
|
|
|
214
216
|
tbl_id=tbl_record.id, version=0, md=dataclasses.asdict(table_version_md))
|
|
215
217
|
|
|
216
218
|
# create schema.TableSchemaVersion
|
|
217
|
-
schema_col_md
|
|
219
|
+
schema_col_md: dict[int, schema.SchemaColumn] = {}
|
|
220
|
+
for pos, col in enumerate(cols):
|
|
221
|
+
md = schema.SchemaColumn(
|
|
222
|
+
pos=pos, name=col.name,
|
|
223
|
+
media_validation=col._media_validation.name.lower() if col._media_validation is not None else None)
|
|
224
|
+
schema_col_md[col.id] = md
|
|
218
225
|
|
|
219
226
|
schema_version_md = schema.TableSchemaVersionMd(
|
|
220
227
|
schema_version=0, preceding_schema_version=None, columns=schema_col_md,
|
|
221
|
-
num_retained_versions=num_retained_versions, comment=comment
|
|
228
|
+
num_retained_versions=num_retained_versions, comment=comment,
|
|
229
|
+
media_validation=media_validation.name.lower())
|
|
222
230
|
schema_version_record = schema.TableSchemaVersion(
|
|
223
231
|
tbl_id=tbl_record.id, schema_version=0, md=dataclasses.asdict(schema_version_md))
|
|
224
232
|
|
|
@@ -285,10 +293,15 @@ class TableVersion:
|
|
|
285
293
|
self.cols_by_name = {}
|
|
286
294
|
self.cols_by_id = {}
|
|
287
295
|
for col_md in tbl_md.column_md.values():
|
|
288
|
-
|
|
296
|
+
schema_col_md = schema_version_md.columns[col_md.id] if col_md.id in schema_version_md.columns else None
|
|
297
|
+
col_name = schema_col_md.name if schema_col_md is not None else None
|
|
298
|
+
media_val = (
|
|
299
|
+
MediaValidation[schema_col_md.media_validation.upper()]
|
|
300
|
+
if schema_col_md is not None and schema_col_md.media_validation is not None else None
|
|
301
|
+
)
|
|
289
302
|
col = Column(
|
|
290
303
|
col_id=col_md.id, name=col_name, col_type=ts.ColumnType.from_dict(col_md.col_type),
|
|
291
|
-
is_pk=col_md.is_pk, stored=col_md.stored,
|
|
304
|
+
is_pk=col_md.is_pk, stored=col_md.stored, media_validation=media_val,
|
|
292
305
|
schema_version_add=col_md.schema_version_add, schema_version_drop=col_md.schema_version_drop,
|
|
293
306
|
value_expr_dict=col_md.value_expr)
|
|
294
307
|
col.tbl = self
|
|
@@ -349,7 +362,8 @@ class TableVersion:
|
|
|
349
362
|
self.store_tbl = StoreTable(self)
|
|
350
363
|
|
|
351
364
|
def _update_md(
|
|
352
|
-
|
|
365
|
+
self, timestamp: float, conn: sql.engine.Connection, update_tbl_version: bool = True,
|
|
366
|
+
preceding_schema_version: Optional[int] = None
|
|
353
367
|
) -> None:
|
|
354
368
|
"""Writes table metadata to the database.
|
|
355
369
|
|
|
@@ -453,7 +467,9 @@ class TableVersion:
|
|
|
453
467
|
self.idxs_by_name[idx_name] = idx_info
|
|
454
468
|
|
|
455
469
|
# add the columns and update the metadata
|
|
456
|
-
|
|
470
|
+
# TODO support on_error='abort' for indices; it's tricky because of the way metadata changes are entangled
|
|
471
|
+
# with the database operations
|
|
472
|
+
status = self._add_columns([val_col, undo_col], conn, print_stats=False, on_error='ignore')
|
|
457
473
|
# now create the index structure
|
|
458
474
|
idx.create_index(self._store_idx_name(idx_id), val_col, conn)
|
|
459
475
|
|
|
@@ -478,7 +494,7 @@ class TableVersion:
|
|
|
478
494
|
self._update_md(time.time(), conn, preceding_schema_version=preceding_schema_version)
|
|
479
495
|
_logger.info(f'Dropped index {idx_md.name} on table {self.name}')
|
|
480
496
|
|
|
481
|
-
def add_column(self, col: Column, print_stats: bool
|
|
497
|
+
def add_column(self, col: Column, print_stats: bool, on_error: Literal['abort', 'ignore']) -> UpdateStatus:
|
|
482
498
|
"""Adds a column to the table.
|
|
483
499
|
"""
|
|
484
500
|
assert not self.is_snapshot
|
|
@@ -498,9 +514,8 @@ class TableVersion:
|
|
|
498
514
|
preceding_schema_version = self.schema_version
|
|
499
515
|
self.schema_version = self.version
|
|
500
516
|
with Env.get().engine.begin() as conn:
|
|
501
|
-
status = self._add_columns([col], conn, print_stats=print_stats)
|
|
517
|
+
status = self._add_columns([col], conn, print_stats=print_stats, on_error=on_error)
|
|
502
518
|
_ = self._add_default_index(col, conn)
|
|
503
|
-
# TODO: what to do about errors?
|
|
504
519
|
self._update_md(time.time(), conn, preceding_schema_version=preceding_schema_version)
|
|
505
520
|
_logger.info(f'Added column {col.name} to table {self.name}, new version: {self.version}')
|
|
506
521
|
|
|
@@ -512,7 +527,13 @@ class TableVersion:
|
|
|
512
527
|
_logger.info(f'Column {col.name}: {msg}')
|
|
513
528
|
return status
|
|
514
529
|
|
|
515
|
-
def _add_columns(
|
|
530
|
+
def _add_columns(
|
|
531
|
+
self,
|
|
532
|
+
cols: Iterable[Column],
|
|
533
|
+
conn: sql.engine.Connection,
|
|
534
|
+
print_stats: bool,
|
|
535
|
+
on_error: Literal['abort', 'ignore']
|
|
536
|
+
) -> UpdateStatus:
|
|
516
537
|
"""Add and populate columns within the current transaction"""
|
|
517
538
|
cols = list(cols)
|
|
518
539
|
row_count = self.store_tbl.count(conn=conn)
|
|
@@ -550,10 +571,14 @@ class TableVersion:
|
|
|
550
571
|
try:
|
|
551
572
|
plan.ctx.set_conn(conn)
|
|
552
573
|
plan.open()
|
|
553
|
-
|
|
574
|
+
try:
|
|
575
|
+
num_excs = self.store_tbl.load_column(col, plan, value_expr_slot_idx, conn, on_error)
|
|
576
|
+
except sql.exc.DBAPIError as exc:
|
|
577
|
+
# Wrap the DBAPIError in an excs.Error to unify processing in the subsequent except block
|
|
578
|
+
raise excs.Error(f'SQL error during execution of computed column `{col.name}`:\n{exc}') from exc
|
|
554
579
|
if num_excs > 0:
|
|
555
580
|
cols_with_excs.append(col)
|
|
556
|
-
except
|
|
581
|
+
except excs.Error as exc:
|
|
557
582
|
self.cols.pop()
|
|
558
583
|
for col in cols:
|
|
559
584
|
# remove columns that we already added
|
|
@@ -564,7 +589,7 @@ class TableVersion:
|
|
|
564
589
|
del self.cols_by_id[col.id]
|
|
565
590
|
# we need to re-initialize the sqlalchemy schema
|
|
566
591
|
self.store_tbl.create_sa_tbl()
|
|
567
|
-
raise
|
|
592
|
+
raise exc
|
|
568
593
|
finally:
|
|
569
594
|
plan.close()
|
|
570
595
|
|
|
@@ -689,21 +714,32 @@ class TableVersion:
|
|
|
689
714
|
plan = Planner.create_insert_plan(self, rows, ignore_errors=not fail_on_exception)
|
|
690
715
|
else:
|
|
691
716
|
plan = Planner.create_df_insert_plan(self, df, ignore_errors=not fail_on_exception)
|
|
717
|
+
|
|
718
|
+
# this is a base table; we generate rowids during the insert
|
|
719
|
+
def rowids() -> Iterator[int]:
|
|
720
|
+
while True:
|
|
721
|
+
rowid = self.next_rowid
|
|
722
|
+
self.next_rowid += 1
|
|
723
|
+
yield rowid
|
|
724
|
+
|
|
692
725
|
if conn is None:
|
|
693
726
|
with Env.get().engine.begin() as conn:
|
|
694
|
-
return self._insert(
|
|
727
|
+
return self._insert(
|
|
728
|
+
plan, conn, time.time(), print_stats=print_stats, rowids=rowids(), abort_on_exc=fail_on_exception)
|
|
695
729
|
else:
|
|
696
|
-
return self._insert(
|
|
730
|
+
return self._insert(
|
|
731
|
+
plan, conn, time.time(), print_stats=print_stats, rowids=rowids(), abort_on_exc=fail_on_exception)
|
|
697
732
|
|
|
698
733
|
def _insert(
|
|
699
|
-
self, exec_plan: 'exec.ExecNode', conn: sql.engine.Connection, timestamp: float,
|
|
734
|
+
self, exec_plan: 'exec.ExecNode', conn: sql.engine.Connection, timestamp: float, *,
|
|
735
|
+
rowids: Optional[Iterator[int]] = None, print_stats: bool = False, abort_on_exc: bool = False
|
|
700
736
|
) -> UpdateStatus:
|
|
701
737
|
"""Insert rows produced by exec_plan and propagate to views"""
|
|
702
738
|
# we're creating a new version
|
|
703
739
|
self.version += 1
|
|
704
740
|
result = UpdateStatus()
|
|
705
|
-
num_rows, num_excs, cols_with_excs = self.store_tbl.insert_rows(
|
|
706
|
-
|
|
741
|
+
num_rows, num_excs, cols_with_excs = self.store_tbl.insert_rows(
|
|
742
|
+
exec_plan, conn, v_min=self.version, rowids=rowids, abort_on_exc=abort_on_exc)
|
|
707
743
|
result.num_rows = num_rows
|
|
708
744
|
result.num_excs = num_excs
|
|
709
745
|
result.num_computed_values += exec_plan.ctx.num_computed_exprs * num_rows
|
|
@@ -714,7 +750,7 @@ class TableVersion:
|
|
|
714
750
|
for view in self.mutable_views:
|
|
715
751
|
from pixeltable.plan import Planner
|
|
716
752
|
plan, _ = Planner.create_view_load_plan(view.path, propagates_insert=True)
|
|
717
|
-
status = view._insert(plan, conn, timestamp, print_stats)
|
|
753
|
+
status = view._insert(plan, conn, timestamp, print_stats=print_stats)
|
|
718
754
|
result.num_rows += status.num_rows
|
|
719
755
|
result.num_excs += status.num_excs
|
|
720
756
|
result.num_computed_values += status.num_computed_values
|
|
@@ -751,9 +787,7 @@ class TableVersion:
|
|
|
751
787
|
raise excs.Error(f'Filter {analysis_info.filter} not expressible in SQL')
|
|
752
788
|
|
|
753
789
|
with Env.get().engine.begin() as conn:
|
|
754
|
-
plan, updated_cols, recomputed_cols = (
|
|
755
|
-
Planner.create_update_plan(self.path, update_spec, [], where, cascade)
|
|
756
|
-
)
|
|
790
|
+
plan, updated_cols, recomputed_cols = Planner.create_update_plan(self.path, update_spec, [], where, cascade)
|
|
757
791
|
from pixeltable.exprs import SqlElementCache
|
|
758
792
|
result = self.propagate_update(
|
|
759
793
|
plan, where.sql_expr(SqlElementCache()) if where is not None else None, recomputed_cols,
|
|
@@ -1185,7 +1219,8 @@ class TableVersion:
|
|
|
1185
1219
|
name=self.name, current_version=self.version, current_schema_version=self.schema_version,
|
|
1186
1220
|
next_col_id=self.next_col_id, next_idx_id=self.next_idx_id, next_row_id=self.next_rowid,
|
|
1187
1221
|
column_md=self._create_column_md(self.cols), index_md=self.idx_md,
|
|
1188
|
-
external_stores=self._create_stores_md(self.external_stores.values()), view_md=self.view_md
|
|
1222
|
+
external_stores=self._create_stores_md(self.external_stores.values()), view_md=self.view_md,
|
|
1223
|
+
)
|
|
1189
1224
|
|
|
1190
1225
|
def _create_version_md(self, timestamp: float) -> schema.TableVersionMd:
|
|
1191
1226
|
return schema.TableVersionMd(created_at=timestamp, version=self.version, schema_version=self.schema_version)
|
|
@@ -1193,11 +1228,14 @@ class TableVersion:
|
|
|
1193
1228
|
def _create_schema_version_md(self, preceding_schema_version: int) -> schema.TableSchemaVersionMd:
|
|
1194
1229
|
column_md: dict[int, schema.SchemaColumn] = {}
|
|
1195
1230
|
for pos, col in enumerate(self.cols_by_name.values()):
|
|
1196
|
-
column_md[col.id] = schema.SchemaColumn(
|
|
1231
|
+
column_md[col.id] = schema.SchemaColumn(
|
|
1232
|
+
pos=pos, name=col.name,
|
|
1233
|
+
media_validation=col._media_validation.name.lower() if col._media_validation is not None else None)
|
|
1197
1234
|
# preceding_schema_version to be set by the caller
|
|
1198
1235
|
return schema.TableSchemaVersionMd(
|
|
1199
1236
|
schema_version=self.schema_version, preceding_schema_version=preceding_schema_version,
|
|
1200
|
-
columns=column_md, num_retained_versions=self.num_retained_versions, comment=self.comment
|
|
1237
|
+
columns=column_md, num_retained_versions=self.num_retained_versions, comment=self.comment,
|
|
1238
|
+
media_validation=self.media_validation.name.lower())
|
|
1201
1239
|
|
|
1202
1240
|
def as_dict(self) -> dict:
|
|
1203
1241
|
return {'id': str(self.id), 'effective_version': self.effective_version}
|
|
@@ -91,14 +91,6 @@ class TableVersionPath:
|
|
|
91
91
|
col = self.tbl_version.cols_by_name[col_name]
|
|
92
92
|
return ColumnRef(col)
|
|
93
93
|
|
|
94
|
-
def __getitem__(self, index: object) -> Union[exprs.ColumnRef, pxt.DataFrame]:
|
|
95
|
-
"""Return a ColumnRef for the given column name, or a DataFrame for the given slice.
|
|
96
|
-
"""
|
|
97
|
-
if isinstance(index, str):
|
|
98
|
-
# basically <tbl>.<colname>
|
|
99
|
-
return self.__getattr__(index)
|
|
100
|
-
return pxt.DataFrame(self).__getitem__(index)
|
|
101
|
-
|
|
102
94
|
def columns(self) -> list[Column]:
|
|
103
95
|
"""Return all user columns visible in this tbl version path, including columns from bases"""
|
|
104
96
|
result = list(self.tbl_version.cols_by_name.values())
|
pixeltable/catalog/view.py
CHANGED
|
@@ -2,24 +2,21 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import logging
|
|
5
|
-
from typing import TYPE_CHECKING, Any,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Iterable, Optional
|
|
6
6
|
from uuid import UUID
|
|
7
7
|
|
|
8
8
|
import sqlalchemy.orm as orm
|
|
9
9
|
|
|
10
|
-
import pixeltable.catalog as catalog
|
|
11
10
|
import pixeltable.exceptions as excs
|
|
12
|
-
import pixeltable.exprs as exprs
|
|
13
|
-
import pixeltable.func as func
|
|
14
11
|
import pixeltable.metadata.schema as md_schema
|
|
12
|
+
import pixeltable.type_system as ts
|
|
13
|
+
from pixeltable import catalog, exprs, func
|
|
15
14
|
from pixeltable.env import Env
|
|
16
|
-
from pixeltable.exceptions import Error
|
|
17
15
|
from pixeltable.iterators import ComponentIterator
|
|
18
|
-
from pixeltable.type_system import IntType, InvalidType
|
|
19
16
|
|
|
20
17
|
from .catalog import Catalog
|
|
21
18
|
from .column import Column
|
|
22
|
-
from .globals import _POS_COLUMN_NAME, UpdateStatus
|
|
19
|
+
from .globals import _POS_COLUMN_NAME, UpdateStatus, MediaValidation
|
|
23
20
|
from .table import Table
|
|
24
21
|
from .table_version import TableVersion
|
|
25
22
|
from .table_version_path import TableVersionPath
|
|
@@ -52,11 +49,12 @@ class View(Table):
|
|
|
52
49
|
|
|
53
50
|
@classmethod
|
|
54
51
|
def _create(
|
|
55
|
-
cls, dir_id: UUID, name: str, base: TableVersionPath,
|
|
56
|
-
predicate: 'pxt.exprs.Expr', is_snapshot: bool, num_retained_versions: int, comment: str,
|
|
57
|
-
|
|
52
|
+
cls, dir_id: UUID, name: str, base: TableVersionPath, additional_columns: dict[str, Any],
|
|
53
|
+
predicate: Optional['pxt.exprs.Expr'], is_snapshot: bool, num_retained_versions: int, comment: str,
|
|
54
|
+
media_validation: MediaValidation,
|
|
55
|
+
iterator_cls: Optional[type[ComponentIterator]], iterator_args: Optional[dict]
|
|
58
56
|
) -> View:
|
|
59
|
-
columns = cls._create_columns(
|
|
57
|
+
columns = cls._create_columns(additional_columns)
|
|
60
58
|
cls._verify_schema(columns)
|
|
61
59
|
|
|
62
60
|
# verify that filter can be evaluated in the context of the base
|
|
@@ -92,17 +90,17 @@ class View(Table):
|
|
|
92
90
|
func.Parameter(param_name, param_type, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
|
93
91
|
for param_name, param_type in iterator_cls.input_schema().items()
|
|
94
92
|
]
|
|
95
|
-
sig = func.Signature(InvalidType(), params)
|
|
93
|
+
sig = func.Signature(ts.InvalidType(), params)
|
|
96
94
|
from pixeltable.exprs import FunctionCall
|
|
97
95
|
FunctionCall.normalize_args(iterator_cls.__name__, sig, bound_args)
|
|
98
96
|
except TypeError as e:
|
|
99
|
-
raise Error(f'Cannot instantiate iterator with given arguments: {e}')
|
|
97
|
+
raise excs.Error(f'Cannot instantiate iterator with given arguments: {e}')
|
|
100
98
|
|
|
101
99
|
# prepend pos and output_schema columns to cols:
|
|
102
100
|
# a component view exposes the pos column of its rowid;
|
|
103
101
|
# we create that column here, so it gets assigned a column id;
|
|
104
102
|
# stored=False: it is not stored separately (it's already stored as part of the rowid)
|
|
105
|
-
iterator_cols = [Column(_POS_COLUMN_NAME, IntType(), stored=False)]
|
|
103
|
+
iterator_cols = [Column(_POS_COLUMN_NAME, ts.IntType(), stored=False)]
|
|
106
104
|
output_dict, unstored_cols = iterator_cls.output_schema(**bound_args)
|
|
107
105
|
iterator_cols.extend([
|
|
108
106
|
Column(col_name, col_type, stored=col_name not in unstored_cols)
|
|
@@ -112,12 +110,12 @@ class View(Table):
|
|
|
112
110
|
iterator_col_names = {col.name for col in iterator_cols}
|
|
113
111
|
for col in columns:
|
|
114
112
|
if col.name in iterator_col_names:
|
|
115
|
-
raise Error(f'Duplicate name: column {col.name} is already present in the iterator output schema')
|
|
113
|
+
raise excs.Error(f'Duplicate name: column {col.name} is already present in the iterator output schema')
|
|
116
114
|
columns = iterator_cols + columns
|
|
117
115
|
|
|
118
116
|
with orm.Session(Env.get().engine, future=True) as session:
|
|
119
117
|
from pixeltable.exprs import InlineDict
|
|
120
|
-
iterator_args_expr = InlineDict(iterator_args) if iterator_args is not None else None
|
|
118
|
+
iterator_args_expr: exprs.Expr = InlineDict(iterator_args) if iterator_args is not None else None
|
|
121
119
|
iterator_class_fqn = f'{iterator_cls.__module__}.{iterator_cls.__name__}' if iterator_cls is not None \
|
|
122
120
|
else None
|
|
123
121
|
base_version_path = cls._get_snapshot_path(base) if is_snapshot else base
|
|
@@ -142,7 +140,8 @@ class View(Table):
|
|
|
142
140
|
iterator_args=iterator_args_expr.as_dict() if iterator_args_expr is not None else None)
|
|
143
141
|
|
|
144
142
|
id, tbl_version = TableVersion.create(
|
|
145
|
-
session, dir_id, name, columns, num_retained_versions, comment,
|
|
143
|
+
session, dir_id, name, columns, num_retained_versions, comment, media_validation=media_validation,
|
|
144
|
+
base_path=base_version_path, view_md=view_md)
|
|
146
145
|
if tbl_version is None:
|
|
147
146
|
# this is purely a snapshot: we use the base's tbl version path
|
|
148
147
|
view = cls(id, dir_id, name, base_version_path, base.tbl_id(), snapshot_only=True)
|
|
@@ -168,11 +167,11 @@ class View(Table):
|
|
|
168
167
|
|
|
169
168
|
@classmethod
|
|
170
169
|
def _verify_column(
|
|
171
|
-
cls, col: Column, existing_column_names:
|
|
170
|
+
cls, col: Column, existing_column_names: set[str], existing_query_names: Optional[set[str]] = None
|
|
172
171
|
) -> None:
|
|
173
172
|
# make sure that columns are nullable or have a default
|
|
174
173
|
if not col.col_type.nullable and not col.is_computed:
|
|
175
|
-
raise Error(f'Column {col.name}: non-computed columns in views must be nullable')
|
|
174
|
+
raise excs.Error(f'Column {col.name}: non-computed columns in views must be nullable')
|
|
176
175
|
super()._verify_column(col, existing_column_names, existing_query_names)
|
|
177
176
|
|
|
178
177
|
@classmethod
|
pixeltable/dataframe.py
CHANGED
|
@@ -8,7 +8,7 @@ import logging
|
|
|
8
8
|
import mimetypes
|
|
9
9
|
import traceback
|
|
10
10
|
from pathlib import Path
|
|
11
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterator, List, Optional, Set, Tuple
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
|
12
12
|
|
|
13
13
|
import pandas as pd
|
|
14
14
|
import pandas.io.formats.style
|
|
@@ -97,8 +97,8 @@ class DataFrameResultSet:
|
|
|
97
97
|
return self._rows[index[0]][col_idx]
|
|
98
98
|
raise excs.Error(f'Bad index: {index}')
|
|
99
99
|
|
|
100
|
-
def __iter__(self) ->
|
|
101
|
-
return
|
|
100
|
+
def __iter__(self) -> Iterator[dict[str, Any]]:
|
|
101
|
+
return (self._row_to_dict(i) for i in range(len(self)))
|
|
102
102
|
|
|
103
103
|
def __eq__(self, other):
|
|
104
104
|
if not isinstance(other, DataFrameResultSet):
|
|
@@ -106,19 +106,6 @@ class DataFrameResultSet:
|
|
|
106
106
|
return self.to_pandas().equals(other.to_pandas())
|
|
107
107
|
|
|
108
108
|
|
|
109
|
-
class DataFrameResultSetIterator:
|
|
110
|
-
def __init__(self, result_set: DataFrameResultSet):
|
|
111
|
-
self._result_set = result_set
|
|
112
|
-
self._idx = 0
|
|
113
|
-
|
|
114
|
-
def __next__(self) -> Dict[str, Any]:
|
|
115
|
-
if self._idx >= len(self._result_set):
|
|
116
|
-
raise StopIteration
|
|
117
|
-
row = self._result_set._row_to_dict(self._idx)
|
|
118
|
-
self._idx += 1
|
|
119
|
-
return row
|
|
120
|
-
|
|
121
|
-
|
|
122
109
|
# # TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
|
|
123
110
|
# class AnalysisInfo:
|
|
124
111
|
# def __init__(self, tbl: catalog.TableVersion):
|
|
@@ -296,7 +283,7 @@ class DataFrame:
|
|
|
296
283
|
|
|
297
284
|
def _create_query_plan(self) -> exec.ExecNode:
|
|
298
285
|
# construct a group-by clause if we're grouping by a table
|
|
299
|
-
group_by_clause:
|
|
286
|
+
group_by_clause: Optional[list[exprs.Expr]] = None
|
|
300
287
|
if self.grouping_tbl is not None:
|
|
301
288
|
assert self.group_by_clause is None
|
|
302
289
|
num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
|
|
@@ -315,8 +302,8 @@ class DataFrame:
|
|
|
315
302
|
where_clause=self.where_clause,
|
|
316
303
|
group_by_clause=group_by_clause,
|
|
317
304
|
order_by_clause=self.order_by_clause if self.order_by_clause is not None else [],
|
|
318
|
-
limit=self.limit_val
|
|
319
|
-
)
|
|
305
|
+
limit=self.limit_val
|
|
306
|
+
)
|
|
320
307
|
|
|
321
308
|
|
|
322
309
|
def show(self, n: int = 20) -> DataFrameResultSet:
|
|
@@ -384,15 +371,10 @@ class DataFrame:
|
|
|
384
371
|
group_by_clause=group_by_clause, grouping_tbl=self.grouping_tbl,
|
|
385
372
|
order_by_clause=order_by_clause, limit=self.limit_val)
|
|
386
373
|
|
|
387
|
-
def
|
|
388
|
-
return self._collect()
|
|
389
|
-
|
|
390
|
-
def _collect(self, conn: Optional[sql.engine.Connection] = None) -> DataFrameResultSet:
|
|
374
|
+
def _output_row_iterator(self, conn: Optional[sql.engine.Connection] = None) -> Iterator[list]:
|
|
391
375
|
try:
|
|
392
|
-
result_rows = []
|
|
393
376
|
for data_row in self._exec(conn):
|
|
394
|
-
|
|
395
|
-
result_rows.append(result_row)
|
|
377
|
+
yield [data_row[e.slot_idx] for e in self._select_list_exprs]
|
|
396
378
|
except excs.ExprEvalError as e:
|
|
397
379
|
msg = f'In row {e.row_num} the {e.expr_msg} encountered exception ' f'{type(e.exc).__name__}:\n{str(e.exc)}'
|
|
398
380
|
if len(e.input_vals) > 0:
|
|
@@ -412,7 +394,11 @@ class DataFrame:
|
|
|
412
394
|
except sql.exc.DBAPIError as e:
|
|
413
395
|
raise excs.Error(f'Error during SQL execution:\n{e}')
|
|
414
396
|
|
|
415
|
-
|
|
397
|
+
def collect(self) -> DataFrameResultSet:
|
|
398
|
+
return self._collect()
|
|
399
|
+
|
|
400
|
+
def _collect(self, conn: Optional[sql.engine.Connection] = None) -> DataFrameResultSet:
|
|
401
|
+
return DataFrameResultSet(list(self._output_row_iterator(conn)), self.schema)
|
|
416
402
|
|
|
417
403
|
def count(self) -> int:
|
|
418
404
|
from pixeltable.plan import Planner
|
|
@@ -629,17 +615,15 @@ class DataFrame:
|
|
|
629
615
|
if self.limit_val is not None:
|
|
630
616
|
raise excs.Error(f'Cannot use `{op_name}` after `limit`')
|
|
631
617
|
|
|
632
|
-
def __getitem__(self, index:
|
|
618
|
+
def __getitem__(self, index: Union[exprs.Expr, Sequence[exprs.Expr]]) -> DataFrame:
|
|
633
619
|
"""
|
|
634
620
|
Allowed:
|
|
635
621
|
- [List[Expr]]/[Tuple[Expr]]: setting the select list
|
|
636
622
|
- [Expr]: setting a single-col select list
|
|
637
623
|
"""
|
|
638
|
-
if isinstance(index, tuple):
|
|
639
|
-
index = list(index)
|
|
640
624
|
if isinstance(index, exprs.Expr):
|
|
641
|
-
|
|
642
|
-
if isinstance(index,
|
|
625
|
+
return self.select(index)
|
|
626
|
+
if isinstance(index, Sequence):
|
|
643
627
|
return self.select(*index)
|
|
644
628
|
raise TypeError(f'Invalid index type: {type(index)}')
|
|
645
629
|
|
pixeltable/env.py
CHANGED
|
@@ -342,7 +342,7 @@ class Env:
|
|
|
342
342
|
|
|
343
343
|
if create_db:
|
|
344
344
|
from pixeltable.metadata import schema
|
|
345
|
-
schema.
|
|
345
|
+
schema.base_metadata.create_all(self._sa_engine)
|
|
346
346
|
metadata.create_system_info(self._sa_engine)
|
|
347
347
|
|
|
348
348
|
print(f'Connected to Pixeltable database at: {self.db_url}')
|
|
@@ -494,13 +494,18 @@ class Env:
|
|
|
494
494
|
self.__register_package('anthropic')
|
|
495
495
|
self.__register_package('boto3')
|
|
496
496
|
self.__register_package('datasets')
|
|
497
|
+
self.__register_package('fiftyone')
|
|
497
498
|
self.__register_package('fireworks', library_name='fireworks-ai')
|
|
499
|
+
self.__register_package('huggingface_hub', library_name='huggingface-hub')
|
|
498
500
|
self.__register_package('label_studio_sdk', library_name='label-studio-sdk')
|
|
501
|
+
self.__register_package('llama_cpp', library_name='llama-cpp-python')
|
|
499
502
|
self.__register_package('mistralai')
|
|
500
503
|
self.__register_package('mistune')
|
|
504
|
+
self.__register_package('ollama')
|
|
501
505
|
self.__register_package('openai')
|
|
502
506
|
self.__register_package('openpyxl')
|
|
503
507
|
self.__register_package('pyarrow')
|
|
508
|
+
self.__register_package('replicate')
|
|
504
509
|
self.__register_package('sentence_transformers', library_name='sentence-transformers')
|
|
505
510
|
self.__register_package('spacy')
|
|
506
511
|
self.__register_package('tiktoken')
|
pixeltable/exec/__init__.py
CHANGED
|
@@ -6,6 +6,5 @@ from .exec_context import ExecContext
|
|
|
6
6
|
from .exec_node import ExecNode
|
|
7
7
|
from .expr_eval_node import ExprEvalNode
|
|
8
8
|
from .in_memory_data_node import InMemoryDataNode
|
|
9
|
-
from .media_validation_node import MediaValidationNode
|
|
10
9
|
from .row_update_node import RowUpdateNode
|
|
11
|
-
from .sql_node import SqlLookupNode, SqlScanNode
|
|
10
|
+
from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode
|
|
@@ -2,28 +2,43 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import sys
|
|
5
|
-
from typing import Iterable,
|
|
5
|
+
from typing import Any, Iterable, Iterator, Optional, cast
|
|
6
6
|
|
|
7
7
|
import pixeltable.catalog as catalog
|
|
8
8
|
import pixeltable.exceptions as excs
|
|
9
9
|
import pixeltable.exprs as exprs
|
|
10
|
+
|
|
10
11
|
from .data_row_batch import DataRowBatch
|
|
11
12
|
from .exec_node import ExecNode
|
|
12
13
|
|
|
13
14
|
_logger = logging.getLogger('pixeltable')
|
|
14
15
|
|
|
15
16
|
class AggregationNode(ExecNode):
|
|
17
|
+
"""
|
|
18
|
+
In-memory aggregation for UDAs.
|
|
19
|
+
|
|
20
|
+
At the moment, this returns all results in a single DataRowBatch.
|
|
21
|
+
"""
|
|
22
|
+
group_by: Optional[list[exprs.Expr]]
|
|
23
|
+
input_exprs: list[exprs.Expr]
|
|
24
|
+
agg_fn_eval_ctx: exprs.RowBuilder.EvalCtx
|
|
25
|
+
agg_fn_calls: list[exprs.FunctionCall]
|
|
26
|
+
output_batch: DataRowBatch
|
|
27
|
+
|
|
16
28
|
def __init__(
|
|
17
|
-
self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, group_by:
|
|
18
|
-
agg_fn_calls:
|
|
29
|
+
self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, group_by: Optional[list[exprs.Expr]],
|
|
30
|
+
agg_fn_calls: list[exprs.FunctionCall], input_exprs: Iterable[exprs.Expr], input: ExecNode
|
|
19
31
|
):
|
|
20
|
-
|
|
32
|
+
output_exprs: list[exprs.Expr] = [] if group_by is None else list(group_by)
|
|
33
|
+
output_exprs.extend(agg_fn_calls)
|
|
34
|
+
super().__init__(row_builder, output_exprs, input_exprs, input)
|
|
21
35
|
self.input = input
|
|
22
36
|
self.group_by = group_by
|
|
23
37
|
self.input_exprs = list(input_exprs)
|
|
24
|
-
self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=input_exprs)
|
|
38
|
+
self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=self.input_exprs)
|
|
25
39
|
# we need to make sure to refer to the same exprs that RowBuilder.eval() will use
|
|
26
|
-
self.agg_fn_calls = self.agg_fn_eval_ctx.target_exprs
|
|
40
|
+
self.agg_fn_calls = [cast(exprs.FunctionCall, e) for e in self.agg_fn_eval_ctx.target_exprs]
|
|
41
|
+
# create output_batch here, rather than in __iter__(), so we don't need to remember tbl and row_builder
|
|
27
42
|
self.output_batch = DataRowBatch(tbl, row_builder, 0)
|
|
28
43
|
|
|
29
44
|
def _reset_agg_state(self, row_num: int) -> None:
|
|
@@ -45,17 +60,14 @@ class AggregationNode(ExecNode):
|
|
|
45
60
|
input_vals = [row[d.slot_idx] for d in fn_call.dependencies()]
|
|
46
61
|
raise excs.ExprEvalError(fn_call, expr_msg, e, exc_tb, input_vals, row_num)
|
|
47
62
|
|
|
48
|
-
def
|
|
49
|
-
if self.output_batch is None:
|
|
50
|
-
raise StopIteration
|
|
51
|
-
|
|
63
|
+
def __iter__(self) -> Iterator[DataRowBatch]:
|
|
52
64
|
prev_row: Optional[exprs.DataRow] = None
|
|
53
|
-
current_group: Optional[
|
|
65
|
+
current_group: Optional[list[Any]] = None # the values of the group-by exprs
|
|
54
66
|
num_input_rows = 0
|
|
55
67
|
for row_batch in self.input:
|
|
56
68
|
num_input_rows += len(row_batch)
|
|
57
69
|
for row in row_batch:
|
|
58
|
-
group = [row[e.slot_idx] for e in self.group_by]
|
|
70
|
+
group = [row[e.slot_idx] for e in self.group_by] if self.group_by is not None else None
|
|
59
71
|
if current_group is None:
|
|
60
72
|
current_group = group
|
|
61
73
|
self._reset_agg_state(0)
|
|
@@ -71,9 +83,7 @@ class AggregationNode(ExecNode):
|
|
|
71
83
|
self.row_builder.eval(prev_row, self.agg_fn_eval_ctx, profile=self.ctx.profile)
|
|
72
84
|
self.output_batch.add_row(prev_row)
|
|
73
85
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
self.output_batch
|
|
77
|
-
_logger.debug(f'AggregateNode: consumed {num_input_rows} rows, returning {len(result.rows)} rows')
|
|
78
|
-
return result
|
|
86
|
+
self.output_batch.flush_imgs(None, self.stored_img_cols, self.flushed_img_slots)
|
|
87
|
+
_logger.debug(f'AggregateNode: consumed {num_input_rows} rows, returning {len(self.output_batch.rows)} rows')
|
|
88
|
+
yield self.output_batch
|
|
79
89
|
|
|
@@ -79,7 +79,7 @@ class CachePrefetchNode(ExecNode):
|
|
|
79
79
|
|
|
80
80
|
return input_batch
|
|
81
81
|
|
|
82
|
-
def _fetch_url(self, row: exprs.DataRow, slot_idx: int) -> Optional[
|
|
82
|
+
def _fetch_url(self, row: exprs.DataRow, slot_idx: int) -> Optional[Path]:
|
|
83
83
|
"""Fetches a remote URL into Env.tmp_dir and returns its path"""
|
|
84
84
|
url = row.file_urls[slot_idx]
|
|
85
85
|
parsed = urllib.parse.urlparse(url)
|
|
@@ -14,6 +14,13 @@ class DataRowBatch:
|
|
|
14
14
|
|
|
15
15
|
Contains the metadata needed to initialize DataRows.
|
|
16
16
|
"""
|
|
17
|
+
tbl: Optional[catalog.TableVersion]
|
|
18
|
+
row_builder: exprs.RowBuilder
|
|
19
|
+
img_slot_idxs: list[int]
|
|
20
|
+
media_slot_idxs: list[int] # non-image media slots
|
|
21
|
+
array_slot_idxs: list[int]
|
|
22
|
+
rows: list[exprs.DataRow]
|
|
23
|
+
|
|
17
24
|
def __init__(self, tbl: Optional[catalog.TableVersion], row_builder: exprs.RowBuilder, len: int = 0):
|
|
18
25
|
self.tbl = tbl
|
|
19
26
|
self.row_builder = row_builder
|
|
@@ -39,17 +46,10 @@ class DataRowBatch:
|
|
|
39
46
|
def pop_row(self) -> exprs.DataRow:
|
|
40
47
|
return self.rows.pop()
|
|
41
48
|
|
|
42
|
-
def set_row_ids(self, row_ids: List[int]) -> None:
|
|
43
|
-
"""Sets pks for rows in batch"""
|
|
44
|
-
assert self.tbl is not None
|
|
45
|
-
assert len(row_ids) == len(self.rows)
|
|
46
|
-
for row, row_id in zip(self.rows, row_ids):
|
|
47
|
-
row.set_pk((row_id, self.tbl))
|
|
48
|
-
|
|
49
49
|
def __len__(self) -> int:
|
|
50
50
|
return len(self.rows)
|
|
51
51
|
|
|
52
|
-
def __getitem__(self, index:
|
|
52
|
+
def __getitem__(self, index: int) -> exprs.DataRow:
|
|
53
53
|
return self.rows[index]
|
|
54
54
|
|
|
55
55
|
def flush_imgs(
|
|
@@ -74,21 +74,4 @@ class DataRowBatch:
|
|
|
74
74
|
row.flush_img(slot_idx)
|
|
75
75
|
|
|
76
76
|
def __iter__(self) -> Iterator[exprs.DataRow]:
|
|
77
|
-
return
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
class DataRowBatchIterator:
|
|
81
|
-
"""
|
|
82
|
-
Iterator over a DataRowBatch.
|
|
83
|
-
"""
|
|
84
|
-
def __init__(self, batch: DataRowBatch):
|
|
85
|
-
self.row_batch = batch
|
|
86
|
-
self.index = 0
|
|
87
|
-
|
|
88
|
-
def __next__(self) -> exprs.DataRow:
|
|
89
|
-
if self.index >= len(self.row_batch.rows):
|
|
90
|
-
raise StopIteration
|
|
91
|
-
row = self.row_batch.rows[self.index]
|
|
92
|
-
self.index += 1
|
|
93
|
-
return row
|
|
94
|
-
|
|
77
|
+
return iter(self.rows)
|