pixeltable 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (120) hide show
  1. pixeltable/__init__.py +7 -19
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +7 -7
  4. pixeltable/catalog/column.py +37 -11
  5. pixeltable/catalog/globals.py +21 -0
  6. pixeltable/catalog/insertable_table.py +6 -4
  7. pixeltable/catalog/table.py +227 -148
  8. pixeltable/catalog/table_version.py +66 -28
  9. pixeltable/catalog/table_version_path.py +0 -8
  10. pixeltable/catalog/view.py +18 -19
  11. pixeltable/dataframe.py +16 -32
  12. pixeltable/env.py +6 -1
  13. pixeltable/exec/__init__.py +1 -2
  14. pixeltable/exec/aggregation_node.py +27 -17
  15. pixeltable/exec/cache_prefetch_node.py +1 -1
  16. pixeltable/exec/data_row_batch.py +9 -26
  17. pixeltable/exec/exec_node.py +36 -7
  18. pixeltable/exec/expr_eval_node.py +19 -11
  19. pixeltable/exec/in_memory_data_node.py +14 -11
  20. pixeltable/exec/sql_node.py +266 -138
  21. pixeltable/exprs/__init__.py +1 -0
  22. pixeltable/exprs/arithmetic_expr.py +3 -1
  23. pixeltable/exprs/array_slice.py +7 -7
  24. pixeltable/exprs/column_property_ref.py +37 -10
  25. pixeltable/exprs/column_ref.py +93 -14
  26. pixeltable/exprs/comparison.py +5 -5
  27. pixeltable/exprs/compound_predicate.py +8 -7
  28. pixeltable/exprs/data_row.py +56 -36
  29. pixeltable/exprs/expr.py +65 -63
  30. pixeltable/exprs/expr_dict.py +55 -0
  31. pixeltable/exprs/expr_set.py +26 -15
  32. pixeltable/exprs/function_call.py +53 -24
  33. pixeltable/exprs/globals.py +4 -1
  34. pixeltable/exprs/in_predicate.py +8 -7
  35. pixeltable/exprs/inline_expr.py +4 -4
  36. pixeltable/exprs/is_null.py +4 -4
  37. pixeltable/exprs/json_mapper.py +11 -12
  38. pixeltable/exprs/json_path.py +5 -10
  39. pixeltable/exprs/literal.py +5 -5
  40. pixeltable/exprs/method_ref.py +5 -4
  41. pixeltable/exprs/object_ref.py +2 -1
  42. pixeltable/exprs/row_builder.py +88 -36
  43. pixeltable/exprs/rowid_ref.py +14 -13
  44. pixeltable/exprs/similarity_expr.py +12 -7
  45. pixeltable/exprs/sql_element_cache.py +12 -6
  46. pixeltable/exprs/type_cast.py +8 -6
  47. pixeltable/exprs/variable.py +5 -4
  48. pixeltable/ext/functions/whisperx.py +7 -2
  49. pixeltable/func/aggregate_function.py +1 -1
  50. pixeltable/func/callable_function.py +2 -2
  51. pixeltable/func/function.py +11 -10
  52. pixeltable/func/function_registry.py +6 -7
  53. pixeltable/func/query_template_function.py +11 -12
  54. pixeltable/func/signature.py +17 -15
  55. pixeltable/func/udf.py +0 -4
  56. pixeltable/functions/__init__.py +2 -2
  57. pixeltable/functions/audio.py +4 -6
  58. pixeltable/functions/globals.py +84 -42
  59. pixeltable/functions/huggingface.py +31 -34
  60. pixeltable/functions/image.py +59 -45
  61. pixeltable/functions/json.py +0 -1
  62. pixeltable/functions/llama_cpp.py +106 -0
  63. pixeltable/functions/mistralai.py +2 -2
  64. pixeltable/functions/ollama.py +147 -0
  65. pixeltable/functions/openai.py +22 -25
  66. pixeltable/functions/replicate.py +72 -0
  67. pixeltable/functions/string.py +59 -50
  68. pixeltable/functions/timestamp.py +20 -20
  69. pixeltable/functions/together.py +2 -2
  70. pixeltable/functions/video.py +11 -20
  71. pixeltable/functions/whisper.py +2 -20
  72. pixeltable/globals.py +65 -74
  73. pixeltable/index/base.py +2 -2
  74. pixeltable/index/btree.py +20 -7
  75. pixeltable/index/embedding_index.py +12 -14
  76. pixeltable/io/__init__.py +1 -2
  77. pixeltable/io/external_store.py +11 -5
  78. pixeltable/io/fiftyone.py +178 -0
  79. pixeltable/io/globals.py +98 -2
  80. pixeltable/io/hf_datasets.py +1 -1
  81. pixeltable/io/label_studio.py +6 -6
  82. pixeltable/io/parquet.py +14 -13
  83. pixeltable/iterators/base.py +3 -2
  84. pixeltable/iterators/document.py +10 -8
  85. pixeltable/iterators/video.py +126 -60
  86. pixeltable/metadata/__init__.py +4 -3
  87. pixeltable/metadata/converters/convert_14.py +4 -2
  88. pixeltable/metadata/converters/convert_15.py +1 -1
  89. pixeltable/metadata/converters/convert_19.py +1 -0
  90. pixeltable/metadata/converters/convert_20.py +1 -1
  91. pixeltable/metadata/converters/convert_21.py +34 -0
  92. pixeltable/metadata/converters/util.py +54 -12
  93. pixeltable/metadata/notes.py +1 -0
  94. pixeltable/metadata/schema.py +40 -21
  95. pixeltable/plan.py +149 -165
  96. pixeltable/py.typed +0 -0
  97. pixeltable/store.py +57 -37
  98. pixeltable/tool/create_test_db_dump.py +6 -6
  99. pixeltable/tool/create_test_video.py +1 -1
  100. pixeltable/tool/doc_plugins/griffe.py +3 -34
  101. pixeltable/tool/embed_udf.py +1 -1
  102. pixeltable/tool/mypy_plugin.py +55 -0
  103. pixeltable/type_system.py +260 -61
  104. pixeltable/utils/arrow.py +10 -9
  105. pixeltable/utils/coco.py +4 -4
  106. pixeltable/utils/documents.py +16 -2
  107. pixeltable/utils/filecache.py +9 -9
  108. pixeltable/utils/formatter.py +10 -11
  109. pixeltable/utils/http_server.py +2 -5
  110. pixeltable/utils/media_store.py +6 -6
  111. pixeltable/utils/pytorch.py +10 -11
  112. pixeltable/utils/sql.py +2 -1
  113. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/METADATA +50 -13
  114. pixeltable-0.2.22.dist-info/RECORD +153 -0
  115. pixeltable/exec/media_validation_node.py +0 -43
  116. pixeltable/utils/help.py +0 -11
  117. pixeltable-0.2.20.dist-info/RECORD +0 -147
  118. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/LICENSE +0 -0
  119. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/WHEEL +0 -0
  120. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/entry_points.txt +0 -0
@@ -6,7 +6,7 @@ import inspect
6
6
  import logging
7
7
  import time
8
8
  import uuid
9
- from typing import TYPE_CHECKING, Any, Iterable, Optional
9
+ from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional
10
10
  from uuid import UUID
11
11
 
12
12
  import sqlalchemy as sql
@@ -26,7 +26,7 @@ from pixeltable.utils.media_store import MediaStore
26
26
 
27
27
  from ..func.globals import resolve_symbol
28
28
  from .column import Column
29
- from .globals import _POS_COLUMN_NAME, _ROWID_COLUMN_NAME, UpdateStatus, is_valid_identifier
29
+ from .globals import _POS_COLUMN_NAME, _ROWID_COLUMN_NAME, UpdateStatus, is_valid_identifier, MediaValidation
30
30
 
31
31
  if TYPE_CHECKING:
32
32
  from pixeltable import exec, store
@@ -53,6 +53,7 @@ class TableVersion:
53
53
  name: str
54
54
  version: int
55
55
  comment: str
56
+ media_validation: MediaValidation
56
57
  num_retained_versions: int
57
58
  schema_version: int
58
59
  view_md: Optional[schema.ViewMd]
@@ -109,6 +110,7 @@ class TableVersion:
109
110
  self.view_md = tbl_md.view_md # save this as-is, it's needed for _create_md()
110
111
  is_view = tbl_md.view_md is not None
111
112
  self.is_snapshot = (is_view and tbl_md.view_md.is_snapshot) or bool(is_snapshot)
113
+ self.media_validation = MediaValidation[schema_version_md.media_validation.upper()]
112
114
  # a mutable TableVersion doesn't have a static version
113
115
  self.effective_version = self.version if self.is_snapshot else None
114
116
 
@@ -182,7 +184,7 @@ class TableVersion:
182
184
  @classmethod
183
185
  def create(
184
186
  cls, session: orm.Session, dir_id: UUID, name: str, cols: list[Column], num_retained_versions: int,
185
- comment: str, base_path: Optional[pxt.catalog.TableVersionPath] = None,
187
+ comment: str, media_validation: MediaValidation, base_path: Optional[pxt.catalog.TableVersionPath] = None,
186
188
  view_md: Optional[schema.ViewMd] = None
187
189
  ) -> tuple[UUID, Optional[TableVersion]]:
188
190
  # assign ids
@@ -214,11 +216,17 @@ class TableVersion:
214
216
  tbl_id=tbl_record.id, version=0, md=dataclasses.asdict(table_version_md))
215
217
 
216
218
  # create schema.TableSchemaVersion
217
- schema_col_md = {col.id: schema.SchemaColumn(pos=pos, name=col.name) for pos, col in enumerate(cols)}
219
+ schema_col_md: dict[int, schema.SchemaColumn] = {}
220
+ for pos, col in enumerate(cols):
221
+ md = schema.SchemaColumn(
222
+ pos=pos, name=col.name,
223
+ media_validation=col._media_validation.name.lower() if col._media_validation is not None else None)
224
+ schema_col_md[col.id] = md
218
225
 
219
226
  schema_version_md = schema.TableSchemaVersionMd(
220
227
  schema_version=0, preceding_schema_version=None, columns=schema_col_md,
221
- num_retained_versions=num_retained_versions, comment=comment)
228
+ num_retained_versions=num_retained_versions, comment=comment,
229
+ media_validation=media_validation.name.lower())
222
230
  schema_version_record = schema.TableSchemaVersion(
223
231
  tbl_id=tbl_record.id, schema_version=0, md=dataclasses.asdict(schema_version_md))
224
232
 
@@ -285,10 +293,15 @@ class TableVersion:
285
293
  self.cols_by_name = {}
286
294
  self.cols_by_id = {}
287
295
  for col_md in tbl_md.column_md.values():
288
- col_name = schema_version_md.columns[col_md.id].name if col_md.id in schema_version_md.columns else None
296
+ schema_col_md = schema_version_md.columns[col_md.id] if col_md.id in schema_version_md.columns else None
297
+ col_name = schema_col_md.name if schema_col_md is not None else None
298
+ media_val = (
299
+ MediaValidation[schema_col_md.media_validation.upper()]
300
+ if schema_col_md is not None and schema_col_md.media_validation is not None else None
301
+ )
289
302
  col = Column(
290
303
  col_id=col_md.id, name=col_name, col_type=ts.ColumnType.from_dict(col_md.col_type),
291
- is_pk=col_md.is_pk, stored=col_md.stored,
304
+ is_pk=col_md.is_pk, stored=col_md.stored, media_validation=media_val,
292
305
  schema_version_add=col_md.schema_version_add, schema_version_drop=col_md.schema_version_drop,
293
306
  value_expr_dict=col_md.value_expr)
294
307
  col.tbl = self
@@ -349,7 +362,8 @@ class TableVersion:
349
362
  self.store_tbl = StoreTable(self)
350
363
 
351
364
  def _update_md(
352
- self, timestamp: float, conn: sql.engine.Connection, update_tbl_version: bool = True, preceding_schema_version: Optional[int] = None
365
+ self, timestamp: float, conn: sql.engine.Connection, update_tbl_version: bool = True,
366
+ preceding_schema_version: Optional[int] = None
353
367
  ) -> None:
354
368
  """Writes table metadata to the database.
355
369
 
@@ -453,7 +467,9 @@ class TableVersion:
453
467
  self.idxs_by_name[idx_name] = idx_info
454
468
 
455
469
  # add the columns and update the metadata
456
- status = self._add_columns([val_col, undo_col], conn)
470
+ # TODO support on_error='abort' for indices; it's tricky because of the way metadata changes are entangled
471
+ # with the database operations
472
+ status = self._add_columns([val_col, undo_col], conn, print_stats=False, on_error='ignore')
457
473
  # now create the index structure
458
474
  idx.create_index(self._store_idx_name(idx_id), val_col, conn)
459
475
 
@@ -478,7 +494,7 @@ class TableVersion:
478
494
  self._update_md(time.time(), conn, preceding_schema_version=preceding_schema_version)
479
495
  _logger.info(f'Dropped index {idx_md.name} on table {self.name}')
480
496
 
481
- def add_column(self, col: Column, print_stats: bool = False) -> UpdateStatus:
497
+ def add_column(self, col: Column, print_stats: bool, on_error: Literal['abort', 'ignore']) -> UpdateStatus:
482
498
  """Adds a column to the table.
483
499
  """
484
500
  assert not self.is_snapshot
@@ -498,9 +514,8 @@ class TableVersion:
498
514
  preceding_schema_version = self.schema_version
499
515
  self.schema_version = self.version
500
516
  with Env.get().engine.begin() as conn:
501
- status = self._add_columns([col], conn, print_stats=print_stats)
517
+ status = self._add_columns([col], conn, print_stats=print_stats, on_error=on_error)
502
518
  _ = self._add_default_index(col, conn)
503
- # TODO: what to do about errors?
504
519
  self._update_md(time.time(), conn, preceding_schema_version=preceding_schema_version)
505
520
  _logger.info(f'Added column {col.name} to table {self.name}, new version: {self.version}')
506
521
 
@@ -512,7 +527,13 @@ class TableVersion:
512
527
  _logger.info(f'Column {col.name}: {msg}')
513
528
  return status
514
529
 
515
- def _add_columns(self, cols: Iterable[Column], conn: sql.engine.Connection, print_stats: bool = False) -> UpdateStatus:
530
+ def _add_columns(
531
+ self,
532
+ cols: Iterable[Column],
533
+ conn: sql.engine.Connection,
534
+ print_stats: bool,
535
+ on_error: Literal['abort', 'ignore']
536
+ ) -> UpdateStatus:
516
537
  """Add and populate columns within the current transaction"""
517
538
  cols = list(cols)
518
539
  row_count = self.store_tbl.count(conn=conn)
@@ -550,10 +571,14 @@ class TableVersion:
550
571
  try:
551
572
  plan.ctx.set_conn(conn)
552
573
  plan.open()
553
- num_excs = self.store_tbl.load_column(col, plan, value_expr_slot_idx, conn)
574
+ try:
575
+ num_excs = self.store_tbl.load_column(col, plan, value_expr_slot_idx, conn, on_error)
576
+ except sql.exc.DBAPIError as exc:
577
+ # Wrap the DBAPIError in an excs.Error to unify processing in the subsequent except block
578
+ raise excs.Error(f'SQL error during execution of computed column `{col.name}`:\n{exc}') from exc
554
579
  if num_excs > 0:
555
580
  cols_with_excs.append(col)
556
- except sql.exc.DBAPIError as e:
581
+ except excs.Error as exc:
557
582
  self.cols.pop()
558
583
  for col in cols:
559
584
  # remove columns that we already added
@@ -564,7 +589,7 @@ class TableVersion:
564
589
  del self.cols_by_id[col.id]
565
590
  # we need to re-initialize the sqlalchemy schema
566
591
  self.store_tbl.create_sa_tbl()
567
- raise excs.Error(f'Error during SQL execution:\n{e}')
592
+ raise exc
568
593
  finally:
569
594
  plan.close()
570
595
 
@@ -689,21 +714,32 @@ class TableVersion:
689
714
  plan = Planner.create_insert_plan(self, rows, ignore_errors=not fail_on_exception)
690
715
  else:
691
716
  plan = Planner.create_df_insert_plan(self, df, ignore_errors=not fail_on_exception)
717
+
718
+ # this is a base table; we generate rowids during the insert
719
+ def rowids() -> Iterator[int]:
720
+ while True:
721
+ rowid = self.next_rowid
722
+ self.next_rowid += 1
723
+ yield rowid
724
+
692
725
  if conn is None:
693
726
  with Env.get().engine.begin() as conn:
694
- return self._insert(plan, conn, time.time(), print_stats)
727
+ return self._insert(
728
+ plan, conn, time.time(), print_stats=print_stats, rowids=rowids(), abort_on_exc=fail_on_exception)
695
729
  else:
696
- return self._insert(plan, conn, time.time(), print_stats)
730
+ return self._insert(
731
+ plan, conn, time.time(), print_stats=print_stats, rowids=rowids(), abort_on_exc=fail_on_exception)
697
732
 
698
733
  def _insert(
699
- self, exec_plan: 'exec.ExecNode', conn: sql.engine.Connection, timestamp: float, print_stats: bool = False,
734
+ self, exec_plan: 'exec.ExecNode', conn: sql.engine.Connection, timestamp: float, *,
735
+ rowids: Optional[Iterator[int]] = None, print_stats: bool = False, abort_on_exc: bool = False
700
736
  ) -> UpdateStatus:
701
737
  """Insert rows produced by exec_plan and propagate to views"""
702
738
  # we're creating a new version
703
739
  self.version += 1
704
740
  result = UpdateStatus()
705
- num_rows, num_excs, cols_with_excs = self.store_tbl.insert_rows(exec_plan, conn, v_min=self.version)
706
- self.next_rowid = num_rows
741
+ num_rows, num_excs, cols_with_excs = self.store_tbl.insert_rows(
742
+ exec_plan, conn, v_min=self.version, rowids=rowids, abort_on_exc=abort_on_exc)
707
743
  result.num_rows = num_rows
708
744
  result.num_excs = num_excs
709
745
  result.num_computed_values += exec_plan.ctx.num_computed_exprs * num_rows
@@ -714,7 +750,7 @@ class TableVersion:
714
750
  for view in self.mutable_views:
715
751
  from pixeltable.plan import Planner
716
752
  plan, _ = Planner.create_view_load_plan(view.path, propagates_insert=True)
717
- status = view._insert(plan, conn, timestamp, print_stats)
753
+ status = view._insert(plan, conn, timestamp, print_stats=print_stats)
718
754
  result.num_rows += status.num_rows
719
755
  result.num_excs += status.num_excs
720
756
  result.num_computed_values += status.num_computed_values
@@ -751,9 +787,7 @@ class TableVersion:
751
787
  raise excs.Error(f'Filter {analysis_info.filter} not expressible in SQL')
752
788
 
753
789
  with Env.get().engine.begin() as conn:
754
- plan, updated_cols, recomputed_cols = (
755
- Planner.create_update_plan(self.path, update_spec, [], where, cascade)
756
- )
790
+ plan, updated_cols, recomputed_cols = Planner.create_update_plan(self.path, update_spec, [], where, cascade)
757
791
  from pixeltable.exprs import SqlElementCache
758
792
  result = self.propagate_update(
759
793
  plan, where.sql_expr(SqlElementCache()) if where is not None else None, recomputed_cols,
@@ -1185,7 +1219,8 @@ class TableVersion:
1185
1219
  name=self.name, current_version=self.version, current_schema_version=self.schema_version,
1186
1220
  next_col_id=self.next_col_id, next_idx_id=self.next_idx_id, next_row_id=self.next_rowid,
1187
1221
  column_md=self._create_column_md(self.cols), index_md=self.idx_md,
1188
- external_stores=self._create_stores_md(self.external_stores.values()), view_md=self.view_md)
1222
+ external_stores=self._create_stores_md(self.external_stores.values()), view_md=self.view_md,
1223
+ )
1189
1224
 
1190
1225
  def _create_version_md(self, timestamp: float) -> schema.TableVersionMd:
1191
1226
  return schema.TableVersionMd(created_at=timestamp, version=self.version, schema_version=self.schema_version)
@@ -1193,11 +1228,14 @@ class TableVersion:
1193
1228
  def _create_schema_version_md(self, preceding_schema_version: int) -> schema.TableSchemaVersionMd:
1194
1229
  column_md: dict[int, schema.SchemaColumn] = {}
1195
1230
  for pos, col in enumerate(self.cols_by_name.values()):
1196
- column_md[col.id] = schema.SchemaColumn(pos=pos, name=col.name)
1231
+ column_md[col.id] = schema.SchemaColumn(
1232
+ pos=pos, name=col.name,
1233
+ media_validation=col._media_validation.name.lower() if col._media_validation is not None else None)
1197
1234
  # preceding_schema_version to be set by the caller
1198
1235
  return schema.TableSchemaVersionMd(
1199
1236
  schema_version=self.schema_version, preceding_schema_version=preceding_schema_version,
1200
- columns=column_md, num_retained_versions=self.num_retained_versions, comment=self.comment)
1237
+ columns=column_md, num_retained_versions=self.num_retained_versions, comment=self.comment,
1238
+ media_validation=self.media_validation.name.lower())
1201
1239
 
1202
1240
  def as_dict(self) -> dict:
1203
1241
  return {'id': str(self.id), 'effective_version': self.effective_version}
@@ -91,14 +91,6 @@ class TableVersionPath:
91
91
  col = self.tbl_version.cols_by_name[col_name]
92
92
  return ColumnRef(col)
93
93
 
94
- def __getitem__(self, index: object) -> Union[exprs.ColumnRef, pxt.DataFrame]:
95
- """Return a ColumnRef for the given column name, or a DataFrame for the given slice.
96
- """
97
- if isinstance(index, str):
98
- # basically <tbl>.<colname>
99
- return self.__getattr__(index)
100
- return pxt.DataFrame(self).__getitem__(index)
101
-
102
94
  def columns(self) -> list[Column]:
103
95
  """Return all user columns visible in this tbl version path, including columns from bases"""
104
96
  result = list(self.tbl_version.cols_by_name.values())
@@ -2,24 +2,21 @@ from __future__ import annotations
2
2
 
3
3
  import inspect
4
4
  import logging
5
- from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Type
5
+ from typing import TYPE_CHECKING, Any, Iterable, Optional
6
6
  from uuid import UUID
7
7
 
8
8
  import sqlalchemy.orm as orm
9
9
 
10
- import pixeltable.catalog as catalog
11
10
  import pixeltable.exceptions as excs
12
- import pixeltable.exprs as exprs
13
- import pixeltable.func as func
14
11
  import pixeltable.metadata.schema as md_schema
12
+ import pixeltable.type_system as ts
13
+ from pixeltable import catalog, exprs, func
15
14
  from pixeltable.env import Env
16
- from pixeltable.exceptions import Error
17
15
  from pixeltable.iterators import ComponentIterator
18
- from pixeltable.type_system import IntType, InvalidType
19
16
 
20
17
  from .catalog import Catalog
21
18
  from .column import Column
22
- from .globals import _POS_COLUMN_NAME, UpdateStatus
19
+ from .globals import _POS_COLUMN_NAME, UpdateStatus, MediaValidation
23
20
  from .table import Table
24
21
  from .table_version import TableVersion
25
22
  from .table_version_path import TableVersionPath
@@ -52,11 +49,12 @@ class View(Table):
52
49
 
53
50
  @classmethod
54
51
  def _create(
55
- cls, dir_id: UUID, name: str, base: TableVersionPath, schema: Dict[str, Any],
56
- predicate: 'pxt.exprs.Expr', is_snapshot: bool, num_retained_versions: int, comment: str,
57
- iterator_cls: Optional[Type[ComponentIterator]], iterator_args: Optional[Dict]
52
+ cls, dir_id: UUID, name: str, base: TableVersionPath, additional_columns: dict[str, Any],
53
+ predicate: Optional['pxt.exprs.Expr'], is_snapshot: bool, num_retained_versions: int, comment: str,
54
+ media_validation: MediaValidation,
55
+ iterator_cls: Optional[type[ComponentIterator]], iterator_args: Optional[dict]
58
56
  ) -> View:
59
- columns = cls._create_columns(schema)
57
+ columns = cls._create_columns(additional_columns)
60
58
  cls._verify_schema(columns)
61
59
 
62
60
  # verify that filter can be evaluated in the context of the base
@@ -92,17 +90,17 @@ class View(Table):
92
90
  func.Parameter(param_name, param_type, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
93
91
  for param_name, param_type in iterator_cls.input_schema().items()
94
92
  ]
95
- sig = func.Signature(InvalidType(), params)
93
+ sig = func.Signature(ts.InvalidType(), params)
96
94
  from pixeltable.exprs import FunctionCall
97
95
  FunctionCall.normalize_args(iterator_cls.__name__, sig, bound_args)
98
96
  except TypeError as e:
99
- raise Error(f'Cannot instantiate iterator with given arguments: {e}')
97
+ raise excs.Error(f'Cannot instantiate iterator with given arguments: {e}')
100
98
 
101
99
  # prepend pos and output_schema columns to cols:
102
100
  # a component view exposes the pos column of its rowid;
103
101
  # we create that column here, so it gets assigned a column id;
104
102
  # stored=False: it is not stored separately (it's already stored as part of the rowid)
105
- iterator_cols = [Column(_POS_COLUMN_NAME, IntType(), stored=False)]
103
+ iterator_cols = [Column(_POS_COLUMN_NAME, ts.IntType(), stored=False)]
106
104
  output_dict, unstored_cols = iterator_cls.output_schema(**bound_args)
107
105
  iterator_cols.extend([
108
106
  Column(col_name, col_type, stored=col_name not in unstored_cols)
@@ -112,12 +110,12 @@ class View(Table):
112
110
  iterator_col_names = {col.name for col in iterator_cols}
113
111
  for col in columns:
114
112
  if col.name in iterator_col_names:
115
- raise Error(f'Duplicate name: column {col.name} is already present in the iterator output schema')
113
+ raise excs.Error(f'Duplicate name: column {col.name} is already present in the iterator output schema')
116
114
  columns = iterator_cols + columns
117
115
 
118
116
  with orm.Session(Env.get().engine, future=True) as session:
119
117
  from pixeltable.exprs import InlineDict
120
- iterator_args_expr = InlineDict(iterator_args) if iterator_args is not None else None
118
+ iterator_args_expr: exprs.Expr = InlineDict(iterator_args) if iterator_args is not None else None
121
119
  iterator_class_fqn = f'{iterator_cls.__module__}.{iterator_cls.__name__}' if iterator_cls is not None \
122
120
  else None
123
121
  base_version_path = cls._get_snapshot_path(base) if is_snapshot else base
@@ -142,7 +140,8 @@ class View(Table):
142
140
  iterator_args=iterator_args_expr.as_dict() if iterator_args_expr is not None else None)
143
141
 
144
142
  id, tbl_version = TableVersion.create(
145
- session, dir_id, name, columns, num_retained_versions, comment, base_path=base_version_path, view_md=view_md)
143
+ session, dir_id, name, columns, num_retained_versions, comment, media_validation=media_validation,
144
+ base_path=base_version_path, view_md=view_md)
146
145
  if tbl_version is None:
147
146
  # this is purely a snapshot: we use the base's tbl version path
148
147
  view = cls(id, dir_id, name, base_version_path, base.tbl_id(), snapshot_only=True)
@@ -168,11 +167,11 @@ class View(Table):
168
167
 
169
168
  @classmethod
170
169
  def _verify_column(
171
- cls, col: Column, existing_column_names: Set[str], existing_query_names: Optional[Set[str]] = None
170
+ cls, col: Column, existing_column_names: set[str], existing_query_names: Optional[set[str]] = None
172
171
  ) -> None:
173
172
  # make sure that columns are nullable or have a default
174
173
  if not col.col_type.nullable and not col.is_computed:
175
- raise Error(f'Column {col.name}: non-computed columns in views must be nullable')
174
+ raise excs.Error(f'Column {col.name}: non-computed columns in views must be nullable')
176
175
  super()._verify_column(col, existing_column_names, existing_query_names)
177
176
 
178
177
  @classmethod
pixeltable/dataframe.py CHANGED
@@ -8,7 +8,7 @@ import logging
8
8
  import mimetypes
9
9
  import traceback
10
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterator, List, Optional, Set, Tuple
11
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterator, List, Optional, Sequence, Set, Tuple, Union
12
12
 
13
13
  import pandas as pd
14
14
  import pandas.io.formats.style
@@ -97,8 +97,8 @@ class DataFrameResultSet:
97
97
  return self._rows[index[0]][col_idx]
98
98
  raise excs.Error(f'Bad index: {index}')
99
99
 
100
- def __iter__(self) -> DataFrameResultSetIterator:
101
- return DataFrameResultSetIterator(self)
100
+ def __iter__(self) -> Iterator[dict[str, Any]]:
101
+ return (self._row_to_dict(i) for i in range(len(self)))
102
102
 
103
103
  def __eq__(self, other):
104
104
  if not isinstance(other, DataFrameResultSet):
@@ -106,19 +106,6 @@ class DataFrameResultSet:
106
106
  return self.to_pandas().equals(other.to_pandas())
107
107
 
108
108
 
109
- class DataFrameResultSetIterator:
110
- def __init__(self, result_set: DataFrameResultSet):
111
- self._result_set = result_set
112
- self._idx = 0
113
-
114
- def __next__(self) -> Dict[str, Any]:
115
- if self._idx >= len(self._result_set):
116
- raise StopIteration
117
- row = self._result_set._row_to_dict(self._idx)
118
- self._idx += 1
119
- return row
120
-
121
-
122
109
  # # TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
123
110
  # class AnalysisInfo:
124
111
  # def __init__(self, tbl: catalog.TableVersion):
@@ -296,7 +283,7 @@ class DataFrame:
296
283
 
297
284
  def _create_query_plan(self) -> exec.ExecNode:
298
285
  # construct a group-by clause if we're grouping by a table
299
- group_by_clause: List[exprs.Expr] = []
286
+ group_by_clause: Optional[list[exprs.Expr]] = None
300
287
  if self.grouping_tbl is not None:
301
288
  assert self.group_by_clause is None
302
289
  num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
@@ -315,8 +302,8 @@ class DataFrame:
315
302
  where_clause=self.where_clause,
316
303
  group_by_clause=group_by_clause,
317
304
  order_by_clause=self.order_by_clause if self.order_by_clause is not None else [],
318
- limit=self.limit_val if self.limit_val is not None else 0,
319
- ) # limit_val == 0: no limit_val
305
+ limit=self.limit_val
306
+ )
320
307
 
321
308
 
322
309
  def show(self, n: int = 20) -> DataFrameResultSet:
@@ -384,15 +371,10 @@ class DataFrame:
384
371
  group_by_clause=group_by_clause, grouping_tbl=self.grouping_tbl,
385
372
  order_by_clause=order_by_clause, limit=self.limit_val)
386
373
 
387
- def collect(self) -> DataFrameResultSet:
388
- return self._collect()
389
-
390
- def _collect(self, conn: Optional[sql.engine.Connection] = None) -> DataFrameResultSet:
374
+ def _output_row_iterator(self, conn: Optional[sql.engine.Connection] = None) -> Iterator[list]:
391
375
  try:
392
- result_rows = []
393
376
  for data_row in self._exec(conn):
394
- result_row = [data_row[e.slot_idx] for e in self._select_list_exprs]
395
- result_rows.append(result_row)
377
+ yield [data_row[e.slot_idx] for e in self._select_list_exprs]
396
378
  except excs.ExprEvalError as e:
397
379
  msg = f'In row {e.row_num} the {e.expr_msg} encountered exception ' f'{type(e.exc).__name__}:\n{str(e.exc)}'
398
380
  if len(e.input_vals) > 0:
@@ -412,7 +394,11 @@ class DataFrame:
412
394
  except sql.exc.DBAPIError as e:
413
395
  raise excs.Error(f'Error during SQL execution:\n{e}')
414
396
 
415
- return DataFrameResultSet(result_rows, self.schema)
397
+ def collect(self) -> DataFrameResultSet:
398
+ return self._collect()
399
+
400
+ def _collect(self, conn: Optional[sql.engine.Connection] = None) -> DataFrameResultSet:
401
+ return DataFrameResultSet(list(self._output_row_iterator(conn)), self.schema)
416
402
 
417
403
  def count(self) -> int:
418
404
  from pixeltable.plan import Planner
@@ -629,17 +615,15 @@ class DataFrame:
629
615
  if self.limit_val is not None:
630
616
  raise excs.Error(f'Cannot use `{op_name}` after `limit`')
631
617
 
632
- def __getitem__(self, index: object) -> DataFrame:
618
+ def __getitem__(self, index: Union[exprs.Expr, Sequence[exprs.Expr]]) -> DataFrame:
633
619
  """
634
620
  Allowed:
635
621
  - [List[Expr]]/[Tuple[Expr]]: setting the select list
636
622
  - [Expr]: setting a single-col select list
637
623
  """
638
- if isinstance(index, tuple):
639
- index = list(index)
640
624
  if isinstance(index, exprs.Expr):
641
- index = [index]
642
- if isinstance(index, list):
625
+ return self.select(index)
626
+ if isinstance(index, Sequence):
643
627
  return self.select(*index)
644
628
  raise TypeError(f'Invalid index type: {type(index)}')
645
629
 
pixeltable/env.py CHANGED
@@ -342,7 +342,7 @@ class Env:
342
342
 
343
343
  if create_db:
344
344
  from pixeltable.metadata import schema
345
- schema.Base.metadata.create_all(self._sa_engine)
345
+ schema.base_metadata.create_all(self._sa_engine)
346
346
  metadata.create_system_info(self._sa_engine)
347
347
 
348
348
  print(f'Connected to Pixeltable database at: {self.db_url}')
@@ -494,13 +494,18 @@ class Env:
494
494
  self.__register_package('anthropic')
495
495
  self.__register_package('boto3')
496
496
  self.__register_package('datasets')
497
+ self.__register_package('fiftyone')
497
498
  self.__register_package('fireworks', library_name='fireworks-ai')
499
+ self.__register_package('huggingface_hub', library_name='huggingface-hub')
498
500
  self.__register_package('label_studio_sdk', library_name='label-studio-sdk')
501
+ self.__register_package('llama_cpp', library_name='llama-cpp-python')
499
502
  self.__register_package('mistralai')
500
503
  self.__register_package('mistune')
504
+ self.__register_package('ollama')
501
505
  self.__register_package('openai')
502
506
  self.__register_package('openpyxl')
503
507
  self.__register_package('pyarrow')
508
+ self.__register_package('replicate')
504
509
  self.__register_package('sentence_transformers', library_name='sentence-transformers')
505
510
  self.__register_package('spacy')
506
511
  self.__register_package('tiktoken')
@@ -6,6 +6,5 @@ from .exec_context import ExecContext
6
6
  from .exec_node import ExecNode
7
7
  from .expr_eval_node import ExprEvalNode
8
8
  from .in_memory_data_node import InMemoryDataNode
9
- from .media_validation_node import MediaValidationNode
10
9
  from .row_update_node import RowUpdateNode
11
- from .sql_node import SqlLookupNode, SqlScanNode
10
+ from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode
@@ -2,28 +2,43 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import sys
5
- from typing import Iterable, List, Optional, Any
5
+ from typing import Any, Iterable, Iterator, Optional, cast
6
6
 
7
7
  import pixeltable.catalog as catalog
8
8
  import pixeltable.exceptions as excs
9
9
  import pixeltable.exprs as exprs
10
+
10
11
  from .data_row_batch import DataRowBatch
11
12
  from .exec_node import ExecNode
12
13
 
13
14
  _logger = logging.getLogger('pixeltable')
14
15
 
15
16
  class AggregationNode(ExecNode):
17
+ """
18
+ In-memory aggregation for UDAs.
19
+
20
+ At the moment, this returns all results in a single DataRowBatch.
21
+ """
22
+ group_by: Optional[list[exprs.Expr]]
23
+ input_exprs: list[exprs.Expr]
24
+ agg_fn_eval_ctx: exprs.RowBuilder.EvalCtx
25
+ agg_fn_calls: list[exprs.FunctionCall]
26
+ output_batch: DataRowBatch
27
+
16
28
  def __init__(
17
- self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, group_by: List[exprs.Expr],
18
- agg_fn_calls: List[exprs.FunctionCall], input_exprs: Iterable[exprs.Expr], input: ExecNode
29
+ self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, group_by: Optional[list[exprs.Expr]],
30
+ agg_fn_calls: list[exprs.FunctionCall], input_exprs: Iterable[exprs.Expr], input: ExecNode
19
31
  ):
20
- super().__init__(row_builder, group_by + agg_fn_calls, input_exprs, input)
32
+ output_exprs: list[exprs.Expr] = [] if group_by is None else list(group_by)
33
+ output_exprs.extend(agg_fn_calls)
34
+ super().__init__(row_builder, output_exprs, input_exprs, input)
21
35
  self.input = input
22
36
  self.group_by = group_by
23
37
  self.input_exprs = list(input_exprs)
24
- self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=input_exprs)
38
+ self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=self.input_exprs)
25
39
  # we need to make sure to refer to the same exprs that RowBuilder.eval() will use
26
- self.agg_fn_calls = self.agg_fn_eval_ctx.target_exprs
40
+ self.agg_fn_calls = [cast(exprs.FunctionCall, e) for e in self.agg_fn_eval_ctx.target_exprs]
41
+ # create output_batch here, rather than in __iter__(), so we don't need to remember tbl and row_builder
27
42
  self.output_batch = DataRowBatch(tbl, row_builder, 0)
28
43
 
29
44
  def _reset_agg_state(self, row_num: int) -> None:
@@ -45,17 +60,14 @@ class AggregationNode(ExecNode):
45
60
  input_vals = [row[d.slot_idx] for d in fn_call.dependencies()]
46
61
  raise excs.ExprEvalError(fn_call, expr_msg, e, exc_tb, input_vals, row_num)
47
62
 
48
- def __next__(self) -> DataRowBatch:
49
- if self.output_batch is None:
50
- raise StopIteration
51
-
63
+ def __iter__(self) -> Iterator[DataRowBatch]:
52
64
  prev_row: Optional[exprs.DataRow] = None
53
- current_group: Optional[List[Any]] = None # the values of the group-by exprs
65
+ current_group: Optional[list[Any]] = None # the values of the group-by exprs
54
66
  num_input_rows = 0
55
67
  for row_batch in self.input:
56
68
  num_input_rows += len(row_batch)
57
69
  for row in row_batch:
58
- group = [row[e.slot_idx] for e in self.group_by]
70
+ group = [row[e.slot_idx] for e in self.group_by] if self.group_by is not None else None
59
71
  if current_group is None:
60
72
  current_group = group
61
73
  self._reset_agg_state(0)
@@ -71,9 +83,7 @@ class AggregationNode(ExecNode):
71
83
  self.row_builder.eval(prev_row, self.agg_fn_eval_ctx, profile=self.ctx.profile)
72
84
  self.output_batch.add_row(prev_row)
73
85
 
74
- result = self.output_batch
75
- result.flush_imgs(None, self.stored_img_cols, self.flushed_img_slots)
76
- self.output_batch = None
77
- _logger.debug(f'AggregateNode: consumed {num_input_rows} rows, returning {len(result.rows)} rows')
78
- return result
86
+ self.output_batch.flush_imgs(None, self.stored_img_cols, self.flushed_img_slots)
87
+ _logger.debug(f'AggregateNode: consumed {num_input_rows} rows, returning {len(self.output_batch.rows)} rows')
88
+ yield self.output_batch
79
89
 
@@ -79,7 +79,7 @@ class CachePrefetchNode(ExecNode):
79
79
 
80
80
  return input_batch
81
81
 
82
- def _fetch_url(self, row: exprs.DataRow, slot_idx: int) -> Optional[str]:
82
+ def _fetch_url(self, row: exprs.DataRow, slot_idx: int) -> Optional[Path]:
83
83
  """Fetches a remote URL into Env.tmp_dir and returns its path"""
84
84
  url = row.file_urls[slot_idx]
85
85
  parsed = urllib.parse.urlparse(url)
@@ -14,6 +14,13 @@ class DataRowBatch:
14
14
 
15
15
  Contains the metadata needed to initialize DataRows.
16
16
  """
17
+ tbl: Optional[catalog.TableVersion]
18
+ row_builder: exprs.RowBuilder
19
+ img_slot_idxs: list[int]
20
+ media_slot_idxs: list[int] # non-image media slots
21
+ array_slot_idxs: list[int]
22
+ rows: list[exprs.DataRow]
23
+
17
24
  def __init__(self, tbl: Optional[catalog.TableVersion], row_builder: exprs.RowBuilder, len: int = 0):
18
25
  self.tbl = tbl
19
26
  self.row_builder = row_builder
@@ -39,17 +46,10 @@ class DataRowBatch:
39
46
  def pop_row(self) -> exprs.DataRow:
40
47
  return self.rows.pop()
41
48
 
42
- def set_row_ids(self, row_ids: List[int]) -> None:
43
- """Sets pks for rows in batch"""
44
- assert self.tbl is not None
45
- assert len(row_ids) == len(self.rows)
46
- for row, row_id in zip(self.rows, row_ids):
47
- row.set_pk((row_id, self.tbl))
48
-
49
49
  def __len__(self) -> int:
50
50
  return len(self.rows)
51
51
 
52
- def __getitem__(self, index: object) -> exprs.DataRow:
52
+ def __getitem__(self, index: int) -> exprs.DataRow:
53
53
  return self.rows[index]
54
54
 
55
55
  def flush_imgs(
@@ -74,21 +74,4 @@ class DataRowBatch:
74
74
  row.flush_img(slot_idx)
75
75
 
76
76
  def __iter__(self) -> Iterator[exprs.DataRow]:
77
- return DataRowBatchIterator(self)
78
-
79
-
80
- class DataRowBatchIterator:
81
- """
82
- Iterator over a DataRowBatch.
83
- """
84
- def __init__(self, batch: DataRowBatch):
85
- self.row_batch = batch
86
- self.index = 0
87
-
88
- def __next__(self) -> exprs.DataRow:
89
- if self.index >= len(self.row_batch.rows):
90
- raise StopIteration
91
- row = self.row_batch.rows[self.index]
92
- self.index += 1
93
- return row
94
-
77
+ return iter(self.rows)