pixeltable 0.4.0rc1__py3-none-any.whl → 0.4.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/catalog.py +4 -0
- pixeltable/catalog/table.py +16 -0
- pixeltable/catalog/table_version.py +17 -2
- pixeltable/catalog/view.py +24 -1
- pixeltable/dataframe.py +185 -9
- pixeltable/env.py +2 -0
- pixeltable/exec/__init__.py +1 -1
- pixeltable/exec/expr_eval/evaluators.py +4 -1
- pixeltable/exec/sql_node.py +152 -12
- pixeltable/exprs/data_row.py +5 -3
- pixeltable/exprs/expr.py +7 -0
- pixeltable/exprs/literal.py +2 -0
- pixeltable/func/tools.py +1 -1
- pixeltable/functions/anthropic.py +19 -45
- pixeltable/functions/deepseek.py +19 -38
- pixeltable/functions/fireworks.py +9 -18
- pixeltable/functions/gemini.py +2 -3
- pixeltable/functions/llama_cpp.py +6 -6
- pixeltable/functions/mistralai.py +15 -41
- pixeltable/functions/ollama.py +1 -1
- pixeltable/functions/openai.py +82 -165
- pixeltable/functions/together.py +22 -80
- pixeltable/globals.py +5 -0
- pixeltable/metadata/__init__.py +11 -2
- pixeltable/metadata/converters/convert_36.py +38 -0
- pixeltable/metadata/notes.py +1 -0
- pixeltable/metadata/schema.py +3 -0
- pixeltable/plan.py +217 -10
- pixeltable/share/packager.py +115 -6
- pixeltable/utils/formatter.py +64 -42
- pixeltable/utils/sample.py +25 -0
- {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/METADATA +2 -1
- {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/RECORD +37 -35
- {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/LICENSE +0 -0
- {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/WHEEL +0 -0
- {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/entry_points.txt +0 -0
pixeltable/functions/together.py
CHANGED
|
@@ -7,7 +7,7 @@ the [Working with Together AI](https://pixeltable.readme.io/docs/together-ai) tu
|
|
|
7
7
|
|
|
8
8
|
import base64
|
|
9
9
|
import io
|
|
10
|
-
from typing import TYPE_CHECKING, Callable, Optional, TypeVar
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar
|
|
11
11
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
import PIL.Image
|
|
@@ -50,21 +50,7 @@ def _retry(fn: Callable[..., T]) -> Callable[..., T]:
|
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
@pxt.udf(resource_pool='request-rate:together:chat')
|
|
53
|
-
async def completions(
|
|
54
|
-
prompt: str,
|
|
55
|
-
*,
|
|
56
|
-
model: str,
|
|
57
|
-
max_tokens: Optional[int] = None,
|
|
58
|
-
stop: Optional[list] = None,
|
|
59
|
-
temperature: Optional[float] = None,
|
|
60
|
-
top_p: Optional[float] = None,
|
|
61
|
-
top_k: Optional[int] = None,
|
|
62
|
-
repetition_penalty: Optional[float] = None,
|
|
63
|
-
logprobs: Optional[int] = None,
|
|
64
|
-
echo: Optional[bool] = None,
|
|
65
|
-
n: Optional[int] = None,
|
|
66
|
-
safety_model: Optional[str] = None,
|
|
67
|
-
) -> dict:
|
|
53
|
+
async def completions(prompt: str, *, model: str, model_kwargs: Optional[dict[str, Any]] = None) -> dict:
|
|
68
54
|
"""
|
|
69
55
|
Generate completions based on a given prompt using a specified model.
|
|
70
56
|
|
|
@@ -82,8 +68,8 @@ async def completions(
|
|
|
82
68
|
Args:
|
|
83
69
|
prompt: A string providing context for the model to complete.
|
|
84
70
|
model: The name of the model to query.
|
|
85
|
-
|
|
86
|
-
|
|
71
|
+
model_kwargs: Additional keyword arguments for the Together `completions` API.
|
|
72
|
+
For details on the available parameters, see: <https://docs.together.ai/reference/completions-1>
|
|
87
73
|
|
|
88
74
|
Returns:
|
|
89
75
|
A dictionary containing the response and other metadata.
|
|
@@ -94,41 +80,16 @@ async def completions(
|
|
|
94
80
|
|
|
95
81
|
>>> tbl.add_computed_column(response=completions(tbl.prompt, model='mistralai/Mixtral-8x7B-v0.1'))
|
|
96
82
|
"""
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
stop=stop,
|
|
102
|
-
temperature=temperature,
|
|
103
|
-
top_p=top_p,
|
|
104
|
-
top_k=top_k,
|
|
105
|
-
repetition_penalty=repetition_penalty,
|
|
106
|
-
logprobs=logprobs,
|
|
107
|
-
echo=echo,
|
|
108
|
-
n=n,
|
|
109
|
-
safety_model=safety_model,
|
|
110
|
-
)
|
|
83
|
+
if model_kwargs is None:
|
|
84
|
+
model_kwargs = {}
|
|
85
|
+
|
|
86
|
+
result = await _together_client().completions.create(prompt=prompt, model=model, **model_kwargs)
|
|
111
87
|
return result.dict()
|
|
112
88
|
|
|
113
89
|
|
|
114
90
|
@pxt.udf(resource_pool='request-rate:together:chat')
|
|
115
91
|
async def chat_completions(
|
|
116
|
-
messages: list[dict[str, str]],
|
|
117
|
-
*,
|
|
118
|
-
model: str,
|
|
119
|
-
max_tokens: Optional[int] = None,
|
|
120
|
-
stop: Optional[list[str]] = None,
|
|
121
|
-
temperature: Optional[float] = None,
|
|
122
|
-
top_p: Optional[float] = None,
|
|
123
|
-
top_k: Optional[int] = None,
|
|
124
|
-
repetition_penalty: Optional[float] = None,
|
|
125
|
-
logprobs: Optional[int] = None,
|
|
126
|
-
echo: Optional[bool] = None,
|
|
127
|
-
n: Optional[int] = None,
|
|
128
|
-
safety_model: Optional[str] = None,
|
|
129
|
-
response_format: Optional[dict] = None,
|
|
130
|
-
tools: Optional[dict] = None,
|
|
131
|
-
tool_choice: Optional[dict] = None,
|
|
92
|
+
messages: list[dict[str, str]], *, model: str, model_kwargs: Optional[dict[str, Any]] = None
|
|
132
93
|
) -> dict:
|
|
133
94
|
"""
|
|
134
95
|
Generate chat completions based on a given prompt using a specified model.
|
|
@@ -147,8 +108,8 @@ async def chat_completions(
|
|
|
147
108
|
Args:
|
|
148
109
|
messages: A list of messages comprising the conversation so far.
|
|
149
110
|
model: The name of the model to query.
|
|
150
|
-
|
|
151
|
-
|
|
111
|
+
model_kwargs: Additional keyword arguments for the Together `chat/completions` API.
|
|
112
|
+
For details on the available parameters, see: <https://docs.together.ai/reference/chat-completions-1>
|
|
152
113
|
|
|
153
114
|
Returns:
|
|
154
115
|
A dictionary containing the response and other metadata.
|
|
@@ -160,23 +121,10 @@ async def chat_completions(
|
|
|
160
121
|
>>> messages = [{'role': 'user', 'content': tbl.prompt}]
|
|
161
122
|
... tbl.add_computed_column(response=chat_completions(messages, model='mistralai/Mixtral-8x7B-v0.1'))
|
|
162
123
|
"""
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
stop=stop,
|
|
168
|
-
temperature=temperature,
|
|
169
|
-
top_p=top_p,
|
|
170
|
-
top_k=top_k,
|
|
171
|
-
repetition_penalty=repetition_penalty,
|
|
172
|
-
logprobs=logprobs,
|
|
173
|
-
echo=echo,
|
|
174
|
-
n=n,
|
|
175
|
-
safety_model=safety_model,
|
|
176
|
-
response_format=response_format,
|
|
177
|
-
tools=tools,
|
|
178
|
-
tool_choice=tool_choice,
|
|
179
|
-
)
|
|
124
|
+
if model_kwargs is None:
|
|
125
|
+
model_kwargs = {}
|
|
126
|
+
|
|
127
|
+
result = await _together_client().chat.completions.create(messages=messages, model=model, **model_kwargs)
|
|
180
128
|
return result.dict()
|
|
181
129
|
|
|
182
130
|
|
|
@@ -236,14 +184,7 @@ def _(model: str) -> ts.ArrayType:
|
|
|
236
184
|
|
|
237
185
|
@pxt.udf(resource_pool='request-rate:together:images')
|
|
238
186
|
async def image_generations(
|
|
239
|
-
prompt: str,
|
|
240
|
-
*,
|
|
241
|
-
model: str,
|
|
242
|
-
steps: Optional[int] = None,
|
|
243
|
-
seed: Optional[int] = None,
|
|
244
|
-
height: Optional[int] = None,
|
|
245
|
-
width: Optional[int] = None,
|
|
246
|
-
negative_prompt: Optional[str] = None,
|
|
187
|
+
prompt: str, *, model: str, model_kwargs: Optional[dict[str, Any]] = None
|
|
247
188
|
) -> PIL.Image.Image:
|
|
248
189
|
"""
|
|
249
190
|
Generate images based on a given prompt using a specified model.
|
|
@@ -262,8 +203,8 @@ async def image_generations(
|
|
|
262
203
|
Args:
|
|
263
204
|
prompt: A description of the desired images.
|
|
264
205
|
model: The model to use for image generation.
|
|
265
|
-
|
|
266
|
-
|
|
206
|
+
model_kwargs: Additional keyword args for the Together `images/generations` API.
|
|
207
|
+
For details on the available parameters, see: <https://docs.together.ai/reference/post_images-generations>
|
|
267
208
|
|
|
268
209
|
Returns:
|
|
269
210
|
The generated image.
|
|
@@ -276,9 +217,10 @@ async def image_generations(
|
|
|
276
217
|
... response=image_generations(tbl.prompt, model='stabilityai/stable-diffusion-xl-base-1.0')
|
|
277
218
|
... )
|
|
278
219
|
"""
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
220
|
+
if model_kwargs is None:
|
|
221
|
+
model_kwargs = {}
|
|
222
|
+
|
|
223
|
+
result = await _together_client().images.generate(prompt=prompt, model=model, **model_kwargs)
|
|
282
224
|
if result.data[0].b64_json is not None:
|
|
283
225
|
b64_bytes = base64.b64decode(result.data[0].b64_json)
|
|
284
226
|
img = PIL.Image.open(io.BytesIO(b64_bytes))
|
pixeltable/globals.py
CHANGED
|
@@ -249,13 +249,17 @@ def create_view(
|
|
|
249
249
|
where: Optional[exprs.Expr] = None
|
|
250
250
|
if isinstance(base, catalog.Table):
|
|
251
251
|
tbl_version_path = base._tbl_version_path
|
|
252
|
+
sample_clause = None
|
|
252
253
|
elif isinstance(base, DataFrame):
|
|
253
254
|
base._validate_mutable('create_view', allow_select=True)
|
|
254
255
|
if len(base._from_clause.tbls) > 1:
|
|
255
256
|
raise excs.Error('Cannot create a view of a join')
|
|
256
257
|
tbl_version_path = base._from_clause.tbls[0]
|
|
257
258
|
where = base.where_clause
|
|
259
|
+
sample_clause = base.sample_clause
|
|
258
260
|
select_list = base.select_list
|
|
261
|
+
if sample_clause is not None and not is_snapshot and not sample_clause.is_repeatable:
|
|
262
|
+
raise excs.Error('Non-snapshot views cannot be created with non-fractional or stratified sampling')
|
|
259
263
|
else:
|
|
260
264
|
raise excs.Error('`base` must be an instance of `Table` or `DataFrame`')
|
|
261
265
|
assert isinstance(base, (catalog.Table, DataFrame))
|
|
@@ -280,6 +284,7 @@ def create_view(
|
|
|
280
284
|
tbl_version_path,
|
|
281
285
|
select_list=select_list,
|
|
282
286
|
where=where,
|
|
287
|
+
sample_clause=sample_clause,
|
|
283
288
|
additional_columns=additional_columns,
|
|
284
289
|
is_snapshot=is_snapshot,
|
|
285
290
|
iterator=iterator,
|
pixeltable/metadata/__init__.py
CHANGED
|
@@ -8,15 +8,17 @@ from typing import Callable
|
|
|
8
8
|
import sqlalchemy as sql
|
|
9
9
|
from sqlalchemy import orm
|
|
10
10
|
|
|
11
|
+
import pixeltable as pxt
|
|
12
|
+
import pixeltable.exceptions as excs
|
|
11
13
|
from pixeltable.utils.console_output import ConsoleLogger
|
|
12
14
|
|
|
13
15
|
from .schema import SystemInfo, SystemInfoMd
|
|
14
16
|
|
|
15
17
|
_console_logger = ConsoleLogger(logging.getLogger('pixeltable'))
|
|
16
|
-
|
|
18
|
+
_logger = logging.getLogger('pixeltable')
|
|
17
19
|
|
|
18
20
|
# current version of the metadata; this is incremented whenever the metadata schema changes
|
|
19
|
-
VERSION =
|
|
21
|
+
VERSION = 37
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
def create_system_info(engine: sql.engine.Engine) -> None:
|
|
@@ -55,6 +57,13 @@ def upgrade_md(engine: sql.engine.Engine) -> None:
|
|
|
55
57
|
system_info = session.query(SystemInfo).one().md
|
|
56
58
|
md_version = system_info['schema_version']
|
|
57
59
|
assert isinstance(md_version, int)
|
|
60
|
+
_logger.info(f'Current database version: {md_version}, installed version: {VERSION}')
|
|
61
|
+
if md_version > VERSION:
|
|
62
|
+
raise excs.Error(
|
|
63
|
+
'This Pixeltable database was created with a newer Pixeltable version '
|
|
64
|
+
f'than the one currently installed ({pxt.__version__}).\n'
|
|
65
|
+
'Please update to the latest Pixeltable version by running: pip install --upgrade pixeltable'
|
|
66
|
+
)
|
|
58
67
|
if md_version == VERSION:
|
|
59
68
|
return
|
|
60
69
|
while md_version < VERSION:
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
from uuid import UUID
|
|
4
|
+
|
|
5
|
+
import sqlalchemy as sql
|
|
6
|
+
|
|
7
|
+
from pixeltable.metadata import register_converter
|
|
8
|
+
from pixeltable.metadata.converters.util import convert_table_md
|
|
9
|
+
|
|
10
|
+
_logger = logging.getLogger('pixeltable')
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_converter(version=36)
|
|
14
|
+
def _(engine: sql.engine.Engine) -> None:
|
|
15
|
+
convert_table_md(engine, table_md_updater=__update_table_md, substitution_fn=__substitute_md)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def __update_table_md(table_md: dict, table_id: UUID) -> None:
|
|
19
|
+
"""Update the view metadata to add the sample_clause field if it is missing
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
table_md (dict): copy of the original table metadata. this gets updated in place.
|
|
23
|
+
table_id (UUID): the table id
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
if table_md['view_md'] is None:
|
|
27
|
+
return
|
|
28
|
+
if 'sample_clause' not in table_md['view_md']:
|
|
29
|
+
table_md['view_md']['sample_clause'] = None
|
|
30
|
+
_logger.info(f'Updating view metadata for table: {table_id}')
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def __substitute_md(k: Optional[str], v: Any) -> Optional[tuple[Optional[str], Any]]:
|
|
34
|
+
if isinstance(v, dict) and (v.get('_classname') == 'DataFrame'):
|
|
35
|
+
if 'sample_clause' not in v:
|
|
36
|
+
v['sample_clause'] = None
|
|
37
|
+
return k, v
|
|
38
|
+
return None
|
pixeltable/metadata/notes.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# rather than as a comment, so that the existence of a description can be enforced by
|
|
3
3
|
# the unit tests when new versions are added.
|
|
4
4
|
VERSION_NOTES = {
|
|
5
|
+
37: 'Add support for the sample() method on DataFrames',
|
|
5
6
|
36: 'Added Table.lock_dummy',
|
|
6
7
|
35: 'Track reference_tbl in ColumnRef',
|
|
7
8
|
34: 'Set default value for is_pk field in column metadata to False',
|
pixeltable/metadata/schema.py
CHANGED
|
@@ -147,6 +147,9 @@ class ViewMd:
|
|
|
147
147
|
# filter predicate applied to the base table; view-only
|
|
148
148
|
predicate: Optional[dict[str, Any]]
|
|
149
149
|
|
|
150
|
+
# sampling predicate applied to the base table; view-only
|
|
151
|
+
sample_clause: Optional[dict[str, Any]]
|
|
152
|
+
|
|
150
153
|
# ComponentIterator subclass; only for component views
|
|
151
154
|
iterator_class_fqn: Optional[str]
|
|
152
155
|
|
pixeltable/plan.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import enum
|
|
5
5
|
from textwrap import dedent
|
|
6
|
-
from typing import Any, Iterable, Literal, Optional, Sequence
|
|
6
|
+
from typing import Any, Iterable, Literal, NamedTuple, Optional, Sequence
|
|
7
7
|
from uuid import UUID
|
|
8
8
|
|
|
9
9
|
import sqlalchemy as sql
|
|
@@ -12,6 +12,7 @@ import pixeltable as pxt
|
|
|
12
12
|
from pixeltable import catalog, exceptions as excs, exec, exprs
|
|
13
13
|
from pixeltable.catalog import Column, TableVersionHandle
|
|
14
14
|
from pixeltable.exec.sql_node import OrderByClause, OrderByItem, combine_order_by_clauses, print_order_by_clause
|
|
15
|
+
from pixeltable.utils.sample import sample_key
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def _is_agg_fn_call(e: exprs.Expr) -> bool:
|
|
@@ -75,6 +76,98 @@ class FromClause:
|
|
|
75
76
|
tbls: list[catalog.TableVersionPath]
|
|
76
77
|
join_clauses: list[JoinClause] = dataclasses.field(default_factory=list)
|
|
77
78
|
|
|
79
|
+
@property
|
|
80
|
+
def _first_tbl(self) -> catalog.TableVersionPath:
|
|
81
|
+
assert len(self.tbls) == 1
|
|
82
|
+
return self.tbls[0]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclasses.dataclass
|
|
86
|
+
class SampleClause:
|
|
87
|
+
"""Defines a sampling clause for a table."""
|
|
88
|
+
|
|
89
|
+
version: Optional[int]
|
|
90
|
+
n: Optional[int]
|
|
91
|
+
n_per_stratum: Optional[int]
|
|
92
|
+
fraction: Optional[float]
|
|
93
|
+
seed: Optional[int]
|
|
94
|
+
stratify_exprs: Optional[list[exprs.Expr]]
|
|
95
|
+
|
|
96
|
+
# This seed value is used if one is not supplied
|
|
97
|
+
DEFAULT_SEED = 0
|
|
98
|
+
|
|
99
|
+
# The version of the hashing algorithm used for ordering and fractional sampling.
|
|
100
|
+
CURRENT_VERSION = 1
|
|
101
|
+
|
|
102
|
+
def __post_init__(self) -> None:
|
|
103
|
+
"""If no version was provided, provide the default version"""
|
|
104
|
+
if self.version is None:
|
|
105
|
+
self.version = self.CURRENT_VERSION
|
|
106
|
+
if self.seed is None:
|
|
107
|
+
self.seed = self.DEFAULT_SEED
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def is_stratified(self) -> bool:
|
|
111
|
+
"""Check if the sampling is stratified"""
|
|
112
|
+
return self.stratify_exprs is not None and len(self.stratify_exprs) > 0
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def is_repeatable(self) -> bool:
|
|
116
|
+
"""Return true if the same rows will continue to be sampled if source rows are added or deleted."""
|
|
117
|
+
return not self.is_stratified and self.fraction is not None
|
|
118
|
+
|
|
119
|
+
def display_str(self, inline: bool = False) -> str:
|
|
120
|
+
return str(self)
|
|
121
|
+
|
|
122
|
+
def as_dict(self) -> dict:
|
|
123
|
+
"""Return a dictionary representation of the object"""
|
|
124
|
+
d = dataclasses.asdict(self)
|
|
125
|
+
d['_classname'] = self.__class__.__name__
|
|
126
|
+
if self.is_stratified:
|
|
127
|
+
d['stratify_exprs'] = [e.as_dict() for e in self.stratify_exprs]
|
|
128
|
+
return d
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def from_dict(cls, d: dict) -> SampleClause:
|
|
132
|
+
"""Create a SampleClause from a dictionary representation"""
|
|
133
|
+
d_cleaned = {key: value for key, value in d.items() if key != '_classname'}
|
|
134
|
+
s = cls(**d_cleaned)
|
|
135
|
+
if s.is_stratified:
|
|
136
|
+
s.stratify_exprs = [exprs.Expr.from_dict(e) for e in d_cleaned.get('stratify_exprs', [])]
|
|
137
|
+
return s
|
|
138
|
+
|
|
139
|
+
def __repr__(self) -> str:
|
|
140
|
+
s = ','.join(e.display_str(inline=True) for e in self.stratify_exprs)
|
|
141
|
+
return (
|
|
142
|
+
f'sample_{self.version}(n={self.n}, n_per_stratum={self.n_per_stratum}, '
|
|
143
|
+
f'fraction={self.fraction}, seed={self.seed}, [{s}])'
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
def fraction_to_md5_hex(cls, fraction: float) -> str:
|
|
148
|
+
"""Return the string representation of an approximation (to ~1e-9) of a fraction of the total space
|
|
149
|
+
of md5 hash values.
|
|
150
|
+
This is used for fractional sampling.
|
|
151
|
+
"""
|
|
152
|
+
# Maximum count for the upper 32 bits of MD5: 2^32
|
|
153
|
+
max_md5_value = (2**32) - 1
|
|
154
|
+
|
|
155
|
+
# Calculate the fraction of this value
|
|
156
|
+
threshold_int = max_md5_value * int(1_000_000_000 * fraction) // 1_000_000_000
|
|
157
|
+
|
|
158
|
+
# Convert to hexadecimal string with padding
|
|
159
|
+
return format(threshold_int, '08x') + 'ffffffffffffffffffffffff'
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class SamplingClauses(NamedTuple):
|
|
163
|
+
"""Clauses provided when rewriting a SampleClause"""
|
|
164
|
+
|
|
165
|
+
where: exprs.Expr
|
|
166
|
+
group_by_clause: Optional[list[exprs.Expr]]
|
|
167
|
+
order_by_clause: Optional[list[tuple[exprs.Expr, bool]]]
|
|
168
|
+
limit: Optional[exprs.Expr]
|
|
169
|
+
sample_clause: Optional[SampleClause]
|
|
170
|
+
|
|
78
171
|
|
|
79
172
|
class Analyzer:
|
|
80
173
|
"""
|
|
@@ -260,7 +353,7 @@ class Planner:
|
|
|
260
353
|
# TODO: create an exec.CountNode and change this to create_count_plan()
|
|
261
354
|
@classmethod
|
|
262
355
|
def create_count_stmt(cls, tbl: catalog.TableVersionPath, where_clause: Optional[exprs.Expr] = None) -> sql.Select:
|
|
263
|
-
stmt = sql.select(sql.func.count())
|
|
356
|
+
stmt = sql.select(sql.func.count().label('all_count'))
|
|
264
357
|
refd_tbl_ids: set[UUID] = set()
|
|
265
358
|
if where_clause is not None:
|
|
266
359
|
analyzer = cls.analyze(tbl, where_clause)
|
|
@@ -322,6 +415,13 @@ class Planner:
|
|
|
322
415
|
)
|
|
323
416
|
return plan
|
|
324
417
|
|
|
418
|
+
@classmethod
|
|
419
|
+
def rowid_columns(cls, target: TableVersionHandle, num_rowid_cols: Optional[int] = None) -> list[exprs.Expr]:
|
|
420
|
+
"""Return list of RowidRef for the given number of associated rowids"""
|
|
421
|
+
if num_rowid_cols is None:
|
|
422
|
+
num_rowid_cols = target.get().num_rowid_columns()
|
|
423
|
+
return [exprs.RowidRef(target, i) for i in range(num_rowid_cols)]
|
|
424
|
+
|
|
325
425
|
@classmethod
|
|
326
426
|
def create_df_insert_plan(
|
|
327
427
|
cls, tbl: catalog.TableVersion, df: 'pxt.DataFrame', ignore_errors: bool
|
|
@@ -591,7 +691,24 @@ class Planner:
|
|
|
591
691
|
# 2. for component views: iterator args
|
|
592
692
|
iterator_args = [target.iterator_args] if target.iterator_args is not None else []
|
|
593
693
|
|
|
594
|
-
|
|
694
|
+
# If this contains a sample specification, modify / create where, group_by, order_by, and limit clauses
|
|
695
|
+
from_clause = FromClause(tbls=[view.base])
|
|
696
|
+
where, group_by_clause, order_by_clause, limit, sample_clause = cls.create_sample_clauses(
|
|
697
|
+
from_clause, target.sample_clause, target.predicate, None, [], None
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
# if we're propagating an insert, we only want to see those base rows that were created for the current version
|
|
701
|
+
base_analyzer = Analyzer(
|
|
702
|
+
from_clause,
|
|
703
|
+
iterator_args,
|
|
704
|
+
where_clause=where,
|
|
705
|
+
group_by_clause=group_by_clause,
|
|
706
|
+
order_by_clause=order_by_clause,
|
|
707
|
+
)
|
|
708
|
+
row_builder = exprs.RowBuilder(base_analyzer.all_exprs, stored_cols, [])
|
|
709
|
+
|
|
710
|
+
if target.sample_clause is not None and base_analyzer.filter is not None:
|
|
711
|
+
raise excs.Error(f'Filter {base_analyzer.filter} not expressible in SQL')
|
|
595
712
|
|
|
596
713
|
# execution plan:
|
|
597
714
|
# 1. materialize exprs computed from the base that are needed for stored view columns
|
|
@@ -603,13 +720,22 @@ class Planner:
|
|
|
603
720
|
for e in row_builder.default_eval_ctx.target_exprs
|
|
604
721
|
if e.is_bound_by([view]) and not e.is_bound_by([view.base])
|
|
605
722
|
]
|
|
606
|
-
|
|
607
|
-
|
|
723
|
+
|
|
724
|
+
# Create a new analyzer reflecting exactly what is required from the base table
|
|
725
|
+
base_analyzer = Analyzer(
|
|
726
|
+
from_clause,
|
|
727
|
+
base_output_exprs,
|
|
728
|
+
where_clause=where,
|
|
729
|
+
group_by_clause=group_by_clause,
|
|
730
|
+
order_by_clause=order_by_clause,
|
|
731
|
+
)
|
|
608
732
|
base_eval_ctx = row_builder.create_eval_ctx(base_analyzer.all_exprs)
|
|
609
733
|
plan = cls._create_query_plan(
|
|
610
734
|
row_builder=row_builder,
|
|
611
735
|
analyzer=base_analyzer,
|
|
612
736
|
eval_ctx=base_eval_ctx,
|
|
737
|
+
limit=limit,
|
|
738
|
+
sample_clause=sample_clause,
|
|
613
739
|
with_pk=True,
|
|
614
740
|
exact_version_only=view.get_bases() if propagates_insert else [],
|
|
615
741
|
)
|
|
@@ -692,6 +818,62 @@ class Planner:
|
|
|
692
818
|
prefetch_node = exec.CachePrefetchNode(tbl_id, file_col_info, input_node)
|
|
693
819
|
return prefetch_node
|
|
694
820
|
|
|
821
|
+
@classmethod
|
|
822
|
+
def create_sample_clauses(
|
|
823
|
+
cls,
|
|
824
|
+
from_clause: FromClause,
|
|
825
|
+
sample_clause: SampleClause,
|
|
826
|
+
where_clause: Optional[exprs.Expr],
|
|
827
|
+
group_by_clause: Optional[list[exprs.Expr]],
|
|
828
|
+
order_by_clause: Optional[list[tuple[exprs.Expr, bool]]],
|
|
829
|
+
limit: Optional[exprs.Expr],
|
|
830
|
+
) -> SamplingClauses:
|
|
831
|
+
"""tuple[
|
|
832
|
+
exprs.Expr,
|
|
833
|
+
Optional[list[exprs.Expr]],
|
|
834
|
+
Optional[list[tuple[exprs.Expr, bool]]],
|
|
835
|
+
Optional[exprs.Expr],
|
|
836
|
+
Optional[SampleClause],
|
|
837
|
+
]:"""
|
|
838
|
+
"""Construct clauses required for sampling under various conditions.
|
|
839
|
+
If there is no sampling, then return the original clauses.
|
|
840
|
+
If the sample is stratified, then return only the group by clause. The rest of the
|
|
841
|
+
mechanism for stratified sampling is provided by the SampleSqlNode.
|
|
842
|
+
If the sample is non-stratified, then rewrite the query to accommodate the supplied where clause,
|
|
843
|
+
and provide the other clauses required for sampling
|
|
844
|
+
"""
|
|
845
|
+
|
|
846
|
+
# If no sample clause, return the original clauses
|
|
847
|
+
if sample_clause is None:
|
|
848
|
+
return SamplingClauses(where_clause, group_by_clause, order_by_clause, limit, None)
|
|
849
|
+
|
|
850
|
+
# If the sample clause is stratified, create a group by clause
|
|
851
|
+
if sample_clause.is_stratified:
|
|
852
|
+
group_by = sample_clause.stratify_exprs
|
|
853
|
+
# Note that limit is not possible here
|
|
854
|
+
return SamplingClauses(where_clause, group_by, order_by_clause, None, sample_clause)
|
|
855
|
+
|
|
856
|
+
else:
|
|
857
|
+
# If non-stratified sampling, construct a where clause, order_by, and limit clauses
|
|
858
|
+
# Construct an expression for sorting rows and limiting row counts
|
|
859
|
+
s_key = sample_key(
|
|
860
|
+
exprs.Literal(sample_clause.seed), *cls.rowid_columns(from_clause._first_tbl.tbl_version)
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
# Construct a suitable where clause
|
|
864
|
+
where = where_clause
|
|
865
|
+
if sample_clause.fraction is not None:
|
|
866
|
+
fraction_md5_hex = exprs.Expr.from_object(
|
|
867
|
+
sample_clause.fraction_to_md5_hex(float(sample_clause.fraction))
|
|
868
|
+
)
|
|
869
|
+
f_where = s_key < fraction_md5_hex
|
|
870
|
+
where = where & f_where if where is not None else f_where
|
|
871
|
+
|
|
872
|
+
order_by: list[tuple[exprs.Expr, bool]] = [(s_key, True)]
|
|
873
|
+
limit = exprs.Literal(sample_clause.n)
|
|
874
|
+
# Note that group_by is not possible here
|
|
875
|
+
return SamplingClauses(where, None, order_by, limit, None)
|
|
876
|
+
|
|
695
877
|
@classmethod
|
|
696
878
|
def create_query_plan(
|
|
697
879
|
cls,
|
|
@@ -701,6 +883,7 @@ class Planner:
|
|
|
701
883
|
group_by_clause: Optional[list[exprs.Expr]] = None,
|
|
702
884
|
order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None,
|
|
703
885
|
limit: Optional[exprs.Expr] = None,
|
|
886
|
+
sample_clause: Optional[SampleClause] = None,
|
|
704
887
|
ignore_errors: bool = False,
|
|
705
888
|
exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
|
|
706
889
|
) -> exec.ExecNode:
|
|
@@ -714,14 +897,22 @@ class Planner:
|
|
|
714
897
|
order_by_clause = []
|
|
715
898
|
if exact_version_only is None:
|
|
716
899
|
exact_version_only = []
|
|
900
|
+
|
|
901
|
+
# Modify clauses to include sample clause
|
|
902
|
+
where, group_by_clause, order_by_clause, limit, sample = cls.create_sample_clauses(
|
|
903
|
+
from_clause, sample_clause, where_clause, group_by_clause, order_by_clause, limit
|
|
904
|
+
)
|
|
905
|
+
|
|
717
906
|
analyzer = Analyzer(
|
|
718
907
|
from_clause,
|
|
719
908
|
select_list,
|
|
720
|
-
where_clause=
|
|
909
|
+
where_clause=where,
|
|
721
910
|
group_by_clause=group_by_clause,
|
|
722
911
|
order_by_clause=order_by_clause,
|
|
723
912
|
)
|
|
724
913
|
row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [])
|
|
914
|
+
if sample_clause is not None and analyzer.filter is not None:
|
|
915
|
+
raise excs.Error(f'Filter {analyzer.filter} not expressible in SQL')
|
|
725
916
|
|
|
726
917
|
analyzer.finalize(row_builder)
|
|
727
918
|
# select_list: we need to materialize everything that's been collected
|
|
@@ -732,6 +923,7 @@ class Planner:
|
|
|
732
923
|
analyzer=analyzer,
|
|
733
924
|
eval_ctx=eval_ctx,
|
|
734
925
|
limit=limit,
|
|
926
|
+
sample_clause=sample,
|
|
735
927
|
with_pk=True,
|
|
736
928
|
exact_version_only=exact_version_only,
|
|
737
929
|
)
|
|
@@ -747,6 +939,7 @@ class Planner:
|
|
|
747
939
|
analyzer: Analyzer,
|
|
748
940
|
eval_ctx: exprs.RowBuilder.EvalCtx,
|
|
749
941
|
limit: Optional[exprs.Expr] = None,
|
|
942
|
+
sample_clause: Optional[SampleClause] = None,
|
|
750
943
|
with_pk: bool = False,
|
|
751
944
|
exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
|
|
752
945
|
) -> exec.ExecNode:
|
|
@@ -857,12 +1050,26 @@ class Planner:
|
|
|
857
1050
|
sql_elements.contains_all(analyzer.select_list)
|
|
858
1051
|
and sql_elements.contains_all(analyzer.grouping_exprs)
|
|
859
1052
|
and isinstance(plan, exec.SqlNode)
|
|
860
|
-
and plan.to_cte() is not None
|
|
1053
|
+
and plan.to_cte(keep_pk=(sample_clause is not None)) is not None
|
|
861
1054
|
):
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
1055
|
+
if sample_clause is not None:
|
|
1056
|
+
plan = exec.SqlSampleNode(
|
|
1057
|
+
row_builder,
|
|
1058
|
+
input=plan,
|
|
1059
|
+
select_list=analyzer.select_list,
|
|
1060
|
+
stratify_exprs=analyzer.group_by_clause,
|
|
1061
|
+
sample_clause=sample_clause,
|
|
1062
|
+
)
|
|
1063
|
+
else:
|
|
1064
|
+
plan = exec.SqlAggregationNode(
|
|
1065
|
+
row_builder,
|
|
1066
|
+
input=plan,
|
|
1067
|
+
select_list=analyzer.select_list,
|
|
1068
|
+
group_by_items=analyzer.group_by_clause,
|
|
1069
|
+
)
|
|
865
1070
|
else:
|
|
1071
|
+
if sample_clause is not None:
|
|
1072
|
+
raise excs.Error('Sample clause not supported with Python aggregation')
|
|
866
1073
|
input_sql_node = plan.get_node(exec.SqlNode)
|
|
867
1074
|
assert combined_ordering is not None
|
|
868
1075
|
input_sql_node.set_order_by(combined_ordering)
|