pixeltable 0.3.12__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. pixeltable/__init__.py +2 -27
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/catalog.py +9 -7
  4. pixeltable/catalog/column.py +6 -2
  5. pixeltable/catalog/dir.py +2 -1
  6. pixeltable/catalog/insertable_table.py +11 -0
  7. pixeltable/catalog/schema_object.py +2 -1
  8. pixeltable/catalog/table.py +27 -38
  9. pixeltable/catalog/table_version.py +19 -0
  10. pixeltable/catalog/table_version_path.py +7 -0
  11. pixeltable/catalog/view.py +31 -0
  12. pixeltable/dataframe.py +50 -7
  13. pixeltable/env.py +1 -1
  14. pixeltable/exceptions.py +20 -2
  15. pixeltable/exec/aggregation_node.py +14 -0
  16. pixeltable/exec/cache_prefetch_node.py +1 -1
  17. pixeltable/exec/expr_eval/evaluators.py +0 -4
  18. pixeltable/exec/expr_eval/expr_eval_node.py +1 -2
  19. pixeltable/exec/sql_node.py +3 -2
  20. pixeltable/exprs/column_ref.py +42 -17
  21. pixeltable/exprs/data_row.py +3 -0
  22. pixeltable/exprs/globals.py +1 -1
  23. pixeltable/exprs/literal.py +11 -1
  24. pixeltable/exprs/rowid_ref.py +4 -1
  25. pixeltable/exprs/similarity_expr.py +1 -1
  26. pixeltable/func/function.py +1 -1
  27. pixeltable/func/udf.py +1 -1
  28. pixeltable/functions/__init__.py +2 -0
  29. pixeltable/functions/anthropic.py +1 -1
  30. pixeltable/functions/bedrock.py +130 -0
  31. pixeltable/functions/date.py +185 -0
  32. pixeltable/functions/gemini.py +22 -20
  33. pixeltable/functions/globals.py +1 -16
  34. pixeltable/functions/huggingface.py +7 -6
  35. pixeltable/functions/image.py +15 -16
  36. pixeltable/functions/json.py +2 -1
  37. pixeltable/functions/math.py +40 -0
  38. pixeltable/functions/mistralai.py +3 -2
  39. pixeltable/functions/openai.py +9 -8
  40. pixeltable/functions/string.py +1 -2
  41. pixeltable/functions/together.py +4 -3
  42. pixeltable/functions/video.py +2 -2
  43. pixeltable/globals.py +26 -9
  44. pixeltable/io/datarows.py +4 -3
  45. pixeltable/io/hf_datasets.py +2 -2
  46. pixeltable/io/label_studio.py +17 -17
  47. pixeltable/io/pandas.py +29 -16
  48. pixeltable/io/parquet.py +2 -0
  49. pixeltable/io/table_data_conduit.py +8 -2
  50. pixeltable/metadata/__init__.py +1 -1
  51. pixeltable/metadata/converters/convert_19.py +2 -2
  52. pixeltable/metadata/converters/convert_34.py +21 -0
  53. pixeltable/metadata/notes.py +1 -0
  54. pixeltable/plan.py +12 -5
  55. pixeltable/share/__init__.py +1 -1
  56. pixeltable/share/packager.py +219 -119
  57. pixeltable/share/publish.py +61 -16
  58. pixeltable/store.py +45 -20
  59. pixeltable/type_system.py +46 -2
  60. pixeltable/utils/arrow.py +8 -2
  61. pixeltable/utils/pytorch.py +4 -0
  62. {pixeltable-0.3.12.dist-info → pixeltable-0.3.14.dist-info}/METADATA +2 -4
  63. {pixeltable-0.3.12.dist-info → pixeltable-0.3.14.dist-info}/RECORD +66 -63
  64. {pixeltable-0.3.12.dist-info → pixeltable-0.3.14.dist-info}/WHEEL +1 -1
  65. {pixeltable-0.3.12.dist-info → pixeltable-0.3.14.dist-info}/LICENSE +0 -0
  66. {pixeltable-0.3.12.dist-info → pixeltable-0.3.14.dist-info}/entry_points.txt +0 -0
pixeltable/__init__.py CHANGED
@@ -9,6 +9,7 @@ from .globals import (
9
9
  array,
10
10
  configure_logging,
11
11
  create_dir,
12
+ create_replica,
12
13
  create_snapshot,
13
14
  create_table,
14
15
  create_view,
@@ -20,36 +21,10 @@ from .globals import (
20
21
  list_functions,
21
22
  list_tables,
22
23
  move,
23
- publish_snapshot,
24
24
  tool,
25
25
  tools,
26
26
  )
27
- from .type_system import (
28
- Array,
29
- ArrayType,
30
- Audio,
31
- AudioType,
32
- Bool,
33
- BoolType,
34
- ColumnType,
35
- Document,
36
- DocumentType,
37
- Float,
38
- FloatType,
39
- Image,
40
- ImageType,
41
- Int,
42
- IntType,
43
- Json,
44
- JsonType,
45
- Required,
46
- String,
47
- StringType,
48
- Timestamp,
49
- TimestampType,
50
- Video,
51
- VideoType,
52
- )
27
+ from .type_system import Array, Audio, Bool, Date, Document, Float, Image, Int, Json, Required, String, Timestamp, Video
53
28
 
54
29
  # This import must go last to avoid circular imports.
55
30
  from . import ext, functions, io, iterators # isort: skip
pixeltable/__version__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  # These version placeholders will be replaced during build.
2
- __version__ = '0.3.12'
3
- __version_tuple__ = (0, 3, 12)
2
+ __version__ = '0.3.14'
3
+ __version_tuple__ = (0, 3, 14)
@@ -432,7 +432,9 @@ class Catalog:
432
432
  return view
433
433
 
434
434
  @_retry_loop
435
- def create_replica(self, path: Path, md: list[schema.FullTableMd], if_exists: IfExistsParam) -> Table:
435
+ def create_replica(
436
+ self, path: Path, md: list[schema.FullTableMd], if_exists: IfExistsParam = IfExistsParam.ERROR
437
+ ) -> Table:
436
438
  """
437
439
  Creates table, table_version, and table_schema_version records for a replica with the given metadata.
438
440
  The metadata should be presented in standard "ancestor order", with the table being replicated at
@@ -458,11 +460,11 @@ class Catalog:
458
460
  # TODO: Handle concurrency in create_replica()
459
461
  existing = Catalog.get().get_table_by_id(tbl_id)
460
462
  if existing is not None:
461
- existing_path = Path(existing._path(), allow_system_paths=True)
463
+ existing_path = Path(existing._path, allow_system_paths=True)
462
464
  # It does exist. If it's a non-system table, that's an error: it's already been replicated.
463
465
  if not existing_path.is_system_path:
464
466
  raise excs.Error(
465
- f'That table has already been replicated as {existing._path()!r}. \n'
467
+ f'That table has already been replicated as {existing._path!r}. \n'
466
468
  f'Drop the existing replica if you wish to re-create it.'
467
469
  )
468
470
  # If it's a system table, then this means it was created at some point as the ancestor of some other
@@ -487,7 +489,7 @@ class Catalog:
487
489
  # The table already exists in the catalog. The existing path might be a system path (if the table
488
490
  # was created as an anonymous base table of some other table), or it might not (if it's a snapshot
489
491
  # that was directly replicated by the user at some point). In either case, use the existing path.
490
- replica_path = Path(replica._path(), allow_system_paths=True)
492
+ replica_path = Path(replica._path, allow_system_paths=True)
491
493
 
492
494
  # Store the metadata; it could be a new version (in which case a new record will be created) or a
493
495
  # known version (in which case the newly received metadata will be validated as identical).
@@ -619,11 +621,11 @@ class Catalog:
619
621
  msg: str
620
622
  if is_replace:
621
623
  msg = (
622
- f'{obj_type_str} {tbl._path()} already exists and has dependents. '
624
+ f'{obj_type_str} {tbl._path} already exists and has dependents. '
623
625
  "Use `if_exists='replace_force'` to replace it."
624
626
  )
625
627
  else:
626
- msg = f'{obj_type_str} {tbl._path()} has dependents.'
628
+ msg = f'{obj_type_str} {tbl._path} has dependents.'
627
629
  raise excs.Error(msg)
628
630
 
629
631
  for view_id in view_ids:
@@ -634,7 +636,7 @@ class Catalog:
634
636
  tbl._drop()
635
637
  assert tbl._id in self._tbls
636
638
  del self._tbls[tbl._id]
637
- _logger.info(f'Dropped table `{tbl._path()}`.')
639
+ _logger.info(f'Dropped table `{tbl._path}`.')
638
640
 
639
641
  @_retry_loop
640
642
  def create_dir(self, path: Path, if_exists: IfExistsParam, parents: bool) -> Dir:
@@ -16,6 +16,7 @@ from .globals import MediaValidation, is_valid_identifier
16
16
  if TYPE_CHECKING:
17
17
  from .table_version import TableVersion
18
18
  from .table_version_handle import TableVersionHandle
19
+ from .table_version_path import TableVersionPath
19
20
 
20
21
  _logger = logging.getLogger('pixeltable')
21
22
 
@@ -170,9 +171,12 @@ class Column:
170
171
  )
171
172
  return len(window_fn_calls) > 0
172
173
 
173
- def get_idx_info(self) -> dict[str, 'TableVersion.IndexInfo']:
174
+ # TODO: This should be moved out of `Column` (its presence in `Column` doesn't anticipate indices being defined on
175
+ # multiple dependents)
176
+ def get_idx_info(self, reference_tbl: Optional['TableVersionPath'] = None) -> dict[str, 'TableVersion.IndexInfo']:
174
177
  assert self.tbl is not None
175
- return {name: info for name, info in self.tbl.get().idxs_by_name.items() if info.col == self}
178
+ tbl = reference_tbl.tbl_version if reference_tbl is not None else self.tbl
179
+ return {name: info for name, info in tbl.get().idxs_by_name.items() if info.col == self}
176
180
 
177
181
  @property
178
182
  def is_computed(self) -> bool:
pixeltable/catalog/dir.py CHANGED
@@ -38,12 +38,13 @@ class Dir(SchemaObject):
38
38
  def _display_name(cls) -> str:
39
39
  return 'directory'
40
40
 
41
+ @property
41
42
  def _path(self) -> str:
42
43
  """Returns the path to this schema object."""
43
44
  if self._dir_id is None:
44
45
  # we're the root dir
45
46
  return ''
46
- return super()._path()
47
+ return super()._path
47
48
 
48
49
  def _move(self, new_name: str, new_dir_id: UUID) -> None:
49
50
  # print(
@@ -228,3 +228,14 @@ class InsertableTable(Table):
228
228
  """
229
229
  with Env.get().begin_xact():
230
230
  return self._tbl_version.get().delete(where=where)
231
+
232
+ @property
233
+ def _base_table(self) -> Optional['Table']:
234
+ return None
235
+
236
+ @property
237
+ def _effective_base_versions(self) -> list[Optional[int]]:
238
+ return []
239
+
240
+ def _table_descriptor(self) -> str:
241
+ return f'Table {self._path!r}'
@@ -33,6 +33,7 @@ class SchemaObject:
33
33
  return None
34
34
  return Catalog.get().get_dir(self._dir_id)
35
35
 
36
+ @property
36
37
  def _path(self) -> str:
37
38
  """Returns the path to this schema object."""
38
39
  from .catalog import Catalog
@@ -44,7 +45,7 @@ class SchemaObject:
44
45
 
45
46
  def get_metadata(self) -> dict[str, Any]:
46
47
  """Returns metadata associated with this schema object."""
47
- return {'name': self._name, 'path': self._path()}
48
+ return {'name': self._name, 'path': self._path}
48
49
 
49
50
  @classmethod
50
51
  @abstractmethod
@@ -109,7 +109,7 @@ class Table(SchemaObject):
109
109
  self._check_is_dropped()
110
110
  with env.Env.get().begin_xact():
111
111
  md = super().get_metadata()
112
- md['base'] = self._base._path() if self._base is not None else None
112
+ md['base'] = self._base_table._path if self._base_table is not None else None
113
113
  md['schema'] = self._schema
114
114
  md['is_replica'] = self._tbl_version.get().is_replica
115
115
  md['version'] = self._version
@@ -146,7 +146,7 @@ class Table(SchemaObject):
146
146
  col = self._tbl_version_path.get_column(name)
147
147
  if col is None:
148
148
  raise AttributeError(f'Column {name!r} unknown')
149
- return ColumnRef(col)
149
+ return ColumnRef(col, reference_tbl=self._tbl_version_path)
150
150
 
151
151
  def __getitem__(self, name: str) -> 'exprs.ColumnRef':
152
152
  """Return a ColumnRef for the given name."""
@@ -165,7 +165,7 @@ class Table(SchemaObject):
165
165
  """
166
166
  self._check_is_dropped()
167
167
  with env.Env.get().begin_xact():
168
- return [t._path() for t in self._get_views(recursive=recursive)]
168
+ return [t._path for t in self._get_views(recursive=recursive)]
169
169
 
170
170
  def _get_views(self, *, recursive: bool = True) -> list['Table']:
171
171
  cat = catalog.Catalog.get()
@@ -220,6 +220,10 @@ class Table(SchemaObject):
220
220
  """
221
221
  return self._df().group_by(*items)
222
222
 
223
+ def distinct(self) -> 'pxt.DataFrame':
224
+ """Remove duplicate rows from table."""
225
+ return self._df().distinct()
226
+
223
227
  def limit(self, n: int) -> 'pxt.DataFrame':
224
228
  return self._df().limit(n)
225
229
 
@@ -255,28 +259,30 @@ class Table(SchemaObject):
255
259
  return {c.name: c.col_type for c in self._tbl_version_path.columns()}
256
260
 
257
261
  @property
258
- def _base(self) -> Optional['Table']:
259
- """
260
- The base table of this `Table`. If this table is a view, returns the `Table`
261
- from which it was derived. Otherwise, returns `None`.
262
- """
263
- if self._tbl_version_path.base is None:
264
- return None
265
- base_id = self._tbl_version_path.base.tbl_version.id
266
- return catalog.Catalog.get().get_table_by_id(base_id)
262
+ def base_table(self) -> Optional['Table']:
263
+ with env.Env.get().begin_xact():
264
+ return self._base_table
267
265
 
268
266
  @property
269
- def _bases(self) -> list['Table']:
270
- """
271
- The ancestor list of bases of this table, starting with its immediate base.
272
- """
267
+ @abc.abstractmethod
268
+ def _base_table(self) -> Optional['Table']:
269
+ """The base's Table instance"""
270
+
271
+ @property
272
+ def _base_tables(self) -> list['Table']:
273
+ """The ancestor list of bases of this table, starting with its immediate base."""
273
274
  bases = []
274
- base = self._base
275
+ base = self._base_table
275
276
  while base is not None:
276
277
  bases.append(base)
277
- base = base._base
278
+ base = base._base_table
278
279
  return bases
279
280
 
281
+ @property
282
+ @abc.abstractmethod
283
+ def _effective_base_versions(self) -> list[Optional[int]]:
284
+ """The effective versions of the ancestor bases, starting with its immediate base."""
285
+
280
286
  @property
281
287
  def _comment(self) -> str:
282
288
  return self._tbl_version.get().comment
@@ -300,7 +306,7 @@ class Table(SchemaObject):
300
306
  Constructs a list of descriptors for this table that can be pretty-printed.
301
307
  """
302
308
  helper = DescriptionHelper()
303
- helper.append(self._title_descriptor())
309
+ helper.append(self._table_descriptor())
304
310
  helper.append(self._col_descriptor())
305
311
  idxs = self._index_descriptor()
306
312
  if not idxs.empty:
@@ -312,15 +318,6 @@ class Table(SchemaObject):
312
318
  helper.append(f'COMMENT: {self._comment}')
313
319
  return helper
314
320
 
315
- def _title_descriptor(self) -> str:
316
- title: str
317
- if self._base is None:
318
- title = f'Table\n{self._path()!r}'
319
- else:
320
- title = f'View\n{self._path()!r}'
321
- title += f'\n(of {self.__bases_to_desc()})'
322
- return title
323
-
324
321
  def _col_descriptor(self, columns: Optional[list[str]] = None) -> pd.DataFrame:
325
322
  return pd.DataFrame(
326
323
  {
@@ -332,14 +329,6 @@ class Table(SchemaObject):
332
329
  if columns is None or col.name in columns
333
330
  )
334
331
 
335
- def __bases_to_desc(self) -> str:
336
- bases = self._bases
337
- assert len(bases) >= 1
338
- if len(bases) <= 2:
339
- return ', '.join(repr(b._path()) for b in bases)
340
- else:
341
- return f'{bases[0]._path()!r}, ..., {bases[-1]._path()!r}'
342
-
343
332
  def _index_descriptor(self, columns: Optional[list[str]] = None) -> pd.DataFrame:
344
333
  from pixeltable import index
345
334
 
@@ -373,9 +362,9 @@ class Table(SchemaObject):
373
362
  """
374
363
  self._check_is_dropped()
375
364
  if getattr(builtins, '__IPYTHON__', False):
376
- from IPython.display import display
365
+ from IPython.display import Markdown, display
377
366
 
378
- display(self._repr_html_())
367
+ display(Markdown(self._repr_html_()))
379
368
  else:
380
369
  print(repr(self))
381
370
 
@@ -202,6 +202,13 @@ class TableVersion:
202
202
 
203
203
  return TableVersionHandle(self.id, self.effective_version, tbl_version=self)
204
204
 
205
+ @property
206
+ def versioned_name(self) -> str:
207
+ if self.effective_version is None:
208
+ return self.name
209
+ else:
210
+ return f'{self.name}:{self.effective_version}'
211
+
205
212
  @classmethod
206
213
  def create(
207
214
  cls,
@@ -314,6 +321,18 @@ class TableVersion:
314
321
  session.add(schema_version_record)
315
322
  return tbl_record.id, tbl_version
316
323
 
324
+ @classmethod
325
+ def create_replica(cls, md: schema.FullTableMd) -> TableVersion:
326
+ tbl_id = UUID(md.tbl_md.tbl_id)
327
+ view_md = md.tbl_md.view_md
328
+ base_path = pxt.catalog.TableVersionPath.from_md(view_md.base_versions) if view_md is not None else None
329
+ base = base_path.tbl_version if base_path is not None else None
330
+ tbl_version = cls(
331
+ tbl_id, md.tbl_md, md.version_md.version, md.schema_version_md, [], base_path=base_path, base=base
332
+ )
333
+ tbl_version.store_tbl.create()
334
+ return tbl_version
335
+
317
336
  def drop(self) -> None:
318
337
  from .catalog import Catalog
319
338
 
@@ -98,6 +98,13 @@ class TableVersionPath:
98
98
  return None
99
99
  return self.base.find_tbl_version(id)
100
100
 
101
+ @property
102
+ def ancestor_paths(self) -> list[TableVersionPath]:
103
+ if self.base is None:
104
+ return [self]
105
+ else:
106
+ return [self, *self.base.ancestor_paths]
107
+
101
108
  def columns(self) -> list[Column]:
102
109
  """Return all user columns visible in this tbl version path, including columns from bases"""
103
110
  result = list(self.tbl_version.get().cols_by_name.values())
@@ -267,3 +267,34 @@ class View(Table):
267
267
 
268
268
  def delete(self, where: Optional[exprs.Expr] = None) -> UpdateStatus:
269
269
  raise excs.Error(f'{self._display_name()} {self._name!r}: cannot delete from view')
270
+
271
+ @property
272
+ def _base_table(self) -> Optional['Table']:
273
+ # if this is a pure snapshot, our tbl_version_path only reflects the base (there is no TableVersion instance
274
+ # for the snapshot itself)
275
+ base_id = self._tbl_version.id if self._snapshot_only else self._tbl_version_path.base.tbl_version.id
276
+ return catalog.Catalog.get().get_table_by_id(base_id)
277
+
278
+ @property
279
+ def _effective_base_versions(self) -> list[Optional[int]]:
280
+ effective_versions = [tv.effective_version for tv in self._tbl_version_path.get_tbl_versions()]
281
+ if self._snapshot_only:
282
+ return effective_versions
283
+ else:
284
+ return effective_versions[1:]
285
+
286
+ def _table_descriptor(self) -> str:
287
+ display_name = 'Snapshot' if self._snapshot_only else 'View'
288
+ result = [f'{display_name} {self._path!r}']
289
+ bases_descrs: list[str] = []
290
+ for base, effective_version in zip(self._base_tables, self._effective_base_versions):
291
+ if effective_version is None:
292
+ bases_descrs.append(f'{base._path!r}')
293
+ else:
294
+ base_descr = f'{base._path}:{effective_version}'
295
+ bases_descrs.append(f'{base_descr!r}')
296
+ result.append(f' (of {", ".join(bases_descrs)})')
297
+
298
+ if self._tbl_version.get().predicate is not None:
299
+ result.append(f'\nWhere: {self._tbl_version.get().predicate!s}')
300
+ return ''.join(result)
pixeltable/dataframe.py CHANGED
@@ -322,6 +322,8 @@ class DataFrame:
322
322
  raise excs.Error('head() cannot be used with order_by()')
323
323
  if self._has_joins():
324
324
  raise excs.Error('head() not supported for joins')
325
+ if self.group_by_clause is not None:
326
+ raise excs.Error('head() cannot be used with group_by()')
325
327
  num_rowid_cols = len(self._first_tbl.tbl_version.get().store_tbl.rowid_columns())
326
328
  order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
327
329
  return self.order_by(*order_by_clause, asc=True).limit(n).collect()
@@ -345,6 +347,8 @@ class DataFrame:
345
347
  raise excs.Error('tail() cannot be used with order_by()')
346
348
  if self._has_joins():
347
349
  raise excs.Error('tail() not supported for joins')
350
+ if self.group_by_clause is not None:
351
+ raise excs.Error('tail() cannot be used with group_by()')
348
352
  num_rowid_cols = len(self._first_tbl.tbl_version.get().store_tbl.rowid_columns())
349
353
  order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
350
354
  result = self.order_by(*order_by_clause, asc=False).limit(n).collect()
@@ -454,6 +458,9 @@ class DataFrame:
454
458
  Returns:
455
459
  The number of rows in the DataFrame.
456
460
  """
461
+ if self.group_by_clause is not None:
462
+ raise excs.Error('count() cannot be used with group_by()')
463
+
457
464
  from pixeltable.plan import Planner
458
465
 
459
466
  stmt = Planner.create_count_stmt(self._first_tbl, self.where_clause)
@@ -513,9 +520,9 @@ class DataFrame:
513
520
  (select list, where clause, ...) vertically.
514
521
  """
515
522
  if getattr(builtins, '__IPYTHON__', False):
516
- from IPython.display import display
523
+ from IPython.display import Markdown, display
517
524
 
518
- display(self._repr_html_())
525
+ display(Markdown(self._repr_html_()))
519
526
  else:
520
527
  print(repr(self))
521
528
 
@@ -573,10 +580,21 @@ class DataFrame:
573
580
  raise excs.Error(f'Invalid expression: {raw_expr}')
574
581
  if expr.col_type.is_invalid_type() and not (isinstance(expr, exprs.Literal) and expr.val is None):
575
582
  raise excs.Error(f'Invalid type: {raw_expr}')
583
+ if len(self._from_clause.tbls) == 1:
584
+ # Select expressions need to be retargeted in order to handle snapshots correctly, as in expressions
585
+ # such as `snapshot.select(base_tbl.col)`
586
+ # TODO: For joins involving snapshots, we need a more sophisticated retarget() that can handle
587
+ # multiple TableVersionPaths.
588
+ expr = expr.copy()
589
+ try:
590
+ expr.retarget(self._from_clause.tbls[0])
591
+ except Exception:
592
+ # If retarget() fails, then the succeeding is_bound_by() will raise an error.
593
+ pass
576
594
  if not expr.is_bound_by(self._from_clause.tbls):
577
595
  raise excs.Error(
578
596
  f"Expression '{expr}' cannot be evaluated in the context of this query's tables "
579
- f'({",".join(tbl.tbl_name() for tbl in self._from_clause.tbls)})'
597
+ f'({",".join(tbl.tbl_version.get().versioned_name for tbl in self._from_clause.tbls)})'
580
598
  )
581
599
  select_list.append((expr, name))
582
600
 
@@ -823,16 +841,18 @@ class DataFrame:
823
841
  grouping_tbl: Optional[catalog.TableVersion] = None
824
842
  group_by_clause: Optional[list[exprs.Expr]] = None
825
843
  for item in grouping_items:
826
- if isinstance(item, catalog.Table):
844
+ if isinstance(item, (catalog.Table, catalog.TableVersion)):
827
845
  if len(grouping_items) > 1:
828
846
  raise excs.Error('group_by(): only one table can be specified')
829
847
  if len(self._from_clause.tbls) > 1:
830
848
  raise excs.Error('group_by() with Table not supported for joins')
849
+ grouping_tbl = item if isinstance(item, catalog.TableVersion) else item._tbl_version.get()
831
850
  # we need to make sure that the grouping table is a base of self.tbl
832
- base = self._first_tbl.find_tbl_version(item._tbl_version_path.tbl_id())
851
+ base = self._first_tbl.find_tbl_version(grouping_tbl.id)
833
852
  if base is None or base.id == self._first_tbl.tbl_id():
834
- raise excs.Error(f'group_by(): {item._name} is not a base table of {self._first_tbl.tbl_name()}')
835
- grouping_tbl = item._tbl_version_path.tbl_version.get()
853
+ raise excs.Error(
854
+ f'group_by(): {grouping_tbl.name} is not a base table of {self._first_tbl.tbl_name()}'
855
+ )
836
856
  break
837
857
  if not isinstance(item, exprs.Expr):
838
858
  raise excs.Error(f'Invalid expression in group_by(): {item}')
@@ -848,6 +868,29 @@ class DataFrame:
848
868
  limit=self.limit_val,
849
869
  )
850
870
 
871
+ def distinct(self) -> DataFrame:
872
+ """
873
+ Remove duplicate rows from this DataFrame.
874
+
875
+ Note that grouping will be applied to the rows based on the select clause of this Dataframe.
876
+ In the absence of a select clause, by default, all columns are selected in the grouping.
877
+
878
+ Examples:
879
+ Select unique addresses from table `addresses`.
880
+
881
+ >>> results = addresses.distinct()
882
+
883
+ Select unique cities in table `addresses`
884
+
885
+ >>> results = addresses.city.distinct()
886
+
887
+ Select unique locations (street, city) in the state of `CA`
888
+
889
+ >>> results = addresses.select(addresses.street, addresses.city).where(addresses.state == 'CA').distinct()
890
+ """
891
+ exps, _ = self._normalize_select_list(self._from_clause.tbls, self.select_list)
892
+ return self.group_by(*exps)
893
+
851
894
  def order_by(self, *expr_list: exprs.Expr, asc: bool = True) -> DataFrame:
852
895
  """Add an order-by clause to this DataFrame.
853
896
 
pixeltable/env.py CHANGED
@@ -610,7 +610,7 @@ class Env:
610
610
  self.__register_package('datasets')
611
611
  self.__register_package('fiftyone')
612
612
  self.__register_package('fireworks', library_name='fireworks-ai')
613
- self.__register_package('google.generativeai', library_name='google-generativeai')
613
+ self.__register_package('google.genai', library_name='google-genai')
614
614
  self.__register_package('huggingface_hub', library_name='huggingface-hub')
615
615
  self.__register_package('label_studio_sdk', library_name='label-studio-sdk')
616
616
  self.__register_package('llama_cpp', library_name='llama-cpp-python')
pixeltable/exceptions.py CHANGED
@@ -1,4 +1,3 @@
1
- from dataclasses import dataclass
2
1
  from types import TracebackType
3
2
  from typing import TYPE_CHECKING, Any
4
3
 
@@ -10,7 +9,6 @@ class Error(Exception):
10
9
  pass
11
10
 
12
11
 
13
- @dataclass
14
12
  class ExprEvalError(Exception):
15
13
  expr: 'exprs.Expr'
16
14
  expr_msg: str
@@ -19,6 +17,26 @@ class ExprEvalError(Exception):
19
17
  input_vals: list[Any]
20
18
  row_num: int
21
19
 
20
+ def __init__(
21
+ self,
22
+ expr: 'exprs.Expr',
23
+ expr_msg: str,
24
+ exc: Exception,
25
+ exc_tb: TracebackType,
26
+ input_vals: list[Any],
27
+ row_num: int,
28
+ ) -> None:
29
+ exct = type(exc)
30
+ super().__init__(
31
+ f'Expression evaluation failed with an error of type `{exct.__module__}.{exct.__qualname__}`:\n{expr}'
32
+ )
33
+ self.expr = expr
34
+ self.expr_msg = expr_msg
35
+ self.exc = exc
36
+ self.exc_tb = exc_tb
37
+ self.input_vals = input_vals
38
+ self.row_num = row_num
39
+
22
40
 
23
41
  class PixeltableWarning(Warning):
24
42
  pass
@@ -24,6 +24,7 @@ class AggregationNode(ExecNode):
24
24
  agg_fn_eval_ctx: exprs.RowBuilder.EvalCtx
25
25
  agg_fn_calls: list[exprs.FunctionCall]
26
26
  output_batch: DataRowBatch
27
+ limit: Optional[int]
27
28
 
28
29
  def __init__(
29
30
  self,
@@ -45,6 +46,11 @@ class AggregationNode(ExecNode):
45
46
  self.agg_fn_calls = [cast(exprs.FunctionCall, e) for e in self.agg_fn_eval_ctx.target_exprs]
46
47
  # create output_batch here, rather than in __iter__(), so we don't need to remember tbl and row_builder
47
48
  self.output_batch = DataRowBatch(tbl, row_builder, 0)
49
+ self.limit = None
50
+
51
+ def set_limit(self, limit: int) -> None:
52
+ # we can't propagate the limit to our input
53
+ self.limit = limit
48
54
 
49
55
  def _reset_agg_state(self, row_num: int) -> None:
50
56
  for fn_call in self.agg_fn_calls:
@@ -69,21 +75,29 @@ class AggregationNode(ExecNode):
69
75
  prev_row: Optional[exprs.DataRow] = None
70
76
  current_group: Optional[list[Any]] = None # the values of the group-by exprs
71
77
  num_input_rows = 0
78
+ num_output_rows = 0
72
79
  async for row_batch in self.input:
73
80
  num_input_rows += len(row_batch)
74
81
  for row in row_batch:
75
82
  group = [row[e.slot_idx] for e in self.group_by] if self.group_by is not None else None
83
+
76
84
  if current_group is None:
77
85
  current_group = group
78
86
  self._reset_agg_state(0)
87
+
79
88
  if group != current_group:
80
89
  # we're entering a new group, emit a row for the previous one
81
90
  self.row_builder.eval(prev_row, self.agg_fn_eval_ctx, profile=self.ctx.profile)
82
91
  self.output_batch.add_row(prev_row)
92
+ num_output_rows += 1
93
+ if self.limit is not None and num_output_rows == self.limit:
94
+ yield self.output_batch
95
+ return
83
96
  current_group = group
84
97
  self._reset_agg_state(0)
85
98
  self._update_agg_state(row, 0)
86
99
  prev_row = row
100
+
87
101
  if prev_row is not None:
88
102
  # emit the last group
89
103
  self.row_builder.eval(prev_row, self.agg_fn_eval_ctx, profile=self.ctx.profile)
@@ -167,7 +167,7 @@ class CachePrefetchNode(ExecNode):
167
167
  assert not self.input_finished
168
168
  input_batch: Optional[DataRowBatch]
169
169
  try:
170
- input_batch = await input.__anext__()
170
+ input_batch = await anext(input)
171
171
  except StopAsyncIteration:
172
172
  input_batch = None
173
173
  if input_batch is None:
@@ -208,10 +208,6 @@ class FnCallEvaluator(Evaluator):
208
208
  _logger.debug(f'Evaluated slot {self.fn_call.slot_idx} in {end_ts - start_ts}')
209
209
  self.dispatcher.dispatch([call_args.row], self.exec_ctx)
210
210
  except Exception as exc:
211
- import anthropic
212
-
213
- if isinstance(exc, anthropic.RateLimitError):
214
- _logger.debug(f'RateLimitError: {exc}')
215
211
  _, _, exc_tb = sys.exc_info()
216
212
  call_args.row.set_exc(self.fn_call.slot_idx, exc)
217
213
  self.dispatcher.dispatch_exc(call_args.rows, self.fn_call.slot_idx, exc_tb, self.exec_ctx)
@@ -115,7 +115,7 @@ class ExprEvalNode(ExecNode):
115
115
  """
116
116
  assert not self.input_complete
117
117
  try:
118
- batch = await self.input_iter.__anext__()
118
+ batch = await anext(self.input_iter)
119
119
  assert self.next_input_batch is None
120
120
  if self.current_input_batch is None:
121
121
  self.current_input_batch = batch
@@ -282,7 +282,6 @@ class ExprEvalNode(ExecNode):
282
282
 
283
283
  if self.exc_event.is_set():
284
284
  # we got an exception that we need to propagate through __iter__()
285
- _logger.debug(f'Propagating exception {self.error}')
286
285
  if isinstance(self.error, excs.ExprEvalError):
287
286
  raise self.error from self.error.exc
288
287
  else: