pixeltable 0.4.0rc2__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (59) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +9 -1
  4. pixeltable/catalog/catalog.py +333 -99
  5. pixeltable/catalog/column.py +28 -26
  6. pixeltable/catalog/globals.py +12 -0
  7. pixeltable/catalog/insertable_table.py +8 -8
  8. pixeltable/catalog/schema_object.py +6 -0
  9. pixeltable/catalog/table.py +111 -116
  10. pixeltable/catalog/table_version.py +36 -50
  11. pixeltable/catalog/table_version_handle.py +4 -1
  12. pixeltable/catalog/table_version_path.py +28 -4
  13. pixeltable/catalog/view.py +10 -18
  14. pixeltable/config.py +4 -0
  15. pixeltable/dataframe.py +10 -9
  16. pixeltable/env.py +5 -11
  17. pixeltable/exceptions.py +6 -0
  18. pixeltable/exec/exec_node.py +2 -0
  19. pixeltable/exec/expr_eval/expr_eval_node.py +4 -4
  20. pixeltable/exec/sql_node.py +47 -30
  21. pixeltable/exprs/column_property_ref.py +2 -1
  22. pixeltable/exprs/column_ref.py +7 -6
  23. pixeltable/exprs/expr.py +4 -4
  24. pixeltable/func/__init__.py +1 -0
  25. pixeltable/func/mcp.py +74 -0
  26. pixeltable/func/query_template_function.py +4 -2
  27. pixeltable/func/tools.py +12 -2
  28. pixeltable/func/udf.py +2 -2
  29. pixeltable/functions/__init__.py +1 -0
  30. pixeltable/functions/anthropic.py +19 -45
  31. pixeltable/functions/deepseek.py +19 -38
  32. pixeltable/functions/fireworks.py +9 -18
  33. pixeltable/functions/gemini.py +2 -2
  34. pixeltable/functions/groq.py +108 -0
  35. pixeltable/functions/huggingface.py +8 -6
  36. pixeltable/functions/llama_cpp.py +6 -6
  37. pixeltable/functions/mistralai.py +16 -53
  38. pixeltable/functions/ollama.py +1 -1
  39. pixeltable/functions/openai.py +82 -170
  40. pixeltable/functions/replicate.py +2 -2
  41. pixeltable/functions/together.py +22 -80
  42. pixeltable/functions/util.py +6 -1
  43. pixeltable/globals.py +0 -2
  44. pixeltable/io/external_store.py +2 -2
  45. pixeltable/io/label_studio.py +4 -4
  46. pixeltable/io/table_data_conduit.py +1 -1
  47. pixeltable/metadata/__init__.py +1 -1
  48. pixeltable/metadata/converters/convert_37.py +15 -0
  49. pixeltable/metadata/notes.py +1 -0
  50. pixeltable/metadata/schema.py +5 -0
  51. pixeltable/plan.py +37 -121
  52. pixeltable/share/packager.py +2 -2
  53. pixeltable/type_system.py +30 -0
  54. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/METADATA +1 -1
  55. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/RECORD +58 -56
  56. pixeltable/utils/sample.py +0 -25
  57. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/LICENSE +0 -0
  58. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/WHEEL +0 -0
  59. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/entry_points.txt +0 -0
@@ -167,18 +167,6 @@ class TableVersion:
167
167
  self.idxs_by_name = {}
168
168
  self.external_stores = {}
169
169
 
170
- def init(self) -> None:
171
- """
172
- Initialize schema-related in-memory metadata separately, now that this TableVersion instance is visible
173
- in Catalog.
174
- """
175
- from .catalog import Catalog
176
-
177
- assert (self.id, self.effective_version) in Catalog.get()._tbl_versions
178
- self._init_schema()
179
- # init external stores; this needs to happen after the schema is created
180
- self._init_external_stores()
181
-
182
170
  def __hash__(self) -> int:
183
171
  return hash(self.id)
184
172
 
@@ -234,6 +222,7 @@ class TableVersion:
234
222
  next_col_id=len(cols),
235
223
  next_idx_id=0,
236
224
  next_row_id=0,
225
+ view_sn=0,
237
226
  column_md=column_md,
238
227
  index_md={},
239
228
  external_stores=[],
@@ -342,24 +331,39 @@ class TableVersion:
342
331
  return tbl_version
343
332
 
344
333
  def drop(self) -> None:
345
- from .catalog import Catalog
346
-
347
334
  if self.is_view and self.is_mutable:
348
335
  # update mutable_views
336
+ # TODO: invalidate base to force reload
349
337
  from .table_version_handle import TableVersionHandle
350
338
 
351
339
  assert self.base is not None
352
340
  if self.base.get().is_mutable:
353
341
  self.base.get().mutable_views.remove(TableVersionHandle.create(self))
354
342
 
355
- cat = Catalog.get()
343
+ # cat = Catalog.get()
356
344
  # delete this table and all associated data
357
345
  MediaStore.delete(self.id)
358
346
  FileCache.get().clear(tbl_id=self.id)
359
- cat.delete_tbl_md(self.id)
347
+ # cat.delete_tbl_md(self.id)
360
348
  self.store_tbl.drop()
361
349
  # de-register table version from catalog
362
- cat.remove_tbl_version(self)
350
+ # cat.remove_tbl_version(self)
351
+
352
+ def init(self) -> None:
353
+ """
354
+ Initialize schema-related in-memory metadata separately, now that this TableVersion instance is visible
355
+ in Catalog.
356
+ """
357
+ from .catalog import Catalog
358
+
359
+ cat = Catalog.get()
360
+ assert (self.id, self.effective_version) in cat._tbl_versions
361
+ self._init_schema()
362
+ if not self.is_snapshot:
363
+ cat.record_column_dependencies(self)
364
+
365
+ # init external stores; this needs to happen after the schema is created
366
+ self._init_external_stores()
363
367
 
364
368
  def _init_schema(self) -> None:
365
369
  # create columns first, so the indices can reference them
@@ -369,6 +373,10 @@ class TableVersion:
369
373
  # create the sa schema only after creating the columns and indices
370
374
  self._init_sa_schema()
371
375
 
376
+ # created value_exprs after everything else has been initialized
377
+ for col in self.cols_by_id.values():
378
+ col.init_value_expr()
379
+
372
380
  def _init_cols(self) -> None:
373
381
  """Initialize self.cols with the columns visible in our effective version"""
374
382
  self.cols = []
@@ -395,6 +403,7 @@ class TableVersion:
395
403
  schema_version_add=col_md.schema_version_add,
396
404
  schema_version_drop=col_md.schema_version_drop,
397
405
  value_expr_dict=col_md.value_expr,
406
+ tbl=self,
398
407
  )
399
408
  col.tbl = self
400
409
  self.cols.append(col)
@@ -410,10 +419,10 @@ class TableVersion:
410
419
  self.cols_by_name[col.name] = col
411
420
  self.cols_by_id[col.id] = col
412
421
 
413
- # make sure to traverse columns ordered by position = order in which cols were created;
414
- # this guarantees that references always point backwards
415
- if not self.is_snapshot and col_md.value_expr is not None:
416
- self._record_refd_columns(col)
422
+ # # make sure to traverse columns ordered by position = order in which cols were created;
423
+ # # this guarantees that references always point backwards
424
+ # if not self.is_snapshot and col_md.value_expr is not None:
425
+ # self._record_refd_columns(col)
417
426
 
418
427
  def _init_idxs(self) -> None:
419
428
  # self.idx_md = tbl_md.index_md
@@ -482,11 +491,6 @@ class TableVersion:
482
491
  self.id, self._tbl_md, version_md, self._schema_version_md if new_schema_version else None
483
492
  )
484
493
 
485
- def ensure_md_loaded(self) -> None:
486
- """Ensure that table metadata is loaded."""
487
- for col in self.cols_by_id.values():
488
- _ = col.value_expr
489
-
490
494
  def _store_idx_name(self, idx_id: int) -> str:
491
495
  """Return name of index in the store, which needs to be globally unique"""
492
496
  return f'idx_{self.id.hex}_{idx_id}'
@@ -700,9 +704,6 @@ class TableVersion:
700
704
  if col.name is not None:
701
705
  self.cols_by_name[col.name] = col
702
706
  self.cols_by_id[col.id] = col
703
- if col.value_expr is not None:
704
- col.check_value_expr()
705
- self._record_refd_columns(col)
706
707
 
707
708
  # also add to stored md
708
709
  self._tbl_md.column_md[col.id] = schema.ColumnMd(
@@ -760,9 +761,11 @@ class TableVersion:
760
761
  run_cleanup_on_exception(cleanup_on_error)
761
762
  plan.close()
762
763
 
764
+ pxt.catalog.Catalog.get().record_column_dependencies(self)
765
+
763
766
  if print_stats:
764
767
  plan.ctx.profile.print(num_rows=row_count)
765
- # TODO(mkornacker): what to do about system columns with exceptions?
768
+ # TODO: what to do about system columns with exceptions?
766
769
  return UpdateStatus(
767
770
  num_rows=row_count,
768
771
  num_computed_values=row_count,
@@ -805,13 +808,6 @@ class TableVersion:
805
808
  assert not self.is_snapshot
806
809
 
807
810
  for col in cols:
808
- if col.value_expr is not None:
809
- # update Column.dependent_cols
810
- for c in self.cols:
811
- if c == col:
812
- break
813
- c.dependent_cols.discard(col)
814
-
815
811
  col.schema_version_drop = self.schema_version
816
812
  if col.name is not None:
817
813
  assert col.name in self.cols_by_name
@@ -828,6 +824,7 @@ class TableVersion:
828
824
  schema_col.pos = pos
829
825
 
830
826
  self.store_tbl.create_sa_tbl()
827
+ pxt.catalog.Catalog.get().record_column_dependencies(self)
831
828
 
832
829
  def rename_column(self, old_name: str, new_name: str) -> None:
833
830
  """Rename a column."""
@@ -1458,18 +1455,6 @@ class TableVersion:
1458
1455
  names = [c.name for c in self.cols_by_name.values() if c.is_computed]
1459
1456
  return names
1460
1457
 
1461
- def _record_refd_columns(self, col: Column) -> None:
1462
- """Update Column.dependent_cols for all cols referenced in col.value_expr."""
1463
- from pixeltable import exprs
1464
-
1465
- if col.value_expr_dict is not None:
1466
- # if we have a value_expr_dict, use that instead of instantiating the value_expr
1467
- refd_cols = exprs.Expr.get_refd_columns(col.value_expr_dict)
1468
- else:
1469
- refd_cols = [e.col for e in col.value_expr.subexprs(expr_class=exprs.ColumnRef)]
1470
- for refd_col in refd_cols:
1471
- refd_col.dependent_cols.add(col)
1472
-
1473
1458
  def get_idx_val_columns(self, cols: Iterable[Column]) -> set[Column]:
1474
1459
  result = {info.val_col for col in cols for info in col.get_idx_info().values()}
1475
1460
  return result
@@ -1478,7 +1463,8 @@ class TableVersion:
1478
1463
  """
1479
1464
  Return the set of columns that transitively depend on any of the given ones.
1480
1465
  """
1481
- result = {dependent_col for col in cols for dependent_col in col.dependent_cols}
1466
+ cat = pxt.catalog.Catalog.get()
1467
+ result = set().union(*[cat.get_column_dependents(col.tbl.id, col.id) for col in cols])
1482
1468
  if len(result) > 0:
1483
1469
  result.update(self.get_dependent_columns(result))
1484
1470
  return result
@@ -34,6 +34,10 @@ class TableVersionHandle:
34
34
  def __hash__(self) -> int:
35
35
  return hash((self.id, self.effective_version))
36
36
 
37
+ @property
38
+ def is_snapshot(self) -> bool:
39
+ return self.effective_version is not None
40
+
37
41
  @classmethod
38
42
  def create(cls, tbl_version: TableVersion) -> TableVersionHandle:
39
43
  return cls(tbl_version.id, tbl_version.effective_version, tbl_version)
@@ -53,7 +57,6 @@ class TableVersionHandle:
53
57
  else:
54
58
  self._tbl_version = Catalog.get().get_tbl_version(self.id, self.effective_version)
55
59
  if self.effective_version is None:
56
- # make sure we don't see a discarded instance of a live TableVersion
57
60
  tvs = list(Catalog.get()._tbl_versions.values())
58
61
  assert self._tbl_version in tvs
59
62
  return self._tbl_version
@@ -8,6 +8,7 @@ from pixeltable.env import Env
8
8
  from pixeltable.metadata import schema
9
9
 
10
10
  from .column import Column
11
+ from .globals import MediaValidation
11
12
  from .table_version import TableVersion
12
13
  from .table_version_handle import TableVersionHandle
13
14
 
@@ -83,6 +84,7 @@ class TableVersionPath:
83
84
  if self.base is not None:
84
85
  self.base.clear_cached_md()
85
86
 
87
+ @property
86
88
  def tbl_id(self) -> UUID:
87
89
  """Return the id of the table/view that this path represents"""
88
90
  return self.tbl_version.id
@@ -92,6 +94,11 @@ class TableVersionPath:
92
94
  self.refresh_cached_md()
93
95
  return self._cached_tbl_version.version
94
96
 
97
+ def schema_version(self) -> int:
98
+ """Return the version of the table/view that this path represents"""
99
+ self.refresh_cached_md()
100
+ return self._cached_tbl_version.schema_version
101
+
95
102
  def tbl_name(self) -> str:
96
103
  """Return the name of the table/view that this path represents"""
97
104
  self.refresh_cached_md()
@@ -103,10 +110,7 @@ class TableVersionPath:
103
110
 
104
111
  def is_snapshot(self) -> bool:
105
112
  """Return True if this is a path of snapshot versions"""
106
- self.refresh_cached_md()
107
- if not self._cached_tbl_version.is_snapshot:
108
- return False
109
- return self.base.is_snapshot() if self.base is not None else True
113
+ return self.tbl_version.is_snapshot
110
114
 
111
115
  def is_view(self) -> bool:
112
116
  self.refresh_cached_md()
@@ -116,10 +120,30 @@ class TableVersionPath:
116
120
  self.refresh_cached_md()
117
121
  return self._cached_tbl_version.is_component_view
118
122
 
123
+ def is_replica(self) -> bool:
124
+ self.refresh_cached_md()
125
+ return self._cached_tbl_version.is_replica
126
+
127
+ def is_mutable(self) -> bool:
128
+ self.refresh_cached_md()
129
+ return self._cached_tbl_version.is_mutable
130
+
119
131
  def is_insertable(self) -> bool:
120
132
  self.refresh_cached_md()
121
133
  return self._cached_tbl_version.is_insertable
122
134
 
135
+ def comment(self) -> str:
136
+ self.refresh_cached_md()
137
+ return self._cached_tbl_version.comment
138
+
139
+ def num_retained_versions(self) -> int:
140
+ self.refresh_cached_md()
141
+ return self._cached_tbl_version.num_retained_versions
142
+
143
+ def media_validation(self) -> MediaValidation:
144
+ self.refresh_cached_md()
145
+ return self._cached_tbl_version.media_validation
146
+
123
147
  def get_tbl_versions(self) -> list[TableVersionHandle]:
124
148
  """Return all tbl versions"""
125
149
  if self.base is None:
@@ -41,6 +41,8 @@ class View(Table):
41
41
  def __init__(self, id: UUID, dir_id: UUID, name: str, tbl_version_path: TableVersionPath, snapshot_only: bool):
42
42
  super().__init__(id, dir_id, name, tbl_version_path)
43
43
  self._snapshot_only = snapshot_only
44
+ if not snapshot_only:
45
+ self._tbl_version = tbl_version_path.tbl_version
44
46
 
45
47
  @classmethod
46
48
  def _display_name(cls) -> str:
@@ -267,17 +269,8 @@ class View(Table):
267
269
  base=cls._get_snapshot_path(tbl_version_path.base) if tbl_version_path.base is not None else None,
268
270
  )
269
271
 
270
- def _drop(self) -> None:
271
- if self._snapshot_only:
272
- # there is not TableVersion to drop
273
- self._check_is_dropped()
274
- self.is_dropped = True
275
- catalog.Catalog.get().delete_tbl_md(self._id)
276
- else:
277
- super()._drop()
278
-
279
- def get_metadata(self) -> dict[str, Any]:
280
- md = super().get_metadata()
272
+ def _get_metadata(self) -> dict[str, Any]:
273
+ md = super()._get_metadata()
281
274
  md['is_view'] = True
282
275
  md['is_snapshot'] = self._tbl_version_path.is_snapshot()
283
276
  return md
@@ -298,11 +291,10 @@ class View(Table):
298
291
  def delete(self, where: Optional[exprs.Expr] = None) -> UpdateStatus:
299
292
  raise excs.Error(f'{self._display_name()} {self._name!r}: cannot delete from view')
300
293
 
301
- @property
302
- def _base_table(self) -> Optional['Table']:
294
+ def _get_base_table(self) -> Optional['Table']:
303
295
  # if this is a pure snapshot, our tbl_version_path only reflects the base (there is no TableVersion instance
304
296
  # for the snapshot itself)
305
- base_id = self._tbl_version.id if self._snapshot_only else self._tbl_version_path.base.tbl_version.id
297
+ base_id = self._tbl_version_path.tbl_id if self._snapshot_only else self._tbl_version_path.base.tbl_id
306
298
  return catalog.Catalog.get().get_table_by_id(base_id)
307
299
 
308
300
  @property
@@ -317,7 +309,7 @@ class View(Table):
317
309
  display_name = 'Snapshot' if self._snapshot_only else 'View'
318
310
  result = [f'{display_name} {self._path()!r}']
319
311
  bases_descrs: list[str] = []
320
- for base, effective_version in zip(self._base_tables, self._effective_base_versions):
312
+ for base, effective_version in zip(self._get_base_tables(), self._effective_base_versions):
321
313
  if effective_version is None:
322
314
  bases_descrs.append(f'{base._path()!r}')
323
315
  else:
@@ -325,8 +317,8 @@ class View(Table):
325
317
  bases_descrs.append(f'{base_descr!r}')
326
318
  result.append(f' (of {", ".join(bases_descrs)})')
327
319
 
328
- if self._tbl_version.get().predicate is not None:
329
- result.append(f'\nWhere: {self._tbl_version.get().predicate!s}')
330
- if self._tbl_version.get().sample_clause is not None:
320
+ if self._tbl_version_path.tbl_version.get().predicate is not None:
321
+ result.append(f'\nWhere: {self._tbl_version_path.tbl_version.get().predicate!s}')
322
+ if self._tbl_version_path.tbl_version.get().sample_clause is not None:
331
323
  result.append(f'\nSample: {self._tbl_version.get().sample_clause!s}')
332
324
  return ''.join(result)
pixeltable/config.py CHANGED
@@ -86,6 +86,10 @@ class Config:
86
86
  return None
87
87
 
88
88
  try:
89
+ if expected_type is bool and isinstance(value, str):
90
+ if value.lower() not in ('true', 'false'):
91
+ raise excs.Error(f'Invalid value for configuration parameter {section}.{key}: {value}')
92
+ return value.lower() == 'true' # type: ignore[return-value]
89
93
  return expected_type(value) # type: ignore[call-arg]
90
94
  except ValueError as exc:
91
95
  raise excs.Error(f'Invalid value for configuration parameter {section}.{key}: {value}') from exc
pixeltable/dataframe.py CHANGED
@@ -475,7 +475,9 @@ class DataFrame:
475
475
  raise excs.Error(msg) from e
476
476
 
477
477
  def _output_row_iterator(self) -> Iterator[list]:
478
- with Catalog.get().begin_xact(for_write=False):
478
+ # TODO: extend begin_xact() to accept multiple TVPs for joins
479
+ single_tbl = self._first_tbl if len(self._from_clause.tbls) == 1 else None
480
+ with Catalog.get().begin_xact(tbl=single_tbl, for_write=False):
479
481
  try:
480
482
  for data_row in self._exec():
481
483
  yield [data_row[e.slot_idx] for e in self._select_list_exprs]
@@ -507,7 +509,7 @@ class DataFrame:
507
509
 
508
510
  from pixeltable.plan import Planner
509
511
 
510
- with Catalog.get().begin_xact(for_write=False) as conn:
512
+ with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=False) as conn:
511
513
  stmt = Planner.create_count_stmt(self._first_tbl, self.where_clause)
512
514
  result: int = conn.execute(stmt).scalar_one()
513
515
  assert isinstance(result, int)
@@ -903,7 +905,7 @@ class DataFrame:
903
905
  grouping_tbl = item if isinstance(item, catalog.TableVersion) else item._tbl_version.get()
904
906
  # we need to make sure that the grouping table is a base of self.tbl
905
907
  base = self._first_tbl.find_tbl_version(grouping_tbl.id)
906
- if base is None or base.id == self._first_tbl.tbl_id():
908
+ if base is None or base.id == self._first_tbl.tbl_id:
907
909
  raise excs.Error(
908
910
  f'group_by(): {grouping_tbl.name} is not a base table of {self._first_tbl.tbl_name()}'
909
911
  )
@@ -1161,8 +1163,7 @@ class DataFrame:
1161
1163
  >>> df = person.where(t.year == 2014).update({'age': 30})
1162
1164
  """
1163
1165
  self._validate_mutable('update', False)
1164
- tbl_id = self._first_tbl.tbl_id()
1165
- with Catalog.get().begin_xact(tbl_id=tbl_id, for_write=True):
1166
+ with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=True, lock_mutable_tree=True):
1166
1167
  return self._first_tbl.tbl_version.get().update(value_spec, where=self.where_clause, cascade=cascade)
1167
1168
 
1168
1169
  def delete(self) -> UpdateStatus:
@@ -1185,8 +1186,7 @@ class DataFrame:
1185
1186
  self._validate_mutable('delete', False)
1186
1187
  if not self._first_tbl.is_insertable():
1187
1188
  raise excs.Error('Cannot delete from view')
1188
- tbl_id = self._first_tbl.tbl_id()
1189
- with Catalog.get().begin_xact(tbl_id=tbl_id, for_write=True):
1189
+ with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=True, lock_mutable_tree=True):
1190
1190
  return self._first_tbl.tbl_version.get().delete(where=self.where_clause)
1191
1191
 
1192
1192
  def _validate_mutable(self, op_name: str, allow_select: bool) -> None:
@@ -1307,7 +1307,8 @@ class DataFrame:
1307
1307
  assert data_file_path.is_file()
1308
1308
  return data_file_path
1309
1309
  else:
1310
- with Catalog.get().begin_xact(for_write=False):
1310
+ # TODO: extend begin_xact() to accept multiple TVPs for joins
1311
+ with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=False):
1311
1312
  return write_coco_dataset(self, dest_path)
1312
1313
 
1313
1314
  def to_pytorch_dataset(self, image_format: str = 'pt') -> 'torch.utils.data.IterableDataset':
@@ -1352,7 +1353,7 @@ class DataFrame:
1352
1353
  if dest_path.exists(): # fast path: use cache
1353
1354
  assert dest_path.is_dir()
1354
1355
  else:
1355
- with Catalog.get().begin_xact(for_write=False):
1356
+ with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=False):
1356
1357
  export_parquet(self, dest_path, inline_images=True)
1357
1358
 
1358
1359
  return PixeltablePytorchDataset(path=dest_path, image_format=image_format)
pixeltable/env.py CHANGED
@@ -10,7 +10,6 @@ import logging
10
10
  import os
11
11
  import platform
12
12
  import shutil
13
- import subprocess
14
13
  import sys
15
14
  import threading
16
15
  import uuid
@@ -611,9 +610,11 @@ class Env:
611
610
  self.__register_package('fiftyone')
612
611
  self.__register_package('fireworks', library_name='fireworks-ai')
613
612
  self.__register_package('google.genai', library_name='google-genai')
613
+ self.__register_package('groq')
614
614
  self.__register_package('huggingface_hub', library_name='huggingface-hub')
615
615
  self.__register_package('label_studio_sdk', library_name='label-studio-sdk')
616
616
  self.__register_package('llama_cpp', library_name='llama-cpp-python')
617
+ self.__register_package('mcp')
617
618
  self.__register_package('mistralai')
618
619
  self.__register_package('mistune')
619
620
  self.__register_package('ollama')
@@ -746,18 +747,11 @@ class Env:
746
747
  have no sub-dependencies (in fact, this is how spaCy normally manages its model resources).
747
748
  """
748
749
  import spacy
749
- from spacy.cli.download import get_model_filename
750
+ from spacy.cli.download import download
750
751
 
751
752
  spacy_model = 'en_core_web_sm'
752
- spacy_model_version = '3.7.1'
753
- filename = get_model_filename(spacy_model, spacy_model_version, sdist=False)
754
- url = f'{spacy.about.__download_url__}/{filename}'
755
- # Try to `pip install` the model. We set check=False; if the pip command fails, it's not necessarily
756
- # a problem, because the model might have been installed on a previous attempt.
757
- self._logger.info(f'Ensuring spaCy model is installed: {filename}')
758
- ret = subprocess.run([sys.executable, '-m', 'pip', 'install', '-qU', url], check=False)
759
- if ret.returncode != 0:
760
- self._logger.warning(f'pip install failed for spaCy model: {filename}')
753
+ self._logger.info(f'Ensuring spaCy model is installed: {spacy_model}')
754
+ download(spacy_model)
761
755
  self._logger.info(f'Loading spaCy model: {spacy_model}')
762
756
  try:
763
757
  self._spacy_nlp = spacy.load(spacy_model)
pixeltable/exceptions.py CHANGED
@@ -10,6 +10,12 @@ class Error(Exception):
10
10
 
11
11
 
12
12
  class ExprEvalError(Exception):
13
+ """
14
+ Used during query execution to signal expr evaluation failures.
15
+
16
+ NOT A USER-FACING EXCEPTION. All ExprEvalError instances need to be converted into Error instances.
17
+ """
18
+
13
19
  expr: 'exprs.Expr'
14
20
  expr_msg: str
15
21
  exc: Exception
@@ -73,6 +73,8 @@ class ExecNode(abc.ABC):
73
73
  except RuntimeError:
74
74
  loop = asyncio.new_event_loop()
75
75
  asyncio.set_event_loop(loop)
76
+ # we set a deliberately long duration to avoid warnings getting printed to the console in debug mode
77
+ loop.slow_callback_duration = 3600
76
78
 
77
79
  if _logger.isEnabledFor(logging.DEBUG):
78
80
  loop.set_debug(True)
@@ -49,7 +49,7 @@ class ExprEvalNode(ExecNode):
49
49
  # execution state
50
50
  tasks: set[asyncio.Task] # collects all running tasks to prevent them from getting gc'd
51
51
  exc_event: asyncio.Event # set if an exception needs to be propagated
52
- error: Optional[Union[excs.Error, excs.ExprEvalError]] # exception that needs to be propagated
52
+ error: Optional[Union[Exception]] # exception that needs to be propagated
53
53
  completed_rows: asyncio.Queue[exprs.DataRow] # rows that have completed evaluation
54
54
  completed_event: asyncio.Event # set when completed_rows is non-empty
55
55
  input_iter: AsyncIterator[DataRowBatch]
@@ -133,10 +133,10 @@ class ExprEvalNode(ExecNode):
133
133
  except StopAsyncIteration:
134
134
  self.input_complete = True
135
135
  _logger.debug(f'finished input: #input_rows={self.num_input_rows}, #avail={self.avail_input_rows}')
136
- except excs.Error as err:
137
- self.error = err
136
+ # make sure to pass DBAPIError through, so the transaction handling logic sees it
137
+ except Exception as exc:
138
+ self.error = exc
138
139
  self.exc_event.set()
139
- # TODO: should we also handle Exception here and create an excs.Error from it?
140
140
 
141
141
  @property
142
142
  def total_buffered(self) -> int:
@@ -308,8 +308,7 @@ class SqlNode(ExecNode):
308
308
  _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
309
309
  except Exception:
310
310
  # log something if we can't log the compiled stmt
311
- stmt_str = repr(stmt)
312
- _logger.debug(f'SqlLookupNode proto-stmt:\n{stmt_str}')
311
+ _logger.debug(f'SqlLookupNode proto-stmt:\n{stmt}')
313
312
  self._log_explain(stmt)
314
313
 
315
314
  conn = Env.get().conn
@@ -530,40 +529,39 @@ class SqlJoinNode(SqlNode):
530
529
 
531
530
  class SqlSampleNode(SqlNode):
532
531
  """
533
- Returns rows from a stratified sample with N samples per strata.
532
+ Returns rows sampled from the input node.
534
533
  """
535
534
 
536
- stratify_exprs: Optional[list[exprs.Expr]]
537
- n_samples: Optional[int]
538
- fraction_samples: Optional[float]
539
- seed: int
540
535
  input_cte: Optional[sql.CTE]
541
536
  pk_count: int
537
+ stratify_exprs: Optional[list[exprs.Expr]]
538
+ sample_clause: 'SampleClause'
542
539
 
543
540
  def __init__(
544
541
  self,
545
542
  row_builder: exprs.RowBuilder,
546
543
  input: SqlNode,
547
544
  select_list: Iterable[exprs.Expr],
548
- stratify_exprs: Optional[list[exprs.Expr]] = None,
549
- sample_clause: Optional['SampleClause'] = None,
545
+ sample_clause: 'SampleClause',
546
+ stratify_exprs: list[exprs.Expr],
550
547
  ):
551
548
  """
552
549
  Args:
550
+ input: SqlNode to sample from
553
551
  select_list: can contain calls to AggregateFunctions
554
- stratify_exprs: list of expressions to group by
555
- n: number of samples per strata
552
+ sample_clause: specifies the sampling method
553
+ stratify_exprs: Analyzer processed list of expressions to stratify by.
556
554
  """
555
+ assert isinstance(input, SqlNode)
557
556
  self.input_cte, input_col_map = input.to_cte(keep_pk=True)
558
557
  self.pk_count = input.num_pk_cols
559
558
  assert self.pk_count > 1
560
559
  sql_elements = exprs.SqlElementCache(input_col_map)
560
+ assert sql_elements.contains_all(stratify_exprs)
561
561
  super().__init__(input.tbl, row_builder, select_list, sql_elements, set_pk=True)
562
562
  self.stratify_exprs = stratify_exprs
563
- self.n_samples = sample_clause.n
564
- self.n_per_stratum = sample_clause.n_per_stratum
565
- self.fraction_samples = sample_clause.fraction
566
- self.seed = sample_clause.seed if sample_clause.seed is not None else 0
563
+ self.sample_clause = sample_clause
564
+ assert isinstance(self.sample_clause.seed, int)
567
565
 
568
566
  @classmethod
569
567
  def key_sql_expr(cls, seed: sql.ColumnElement, sql_cols: Iterable[sql.ColumnElement]) -> sql.ColumnElement:
@@ -573,25 +571,44 @@ class SqlSampleNode(SqlNode):
573
571
  """
574
572
  sql_expr: sql.ColumnElement = sql.cast(seed, sql.Text)
575
573
  for e in sql_cols:
576
- sql_expr = sql_expr + sql.literal_column("'___'") + sql.cast(e, sql.Text)
574
+ # Quotes are required below to guarantee that the string is properly presented in SQL
575
+ sql_expr = sql_expr + sql.literal_column("'___'", sql.Text) + sql.cast(e, sql.Text)
577
576
  sql_expr = sql.func.md5(sql_expr)
578
577
  return sql_expr
579
578
 
580
- def _create_order_by(self, cte: sql.CTE) -> sql.ColumnElement:
579
+ def _create_key_sql(self, cte: sql.CTE) -> sql.ColumnElement:
581
580
  """Create an expression for randomly ordering rows with a given seed"""
582
581
  rowid_cols = [*cte.c[-self.pk_count : -1]] # exclude the version column
583
582
  assert len(rowid_cols) > 0
584
- return self.key_sql_expr(sql.literal_column(str(self.seed)), rowid_cols)
583
+ return self.key_sql_expr(sql.literal_column(str(self.sample_clause.seed)), rowid_cols)
585
584
 
586
585
  def _create_stmt(self) -> sql.Select:
587
- if self.fraction_samples is not None:
588
- return self._create_stmt_fraction(self.fraction_samples)
589
- return self._create_stmt_n(self.n_samples, self.n_per_stratum)
586
+ from pixeltable.plan import SampleClause
587
+
588
+ if self.sample_clause.fraction is not None:
589
+ if len(self.stratify_exprs) == 0:
590
+ # If non-stratified sampling, construct a where clause, order_by, and limit clauses
591
+ s_key = self._create_key_sql(self.input_cte)
592
+
593
+ # Construct a suitable where clause
594
+ fraction_sql = sql.cast(SampleClause.fraction_to_md5_hex(float(self.sample_clause.fraction)), sql.Text)
595
+ order_by = self._create_key_sql(self.input_cte)
596
+ return sql.select(*self.input_cte.c).where(s_key < fraction_sql).order_by(order_by)
597
+
598
+ return self._create_stmt_stratified_fraction(self.sample_clause.fraction)
599
+ else:
600
+ if len(self.stratify_exprs) == 0:
601
+ # No stratification, just return n samples from the input CTE
602
+ order_by = self._create_key_sql(self.input_cte)
603
+ return sql.select(*self.input_cte.c).order_by(order_by).limit(self.sample_clause.n)
604
+
605
+ return self._create_stmt_stratified_n(self.sample_clause.n, self.sample_clause.n_per_stratum)
606
+
607
+ def _create_stmt_stratified_n(self, n: Optional[int], n_per_stratum: Optional[int]) -> sql.Select:
608
+ """Create a Select stmt that returns n samples across all strata or n_per_stratum samples per stratum"""
590
609
 
591
- def _create_stmt_n(self, n: Optional[int], n_per_stratum: Optional[int]) -> sql.Select:
592
- """Create a Select stmt that returns n samples across all strata"""
593
610
  sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
594
- order_by = self._create_order_by(self.input_cte)
611
+ order_by = self._create_key_sql(self.input_cte)
595
612
 
596
613
  # Create a list of all columns plus the rank
597
614
  # Get all columns from the input CTE dynamically
@@ -605,15 +622,15 @@ class SqlSampleNode(SqlNode):
605
622
  if n_per_stratum is not None:
606
623
  return sql.select(*final_columns).filter(row_rank_cte.c.rank <= n_per_stratum)
607
624
  else:
608
- secondary_order = self._create_order_by(row_rank_cte)
625
+ secondary_order = self._create_key_sql(row_rank_cte)
609
626
  return sql.select(*final_columns).order_by(row_rank_cte.c.rank, secondary_order).limit(n)
610
627
 
611
- def _create_stmt_fraction(self, fraction_samples: float) -> sql.Select:
628
+ def _create_stmt_stratified_fraction(self, fraction_samples: float) -> sql.Select:
612
629
  """Create a Select stmt that returns a fraction of the rows per strata"""
613
630
 
614
631
  # Build the strata count CTE
615
632
  # Produces a table of the form:
616
- # ([stratify_exprs], s_s_size)
633
+ # (*stratify_exprs, s_s_size)
617
634
  # where s_s_size is the number of samples to take from each stratum
618
635
  sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
619
636
  per_strata_count_cte = (
@@ -628,19 +645,19 @@ class SqlSampleNode(SqlNode):
628
645
 
629
646
  # Build a CTE that ranks the rows within each stratum
630
647
  # Include all columns from the input CTE dynamically
631
- order_by = self._create_order_by(self.input_cte)
648
+ order_by = self._create_key_sql(self.input_cte)
632
649
  select_columns = [*self.input_cte.c]
633
650
  select_columns.append(
634
651
  sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
635
652
  )
636
653
  row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
637
654
 
638
- # Build the join criterion dynamically to accommodate any number of group by columns
655
+ # Build the join criterion dynamically to accommodate any number of stratify_by expressions
639
656
  join_c = sql.true()
640
657
  for col in per_strata_count_cte.c[:-1]:
641
658
  join_c &= row_rank_cte.c[col.name].isnot_distinct_from(col)
642
659
 
643
- # Join srcp with per_strata_count_cte to limit returns to the requested fraction of rows
660
+ # Join with per_strata_count_cte to limit returns to the requested fraction of rows
644
661
  final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
645
662
  stmt = (
646
663
  sql.select(*final_columns)