pixeltable 0.4.0rc1__py3-none-any.whl → 0.4.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (37) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/catalog.py +4 -0
  3. pixeltable/catalog/table.py +16 -0
  4. pixeltable/catalog/table_version.py +17 -2
  5. pixeltable/catalog/view.py +24 -1
  6. pixeltable/dataframe.py +185 -9
  7. pixeltable/env.py +2 -0
  8. pixeltable/exec/__init__.py +1 -1
  9. pixeltable/exec/expr_eval/evaluators.py +4 -1
  10. pixeltable/exec/sql_node.py +152 -12
  11. pixeltable/exprs/data_row.py +5 -3
  12. pixeltable/exprs/expr.py +7 -0
  13. pixeltable/exprs/literal.py +2 -0
  14. pixeltable/func/tools.py +1 -1
  15. pixeltable/functions/anthropic.py +19 -45
  16. pixeltable/functions/deepseek.py +19 -38
  17. pixeltable/functions/fireworks.py +9 -18
  18. pixeltable/functions/gemini.py +2 -3
  19. pixeltable/functions/llama_cpp.py +6 -6
  20. pixeltable/functions/mistralai.py +15 -41
  21. pixeltable/functions/ollama.py +1 -1
  22. pixeltable/functions/openai.py +82 -165
  23. pixeltable/functions/together.py +22 -80
  24. pixeltable/globals.py +5 -0
  25. pixeltable/metadata/__init__.py +11 -2
  26. pixeltable/metadata/converters/convert_36.py +38 -0
  27. pixeltable/metadata/notes.py +1 -0
  28. pixeltable/metadata/schema.py +3 -0
  29. pixeltable/plan.py +217 -10
  30. pixeltable/share/packager.py +115 -6
  31. pixeltable/utils/formatter.py +64 -42
  32. pixeltable/utils/sample.py +25 -0
  33. {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/METADATA +2 -1
  34. {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/RECORD +37 -35
  35. {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/LICENSE +0 -0
  36. {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/WHEEL +0 -0
  37. {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/entry_points.txt +0 -0
@@ -7,7 +7,7 @@ the [Working with Together AI](https://pixeltable.readme.io/docs/together-ai) tu
7
7
 
8
8
  import base64
9
9
  import io
10
- from typing import TYPE_CHECKING, Callable, Optional, TypeVar
10
+ from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar
11
11
 
12
12
  import numpy as np
13
13
  import PIL.Image
@@ -50,21 +50,7 @@ def _retry(fn: Callable[..., T]) -> Callable[..., T]:
50
50
 
51
51
 
52
52
  @pxt.udf(resource_pool='request-rate:together:chat')
53
- async def completions(
54
- prompt: str,
55
- *,
56
- model: str,
57
- max_tokens: Optional[int] = None,
58
- stop: Optional[list] = None,
59
- temperature: Optional[float] = None,
60
- top_p: Optional[float] = None,
61
- top_k: Optional[int] = None,
62
- repetition_penalty: Optional[float] = None,
63
- logprobs: Optional[int] = None,
64
- echo: Optional[bool] = None,
65
- n: Optional[int] = None,
66
- safety_model: Optional[str] = None,
67
- ) -> dict:
53
+ async def completions(prompt: str, *, model: str, model_kwargs: Optional[dict[str, Any]] = None) -> dict:
68
54
  """
69
55
  Generate completions based on a given prompt using a specified model.
70
56
 
@@ -82,8 +68,8 @@ async def completions(
82
68
  Args:
83
69
  prompt: A string providing context for the model to complete.
84
70
  model: The name of the model to query.
85
-
86
- For details on the other parameters, see: <https://docs.together.ai/reference/completions-1>
71
+ model_kwargs: Additional keyword arguments for the Together `completions` API.
72
+ For details on the available parameters, see: <https://docs.together.ai/reference/completions-1>
87
73
 
88
74
  Returns:
89
75
  A dictionary containing the response and other metadata.
@@ -94,41 +80,16 @@ async def completions(
94
80
 
95
81
  >>> tbl.add_computed_column(response=completions(tbl.prompt, model='mistralai/Mixtral-8x7B-v0.1'))
96
82
  """
97
- result = await _together_client().completions.create(
98
- prompt=prompt,
99
- model=model,
100
- max_tokens=max_tokens,
101
- stop=stop,
102
- temperature=temperature,
103
- top_p=top_p,
104
- top_k=top_k,
105
- repetition_penalty=repetition_penalty,
106
- logprobs=logprobs,
107
- echo=echo,
108
- n=n,
109
- safety_model=safety_model,
110
- )
83
+ if model_kwargs is None:
84
+ model_kwargs = {}
85
+
86
+ result = await _together_client().completions.create(prompt=prompt, model=model, **model_kwargs)
111
87
  return result.dict()
112
88
 
113
89
 
114
90
  @pxt.udf(resource_pool='request-rate:together:chat')
115
91
  async def chat_completions(
116
- messages: list[dict[str, str]],
117
- *,
118
- model: str,
119
- max_tokens: Optional[int] = None,
120
- stop: Optional[list[str]] = None,
121
- temperature: Optional[float] = None,
122
- top_p: Optional[float] = None,
123
- top_k: Optional[int] = None,
124
- repetition_penalty: Optional[float] = None,
125
- logprobs: Optional[int] = None,
126
- echo: Optional[bool] = None,
127
- n: Optional[int] = None,
128
- safety_model: Optional[str] = None,
129
- response_format: Optional[dict] = None,
130
- tools: Optional[dict] = None,
131
- tool_choice: Optional[dict] = None,
92
+ messages: list[dict[str, str]], *, model: str, model_kwargs: Optional[dict[str, Any]] = None
132
93
  ) -> dict:
133
94
  """
134
95
  Generate chat completions based on a given prompt using a specified model.
@@ -147,8 +108,8 @@ async def chat_completions(
147
108
  Args:
148
109
  messages: A list of messages comprising the conversation so far.
149
110
  model: The name of the model to query.
150
-
151
- For details on the other parameters, see: <https://docs.together.ai/reference/chat-completions-1>
111
+ model_kwargs: Additional keyword arguments for the Together `chat/completions` API.
112
+ For details on the available parameters, see: <https://docs.together.ai/reference/chat-completions-1>
152
113
 
153
114
  Returns:
154
115
  A dictionary containing the response and other metadata.
@@ -160,23 +121,10 @@ async def chat_completions(
160
121
  >>> messages = [{'role': 'user', 'content': tbl.prompt}]
161
122
  ... tbl.add_computed_column(response=chat_completions(messages, model='mistralai/Mixtral-8x7B-v0.1'))
162
123
  """
163
- result = await _together_client().chat.completions.create(
164
- messages=messages,
165
- model=model,
166
- max_tokens=max_tokens,
167
- stop=stop,
168
- temperature=temperature,
169
- top_p=top_p,
170
- top_k=top_k,
171
- repetition_penalty=repetition_penalty,
172
- logprobs=logprobs,
173
- echo=echo,
174
- n=n,
175
- safety_model=safety_model,
176
- response_format=response_format,
177
- tools=tools,
178
- tool_choice=tool_choice,
179
- )
124
+ if model_kwargs is None:
125
+ model_kwargs = {}
126
+
127
+ result = await _together_client().chat.completions.create(messages=messages, model=model, **model_kwargs)
180
128
  return result.dict()
181
129
 
182
130
 
@@ -236,14 +184,7 @@ def _(model: str) -> ts.ArrayType:
236
184
 
237
185
  @pxt.udf(resource_pool='request-rate:together:images')
238
186
  async def image_generations(
239
- prompt: str,
240
- *,
241
- model: str,
242
- steps: Optional[int] = None,
243
- seed: Optional[int] = None,
244
- height: Optional[int] = None,
245
- width: Optional[int] = None,
246
- negative_prompt: Optional[str] = None,
187
+ prompt: str, *, model: str, model_kwargs: Optional[dict[str, Any]] = None
247
188
  ) -> PIL.Image.Image:
248
189
  """
249
190
  Generate images based on a given prompt using a specified model.
@@ -262,8 +203,8 @@ async def image_generations(
262
203
  Args:
263
204
  prompt: A description of the desired images.
264
205
  model: The model to use for image generation.
265
-
266
- For details on the other parameters, see: <https://docs.together.ai/reference/post_images-generations>
206
+ model_kwargs: Additional keyword args for the Together `images/generations` API.
207
+ For details on the available parameters, see: <https://docs.together.ai/reference/post_images-generations>
267
208
 
268
209
  Returns:
269
210
  The generated image.
@@ -276,9 +217,10 @@ async def image_generations(
276
217
  ... response=image_generations(tbl.prompt, model='stabilityai/stable-diffusion-xl-base-1.0')
277
218
  ... )
278
219
  """
279
- result = await _together_client().images.generate(
280
- prompt=prompt, model=model, steps=steps, seed=seed, height=height, width=width, negative_prompt=negative_prompt
281
- )
220
+ if model_kwargs is None:
221
+ model_kwargs = {}
222
+
223
+ result = await _together_client().images.generate(prompt=prompt, model=model, **model_kwargs)
282
224
  if result.data[0].b64_json is not None:
283
225
  b64_bytes = base64.b64decode(result.data[0].b64_json)
284
226
  img = PIL.Image.open(io.BytesIO(b64_bytes))
pixeltable/globals.py CHANGED
@@ -249,13 +249,17 @@ def create_view(
249
249
  where: Optional[exprs.Expr] = None
250
250
  if isinstance(base, catalog.Table):
251
251
  tbl_version_path = base._tbl_version_path
252
+ sample_clause = None
252
253
  elif isinstance(base, DataFrame):
253
254
  base._validate_mutable('create_view', allow_select=True)
254
255
  if len(base._from_clause.tbls) > 1:
255
256
  raise excs.Error('Cannot create a view of a join')
256
257
  tbl_version_path = base._from_clause.tbls[0]
257
258
  where = base.where_clause
259
+ sample_clause = base.sample_clause
258
260
  select_list = base.select_list
261
+ if sample_clause is not None and not is_snapshot and not sample_clause.is_repeatable:
262
+ raise excs.Error('Non-snapshot views cannot be created with non-fractional or stratified sampling')
259
263
  else:
260
264
  raise excs.Error('`base` must be an instance of `Table` or `DataFrame`')
261
265
  assert isinstance(base, (catalog.Table, DataFrame))
@@ -280,6 +284,7 @@ def create_view(
280
284
  tbl_version_path,
281
285
  select_list=select_list,
282
286
  where=where,
287
+ sample_clause=sample_clause,
283
288
  additional_columns=additional_columns,
284
289
  is_snapshot=is_snapshot,
285
290
  iterator=iterator,
@@ -8,15 +8,17 @@ from typing import Callable
8
8
  import sqlalchemy as sql
9
9
  from sqlalchemy import orm
10
10
 
11
+ import pixeltable as pxt
12
+ import pixeltable.exceptions as excs
11
13
  from pixeltable.utils.console_output import ConsoleLogger
12
14
 
13
15
  from .schema import SystemInfo, SystemInfoMd
14
16
 
15
17
  _console_logger = ConsoleLogger(logging.getLogger('pixeltable'))
16
-
18
+ _logger = logging.getLogger('pixeltable')
17
19
 
18
20
  # current version of the metadata; this is incremented whenever the metadata schema changes
19
- VERSION = 36
21
+ VERSION = 37
20
22
 
21
23
 
22
24
  def create_system_info(engine: sql.engine.Engine) -> None:
@@ -55,6 +57,13 @@ def upgrade_md(engine: sql.engine.Engine) -> None:
55
57
  system_info = session.query(SystemInfo).one().md
56
58
  md_version = system_info['schema_version']
57
59
  assert isinstance(md_version, int)
60
+ _logger.info(f'Current database version: {md_version}, installed version: {VERSION}')
61
+ if md_version > VERSION:
62
+ raise excs.Error(
63
+ 'This Pixeltable database was created with a newer Pixeltable version '
64
+ f'than the one currently installed ({pxt.__version__}).\n'
65
+ 'Please update to the latest Pixeltable version by running: pip install --upgrade pixeltable'
66
+ )
58
67
  if md_version == VERSION:
59
68
  return
60
69
  while md_version < VERSION:
@@ -0,0 +1,38 @@
1
+ import logging
2
+ from typing import Any, Optional
3
+ from uuid import UUID
4
+
5
+ import sqlalchemy as sql
6
+
7
+ from pixeltable.metadata import register_converter
8
+ from pixeltable.metadata.converters.util import convert_table_md
9
+
10
+ _logger = logging.getLogger('pixeltable')
11
+
12
+
13
+ @register_converter(version=36)
14
+ def _(engine: sql.engine.Engine) -> None:
15
+ convert_table_md(engine, table_md_updater=__update_table_md, substitution_fn=__substitute_md)
16
+
17
+
18
+ def __update_table_md(table_md: dict, table_id: UUID) -> None:
19
+ """Update the view metadata to add the sample_clause field if it is missing
20
+
21
+ Args:
22
+ table_md (dict): copy of the original table metadata. this gets updated in place.
23
+ table_id (UUID): the table id
24
+
25
+ """
26
+ if table_md['view_md'] is None:
27
+ return
28
+ if 'sample_clause' not in table_md['view_md']:
29
+ table_md['view_md']['sample_clause'] = None
30
+ _logger.info(f'Updating view metadata for table: {table_id}')
31
+
32
+
33
+ def __substitute_md(k: Optional[str], v: Any) -> Optional[tuple[Optional[str], Any]]:
34
+ if isinstance(v, dict) and (v.get('_classname') == 'DataFrame'):
35
+ if 'sample_clause' not in v:
36
+ v['sample_clause'] = None
37
+ return k, v
38
+ return None
@@ -2,6 +2,7 @@
2
2
  # rather than as a comment, so that the existence of a description can be enforced by
3
3
  # the unit tests when new versions are added.
4
4
  VERSION_NOTES = {
5
+ 37: 'Add support for the sample() method on DataFrames',
5
6
  36: 'Added Table.lock_dummy',
6
7
  35: 'Track reference_tbl in ColumnRef',
7
8
  34: 'Set default value for is_pk field in column metadata to False',
@@ -147,6 +147,9 @@ class ViewMd:
147
147
  # filter predicate applied to the base table; view-only
148
148
  predicate: Optional[dict[str, Any]]
149
149
 
150
+ # sampling predicate applied to the base table; view-only
151
+ sample_clause: Optional[dict[str, Any]]
152
+
150
153
  # ComponentIterator subclass; only for component views
151
154
  iterator_class_fqn: Optional[str]
152
155
 
pixeltable/plan.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import dataclasses
4
4
  import enum
5
5
  from textwrap import dedent
6
- from typing import Any, Iterable, Literal, Optional, Sequence
6
+ from typing import Any, Iterable, Literal, NamedTuple, Optional, Sequence
7
7
  from uuid import UUID
8
8
 
9
9
  import sqlalchemy as sql
@@ -12,6 +12,7 @@ import pixeltable as pxt
12
12
  from pixeltable import catalog, exceptions as excs, exec, exprs
13
13
  from pixeltable.catalog import Column, TableVersionHandle
14
14
  from pixeltable.exec.sql_node import OrderByClause, OrderByItem, combine_order_by_clauses, print_order_by_clause
15
+ from pixeltable.utils.sample import sample_key
15
16
 
16
17
 
17
18
  def _is_agg_fn_call(e: exprs.Expr) -> bool:
@@ -75,6 +76,98 @@ class FromClause:
75
76
  tbls: list[catalog.TableVersionPath]
76
77
  join_clauses: list[JoinClause] = dataclasses.field(default_factory=list)
77
78
 
79
+ @property
80
+ def _first_tbl(self) -> catalog.TableVersionPath:
81
+ assert len(self.tbls) == 1
82
+ return self.tbls[0]
83
+
84
+
85
+ @dataclasses.dataclass
86
+ class SampleClause:
87
+ """Defines a sampling clause for a table."""
88
+
89
+ version: Optional[int]
90
+ n: Optional[int]
91
+ n_per_stratum: Optional[int]
92
+ fraction: Optional[float]
93
+ seed: Optional[int]
94
+ stratify_exprs: Optional[list[exprs.Expr]]
95
+
96
+ # This seed value is used if one is not supplied
97
+ DEFAULT_SEED = 0
98
+
99
+ # The version of the hashing algorithm used for ordering and fractional sampling.
100
+ CURRENT_VERSION = 1
101
+
102
+ def __post_init__(self) -> None:
103
+ """If no version was provided, provide the default version"""
104
+ if self.version is None:
105
+ self.version = self.CURRENT_VERSION
106
+ if self.seed is None:
107
+ self.seed = self.DEFAULT_SEED
108
+
109
+ @property
110
+ def is_stratified(self) -> bool:
111
+ """Check if the sampling is stratified"""
112
+ return self.stratify_exprs is not None and len(self.stratify_exprs) > 0
113
+
114
+ @property
115
+ def is_repeatable(self) -> bool:
116
+ """Return true if the same rows will continue to be sampled if source rows are added or deleted."""
117
+ return not self.is_stratified and self.fraction is not None
118
+
119
+ def display_str(self, inline: bool = False) -> str:
120
+ return str(self)
121
+
122
+ def as_dict(self) -> dict:
123
+ """Return a dictionary representation of the object"""
124
+ d = dataclasses.asdict(self)
125
+ d['_classname'] = self.__class__.__name__
126
+ if self.is_stratified:
127
+ d['stratify_exprs'] = [e.as_dict() for e in self.stratify_exprs]
128
+ return d
129
+
130
+ @classmethod
131
+ def from_dict(cls, d: dict) -> SampleClause:
132
+ """Create a SampleClause from a dictionary representation"""
133
+ d_cleaned = {key: value for key, value in d.items() if key != '_classname'}
134
+ s = cls(**d_cleaned)
135
+ if s.is_stratified:
136
+ s.stratify_exprs = [exprs.Expr.from_dict(e) for e in d_cleaned.get('stratify_exprs', [])]
137
+ return s
138
+
139
+ def __repr__(self) -> str:
140
+ s = ','.join(e.display_str(inline=True) for e in self.stratify_exprs)
141
+ return (
142
+ f'sample_{self.version}(n={self.n}, n_per_stratum={self.n_per_stratum}, '
143
+ f'fraction={self.fraction}, seed={self.seed}, [{s}])'
144
+ )
145
+
146
+ @classmethod
147
+ def fraction_to_md5_hex(cls, fraction: float) -> str:
148
+ """Return the string representation of an approximation (to ~1e-9) of a fraction of the total space
149
+ of md5 hash values.
150
+ This is used for fractional sampling.
151
+ """
152
+ # Maximum count for the upper 32 bits of MD5: 2^32
153
+ max_md5_value = (2**32) - 1
154
+
155
+ # Calculate the fraction of this value
156
+ threshold_int = max_md5_value * int(1_000_000_000 * fraction) // 1_000_000_000
157
+
158
+ # Convert to hexadecimal string with padding
159
+ return format(threshold_int, '08x') + 'ffffffffffffffffffffffff'
160
+
161
+
162
+ class SamplingClauses(NamedTuple):
163
+ """Clauses provided when rewriting a SampleClause"""
164
+
165
+ where: exprs.Expr
166
+ group_by_clause: Optional[list[exprs.Expr]]
167
+ order_by_clause: Optional[list[tuple[exprs.Expr, bool]]]
168
+ limit: Optional[exprs.Expr]
169
+ sample_clause: Optional[SampleClause]
170
+
78
171
 
79
172
  class Analyzer:
80
173
  """
@@ -260,7 +353,7 @@ class Planner:
260
353
  # TODO: create an exec.CountNode and change this to create_count_plan()
261
354
  @classmethod
262
355
  def create_count_stmt(cls, tbl: catalog.TableVersionPath, where_clause: Optional[exprs.Expr] = None) -> sql.Select:
263
- stmt = sql.select(sql.func.count())
356
+ stmt = sql.select(sql.func.count().label('all_count'))
264
357
  refd_tbl_ids: set[UUID] = set()
265
358
  if where_clause is not None:
266
359
  analyzer = cls.analyze(tbl, where_clause)
@@ -322,6 +415,13 @@ class Planner:
322
415
  )
323
416
  return plan
324
417
 
418
+ @classmethod
419
+ def rowid_columns(cls, target: TableVersionHandle, num_rowid_cols: Optional[int] = None) -> list[exprs.Expr]:
420
+ """Return list of RowidRef for the given number of associated rowids"""
421
+ if num_rowid_cols is None:
422
+ num_rowid_cols = target.get().num_rowid_columns()
423
+ return [exprs.RowidRef(target, i) for i in range(num_rowid_cols)]
424
+
325
425
  @classmethod
326
426
  def create_df_insert_plan(
327
427
  cls, tbl: catalog.TableVersion, df: 'pxt.DataFrame', ignore_errors: bool
@@ -591,7 +691,24 @@ class Planner:
591
691
  # 2. for component views: iterator args
592
692
  iterator_args = [target.iterator_args] if target.iterator_args is not None else []
593
693
 
594
- row_builder = exprs.RowBuilder(iterator_args, stored_cols, [])
694
+ # If this contains a sample specification, modify / create where, group_by, order_by, and limit clauses
695
+ from_clause = FromClause(tbls=[view.base])
696
+ where, group_by_clause, order_by_clause, limit, sample_clause = cls.create_sample_clauses(
697
+ from_clause, target.sample_clause, target.predicate, None, [], None
698
+ )
699
+
700
+ # if we're propagating an insert, we only want to see those base rows that were created for the current version
701
+ base_analyzer = Analyzer(
702
+ from_clause,
703
+ iterator_args,
704
+ where_clause=where,
705
+ group_by_clause=group_by_clause,
706
+ order_by_clause=order_by_clause,
707
+ )
708
+ row_builder = exprs.RowBuilder(base_analyzer.all_exprs, stored_cols, [])
709
+
710
+ if target.sample_clause is not None and base_analyzer.filter is not None:
711
+ raise excs.Error(f'Filter {base_analyzer.filter} not expressible in SQL')
595
712
 
596
713
  # execution plan:
597
714
  # 1. materialize exprs computed from the base that are needed for stored view columns
@@ -603,13 +720,22 @@ class Planner:
603
720
  for e in row_builder.default_eval_ctx.target_exprs
604
721
  if e.is_bound_by([view]) and not e.is_bound_by([view.base])
605
722
  ]
606
- # if we're propagating an insert, we only want to see those base rows that were created for the current version
607
- base_analyzer = Analyzer(FromClause(tbls=[view.base]), base_output_exprs, where_clause=target.predicate)
723
+
724
+ # Create a new analyzer reflecting exactly what is required from the base table
725
+ base_analyzer = Analyzer(
726
+ from_clause,
727
+ base_output_exprs,
728
+ where_clause=where,
729
+ group_by_clause=group_by_clause,
730
+ order_by_clause=order_by_clause,
731
+ )
608
732
  base_eval_ctx = row_builder.create_eval_ctx(base_analyzer.all_exprs)
609
733
  plan = cls._create_query_plan(
610
734
  row_builder=row_builder,
611
735
  analyzer=base_analyzer,
612
736
  eval_ctx=base_eval_ctx,
737
+ limit=limit,
738
+ sample_clause=sample_clause,
613
739
  with_pk=True,
614
740
  exact_version_only=view.get_bases() if propagates_insert else [],
615
741
  )
@@ -692,6 +818,62 @@ class Planner:
692
818
  prefetch_node = exec.CachePrefetchNode(tbl_id, file_col_info, input_node)
693
819
  return prefetch_node
694
820
 
821
+ @classmethod
822
+ def create_sample_clauses(
823
+ cls,
824
+ from_clause: FromClause,
825
+ sample_clause: SampleClause,
826
+ where_clause: Optional[exprs.Expr],
827
+ group_by_clause: Optional[list[exprs.Expr]],
828
+ order_by_clause: Optional[list[tuple[exprs.Expr, bool]]],
829
+ limit: Optional[exprs.Expr],
830
+ ) -> SamplingClauses:
831
+ """tuple[
832
+ exprs.Expr,
833
+ Optional[list[exprs.Expr]],
834
+ Optional[list[tuple[exprs.Expr, bool]]],
835
+ Optional[exprs.Expr],
836
+ Optional[SampleClause],
837
+ ]:"""
838
+ """Construct clauses required for sampling under various conditions.
839
+ If there is no sampling, then return the original clauses.
840
+ If the sample is stratified, then return only the group by clause. The rest of the
841
+ mechanism for stratified sampling is provided by the SampleSqlNode.
842
+ If the sample is non-stratified, then rewrite the query to accommodate the supplied where clause,
843
+ and provide the other clauses required for sampling
844
+ """
845
+
846
+ # If no sample clause, return the original clauses
847
+ if sample_clause is None:
848
+ return SamplingClauses(where_clause, group_by_clause, order_by_clause, limit, None)
849
+
850
+ # If the sample clause is stratified, create a group by clause
851
+ if sample_clause.is_stratified:
852
+ group_by = sample_clause.stratify_exprs
853
+ # Note that limit is not possible here
854
+ return SamplingClauses(where_clause, group_by, order_by_clause, None, sample_clause)
855
+
856
+ else:
857
+ # If non-stratified sampling, construct a where clause, order_by, and limit clauses
858
+ # Construct an expression for sorting rows and limiting row counts
859
+ s_key = sample_key(
860
+ exprs.Literal(sample_clause.seed), *cls.rowid_columns(from_clause._first_tbl.tbl_version)
861
+ )
862
+
863
+ # Construct a suitable where clause
864
+ where = where_clause
865
+ if sample_clause.fraction is not None:
866
+ fraction_md5_hex = exprs.Expr.from_object(
867
+ sample_clause.fraction_to_md5_hex(float(sample_clause.fraction))
868
+ )
869
+ f_where = s_key < fraction_md5_hex
870
+ where = where & f_where if where is not None else f_where
871
+
872
+ order_by: list[tuple[exprs.Expr, bool]] = [(s_key, True)]
873
+ limit = exprs.Literal(sample_clause.n)
874
+ # Note that group_by is not possible here
875
+ return SamplingClauses(where, None, order_by, limit, None)
876
+
695
877
  @classmethod
696
878
  def create_query_plan(
697
879
  cls,
@@ -701,6 +883,7 @@ class Planner:
701
883
  group_by_clause: Optional[list[exprs.Expr]] = None,
702
884
  order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None,
703
885
  limit: Optional[exprs.Expr] = None,
886
+ sample_clause: Optional[SampleClause] = None,
704
887
  ignore_errors: bool = False,
705
888
  exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
706
889
  ) -> exec.ExecNode:
@@ -714,14 +897,22 @@ class Planner:
714
897
  order_by_clause = []
715
898
  if exact_version_only is None:
716
899
  exact_version_only = []
900
+
901
+ # Modify clauses to include sample clause
902
+ where, group_by_clause, order_by_clause, limit, sample = cls.create_sample_clauses(
903
+ from_clause, sample_clause, where_clause, group_by_clause, order_by_clause, limit
904
+ )
905
+
717
906
  analyzer = Analyzer(
718
907
  from_clause,
719
908
  select_list,
720
- where_clause=where_clause,
909
+ where_clause=where,
721
910
  group_by_clause=group_by_clause,
722
911
  order_by_clause=order_by_clause,
723
912
  )
724
913
  row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [])
914
+ if sample_clause is not None and analyzer.filter is not None:
915
+ raise excs.Error(f'Filter {analyzer.filter} not expressible in SQL')
725
916
 
726
917
  analyzer.finalize(row_builder)
727
918
  # select_list: we need to materialize everything that's been collected
@@ -732,6 +923,7 @@ class Planner:
732
923
  analyzer=analyzer,
733
924
  eval_ctx=eval_ctx,
734
925
  limit=limit,
926
+ sample_clause=sample,
735
927
  with_pk=True,
736
928
  exact_version_only=exact_version_only,
737
929
  )
@@ -747,6 +939,7 @@ class Planner:
747
939
  analyzer: Analyzer,
748
940
  eval_ctx: exprs.RowBuilder.EvalCtx,
749
941
  limit: Optional[exprs.Expr] = None,
942
+ sample_clause: Optional[SampleClause] = None,
750
943
  with_pk: bool = False,
751
944
  exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
752
945
  ) -> exec.ExecNode:
@@ -857,12 +1050,26 @@ class Planner:
857
1050
  sql_elements.contains_all(analyzer.select_list)
858
1051
  and sql_elements.contains_all(analyzer.grouping_exprs)
859
1052
  and isinstance(plan, exec.SqlNode)
860
- and plan.to_cte() is not None
1053
+ and plan.to_cte(keep_pk=(sample_clause is not None)) is not None
861
1054
  ):
862
- plan = exec.SqlAggregationNode(
863
- row_builder, input=plan, select_list=analyzer.select_list, group_by_items=analyzer.group_by_clause
864
- )
1055
+ if sample_clause is not None:
1056
+ plan = exec.SqlSampleNode(
1057
+ row_builder,
1058
+ input=plan,
1059
+ select_list=analyzer.select_list,
1060
+ stratify_exprs=analyzer.group_by_clause,
1061
+ sample_clause=sample_clause,
1062
+ )
1063
+ else:
1064
+ plan = exec.SqlAggregationNode(
1065
+ row_builder,
1066
+ input=plan,
1067
+ select_list=analyzer.select_list,
1068
+ group_by_items=analyzer.group_by_clause,
1069
+ )
865
1070
  else:
1071
+ if sample_clause is not None:
1072
+ raise excs.Error('Sample clause not supported with Python aggregation')
866
1073
  input_sql_node = plan.get_node(exec.SqlNode)
867
1074
  assert combined_ordering is not None
868
1075
  input_sql_node.set_order_by(combined_ordering)