pixeltable 0.2.25__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (97) hide show
  1. pixeltable/__init__.py +2 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +1 -1
  4. pixeltable/catalog/dir.py +6 -0
  5. pixeltable/catalog/globals.py +25 -0
  6. pixeltable/catalog/named_function.py +4 -0
  7. pixeltable/catalog/path_dict.py +37 -11
  8. pixeltable/catalog/schema_object.py +6 -0
  9. pixeltable/catalog/table.py +421 -231
  10. pixeltable/catalog/table_version.py +22 -8
  11. pixeltable/catalog/view.py +5 -7
  12. pixeltable/dataframe.py +439 -105
  13. pixeltable/env.py +19 -5
  14. pixeltable/exec/__init__.py +1 -1
  15. pixeltable/exec/exec_node.py +6 -7
  16. pixeltable/exec/expr_eval_node.py +1 -1
  17. pixeltable/exec/sql_node.py +92 -45
  18. pixeltable/exprs/__init__.py +1 -0
  19. pixeltable/exprs/arithmetic_expr.py +1 -1
  20. pixeltable/exprs/array_slice.py +1 -1
  21. pixeltable/exprs/column_property_ref.py +1 -1
  22. pixeltable/exprs/column_ref.py +29 -2
  23. pixeltable/exprs/comparison.py +1 -1
  24. pixeltable/exprs/compound_predicate.py +1 -1
  25. pixeltable/exprs/expr.py +12 -5
  26. pixeltable/exprs/expr_set.py +8 -0
  27. pixeltable/exprs/function_call.py +147 -39
  28. pixeltable/exprs/in_predicate.py +1 -1
  29. pixeltable/exprs/inline_expr.py +25 -5
  30. pixeltable/exprs/is_null.py +1 -1
  31. pixeltable/exprs/json_mapper.py +1 -1
  32. pixeltable/exprs/json_path.py +1 -1
  33. pixeltable/exprs/method_ref.py +1 -1
  34. pixeltable/exprs/row_builder.py +1 -1
  35. pixeltable/exprs/rowid_ref.py +1 -1
  36. pixeltable/exprs/similarity_expr.py +14 -7
  37. pixeltable/exprs/sql_element_cache.py +4 -0
  38. pixeltable/exprs/type_cast.py +2 -2
  39. pixeltable/exprs/variable.py +3 -0
  40. pixeltable/func/__init__.py +5 -4
  41. pixeltable/func/aggregate_function.py +151 -68
  42. pixeltable/func/callable_function.py +48 -16
  43. pixeltable/func/expr_template_function.py +64 -23
  44. pixeltable/func/function.py +195 -27
  45. pixeltable/func/function_registry.py +2 -1
  46. pixeltable/func/query_template_function.py +51 -9
  47. pixeltable/func/signature.py +64 -7
  48. pixeltable/func/tools.py +153 -0
  49. pixeltable/func/udf.py +57 -35
  50. pixeltable/functions/__init__.py +2 -2
  51. pixeltable/functions/anthropic.py +51 -4
  52. pixeltable/functions/gemini.py +85 -0
  53. pixeltable/functions/globals.py +54 -34
  54. pixeltable/functions/huggingface.py +10 -28
  55. pixeltable/functions/json.py +3 -8
  56. pixeltable/functions/math.py +67 -0
  57. pixeltable/functions/ollama.py +8 -8
  58. pixeltable/functions/openai.py +51 -4
  59. pixeltable/functions/timestamp.py +1 -1
  60. pixeltable/functions/video.py +3 -9
  61. pixeltable/functions/vision.py +1 -1
  62. pixeltable/globals.py +354 -80
  63. pixeltable/index/embedding_index.py +106 -34
  64. pixeltable/io/__init__.py +1 -1
  65. pixeltable/io/label_studio.py +1 -1
  66. pixeltable/io/parquet.py +39 -19
  67. pixeltable/iterators/document.py +12 -0
  68. pixeltable/metadata/__init__.py +1 -1
  69. pixeltable/metadata/converters/convert_16.py +2 -1
  70. pixeltable/metadata/converters/convert_17.py +2 -1
  71. pixeltable/metadata/converters/convert_22.py +17 -0
  72. pixeltable/metadata/converters/convert_23.py +35 -0
  73. pixeltable/metadata/converters/convert_24.py +56 -0
  74. pixeltable/metadata/converters/convert_25.py +19 -0
  75. pixeltable/metadata/converters/util.py +4 -2
  76. pixeltable/metadata/notes.py +4 -0
  77. pixeltable/metadata/schema.py +1 -0
  78. pixeltable/plan.py +128 -50
  79. pixeltable/store.py +1 -1
  80. pixeltable/type_system.py +196 -54
  81. pixeltable/utils/arrow.py +8 -3
  82. pixeltable/utils/description_helper.py +89 -0
  83. pixeltable/utils/documents.py +14 -0
  84. {pixeltable-0.2.25.dist-info → pixeltable-0.3.0.dist-info}/METADATA +30 -20
  85. pixeltable-0.3.0.dist-info/RECORD +155 -0
  86. {pixeltable-0.2.25.dist-info → pixeltable-0.3.0.dist-info}/WHEEL +1 -1
  87. pixeltable-0.3.0.dist-info/entry_points.txt +3 -0
  88. pixeltable/tool/create_test_db_dump.py +0 -311
  89. pixeltable/tool/create_test_video.py +0 -81
  90. pixeltable/tool/doc_plugins/griffe.py +0 -50
  91. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  92. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  93. pixeltable/tool/embed_udf.py +0 -9
  94. pixeltable/tool/mypy_plugin.py +0 -55
  95. pixeltable-0.2.25.dist-info/RECORD +0 -154
  96. pixeltable-0.2.25.dist-info/entry_points.txt +0 -3
  97. {pixeltable-0.2.25.dist-info → pixeltable-0.3.0.dist-info}/LICENSE +0 -0
pixeltable/env.py CHANGED
@@ -8,6 +8,7 @@ import importlib.util
8
8
  import inspect
9
9
  import logging
10
10
  import os
11
+ import platform
11
12
  import shutil
12
13
  import subprocess
13
14
  import sys
@@ -275,6 +276,7 @@ class Env:
275
276
  if self._config.get_bool_value('hide_warnings'):
276
277
  # Disable more warnings
277
278
  warnings.simplefilter('ignore', category=UserWarning)
279
+ warnings.simplefilter('ignore', category=FutureWarning)
278
280
 
279
281
  # configure _logger to log to a file
280
282
  self._logfilename = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + '.log'
@@ -311,8 +313,12 @@ class Env:
311
313
  self._db_name = os.environ.get('PIXELTABLE_DB', 'pixeltable')
312
314
  self._pgdata_dir = Path(os.environ.get('PIXELTABLE_PGDATA', str(self._home / 'pgdata')))
313
315
 
314
- # in pixeltable_pgserver.get_server(): cleanup_mode=None will leave db on for debugging purposes
315
- self._db_server = pixeltable_pgserver.get_server(self._pgdata_dir, cleanup_mode=None)
316
+ # cleanup_mode=None will leave the postgres process running after Python exits
317
+ # cleanup_mode='stop' will terminate the postgres process when Python exits
318
+ # On Windows, we need cleanup_mode='stop' because child processes are killed automatically when the parent
319
+ # process (such as Terminal or VSCode) exits, potentially leaving it in an unusable state.
320
+ cleanup_mode = 'stop' if platform.system() == 'Windows' else None
321
+ self._db_server = pixeltable_pgserver.get_server(self._pgdata_dir, cleanup_mode=cleanup_mode)
316
322
  self._db_url = self._db_server.get_uri(database=self._db_name, driver='psycopg')
317
323
 
318
324
  tz_name = self.config.get_string_value('time_zone')
@@ -357,7 +363,7 @@ class Env:
357
363
  self.db_url,
358
364
  echo=echo,
359
365
  future=True,
360
- isolation_level='AUTOCOMMIT',
366
+ isolation_level='REPEATABLE READ',
361
367
  connect_args=connect_args,
362
368
  )
363
369
  self._logger.info(f'Created SQLAlchemy engine at: {self.db_url}')
@@ -496,6 +502,7 @@ class Env:
496
502
  self.__register_package('datasets')
497
503
  self.__register_package('fiftyone')
498
504
  self.__register_package('fireworks', library_name='fireworks-ai')
505
+ self.__register_package('google.generativeai', library_name='google-generativeai')
499
506
  self.__register_package('huggingface_hub', library_name='huggingface-hub')
500
507
  self.__register_package('label_studio_sdk', library_name='label-studio-sdk')
501
508
  self.__register_package('llama_cpp', library_name='llama-cpp-python')
@@ -505,6 +512,7 @@ class Env:
505
512
  self.__register_package('openai')
506
513
  self.__register_package('openpyxl')
507
514
  self.__register_package('pyarrow')
515
+ self.__register_package('pydantic')
508
516
  self.__register_package('replicate')
509
517
  self.__register_package('sentencepiece')
510
518
  self.__register_package('sentence_transformers', library_name='sentence-transformers')
@@ -520,8 +528,14 @@ class Env:
520
528
  self.__register_package('yolox', library_name='git+https://github.com/Megvii-BaseDetection/YOLOX@ac58e0a')
521
529
 
522
530
  def __register_package(self, package_name: str, library_name: Optional[str] = None) -> None:
531
+ is_installed: bool
532
+ try:
533
+ is_installed = importlib.util.find_spec(package_name) is not None
534
+ except ModuleNotFoundError:
535
+ # This can happen if the parent of `package_name` is not installed.
536
+ is_installed = False
523
537
  self.__optional_packages[package_name] = PackageInfo(
524
- is_installed=importlib.util.find_spec(package_name) is not None,
538
+ is_installed=is_installed,
525
539
  library_name=library_name or package_name # defaults to package_name unless specified otherwise
526
540
  )
527
541
 
@@ -577,7 +591,7 @@ class Env:
577
591
  self._logger.info(f'Ensuring spaCy model is installed: {filename}')
578
592
  ret = subprocess.run([sys.executable, '-m', 'pip', 'install', '-qU', url], check=False)
579
593
  if ret.returncode != 0:
580
- self._logger.warn(f'pip install failed for spaCy model: {filename}')
594
+ self._logger.warning(f'pip install failed for spaCy model: {filename}')
581
595
  try:
582
596
  self._logger.info(f'Loading spaCy model: {spacy_model}')
583
597
  self._spacy_nlp = spacy.load(spacy_model)
@@ -7,4 +7,4 @@ from .exec_node import ExecNode
7
7
  from .expr_eval_node import ExprEvalNode
8
8
  from .in_memory_data_node import InMemoryDataNode
9
9
  from .row_update_node import RowUpdateNode
10
- from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode
10
+ from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode, SqlJoinNode
@@ -1,15 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
- from typing import TYPE_CHECKING, Iterable, Iterator, Optional
4
+ from typing import TYPE_CHECKING, Iterable, Iterator, Optional, TypeVar
5
5
 
6
6
  import pixeltable.exprs as exprs
7
7
 
8
8
  from .data_row_batch import DataRowBatch
9
9
  from .exec_context import ExecContext
10
10
 
11
- if TYPE_CHECKING:
12
- from pixeltable import exec
13
11
 
14
12
  class ExecNode(abc.ABC):
15
13
  """Base class of all execution nodes"""
@@ -77,12 +75,13 @@ class ExecNode(abc.ABC):
77
75
  def _close(self) -> None:
78
76
  pass
79
77
 
80
- def get_sql_node(self) -> Optional['exec.SqlNode']:
81
- from .sql_node import SqlNode
82
- if isinstance(self, SqlNode):
78
+ T = TypeVar('T', bound='ExecNode')
79
+
80
+ def get_node(self, node_class: type[T]) -> Optional[T]:
81
+ if isinstance(self, node_class):
83
82
  return self
84
83
  if self.input is not None:
85
- return self.input.get_sql_node()
84
+ return self.input.get_node(node_class)
86
85
  return None
87
86
 
88
87
  def set_limit(self, limit: int) -> None:
@@ -208,7 +208,7 @@ class ExprEvalNode(ExecNode):
208
208
  }
209
209
  start_ts = time.perf_counter()
210
210
  assert isinstance(fn_call.fn, CallableFunction)
211
- result_batch = fn_call.fn.exec_batch(*call_args, **call_kwargs)
211
+ result_batch = fn_call.fn.exec_batch(call_args, call_kwargs)
212
212
  self.ctx.profile.eval_time[fn_call.slot_idx] += time.perf_counter() - start_ts
213
213
  self.ctx.profile.eval_count[fn_call.slot_idx] += num_ext_batch_rows
214
214
 
@@ -1,17 +1,19 @@
1
1
  import logging
2
2
  import warnings
3
3
  from decimal import Decimal
4
- from typing import Iterable, Iterator, NamedTuple, Optional
4
+ from typing import Iterable, Iterator, NamedTuple, Optional, TYPE_CHECKING, Sequence
5
5
  from uuid import UUID
6
6
 
7
7
  import sqlalchemy as sql
8
8
 
9
9
  import pixeltable.catalog as catalog
10
10
  import pixeltable.exprs as exprs
11
-
12
11
  from .data_row_batch import DataRowBatch
13
12
  from .exec_node import ExecNode
14
13
 
14
+ if TYPE_CHECKING:
15
+ import pixeltable.plan
16
+
15
17
  _logger = logging.getLogger('pixeltable')
16
18
 
17
19
 
@@ -67,12 +69,17 @@ class SqlNode(ExecNode):
67
69
  select_list: exprs.ExprSet
68
70
  set_pk: bool
69
71
  num_pk_cols: int
70
- filter: Optional[exprs.Expr]
71
- filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
72
+ py_filter: Optional[exprs.Expr] # a predicate that can only be run in Python
73
+ py_filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
72
74
  cte: Optional[sql.CTE]
73
75
  sql_elements: exprs.SqlElementCache
74
- limit: Optional[int]
76
+
77
+ # where_clause/-_element: allow subclass to set one or the other (but not both)
78
+ where_clause: Optional[exprs.Expr]
79
+ where_clause_element: Optional[sql.ColumnElement]
80
+
75
81
  order_by_clause: OrderByClause
82
+ limit: Optional[int]
76
83
 
77
84
  def __init__(
78
85
  self, tbl: Optional[catalog.TableVersionPath], row_builder: exprs.RowBuilder,
@@ -89,6 +96,7 @@ class SqlNode(ExecNode):
89
96
  # create Select stmt
90
97
  self.sql_elements = sql_elements
91
98
  self.tbl = tbl
99
+ assert all(not isinstance(e, exprs.Literal) for e in select_list) # we're never asked to materialize literals
92
100
  self.select_list = exprs.ExprSet(select_list)
93
101
  # unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
94
102
  for iter_arg in row_builder.unstored_iter_args.values():
@@ -112,10 +120,12 @@ class SqlNode(ExecNode):
112
120
  # additional state
113
121
  self.result_cursor = None
114
122
  # the filter is provided by the subclass
115
- self.filter = None
116
- self.filter_eval_ctx = None
123
+ self.py_filter = None
124
+ self.py_filter_eval_ctx = None
117
125
  self.cte = None
118
126
  self.limit = None
127
+ self.where_clause = None
128
+ self.where_clause_element = None
119
129
  self.order_by_clause = []
120
130
 
121
131
  def _create_stmt(self) -> sql.Select:
@@ -124,9 +134,16 @@ class SqlNode(ExecNode):
124
134
  assert self.sql_elements.contains_all(self.select_list)
125
135
  sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
126
136
  if self.set_pk:
137
+ assert self.tbl is not None
127
138
  sql_select_list += self.tbl.tbl_version.store_tbl.pk_columns()
128
139
  stmt = sql.select(*sql_select_list)
129
140
 
141
+ where_clause_element = (
142
+ self.sql_elements.get(self.where_clause) if self.where_clause is not None else self.where_clause_element
143
+ )
144
+ if where_clause_element is not None:
145
+ stmt = stmt.where(where_clause_element)
146
+
130
147
  order_by_clause: list[sql.ColumnElement] = []
131
148
  for e, asc in self.order_by_clause:
132
149
  if isinstance(e, exprs.SimilarityExpr):
@@ -135,7 +152,7 @@ class SqlNode(ExecNode):
135
152
  order_by_clause.append(self.sql_elements.get(e).desc() if asc is False else self.sql_elements.get(e))
136
153
  stmt = stmt.order_by(*order_by_clause)
137
154
 
138
- if self.filter is None and self.limit is not None:
155
+ if self.py_filter is None and self.limit is not None:
139
156
  # if we don't have a Python filter, we can apply the limit to stmt
140
157
  stmt = stmt.limit(self.limit)
141
158
 
@@ -151,7 +168,7 @@ class SqlNode(ExecNode):
151
168
  Returns:
152
169
  (CTE, dict from Expr to output column)
153
170
  """
154
- if self.filter is not None:
171
+ if self.py_filter is not None:
155
172
  # the filter needs to run in Python
156
173
  return None
157
174
  self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
@@ -215,8 +232,17 @@ class SqlNode(ExecNode):
215
232
  prev_tbl = tbl
216
233
  return stmt
217
234
 
218
- def add_order_by(self, ordering: OrderByClause) -> None:
219
- """Add Order By clause to stmt"""
235
+ def set_where(self, where_clause: exprs.Expr) -> None:
236
+ assert self.where_clause_element is None
237
+ self.where_clause = where_clause
238
+
239
+ def set_py_filter(self, py_filter: exprs.Expr) -> None:
240
+ assert self.py_filter is None
241
+ self.py_filter = py_filter
242
+ self.py_filter_eval_ctx = self.row_builder.create_eval_ctx([py_filter], exclude=self.select_list)
243
+
244
+ def set_order_by(self, ordering: OrderByClause) -> None:
245
+ """Add Order By clause"""
220
246
  if self.tbl is not None:
221
247
  # change rowid refs against a base table to rowid refs against the target table, so that we minimize
222
248
  # the number of tables that need to be joined to the target table
@@ -236,7 +262,7 @@ class SqlNode(ExecNode):
236
262
  explain_str = '\n'.join([str(row) for row in explain_result])
237
263
  _logger.debug(f'SqlScanNode explain:\n{explain_str}')
238
264
  except Exception as e:
239
- _logger.warning(f'EXPLAIN failed')
265
+ _logger.warning(f'EXPLAIN failed with error: {e}')
240
266
 
241
267
  def __iter__(self) -> Iterator[DataRowBatch]:
242
268
  # run the query; do this here rather than in _open(), exceptions are only expected during iteration
@@ -280,10 +306,10 @@ class SqlNode(ExecNode):
280
306
  else:
281
307
  output_row[slot_idx] = sql_row[i]
282
308
 
283
- if self.filter is not None:
309
+ if self.py_filter is not None:
284
310
  # evaluate filter
285
- self.row_builder.eval(output_row, self.filter_eval_ctx, profile=self.ctx.profile)
286
- if self.filter is not None and not output_row[self.filter.slot_idx]:
311
+ self.row_builder.eval(output_row, self.py_filter_eval_ctx, profile=self.ctx.profile)
312
+ if self.py_filter is not None and not output_row[self.py_filter.slot_idx]:
287
313
  # we re-use this row for the next sql row since it didn't pass the filter
288
314
  output_row = output_batch.pop_row()
289
315
  output_row.clear()
@@ -315,21 +341,16 @@ class SqlScanNode(SqlNode):
315
341
 
316
342
  Supports filtering and ordering.
317
343
  """
318
- where_clause: Optional[exprs.Expr]
319
344
  exact_version_only: list[catalog.TableVersion]
320
345
 
321
346
  def __init__(
322
- self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
323
- select_list: Iterable[exprs.Expr],
324
- where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Expr] = None,
325
- set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
347
+ self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
348
+ select_list: Iterable[exprs.Expr],
349
+ set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
326
350
  ):
327
351
  """
328
352
  Args:
329
353
  select_list: output of the query
330
- sql_where_clause: SQL Where clause
331
- filter: additional Where-clause predicate that can't be evaluated via SQL
332
- limit: max number of rows to return: 0 = no limit
333
354
  set_pk: if True, sets the primary for each DataRow
334
355
  exact_version_only: tables for which we only want to see rows created at the current version
335
356
  """
@@ -338,12 +359,7 @@ class SqlScanNode(SqlNode):
338
359
  # create Select stmt
339
360
  if exact_version_only is None:
340
361
  exact_version_only = []
341
- target = tbl.tbl_version # the stored table we're scanning
342
- self.filter = filter
343
- self.filter_eval_ctx = \
344
- row_builder.create_eval_ctx([filter], exclude=select_list) if filter is not None else None
345
362
 
346
- self.where_clause = where_clause
347
363
  self.exact_version_only = exact_version_only
348
364
 
349
365
  def _create_stmt(self) -> sql.Select:
@@ -352,12 +368,6 @@ class SqlScanNode(SqlNode):
352
368
  refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
353
369
  stmt = self.create_from_clause(
354
370
  self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only})
355
-
356
- if self.where_clause is not None:
357
- sql_where_clause = self.sql_elements.get(self.where_clause)
358
- assert sql_where_clause is not None
359
- stmt = stmt.where(sql_where_clause)
360
-
361
371
  return stmt
362
372
 
363
373
 
@@ -366,11 +376,9 @@ class SqlLookupNode(SqlNode):
366
376
  Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
367
377
  """
368
378
 
369
- where_clause: sql.ColumnElement
370
-
371
379
  def __init__(
372
- self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
373
- select_list: Iterable[exprs.Expr], sa_key_cols: list[sql.Column], key_vals: list[tuple],
380
+ self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
381
+ select_list: Iterable[exprs.Expr], sa_key_cols: list[sql.Column], key_vals: list[tuple],
374
382
  ):
375
383
  """
376
384
  Args:
@@ -381,15 +389,15 @@ class SqlLookupNode(SqlNode):
381
389
  sql_elements = exprs.SqlElementCache()
382
390
  super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
383
391
  # Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
384
- self.where_clause = sql.tuple_(*sa_key_cols).in_(key_vals)
392
+ self.where_clause_element = sql.tuple_(*sa_key_cols).in_(key_vals)
385
393
 
386
394
  def _create_stmt(self) -> sql.Select:
387
395
  stmt = super()._create_stmt()
388
396
  refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | self._ordering_tbl_ids()
389
397
  stmt = self.create_from_clause(self.tbl, stmt, refd_tbl_ids)
390
- stmt = stmt.where(self.where_clause)
391
398
  return stmt
392
399
 
400
+
393
401
  class SqlAggregationNode(SqlNode):
394
402
  """
395
403
  Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
@@ -398,11 +406,11 @@ class SqlAggregationNode(SqlNode):
398
406
  group_by_items: Optional[list[exprs.Expr]]
399
407
 
400
408
  def __init__(
401
- self, row_builder: exprs.RowBuilder,
402
- input: SqlNode,
403
- select_list: Iterable[exprs.Expr],
404
- group_by_items: Optional[list[exprs.Expr]] = None,
405
- limit: Optional[int] = None, exact_version_only: Optional[list[catalog.TableVersion]] = None
409
+ self, row_builder: exprs.RowBuilder,
410
+ input: SqlNode,
411
+ select_list: Iterable[exprs.Expr],
412
+ group_by_items: Optional[list[exprs.Expr]] = None,
413
+ limit: Optional[int] = None, exact_version_only: Optional[list[catalog.TableVersion]] = None
406
414
  ):
407
415
  """
408
416
  Args:
@@ -422,3 +430,42 @@ class SqlAggregationNode(SqlNode):
422
430
  assert all(e is not None for e in sql_group_by_items)
423
431
  stmt = stmt.group_by(*sql_group_by_items)
424
432
  return stmt
433
+
434
+
435
+ class SqlJoinNode(SqlNode):
436
+ """
437
+ Materializes data from the store via a Select ... From ... that contains joins
438
+ """
439
+ input_ctes: list[sql.CTE]
440
+ join_clauses: list['pixeltable.plan.JoinClause']
441
+
442
+ def __init__(
443
+ self, row_builder: exprs.RowBuilder,
444
+ inputs: Sequence[SqlNode], join_clauses: list['pixeltable.plan.JoinClause'], select_list: Iterable[exprs.Expr]
445
+ ):
446
+ assert len(inputs) > 1
447
+ assert len(inputs) == len(join_clauses) + 1
448
+ self.input_ctes = []
449
+ self.join_clauses = join_clauses
450
+ sql_elements = exprs.SqlElementCache()
451
+ for input_node in inputs:
452
+ input_cte, input_col_map = input_node.to_cte()
453
+ self.input_ctes.append(input_cte)
454
+ sql_elements.extend(input_col_map)
455
+ super().__init__(None, row_builder, select_list, sql_elements)
456
+
457
+ def _create_stmt(self) -> sql.Select:
458
+ from pixeltable import plan
459
+ stmt = super()._create_stmt()
460
+ stmt = stmt.select_from(self.input_ctes[0])
461
+ for i in range(len(self.join_clauses)):
462
+ join_clause = self.join_clauses[i]
463
+ on_clause = (
464
+ self.sql_elements.get(join_clause.join_predicate) if join_clause.join_type != plan.JoinType.CROSS
465
+ else sql.sql.expression.literal(True)
466
+ )
467
+ is_outer = join_clause.join_type == plan.JoinType.LEFT or join_clause.join_type == plan.JoinType.FULL_OUTER
468
+ stmt = stmt.join(
469
+ self.input_ctes[i + 1], onclause=on_clause, isouter=is_outer,
470
+ full=join_clause == plan.JoinType.FULL_OUTER)
471
+ return stmt
@@ -23,3 +23,4 @@ from .similarity_expr import SimilarityExpr
23
23
  from .sql_element_cache import SqlElementCache
24
24
  from .type_cast import TypeCast
25
25
  from .variable import Variable
26
+ from .globals import ComparisonOperator, LogicalOperator, ArithmeticOperator
@@ -35,7 +35,7 @@ class ArithmeticExpr(Expr):
35
35
 
36
36
  self.id = self._create_id()
37
37
 
38
- def __str__(self) -> str:
38
+ def __repr__(self) -> str:
39
39
  # add parentheses around operands that are ArithmeticExprs to express precedence
40
40
  op1_str = f'({self._op1})' if isinstance(self._op1, ArithmeticExpr) else str(self._op1)
41
41
  op2_str = f'({self._op2})' if isinstance(self._op2, ArithmeticExpr) else str(self._op2)
@@ -23,7 +23,7 @@ class ArraySlice(Expr):
23
23
  self.index = index
24
24
  self.id = self._create_id()
25
25
 
26
- def __str__(self) -> str:
26
+ def __repr__(self) -> str:
27
27
  index_strs: list[str] = []
28
28
  for el in self.index:
29
29
  if isinstance(el, int):
@@ -46,7 +46,7 @@ class ColumnPropertyRef(Expr):
46
46
  assert isinstance(col_ref, ColumnRef)
47
47
  return col_ref
48
48
 
49
- def __str__(self) -> str:
49
+ def __repr__(self) -> str:
50
50
  return f'{self._col_ref}.{self.prop.name.lower()}'
51
51
 
52
52
  def is_error_prop(self) -> bool:
@@ -5,10 +5,12 @@ from uuid import UUID
5
5
 
6
6
  import sqlalchemy as sql
7
7
 
8
+ import pixeltable as pxt
8
9
  import pixeltable.catalog as catalog
9
10
  import pixeltable.exceptions as excs
10
11
  import pixeltable.iterators as iters
11
12
 
13
+ from ..utils.description_helper import DescriptionHelper
12
14
  from .data_row import DataRow
13
15
  from .expr import Expr
14
16
  from .row_builder import RowBuilder
@@ -126,6 +128,22 @@ class ColumnRef(Expr):
126
128
  def _equals(self, other: ColumnRef) -> bool:
127
129
  return self.col == other.col and self.perform_validation == other.perform_validation
128
130
 
131
+ def _df(self) -> 'pxt.dataframe.DataFrame':
132
+ tbl = catalog.Catalog.get().tbls[self.col.tbl.id]
133
+ return tbl.select(self)
134
+
135
+ def show(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
136
+ return self._df().show(*args, **kwargs)
137
+
138
+ def head(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
139
+ return self._df().head(*args, **kwargs)
140
+
141
+ def tail(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
142
+ return self._df().tail(*args, **kwargs)
143
+
144
+ def count(self) -> int:
145
+ return self._df().count()
146
+
129
147
  def __str__(self) -> str:
130
148
  if self.col.name is None:
131
149
  return f'<unnamed column {self.col.id}>'
@@ -133,11 +151,20 @@ class ColumnRef(Expr):
133
151
  return self.col.name
134
152
 
135
153
  def __repr__(self) -> str:
136
- return f'ColumnRef({self.col!r})'
154
+ return self._descriptors().to_string()
137
155
 
138
156
  def _repr_html_(self) -> str:
157
+ return self._descriptors().to_html()
158
+
159
+ def _descriptors(self) -> DescriptionHelper:
139
160
  tbl = catalog.Catalog.get().tbls[self.col.tbl.id]
140
- return tbl._description_html(cols=[self.col])._repr_html_() # type: ignore[attr-defined]
161
+ helper = DescriptionHelper()
162
+ helper.append(f'Column\n{self.col.name!r}\n(of table {tbl._path!r})')
163
+ helper.append(tbl._col_descriptor([self.col.name]))
164
+ idxs = tbl._index_descriptor([self.col.name])
165
+ if len(idxs) > 0:
166
+ helper.append(idxs)
167
+ return helper
141
168
 
142
169
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
143
170
  return None if self.perform_validation else self.col.sa_col
@@ -49,7 +49,7 @@ class Comparison(Expr):
49
49
 
50
50
  self.id = self._create_id()
51
51
 
52
- def __str__(self) -> str:
52
+ def __repr__(self) -> str:
53
53
  return f'{self._op1} {self.operator} {self._op2}'
54
54
 
55
55
  def _equals(self, other: Comparison) -> bool:
@@ -30,7 +30,7 @@ class CompoundPredicate(Expr):
30
30
 
31
31
  self.id = self._create_id()
32
32
 
33
- def __str__(self) -> str:
33
+ def __repr__(self) -> str:
34
34
  if self.operator == LogicalOperator.NOT:
35
35
  return f'~({self.components[0]})'
36
36
  return f' {self.operator} '.join([f'({e})' for e in self.components])
pixeltable/exprs/expr.py CHANGED
@@ -190,6 +190,7 @@ class Expr(abc.ABC):
190
190
  return new.copy()
191
191
  for i in range(len(self.components)):
192
192
  self.components[i] = self.components[i].substitute(spec)
193
+ self.id = self._create_id()
193
194
  return self
194
195
 
195
196
  @classmethod
@@ -216,12 +217,12 @@ class Expr(abc.ABC):
216
217
  return result
217
218
  result = result.substitute({ref: ref.col.value_expr for ref in target_col_refs})
218
219
 
219
- def is_bound_by(self, tbl: catalog.TableVersionPath) -> bool:
220
- """Returns True if this expr can be evaluated in the context of tbl."""
220
+ def is_bound_by(self, tbls: list[catalog.TableVersionPath]) -> bool:
221
+ """Returns True if this expr can be evaluated in the context of tbls."""
221
222
  from .column_ref import ColumnRef
222
223
  col_refs = self.subexprs(ColumnRef)
223
224
  for col_ref in col_refs:
224
- if not tbl.has_column(col_ref.col):
225
+ if not any(tbl.has_column(col_ref.col) for tbl in tbls):
225
226
  return False
226
227
  return True
227
228
 
@@ -235,7 +236,7 @@ class Expr(abc.ABC):
235
236
  self.components[i] = self.components[i]._retarget(tbl_versions)
236
237
  return self
237
238
 
238
- def __str__(self) -> str:
239
+ def __repr__(self) -> str:
239
240
  return f'<Expression of type {type(self)}>'
240
241
 
241
242
  def display_str(self, inline: bool = True) -> str:
@@ -450,7 +451,13 @@ class Expr(abc.ABC):
450
451
 
451
452
  def astype(self, new_type: Union[ts.ColumnType, type, _AnnotatedAlias]) -> 'exprs.TypeCast':
452
453
  from pixeltable.exprs import TypeCast
453
- return TypeCast(self, ts.ColumnType.normalize_type(new_type))
454
+ # Interpret the type argument the same way we would if given in a schema
455
+ col_type = ts.ColumnType.normalize_type(new_type, nullable_default=True, allow_builtin_types=False)
456
+ if not self.col_type.nullable:
457
+ # This expression is non-nullable; we can prove that the output is non-nullable, regardless of
458
+ # whether new_type is given as nullable.
459
+ col_type = col_type.copy(nullable=False)
460
+ return TypeCast(self, col_type)
454
461
 
455
462
  def apply(self, fn: Callable, *, col_type: Union[ts.ColumnType, type, _AnnotatedAlias, None] = None) -> 'exprs.FunctionCall':
456
463
  if col_type is not None:
@@ -60,6 +60,14 @@ class ExprSet(Generic[T]):
60
60
  def __le__(self, other: ExprSet[T]) -> bool:
61
61
  return other.issuperset(self)
62
62
 
63
+ def union(self, *others: Iterable[T]) -> ExprSet[T]:
64
+ result = ExprSet(self.exprs.values())
65
+ result.update(*others)
66
+ return result
67
+
68
+ def __or__(self, other: ExprSet[T]) -> ExprSet[T]:
69
+ return self.union(other)
70
+
63
71
  def difference(self, *others: Iterable[T]) -> ExprSet[T]:
64
72
  id_diff = set(self.exprs.keys()).difference(e.id for other_set in others for e in other_set)
65
73
  return ExprSet(self.exprs[id] for id in id_diff)