pixeltable 0.4.0rc3__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (58) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +11 -2
  4. pixeltable/catalog/catalog.py +407 -119
  5. pixeltable/catalog/column.py +38 -26
  6. pixeltable/catalog/globals.py +130 -15
  7. pixeltable/catalog/insertable_table.py +10 -9
  8. pixeltable/catalog/schema_object.py +6 -0
  9. pixeltable/catalog/table.py +245 -119
  10. pixeltable/catalog/table_version.py +142 -116
  11. pixeltable/catalog/table_version_handle.py +30 -2
  12. pixeltable/catalog/table_version_path.py +28 -4
  13. pixeltable/catalog/view.py +14 -20
  14. pixeltable/config.py +4 -0
  15. pixeltable/dataframe.py +10 -9
  16. pixeltable/env.py +5 -11
  17. pixeltable/exceptions.py +6 -0
  18. pixeltable/exec/exec_node.py +2 -0
  19. pixeltable/exec/expr_eval/expr_eval_node.py +4 -4
  20. pixeltable/exec/sql_node.py +47 -30
  21. pixeltable/exprs/column_property_ref.py +2 -10
  22. pixeltable/exprs/column_ref.py +24 -21
  23. pixeltable/exprs/data_row.py +9 -0
  24. pixeltable/exprs/expr.py +4 -4
  25. pixeltable/exprs/row_builder.py +44 -13
  26. pixeltable/func/__init__.py +1 -0
  27. pixeltable/func/mcp.py +74 -0
  28. pixeltable/func/query_template_function.py +4 -2
  29. pixeltable/func/tools.py +12 -2
  30. pixeltable/func/udf.py +2 -2
  31. pixeltable/functions/__init__.py +1 -0
  32. pixeltable/functions/groq.py +108 -0
  33. pixeltable/functions/huggingface.py +8 -6
  34. pixeltable/functions/mistralai.py +2 -13
  35. pixeltable/functions/openai.py +1 -6
  36. pixeltable/functions/replicate.py +2 -2
  37. pixeltable/functions/util.py +6 -1
  38. pixeltable/globals.py +0 -2
  39. pixeltable/io/external_store.py +81 -54
  40. pixeltable/io/globals.py +1 -1
  41. pixeltable/io/label_studio.py +49 -45
  42. pixeltable/io/table_data_conduit.py +1 -1
  43. pixeltable/metadata/__init__.py +1 -1
  44. pixeltable/metadata/converters/convert_37.py +15 -0
  45. pixeltable/metadata/converters/convert_38.py +39 -0
  46. pixeltable/metadata/notes.py +2 -0
  47. pixeltable/metadata/schema.py +5 -0
  48. pixeltable/metadata/utils.py +78 -0
  49. pixeltable/plan.py +59 -139
  50. pixeltable/share/packager.py +2 -2
  51. pixeltable/store.py +114 -103
  52. pixeltable/type_system.py +30 -0
  53. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.2.dist-info}/METADATA +1 -1
  54. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.2.dist-info}/RECORD +57 -53
  55. pixeltable/utils/sample.py +0 -25
  56. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.2.dist-info}/LICENSE +0 -0
  57. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.2.dist-info}/WHEEL +0 -0
  58. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.2.dist-info}/entry_points.txt +0 -0
@@ -41,6 +41,8 @@ class View(Table):
41
41
  def __init__(self, id: UUID, dir_id: UUID, name: str, tbl_version_path: TableVersionPath, snapshot_only: bool):
42
42
  super().__init__(id, dir_id, name, tbl_version_path)
43
43
  self._snapshot_only = snapshot_only
44
+ if not snapshot_only:
45
+ self._tbl_version = tbl_version_path.tbl_version
44
46
 
45
47
  @classmethod
46
48
  def _display_name(cls) -> str:
@@ -227,7 +229,7 @@ class View(Table):
227
229
 
228
230
  try:
229
231
  plan, _ = Planner.create_view_load_plan(view._tbl_version_path)
230
- num_rows, num_excs, _ = tbl_version.store_tbl.insert_rows(plan, v_min=tbl_version.version)
232
+ _, status = tbl_version.store_tbl.insert_rows(plan, v_min=tbl_version.version)
231
233
  except:
232
234
  # we need to remove the orphaned TableVersion instance
233
235
  del catalog.Catalog.get()._tbl_versions[tbl_version.id, tbl_version.effective_version]
@@ -236,7 +238,9 @@ class View(Table):
236
238
  # also remove tbl_version from the base
237
239
  base_tbl_version.mutable_views.remove(TableVersionHandle.create(tbl_version))
238
240
  raise
239
- Env.get().console_logger.info(f'Created view `{name}` with {num_rows} rows, {num_excs} exceptions.')
241
+ Env.get().console_logger.info(
242
+ f'Created view `{name}` with {status.num_rows} rows, {status.num_excs} exceptions.'
243
+ )
240
244
 
241
245
  session.commit()
242
246
  return view
@@ -267,17 +271,8 @@ class View(Table):
267
271
  base=cls._get_snapshot_path(tbl_version_path.base) if tbl_version_path.base is not None else None,
268
272
  )
269
273
 
270
- def _drop(self) -> None:
271
- if self._snapshot_only:
272
- # there is not TableVersion to drop
273
- self._check_is_dropped()
274
- self.is_dropped = True
275
- catalog.Catalog.get().delete_tbl_md(self._id)
276
- else:
277
- super()._drop()
278
-
279
- def get_metadata(self) -> dict[str, Any]:
280
- md = super().get_metadata()
274
+ def _get_metadata(self) -> dict[str, Any]:
275
+ md = super()._get_metadata()
281
276
  md['is_view'] = True
282
277
  md['is_snapshot'] = self._tbl_version_path.is_snapshot()
283
278
  return md
@@ -298,11 +293,10 @@ class View(Table):
298
293
  def delete(self, where: Optional[exprs.Expr] = None) -> UpdateStatus:
299
294
  raise excs.Error(f'{self._display_name()} {self._name!r}: cannot delete from view')
300
295
 
301
- @property
302
- def _base_table(self) -> Optional['Table']:
296
+ def _get_base_table(self) -> Optional['Table']:
303
297
  # if this is a pure snapshot, our tbl_version_path only reflects the base (there is no TableVersion instance
304
298
  # for the snapshot itself)
305
- base_id = self._tbl_version.id if self._snapshot_only else self._tbl_version_path.base.tbl_version.id
299
+ base_id = self._tbl_version_path.tbl_id if self._snapshot_only else self._tbl_version_path.base.tbl_id
306
300
  return catalog.Catalog.get().get_table_by_id(base_id)
307
301
 
308
302
  @property
@@ -317,7 +311,7 @@ class View(Table):
317
311
  display_name = 'Snapshot' if self._snapshot_only else 'View'
318
312
  result = [f'{display_name} {self._path()!r}']
319
313
  bases_descrs: list[str] = []
320
- for base, effective_version in zip(self._base_tables, self._effective_base_versions):
314
+ for base, effective_version in zip(self._get_base_tables(), self._effective_base_versions):
321
315
  if effective_version is None:
322
316
  bases_descrs.append(f'{base._path()!r}')
323
317
  else:
@@ -325,8 +319,8 @@ class View(Table):
325
319
  bases_descrs.append(f'{base_descr!r}')
326
320
  result.append(f' (of {", ".join(bases_descrs)})')
327
321
 
328
- if self._tbl_version.get().predicate is not None:
329
- result.append(f'\nWhere: {self._tbl_version.get().predicate!s}')
330
- if self._tbl_version.get().sample_clause is not None:
322
+ if self._tbl_version_path.tbl_version.get().predicate is not None:
323
+ result.append(f'\nWhere: {self._tbl_version_path.tbl_version.get().predicate!s}')
324
+ if self._tbl_version_path.tbl_version.get().sample_clause is not None:
331
325
  result.append(f'\nSample: {self._tbl_version.get().sample_clause!s}')
332
326
  return ''.join(result)
pixeltable/config.py CHANGED
@@ -86,6 +86,10 @@ class Config:
86
86
  return None
87
87
 
88
88
  try:
89
+ if expected_type is bool and isinstance(value, str):
90
+ if value.lower() not in ('true', 'false'):
91
+ raise excs.Error(f'Invalid value for configuration parameter {section}.{key}: {value}')
92
+ return value.lower() == 'true' # type: ignore[return-value]
89
93
  return expected_type(value) # type: ignore[call-arg]
90
94
  except ValueError as exc:
91
95
  raise excs.Error(f'Invalid value for configuration parameter {section}.{key}: {value}') from exc
pixeltable/dataframe.py CHANGED
@@ -475,7 +475,9 @@ class DataFrame:
475
475
  raise excs.Error(msg) from e
476
476
 
477
477
  def _output_row_iterator(self) -> Iterator[list]:
478
- with Catalog.get().begin_xact(for_write=False):
478
+ # TODO: extend begin_xact() to accept multiple TVPs for joins
479
+ single_tbl = self._first_tbl if len(self._from_clause.tbls) == 1 else None
480
+ with Catalog.get().begin_xact(tbl=single_tbl, for_write=False):
479
481
  try:
480
482
  for data_row in self._exec():
481
483
  yield [data_row[e.slot_idx] for e in self._select_list_exprs]
@@ -507,7 +509,7 @@ class DataFrame:
507
509
 
508
510
  from pixeltable.plan import Planner
509
511
 
510
- with Catalog.get().begin_xact(for_write=False) as conn:
512
+ with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=False) as conn:
511
513
  stmt = Planner.create_count_stmt(self._first_tbl, self.where_clause)
512
514
  result: int = conn.execute(stmt).scalar_one()
513
515
  assert isinstance(result, int)
@@ -903,7 +905,7 @@ class DataFrame:
903
905
  grouping_tbl = item if isinstance(item, catalog.TableVersion) else item._tbl_version.get()
904
906
  # we need to make sure that the grouping table is a base of self.tbl
905
907
  base = self._first_tbl.find_tbl_version(grouping_tbl.id)
906
- if base is None or base.id == self._first_tbl.tbl_id():
908
+ if base is None or base.id == self._first_tbl.tbl_id:
907
909
  raise excs.Error(
908
910
  f'group_by(): {grouping_tbl.name} is not a base table of {self._first_tbl.tbl_name()}'
909
911
  )
@@ -1161,8 +1163,7 @@ class DataFrame:
1161
1163
  >>> df = person.where(t.year == 2014).update({'age': 30})
1162
1164
  """
1163
1165
  self._validate_mutable('update', False)
1164
- tbl_id = self._first_tbl.tbl_id()
1165
- with Catalog.get().begin_xact(tbl_id=tbl_id, for_write=True):
1166
+ with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=True, lock_mutable_tree=True):
1166
1167
  return self._first_tbl.tbl_version.get().update(value_spec, where=self.where_clause, cascade=cascade)
1167
1168
 
1168
1169
  def delete(self) -> UpdateStatus:
@@ -1185,8 +1186,7 @@ class DataFrame:
1185
1186
  self._validate_mutable('delete', False)
1186
1187
  if not self._first_tbl.is_insertable():
1187
1188
  raise excs.Error('Cannot delete from view')
1188
- tbl_id = self._first_tbl.tbl_id()
1189
- with Catalog.get().begin_xact(tbl_id=tbl_id, for_write=True):
1189
+ with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=True, lock_mutable_tree=True):
1190
1190
  return self._first_tbl.tbl_version.get().delete(where=self.where_clause)
1191
1191
 
1192
1192
  def _validate_mutable(self, op_name: str, allow_select: bool) -> None:
@@ -1307,7 +1307,8 @@ class DataFrame:
1307
1307
  assert data_file_path.is_file()
1308
1308
  return data_file_path
1309
1309
  else:
1310
- with Catalog.get().begin_xact(for_write=False):
1310
+ # TODO: extend begin_xact() to accept multiple TVPs for joins
1311
+ with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=False):
1311
1312
  return write_coco_dataset(self, dest_path)
1312
1313
 
1313
1314
  def to_pytorch_dataset(self, image_format: str = 'pt') -> 'torch.utils.data.IterableDataset':
@@ -1352,7 +1353,7 @@ class DataFrame:
1352
1353
  if dest_path.exists(): # fast path: use cache
1353
1354
  assert dest_path.is_dir()
1354
1355
  else:
1355
- with Catalog.get().begin_xact(for_write=False):
1356
+ with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=False):
1356
1357
  export_parquet(self, dest_path, inline_images=True)
1357
1358
 
1358
1359
  return PixeltablePytorchDataset(path=dest_path, image_format=image_format)
pixeltable/env.py CHANGED
@@ -10,7 +10,6 @@ import logging
10
10
  import os
11
11
  import platform
12
12
  import shutil
13
- import subprocess
14
13
  import sys
15
14
  import threading
16
15
  import uuid
@@ -611,9 +610,11 @@ class Env:
611
610
  self.__register_package('fiftyone')
612
611
  self.__register_package('fireworks', library_name='fireworks-ai')
613
612
  self.__register_package('google.genai', library_name='google-genai')
613
+ self.__register_package('groq')
614
614
  self.__register_package('huggingface_hub', library_name='huggingface-hub')
615
615
  self.__register_package('label_studio_sdk', library_name='label-studio-sdk')
616
616
  self.__register_package('llama_cpp', library_name='llama-cpp-python')
617
+ self.__register_package('mcp')
617
618
  self.__register_package('mistralai')
618
619
  self.__register_package('mistune')
619
620
  self.__register_package('ollama')
@@ -746,18 +747,11 @@ class Env:
746
747
  have no sub-dependencies (in fact, this is how spaCy normally manages its model resources).
747
748
  """
748
749
  import spacy
749
- from spacy.cli.download import get_model_filename
750
+ from spacy.cli.download import download
750
751
 
751
752
  spacy_model = 'en_core_web_sm'
752
- spacy_model_version = '3.7.1'
753
- filename = get_model_filename(spacy_model, spacy_model_version, sdist=False)
754
- url = f'{spacy.about.__download_url__}/{filename}'
755
- # Try to `pip install` the model. We set check=False; if the pip command fails, it's not necessarily
756
- # a problem, because the model might have been installed on a previous attempt.
757
- self._logger.info(f'Ensuring spaCy model is installed: {filename}')
758
- ret = subprocess.run([sys.executable, '-m', 'pip', 'install', '-qU', url], check=False)
759
- if ret.returncode != 0:
760
- self._logger.warning(f'pip install failed for spaCy model: {filename}')
753
+ self._logger.info(f'Ensuring spaCy model is installed: {spacy_model}')
754
+ download(spacy_model)
761
755
  self._logger.info(f'Loading spaCy model: {spacy_model}')
762
756
  try:
763
757
  self._spacy_nlp = spacy.load(spacy_model)
pixeltable/exceptions.py CHANGED
@@ -10,6 +10,12 @@ class Error(Exception):
10
10
 
11
11
 
12
12
  class ExprEvalError(Exception):
13
+ """
14
+ Used during query execution to signal expr evaluation failures.
15
+
16
+ NOT A USER-FACING EXCEPTION. All ExprEvalError instances need to be converted into Error instances.
17
+ """
18
+
13
19
  expr: 'exprs.Expr'
14
20
  expr_msg: str
15
21
  exc: Exception
@@ -73,6 +73,8 @@ class ExecNode(abc.ABC):
73
73
  except RuntimeError:
74
74
  loop = asyncio.new_event_loop()
75
75
  asyncio.set_event_loop(loop)
76
+ # we set a deliberately long duration to avoid warnings getting printed to the console in debug mode
77
+ loop.slow_callback_duration = 3600
76
78
 
77
79
  if _logger.isEnabledFor(logging.DEBUG):
78
80
  loop.set_debug(True)
@@ -49,7 +49,7 @@ class ExprEvalNode(ExecNode):
49
49
  # execution state
50
50
  tasks: set[asyncio.Task] # collects all running tasks to prevent them from getting gc'd
51
51
  exc_event: asyncio.Event # set if an exception needs to be propagated
52
- error: Optional[Union[excs.Error, excs.ExprEvalError]] # exception that needs to be propagated
52
+ error: Optional[Union[Exception]] # exception that needs to be propagated
53
53
  completed_rows: asyncio.Queue[exprs.DataRow] # rows that have completed evaluation
54
54
  completed_event: asyncio.Event # set when completed_rows is non-empty
55
55
  input_iter: AsyncIterator[DataRowBatch]
@@ -133,10 +133,10 @@ class ExprEvalNode(ExecNode):
133
133
  except StopAsyncIteration:
134
134
  self.input_complete = True
135
135
  _logger.debug(f'finished input: #input_rows={self.num_input_rows}, #avail={self.avail_input_rows}')
136
- except excs.Error as err:
137
- self.error = err
136
+ # make sure to pass DBAPIError through, so the transaction handling logic sees it
137
+ except Exception as exc:
138
+ self.error = exc
138
139
  self.exc_event.set()
139
- # TODO: should we also handle Exception here and create an excs.Error from it?
140
140
 
141
141
  @property
142
142
  def total_buffered(self) -> int:
@@ -308,8 +308,7 @@ class SqlNode(ExecNode):
308
308
  _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
309
309
  except Exception:
310
310
  # log something if we can't log the compiled stmt
311
- stmt_str = repr(stmt)
312
- _logger.debug(f'SqlLookupNode proto-stmt:\n{stmt_str}')
311
+ _logger.debug(f'SqlLookupNode proto-stmt:\n{stmt}')
313
312
  self._log_explain(stmt)
314
313
 
315
314
  conn = Env.get().conn
@@ -530,40 +529,39 @@ class SqlJoinNode(SqlNode):
530
529
 
531
530
  class SqlSampleNode(SqlNode):
532
531
  """
533
- Returns rows from a stratified sample with N samples per strata.
532
+ Returns rows sampled from the input node.
534
533
  """
535
534
 
536
- stratify_exprs: Optional[list[exprs.Expr]]
537
- n_samples: Optional[int]
538
- fraction_samples: Optional[float]
539
- seed: int
540
535
  input_cte: Optional[sql.CTE]
541
536
  pk_count: int
537
+ stratify_exprs: Optional[list[exprs.Expr]]
538
+ sample_clause: 'SampleClause'
542
539
 
543
540
  def __init__(
544
541
  self,
545
542
  row_builder: exprs.RowBuilder,
546
543
  input: SqlNode,
547
544
  select_list: Iterable[exprs.Expr],
548
- stratify_exprs: Optional[list[exprs.Expr]] = None,
549
- sample_clause: Optional['SampleClause'] = None,
545
+ sample_clause: 'SampleClause',
546
+ stratify_exprs: list[exprs.Expr],
550
547
  ):
551
548
  """
552
549
  Args:
550
+ input: SqlNode to sample from
553
551
  select_list: can contain calls to AggregateFunctions
554
- stratify_exprs: list of expressions to group by
555
- n: number of samples per strata
552
+ sample_clause: specifies the sampling method
553
+ stratify_exprs: Analyzer processed list of expressions to stratify by.
556
554
  """
555
+ assert isinstance(input, SqlNode)
557
556
  self.input_cte, input_col_map = input.to_cte(keep_pk=True)
558
557
  self.pk_count = input.num_pk_cols
559
558
  assert self.pk_count > 1
560
559
  sql_elements = exprs.SqlElementCache(input_col_map)
560
+ assert sql_elements.contains_all(stratify_exprs)
561
561
  super().__init__(input.tbl, row_builder, select_list, sql_elements, set_pk=True)
562
562
  self.stratify_exprs = stratify_exprs
563
- self.n_samples = sample_clause.n
564
- self.n_per_stratum = sample_clause.n_per_stratum
565
- self.fraction_samples = sample_clause.fraction
566
- self.seed = sample_clause.seed if sample_clause.seed is not None else 0
563
+ self.sample_clause = sample_clause
564
+ assert isinstance(self.sample_clause.seed, int)
567
565
 
568
566
  @classmethod
569
567
  def key_sql_expr(cls, seed: sql.ColumnElement, sql_cols: Iterable[sql.ColumnElement]) -> sql.ColumnElement:
@@ -573,25 +571,44 @@ class SqlSampleNode(SqlNode):
573
571
  """
574
572
  sql_expr: sql.ColumnElement = sql.cast(seed, sql.Text)
575
573
  for e in sql_cols:
576
- sql_expr = sql_expr + sql.literal_column("'___'") + sql.cast(e, sql.Text)
574
+ # Quotes are required below to guarantee that the string is properly presented in SQL
575
+ sql_expr = sql_expr + sql.literal_column("'___'", sql.Text) + sql.cast(e, sql.Text)
577
576
  sql_expr = sql.func.md5(sql_expr)
578
577
  return sql_expr
579
578
 
580
- def _create_order_by(self, cte: sql.CTE) -> sql.ColumnElement:
579
+ def _create_key_sql(self, cte: sql.CTE) -> sql.ColumnElement:
581
580
  """Create an expression for randomly ordering rows with a given seed"""
582
581
  rowid_cols = [*cte.c[-self.pk_count : -1]] # exclude the version column
583
582
  assert len(rowid_cols) > 0
584
- return self.key_sql_expr(sql.literal_column(str(self.seed)), rowid_cols)
583
+ return self.key_sql_expr(sql.literal_column(str(self.sample_clause.seed)), rowid_cols)
585
584
 
586
585
  def _create_stmt(self) -> sql.Select:
587
- if self.fraction_samples is not None:
588
- return self._create_stmt_fraction(self.fraction_samples)
589
- return self._create_stmt_n(self.n_samples, self.n_per_stratum)
586
+ from pixeltable.plan import SampleClause
587
+
588
+ if self.sample_clause.fraction is not None:
589
+ if len(self.stratify_exprs) == 0:
590
+ # If non-stratified sampling, construct a where clause, order_by, and limit clauses
591
+ s_key = self._create_key_sql(self.input_cte)
592
+
593
+ # Construct a suitable where clause
594
+ fraction_sql = sql.cast(SampleClause.fraction_to_md5_hex(float(self.sample_clause.fraction)), sql.Text)
595
+ order_by = self._create_key_sql(self.input_cte)
596
+ return sql.select(*self.input_cte.c).where(s_key < fraction_sql).order_by(order_by)
597
+
598
+ return self._create_stmt_stratified_fraction(self.sample_clause.fraction)
599
+ else:
600
+ if len(self.stratify_exprs) == 0:
601
+ # No stratification, just return n samples from the input CTE
602
+ order_by = self._create_key_sql(self.input_cte)
603
+ return sql.select(*self.input_cte.c).order_by(order_by).limit(self.sample_clause.n)
604
+
605
+ return self._create_stmt_stratified_n(self.sample_clause.n, self.sample_clause.n_per_stratum)
606
+
607
+ def _create_stmt_stratified_n(self, n: Optional[int], n_per_stratum: Optional[int]) -> sql.Select:
608
+ """Create a Select stmt that returns n samples across all strata or n_per_stratum samples per stratum"""
590
609
 
591
- def _create_stmt_n(self, n: Optional[int], n_per_stratum: Optional[int]) -> sql.Select:
592
- """Create a Select stmt that returns n samples across all strata"""
593
610
  sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
594
- order_by = self._create_order_by(self.input_cte)
611
+ order_by = self._create_key_sql(self.input_cte)
595
612
 
596
613
  # Create a list of all columns plus the rank
597
614
  # Get all columns from the input CTE dynamically
@@ -605,15 +622,15 @@ class SqlSampleNode(SqlNode):
605
622
  if n_per_stratum is not None:
606
623
  return sql.select(*final_columns).filter(row_rank_cte.c.rank <= n_per_stratum)
607
624
  else:
608
- secondary_order = self._create_order_by(row_rank_cte)
625
+ secondary_order = self._create_key_sql(row_rank_cte)
609
626
  return sql.select(*final_columns).order_by(row_rank_cte.c.rank, secondary_order).limit(n)
610
627
 
611
- def _create_stmt_fraction(self, fraction_samples: float) -> sql.Select:
628
+ def _create_stmt_stratified_fraction(self, fraction_samples: float) -> sql.Select:
612
629
  """Create a Select stmt that returns a fraction of the rows per strata"""
613
630
 
614
631
  # Build the strata count CTE
615
632
  # Produces a table of the form:
616
- # ([stratify_exprs], s_s_size)
633
+ # (*stratify_exprs, s_s_size)
617
634
  # where s_s_size is the number of samples to take from each stratum
618
635
  sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
619
636
  per_strata_count_cte = (
@@ -628,19 +645,19 @@ class SqlSampleNode(SqlNode):
628
645
 
629
646
  # Build a CTE that ranks the rows within each stratum
630
647
  # Include all columns from the input CTE dynamically
631
- order_by = self._create_order_by(self.input_cte)
648
+ order_by = self._create_key_sql(self.input_cte)
632
649
  select_columns = [*self.input_cte.c]
633
650
  select_columns.append(
634
651
  sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
635
652
  )
636
653
  row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
637
654
 
638
- # Build the join criterion dynamically to accommodate any number of group by columns
655
+ # Build the join criterion dynamically to accommodate any number of stratify_by expressions
639
656
  join_c = sql.true()
640
657
  for col in per_strata_count_cte.c[:-1]:
641
658
  join_c &= row_rank_cte.c[col.name].isnot_distinct_from(col)
642
659
 
643
- # Join srcp with per_strata_count_cte to limit returns to the requested fraction of rows
660
+ # Join with per_strata_count_cte to limit returns to the requested fraction of rows
644
661
  final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
645
662
  stmt = (
646
663
  sql.select(*final_columns)
@@ -55,17 +55,9 @@ class ColumnPropertyRef(Expr):
55
55
  return self.prop in (self.Property.ERRORTYPE, self.Property.ERRORMSG)
56
56
 
57
57
  def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
58
- if not self._col_ref.col.is_stored:
58
+ if not self._col_ref.col_handle.get().is_stored:
59
59
  return None
60
-
61
- # we need to reestablish that we have the correct Column instance, there could have been a metadata
62
- # reload since init()
63
- # TODO: add an explicit prepare phase (ie, Expr.prepare()) that gives every subclass instance a chance to
64
- # perform runtime checks and update state
65
- tv = self._col_ref.tbl_version.get()
66
- assert tv.is_validated
67
- col = tv.cols_by_id[self._col_ref.col_id]
68
- # TODO: check for column being dropped
60
+ col = self._col_ref.col_handle.get()
69
61
 
70
62
  # the errortype/-msg properties of a read-validated media column need to be extracted from the DataRow
71
63
  if (
@@ -10,6 +10,7 @@ import pixeltable as pxt
10
10
  from pixeltable import catalog, exceptions as excs, iterators as iters
11
11
 
12
12
  from ..utils.description_helper import DescriptionHelper
13
+ from ..utils.filecache import FileCache
13
14
  from .data_row import DataRow
14
15
  from .expr import Expr
15
16
  from .row_builder import RowBuilder
@@ -41,7 +42,8 @@ class ColumnRef(Expr):
41
42
  insert them into the EvalCtxs as needed
42
43
  """
43
44
 
44
- col: catalog.Column
45
+ col: catalog.Column # TODO: merge with col_handle
46
+ col_handle: catalog.ColumnHandle
45
47
  reference_tbl: Optional[catalog.TableVersionPath]
46
48
  is_unstored_iter_col: bool
47
49
  iter_arg_ctx: Optional[RowBuilder.EvalCtx]
@@ -52,10 +54,6 @@ class ColumnRef(Expr):
52
54
  id: int
53
55
  perform_validation: bool # if True, performs media validation
54
56
 
55
- # needed by sql_expr() to re-resolve Column instance after a metadata reload
56
- tbl_version: catalog.TableVersionHandle
57
- col_id: int
58
-
59
57
  def __init__(
60
58
  self,
61
59
  col: catalog.Column,
@@ -66,8 +64,7 @@ class ColumnRef(Expr):
66
64
  assert col.tbl is not None
67
65
  self.col = col
68
66
  self.reference_tbl = reference_tbl
69
- self.tbl_version = catalog.TableVersionHandle(col.tbl.id, col.tbl.effective_version)
70
- self.col_id = col.id
67
+ self.col_handle = catalog.ColumnHandle(col.tbl.handle, col.id)
71
68
 
72
69
  self.is_unstored_iter_col = col.tbl.is_component_view and col.tbl.is_iterator_column(col) and not col.is_stored
73
70
  self.iter_arg_ctx = None
@@ -170,6 +167,20 @@ class ColumnRef(Expr):
170
167
  idx_info = embedding_idx_info
171
168
  return idx_info
172
169
 
170
+ def recompute(self, *, cascade: bool = True, errors_only: bool = False) -> catalog.UpdateStatus:
171
+ cat = catalog.Catalog.get()
172
+ # lock_mutable_tree=True: we need to be able to see whether any transitive view has column dependents
173
+ with cat.begin_xact(tbl=self.reference_tbl, for_write=True, lock_mutable_tree=True):
174
+ tbl_version = self.col_handle.tbl_version.get()
175
+ if tbl_version.id != self.reference_tbl.tbl_id:
176
+ raise excs.Error('Cannot recompute column of a base.')
177
+ if tbl_version.is_snapshot:
178
+ raise excs.Error('Cannot recompute column of a snapshot.')
179
+ col_name = self.col_handle.get().name
180
+ status = tbl_version.recompute_columns([col_name], errors_only=errors_only, cascade=cascade)
181
+ FileCache.get().emit_eviction_warnings()
182
+ return status
183
+
173
184
  def similarity(self, item: Any, *, idx: Optional[str] = None) -> Expr:
174
185
  from .similarity_expr import SimilarityExpr
175
186
 
@@ -239,22 +250,9 @@ class ColumnRef(Expr):
239
250
  return helper
240
251
 
241
252
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
242
- # return None if self.perform_validation else self.col.sa_col
243
253
  if self.perform_validation:
244
254
  return None
245
- # we need to reestablish that we have the correct Column instance, there could have been a metadata
246
- # reload since init()
247
- # TODO: add an explicit prepare phase (ie, Expr.prepare()) that gives every subclass instance a chance to
248
- # perform runtime checks and update state
249
- tv = self.tbl_version.get()
250
- assert tv.is_validated
251
- self.col = tv.cols_by_id[self.col_id]
252
- assert self.col.tbl is tv
253
- # TODO: check for column being dropped
254
- # print(
255
- # f'ColumnRef.sql_expr: tbl={tv.id}:{tv.effective_version} sa_tbl={id(self.col.tbl.store_tbl.sa_tbl):x} '
256
- # f'tv={id(tv):x}'
257
- # )
255
+ self.col = self.col_handle.get()
258
256
  return self.col.sa_col
259
257
 
260
258
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -315,6 +313,11 @@ class ColumnRef(Expr):
315
313
  'perform_validation': self.perform_validation,
316
314
  }
317
315
 
316
+ @classmethod
317
+ def get_column_id(cls, d: dict) -> catalog.QColumnId:
318
+ tbl_id, col_id = UUID(d['tbl_id']), d['col_id']
319
+ return catalog.QColumnId(tbl_id, col_id)
320
+
318
321
  @classmethod
319
322
  def get_column(cls, d: dict) -> catalog.Column:
320
323
  tbl_id, version, col_id = UUID(d['tbl_id']), d['tbl_version'], d['col_id']
@@ -42,6 +42,10 @@ class DataRow:
42
42
  has_val: np.ndarray # of bool
43
43
  excs: np.ndarray # of object
44
44
 
45
+ # If `may_have_exc` is False, then we guarantee that no slot has an exception set. This is used to optimize
46
+ # exception handling under normal operation.
47
+ _may_have_exc: bool
48
+
45
49
  # expr evaluation state; indexed by slot idx
46
50
  missing_slots: np.ndarray # of bool; number of missing dependencies
47
51
  missing_dependents: np.ndarray # of int16; number of missing dependents
@@ -90,6 +94,7 @@ class DataRow:
90
94
  self.vals = np.full(num_slots, None, dtype=object)
91
95
  self.has_val = np.zeros(num_slots, dtype=bool)
92
96
  self.excs = np.full(num_slots, None, dtype=object)
97
+ self._may_have_exc = False
93
98
  self.missing_slots = np.zeros(num_slots, dtype=bool)
94
99
  self.missing_dependents = np.zeros(num_slots, dtype=np.int16)
95
100
  self.is_scheduled = np.zeros(num_slots, dtype=bool)
@@ -136,6 +141,9 @@ class DataRow:
136
141
  """
137
142
  Returns True if an exception has been set for the given slot index, or for any slot index if slot_idx is None
138
143
  """
144
+ if not self._may_have_exc:
145
+ return False
146
+
139
147
  if slot_idx is not None:
140
148
  return self.excs[slot_idx] is not None
141
149
  return (self.excs != None).any()
@@ -154,6 +162,7 @@ class DataRow:
154
162
  def set_exc(self, slot_idx: int, exc: Exception) -> None:
155
163
  assert self.excs[slot_idx] is None
156
164
  self.excs[slot_idx] = exc
165
+ self._may_have_exc = True
157
166
 
158
167
  # an exception means the value is None
159
168
  self.has_val[slot_idx] = True
pixeltable/exprs/expr.py CHANGED
@@ -394,17 +394,17 @@ class Expr(abc.ABC):
394
394
  return {tbl_id for e in exprs_ for tbl_id in e.tbl_ids()}
395
395
 
396
396
  @classmethod
397
- def get_refd_columns(cls, expr_dict: dict[str, Any]) -> list[catalog.Column]:
397
+ def get_refd_column_ids(cls, expr_dict: dict[str, Any]) -> set[catalog.QColumnId]:
398
398
  """Return Columns referenced by expr_dict."""
399
- result: list[catalog.Column] = []
399
+ result: set[catalog.QColumnId] = set()
400
400
  assert '_classname' in expr_dict
401
401
  from .column_ref import ColumnRef
402
402
 
403
403
  if expr_dict['_classname'] == 'ColumnRef':
404
- result.append(ColumnRef.get_column(expr_dict))
404
+ result.add(ColumnRef.get_column_id(expr_dict))
405
405
  if 'components' in expr_dict:
406
406
  for component_dict in expr_dict['components']:
407
- result.extend(cls.get_refd_columns(component_dict))
407
+ result.update(cls.get_refd_column_ids(component_dict))
408
408
  return result
409
409
 
410
410
  def as_literal(self) -> Optional[Expr]: