pixeltable 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (140) hide show
  1. pixeltable/__init__.py +21 -4
  2. pixeltable/catalog/__init__.py +13 -0
  3. pixeltable/catalog/catalog.py +159 -0
  4. pixeltable/catalog/column.py +200 -0
  5. pixeltable/catalog/dir.py +32 -0
  6. pixeltable/catalog/globals.py +33 -0
  7. pixeltable/catalog/insertable_table.py +191 -0
  8. pixeltable/catalog/named_function.py +36 -0
  9. pixeltable/catalog/path.py +58 -0
  10. pixeltable/catalog/path_dict.py +139 -0
  11. pixeltable/catalog/schema_object.py +39 -0
  12. pixeltable/catalog/table.py +581 -0
  13. pixeltable/catalog/table_version.py +749 -0
  14. pixeltable/catalog/table_version_path.py +133 -0
  15. pixeltable/catalog/view.py +203 -0
  16. pixeltable/client.py +520 -31
  17. pixeltable/dataframe.py +540 -349
  18. pixeltable/env.py +373 -48
  19. pixeltable/exceptions.py +12 -21
  20. pixeltable/exec/__init__.py +9 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +113 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +95 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +69 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +225 -0
  31. pixeltable/exprs/__init__.py +24 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +105 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +187 -0
  39. pixeltable/exprs/expr.py +586 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +380 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +115 -0
  44. pixeltable/exprs/image_similarity_predicate.py +58 -0
  45. pixeltable/exprs/inline_array.py +107 -0
  46. pixeltable/exprs/inline_dict.py +101 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +54 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +355 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/type_cast.py +53 -0
  56. pixeltable/exprs/variable.py +45 -0
  57. pixeltable/func/__init__.py +9 -0
  58. pixeltable/func/aggregate_function.py +194 -0
  59. pixeltable/func/batched_function.py +53 -0
  60. pixeltable/func/callable_function.py +69 -0
  61. pixeltable/func/expr_template_function.py +82 -0
  62. pixeltable/func/function.py +110 -0
  63. pixeltable/func/function_registry.py +227 -0
  64. pixeltable/func/globals.py +36 -0
  65. pixeltable/func/nos_function.py +202 -0
  66. pixeltable/func/signature.py +166 -0
  67. pixeltable/func/udf.py +163 -0
  68. pixeltable/functions/__init__.py +52 -103
  69. pixeltable/functions/eval.py +216 -0
  70. pixeltable/functions/fireworks.py +61 -0
  71. pixeltable/functions/huggingface.py +120 -0
  72. pixeltable/functions/image.py +16 -0
  73. pixeltable/functions/openai.py +88 -0
  74. pixeltable/functions/pil/image.py +148 -7
  75. pixeltable/functions/string.py +13 -0
  76. pixeltable/functions/together.py +27 -0
  77. pixeltable/functions/util.py +41 -0
  78. pixeltable/functions/video.py +62 -0
  79. pixeltable/iterators/__init__.py +3 -0
  80. pixeltable/iterators/base.py +48 -0
  81. pixeltable/iterators/document.py +311 -0
  82. pixeltable/iterators/video.py +89 -0
  83. pixeltable/metadata/__init__.py +54 -0
  84. pixeltable/metadata/converters/convert_10.py +18 -0
  85. pixeltable/metadata/schema.py +211 -0
  86. pixeltable/plan.py +656 -0
  87. pixeltable/store.py +413 -182
  88. pixeltable/tests/conftest.py +143 -86
  89. pixeltable/tests/test_audio.py +65 -0
  90. pixeltable/tests/test_catalog.py +27 -0
  91. pixeltable/tests/test_client.py +14 -14
  92. pixeltable/tests/test_component_view.py +372 -0
  93. pixeltable/tests/test_dataframe.py +433 -0
  94. pixeltable/tests/test_dirs.py +78 -62
  95. pixeltable/tests/test_document.py +117 -0
  96. pixeltable/tests/test_exprs.py +591 -135
  97. pixeltable/tests/test_function.py +297 -67
  98. pixeltable/tests/test_functions.py +283 -1
  99. pixeltable/tests/test_migration.py +43 -0
  100. pixeltable/tests/test_nos.py +54 -0
  101. pixeltable/tests/test_snapshot.py +208 -0
  102. pixeltable/tests/test_table.py +1086 -258
  103. pixeltable/tests/test_transactional_directory.py +42 -0
  104. pixeltable/tests/test_types.py +5 -11
  105. pixeltable/tests/test_video.py +149 -34
  106. pixeltable/tests/test_view.py +530 -0
  107. pixeltable/tests/utils.py +186 -45
  108. pixeltable/tool/create_test_db_dump.py +149 -0
  109. pixeltable/type_system.py +490 -133
  110. pixeltable/utils/__init__.py +17 -46
  111. pixeltable/utils/clip.py +12 -15
  112. pixeltable/utils/coco.py +136 -0
  113. pixeltable/utils/documents.py +39 -0
  114. pixeltable/utils/filecache.py +195 -0
  115. pixeltable/utils/help.py +11 -0
  116. pixeltable/utils/media_store.py +76 -0
  117. pixeltable/utils/parquet.py +126 -0
  118. pixeltable/utils/pytorch.py +172 -0
  119. pixeltable/utils/s3.py +13 -0
  120. pixeltable/utils/sql.py +17 -0
  121. pixeltable/utils/transactional_directory.py +35 -0
  122. pixeltable-0.2.0.dist-info/LICENSE +18 -0
  123. pixeltable-0.2.0.dist-info/METADATA +117 -0
  124. pixeltable-0.2.0.dist-info/RECORD +125 -0
  125. {pixeltable-0.1.2.dist-info → pixeltable-0.2.0.dist-info}/WHEEL +1 -1
  126. pixeltable/catalog.py +0 -1421
  127. pixeltable/exprs.py +0 -1745
  128. pixeltable/function.py +0 -269
  129. pixeltable/functions/clip.py +0 -10
  130. pixeltable/functions/pil/__init__.py +0 -23
  131. pixeltable/functions/tf.py +0 -21
  132. pixeltable/index.py +0 -57
  133. pixeltable/tests/test_dict.py +0 -24
  134. pixeltable/tests/test_tf.py +0 -69
  135. pixeltable/tf.py +0 -33
  136. pixeltable/utils/tf.py +0 -33
  137. pixeltable/utils/video.py +0 -32
  138. pixeltable-0.1.2.dist-info/LICENSE +0 -201
  139. pixeltable-0.1.2.dist-info/METADATA +0 -89
  140. pixeltable-0.1.2.dist-info/RECORD +0 -37
pixeltable/dataframe.py CHANGED
@@ -1,24 +1,35 @@
1
+ from __future__ import annotations
2
+
1
3
  import base64
4
+ import copy
5
+ import hashlib
2
6
  import io
3
- import os
4
- from typing import List, Optional, Any, Dict, Generator, Tuple, Set
7
+ import json
8
+ import logging
9
+ import mimetypes
10
+ import traceback
5
11
  from pathlib import Path
6
- from dataclasses import dataclass, field
12
+ from typing import List, Optional, Any, Dict, Generator, Tuple, Set
13
+
7
14
  import pandas as pd
15
+ import pandas.io.formats.style
8
16
  import sqlalchemy as sql
9
17
  from PIL import Image
10
- import copy
11
18
 
12
- from pixeltable import catalog
19
+ import pixeltable.catalog as catalog
20
+ import pixeltable.exceptions as excs
21
+ import pixeltable.exprs as exprs
22
+ import pixeltable.type_system as ts
23
+ from pixeltable.catalog import is_valid_identifier
13
24
  from pixeltable.env import Env
25
+ from pixeltable.plan import Planner
14
26
  from pixeltable.type_system import ColumnType
15
- from pixeltable import exprs
16
- from pixeltable import exceptions as exc
17
27
 
18
28
  __all__ = [
19
29
  'DataFrame'
20
30
  ]
21
31
 
32
+ _logger = logging.getLogger('pixeltable')
22
33
 
23
34
  def _format_img(img: object) -> str:
24
35
  """
@@ -28,360 +39,479 @@ def _format_img(img: object) -> str:
28
39
  with io.BytesIO() as buffer:
29
40
  img.save(buffer, 'jpeg')
30
41
  img_base64 = base64.b64encode(buffer.getvalue()).decode()
31
- return f'<img src="data:image/jpeg;base64,{img_base64}">'
32
-
33
- def _format_video(video_file_path: str) -> str:
34
- # turn absolute video_file_path into relative path, absolute paths don't work
35
- p = Path(video_file_path)
36
- root = Path(os.getcwd())
37
- #print(root)
38
- #return f'<video controls><source src="{video_file_path}" type="video/mp4"></video>'
39
- try:
40
- rel_path = p.relative_to(root)
41
- return f'<video controls><source src="{rel_path}" type="video/mp4"></video>'
42
- except ValueError:
43
- # display path as string
44
- return video_file_path
42
+ return f'<div style=\'width:200px;\'><img src="data:image/jpeg;base64,{img_base64}"></div>'
43
+
44
+ def _create_source_tag(file_path: str) -> str:
45
+ abs_path = Path(file_path)
46
+ assert abs_path.is_absolute()
47
+ src_url = f'{Env.get().http_address}/{abs_path}'
48
+ mime = mimetypes.guess_type(src_url)[0]
49
+ # if mime is None, the attribute string would not be valid html.
50
+ mime_attr = f'type="{mime}"' if mime is not None else ''
51
+ return f'<source src="{src_url}" {mime_attr} />'
52
+
53
+ def _format_video(file_path: str) -> str:
54
+ return f'<video controls>{_create_source_tag(file_path)}</video>'
55
+
56
+ def _format_audio(file_path: str) -> str:
57
+ return f'<audio controls>{_create_source_tag(file_path)}</audio>'
45
58
 
46
59
  class DataFrameResultSet:
47
- def __init__(self, rows: List[List], col_names: List[str], col_types: List[ColumnType]):
48
- self.rows = rows
49
- self.col_names = col_names
50
- self.col_types = col_types
60
+ def __init__(self, rows: List[List[Any]], col_names: List[str], col_types: List[ColumnType]):
61
+ self._rows = rows
62
+ self._col_names = col_names
63
+ self._col_types = col_types
64
+ self._formatters = {
65
+ ts.ImageType: _format_img,
66
+ ts.VideoType: _format_video,
67
+ ts.AudioType: _format_audio,
68
+ }
51
69
 
52
70
  def __len__(self) -> int:
53
- return len(self.rows)
71
+ return len(self._rows)
72
+
73
+ def column_names(self) -> List[str]:
74
+ return self._col_names
75
+
76
+ def column_types(self) -> List[ColumnType]:
77
+ return self._col_types
78
+
79
+ def __repr__(self) -> str:
80
+ return self.to_pandas().__repr__()
54
81
 
55
82
  def _repr_html_(self) -> str:
56
- img_col_idxs = [i for i, col_type in enumerate(self.col_types) if col_type.is_image_type()]
57
- video_col_idxs = [i for i, col_type in enumerate(self.col_types) if col_type.is_video_type()]
58
- formatters = {self.col_names[i]: _format_img for i in img_col_idxs}
59
- formatters.update({self.col_names[i]: _format_video for i in video_col_idxs})
60
- # escape=False: make sure <img> tags stay intact
83
+ formatters = {
84
+ col_name: self._formatters[col_type.__class__]
85
+ for col_name, col_type in zip(self._col_names, self._col_types)
86
+ if col_type.__class__ in self._formatters
87
+ }
88
+
61
89
  # TODO: why does mypy complain about formatters having an incorrect type?
62
90
  return self.to_pandas().to_html(formatters=formatters, escape=False, index=False) # type: ignore[arg-type]
63
91
 
64
92
  def __str__(self) -> str:
65
93
  return self.to_pandas().to_string()
66
94
 
67
- def to_pandas(self) -> pd.DataFrame:
68
- return pd.DataFrame.from_records(self.rows, columns=self.col_names)
95
+ def _reverse(self) -> None:
96
+ """Reverse order of rows"""
97
+ self._rows.reverse()
69
98
 
70
- def __getitem__(self, index: Any) -> Any:
71
- if isinstance(index, tuple):
72
- if len(index) != 2 or not isinstance(index[0], int) or not isinstance(index[1], int):
73
- raise exc.OperationalError(f'Bad index: {index}')
74
- return self.rows[index[0]][index[1]]
99
+ def to_pandas(self) -> pd.DataFrame:
100
+ return pd.DataFrame.from_records(self._rows, columns=self._col_names)
75
101
 
102
+ def _row_to_dict(self, row_idx: int) -> Dict[str, Any]:
103
+ return {self._col_names[i]: self._rows[row_idx][i] for i in range(len(self._col_names))}
76
104
 
105
+ def __getitem__(self, index: Any) -> Any:
106
+ if isinstance(index, str):
107
+ if index not in self._col_names:
108
+ raise excs.Error(f'Invalid column name: {index}')
109
+ col_idx = self._col_names.index(index)
110
+ return [row[col_idx] for row in self._rows]
111
+ if isinstance(index, int):
112
+ return self._row_to_dict(index)
113
+ if isinstance(index, tuple) and len(index) == 2:
114
+ if not isinstance(index[0], int) or not (isinstance(index[1], str) or isinstance(index[1], int)):
115
+ raise excs.Error(f'Bad index, expected [<row idx>, <column name | column index>]: {index}')
116
+ if isinstance(index[1], str) and index[1] not in self._col_names:
117
+ raise excs.Error(f'Invalid column name: {index[1]}')
118
+ col_idx = self._col_names.index(index[1]) if isinstance(index[1], str) else index[1]
119
+ return self._rows[index[0]][col_idx]
120
+ raise excs.Error(f'Bad index: {index}')
121
+
122
+ def __iter__(self) -> DataFrameResultSetIterator:
123
+ return DataFrameResultSetIterator(self)
124
+
125
+ def __eq__(self, other):
126
+ if not isinstance(other, DataFrameResultSet):
127
+ return False
128
+ return self.to_pandas().equals(other.to_pandas())
129
+
130
+
131
+ class DataFrameResultSetIterator:
132
+ def __init__(self, result_set: DataFrameResultSet):
133
+ self._result_set = result_set
134
+ self._idx = 0
135
+
136
+ def __next__(self) -> Dict[str, Any]:
137
+ if self._idx >= len(self._result_set):
138
+ raise StopIteration
139
+ row = self._result_set._row_to_dict(self._idx)
140
+ self._idx += 1
141
+ return row
142
+
143
+
144
+ # TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
77
145
  class AnalysisInfo:
78
- def __init__(self):
146
+ def __init__(self, tbl: catalog.TableVersion):
147
+ self.tbl = tbl
79
148
  # output of the SQL scan stage
80
149
  self.sql_scan_output_exprs: List[exprs.Expr] = []
81
150
  # output of the agg stage
82
151
  self.agg_output_exprs: List[exprs.Expr] = []
83
- # select list providing the input to the SQL scan stage
84
- self.sql_select_list: List[sql.sql.expression.ClauseElement] = []
85
152
  # Where clause of the Select stmt of the SQL scan stage
86
- self.sql_where_clause: Optional[sql.sql.expression.ClauseElement] = None
153
+ self.sql_where_clause: Optional[sql.ClauseElement] = None
87
154
  # filter predicate applied to input rows of the SQL scan stage
88
155
  self.filter: Optional[exprs.Predicate] = None
89
156
  self.similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None
90
157
  self.agg_fn_calls: List[exprs.FunctionCall] = [] # derived from unique_exprs
158
+ self.has_frame_col: bool = False # True if we're referencing the frame col
91
159
 
92
- self.unique_exprs = exprs.UniqueExprSet()
93
- self.next_data_row_idx = 0
160
+ self.evaluator: Optional[exprs.Evaluator] = None
161
+ self.sql_scan_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of SQL scan stage
162
+ self.agg_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of agg stage
163
+ self.filter_eval_ctx: List[exprs.Expr] = []
164
+ self.group_by_eval_ctx: List[exprs.Expr] = []
94
165
 
95
- @property
96
- def num_materialized(self) -> int:
97
- return self.next_data_row_idx
98
-
99
- def assign_idxs(self, expr_list: List[exprs.Expr]) -> None:
166
+ def finalize_exec(self) -> None:
100
167
  """
101
- Assign data/sql_row_idx to exprs in expr_list and all their subcomponents.
102
- An expr with to_sql() != None is assumed to be materialized fully via SQL; its components
103
- aren't materialized and don't receive idxs.
168
+ Call release() on all collected Exprs.
104
169
  """
105
- for e in expr_list:
106
- self._assign_idxs_aux(e)
107
- self.agg_fn_calls = [e for e in self.unique_exprs if isinstance(e, exprs.FunctionCall) and e.is_agg_fn_call]
108
-
109
- def _assign_idxs_aux(self, expr: exprs.Expr) -> None:
110
- if not self.unique_exprs.add(expr):
111
- # nothing left to do
112
- return
113
-
114
- sql_expr = expr.sql_expr()
115
- # if this can be materialized via SQL we don't need to look at its components;
116
- # we special-case Literals because we don't want to have to materialize them via SQL
117
- if sql_expr is not None and not isinstance(expr, exprs.Literal):
118
- assert expr.data_row_idx < 0
119
- expr.data_row_idx = self.next_data_row_idx
120
- self.next_data_row_idx += 1
121
- expr.sql_row_idx = len(self.sql_select_list)
122
- self.sql_select_list.append(sql_expr)
123
- return
170
+ exprs.Expr.release_list(self.sql_scan_output_exprs)
171
+ exprs.Expr.release_list(self.agg_output_exprs)
172
+ if self.filter is not None:
173
+ self.filter.release()
124
174
 
125
- # expr value needs to be computed via Expr.eval()
126
- for c in expr.components:
127
- self._assign_idxs_aux(c)
128
- assert expr.data_row_idx < 0
129
- expr.data_row_idx = self.next_data_row_idx
130
- self.next_data_row_idx += 1
131
175
 
132
176
 
133
177
  class DataFrame:
134
178
  def __init__(
135
- self, tbl: catalog.Table,
136
- select_list: Optional[List[exprs.Expr]] = None,
137
- where_clause: Optional[exprs.Predicate] = None):
179
+ self, tbl: catalog.TableVersionPath,
180
+ select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]] = None,
181
+ where_clause: Optional[exprs.Predicate] = None,
182
+ group_by_clause: Optional[List[exprs.Expr]] = None,
183
+ grouping_tbl: Optional[catalog.TableVersion] = None,
184
+ order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, # List[(expr, asc)]
185
+ limit: Optional[int] = None):
138
186
  self.tbl = tbl
139
- # self.select_list and self.where_clause contain execution state and therefore cannot be shared
140
- self.select_list: Optional[List[exprs.Expr]] = None # None: implies all cols
141
- if select_list is not None:
142
- self.select_list = [e.copy() for e in select_list]
143
- self.where_clause: Optional[exprs.Predicate] = None
144
- if where_clause is not None:
145
- self.where_clause = where_clause.copy()
146
- self.group_by_clause: Optional[List[exprs.Expr]] = None
147
- self.analysis_info: Optional[AnalysisInfo] = None
148
-
149
- def analyze(self) -> None:
150
- """
151
- Populates self.analysis_info.
152
- """
153
- info = self.analysis_info = AnalysisInfo()
154
- if self.where_clause is not None:
155
- info.sql_where_clause, info.filter = self.where_clause.extract_sql_predicate()
156
- if info.filter is not None:
157
- similarity_clauses, info.filter = info.filter.split_conjuncts(
158
- lambda e: isinstance(e, exprs.ImageSimilarityPredicate))
159
- if len(similarity_clauses) > 1:
160
- raise exc.OperationalError(f'More than one nearest() or matches() not supported')
161
- if len(similarity_clauses) == 1:
162
- info.similarity_clause = similarity_clauses[0]
163
- img_col = info.similarity_clause.img_col_ref.col
164
- if not img_col.is_indexed:
165
- raise exc.OperationalError(
166
- f'nearest()/matches() not available for unindexed column {img_col.name}')
167
-
168
- if info.filter is not None:
169
- info.assign_idxs([info.filter])
170
- if len(self.group_by_clause) > 0:
171
- info.assign_idxs(self.group_by_clause)
172
- for e in self.group_by_clause:
173
- self._analyze_group_by(e, True)
174
- info.assign_idxs(self.select_list)
175
- grouping_expr_idxs = set([e.data_row_idx for e in self.group_by_clause])
176
- item_is_agg = [self._analyze_select_list(e, grouping_expr_idxs)[0] for e in self.select_list]
177
-
178
- if self.is_agg():
179
- # this is an aggregation
180
- if item_is_agg.count(False) > 0:
181
- raise exc.Error(f'Invalid non-aggregate in select list: {self.select_list[item_is_agg.find(False)]}')
182
- # the agg stage materializes select list items that haven't already been provided by SQL
183
- info.agg_output_exprs = [e for e in self.select_list if e.sql_row_idx == -1]
184
- # our sql scan stage needs to materialize: grouping exprs, arguments of agg fn calls
185
- info.sql_scan_output_exprs = copy.copy(self.group_by_clause)
186
- unique_args: Set[int] = set()
187
- for fn_call in info.agg_fn_calls:
188
- for c in fn_call.components:
189
- unique_args.add(c.data_row_idx)
190
- all_exprs = {e.data_row_idx: e for e in info.unique_exprs}
191
- info.sql_scan_output_exprs.extend([all_exprs[idx] for idx in unique_args])
192
- else:
193
- info.sql_scan_output_exprs = self.select_list
194
187
 
195
- def is_agg(self) -> bool:
196
- return len(self.group_by_clause) > 0 \
197
- or (self.analysis_info is not None and len(self.analysis_info.agg_fn_calls) > 0)
198
-
199
- def _is_agg_fn_call(self, e: exprs.Expr) -> bool:
200
- return isinstance(e, exprs.FunctionCall) and e.is_agg_fn_call
201
-
202
- def _analyze_group_by(self, e: exprs.Expr, check_sql: bool) -> None:
188
+ # select list logic
189
+ DataFrame._select_list_check_rep(select_list) # check select list without expansion
190
+ # exprs contain execution state and therefore cannot be shared
191
+ select_list = copy.deepcopy(select_list)
192
+ select_list_exprs, column_names = DataFrame._normalize_select_list(tbl, select_list)
193
+ DataFrame._select_list_check_rep(list(zip(select_list_exprs, column_names)))
194
+ # check select list after expansion to catch early
195
+ # the following two lists are always non empty, even if select list is None.
196
+ self._select_list_exprs = select_list_exprs
197
+ self._column_names = column_names
198
+ self.select_list = select_list
199
+
200
+ self.where_clause = copy.deepcopy(where_clause)
201
+ assert group_by_clause is None or grouping_tbl is None
202
+ self.group_by_clause = copy.deepcopy(group_by_clause)
203
+ self.grouping_tbl = grouping_tbl
204
+ self.order_by_clause = copy.deepcopy(order_by_clause)
205
+ self.limit_val = limit
206
+
207
+ @classmethod
208
+ def _select_list_check_rep(cls,
209
+ select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]],
210
+ ) -> None:
211
+ """Validate basic select list types.
203
212
  """
204
- Make sure that group-by exprs don't contain aggregates.
205
- """
206
- if e.sql_row_idx == -1 and check_sql:
207
- raise exc.Error(f'Invalid grouping expr, needs to be expressible in SQL: {e}')
208
- if self._is_agg_fn_call(e):
209
- raise exc.Error(f'Cannot group by aggregate function: {e}')
210
- for c in e.components:
211
- self._analyze_group_by(c, False)
212
-
213
- def _analyze_select_list(self, e: exprs.Expr, grouping_exprs: Set[int]) -> Tuple[bool, bool]:
213
+ if select_list is None: # basic check for valid select list
214
+ return
215
+
216
+ assert len(select_list) > 0
217
+ for ent in select_list:
218
+ assert isinstance(ent, tuple)
219
+ assert len(ent) == 2
220
+ assert isinstance(ent[0], exprs.Expr)
221
+ assert ent[1] is None or isinstance(ent[1], str)
222
+ if isinstance(ent[1], str):
223
+ assert is_valid_identifier(ent[1])
224
+
225
+ @classmethod
226
+ def _normalize_select_list(cls,
227
+ tbl: catalog.TableVersionPath,
228
+ select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]],
229
+ ) -> Tuple[List[exprs.Expr], List[str]]:
214
230
  """
215
- Analyzes select list item. Returns (list item is output of agg stage, item is output of scan stage).
216
- Collects agg fn calls in self.analysis_info.
231
+ Expand select list information with all columns and their names
232
+ Returns:
233
+ a pair composed of the list of expressions and the list of corresponding names
217
234
  """
218
- if e.data_row_idx in grouping_exprs:
219
- return True, True
220
- elif self._is_agg_fn_call(e):
221
- for c in e.components:
222
- _, is_scan_output = self._analyze_select_list(c, grouping_exprs)
223
- if not is_scan_output:
224
- raise exc.Error(f'Invalid nested aggregates: {e}')
225
- return True, False
226
- elif isinstance(e, exprs.Literal):
227
- return True, True
228
- elif isinstance(e, exprs.ColumnRef):
229
- # we already know that this isn't a grouping expr
230
- return False, True
235
+ if select_list is None:
236
+ expanded_list = [(exprs.ColumnRef(col), None) for col in tbl.columns()]
231
237
  else:
232
- # an expression such as <grouping expr 1> + <grouping expr 2> can be the output of both
233
- # the agg stage and the scan stage
234
- component_is_agg: List[bool] = []
235
- component_is_scan: List[bool] = []
236
- for c in e.components:
237
- is_agg, is_scan = self._analyze_select_list(c, grouping_exprs)
238
- component_is_agg.append(is_agg)
239
- component_is_scan.append(is_scan)
240
- is_agg = component_is_agg.count(True) == len(e.components)
241
- is_scan = component_is_scan.count(True) == len(e.components)
242
- if not is_agg and not is_scan:
243
- raise exc.Error(f'Invalid expression, mixes aggregate with non-aggregate: {e}')
244
- return is_agg, is_scan
245
-
246
- def exec(self, n: int = 20, select_pk: bool = False) -> Generator[List[Any], None, None]:
247
- """
248
- Returned value: list of select list values.
249
- If select_pk == True, also selects the primary key of the storage table (which is rowid and v_min).
238
+ expanded_list = select_list
239
+
240
+ out_exprs : List[exprs.Expr] = []
241
+ out_names : List[str] = [] # keep track of order
242
+ seen_out_names : set[str] = set() # use to check for duplicates in loop, avoid square complexity
243
+ for i, (expr, name) in enumerate(expanded_list):
244
+ if name is None:
245
+ # use default, add suffix if needed so default adds no duplicates
246
+ default_name = expr.default_column_name()
247
+ if default_name is not None:
248
+ column_name = default_name
249
+ if default_name in seen_out_names:
250
+ # already used, then add suffix until unique name is found
251
+ for j in range(1, len(out_names)+1):
252
+ column_name = f'{default_name}_{j}'
253
+ if column_name not in seen_out_names:
254
+ break
255
+ else: # no default name, eg some expressions
256
+ column_name = f'col_{i}'
257
+ else: # user provided name, no attempt to rename
258
+ column_name = name
259
+
260
+ out_exprs.append(expr)
261
+ out_names.append(column_name)
262
+ seen_out_names.add(column_name)
263
+ assert len(out_exprs) == len(out_names)
264
+ assert set(out_names) == seen_out_names
265
+ return out_exprs, out_names
266
+
267
+ def _exec(self) -> Generator[exprs.DataRow, None, None]:
268
+ """Run the query and return rows as a generator.
269
+ This function must not modify the state of the DataFrame, otherwise it breaks dataset caching.
250
270
  """
251
- if self.select_list is None:
252
- self.select_list = [exprs.ColumnRef(col) for col in self.tbl.columns]
253
- if self.group_by_clause is None:
254
- self.group_by_clause = []
255
- for item in self.select_list:
271
+ # construct a group-by clause if we're grouping by a table
272
+ group_by_clause: List[exprs.Expr] = []
273
+ if self.grouping_tbl is not None:
274
+ assert self.group_by_clause is None
275
+ num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
276
+ # the grouping table must be a base of self.tbl
277
+ assert num_rowid_cols <= len(self.tbl.tbl_version.store_tbl.rowid_columns())
278
+ group_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
279
+ elif self.group_by_clause is not None:
280
+ group_by_clause = self.group_by_clause
281
+
282
+ for item in self._select_list_exprs:
256
283
  item.bind_rel_paths(None)
257
- if self.analysis_info is None:
258
- self.analyze()
259
- if self.analysis_info.similarity_clause is not None and n > 100:
260
- raise exc.OperationalError(f'nearest()/matches() requires show(n <= 100): n={n}')
261
-
262
- # determine order_by clause for window functions or grouping, if present
263
- window_fn_calls = [
264
- e for e in self.analysis_info.unique_exprs
265
- if isinstance(e, exprs.FunctionCall) and e.is_window_fn_call
266
- ]
267
- if len(window_fn_calls) > 0 and self.is_agg():
268
- raise exc.Error(f'Cannot combine window functions with non-windowed aggregation')
269
- order_by_exprs: List[exprs.Expr] = []
270
- # TODO: check compatibility of window clauses
271
- if len(window_fn_calls) > 0:
272
- order_by_exprs = window_fn_calls[0].get_window_sort_exprs()
273
- elif self.is_agg():
274
- # TODO: collect aggs with order-by and analyze for compatibility
275
- order_by_exprs = self.group_by_clause + self.analysis_info.agg_fn_calls[0].get_agg_order_by()
276
- order_by_clause = [e.sql_expr() for e in order_by_exprs]
277
- for i in range(len(order_by_exprs)):
278
- if order_by_clause[i] is None:
279
- raise exc.Error(f'order_by element cannot be expressed in SQL: {order_by_exprs[i]}')
280
-
281
- idx_rowids: List[int] = [] # rowids returned by index lookup
282
- if self.analysis_info.similarity_clause is not None:
283
- # do index lookup
284
- assert self.analysis_info.similarity_clause.img_col_ref.col.idx is not None
285
- embed = self.analysis_info.similarity_clause.embedding()
286
- idx_rowids = self.analysis_info.similarity_clause.img_col_ref.col.idx.search(embed, n, self.tbl.valid_rowids)
287
-
288
- with Env.get().engine.connect() as conn:
289
- stmt = self._create_select_stmt(
290
- self.analysis_info.sql_select_list, self.analysis_info.sql_where_clause, idx_rowids, select_pk,
291
- order_by_clause)
292
- num_rows = 0
293
- sql_scan_evaluator = exprs.ExprEvaluator(
294
- self.analysis_info.sql_scan_output_exprs, self.analysis_info.filter)
295
- agg_evaluator = exprs.ExprEvaluator(self.analysis_info.agg_output_exprs, None)
296
-
297
- current_group: Optional[List[Any]] = None # for grouping agg, the values of the group-by exprs
298
- for row in conn.execute(stmt):
299
- sql_row = row._data
300
- data_row: List[Any] = [None] * self.analysis_info.num_materialized
301
- if not sql_scan_evaluator.eval(sql_row, data_row):
302
- continue
303
-
304
- # copy select list results into contiguous array
305
- result_row: Optional[List[Any]] = None
306
- if self.is_agg():
307
- group = [data_row[e.data_row_idx] for e in self.group_by_clause]
308
- if current_group is None:
309
- current_group = group
310
- if group != current_group:
311
- # we're entering a new group, emit a row for the last one
312
- agg_evaluator.eval(last_sql_row, last_data_row)
313
- result_row = [last_data_row[e.data_row_idx] for e in self.select_list]
314
- current_group = group
315
- for fn_call in self.analysis_info.agg_fn_calls:
316
- fn_call.reset_agg()
317
- for fn_call in self.analysis_info.agg_fn_calls:
318
- fn_call.update(data_row)
319
- else:
320
- result_row = [data_row[e.data_row_idx] for e in self.select_list]
321
- if select_pk:
322
- result_row.extend(sql_row[-2:])
323
-
324
- last_data_row = data_row
325
- last_sql_row = row._data
326
- if result_row is not None:
327
- yield result_row
328
- num_rows += 1
329
- if n > 0 and num_rows == n:
330
- break
331
-
332
- if self.is_agg():
333
- # we need to emit the output row for the current group
334
- agg_evaluator.eval(sql_row, data_row)
335
- result_row = [data_row[e.data_row_idx] for e in self.select_list]
336
- yield result_row
284
+ plan = Planner.create_query_plan(
285
+ self.tbl, self._select_list_exprs, where_clause=self.where_clause, group_by_clause=group_by_clause,
286
+ order_by_clause=self.order_by_clause if self.order_by_clause is not None else [],
287
+ limit=self.limit_val if self.limit_val is not None else 0) # limit_val == 0: no limit_val
288
+
289
+ with Env.get().engine.begin() as conn:
290
+ plan.ctx.conn = conn
291
+ plan.open()
292
+ try:
293
+ for row_batch in plan:
294
+ for data_row in row_batch:
295
+ yield data_row
296
+ finally:
297
+ plan.close()
298
+ return
337
299
 
338
300
  def show(self, n: int = 20) -> DataFrameResultSet:
339
- data_rows = [row for row in self.exec(n)]
340
- col_names = [expr.display_name() for expr in self.select_list]
341
- # replace ''
342
- col_names = [n if n != '' else f'col_{i}' for i, n in enumerate(col_names)]
343
- return DataFrameResultSet(data_rows, col_names, [expr.col_type for expr in self.select_list])
301
+ assert n is not None
302
+ return self.limit(n).collect()
303
+
304
+ def head(self, n: int = 10) -> DataFrameResultSet:
305
+ if self.order_by_clause is not None:
306
+ raise excs.Error(f'head() cannot be used with order_by()')
307
+ num_rowid_cols = len(self.tbl.tbl_version.store_tbl.rowid_columns())
308
+ order_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
309
+ return self.order_by(*order_by_clause, asc=True).limit(n).collect()
310
+
311
+ def tail(self, n: int = 10) -> DataFrameResultSet:
312
+ if self.order_by_clause is not None:
313
+ raise excs.Error(f'tail() cannot be used with order_by()')
314
+ num_rowid_cols = len(self.tbl.tbl_version.store_tbl.rowid_columns())
315
+ order_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
316
+ result = self.order_by(*order_by_clause, asc=False).limit(n).collect()
317
+ result._reverse()
318
+ return result
319
+
320
+ def get_column_names(self) -> List[str]:
321
+ return self._column_names
322
+
323
+ def get_column_types(self) -> List[ColumnType]:
324
+ return [expr.col_type for expr in self._select_list_exprs]
325
+
326
+ def collect(self) -> DataFrameResultSet:
327
+ try:
328
+ result_rows = []
329
+ for data_row in self._exec():
330
+ result_row = [data_row[e.slot_idx] for e in self._select_list_exprs]
331
+ result_rows.append(result_row)
332
+ except excs.ExprEvalError as e:
333
+ msg = (f'In row {e.row_num} the {e.expr_msg} encountered exception '
334
+ f'{type(e.exc).__name__}:\n{str(e.exc)}')
335
+ if len(e.input_vals) > 0:
336
+ input_msgs = [
337
+ f"'{d}' = {d.col_type.print_value(e.input_vals[i])}"
338
+ for i, d in enumerate(e.expr.dependencies())
339
+ ]
340
+ msg += f'\nwith {", ".join(input_msgs)}'
341
+ assert e.exc_tb is not None
342
+ stack_trace = traceback.format_tb(e.exc_tb)
343
+ if len(stack_trace) > 2:
344
+ # append a stack trace if the exception happened in user code
345
+ # (frame 0 is ExprEvaluator and frame 1 is some expr's eval()
346
+ nl = '\n'
347
+ # [-1:0:-1]: leave out entry 0 and reverse order, so that the most recent frame is at the top
348
+ msg += f'\nStack:\n{nl.join(stack_trace[-1:1:-1])}'
349
+ raise excs.Error(msg)
350
+ except sql.exc.DBAPIError as e:
351
+ raise excs.Error(f'Error during SQL execution:\n{e}')
352
+
353
+ col_types = self.get_column_types()
354
+ return DataFrameResultSet(result_rows, self._column_names, col_types)
344
355
 
345
356
  def count(self) -> int:
346
- """
347
- TODO: implement as part of DataFrame.agg()
348
- """
349
- stmt = sql.select(sql.func.count('*')).select_from(self.tbl.sa_tbl) \
350
- .where(self.tbl.v_min_col <= self.tbl.version) \
351
- .where(self.tbl.v_max_col > self.tbl.version)
352
- if self.where_clause is not None:
353
- sql_where_clause = self.where_clause.sql_expr()
354
- assert sql_where_clause is not None
355
- stmt = stmt.where(sql_where_clause)
357
+ from pixeltable.plan import Planner
358
+ stmt = Planner.create_count_stmt(self.tbl, self.where_clause)
356
359
  with Env.get().engine.connect() as conn:
357
360
  result: int = conn.execute(stmt).scalar_one()
358
361
  assert isinstance(result, int)
359
362
  return result
360
363
 
361
- def categorical_map(self) -> Dict[str, int]:
364
+ def _description(self) -> pd.DataFrame:
365
+ """see DataFrame.describe()"""
366
+ heading_vals: List[str] = []
367
+ info_vals: List[str] = []
368
+ if self.select_list is not None:
369
+ assert len(self.select_list) > 0
370
+ heading_vals.append('Select')
371
+ heading_vals.extend([''] * (len(self.select_list) - 1))
372
+ info_vals.extend(self.get_column_names())
373
+ if self.where_clause is not None:
374
+ heading_vals.append('Where')
375
+ info_vals.append(self.where_clause.display_str(inline=False))
376
+ if self.group_by_clause is not None:
377
+ heading_vals.append('Group By')
378
+ heading_vals.extend([''] * (len(self.group_by_clause) - 1))
379
+ info_vals.extend([e.display_str(inline=False) for e in self.group_by_clause])
380
+ if self.order_by_clause is not None:
381
+ heading_vals.append('Order By')
382
+ heading_vals.extend([''] * (len(self.order_by_clause) - 1))
383
+ info_vals.extend([
384
+ f'{e[0].display_str(inline=False)} {"asc" if e[1] else "desc"}' for e in self.order_by_clause
385
+ ])
386
+ if self.limit_val is not None:
387
+ heading_vals.append('Limit')
388
+ info_vals.append(str(self.limit_val))
389
+ assert len(heading_vals) > 0
390
+ assert len(info_vals) > 0
391
+ assert len(heading_vals) == len(info_vals)
392
+ return pd.DataFrame({'Heading': heading_vals, 'Info': info_vals})
393
+
394
+ def _description_html(self) -> pandas.io.formats.style.Styler:
395
+ """Return the description in an ipython-friendly manner."""
396
+ pd_df = self._description()
397
+ # white-space: pre-wrap: print \n as newline
398
+ # th: center-align headings
399
+ return pd_df.style.set_properties(**{'white-space': 'pre-wrap', 'text-align': 'left'}) \
400
+ .set_table_styles([dict(selector='th', props=[('text-align', 'center')])]) \
401
+ .hide(axis='index').hide(axis='columns')
402
+
403
+ def describe(self) -> None:
362
404
  """
363
- Return map of distinct values in string ColumnRef to increasing integers.
364
- TODO: implement as part of DataFrame.agg()
405
+ Prints a tabular description of this DataFrame.
406
+ The description has two columns, heading and info, which list the contents of each 'component'
407
+ (select list, where clause, ...) vertically.
365
408
  """
366
- if self.select_list is None or len(self.select_list) != 1 \
367
- or not isinstance(self.select_list[0], exprs.ColumnRef) \
368
- or not self.select_list[0].col_type.is_string_type():
369
- raise exc.OperationalError(f'categoricals_map() can only be applied to an individual string column')
370
- assert isinstance(self.select_list[0], exprs.ColumnRef)
371
- col = self.select_list[0].col
372
- stmt = sql.select(sql.distinct(col.sa_col)) \
373
- .where(self.tbl.v_min_col <= self.tbl.version) \
374
- .where(self.tbl.v_max_col > self.tbl.version) \
375
- .order_by(col.sa_col)
376
- if self.where_clause is not None:
377
- sql_where_clause = self.where_clause.sql_expr()
378
- assert sql_where_clause is not None
379
- stmt = stmt.where(sql_where_clause)
380
- with Env.get().engine.connect() as conn:
381
- result = {row._data[0]: i for i, row in enumerate(conn.execute(stmt))}
382
- return result
409
+ try:
410
+ __IPYTHON__
411
+ from IPython.display import display
412
+ display(self._description_html())
413
+ except NameError:
414
+ print(self.__repr__())
383
415
 
384
- def __getitem__(self, index: object) -> 'DataFrame':
416
+ def __repr__(self) -> str:
417
+ return self._description().to_string(header=False, index=False)
418
+
419
+ def _repr_html_(self) -> str:
420
+ return self._description_html()._repr_html_()
421
+
422
+ def select(self, *items: Any, **named_items : Any) -> DataFrame:
423
+ if self.select_list is not None:
424
+ raise excs.Error(f'Select list already specified')
425
+ for (name, _) in named_items.items():
426
+ if not isinstance(name, str) or not is_valid_identifier(name):
427
+ raise excs.Error(f'Invalid name: {name}')
428
+ base_list = [(expr, None) for expr in items] + [(expr, k) for (k, expr) in named_items.items()]
429
+ if len(base_list) == 0:
430
+ raise excs.Error(f'Empty select list')
431
+
432
+ # analyze select list; wrap literals with the corresponding expressions
433
+ select_list = []
434
+ for raw_expr, name in base_list:
435
+ if isinstance(raw_expr, exprs.Expr):
436
+ select_list.append((raw_expr, name))
437
+ elif isinstance(raw_expr, dict):
438
+ select_list.append((exprs.InlineDict(raw_expr), name))
439
+ elif isinstance(raw_expr, list):
440
+ select_list.append((exprs.InlineArray(raw_expr), name))
441
+ else:
442
+ select_list.append((exprs.Literal(raw_expr), name))
443
+ expr = select_list[-1][0]
444
+ if expr.col_type.is_invalid_type():
445
+ raise excs.Error(f'Invalid type: {raw_expr}')
446
+ # TODO: check that ColumnRefs in expr refer to self.tbl
447
+
448
+ # check user provided names do not conflict among themselves
449
+ # or with auto-generated ones
450
+ seen: Set[str] = set()
451
+ _, names = DataFrame._normalize_select_list(self.tbl, select_list)
452
+ for name in names:
453
+ if name in seen:
454
+ repeated_names = [j for j, x in enumerate(names) if x == name]
455
+ pretty = ', '.join(map(str, repeated_names))
456
+ raise excs.Error(f'Repeated column name "{name}" in select() at positions: {pretty}')
457
+ seen.add(name)
458
+
459
+ return DataFrame(
460
+ self.tbl, select_list=select_list, where_clause=self.where_clause, group_by_clause=self.group_by_clause,
461
+ grouping_tbl=self.grouping_tbl, order_by_clause=self.order_by_clause, limit=self.limit_val)
462
+
463
+ def where(self, pred: exprs.Predicate) -> DataFrame:
464
+ return DataFrame(
465
+ self.tbl, select_list=self.select_list, where_clause=pred, group_by_clause=self.group_by_clause,
466
+ grouping_tbl=self.grouping_tbl, order_by_clause=self.order_by_clause, limit=self.limit_val)
467
+
468
+ def group_by(self, *grouping_items: Any) -> DataFrame:
469
+ """Add a group-by clause to this DataFrame.
470
+ Variants:
471
+ - group_by(<base table>): group a component view by their respective base table rows
472
+ - group_by(<expr>, ...): group by the given expressions
473
+ """
474
+ if self.group_by_clause is not None:
475
+ raise excs.Error(f'Group-by already specified')
476
+ grouping_tbl: Optional[catalog.TableVersion] = None
477
+ group_by_clause: Optional[List[exprs.Expr]] = None
478
+ for item in grouping_items:
479
+ if isinstance(item, catalog.Table):
480
+ if len(grouping_items) > 1:
481
+ raise excs.Error(f'group_by(): only one table can be specified')
482
+ # we need to make sure that the grouping table is a base of self.tbl
483
+ base = self.tbl.find_tbl_version(item.tbl_version_path.tbl_id())
484
+ if base is None or base.id == self.tbl.tbl_id():
485
+ raise excs.Error(f'group_by(): {item.name} is not a base table of {self.tbl.tbl_name()}')
486
+ grouping_tbl = item.tbl_version_path.tbl_version
487
+ break
488
+ if not isinstance(item, exprs.Expr):
489
+ raise excs.Error(f'Invalid expression in group_by(): {item}')
490
+ if grouping_tbl is None:
491
+ group_by_clause = list(grouping_items)
492
+ return DataFrame(
493
+ self.tbl, select_list=self.select_list, where_clause=self.where_clause, group_by_clause=group_by_clause,
494
+ grouping_tbl=grouping_tbl, order_by_clause=self.order_by_clause, limit=self.limit_val)
495
+
496
+ def order_by(self, *expr_list: exprs.Expr, asc: bool = True) -> DataFrame:
497
+ for e in expr_list:
498
+ if not isinstance(e, exprs.Expr):
499
+ raise excs.Error(f'Invalid expression in order_by(): {e}')
500
+ order_by_clause = self.order_by_clause if self.order_by_clause is not None else []
501
+ order_by_clause.extend([(e.copy(), asc) for e in expr_list])
502
+ return DataFrame(
503
+ self.tbl, select_list=self.select_list, where_clause=self.where_clause,
504
+ group_by_clause=self.group_by_clause, grouping_tbl=self.grouping_tbl, order_by_clause=order_by_clause,
505
+ limit=self.limit_val)
506
+
507
+ def limit(self, n: int) -> DataFrame:
508
+ assert n is not None and isinstance(n, int)
509
+ return DataFrame(
510
+ self.tbl, select_list=self.select_list, where_clause=self.where_clause,
511
+ group_by_clause=self.group_by_clause, grouping_tbl=self.grouping_tbl, order_by_clause=self.order_by_clause,
512
+ limit=n)
513
+
514
+ def __getitem__(self, index: object) -> DataFrame:
385
515
  """
386
516
  Allowed:
387
517
  - [<Predicate>]: filter operation
@@ -389,52 +519,113 @@ class DataFrame:
389
519
  - [Expr]: setting a single-col select list
390
520
  """
391
521
  if isinstance(index, exprs.Predicate):
392
- return DataFrame(self.tbl, select_list=self.select_list, where_clause=index)
522
+ return self.where(index)
393
523
  if isinstance(index, tuple):
394
524
  index = list(index)
395
525
  if isinstance(index, exprs.Expr):
396
526
  index = [index]
397
527
  if isinstance(index, list):
398
- if self.select_list is not None:
399
- raise exc.OperationalError(f'[] for column selection is only allowed once')
400
- # analyze select list; wrap literals with the corresponding expressions and update it in place
401
- for i in range(len(index)):
402
- expr = index[i]
403
- if isinstance(expr, dict):
404
- index[i] = expr = exprs.InlineDict(expr)
405
- if isinstance(expr, list):
406
- index[i] = expr = exprs.InlineArray(tuple(expr))
407
- if not isinstance(expr, exprs.Expr):
408
- raise exc.OperationalError(f'Invalid expression in []: {expr}')
409
- if expr.col_type.is_invalid_type():
410
- raise exc.OperationalError(f'Invalid type: {expr}')
411
- # TODO: check that ColumnRefs in expr refer to self.tbl
412
- return DataFrame(self.tbl, select_list=index, where_clause=self.where_clause)
528
+ return self.select(*index)
413
529
  raise TypeError(f'Invalid index type: {type(index)}')
530
+
531
+ def _as_dict(self) -> Dict[str, Any]:
532
+ """
533
+ Returns:
534
+ Dictionary representing this dataframe.
535
+ """
536
+ tbl_versions = self.tbl.get_tbl_versions()
537
+ d = {
538
+ '_classname': 'DataFrame',
539
+ 'tbl_ids': [str(t.id) for t in tbl_versions],
540
+ 'tbl_versions': [t.version for t in tbl_versions],
541
+ 'select_list':
542
+ [(e.as_dict(), name) for (e, name) in self.select_list] if self.select_list is not None else None,
543
+ 'where_clause': self.where_clause.as_dict() if self.where_clause is not None else None,
544
+ 'group_by_clause':
545
+ [e.as_dict() for e in self.group_by_clause] if self.group_by_clause is not None else None,
546
+ 'order_by_clause':
547
+ [(e.as_dict(), asc) for (e,asc) in self.order_by_clause] if self.order_by_clause is not None else None,
548
+ 'limit_val': self.limit_val,
549
+ }
550
+ return d
551
+
552
+ def to_coco_dataset(self) -> Path:
553
+ """Convert the dataframe to a COCO dataset.
554
+ This dataframe must return a single json-typed output column in the following format:
555
+ {
556
+ 'image': PIL.Image.Image,
557
+ 'annotations': [
558
+ {
559
+ 'bbox': [x: int, y: int, w: int, h: int],
560
+ 'category': str | int,
561
+ },
562
+ ...
563
+ ],
564
+ }
565
+
566
+ Returns:
567
+ Path to the COCO dataset file.
568
+ """
569
+ from pixeltable.utils.coco import write_coco_dataset
570
+
571
+ summary_string = json.dumps(self._as_dict())
572
+ cache_key = hashlib.sha256(summary_string.encode()).hexdigest()
573
+
574
+ dest_path = (Env.get().dataset_cache_dir / f'coco_{cache_key}')
575
+ if dest_path.exists():
576
+ assert dest_path.is_dir()
577
+ data_file_path = dest_path / 'data.json'
578
+ assert data_file_path.exists()
579
+ assert data_file_path.is_file()
580
+ return data_file_path
581
+ else:
582
+ return write_coco_dataset(self, dest_path)
414
583
 
415
- def group_by(self, *expr_list: Tuple[exprs.Expr]) -> 'DataFrame':
416
- for e in expr_list:
417
- if not isinstance(e, exprs.Expr):
418
- raise exc.Error(f'Invalid expr in group_by(): {e}')
419
- self.group_by_clause = [e.copy() for e in expr_list]
420
- return self
421
-
422
- def _create_select_stmt(
423
- self, select_list: List[sql.sql.expression.ClauseElement],
424
- where_clause: Optional[sql.sql.expression.ClauseElement],
425
- valid_rowids: List[int],
426
- select_pk: bool,
427
- order_by_exprs: List[sql.sql.expression.ClauseElement]
428
- ) -> sql.sql.expression.Select:
429
- pk_cols = [self.tbl.rowid_col, self.tbl.v_min_col] if select_pk else []
430
- # we add pk_cols at the end so that the already-computed sql row indices remain correct
431
- stmt = sql.select(*select_list, *pk_cols) \
432
- .where(self.tbl.v_min_col <= self.tbl.version) \
433
- .where(self.tbl.v_max_col > self.tbl.version)
434
- if where_clause is not None:
435
- stmt = stmt.where(where_clause)
436
- if len(valid_rowids) > 0:
437
- stmt = stmt.where(self.tbl.rowid_col.in_(valid_rowids))
438
- if len(order_by_exprs) > 0:
439
- stmt = stmt.order_by(*order_by_exprs)
440
- return stmt
584
+ # TODO Factor this out into a separate module.
585
+ # The return type is unresolvable, but torch can't be imported since it's an optional dependency.
586
+ def to_pytorch_dataset(self, image_format: str = 'pt') -> 'torch.utils.data.IterableDataset':
587
+ """
588
+ Convert the dataframe to a pytorch IterableDataset suitable for parallel loading
589
+ with torch.utils.data.DataLoader.
590
+
591
+ This method requires pyarrow >= 13, torch and torchvision to work.
592
+
593
+ This method serializes data so it can be read from disk efficiently and repeatedly without
594
+ re-executing the query. This data is cached to disk for future re-use.
595
+
596
+ Args:
597
+ image_format: format of the images. Can be 'pt' (pytorch tensor) or 'np' (numpy array).
598
+ 'np' means image columns return as an RGB uint8 array of shape HxWxC.
599
+ 'pt' means image columns return as a CxHxW tensor with values in [0,1] and type torch.float32.
600
+ (the format output by torchvision.transforms.ToTensor())
601
+
602
+ Returns:
603
+ A pytorch IterableDataset: Columns become fields of the dataset, where rows are returned as a dictionary
604
+ compatible with torch.utils.data.DataLoader default collation.
605
+
606
+ Constraints:
607
+ The default collate_fn for torch.data.util.DataLoader cannot represent null values as part of a
608
+ pytorch tensor when forming batches. These values will raise an exception while running the dataloader.
609
+
610
+ If you have them, you can work around None values by providing your custom collate_fn to the DataLoader
611
+ (and have your model handle it). Or, if these are not meaningful values within a minibtach, you can
612
+ modify or remove any such values through selections and filters prior to calling to_pytorch_dataset().
613
+ """
614
+ # check dependencies
615
+ Env.get().require_package('pyarrow', [13])
616
+ Env.get().require_package('torch')
617
+ Env.get().require_package('torchvision')
618
+
619
+ from pixeltable.utils.parquet import save_parquet # pylint: disable=import-outside-toplevel
620
+ from pixeltable.utils.pytorch import PixeltablePytorchDataset # pylint: disable=import-outside-toplevel
621
+
622
+ summary_string = json.dumps(self._as_dict())
623
+ cache_key = hashlib.sha256(summary_string.encode()).hexdigest()
624
+
625
+ dest_path = (Env.get().dataset_cache_dir / f'df_{cache_key}').with_suffix('.parquet') # pylint: disable = protected-access
626
+ if dest_path.exists(): # fast path: use cache
627
+ assert dest_path.is_dir()
628
+ else:
629
+ save_parquet(self, dest_path)
630
+
631
+ return PixeltablePytorchDataset(path=dest_path, image_format=image_format)