pixeltable 0.3.11__py3-none-any.whl → 0.3.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (35) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/catalog.py +309 -59
  4. pixeltable/catalog/globals.py +5 -5
  5. pixeltable/catalog/insertable_table.py +2 -1
  6. pixeltable/catalog/path.py +13 -6
  7. pixeltable/catalog/table.py +8 -6
  8. pixeltable/catalog/table_version.py +100 -72
  9. pixeltable/catalog/view.py +4 -9
  10. pixeltable/exec/sql_node.py +0 -1
  11. pixeltable/exprs/json_path.py +1 -5
  12. pixeltable/func/__init__.py +1 -1
  13. pixeltable/func/aggregate_function.py +1 -1
  14. pixeltable/func/callable_function.py +1 -1
  15. pixeltable/func/expr_template_function.py +2 -2
  16. pixeltable/func/function.py +3 -4
  17. pixeltable/func/query_template_function.py +87 -4
  18. pixeltable/func/tools.py +1 -1
  19. pixeltable/globals.py +7 -2
  20. pixeltable/metadata/__init__.py +1 -1
  21. pixeltable/metadata/converters/convert_31.py +11 -0
  22. pixeltable/metadata/converters/convert_32.py +15 -0
  23. pixeltable/metadata/converters/convert_33.py +17 -0
  24. pixeltable/metadata/notes.py +3 -0
  25. pixeltable/metadata/schema.py +26 -1
  26. pixeltable/plan.py +2 -3
  27. pixeltable/share/packager.py +8 -24
  28. pixeltable/share/publish.py +20 -9
  29. pixeltable/store.py +7 -4
  30. pixeltable/utils/exception_handler.py +59 -0
  31. {pixeltable-0.3.11.dist-info → pixeltable-0.3.12.dist-info}/METADATA +1 -1
  32. {pixeltable-0.3.11.dist-info → pixeltable-0.3.12.dist-info}/RECORD +35 -31
  33. {pixeltable-0.3.11.dist-info → pixeltable-0.3.12.dist-info}/LICENSE +0 -0
  34. {pixeltable-0.3.11.dist-info → pixeltable-0.3.12.dist-info}/WHEEL +0 -0
  35. {pixeltable-0.3.11.dist-info → pixeltable-0.3.12.dist-info}/entry_points.txt +0 -0
@@ -95,6 +95,7 @@ class Table(SchemaObject):
95
95
  'col1': StringType(),
96
96
  'col2': IntType(),
97
97
  },
98
+ 'is_replica': False,
98
99
  'version': 22,
99
100
  'schema_version': 1,
100
101
  'comment': '',
@@ -110,6 +111,7 @@ class Table(SchemaObject):
110
111
  md = super().get_metadata()
111
112
  md['base'] = self._base._path() if self._base is not None else None
112
113
  md['schema'] = self._schema
114
+ md['is_replica'] = self._tbl_version.get().is_replica
113
115
  md['version'] = self._version
114
116
  md['schema_version'] = self._tbl_version.get().schema_version
115
117
  md['comment'] = self._comment
@@ -139,14 +141,14 @@ class Table(SchemaObject):
139
141
  if self._is_dropped:
140
142
  raise excs.Error(f'{self._display_name()} {self._name} has been dropped')
141
143
 
142
- def __getattr__(self, name: str) -> 'pxt.exprs.ColumnRef':
144
+ def __getattr__(self, name: str) -> 'exprs.ColumnRef':
143
145
  """Return a ColumnRef for the given name."""
144
146
  col = self._tbl_version_path.get_column(name)
145
147
  if col is None:
146
148
  raise AttributeError(f'Column {name!r} unknown')
147
149
  return ColumnRef(col)
148
150
 
149
- def __getitem__(self, name: str) -> 'pxt.exprs.ColumnRef':
151
+ def __getitem__(self, name: str) -> 'exprs.ColumnRef':
150
152
  """Return a ColumnRef for the given name."""
151
153
  return getattr(self, name)
152
154
 
@@ -689,7 +691,7 @@ class Table(SchemaObject):
689
691
  for name, spec in schema.items():
690
692
  col_type: Optional[ts.ColumnType] = None
691
693
  value_expr: Optional[exprs.Expr] = None
692
- primary_key: Optional[bool] = None
694
+ primary_key: bool = False
693
695
  media_validation: Optional[catalog.MediaValidation] = None
694
696
  stored = True
695
697
 
@@ -711,7 +713,7 @@ class Table(SchemaObject):
711
713
  value_expr = value_expr.copy()
712
714
  value_expr.bind_rel_paths()
713
715
  stored = spec.get('stored', True)
714
- primary_key = spec.get('primary_key')
716
+ primary_key = spec.get('primary_key', False)
715
717
  media_validation_str = spec.get('media_validation')
716
718
  media_validation = (
717
719
  catalog.MediaValidation[media_validation_str.upper()] if media_validation_str is not None else None
@@ -1282,7 +1284,7 @@ class Table(SchemaObject):
1282
1284
  raise NotImplementedError
1283
1285
 
1284
1286
  def update(
1285
- self, value_spec: dict[str, Any], where: Optional['pxt.exprs.Expr'] = None, cascade: bool = True
1287
+ self, value_spec: dict[str, Any], where: Optional['exprs.Expr'] = None, cascade: bool = True
1286
1288
  ) -> UpdateStatus:
1287
1289
  """Update rows in this table.
1288
1290
 
@@ -1383,7 +1385,7 @@ class Table(SchemaObject):
1383
1385
  FileCache.get().emit_eviction_warnings()
1384
1386
  return status
1385
1387
 
1386
- def delete(self, where: Optional['pxt.exprs.Expr'] = None) -> UpdateStatus:
1388
+ def delete(self, where: Optional['exprs.Expr'] = None) -> UpdateStatus:
1387
1389
  """Delete rows in this table.
1388
1390
 
1389
1391
  Args:
@@ -5,7 +5,7 @@ import importlib
5
5
  import logging
6
6
  import time
7
7
  import uuid
8
- from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional
8
+ from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional, Tuple
9
9
  from uuid import UUID
10
10
 
11
11
  import jsonschema.exceptions
@@ -18,6 +18,7 @@ from pixeltable import exprs, index
18
18
  from pixeltable.env import Env
19
19
  from pixeltable.iterators import ComponentIterator
20
20
  from pixeltable.metadata import schema
21
+ from pixeltable.utils.exception_handler import run_cleanup_on_exception
21
22
  from pixeltable.utils.filecache import FileCache
22
23
  from pixeltable.utils.media_store import MediaStore
23
24
 
@@ -55,6 +56,7 @@ class TableVersion:
55
56
  name: str
56
57
  user: Optional[str]
57
58
  effective_version: Optional[int]
59
+ is_replica: bool
58
60
  version: int
59
61
  comment: str
60
62
  media_validation: MediaValidation
@@ -111,6 +113,7 @@ class TableVersion:
111
113
  self.user = tbl_md.user
112
114
  self.effective_version = effective_version
113
115
  self.version = tbl_md.current_version if effective_version is None else effective_version
116
+ self.is_replica = tbl_md.is_replica
114
117
  self.comment = schema_version_md.comment
115
118
  self.num_retained_versions = schema_version_md.num_retained_versions
116
119
  self.schema_version = schema_version_md.schema_version
@@ -232,6 +235,7 @@ class TableVersion:
232
235
  tbl_id=str(tbl_id),
233
236
  name=name,
234
237
  user=user,
238
+ is_replica=False,
235
239
  current_version=0,
236
240
  current_schema_version=0,
237
241
  next_col_id=len(cols),
@@ -310,24 +314,16 @@ class TableVersion:
310
314
  session.add(schema_version_record)
311
315
  return tbl_record.id, tbl_version
312
316
 
313
- @classmethod
314
- def delete_md(cls, tbl_id: UUID) -> None:
315
- conn = Env.get().conn
316
- conn.execute(sql.delete(schema.TableSchemaVersion.__table__).where(schema.TableSchemaVersion.tbl_id == tbl_id))
317
- conn.execute(sql.delete(schema.TableVersion.__table__).where(schema.TableVersion.tbl_id == tbl_id))
318
- conn.execute(sql.delete(schema.Table.__table__).where(schema.Table.id == tbl_id))
319
-
320
317
  def drop(self) -> None:
318
+ from .catalog import Catalog
319
+
320
+ cat = Catalog.get()
321
321
  # delete this table and all associated data
322
322
  MediaStore.delete(self.id)
323
323
  FileCache.get().clear(tbl_id=self.id)
324
- self.delete_md(self.id)
324
+ cat.delete_tbl_md(self.id)
325
325
  self.store_tbl.drop()
326
-
327
326
  # de-register table version from catalog
328
- from .catalog import Catalog
329
-
330
- cat = Catalog.get()
331
327
  cat.remove_tbl_version(self)
332
328
 
333
329
  def _init_schema(self, tbl_md: schema.TableMd, schema_version_md: schema.TableSchemaVersionMd) -> None:
@@ -381,7 +377,7 @@ class TableVersion:
381
377
 
382
378
  # make sure to traverse columns ordered by position = order in which cols were created;
383
379
  # this guarantees that references always point backwards
384
- if col_md.value_expr is not None:
380
+ if not self.is_snapshot and col_md.value_expr is not None:
385
381
  self._record_refd_columns(col)
386
382
 
387
383
  def _init_idxs(self, tbl_md: schema.TableMd) -> None:
@@ -437,29 +433,15 @@ class TableVersion:
437
433
  specified preceding schema version
438
434
  """
439
435
  assert update_tbl_version or preceding_schema_version is None
436
+ from pixeltable.catalog import Catalog
440
437
 
441
- conn = Env.get().conn
442
- conn.execute(
443
- sql.update(schema.Table.__table__)
444
- .values({schema.Table.md: dataclasses.asdict(self._create_tbl_md())})
445
- .where(schema.Table.id == self.id)
438
+ tbl_md = self._create_tbl_md()
439
+ version_md = self._create_version_md(timestamp) if update_tbl_version else None
440
+ schema_version_md = (
441
+ self._create_schema_version_md(preceding_schema_version) if preceding_schema_version is not None else None
446
442
  )
447
443
 
448
- if update_tbl_version:
449
- version_md = self._create_version_md(timestamp)
450
- conn.execute(
451
- sql.insert(schema.TableVersion.__table__).values(
452
- tbl_id=self.id, version=self.version, md=dataclasses.asdict(version_md)
453
- )
454
- )
455
-
456
- if preceding_schema_version is not None:
457
- schema_version_md = self._create_schema_version_md(preceding_schema_version)
458
- conn.execute(
459
- sql.insert(schema.TableSchemaVersion.__table__).values(
460
- tbl_id=self.id, schema_version=self.schema_version, md=dataclasses.asdict(schema_version_md)
461
- )
462
- )
444
+ Catalog.get().store_tbl_md(self.id, tbl_md, version_md, schema_version_md)
463
445
 
464
446
  def ensure_md_loaded(self) -> None:
465
447
  """Ensure that table metadata is loaded."""
@@ -480,33 +462,36 @@ class TableVersion:
480
462
  _logger.info(f'Added index {idx_name} on column {col.name} to table {self.name}')
481
463
  return status
482
464
 
483
- def _add_default_index(self, col: Column) -> Optional[UpdateStatus]:
484
- """Add a B-tree index on this column if it has a compatible type"""
465
+ def _is_btree_indexable(self, col: Column) -> bool:
485
466
  if not col.stored:
486
467
  # if the column is intentionally not stored, we want to avoid the overhead of an index
487
- return None
468
+ return False
488
469
  # Skip index for stored media columns produced by an iterator
489
470
  if col.col_type.is_media_type() and self.is_iterator_column(col):
490
- return None
471
+ return False
491
472
  if not col.col_type.is_scalar_type() and not (col.col_type.is_media_type() and not col.is_computed):
492
473
  # wrong type for a B-tree
493
- return None
494
- if col.col_type.is_bool_type():
474
+ return False
475
+ if col.col_type.is_bool_type(): # noqa : SIM103 Supress `Return the negated condition directly` check
495
476
  # B-trees on bools aren't useful
477
+ return False
478
+ return True
479
+
480
+ def _add_default_index(self, col: Column) -> Optional[UpdateStatus]:
481
+ """Add a B-tree index on this column if it has a compatible type"""
482
+ if not self._is_btree_indexable(col):
496
483
  return None
497
484
  status = self._add_index(col, idx_name=None, idx=index.BtreeIndex(col))
498
485
  return status
499
486
 
500
- def _add_index(self, col: Column, idx_name: Optional[str], idx: index.IndexBase) -> UpdateStatus:
487
+ def _create_index_columns(self, idx: index.IndexBase) -> Tuple[Column, Column]:
488
+ """Create value and undo columns for the given index.
489
+ Args:
490
+ idx: index for which columns will be created.
491
+ Returns:
492
+ A tuple containing the value column and the undo column.
493
+ """
501
494
  assert not self.is_snapshot
502
- idx_id = self.next_idx_id
503
- self.next_idx_id += 1
504
- if idx_name is None:
505
- idx_name = f'idx{idx_id}'
506
- else:
507
- assert is_valid_identifier(idx_name)
508
- assert idx_name not in [i.name for i in self.idx_md.values()]
509
-
510
495
  # add the index value and undo columns (which need to be nullable)
511
496
  val_col = Column(
512
497
  col_id=self.next_col_id,
@@ -535,7 +520,19 @@ class TableVersion:
535
520
  undo_col.tbl = self.create_handle()
536
521
  undo_col.col_type = undo_col.col_type.copy(nullable=True)
537
522
  self.next_col_id += 1
523
+ return val_col, undo_col
538
524
 
525
+ def _create_index(
526
+ self, col: Column, val_col: Column, undo_col: Column, idx_name: Optional[str], idx: index.IndexBase
527
+ ) -> None:
528
+ """Create the given index along with index md"""
529
+ idx_id = self.next_idx_id
530
+ self.next_idx_id += 1
531
+ if idx_name is None:
532
+ idx_name = f'idx{idx_id}'
533
+ else:
534
+ assert is_valid_identifier(idx_name)
535
+ assert idx_name not in [i.name for i in self.idx_md.values()]
539
536
  # create and register the index metadata
540
537
  idx_cls = type(idx)
541
538
  idx_md = schema.IndexMd(
@@ -553,14 +550,27 @@ class TableVersion:
553
550
  idx_info = self.IndexInfo(id=idx_id, name=idx_name, idx=idx, col=col, val_col=val_col, undo_col=undo_col)
554
551
  self.idx_md[idx_id] = idx_md
555
552
  self.idxs_by_name[idx_name] = idx_info
553
+ try:
554
+ idx.create_index(self._store_idx_name(idx_id), val_col)
555
+ finally:
556
556
 
557
+ def cleanup_index() -> None:
558
+ """Delete the newly added in-memory index structure"""
559
+ del self.idxs_by_name[idx_name]
560
+ del self.idx_md[idx_id]
561
+ self.next_idx_id = idx_id
562
+
563
+ # Run cleanup only if there has been an exception; otherwise, skip cleanup.
564
+ run_cleanup_on_exception(cleanup_index)
565
+
566
+ def _add_index(self, col: Column, idx_name: Optional[str], idx: index.IndexBase) -> UpdateStatus:
567
+ val_col, undo_vol = self._create_index_columns(idx)
557
568
  # add the columns and update the metadata
558
569
  # TODO support on_error='abort' for indices; it's tricky because of the way metadata changes are entangled
559
570
  # with the database operations
560
- status = self._add_columns([val_col, undo_col], print_stats=False, on_error='ignore')
571
+ status = self._add_columns([val_col, undo_vol], print_stats=False, on_error='ignore')
561
572
  # now create the index structure
562
- idx.create_index(self._store_idx_name(idx_id), val_col)
563
-
573
+ self._create_index(col, val_col, undo_vol, idx_name, idx)
564
574
  return status
565
575
 
566
576
  def drop_index(self, idx_id: int) -> None:
@@ -601,9 +611,21 @@ class TableVersion:
601
611
  self.version += 1
602
612
  preceding_schema_version = self.schema_version
603
613
  self.schema_version = self.version
604
- status = self._add_columns(cols, print_stats=print_stats, on_error=on_error)
614
+ index_cols: dict[Column, tuple[index.BtreeIndex, Column, Column]] = {}
615
+ all_cols: list[Column] = []
605
616
  for col in cols:
606
- _ = self._add_default_index(col)
617
+ all_cols.append(col)
618
+ if self._is_btree_indexable(col):
619
+ idx = index.BtreeIndex(col)
620
+ val_col, undo_col = self._create_index_columns(idx)
621
+ index_cols[col] = (idx, val_col, undo_col)
622
+ all_cols.append(val_col)
623
+ all_cols.append(undo_col)
624
+ # Add all columns
625
+ status = self._add_columns(all_cols, print_stats=print_stats, on_error=on_error)
626
+ # Create indices and their mds
627
+ for col, (idx, val_col, undo_col) in index_cols.items():
628
+ self._create_index(col, val_col, undo_col, idx_name=None, idx=idx)
607
629
  self._update_md(time.time(), preceding_schema_version=preceding_schema_version)
608
630
  _logger.info(f'Added columns {[col.name for col in cols]} to table {self.name}, new version: {self.version}')
609
631
 
@@ -619,9 +641,9 @@ class TableVersion:
619
641
  self, cols: Iterable[Column], print_stats: bool, on_error: Literal['abort', 'ignore']
620
642
  ) -> UpdateStatus:
621
643
  """Add and populate columns within the current transaction"""
622
- cols = list(cols)
644
+ cols_to_add = list(cols)
623
645
  row_count = self.store_tbl.count()
624
- for col in cols:
646
+ for col in cols_to_add:
625
647
  if not col.col_type.nullable and not col.is_computed and row_count > 0:
626
648
  raise excs.Error(
627
649
  f'Cannot add non-nullable column {col.name!r} to table {self.name!r} with existing rows'
@@ -629,7 +651,8 @@ class TableVersion:
629
651
 
630
652
  num_excs = 0
631
653
  cols_with_excs: list[Column] = []
632
- for col in cols:
654
+ for col in cols_to_add:
655
+ excs_per_col = 0
633
656
  col.schema_version_add = self.schema_version
634
657
  # add the column to the lookup structures now, rather than after the store changes executed successfully,
635
658
  # because it might be referenced by the next column's value_expr
@@ -652,29 +675,32 @@ class TableVersion:
652
675
 
653
676
  plan, value_expr_slot_idx = Planner.create_add_column_plan(self.path, col)
654
677
  plan.ctx.num_rows = row_count
655
-
656
678
  try:
657
679
  plan.open()
658
680
  try:
659
- num_excs = self.store_tbl.load_column(col, plan, value_expr_slot_idx, on_error)
681
+ excs_per_col = self.store_tbl.load_column(col, plan, value_expr_slot_idx, on_error)
660
682
  except sql.exc.DBAPIError as exc:
661
683
  # Wrap the DBAPIError in an excs.Error to unify processing in the subsequent except block
662
684
  raise excs.Error(f'SQL error during execution of computed column `{col.name}`:\n{exc}') from exc
663
- if num_excs > 0:
685
+ if excs_per_col > 0:
664
686
  cols_with_excs.append(col)
665
- except excs.Error as exc:
666
- self.cols.pop()
667
- for c in cols:
668
- # remove columns that we already added
669
- if c.id not in self.cols_by_id:
670
- continue
671
- if c.name is not None:
672
- del self.cols_by_name[c.name]
673
- del self.cols_by_id[c.id]
674
- # we need to re-initialize the sqlalchemy schema
675
- self.store_tbl.create_sa_tbl()
676
- raise exc
687
+ num_excs += excs_per_col
677
688
  finally:
689
+ # Ensure cleanup occurs if an exception or keyboard interruption happens during `load_column()`.
690
+ def cleanup_on_error() -> None:
691
+ """Delete columns that are added as part of current add_columns operation and re-initialize
692
+ the sqlalchemy schema"""
693
+ self.cols = [col for col in self.cols if col not in cols_to_add]
694
+ for col in cols_to_add:
695
+ # remove columns that we already added
696
+ if col.id in self.cols_by_id:
697
+ del self.cols_by_id[col.id]
698
+ if col.name is not None and col.name in self.cols_by_name:
699
+ del self.cols_by_name[col.name]
700
+ self.store_tbl.create_sa_tbl()
701
+
702
+ # Run cleanup only if there has been an exception; otherwise, skip cleanup.
703
+ run_cleanup_on_exception(cleanup_on_error)
678
704
  plan.close()
679
705
 
680
706
  if print_stats:
@@ -1298,6 +1324,7 @@ class TableVersion:
1298
1324
  column_md: dict[int, schema.ColumnMd] = {}
1299
1325
  for col in cols:
1300
1326
  value_expr_dict = col.value_expr.as_dict() if col.value_expr is not None else None
1327
+ assert col.is_pk is not None
1301
1328
  column_md[col.id] = schema.ColumnMd(
1302
1329
  id=col.id,
1303
1330
  col_type=col.col_type.as_dict(),
@@ -1320,6 +1347,7 @@ class TableVersion:
1320
1347
  tbl_id=str(self.id),
1321
1348
  name=self.name,
1322
1349
  user=self.user,
1350
+ is_replica=self.is_replica,
1323
1351
  current_version=self.version,
1324
1352
  current_schema_version=self.schema_version,
1325
1353
  next_col_id=self.next_col_id,
@@ -8,7 +8,7 @@ from uuid import UUID
8
8
  import pixeltable.exceptions as excs
9
9
  import pixeltable.metadata.schema as md_schema
10
10
  import pixeltable.type_system as ts
11
- from pixeltable import exprs, func
11
+ from pixeltable import catalog, exprs, func
12
12
  from pixeltable.env import Env
13
13
  from pixeltable.iterators import ComponentIterator
14
14
 
@@ -20,7 +20,7 @@ from .table_version_handle import TableVersionHandle
20
20
  from .table_version_path import TableVersionPath
21
21
 
22
22
  if TYPE_CHECKING:
23
- import pixeltable as pxt
23
+ from pixeltable.globals import TableDataSource
24
24
 
25
25
  _logger = logging.getLogger('pixeltable')
26
26
 
@@ -65,7 +65,7 @@ class View(Table):
65
65
  base: TableVersionPath,
66
66
  select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]],
67
67
  additional_columns: dict[str, Any],
68
- predicate: Optional['pxt.exprs.Expr'],
68
+ predicate: Optional['exprs.Expr'],
69
69
  is_snapshot: bool,
70
70
  num_retained_versions: int,
71
71
  comment: str,
@@ -242,7 +242,7 @@ class View(Table):
242
242
  # there is not TableVersion to drop
243
243
  self._check_is_dropped()
244
244
  self.is_dropped = True
245
- TableVersion.delete_md(self._id)
245
+ catalog.Catalog.get().delete_tbl_md(self._id)
246
246
  else:
247
247
  super()._drop()
248
248
 
@@ -252,11 +252,6 @@ class View(Table):
252
252
  md['is_snapshot'] = self._tbl_version_path.is_snapshot()
253
253
  return md
254
254
 
255
- if TYPE_CHECKING:
256
- import datasets # type: ignore[import-untyped]
257
-
258
- from pixeltable.globals import RowData, TableDataSource
259
-
260
255
  def insert(
261
256
  self,
262
257
  source: Optional[TableDataSource] = None,
@@ -103,7 +103,6 @@ class SqlNode(ExecNode):
103
103
  # create Select stmt
104
104
  self.sql_elements = sql_elements
105
105
  self.tbl = tbl
106
- assert all(not isinstance(e, exprs.Literal) for e in select_list) # we're never asked to materialize literals
107
106
  self.select_list = exprs.ExprSet(select_list)
108
107
  # unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
109
108
  for iter_arg in row_builder.unstored_iter_args.values():
@@ -5,7 +5,6 @@ from typing import Any, Optional, Union
5
5
  import jmespath
6
6
  import sqlalchemy as sql
7
7
 
8
- import pixeltable as pxt
9
8
  from pixeltable import catalog, exceptions as excs, type_system as ts
10
9
 
11
10
  from .data_row import DataRow
@@ -19,10 +18,7 @@ from .sql_element_cache import SqlElementCache
19
18
 
20
19
  class JsonPath(Expr):
21
20
  def __init__(
22
- self,
23
- anchor: Optional['pxt.exprs.Expr'],
24
- path_elements: Optional[list[Union[str, int, slice]]] = None,
25
- scope_idx: int = 0,
21
+ self, anchor: Optional[Expr], path_elements: Optional[list[Union[str, int, slice]]] = None, scope_idx: int = 0
26
22
  ) -> None:
27
23
  """
28
24
  anchor can be None, in which case this is a relative JsonPath and the anchor is set later via set_anchor().
@@ -5,7 +5,7 @@ from .callable_function import CallableFunction
5
5
  from .expr_template_function import ExprTemplateFunction
6
6
  from .function import Function, InvalidFunction
7
7
  from .function_registry import FunctionRegistry
8
- from .query_template_function import QueryTemplateFunction, query
8
+ from .query_template_function import QueryTemplateFunction, query, retrieval_udf
9
9
  from .signature import Batch, Parameter, Signature
10
10
  from .tools import Tool, ToolChoice, Tools
11
11
  from .udf import expr_udf, make_function, udf
@@ -159,7 +159,7 @@ class AggregateFunction(Function):
159
159
  self.init_param_names.append(init_param_names)
160
160
  return self
161
161
 
162
- def _docstring(self) -> Optional[str]:
162
+ def comment(self) -> Optional[str]:
163
163
  return inspect.getdoc(self.agg_classes[0])
164
164
 
165
165
  def help_str(self) -> str:
@@ -60,7 +60,7 @@ class CallableFunction(Function):
60
60
  def is_async(self) -> bool:
61
61
  return inspect.iscoroutinefunction(self.py_fn)
62
62
 
63
- def _docstring(self) -> Optional[str]:
63
+ def comment(self) -> Optional[str]:
64
64
  return inspect.getdoc(self.py_fns[0])
65
65
 
66
66
  @property
@@ -95,9 +95,9 @@ class ExprTemplateFunction(Function):
95
95
  )
96
96
  return substituted_expr.col_type
97
97
 
98
- def _docstring(self) -> Optional[str]:
98
+ def comment(self) -> Optional[str]:
99
99
  if isinstance(self.templates[0].expr, exprs.FunctionCall):
100
- return self.templates[0].expr.fn._docstring()
100
+ return self.templates[0].expr.fn.comment()
101
101
  return None
102
102
 
103
103
  def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
@@ -10,8 +10,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, cast
10
10
  import sqlalchemy as sql
11
11
  from typing_extensions import Self
12
12
 
13
- import pixeltable.exceptions as excs
14
- import pixeltable.type_system as ts
13
+ from pixeltable import exceptions as excs, type_system as ts
15
14
 
16
15
  from .globals import resolve_symbol
17
16
  from .signature import Signature
@@ -106,11 +105,11 @@ class Function(ABC):
106
105
  @abstractmethod
107
106
  def is_async(self) -> bool: ...
108
107
 
109
- def _docstring(self) -> Optional[str]:
108
+ def comment(self) -> Optional[str]:
110
109
  return None
111
110
 
112
111
  def help_str(self) -> str:
113
- docstring = self._docstring()
112
+ docstring = self.comment()
114
113
  display = self.display_name + str(self.signatures[0])
115
114
  if docstring is None:
116
115
  return display
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from typing import TYPE_CHECKING, Any, Callable, Optional, overload
4
+ from functools import reduce
5
+ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Union, overload
5
6
 
6
- from pixeltable import exprs, type_system as ts
7
+ from pixeltable import catalog, exceptions as excs, exprs, func, type_system as ts
7
8
 
8
9
  from .function import Function
9
10
  from .signature import Signature
@@ -17,6 +18,7 @@ class QueryTemplateFunction(Function):
17
18
 
18
19
  template_df: Optional['DataFrame']
19
20
  self_name: Optional[str]
21
+ _comment: Optional[str]
20
22
 
21
23
  @classmethod
22
24
  def create(
@@ -34,15 +36,21 @@ class QueryTemplateFunction(Function):
34
36
  assert isinstance(template_df, DataFrame)
35
37
  # we take params and return json
36
38
  sig = Signature(return_type=ts.JsonType(), parameters=params)
37
- return QueryTemplateFunction(template_df, sig, path=path, name=name)
39
+ return QueryTemplateFunction(template_df, sig, path=path, name=name, comment=inspect.getdoc(template_callable))
38
40
 
39
41
  def __init__(
40
- self, template_df: Optional['DataFrame'], sig: Signature, path: Optional[str] = None, name: Optional[str] = None
42
+ self,
43
+ template_df: Optional['DataFrame'],
44
+ sig: Signature,
45
+ path: Optional[str] = None,
46
+ name: Optional[str] = None,
47
+ comment: Optional[str] = None,
41
48
  ):
42
49
  assert sig is not None
43
50
  super().__init__([sig], self_path=path)
44
51
  self.self_name = name
45
52
  self.template_df = template_df
53
+ self._comment = comment
46
54
 
47
55
  def _update_as_overload_resolution(self, signature_idx: int) -> None:
48
56
  pass # only one signature supported for QueryTemplateFunction
@@ -74,6 +82,9 @@ class QueryTemplateFunction(Function):
74
82
  def name(self) -> str:
75
83
  return self.self_name
76
84
 
85
+ def comment(self) -> Optional[str]:
86
+ return self._comment
87
+
77
88
  def _as_dict(self) -> dict:
78
89
  return {'name': self.name, 'signature': self.signature.as_dict(), 'df': self.template_df.as_dict()}
79
90
 
@@ -112,3 +123,75 @@ def query(*args: Any, **kwargs: Any) -> Any:
112
123
  else:
113
124
  assert len(args) == 0 and len(kwargs) == 1 and 'param_types' in kwargs
114
125
  return lambda py_fn: make_query_template(py_fn, kwargs['param_types'])
126
+
127
+
128
+ def retrieval_udf(
129
+ table: catalog.Table,
130
+ name: Optional[str] = None,
131
+ description: Optional[str] = None,
132
+ parameters: Optional[Iterable[Union[str, exprs.ColumnRef]]] = None,
133
+ limit: Optional[int] = 10,
134
+ ) -> func.QueryTemplateFunction:
135
+ """
136
+ Constructs a retrieval UDF for the given table. The retrieval UDF is a UDF whose parameters are
137
+ columns of the table and whose return value is a list of rows from the table. The return value of
138
+ ```python
139
+ f(col1=x, col2=y, ...)
140
+ ```
141
+ will be a list of all rows from the table that match the specified arguments.
142
+
143
+ Args:
144
+ table: The table to use as the dataset for the retrieval tool.
145
+ name: The name of the tool. If not specified, then the name of the table will be used by default.
146
+ description: The description of the tool. If not specified, then a default description will be generated.
147
+ parameters: The columns of the table to use as parameters. If not specified, all data columns
148
+ (non-computed columns) will be used as parameters.
149
+
150
+ All of the specified parameters will be required parameters of the tool, regardless of their status
151
+ as columns.
152
+ limit: The maximum number of rows to return. If not specified, then all matching rows will be returned.
153
+
154
+ Returns:
155
+ A list of dictionaries containing data from the table, one per row that matches the input arguments.
156
+ If there are no matching rows, an empty list will be returned.
157
+ """
158
+ # Argument validation
159
+ col_refs: list[exprs.ColumnRef]
160
+ if parameters is None:
161
+ col_refs = [table[col_name] for col_name in table.columns if not table[col_name].col.is_computed]
162
+ else:
163
+ for param in parameters:
164
+ if isinstance(param, str) and param not in table.columns:
165
+ raise excs.Error(f'The specified parameter {param!r} is not a column of the table {table._path!r}')
166
+ col_refs = [table[param] if isinstance(param, str) else param for param in parameters]
167
+
168
+ if len(col_refs) == 0:
169
+ raise excs.Error('Parameter list cannot be empty.')
170
+
171
+ # Construct the dataframe
172
+ predicates = [col_ref == exprs.Variable(col_ref.col.name, col_ref.col.col_type) for col_ref in col_refs]
173
+ where_clause = reduce(lambda c1, c2: c1 & c2, predicates)
174
+ df = table.select().where(where_clause)
175
+ if limit is not None:
176
+ df = df.limit(limit)
177
+
178
+ # Construct the signature
179
+ query_params = [
180
+ func.Parameter(col_ref.col.name, col_ref.col.col_type, inspect.Parameter.POSITIONAL_OR_KEYWORD)
181
+ for col_ref in col_refs
182
+ ]
183
+ query_signature = func.Signature(return_type=ts.JsonType(), parameters=query_params)
184
+
185
+ # Construct a name and/or description if not provided
186
+ if name is None:
187
+ name = table._name
188
+ if description is None:
189
+ description = (
190
+ f'Retrieves an entry from the dataset {name!r} that matches the given parameters.\n\nParameters:\n'
191
+ )
192
+ description += '\n'.join(
193
+ [f' {col_ref.col.name}: of type `{col_ref.col.col_type._to_base_str()}`' for col_ref in col_refs]
194
+ )
195
+
196
+ fn = func.QueryTemplateFunction(df, query_signature, name=name, comment=description)
197
+ return fn
pixeltable/func/tools.py CHANGED
@@ -39,7 +39,7 @@ class Tool(pydantic.BaseModel):
39
39
  def ser_model(self) -> dict[str, Any]:
40
40
  return {
41
41
  'name': self.name or self.fn.name,
42
- 'description': self.description or self.fn._docstring(),
42
+ 'description': self.description or self.fn.comment(),
43
43
  'parameters': {
44
44
  'type': 'object',
45
45
  'properties': {param.name: param.col_type._to_json_schema() for param in self.parameters.values()},