pixeltable 0.4.0rc1__py3-none-any.whl → 0.4.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

@@ -14,6 +14,7 @@ from .exec_node import ExecNode
14
14
 
15
15
  if TYPE_CHECKING:
16
16
  import pixeltable.plan
17
+ from pixeltable.plan import SampleClause
17
18
 
18
19
  _logger = logging.getLogger('pixeltable')
19
20
 
@@ -64,8 +65,12 @@ def print_order_by_clause(clause: OrderByClause) -> str:
64
65
 
65
66
  class SqlNode(ExecNode):
66
67
  """
67
- Materializes data from the store via a Select stmt.
68
+ Materializes data from the store via an SQL statement.
68
69
  This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
70
+ The pk columns are not included in the select list.
71
+ If set_pk is True, they are added to the end of the result set when creating the SQL statement
72
+ so they can always be referenced as cols[-num_pk_cols:] in the result set.
73
+ The pk_columns consist of the rowid columns of the target table followed by the version number.
69
74
  """
70
75
 
71
76
  tbl: Optional[catalog.TableVersionPath]
@@ -122,6 +127,7 @@ class SqlNode(ExecNode):
122
127
  # we also need to retrieve the pk columns
123
128
  assert tbl is not None
124
129
  self.num_pk_cols = len(tbl.tbl_version.get().store_tbl.pk_columns())
130
+ assert self.num_pk_cols > 1
125
131
 
126
132
  # additional state
127
133
  self.result_cursor = None
@@ -139,15 +145,20 @@ class SqlNode(ExecNode):
139
145
  if tv is not None:
140
146
  assert tv.is_validated
141
147
 
148
+ def _create_pk_cols(self) -> list[sql.Column]:
149
+ """Create a list of pk columns"""
150
+ # we need to retrieve the pk columns
151
+ if self.set_pk:
152
+ assert self.tbl is not None
153
+ assert self.tbl.tbl_version.get().is_validated
154
+ return self.tbl.tbl_version.get().store_tbl.pk_columns()
155
+ return []
156
+
142
157
  def _create_stmt(self) -> sql.Select:
143
158
  """Create Select from local state"""
144
159
 
145
160
  assert self.sql_elements.contains_all(self.select_list)
146
- sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
147
- if self.set_pk:
148
- assert self.tbl is not None
149
- assert self.tbl.tbl_version.get().is_validated
150
- sql_select_list += self.tbl.tbl_version.get().store_tbl.pk_columns()
161
+ sql_select_list = [self.sql_elements.get(e) for e in self.select_list] + self._create_pk_cols()
151
162
  stmt = sql.select(*sql_select_list)
152
163
 
153
164
  where_clause_element = (
@@ -173,9 +184,10 @@ class SqlNode(ExecNode):
173
184
  def _ordering_tbl_ids(self) -> set[UUID]:
174
185
  return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
175
186
 
176
- def to_cte(self) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
187
+ def to_cte(self, keep_pk: bool = False) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
177
188
  """
178
- Returns a CTE that materializes the output of this node plus a mapping from select list expr to output column
189
+ Creates a CTE that materializes the output of this node plus a mapping from select list expr to output column.
190
+ keep_pk: if True, the PK columns are included in the CTE Select statement
179
191
 
180
192
  Returns:
181
193
  (CTE, dict from Expr to output column)
@@ -183,11 +195,13 @@ class SqlNode(ExecNode):
183
195
  if self.py_filter is not None:
184
196
  # the filter needs to run in Python
185
197
  return None
186
- self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
187
198
  if self.cte is None:
199
+ if not keep_pk:
200
+ self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
188
201
  self.cte = self._create_stmt().cte()
189
- assert len(self.cte.c) == len(self.select_list)
190
- return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c))
202
+ pk_count = self.num_pk_cols if self.set_pk else 0
203
+ assert len(self.select_list) + pk_count == len(self.cte.c)
204
+ return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c)) # skip pk cols
191
205
 
192
206
  @classmethod
193
207
  def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
@@ -293,7 +307,9 @@ class SqlNode(ExecNode):
293
307
  stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
294
308
  _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
295
309
  except Exception:
296
- pass
310
+ # log something if we can't log the compiled stmt
311
+ stmt_str = repr(stmt)
312
+ _logger.debug(f'SqlLookupNode proto-stmt:\n{stmt_str}')
297
313
  self._log_explain(stmt)
298
314
 
299
315
  conn = Env.get().conn
@@ -510,3 +526,127 @@ class SqlJoinNode(SqlNode):
510
526
  full=join_clause == plan.JoinType.FULL_OUTER,
511
527
  )
512
528
  return stmt
529
+
530
+
531
+ class SqlSampleNode(SqlNode):
532
+ """
533
+ Returns rows from a stratified sample with N samples per strata.
534
+ """
535
+
536
+ stratify_exprs: Optional[list[exprs.Expr]]
537
+ n_samples: Optional[int]
538
+ fraction_samples: Optional[float]
539
+ seed: int
540
+ input_cte: Optional[sql.CTE]
541
+ pk_count: int
542
+
543
+ def __init__(
544
+ self,
545
+ row_builder: exprs.RowBuilder,
546
+ input: SqlNode,
547
+ select_list: Iterable[exprs.Expr],
548
+ stratify_exprs: Optional[list[exprs.Expr]] = None,
549
+ sample_clause: Optional['SampleClause'] = None,
550
+ ):
551
+ """
552
+ Args:
553
+ select_list: can contain calls to AggregateFunctions
554
+ stratify_exprs: list of expressions to group by
555
+ n: number of samples per strata
556
+ """
557
+ self.input_cte, input_col_map = input.to_cte(keep_pk=True)
558
+ self.pk_count = input.num_pk_cols
559
+ assert self.pk_count > 1
560
+ sql_elements = exprs.SqlElementCache(input_col_map)
561
+ super().__init__(input.tbl, row_builder, select_list, sql_elements, set_pk=True)
562
+ self.stratify_exprs = stratify_exprs
563
+ self.n_samples = sample_clause.n
564
+ self.n_per_stratum = sample_clause.n_per_stratum
565
+ self.fraction_samples = sample_clause.fraction
566
+ self.seed = sample_clause.seed if sample_clause.seed is not None else 0
567
+
568
+ @classmethod
569
+ def key_sql_expr(cls, seed: sql.ColumnElement, sql_cols: Iterable[sql.ColumnElement]) -> sql.ColumnElement:
570
+ """Construct expression which is the ordering key for rows to be sampled
571
+ General SQL form is:
572
+ - MD5(<seed::text> [ + '___' + <rowid_col_val>::text]+
573
+ """
574
+ sql_expr: sql.ColumnElement = sql.cast(seed, sql.Text)
575
+ for e in sql_cols:
576
+ sql_expr = sql_expr + sql.literal_column("'___'") + sql.cast(e, sql.Text)
577
+ sql_expr = sql.func.md5(sql_expr)
578
+ return sql_expr
579
+
580
+ def _create_order_by(self, cte: sql.CTE) -> sql.ColumnElement:
581
+ """Create an expression for randomly ordering rows with a given seed"""
582
+ rowid_cols = [*cte.c[-self.pk_count : -1]] # exclude the version column
583
+ assert len(rowid_cols) > 0
584
+ return self.key_sql_expr(sql.literal_column(str(self.seed)), rowid_cols)
585
+
586
+ def _create_stmt(self) -> sql.Select:
587
+ if self.fraction_samples is not None:
588
+ return self._create_stmt_fraction(self.fraction_samples)
589
+ return self._create_stmt_n(self.n_samples, self.n_per_stratum)
590
+
591
+ def _create_stmt_n(self, n: Optional[int], n_per_stratum: Optional[int]) -> sql.Select:
592
+ """Create a Select stmt that returns n samples across all strata"""
593
+ sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
594
+ order_by = self._create_order_by(self.input_cte)
595
+
596
+ # Create a list of all columns plus the rank
597
+ # Get all columns from the input CTE dynamically
598
+ select_columns = [*self.input_cte.c]
599
+ select_columns.append(
600
+ sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
601
+ )
602
+ row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
603
+
604
+ final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
605
+ if n_per_stratum is not None:
606
+ return sql.select(*final_columns).filter(row_rank_cte.c.rank <= n_per_stratum)
607
+ else:
608
+ secondary_order = self._create_order_by(row_rank_cte)
609
+ return sql.select(*final_columns).order_by(row_rank_cte.c.rank, secondary_order).limit(n)
610
+
611
+ def _create_stmt_fraction(self, fraction_samples: float) -> sql.Select:
612
+ """Create a Select stmt that returns a fraction of the rows per strata"""
613
+
614
+ # Build the strata count CTE
615
+ # Produces a table of the form:
616
+ # ([stratify_exprs], s_s_size)
617
+ # where s_s_size is the number of samples to take from each stratum
618
+ sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
619
+ per_strata_count_cte = (
620
+ sql.select(
621
+ *sql_strata_exprs,
622
+ sql.func.ceil(fraction_samples * sql.func.count(1).cast(sql.Integer)).label('s_s_size'),
623
+ )
624
+ .select_from(self.input_cte)
625
+ .group_by(*sql_strata_exprs)
626
+ .cte('per_strata_count_cte')
627
+ )
628
+
629
+ # Build a CTE that ranks the rows within each stratum
630
+ # Include all columns from the input CTE dynamically
631
+ order_by = self._create_order_by(self.input_cte)
632
+ select_columns = [*self.input_cte.c]
633
+ select_columns.append(
634
+ sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
635
+ )
636
+ row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
637
+
638
+ # Build the join criterion dynamically to accommodate any number of group by columns
639
+ join_c = sql.true()
640
+ for col in per_strata_count_cte.c[:-1]:
641
+ join_c &= row_rank_cte.c[col.name].isnot_distinct_from(col)
642
+
643
+ # Join srcp with per_strata_count_cte to limit returns to the requested fraction of rows
644
+ final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
645
+ stmt = (
646
+ sql.select(*final_columns)
647
+ .select_from(row_rank_cte)
648
+ .join(per_strata_count_cte, join_c)
649
+ .where(row_rank_cte.c.rank <= per_strata_count_cte.c.s_s_size)
650
+ )
651
+
652
+ return stmt
@@ -214,6 +214,7 @@ class DataRow:
214
214
  """Assign in-memory cell value
215
215
  This allows overwriting
216
216
  """
217
+ assert isinstance(idx, int)
217
218
  assert self.excs[idx] is None
218
219
 
219
220
  if (idx in self.img_slot_idxs or idx in self.media_slot_idxs) and isinstance(val, str):
@@ -253,14 +254,15 @@ class DataRow:
253
254
  assert self.excs[index] is None
254
255
  if self.file_paths[index] is None:
255
256
  if filepath is not None:
256
- # we want to save this to a file
257
- self.file_paths[index] = filepath
258
- self.file_urls[index] = urllib.parse.urljoin('file:', urllib.request.pathname2url(filepath))
259
257
  image = self.vals[index]
260
258
  assert isinstance(image, PIL.Image.Image)
261
259
  # Default to JPEG unless the image has a transparency layer (which isn't supported by JPEG).
262
260
  # In that case, use WebP instead.
263
261
  format = 'webp' if image.has_transparency_data else 'jpeg'
262
+ if not filepath.endswith(f'.{format}'):
263
+ filepath += f'.{format}'
264
+ self.file_paths[index] = filepath
265
+ self.file_urls[index] = urllib.parse.urljoin('file:', urllib.request.pathname2url(filepath))
264
266
  image.save(filepath, format=format)
265
267
  else:
266
268
  # we discard the content of this cell
pixeltable/exprs/expr.py CHANGED
@@ -276,6 +276,13 @@ class Expr(abc.ABC):
276
276
  tbl_versions = {tbl_version.id: tbl_version.get() for tbl_version in tbl.get_tbl_versions()}
277
277
  return self._retarget(tbl_versions)
278
278
 
279
+ @classmethod
280
+ def retarget_list(cls, expr_list: list[Expr], tbl: catalog.TableVersionPath) -> None:
281
+ """Retarget ColumnRefs in expr_list to the specific TableVersions in tbl."""
282
+ tbl_versions = {tbl_version.id: tbl_version.get() for tbl_version in tbl.get_tbl_versions()}
283
+ for i in range(len(expr_list)):
284
+ expr_list[i] = expr_list[i]._retarget(tbl_versions)
285
+
279
286
  def _retarget(self, tbl_versions: dict[UUID, catalog.TableVersion]) -> Self:
280
287
  for i in range(len(self.components)):
281
288
  self.components[i] = self.components[i]._retarget(tbl_versions)
@@ -16,6 +16,8 @@ from .sql_element_cache import SqlElementCache
16
16
 
17
17
 
18
18
  class Literal(Expr):
19
+ val: Any
20
+
19
21
  def __init__(self, val: Any, col_type: Optional[ts.ColumnType] = None):
20
22
  if col_type is not None:
21
23
  val = col_type.create_literal(val)
pixeltable/func/tools.py CHANGED
@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
2
2
 
3
3
  import pydantic
4
4
 
5
- import pixeltable.exceptions as excs
5
+ from pixeltable import exceptions as excs
6
6
 
7
7
  from .function import Function
8
8
  from .signature import Parameter
@@ -103,7 +103,6 @@ def invoke_tools(tools: pxt.func.Tools, response: exprs.Expr) -> exprs.InlineDic
103
103
 
104
104
  @pxt.udf
105
105
  def _gemini_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
106
- print(response)
107
106
  pxt_tool_calls: dict[str, list[dict]] = {}
108
107
  for part in response['candidates'][0]['content']['parts']:
109
108
  tool_call = part.get('function_call')
pixeltable/globals.py CHANGED
@@ -249,13 +249,17 @@ def create_view(
249
249
  where: Optional[exprs.Expr] = None
250
250
  if isinstance(base, catalog.Table):
251
251
  tbl_version_path = base._tbl_version_path
252
+ sample_clause = None
252
253
  elif isinstance(base, DataFrame):
253
254
  base._validate_mutable('create_view', allow_select=True)
254
255
  if len(base._from_clause.tbls) > 1:
255
256
  raise excs.Error('Cannot create a view of a join')
256
257
  tbl_version_path = base._from_clause.tbls[0]
257
258
  where = base.where_clause
259
+ sample_clause = base.sample_clause
258
260
  select_list = base.select_list
261
+ if sample_clause is not None and not is_snapshot and not sample_clause.is_repeatable:
262
+ raise excs.Error('Non-snapshot views cannot be created with non-fractional or stratified sampling')
259
263
  else:
260
264
  raise excs.Error('`base` must be an instance of `Table` or `DataFrame`')
261
265
  assert isinstance(base, (catalog.Table, DataFrame))
@@ -280,6 +284,7 @@ def create_view(
280
284
  tbl_version_path,
281
285
  select_list=select_list,
282
286
  where=where,
287
+ sample_clause=sample_clause,
283
288
  additional_columns=additional_columns,
284
289
  is_snapshot=is_snapshot,
285
290
  iterator=iterator,
@@ -8,15 +8,17 @@ from typing import Callable
8
8
  import sqlalchemy as sql
9
9
  from sqlalchemy import orm
10
10
 
11
+ import pixeltable as pxt
12
+ import pixeltable.exceptions as excs
11
13
  from pixeltable.utils.console_output import ConsoleLogger
12
14
 
13
15
  from .schema import SystemInfo, SystemInfoMd
14
16
 
15
17
  _console_logger = ConsoleLogger(logging.getLogger('pixeltable'))
16
-
18
+ _logger = logging.getLogger('pixeltable')
17
19
 
18
20
  # current version of the metadata; this is incremented whenever the metadata schema changes
19
- VERSION = 36
21
+ VERSION = 37
20
22
 
21
23
 
22
24
  def create_system_info(engine: sql.engine.Engine) -> None:
@@ -55,6 +57,13 @@ def upgrade_md(engine: sql.engine.Engine) -> None:
55
57
  system_info = session.query(SystemInfo).one().md
56
58
  md_version = system_info['schema_version']
57
59
  assert isinstance(md_version, int)
60
+ _logger.info(f'Current database version: {md_version}, installed version: {VERSION}')
61
+ if md_version > VERSION:
62
+ raise excs.Error(
63
+ 'This Pixeltable database was created with a newer Pixeltable version '
64
+ f'than the one currently installed ({pxt.__version__}).\n'
65
+ 'Please update to the latest Pixeltable version by running: pip install --upgrade pixeltable'
66
+ )
58
67
  if md_version == VERSION:
59
68
  return
60
69
  while md_version < VERSION:
@@ -0,0 +1,38 @@
1
+ import logging
2
+ from typing import Any, Optional
3
+ from uuid import UUID
4
+
5
+ import sqlalchemy as sql
6
+
7
+ from pixeltable.metadata import register_converter
8
+ from pixeltable.metadata.converters.util import convert_table_md
9
+
10
+ _logger = logging.getLogger('pixeltable')
11
+
12
+
13
+ @register_converter(version=36)
14
+ def _(engine: sql.engine.Engine) -> None:
15
+ convert_table_md(engine, table_md_updater=__update_table_md, substitution_fn=__substitute_md)
16
+
17
+
18
+ def __update_table_md(table_md: dict, table_id: UUID) -> None:
19
+ """Update the view metadata to add the sample_clause field if it is missing
20
+
21
+ Args:
22
+ table_md (dict): copy of the original table metadata. this gets updated in place.
23
+ table_id (UUID): the table id
24
+
25
+ """
26
+ if table_md['view_md'] is None:
27
+ return
28
+ if 'sample_clause' not in table_md['view_md']:
29
+ table_md['view_md']['sample_clause'] = None
30
+ _logger.info(f'Updating view metadata for table: {table_id}')
31
+
32
+
33
+ def __substitute_md(k: Optional[str], v: Any) -> Optional[tuple[Optional[str], Any]]:
34
+ if isinstance(v, dict) and (v.get('_classname') == 'DataFrame'):
35
+ if 'sample_clause' not in v:
36
+ v['sample_clause'] = None
37
+ return k, v
38
+ return None
@@ -2,6 +2,7 @@
2
2
  # rather than as a comment, so that the existence of a description can be enforced by
3
3
  # the unit tests when new versions are added.
4
4
  VERSION_NOTES = {
5
+ 37: 'Add support for the sample() method on DataFrames',
5
6
  36: 'Added Table.lock_dummy',
6
7
  35: 'Track reference_tbl in ColumnRef',
7
8
  34: 'Set default value for is_pk field in column metadata to False',
@@ -147,6 +147,9 @@ class ViewMd:
147
147
  # filter predicate applied to the base table; view-only
148
148
  predicate: Optional[dict[str, Any]]
149
149
 
150
+ # sampling predicate applied to the base table; view-only
151
+ sample_clause: Optional[dict[str, Any]]
152
+
150
153
  # ComponentIterator subclass; only for component views
151
154
  iterator_class_fqn: Optional[str]
152
155