pixeltable 0.3.15__py3-none-any.whl → 0.4.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (58) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/catalog.py +296 -105
  3. pixeltable/catalog/column.py +10 -8
  4. pixeltable/catalog/dir.py +1 -2
  5. pixeltable/catalog/insertable_table.py +25 -20
  6. pixeltable/catalog/schema_object.py +3 -6
  7. pixeltable/catalog/table.py +261 -189
  8. pixeltable/catalog/table_version.py +333 -202
  9. pixeltable/catalog/table_version_handle.py +15 -2
  10. pixeltable/catalog/table_version_path.py +60 -14
  11. pixeltable/catalog/view.py +38 -6
  12. pixeltable/dataframe.py +196 -18
  13. pixeltable/env.py +4 -4
  14. pixeltable/exec/__init__.py +1 -1
  15. pixeltable/exec/expr_eval/evaluators.py +4 -1
  16. pixeltable/exec/in_memory_data_node.py +1 -1
  17. pixeltable/exec/sql_node.py +171 -22
  18. pixeltable/exprs/column_property_ref.py +15 -6
  19. pixeltable/exprs/column_ref.py +32 -11
  20. pixeltable/exprs/comparison.py +1 -1
  21. pixeltable/exprs/data_row.py +5 -3
  22. pixeltable/exprs/expr.py +7 -0
  23. pixeltable/exprs/literal.py +2 -0
  24. pixeltable/exprs/row_builder.py +4 -6
  25. pixeltable/exprs/rowid_ref.py +8 -0
  26. pixeltable/exprs/similarity_expr.py +1 -0
  27. pixeltable/func/query_template_function.py +1 -1
  28. pixeltable/func/tools.py +1 -1
  29. pixeltable/functions/gemini.py +0 -1
  30. pixeltable/functions/string.py +212 -58
  31. pixeltable/globals.py +12 -4
  32. pixeltable/index/base.py +5 -0
  33. pixeltable/index/btree.py +5 -0
  34. pixeltable/index/embedding_index.py +5 -0
  35. pixeltable/io/external_store.py +8 -29
  36. pixeltable/io/label_studio.py +1 -1
  37. pixeltable/io/parquet.py +2 -2
  38. pixeltable/io/table_data_conduit.py +0 -31
  39. pixeltable/metadata/__init__.py +11 -2
  40. pixeltable/metadata/converters/convert_13.py +2 -2
  41. pixeltable/metadata/converters/convert_30.py +6 -11
  42. pixeltable/metadata/converters/convert_35.py +9 -0
  43. pixeltable/metadata/converters/convert_36.py +38 -0
  44. pixeltable/metadata/converters/util.py +3 -9
  45. pixeltable/metadata/notes.py +2 -0
  46. pixeltable/metadata/schema.py +8 -1
  47. pixeltable/plan.py +221 -14
  48. pixeltable/share/packager.py +137 -13
  49. pixeltable/share/publish.py +2 -2
  50. pixeltable/store.py +19 -13
  51. pixeltable/utils/dbms.py +1 -1
  52. pixeltable/utils/formatter.py +64 -42
  53. pixeltable/utils/sample.py +25 -0
  54. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0rc2.dist-info}/METADATA +2 -1
  55. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0rc2.dist-info}/RECORD +58 -55
  56. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0rc2.dist-info}/LICENSE +0 -0
  57. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0rc2.dist-info}/WHEEL +0 -0
  58. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0rc2.dist-info}/entry_points.txt +0 -0
@@ -23,7 +23,6 @@ from .utils import normalize_schema_names
23
23
 
24
24
  _logger = logging.getLogger('pixeltable')
25
25
 
26
- # ---------------------------------------------------------------------------------------------------------
27
26
 
28
27
  if TYPE_CHECKING:
29
28
  import datasets # type: ignore[import-untyped]
@@ -46,9 +45,6 @@ class TableDataConduitFormat(str, enum.Enum):
46
45
  return False
47
46
 
48
47
 
49
- # ---------------------------------------------------------------------------------------------------------
50
-
51
-
52
48
  @dataclass
53
49
  class TableDataConduit:
54
50
  source: TableDataSource
@@ -129,9 +125,6 @@ class TableDataConduit:
129
125
  raise excs.Error(f'Missing required column(s) ({", ".join(missing_cols)})')
130
126
 
131
127
 
132
- # ---------------------------------------------------------------------------------------------------------
133
-
134
-
135
128
  class DFTableDataConduit(TableDataConduit):
136
129
  pxt_df: pxt.DataFrame = None
137
130
 
@@ -155,9 +148,6 @@ class DFTableDataConduit(TableDataConduit):
155
148
  self.check_source_columns_are_insertable(self.pxt_df.schema.keys())
156
149
 
157
150
 
158
- # ---------------------------------------------------------------------------------------------------------
159
-
160
-
161
151
  class RowDataTableDataConduit(TableDataConduit):
162
152
  raw_rows: Optional[RowData] = None
163
153
  disable_mapping: bool = True
@@ -235,9 +225,6 @@ class RowDataTableDataConduit(TableDataConduit):
235
225
  yield self.valid_rows
236
226
 
237
227
 
238
- # ---------------------------------------------------------------------------------------------------------
239
-
240
-
241
228
  class PandasTableDataConduit(TableDataConduit):
242
229
  pd_df: pd.DataFrame = None
243
230
  batch_count: int = 0
@@ -293,9 +280,6 @@ class PandasTableDataConduit(TableDataConduit):
293
280
  yield self.valid_rows
294
281
 
295
282
 
296
- # ---------------------------------------------------------------------------------------------------------
297
-
298
-
299
283
  class CSVTableDataConduit(TableDataConduit):
300
284
  @classmethod
301
285
  def from_tds(cls, tds: TableDataConduit) -> 'PandasTableDataConduit':
@@ -307,9 +291,6 @@ class CSVTableDataConduit(TableDataConduit):
307
291
  return PandasTableDataConduit.from_tds(t)
308
292
 
309
293
 
310
- # ---------------------------------------------------------------------------------------------------------
311
-
312
-
313
294
  class ExcelTableDataConduit(TableDataConduit):
314
295
  @classmethod
315
296
  def from_tds(cls, tds: TableDataConduit) -> 'PandasTableDataConduit':
@@ -321,9 +302,6 @@ class ExcelTableDataConduit(TableDataConduit):
321
302
  return PandasTableDataConduit.from_tds(t)
322
303
 
323
304
 
324
- # ---------------------------------------------------------------------------------------------------------
325
-
326
-
327
305
  class JsonTableDataConduit(TableDataConduit):
328
306
  @classmethod
329
307
  def from_tds(cls, tds: TableDataConduit) -> RowDataTableDataConduit:
@@ -346,9 +324,6 @@ class JsonTableDataConduit(TableDataConduit):
346
324
  return t2
347
325
 
348
326
 
349
- # ---------------------------------------------------------------------------------------------------------
350
-
351
-
352
327
  class HFTableDataConduit(TableDataConduit):
353
328
  hf_ds: Optional[Union[datasets.Dataset, datasets.DatasetDict]] = None
354
329
  column_name_for_split: Optional[str] = None
@@ -478,9 +453,6 @@ class HFTableDataConduit(TableDataConduit):
478
453
  yield batch
479
454
 
480
455
 
481
- # ---------------------------------------------------------------------------------------------------------
482
-
483
-
484
456
  class ParquetTableDataConduit(TableDataConduit):
485
457
  pq_ds: Optional[ParquetDataset] = None
486
458
 
@@ -542,9 +514,6 @@ class ParquetTableDataConduit(TableDataConduit):
542
514
  raise e
543
515
 
544
516
 
545
- # ---------------------------------------------------------------------------------------------------------
546
-
547
-
548
517
  class UnkTableDataConduit(TableDataConduit):
549
518
  """Source type is not known at the time of creation"""
550
519
 
@@ -8,15 +8,17 @@ from typing import Callable
8
8
  import sqlalchemy as sql
9
9
  from sqlalchemy import orm
10
10
 
11
+ import pixeltable as pxt
12
+ import pixeltable.exceptions as excs
11
13
  from pixeltable.utils.console_output import ConsoleLogger
12
14
 
13
15
  from .schema import SystemInfo, SystemInfoMd
14
16
 
15
17
  _console_logger = ConsoleLogger(logging.getLogger('pixeltable'))
16
-
18
+ _logger = logging.getLogger('pixeltable')
17
19
 
18
20
  # current version of the metadata; this is incremented whenever the metadata schema changes
19
- VERSION = 35
21
+ VERSION = 37
20
22
 
21
23
 
22
24
  def create_system_info(engine: sql.engine.Engine) -> None:
@@ -55,6 +57,13 @@ def upgrade_md(engine: sql.engine.Engine) -> None:
55
57
  system_info = session.query(SystemInfo).one().md
56
58
  md_version = system_info['schema_version']
57
59
  assert isinstance(md_version, int)
60
+ _logger.info(f'Current database version: {md_version}, installed version: {VERSION}')
61
+ if md_version > VERSION:
62
+ raise excs.Error(
63
+ 'This Pixeltable database was created with a newer Pixeltable version '
64
+ f'than the one currently installed ({pxt.__version__}).\n'
65
+ 'Please update to the latest Pixeltable version by running: pip install --upgrade pixeltable'
66
+ )
58
67
  if md_version == VERSION:
59
68
  return
60
69
  while md_version < VERSION:
@@ -12,9 +12,9 @@ _logger = logging.getLogger('pixeltable')
12
12
  @register_converter(version=13)
13
13
  def _(engine: sql.engine.Engine) -> None:
14
14
  with engine.begin() as conn:
15
- for row in conn.execute(sql.select(Table)):
15
+ for row in conn.execute(sql.select(Table.id, Table.md)):
16
16
  id = row[0]
17
- md = row[2]
17
+ md = row[1]
18
18
  updated_md = __update_md(md)
19
19
  if updated_md != md:
20
20
  _logger.info(f'Updating schema for table: {id}')
@@ -1,33 +1,28 @@
1
1
  import copy
2
+ from uuid import UUID
2
3
 
3
4
  import sqlalchemy as sql
4
5
 
5
6
  from pixeltable.metadata import register_converter
6
7
  from pixeltable.metadata.converters.util import (
7
- convert_table_record,
8
+ convert_table_md,
8
9
  convert_table_schema_version_record,
9
10
  convert_table_version_record,
10
11
  )
11
- from pixeltable.metadata.schema import Table, TableSchemaVersion, TableVersion
12
+ from pixeltable.metadata.schema import TableSchemaVersion, TableVersion
12
13
 
13
14
 
14
15
  @register_converter(version=30)
15
16
  def _(engine: sql.engine.Engine) -> None:
16
- convert_table_record(engine, table_record_updater=__update_table_record)
17
+ convert_table_md(engine, table_md_updater=__update_table_md)
17
18
  convert_table_version_record(engine, table_version_record_updater=__update_table_version_record)
18
19
  convert_table_schema_version_record(
19
20
  engine, table_schema_version_record_updater=__update_table_schema_version_record
20
21
  )
21
22
 
22
23
 
23
- def __update_table_record(record: Table) -> None:
24
- """
25
- Update TableMd with table_id
26
- """
27
- assert isinstance(record.md, dict)
28
- md = copy.copy(record.md)
29
- md['tbl_id'] = str(record.id)
30
- record.md = md
24
+ def __update_table_md(md: dict, tbl_id: UUID) -> None:
25
+ md['tbl_id'] = str(tbl_id)
31
26
 
32
27
 
33
28
  def __update_table_version_record(record: TableVersion) -> None:
@@ -0,0 +1,9 @@
1
+ import sqlalchemy as sql
2
+
3
+ from pixeltable.metadata import register_converter
4
+
5
+
6
+ @register_converter(version=35)
7
+ def _(engine: sql.engine.Engine) -> None:
8
+ with engine.begin() as conn:
9
+ conn.execute(sql.text('ALTER TABLE tables ADD COLUMN lock_dummy int8'))
@@ -0,0 +1,38 @@
1
+ import logging
2
+ from typing import Any, Optional
3
+ from uuid import UUID
4
+
5
+ import sqlalchemy as sql
6
+
7
+ from pixeltable.metadata import register_converter
8
+ from pixeltable.metadata.converters.util import convert_table_md
9
+
10
+ _logger = logging.getLogger('pixeltable')
11
+
12
+
13
+ @register_converter(version=36)
14
+ def _(engine: sql.engine.Engine) -> None:
15
+ convert_table_md(engine, table_md_updater=__update_table_md, substitution_fn=__substitute_md)
16
+
17
+
18
+ def __update_table_md(table_md: dict, table_id: UUID) -> None:
19
+ """Update the view metadata to add the sample_clause field if it is missing
20
+
21
+ Args:
22
+ table_md (dict): copy of the original table metadata. this gets updated in place.
23
+ table_id (UUID): the table id
24
+
25
+ """
26
+ if table_md['view_md'] is None:
27
+ return
28
+ if 'sample_clause' not in table_md['view_md']:
29
+ table_md['view_md']['sample_clause'] = None
30
+ _logger.info(f'Updating view metadata for table: {table_id}')
31
+
32
+
33
+ def __substitute_md(k: Optional[str], v: Any) -> Optional[tuple[Optional[str], Any]]:
34
+ if isinstance(v, dict) and (v.get('_classname') == 'DataFrame'):
35
+ if 'sample_clause' not in v:
36
+ v['sample_clause'] = None
37
+ return k, v
38
+ return None
@@ -33,9 +33,10 @@ def convert_table_md(
33
33
  the original entry will be replaced, and the traversal will continue with `v'`.
34
34
  """
35
35
  with engine.begin() as conn:
36
- for row in conn.execute(sql.select(Table)):
36
+ # avoid a SELECT * here, which breaks when we add new columns to Table
37
+ for row in conn.execute(sql.select(Table.id, Table.md)):
37
38
  tbl_id = row[0]
38
- table_md = row[2]
39
+ table_md = row[1]
39
40
  assert isinstance(table_md, dict)
40
41
  updated_table_md = copy.deepcopy(table_md)
41
42
  if table_md_updater is not None:
@@ -145,13 +146,6 @@ def __update_schema_column(table_schema_version_md: dict, schema_column_updater:
145
146
  schema_column_updater(schema_col)
146
147
 
147
148
 
148
- def convert_table_record(engine: sql.engine.Engine, table_record_updater: Optional[Callable[[Table], None]]) -> None:
149
- with sql.orm.Session(engine, future=True) as session:
150
- for record in session.query(Table).all():
151
- table_record_updater(record)
152
- session.commit()
153
-
154
-
155
149
  def convert_table_version_record(
156
150
  engine: sql.engine.Engine, table_version_record_updater: Optional[Callable[[TableVersion], None]]
157
151
  ) -> None:
@@ -2,6 +2,8 @@
2
2
  # rather than as a comment, so that the existence of a description can be enforced by
3
3
  # the unit tests when new versions are added.
4
4
  VERSION_NOTES = {
5
+ 37: 'Add support for the sample() method on DataFrames',
6
+ 36: 'Added Table.lock_dummy',
5
7
  35: 'Track reference_tbl in ColumnRef',
6
8
  34: 'Set default value for is_pk field in column metadata to False',
7
9
  33: 'Add is_replica field to table metadata',
@@ -84,7 +84,8 @@ class Dir(Base):
84
84
  )
85
85
  parent_id: orm.Mapped[uuid.UUID] = orm.mapped_column(UUID(as_uuid=True), ForeignKey('dirs.id'), nullable=True)
86
86
  md: orm.Mapped[dict[str, Any]] = orm.mapped_column(JSONB, nullable=False) # DirMd
87
- # This field is updated to synchronize database operations across multiple sessions
87
+
88
+ # used to force acquisition of an X-lock via an Update stmt
88
89
  lock_dummy: orm.Mapped[int] = orm.mapped_column(BigInteger, nullable=True)
89
90
 
90
91
 
@@ -146,6 +147,9 @@ class ViewMd:
146
147
  # filter predicate applied to the base table; view-only
147
148
  predicate: Optional[dict[str, Any]]
148
149
 
150
+ # sampling predicate applied to the base table; view-only
151
+ sample_clause: Optional[dict[str, Any]]
152
+
149
153
  # ComponentIterator subclass; only for component views
150
154
  iterator_class_fqn: Optional[str]
151
155
 
@@ -200,6 +204,9 @@ class Table(Base):
200
204
  dir_id: orm.Mapped[uuid.UUID] = orm.mapped_column(UUID(as_uuid=True), ForeignKey('dirs.id'), nullable=False)
201
205
  md: orm.Mapped[dict[str, Any]] = orm.mapped_column(JSONB, nullable=False) # TableMd
202
206
 
207
+ # used to force acquisition of an X-lock via an Update stmt
208
+ lock_dummy: orm.Mapped[int] = orm.mapped_column(BigInteger, nullable=True)
209
+
203
210
 
204
211
  @dataclasses.dataclass
205
212
  class TableVersionMd:
pixeltable/plan.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import dataclasses
4
4
  import enum
5
5
  from textwrap import dedent
6
- from typing import Any, Iterable, Literal, Optional, Sequence
6
+ from typing import Any, Iterable, Literal, NamedTuple, Optional, Sequence
7
7
  from uuid import UUID
8
8
 
9
9
  import sqlalchemy as sql
@@ -12,6 +12,7 @@ import pixeltable as pxt
12
12
  from pixeltable import catalog, exceptions as excs, exec, exprs
13
13
  from pixeltable.catalog import Column, TableVersionHandle
14
14
  from pixeltable.exec.sql_node import OrderByClause, OrderByItem, combine_order_by_clauses, print_order_by_clause
15
+ from pixeltable.utils.sample import sample_key
15
16
 
16
17
 
17
18
  def _is_agg_fn_call(e: exprs.Expr) -> bool:
@@ -75,6 +76,98 @@ class FromClause:
75
76
  tbls: list[catalog.TableVersionPath]
76
77
  join_clauses: list[JoinClause] = dataclasses.field(default_factory=list)
77
78
 
79
+ @property
80
+ def _first_tbl(self) -> catalog.TableVersionPath:
81
+ assert len(self.tbls) == 1
82
+ return self.tbls[0]
83
+
84
+
85
+ @dataclasses.dataclass
86
+ class SampleClause:
87
+ """Defines a sampling clause for a table."""
88
+
89
+ version: Optional[int]
90
+ n: Optional[int]
91
+ n_per_stratum: Optional[int]
92
+ fraction: Optional[float]
93
+ seed: Optional[int]
94
+ stratify_exprs: Optional[list[exprs.Expr]]
95
+
96
+ # This seed value is used if one is not supplied
97
+ DEFAULT_SEED = 0
98
+
99
+ # The version of the hashing algorithm used for ordering and fractional sampling.
100
+ CURRENT_VERSION = 1
101
+
102
+ def __post_init__(self) -> None:
103
+ """If no version was provided, provide the default version"""
104
+ if self.version is None:
105
+ self.version = self.CURRENT_VERSION
106
+ if self.seed is None:
107
+ self.seed = self.DEFAULT_SEED
108
+
109
+ @property
110
+ def is_stratified(self) -> bool:
111
+ """Check if the sampling is stratified"""
112
+ return self.stratify_exprs is not None and len(self.stratify_exprs) > 0
113
+
114
+ @property
115
+ def is_repeatable(self) -> bool:
116
+ """Return true if the same rows will continue to be sampled if source rows are added or deleted."""
117
+ return not self.is_stratified and self.fraction is not None
118
+
119
+ def display_str(self, inline: bool = False) -> str:
120
+ return str(self)
121
+
122
+ def as_dict(self) -> dict:
123
+ """Return a dictionary representation of the object"""
124
+ d = dataclasses.asdict(self)
125
+ d['_classname'] = self.__class__.__name__
126
+ if self.is_stratified:
127
+ d['stratify_exprs'] = [e.as_dict() for e in self.stratify_exprs]
128
+ return d
129
+
130
+ @classmethod
131
+ def from_dict(cls, d: dict) -> SampleClause:
132
+ """Create a SampleClause from a dictionary representation"""
133
+ d_cleaned = {key: value for key, value in d.items() if key != '_classname'}
134
+ s = cls(**d_cleaned)
135
+ if s.is_stratified:
136
+ s.stratify_exprs = [exprs.Expr.from_dict(e) for e in d_cleaned.get('stratify_exprs', [])]
137
+ return s
138
+
139
+ def __repr__(self) -> str:
140
+ s = ','.join(e.display_str(inline=True) for e in self.stratify_exprs)
141
+ return (
142
+ f'sample_{self.version}(n={self.n}, n_per_stratum={self.n_per_stratum}, '
143
+ f'fraction={self.fraction}, seed={self.seed}, [{s}])'
144
+ )
145
+
146
+ @classmethod
147
+ def fraction_to_md5_hex(cls, fraction: float) -> str:
148
+ """Return the string representation of an approximation (to ~1e-9) of a fraction of the total space
149
+ of md5 hash values.
150
+ This is used for fractional sampling.
151
+ """
152
+ # Maximum count for the upper 32 bits of MD5: 2^32
153
+ max_md5_value = (2**32) - 1
154
+
155
+ # Calculate the fraction of this value
156
+ threshold_int = max_md5_value * int(1_000_000_000 * fraction) // 1_000_000_000
157
+
158
+ # Convert to hexadecimal string with padding
159
+ return format(threshold_int, '08x') + 'ffffffffffffffffffffffff'
160
+
161
+
162
+ class SamplingClauses(NamedTuple):
163
+ """Clauses provided when rewriting a SampleClause"""
164
+
165
+ where: exprs.Expr
166
+ group_by_clause: Optional[list[exprs.Expr]]
167
+ order_by_clause: Optional[list[tuple[exprs.Expr, bool]]]
168
+ limit: Optional[exprs.Expr]
169
+ sample_clause: Optional[SampleClause]
170
+
78
171
 
79
172
  class Analyzer:
80
173
  """
@@ -260,7 +353,7 @@ class Planner:
260
353
  # TODO: create an exec.CountNode and change this to create_count_plan()
261
354
  @classmethod
262
355
  def create_count_stmt(cls, tbl: catalog.TableVersionPath, where_clause: Optional[exprs.Expr] = None) -> sql.Select:
263
- stmt = sql.select(sql.func.count())
356
+ stmt = sql.select(sql.func.count().label('all_count'))
264
357
  refd_tbl_ids: set[UUID] = set()
265
358
  if where_clause is not None:
266
359
  analyzer = cls.analyze(tbl, where_clause)
@@ -289,7 +382,7 @@ class Planner:
289
382
 
290
383
  # create InMemoryDataNode for 'rows'
291
384
  plan: exec.ExecNode = exec.InMemoryDataNode(
292
- TableVersionHandle(tbl.id, tbl.effective_version), rows, row_builder, tbl.next_rowid
385
+ TableVersionHandle(tbl.id, tbl.effective_version), rows, row_builder, tbl.next_row_id
293
386
  )
294
387
 
295
388
  media_input_col_info = [
@@ -322,6 +415,13 @@ class Planner:
322
415
  )
323
416
  return plan
324
417
 
418
+ @classmethod
419
+ def rowid_columns(cls, target: TableVersionHandle, num_rowid_cols: Optional[int] = None) -> list[exprs.Expr]:
420
+ """Return list of RowidRef for the given number of associated rowids"""
421
+ if num_rowid_cols is None:
422
+ num_rowid_cols = target.get().num_rowid_columns()
423
+ return [exprs.RowidRef(target, i) for i in range(num_rowid_cols)]
424
+
325
425
  @classmethod
326
426
  def create_df_insert_plan(
327
427
  cls, tbl: catalog.TableVersion, df: 'pxt.DataFrame', ignore_errors: bool
@@ -385,7 +485,7 @@ class Planner:
385
485
 
386
486
  cls.__check_valid_columns(tbl.tbl_version.get(), recomputed_cols, 'updated in')
387
487
 
388
- recomputed_base_cols = {col for col in recomputed_cols if col.tbl == tbl.tbl_version}
488
+ recomputed_base_cols = {col for col in recomputed_cols if col.tbl.id == tbl.tbl_version.id}
389
489
  copied_cols = [
390
490
  col
391
491
  for col in target.cols_by_id.values()
@@ -409,7 +509,7 @@ class Planner:
409
509
  for i, col in enumerate(all_base_cols):
410
510
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
411
511
  recomputed_user_cols = [c for c in recomputed_cols if c.name is not None]
412
- return plan, [f'{c.tbl.get().name}.{c.name}' for c in updated_cols + recomputed_user_cols], recomputed_user_cols
512
+ return plan, [f'{c.tbl.name}.{c.name}' for c in updated_cols + recomputed_user_cols], recomputed_user_cols
413
513
 
414
514
  @classmethod
415
515
  def __check_valid_columns(
@@ -465,7 +565,7 @@ class Planner:
465
565
  recomputed_cols.update(idx_val_cols)
466
566
  # we only need to recompute stored columns (unstored ones are substituted away)
467
567
  recomputed_cols = {c for c in recomputed_cols if c.is_stored}
468
- recomputed_base_cols = {col for col in recomputed_cols if col.tbl == target}
568
+ recomputed_base_cols = {col for col in recomputed_cols if col.tbl.id == target.id}
469
569
  copied_cols = [
470
570
  col
471
571
  for col in target.cols_by_id.values()
@@ -591,7 +691,24 @@ class Planner:
591
691
  # 2. for component views: iterator args
592
692
  iterator_args = [target.iterator_args] if target.iterator_args is not None else []
593
693
 
594
- row_builder = exprs.RowBuilder(iterator_args, stored_cols, [])
694
+ # If this contains a sample specification, modify / create where, group_by, order_by, and limit clauses
695
+ from_clause = FromClause(tbls=[view.base])
696
+ where, group_by_clause, order_by_clause, limit, sample_clause = cls.create_sample_clauses(
697
+ from_clause, target.sample_clause, target.predicate, None, [], None
698
+ )
699
+
700
+ # if we're propagating an insert, we only want to see those base rows that were created for the current version
701
+ base_analyzer = Analyzer(
702
+ from_clause,
703
+ iterator_args,
704
+ where_clause=where,
705
+ group_by_clause=group_by_clause,
706
+ order_by_clause=order_by_clause,
707
+ )
708
+ row_builder = exprs.RowBuilder(base_analyzer.all_exprs, stored_cols, [])
709
+
710
+ if target.sample_clause is not None and base_analyzer.filter is not None:
711
+ raise excs.Error(f'Filter {base_analyzer.filter} not expressible in SQL')
595
712
 
596
713
  # execution plan:
597
714
  # 1. materialize exprs computed from the base that are needed for stored view columns
@@ -603,13 +720,22 @@ class Planner:
603
720
  for e in row_builder.default_eval_ctx.target_exprs
604
721
  if e.is_bound_by([view]) and not e.is_bound_by([view.base])
605
722
  ]
606
- # if we're propagating an insert, we only want to see those base rows that were created for the current version
607
- base_analyzer = Analyzer(FromClause(tbls=[view.base]), base_output_exprs, where_clause=target.predicate)
723
+
724
+ # Create a new analyzer reflecting exactly what is required from the base table
725
+ base_analyzer = Analyzer(
726
+ from_clause,
727
+ base_output_exprs,
728
+ where_clause=where,
729
+ group_by_clause=group_by_clause,
730
+ order_by_clause=order_by_clause,
731
+ )
608
732
  base_eval_ctx = row_builder.create_eval_ctx(base_analyzer.all_exprs)
609
733
  plan = cls._create_query_plan(
610
734
  row_builder=row_builder,
611
735
  analyzer=base_analyzer,
612
736
  eval_ctx=base_eval_ctx,
737
+ limit=limit,
738
+ sample_clause=sample_clause,
613
739
  with_pk=True,
614
740
  exact_version_only=view.get_bases() if propagates_insert else [],
615
741
  )
@@ -692,6 +818,62 @@ class Planner:
692
818
  prefetch_node = exec.CachePrefetchNode(tbl_id, file_col_info, input_node)
693
819
  return prefetch_node
694
820
 
821
+ @classmethod
822
+ def create_sample_clauses(
823
+ cls,
824
+ from_clause: FromClause,
825
+ sample_clause: SampleClause,
826
+ where_clause: Optional[exprs.Expr],
827
+ group_by_clause: Optional[list[exprs.Expr]],
828
+ order_by_clause: Optional[list[tuple[exprs.Expr, bool]]],
829
+ limit: Optional[exprs.Expr],
830
+ ) -> SamplingClauses:
831
+ """tuple[
832
+ exprs.Expr,
833
+ Optional[list[exprs.Expr]],
834
+ Optional[list[tuple[exprs.Expr, bool]]],
835
+ Optional[exprs.Expr],
836
+ Optional[SampleClause],
837
+ ]:"""
838
+ """Construct clauses required for sampling under various conditions.
839
+ If there is no sampling, then return the original clauses.
840
+ If the sample is stratified, then return only the group by clause. The rest of the
841
+ mechanism for stratified sampling is provided by the SampleSqlNode.
842
+ If the sample is non-stratified, then rewrite the query to accommodate the supplied where clause,
843
+ and provide the other clauses required for sampling
844
+ """
845
+
846
+ # If no sample clause, return the original clauses
847
+ if sample_clause is None:
848
+ return SamplingClauses(where_clause, group_by_clause, order_by_clause, limit, None)
849
+
850
+ # If the sample clause is stratified, create a group by clause
851
+ if sample_clause.is_stratified:
852
+ group_by = sample_clause.stratify_exprs
853
+ # Note that limit is not possible here
854
+ return SamplingClauses(where_clause, group_by, order_by_clause, None, sample_clause)
855
+
856
+ else:
857
+ # If non-stratified sampling, construct a where clause, order_by, and limit clauses
858
+ # Construct an expression for sorting rows and limiting row counts
859
+ s_key = sample_key(
860
+ exprs.Literal(sample_clause.seed), *cls.rowid_columns(from_clause._first_tbl.tbl_version)
861
+ )
862
+
863
+ # Construct a suitable where clause
864
+ where = where_clause
865
+ if sample_clause.fraction is not None:
866
+ fraction_md5_hex = exprs.Expr.from_object(
867
+ sample_clause.fraction_to_md5_hex(float(sample_clause.fraction))
868
+ )
869
+ f_where = s_key < fraction_md5_hex
870
+ where = where & f_where if where is not None else f_where
871
+
872
+ order_by: list[tuple[exprs.Expr, bool]] = [(s_key, True)]
873
+ limit = exprs.Literal(sample_clause.n)
874
+ # Note that group_by is not possible here
875
+ return SamplingClauses(where, None, order_by, limit, None)
876
+
695
877
  @classmethod
696
878
  def create_query_plan(
697
879
  cls,
@@ -701,6 +883,7 @@ class Planner:
701
883
  group_by_clause: Optional[list[exprs.Expr]] = None,
702
884
  order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None,
703
885
  limit: Optional[exprs.Expr] = None,
886
+ sample_clause: Optional[SampleClause] = None,
704
887
  ignore_errors: bool = False,
705
888
  exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
706
889
  ) -> exec.ExecNode:
@@ -714,14 +897,22 @@ class Planner:
714
897
  order_by_clause = []
715
898
  if exact_version_only is None:
716
899
  exact_version_only = []
900
+
901
+ # Modify clauses to include sample clause
902
+ where, group_by_clause, order_by_clause, limit, sample = cls.create_sample_clauses(
903
+ from_clause, sample_clause, where_clause, group_by_clause, order_by_clause, limit
904
+ )
905
+
717
906
  analyzer = Analyzer(
718
907
  from_clause,
719
908
  select_list,
720
- where_clause=where_clause,
909
+ where_clause=where,
721
910
  group_by_clause=group_by_clause,
722
911
  order_by_clause=order_by_clause,
723
912
  )
724
913
  row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [])
914
+ if sample_clause is not None and analyzer.filter is not None:
915
+ raise excs.Error(f'Filter {analyzer.filter} not expressible in SQL')
725
916
 
726
917
  analyzer.finalize(row_builder)
727
918
  # select_list: we need to materialize everything that's been collected
@@ -732,6 +923,7 @@ class Planner:
732
923
  analyzer=analyzer,
733
924
  eval_ctx=eval_ctx,
734
925
  limit=limit,
926
+ sample_clause=sample,
735
927
  with_pk=True,
736
928
  exact_version_only=exact_version_only,
737
929
  )
@@ -747,6 +939,7 @@ class Planner:
747
939
  analyzer: Analyzer,
748
940
  eval_ctx: exprs.RowBuilder.EvalCtx,
749
941
  limit: Optional[exprs.Expr] = None,
942
+ sample_clause: Optional[SampleClause] = None,
750
943
  with_pk: bool = False,
751
944
  exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
752
945
  ) -> exec.ExecNode:
@@ -857,12 +1050,26 @@ class Planner:
857
1050
  sql_elements.contains_all(analyzer.select_list)
858
1051
  and sql_elements.contains_all(analyzer.grouping_exprs)
859
1052
  and isinstance(plan, exec.SqlNode)
860
- and plan.to_cte() is not None
1053
+ and plan.to_cte(keep_pk=(sample_clause is not None)) is not None
861
1054
  ):
862
- plan = exec.SqlAggregationNode(
863
- row_builder, input=plan, select_list=analyzer.select_list, group_by_items=analyzer.group_by_clause
864
- )
1055
+ if sample_clause is not None:
1056
+ plan = exec.SqlSampleNode(
1057
+ row_builder,
1058
+ input=plan,
1059
+ select_list=analyzer.select_list,
1060
+ stratify_exprs=analyzer.group_by_clause,
1061
+ sample_clause=sample_clause,
1062
+ )
1063
+ else:
1064
+ plan = exec.SqlAggregationNode(
1065
+ row_builder,
1066
+ input=plan,
1067
+ select_list=analyzer.select_list,
1068
+ group_by_items=analyzer.group_by_clause,
1069
+ )
865
1070
  else:
1071
+ if sample_clause is not None:
1072
+ raise excs.Error('Sample clause not supported with Python aggregation')
866
1073
  input_sql_node = plan.get_node(exec.SqlNode)
867
1074
  assert combined_ordering is not None
868
1075
  input_sql_node.set_order_by(combined_ordering)