pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. pixeltable/__init__.py +83 -19
  2. pixeltable/_query.py +1444 -0
  3. pixeltable/_version.py +1 -0
  4. pixeltable/catalog/__init__.py +7 -4
  5. pixeltable/catalog/catalog.py +2394 -119
  6. pixeltable/catalog/column.py +225 -104
  7. pixeltable/catalog/dir.py +38 -9
  8. pixeltable/catalog/globals.py +53 -34
  9. pixeltable/catalog/insertable_table.py +265 -115
  10. pixeltable/catalog/path.py +80 -17
  11. pixeltable/catalog/schema_object.py +28 -43
  12. pixeltable/catalog/table.py +1270 -677
  13. pixeltable/catalog/table_metadata.py +103 -0
  14. pixeltable/catalog/table_version.py +1270 -751
  15. pixeltable/catalog/table_version_handle.py +109 -0
  16. pixeltable/catalog/table_version_path.py +137 -42
  17. pixeltable/catalog/tbl_ops.py +53 -0
  18. pixeltable/catalog/update_status.py +191 -0
  19. pixeltable/catalog/view.py +251 -134
  20. pixeltable/config.py +215 -0
  21. pixeltable/env.py +736 -285
  22. pixeltable/exceptions.py +26 -2
  23. pixeltable/exec/__init__.py +7 -2
  24. pixeltable/exec/aggregation_node.py +39 -21
  25. pixeltable/exec/cache_prefetch_node.py +87 -109
  26. pixeltable/exec/cell_materialization_node.py +268 -0
  27. pixeltable/exec/cell_reconstruction_node.py +168 -0
  28. pixeltable/exec/component_iteration_node.py +25 -28
  29. pixeltable/exec/data_row_batch.py +11 -46
  30. pixeltable/exec/exec_context.py +26 -11
  31. pixeltable/exec/exec_node.py +35 -27
  32. pixeltable/exec/expr_eval/__init__.py +3 -0
  33. pixeltable/exec/expr_eval/evaluators.py +365 -0
  34. pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
  35. pixeltable/exec/expr_eval/globals.py +200 -0
  36. pixeltable/exec/expr_eval/row_buffer.py +74 -0
  37. pixeltable/exec/expr_eval/schedulers.py +413 -0
  38. pixeltable/exec/globals.py +35 -0
  39. pixeltable/exec/in_memory_data_node.py +35 -27
  40. pixeltable/exec/object_store_save_node.py +293 -0
  41. pixeltable/exec/row_update_node.py +44 -29
  42. pixeltable/exec/sql_node.py +414 -115
  43. pixeltable/exprs/__init__.py +8 -5
  44. pixeltable/exprs/arithmetic_expr.py +79 -45
  45. pixeltable/exprs/array_slice.py +5 -5
  46. pixeltable/exprs/column_property_ref.py +40 -26
  47. pixeltable/exprs/column_ref.py +254 -61
  48. pixeltable/exprs/comparison.py +14 -9
  49. pixeltable/exprs/compound_predicate.py +9 -10
  50. pixeltable/exprs/data_row.py +213 -72
  51. pixeltable/exprs/expr.py +270 -104
  52. pixeltable/exprs/expr_dict.py +6 -5
  53. pixeltable/exprs/expr_set.py +20 -11
  54. pixeltable/exprs/function_call.py +383 -284
  55. pixeltable/exprs/globals.py +18 -5
  56. pixeltable/exprs/in_predicate.py +7 -7
  57. pixeltable/exprs/inline_expr.py +37 -37
  58. pixeltable/exprs/is_null.py +8 -4
  59. pixeltable/exprs/json_mapper.py +120 -54
  60. pixeltable/exprs/json_path.py +90 -60
  61. pixeltable/exprs/literal.py +61 -16
  62. pixeltable/exprs/method_ref.py +7 -6
  63. pixeltable/exprs/object_ref.py +19 -8
  64. pixeltable/exprs/row_builder.py +238 -75
  65. pixeltable/exprs/rowid_ref.py +53 -15
  66. pixeltable/exprs/similarity_expr.py +65 -50
  67. pixeltable/exprs/sql_element_cache.py +5 -5
  68. pixeltable/exprs/string_op.py +107 -0
  69. pixeltable/exprs/type_cast.py +25 -13
  70. pixeltable/exprs/variable.py +2 -2
  71. pixeltable/func/__init__.py +9 -5
  72. pixeltable/func/aggregate_function.py +197 -92
  73. pixeltable/func/callable_function.py +119 -35
  74. pixeltable/func/expr_template_function.py +101 -48
  75. pixeltable/func/function.py +375 -62
  76. pixeltable/func/function_registry.py +20 -19
  77. pixeltable/func/globals.py +6 -5
  78. pixeltable/func/mcp.py +74 -0
  79. pixeltable/func/query_template_function.py +151 -35
  80. pixeltable/func/signature.py +178 -49
  81. pixeltable/func/tools.py +164 -0
  82. pixeltable/func/udf.py +176 -53
  83. pixeltable/functions/__init__.py +44 -4
  84. pixeltable/functions/anthropic.py +226 -47
  85. pixeltable/functions/audio.py +148 -11
  86. pixeltable/functions/bedrock.py +137 -0
  87. pixeltable/functions/date.py +188 -0
  88. pixeltable/functions/deepseek.py +113 -0
  89. pixeltable/functions/document.py +81 -0
  90. pixeltable/functions/fal.py +76 -0
  91. pixeltable/functions/fireworks.py +72 -20
  92. pixeltable/functions/gemini.py +249 -0
  93. pixeltable/functions/globals.py +208 -53
  94. pixeltable/functions/groq.py +108 -0
  95. pixeltable/functions/huggingface.py +1088 -95
  96. pixeltable/functions/image.py +155 -84
  97. pixeltable/functions/json.py +8 -11
  98. pixeltable/functions/llama_cpp.py +31 -19
  99. pixeltable/functions/math.py +169 -0
  100. pixeltable/functions/mistralai.py +50 -75
  101. pixeltable/functions/net.py +70 -0
  102. pixeltable/functions/ollama.py +29 -36
  103. pixeltable/functions/openai.py +548 -160
  104. pixeltable/functions/openrouter.py +143 -0
  105. pixeltable/functions/replicate.py +15 -14
  106. pixeltable/functions/reve.py +250 -0
  107. pixeltable/functions/string.py +310 -85
  108. pixeltable/functions/timestamp.py +37 -19
  109. pixeltable/functions/together.py +77 -120
  110. pixeltable/functions/twelvelabs.py +188 -0
  111. pixeltable/functions/util.py +7 -2
  112. pixeltable/functions/uuid.py +30 -0
  113. pixeltable/functions/video.py +1528 -117
  114. pixeltable/functions/vision.py +26 -26
  115. pixeltable/functions/voyageai.py +289 -0
  116. pixeltable/functions/whisper.py +19 -10
  117. pixeltable/functions/whisperx.py +179 -0
  118. pixeltable/functions/yolox.py +112 -0
  119. pixeltable/globals.py +716 -236
  120. pixeltable/index/__init__.py +3 -1
  121. pixeltable/index/base.py +17 -21
  122. pixeltable/index/btree.py +32 -22
  123. pixeltable/index/embedding_index.py +155 -92
  124. pixeltable/io/__init__.py +12 -7
  125. pixeltable/io/datarows.py +140 -0
  126. pixeltable/io/external_store.py +83 -125
  127. pixeltable/io/fiftyone.py +24 -33
  128. pixeltable/io/globals.py +47 -182
  129. pixeltable/io/hf_datasets.py +96 -127
  130. pixeltable/io/label_studio.py +171 -156
  131. pixeltable/io/lancedb.py +3 -0
  132. pixeltable/io/pandas.py +136 -115
  133. pixeltable/io/parquet.py +40 -153
  134. pixeltable/io/table_data_conduit.py +702 -0
  135. pixeltable/io/utils.py +100 -0
  136. pixeltable/iterators/__init__.py +8 -4
  137. pixeltable/iterators/audio.py +207 -0
  138. pixeltable/iterators/base.py +9 -3
  139. pixeltable/iterators/document.py +144 -87
  140. pixeltable/iterators/image.py +17 -38
  141. pixeltable/iterators/string.py +15 -12
  142. pixeltable/iterators/video.py +523 -127
  143. pixeltable/metadata/__init__.py +33 -8
  144. pixeltable/metadata/converters/convert_10.py +2 -3
  145. pixeltable/metadata/converters/convert_13.py +2 -2
  146. pixeltable/metadata/converters/convert_15.py +15 -11
  147. pixeltable/metadata/converters/convert_16.py +4 -5
  148. pixeltable/metadata/converters/convert_17.py +4 -5
  149. pixeltable/metadata/converters/convert_18.py +4 -6
  150. pixeltable/metadata/converters/convert_19.py +6 -9
  151. pixeltable/metadata/converters/convert_20.py +3 -6
  152. pixeltable/metadata/converters/convert_21.py +6 -8
  153. pixeltable/metadata/converters/convert_22.py +3 -2
  154. pixeltable/metadata/converters/convert_23.py +33 -0
  155. pixeltable/metadata/converters/convert_24.py +55 -0
  156. pixeltable/metadata/converters/convert_25.py +19 -0
  157. pixeltable/metadata/converters/convert_26.py +23 -0
  158. pixeltable/metadata/converters/convert_27.py +29 -0
  159. pixeltable/metadata/converters/convert_28.py +13 -0
  160. pixeltable/metadata/converters/convert_29.py +110 -0
  161. pixeltable/metadata/converters/convert_30.py +63 -0
  162. pixeltable/metadata/converters/convert_31.py +11 -0
  163. pixeltable/metadata/converters/convert_32.py +15 -0
  164. pixeltable/metadata/converters/convert_33.py +17 -0
  165. pixeltable/metadata/converters/convert_34.py +21 -0
  166. pixeltable/metadata/converters/convert_35.py +9 -0
  167. pixeltable/metadata/converters/convert_36.py +38 -0
  168. pixeltable/metadata/converters/convert_37.py +15 -0
  169. pixeltable/metadata/converters/convert_38.py +39 -0
  170. pixeltable/metadata/converters/convert_39.py +124 -0
  171. pixeltable/metadata/converters/convert_40.py +73 -0
  172. pixeltable/metadata/converters/convert_41.py +12 -0
  173. pixeltable/metadata/converters/convert_42.py +9 -0
  174. pixeltable/metadata/converters/convert_43.py +44 -0
  175. pixeltable/metadata/converters/util.py +44 -18
  176. pixeltable/metadata/notes.py +21 -0
  177. pixeltable/metadata/schema.py +185 -42
  178. pixeltable/metadata/utils.py +74 -0
  179. pixeltable/mypy/__init__.py +3 -0
  180. pixeltable/mypy/mypy_plugin.py +123 -0
  181. pixeltable/plan.py +616 -225
  182. pixeltable/share/__init__.py +3 -0
  183. pixeltable/share/packager.py +797 -0
  184. pixeltable/share/protocol/__init__.py +33 -0
  185. pixeltable/share/protocol/common.py +165 -0
  186. pixeltable/share/protocol/operation_types.py +33 -0
  187. pixeltable/share/protocol/replica.py +119 -0
  188. pixeltable/share/publish.py +349 -0
  189. pixeltable/store.py +398 -232
  190. pixeltable/type_system.py +730 -267
  191. pixeltable/utils/__init__.py +40 -0
  192. pixeltable/utils/arrow.py +201 -29
  193. pixeltable/utils/av.py +298 -0
  194. pixeltable/utils/azure_store.py +346 -0
  195. pixeltable/utils/coco.py +26 -27
  196. pixeltable/utils/code.py +4 -4
  197. pixeltable/utils/console_output.py +46 -0
  198. pixeltable/utils/coroutine.py +24 -0
  199. pixeltable/utils/dbms.py +92 -0
  200. pixeltable/utils/description_helper.py +11 -12
  201. pixeltable/utils/documents.py +60 -61
  202. pixeltable/utils/exception_handler.py +36 -0
  203. pixeltable/utils/filecache.py +38 -22
  204. pixeltable/utils/formatter.py +88 -51
  205. pixeltable/utils/gcs_store.py +295 -0
  206. pixeltable/utils/http.py +133 -0
  207. pixeltable/utils/http_server.py +14 -13
  208. pixeltable/utils/iceberg.py +13 -0
  209. pixeltable/utils/image.py +17 -0
  210. pixeltable/utils/lancedb.py +90 -0
  211. pixeltable/utils/local_store.py +322 -0
  212. pixeltable/utils/misc.py +5 -0
  213. pixeltable/utils/object_stores.py +573 -0
  214. pixeltable/utils/pydantic.py +60 -0
  215. pixeltable/utils/pytorch.py +20 -20
  216. pixeltable/utils/s3_store.py +527 -0
  217. pixeltable/utils/sql.py +32 -5
  218. pixeltable/utils/system.py +30 -0
  219. pixeltable/utils/transactional_directory.py +4 -3
  220. pixeltable-0.5.7.dist-info/METADATA +579 -0
  221. pixeltable-0.5.7.dist-info/RECORD +227 -0
  222. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
  223. pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
  224. pixeltable/__version__.py +0 -3
  225. pixeltable/catalog/named_function.py +0 -36
  226. pixeltable/catalog/path_dict.py +0 -141
  227. pixeltable/dataframe.py +0 -894
  228. pixeltable/exec/expr_eval_node.py +0 -232
  229. pixeltable/ext/__init__.py +0 -14
  230. pixeltable/ext/functions/__init__.py +0 -8
  231. pixeltable/ext/functions/whisperx.py +0 -77
  232. pixeltable/ext/functions/yolox.py +0 -157
  233. pixeltable/tool/create_test_db_dump.py +0 -311
  234. pixeltable/tool/create_test_video.py +0 -81
  235. pixeltable/tool/doc_plugins/griffe.py +0 -50
  236. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  237. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  238. pixeltable/tool/embed_udf.py +0 -9
  239. pixeltable/tool/mypy_plugin.py +0 -55
  240. pixeltable/utils/media_store.py +0 -76
  241. pixeltable/utils/s3.py +0 -16
  242. pixeltable-0.2.26.dist-info/METADATA +0 -400
  243. pixeltable-0.2.26.dist-info/RECORD +0 -156
  244. pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
  245. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
pixeltable/dataframe.py DELETED
@@ -1,894 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import builtins
4
- import copy
5
- import dataclasses
6
- import hashlib
7
- import json
8
- import logging
9
- import traceback
10
- from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, Optional, Sequence, Union, Literal
12
-
13
- import pandas as pd
14
- import pandas.io.formats.style
15
- import sqlalchemy as sql
16
-
17
- import pixeltable.catalog as catalog
18
- import pixeltable.exceptions as excs
19
- import pixeltable.exprs as exprs
20
- import pixeltable.type_system as ts
21
- from pixeltable import exec
22
- from pixeltable import plan
23
- from pixeltable.catalog import is_valid_identifier
24
- from pixeltable.catalog.globals import UpdateStatus
25
- from pixeltable.env import Env
26
- from pixeltable.type_system import ColumnType
27
- from pixeltable.utils.description_helper import DescriptionHelper
28
- from pixeltable.utils.formatter import Formatter
29
-
30
- if TYPE_CHECKING:
31
- import torch
32
-
33
- __all__ = ['DataFrame']
34
-
35
- _logger = logging.getLogger('pixeltable')
36
-
37
-
38
- class DataFrameResultSet:
39
- def __init__(self, rows: list[list[Any]], schema: dict[str, ColumnType]):
40
- self._rows = rows
41
- self._col_names = list(schema.keys())
42
- self.__schema = schema
43
- self.__formatter = Formatter(len(self._rows), len(self._col_names), Env.get().http_address)
44
-
45
- @property
46
- def schema(self) -> dict[str, ColumnType]:
47
- return self.__schema
48
-
49
- def __len__(self) -> int:
50
- return len(self._rows)
51
-
52
- def __repr__(self) -> str:
53
- return self.to_pandas().__repr__()
54
-
55
- def _repr_html_(self) -> str:
56
- formatters: dict[Hashable, Callable[[object], str]] = {}
57
- for col_name, col_type in self.schema.items():
58
- formatter = self.__formatter.get_pandas_formatter(col_type)
59
- if formatter is not None:
60
- formatters[col_name] = formatter
61
- return self.to_pandas().to_html(formatters=formatters, escape=False, index=False)
62
-
63
- def __str__(self) -> str:
64
- return self.to_pandas().to_string()
65
-
66
- def _reverse(self) -> None:
67
- """Reverse order of rows"""
68
- self._rows.reverse()
69
-
70
- def to_pandas(self) -> pd.DataFrame:
71
- return pd.DataFrame.from_records(self._rows, columns=self._col_names)
72
-
73
- def _row_to_dict(self, row_idx: int) -> dict[str, Any]:
74
- return {self._col_names[i]: self._rows[row_idx][i] for i in range(len(self._col_names))}
75
-
76
- def __getitem__(self, index: Any) -> Any:
77
- if isinstance(index, str):
78
- if index not in self._col_names:
79
- raise excs.Error(f'Invalid column name: {index}')
80
- col_idx = self._col_names.index(index)
81
- return [row[col_idx] for row in self._rows]
82
- if isinstance(index, int):
83
- return self._row_to_dict(index)
84
- if isinstance(index, tuple) and len(index) == 2:
85
- if not isinstance(index[0], int) or not (isinstance(index[1], str) or isinstance(index[1], int)):
86
- raise excs.Error(f'Bad index, expected [<row idx>, <column name | column index>]: {index}')
87
- if isinstance(index[1], str) and index[1] not in self._col_names:
88
- raise excs.Error(f'Invalid column name: {index[1]}')
89
- col_idx = self._col_names.index(index[1]) if isinstance(index[1], str) else index[1]
90
- return self._rows[index[0]][col_idx]
91
- raise excs.Error(f'Bad index: {index}')
92
-
93
- def __iter__(self) -> Iterator[dict[str, Any]]:
94
- return (self._row_to_dict(i) for i in range(len(self)))
95
-
96
- def __eq__(self, other):
97
- if not isinstance(other, DataFrameResultSet):
98
- return False
99
- return self.to_pandas().equals(other.to_pandas())
100
-
101
-
102
- # # TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
103
- # class AnalysisInfo:
104
- # def __init__(self, tbl: catalog.TableVersion):
105
- # self.tbl = tbl
106
- # # output of the SQL scan stage
107
- # self.sql_scan_output_exprs: list[exprs.Expr] = []
108
- # # output of the agg stage
109
- # self.agg_output_exprs: list[exprs.Expr] = []
110
- # # Where clause of the Select stmt of the SQL scan stage
111
- # self.sql_where_clause: Optional[sql.ClauseElement] = None
112
- # # filter predicate applied to input rows of the SQL scan stage
113
- # self.filter: Optional[exprs.Predicate] = None
114
- # self.similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None
115
- # self.agg_fn_calls: list[exprs.FunctionCall] = [] # derived from unique_exprs
116
- # self.has_frame_col: bool = False # True if we're referencing the frame col
117
- #
118
- # self.evaluator: Optional[exprs.Evaluator] = None
119
- # self.sql_scan_eval_ctx: list[exprs.Expr] = [] # needed to materialize output of SQL scan stage
120
- # self.agg_eval_ctx: list[exprs.Expr] = [] # needed to materialize output of agg stage
121
- # self.filter_eval_ctx: list[exprs.Expr] = []
122
- # self.group_by_eval_ctx: list[exprs.Expr] = []
123
- #
124
- # def finalize_exec(self) -> None:
125
- # """
126
- # Call release() on all collected Exprs.
127
- # """
128
- # exprs.Expr.release_list(self.sql_scan_output_exprs)
129
- # exprs.Expr.release_list(self.agg_output_exprs)
130
- # if self.filter is not None:
131
- # self.filter.release()
132
-
133
-
134
- class DataFrame:
135
- _from_clause: plan.FromClause
136
- _select_list_exprs: list[exprs.Expr]
137
- _schema: dict[str, ts.ColumnType]
138
- select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]]
139
- where_clause: Optional[exprs.Expr]
140
- group_by_clause: Optional[list[exprs.Expr]]
141
- grouping_tbl: Optional[catalog.TableVersion]
142
- order_by_clause: Optional[list[tuple[exprs.Expr, bool]]]
143
- limit_val: Optional[int]
144
-
145
- def __init__(
146
- self,
147
- from_clause: Optional[plan.FromClause] = None,
148
- select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]] = None,
149
- where_clause: Optional[exprs.Expr] = None,
150
- group_by_clause: Optional[list[exprs.Expr]] = None,
151
- grouping_tbl: Optional[catalog.TableVersion] = None,
152
- order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None, # list[(expr, asc)]
153
- limit: Optional[int] = None,
154
- ):
155
- self._from_clause = from_clause
156
-
157
- # exprs contain execution state and therefore cannot be shared
158
- select_list = copy.deepcopy(select_list)
159
- select_list_exprs, column_names = DataFrame._normalize_select_list(self._from_clause.tbls, select_list)
160
- # check select list after expansion to catch early
161
- # the following two lists are always non empty, even if select list is None.
162
- assert len(column_names) == len(select_list_exprs)
163
- self._select_list_exprs = select_list_exprs
164
- self._schema = {column_names[i]: select_list_exprs[i].col_type for i in range(len(column_names))}
165
- self.select_list = select_list
166
-
167
- self.where_clause = copy.deepcopy(where_clause)
168
- assert group_by_clause is None or grouping_tbl is None
169
- self.group_by_clause = copy.deepcopy(group_by_clause)
170
- self.grouping_tbl = grouping_tbl
171
- self.order_by_clause = copy.deepcopy(order_by_clause)
172
- self.limit_val = limit
173
-
174
- @classmethod
175
- def _normalize_select_list(
176
- cls,
177
- tbls: list[catalog.TableVersionPath],
178
- select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]],
179
- ) -> tuple[list[exprs.Expr], list[str]]:
180
- """
181
- Expand select list information with all columns and their names
182
- Returns:
183
- a pair composed of the list of expressions and the list of corresponding names
184
- """
185
- if select_list is None:
186
- select_list = [(exprs.ColumnRef(col), None) for tbl in tbls for col in tbl.columns()]
187
-
188
- out_exprs: list[exprs.Expr] = []
189
- out_names: list[str] = [] # keep track of order
190
- seen_out_names: set[str] = set() # use to check for duplicates in loop, avoid square complexity
191
- for i, (expr, name) in enumerate(select_list):
192
- if name is None:
193
- # use default, add suffix if needed so default adds no duplicates
194
- default_name = expr.default_column_name()
195
- if default_name is not None:
196
- column_name = default_name
197
- if default_name in seen_out_names:
198
- # already used, then add suffix until unique name is found
199
- for j in range(1, len(out_names) + 1):
200
- column_name = f'{default_name}_{j}'
201
- if column_name not in seen_out_names:
202
- break
203
- else: # no default name, eg some expressions
204
- column_name = f'col_{i}'
205
- else: # user provided name, no attempt to rename
206
- column_name = name
207
-
208
- out_exprs.append(expr)
209
- out_names.append(column_name)
210
- seen_out_names.add(column_name)
211
- assert len(out_exprs) == len(out_names)
212
- assert set(out_names) == seen_out_names
213
- return out_exprs, out_names
214
-
215
- @property
216
- def _first_tbl(self) -> catalog.TableVersionPath:
217
- assert len(self._from_clause.tbls) == 1
218
- return self._from_clause.tbls[0]
219
-
220
- def _vars(self) -> dict[str, exprs.Variable]:
221
- """
222
- Return a dict mapping variable name to Variable for all Variables contained in any component of the DataFrame
223
- """
224
- all_exprs: list[exprs.Expr] = []
225
- all_exprs.extend(self._select_list_exprs)
226
- if self.where_clause is not None:
227
- all_exprs.append(self.where_clause)
228
- if self.group_by_clause is not None:
229
- all_exprs.extend(self.group_by_clause)
230
- if self.order_by_clause is not None:
231
- all_exprs.extend([expr for expr, _ in self.order_by_clause])
232
- vars = exprs.Expr.list_subexprs(all_exprs, expr_class=exprs.Variable)
233
- unique_vars: dict[str, exprs.Variable] = {}
234
- for var in vars:
235
- if var.name not in unique_vars:
236
- unique_vars[var.name] = var
237
- else:
238
- if unique_vars[var.name].col_type != var.col_type:
239
- raise excs.Error(f'Multiple definitions of parameter {var.name}')
240
- return unique_vars
241
-
242
- def parameters(self) -> dict[str, ColumnType]:
243
- """Return a dict mapping parameter name to parameter type.
244
-
245
- Parameters are Variables contained in any component of the DataFrame.
246
- """
247
- vars = self._vars()
248
- return {name: var.col_type for name, var in vars.items()}
249
-
250
- def _exec(self, conn: Optional[sql.engine.Connection] = None) -> Iterator[exprs.DataRow]:
251
- """Run the query and return rows as a generator.
252
- This function must not modify the state of the DataFrame, otherwise it breaks dataset caching.
253
- """
254
- plan = self._create_query_plan()
255
-
256
- def exec_plan(conn: sql.engine.Connection) -> Iterator[exprs.DataRow]:
257
- plan.ctx.set_conn(conn)
258
- plan.open()
259
- try:
260
- for row_batch in plan:
261
- yield from row_batch
262
- finally:
263
- plan.close()
264
-
265
- if conn is None:
266
- with Env.get().engine.begin() as conn:
267
- yield from exec_plan(conn)
268
- else:
269
- yield from exec_plan(conn)
270
-
271
- def _create_query_plan(self) -> exec.ExecNode:
272
- # construct a group-by clause if we're grouping by a table
273
- group_by_clause: Optional[list[exprs.Expr]] = None
274
- if self.grouping_tbl is not None:
275
- assert self.group_by_clause is None
276
- num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
277
- # the grouping table must be a base of self.tbl
278
- assert num_rowid_cols <= len(self._first_tbl.tbl_version.store_tbl.rowid_columns())
279
- group_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
280
- elif self.group_by_clause is not None:
281
- group_by_clause = self.group_by_clause
282
-
283
- for item in self._select_list_exprs:
284
- item.bind_rel_paths(None)
285
-
286
- return plan.Planner.create_query_plan(
287
- self._from_clause,
288
- self._select_list_exprs,
289
- where_clause=self.where_clause,
290
- group_by_clause=group_by_clause,
291
- order_by_clause=self.order_by_clause if self.order_by_clause is not None else [],
292
- limit=self.limit_val
293
- )
294
-
295
- def _has_joins(self) -> bool:
296
- return len(self._from_clause.join_clauses) > 0
297
-
298
- def show(self, n: int = 20) -> DataFrameResultSet:
299
- assert n is not None
300
- return self.limit(n).collect()
301
-
302
- def head(self, n: int = 10) -> DataFrameResultSet:
303
- if self.order_by_clause is not None:
304
- raise excs.Error(f'head() cannot be used with order_by()')
305
- if self._has_joins():
306
- raise excs.Error(f'head() not supported for joins')
307
- num_rowid_cols = len(self._first_tbl.tbl_version.store_tbl.rowid_columns())
308
- order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
309
- return self.order_by(*order_by_clause, asc=True).limit(n).collect()
310
-
311
- def tail(self, n: int = 10) -> DataFrameResultSet:
312
- if self.order_by_clause is not None:
313
- raise excs.Error(f'tail() cannot be used with order_by()')
314
- if self._has_joins():
315
- raise excs.Error(f'tail() not supported for joins')
316
- num_rowid_cols = len(self._first_tbl.tbl_version.store_tbl.rowid_columns())
317
- order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
318
- result = self.order_by(*order_by_clause, asc=False).limit(n).collect()
319
- result._reverse()
320
- return result
321
-
322
- @property
323
- def schema(self) -> dict[str, ColumnType]:
324
- return self._schema
325
-
326
- def bind(self, args: dict[str, Any]) -> DataFrame:
327
- """Bind arguments to parameters and return a new DataFrame."""
328
- # substitute Variables with the corresponding values according to 'args', converted to Literals
329
- select_list_exprs = copy.deepcopy(self._select_list_exprs)
330
- where_clause = copy.deepcopy(self.where_clause)
331
- group_by_clause = copy.deepcopy(self.group_by_clause)
332
- order_by_exprs = [copy.deepcopy(order_by_expr) for order_by_expr, _ in self.order_by_clause] \
333
- if self.order_by_clause is not None else None
334
-
335
- var_exprs: dict[exprs.Expr, exprs.Expr] = {}
336
- vars = self._vars()
337
- for arg_name, arg_val in args.items():
338
- if arg_name not in vars:
339
- # ignore unused variables
340
- continue
341
- var_expr = vars[arg_name]
342
- arg_expr = exprs.Expr.from_object(arg_val)
343
- if arg_expr is None:
344
- raise excs.Error(f'Cannot convert argument {arg_val} to a Pixeltable expression')
345
- var_exprs[var_expr] = arg_expr
346
-
347
- exprs.Expr.list_substitute(select_list_exprs, var_exprs)
348
- if where_clause is not None:
349
- where_clause.substitute(var_exprs)
350
- if group_by_clause is not None:
351
- exprs.Expr.list_substitute(group_by_clause, var_exprs)
352
- if order_by_exprs is not None:
353
- exprs.Expr.list_substitute(order_by_exprs, var_exprs)
354
-
355
- select_list = list(zip(select_list_exprs, self.schema.keys()))
356
- order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None
357
- if order_by_exprs is not None:
358
- order_by_clause = [
359
- (expr, asc) for expr, asc in zip(order_by_exprs, [asc for _, asc in self.order_by_clause])
360
- ]
361
-
362
- return DataFrame(
363
- from_clause=self._from_clause, select_list=select_list, where_clause=where_clause,
364
- group_by_clause=group_by_clause, grouping_tbl=self.grouping_tbl,
365
- order_by_clause=order_by_clause, limit=self.limit_val)
366
-
367
- def _output_row_iterator(self, conn: Optional[sql.engine.Connection] = None) -> Iterator[list]:
368
- try:
369
- for data_row in self._exec(conn):
370
- yield [data_row[e.slot_idx] for e in self._select_list_exprs]
371
- except excs.ExprEvalError as e:
372
- msg = f'In row {e.row_num} the {e.expr_msg} encountered exception ' f'{type(e.exc).__name__}:\n{str(e.exc)}'
373
- if len(e.input_vals) > 0:
374
- input_msgs = [
375
- f"'{d}' = {d.col_type.print_value(e.input_vals[i])}" for i, d in enumerate(e.expr.dependencies())
376
- ]
377
- msg += f'\nwith {", ".join(input_msgs)}'
378
- assert e.exc_tb is not None
379
- stack_trace = traceback.format_tb(e.exc_tb)
380
- if len(stack_trace) > 2:
381
- # append a stack trace if the exception happened in user code
382
- # (frame 0 is ExprEvaluator and frame 1 is some expr's eval()
383
- nl = '\n'
384
- # [-1:0:-1]: leave out entry 0 and reverse order, so that the most recent frame is at the top
385
- msg += f'\nStack:\n{nl.join(stack_trace[-1:1:-1])}'
386
- raise excs.Error(msg)
387
- except sql.exc.DBAPIError as e:
388
- raise excs.Error(f'Error during SQL execution:\n{e}')
389
-
390
- def collect(self) -> DataFrameResultSet:
391
- return self._collect()
392
-
393
- def _collect(self, conn: Optional[sql.engine.Connection] = None) -> DataFrameResultSet:
394
- return DataFrameResultSet(list(self._output_row_iterator(conn)), self.schema)
395
-
396
- def count(self) -> int:
397
- from pixeltable.plan import Planner
398
-
399
- stmt = Planner.create_count_stmt(self._first_tbl, self.where_clause)
400
- with Env.get().engine.connect() as conn:
401
- result: int = conn.execute(stmt).scalar_one()
402
- assert isinstance(result, int)
403
- return result
404
-
405
- def _descriptors(self) -> DescriptionHelper:
406
- helper = DescriptionHelper()
407
- helper.append(self._col_descriptor())
408
- qd = self._query_descriptor()
409
- if not qd.empty:
410
- helper.append(qd, show_index=True, show_header=False)
411
- return helper
412
-
413
- def _col_descriptor(self) -> pd.DataFrame:
414
- return pd.DataFrame([
415
- {
416
- 'Name': name,
417
- 'Type': expr.col_type._to_str(as_schema=True),
418
- 'Expression': expr.display_str(inline=False),
419
- }
420
- for name, expr in zip(self.schema.keys(), self._select_list_exprs)
421
- ])
422
-
423
- def _query_descriptor(self) -> pd.DataFrame:
424
- heading_vals: list[str] = []
425
- info_vals: list[str] = []
426
- heading_vals.append('From')
427
- info_vals.extend(tbl.tbl_name() for tbl in self._from_clause.tbls)
428
- if self.where_clause is not None:
429
- heading_vals.append('Where')
430
- info_vals.append(self.where_clause.display_str(inline=False))
431
- if self.group_by_clause is not None:
432
- heading_vals.append('Group By')
433
- heading_vals.extend([''] * (len(self.group_by_clause) - 1))
434
- info_vals.extend(e.display_str(inline=False) for e in self.group_by_clause)
435
- if self.order_by_clause is not None:
436
- heading_vals.append('Order By')
437
- heading_vals.extend([''] * (len(self.order_by_clause) - 1))
438
- info_vals.extend(
439
- [f'{e[0].display_str(inline=False)} {"asc" if e[1] else "desc"}' for e in self.order_by_clause]
440
- )
441
- if self.limit_val is not None:
442
- heading_vals.append('Limit')
443
- info_vals.append(str(self.limit_val))
444
- assert len(heading_vals) == len(info_vals)
445
- return pd.DataFrame(info_vals, index=heading_vals)
446
-
447
- def describe(self) -> None:
448
- """
449
- Prints a tabular description of this DataFrame.
450
- The description has two columns, heading and info, which list the contents of each 'component'
451
- (select list, where clause, ...) vertically.
452
- """
453
- if getattr(builtins, '__IPYTHON__', False):
454
- from IPython.display import display
455
- display(self._repr_html_())
456
- else:
457
- print(repr(self))
458
-
459
- def __repr__(self) -> str:
460
- return self._descriptors().to_string()
461
-
462
- def _repr_html_(self) -> str:
463
- return self._descriptors().to_html()
464
-
465
- def select(self, *items: Any, **named_items: Any) -> DataFrame:
466
- if self.select_list is not None:
467
- raise excs.Error(f'Select list already specified')
468
- for name, _ in named_items.items():
469
- if not isinstance(name, str) or not is_valid_identifier(name):
470
- raise excs.Error(f'Invalid name: {name}')
471
- base_list = [(expr, None) for expr in items] + [(expr, k) for (k, expr) in named_items.items()]
472
- if len(base_list) == 0:
473
- return self
474
-
475
- # analyze select list; wrap literals with the corresponding expressions
476
- select_list: list[tuple[exprs.Expr, Optional[str]]] = []
477
- for raw_expr, name in base_list:
478
- if isinstance(raw_expr, exprs.Expr):
479
- select_list.append((raw_expr, name))
480
- elif isinstance(raw_expr, dict):
481
- select_list.append((exprs.InlineDict(raw_expr), name))
482
- elif isinstance(raw_expr, list):
483
- select_list.append((exprs.InlineList(raw_expr), name))
484
- else:
485
- select_list.append((exprs.Literal(raw_expr), name))
486
- expr = select_list[-1][0]
487
- if expr.col_type.is_invalid_type():
488
- raise excs.Error(f'Invalid type: {raw_expr}')
489
- if not expr.is_bound_by(self._from_clause.tbls):
490
- raise excs.Error(
491
- f"Expression '{expr}' cannot be evaluated in the context of this query's tables "
492
- f"({','.join(tbl.tbl_name() for tbl in self._from_clause.tbls)})")
493
-
494
- # check user provided names do not conflict among themselves or with auto-generated ones
495
- seen: set[str] = set()
496
- _, names = DataFrame._normalize_select_list(self._from_clause.tbls, select_list)
497
- for name in names:
498
- if name in seen:
499
- repeated_names = [j for j, x in enumerate(names) if x == name]
500
- pretty = ', '.join(map(str, repeated_names))
501
- raise excs.Error(f'Repeated column name "{name}" in select() at positions: {pretty}')
502
- seen.add(name)
503
-
504
- return DataFrame(
505
- from_clause=self._from_clause,
506
- select_list=select_list,
507
- where_clause=self.where_clause,
508
- group_by_clause=self.group_by_clause,
509
- grouping_tbl=self.grouping_tbl,
510
- order_by_clause=self.order_by_clause,
511
- limit=self.limit_val,
512
- )
513
-
514
- def where(self, pred: exprs.Expr) -> DataFrame:
515
- if not isinstance(pred, exprs.Expr):
516
- raise excs.Error(f'Where() requires a Pixeltable expression, but instead got {type(pred)}')
517
- if not pred.col_type.is_bool_type():
518
- raise excs.Error(f'Where(): expression needs to return bool, but instead returns {pred.col_type}')
519
- return DataFrame(
520
- from_clause=self._from_clause,
521
- select_list=self.select_list,
522
- where_clause=pred,
523
- group_by_clause=self.group_by_clause,
524
- grouping_tbl=self.grouping_tbl,
525
- order_by_clause=self.order_by_clause,
526
- limit=self.limit_val,
527
- )
528
-
529
- def _create_join_predicate(
530
- self, other: catalog.TableVersionPath, on: Union[exprs.Expr, Sequence[exprs.ColumnRef]]
531
- ) -> exprs.Expr:
532
- """Verifies user-specified 'on' argument and converts it into a join predicate."""
533
- col_refs: list[exprs.ColumnRef] = []
534
- joined_tbls = self._from_clause.tbls + [other]
535
-
536
- if isinstance(on, exprs.ColumnRef):
537
- on = [on]
538
- elif isinstance(on, exprs.Expr):
539
- if not on.is_bound_by(joined_tbls):
540
- raise excs.Error(f"'on': expression cannot be evaluated in the context of the joined tables: {on}")
541
- if not on.col_type.is_bool_type():
542
- raise excs.Error(f"'on': boolean expression expected, but got {on.col_type}: {on}")
543
- return on
544
- else:
545
- if not isinstance(on, Sequence) or len(on) == 0:
546
- raise excs.Error(
547
- f"'on': must be a sequence of column references or a boolean expression")
548
-
549
- assert isinstance(on, Sequence)
550
- for col_ref in on:
551
- if not isinstance(col_ref, exprs.ColumnRef):
552
- raise excs.Error(
553
- f"'on': must be a sequence of column references or a boolean expression")
554
- if not col_ref.is_bound_by(joined_tbls):
555
- raise excs.Error(f"'on': expression cannot be evaluated in the context of the joined tables: {col_ref}")
556
- col_refs.append(col_ref)
557
-
558
- predicates: list[exprs.Expr] = []
559
- # try to turn ColumnRefs into equality predicates
560
- assert len(col_refs) > 0 and len(joined_tbls) >= 2
561
- for col_ref in col_refs:
562
- # identify the referenced column by name in 'other'
563
- rhs_col = other.get_column(col_ref.col.name, include_bases=True)
564
- if rhs_col is None:
565
- raise excs.Error(f"'on': column {col_ref.col.name!r} not found in joined table")
566
- rhs_col_ref = exprs.ColumnRef(rhs_col)
567
-
568
- lhs_col_ref: Optional[exprs.ColumnRef] = None
569
- if any(tbl.has_column(col_ref.col, include_bases=True) for tbl in self._from_clause.tbls):
570
- # col_ref comes from the existing from_clause, we use that directly
571
- lhs_col_ref = col_ref
572
- else:
573
- # col_ref comes from other, we need to look for a match in the existing from_clause by name
574
- for tbl in self._from_clause.tbls:
575
- col = tbl.get_column(col_ref.col.name, include_bases=True)
576
- if col is None:
577
- continue
578
- if lhs_col_ref is not None:
579
- raise excs.Error(f"'on': ambiguous column reference: {col_ref.col.name!r}")
580
- lhs_col_ref = exprs.ColumnRef(col)
581
- if lhs_col_ref is None:
582
- tbl_names = [tbl.tbl_name() for tbl in self._from_clause.tbls]
583
- raise excs.Error(
584
- f"'on': column {col_ref.col.name!r} not found in any of: {' '.join(tbl_names)}")
585
- pred = exprs.Comparison(exprs.ComparisonOperator.EQ, lhs_col_ref, rhs_col_ref)
586
- predicates.append(pred)
587
-
588
- assert len(predicates) > 0
589
- if len(predicates) == 1:
590
- return predicates[0]
591
- else:
592
- return exprs.CompoundPredicate(operator=exprs.LogicalOperator.AND, operands=predicates)
593
-
594
- def join(
595
- self, other: catalog.Table, on: Optional[Union[exprs.Expr, Sequence[exprs.ColumnRef]]] = None,
596
- how: plan.JoinType.LiteralType = 'inner'
597
- ) -> DataFrame:
598
- """
599
- Join this DataFrame with a table.
600
-
601
- Args:
602
- other: the table to join with
603
- on: the join condition, which can be either a) references to one or more columns or b) a boolean
604
- expression.
605
-
606
- - column references: implies an equality predicate that matches columns in both this
607
- DataFrame and `other` by name.
608
-
609
- - column in `other`: A column with that same name must be present in this DataFrame, and **it must
610
- be unique** (otherwise the join is ambiguous).
611
- - column in this DataFrame: A column with that same name must be present in `other`.
612
-
613
- - boolean expression: The expressions must be valid in the context of the joined tables.
614
- how: the type of join to perform.
615
-
616
- - `'inner'`: only keep rows that have a match in both
617
- - `'left'`: keep all rows from this DataFrame and only matching rows from the other table
618
- - `'right'`: keep all rows from the other table and only matching rows from this DataFrame
619
- - `'full_outer'`: keep all rows from both this DataFrame and the other table
620
- - `'cross'`: Cartesian product; no `on` condition allowed
621
-
622
- Returns:
623
- A new DataFrame.
624
-
625
- Examples:
626
- Perform an inner join between t1 and t2 on the column id:
627
-
628
- >>> join1 = t1.join(t2, on=t2.id)
629
-
630
- Perform a left outer join of join1 with t3, also on id (note that we can't specify `on=t3.id` here,
631
- because that would be ambiguous, since both t1 and t2 have a column named id):
632
-
633
- >>> join2 = join1.join(t3, on=t2.id, how='left')
634
-
635
- Do the same, but now with an explicit join predicate:
636
-
637
- >>> join2 = join1.join(t3, on=t2.id == t3.id, how='left')
638
-
639
- Join t with d, which has a composite primary key (columns pk1 and pk2, with corresponding foreign
640
- key columns d1 and d2 in t):
641
-
642
- >>> df = t.join(d, on=(t.d1 == d.pk1) & (t.d2 == d.pk2), how='left')
643
- """
644
- join_pred: Optional[exprs.Expr]
645
- if how == 'cross':
646
- if on is not None:
647
- raise excs.Error(f"'on' not allowed for cross join")
648
- join_pred = None
649
- else:
650
- if on is None:
651
- raise excs.Error(f"how={how!r} requires 'on'")
652
- join_pred = self._create_join_predicate(other._tbl_version_path, on)
653
- join_clause = plan.JoinClause(join_type=plan.JoinType.validated(how, "'how'"), join_predicate=join_pred)
654
- from_clause = plan.FromClause(
655
- tbls=[*self._from_clause.tbls, other._tbl_version_path],
656
- join_clauses=[*self._from_clause.join_clauses, join_clause])
657
- return DataFrame(
658
- from_clause=from_clause,
659
- select_list=self.select_list, where_clause=self.where_clause,
660
- group_by_clause=self.group_by_clause, grouping_tbl=self.grouping_tbl,
661
- order_by_clause=self.order_by_clause, limit=self.limit_val,
662
- )
663
-
664
- def group_by(self, *grouping_items: Any) -> DataFrame:
665
- """
666
- Add a group-by clause to this DataFrame.
667
- Variants:
668
- - group_by(<base table>): group a component view by their respective base table rows
669
- - group_by(<expr>, ...): group by the given expressions
670
- """
671
- if self.group_by_clause is not None:
672
- raise excs.Error(f'Group-by already specified')
673
- grouping_tbl: Optional[catalog.TableVersion] = None
674
- group_by_clause: Optional[list[exprs.Expr]] = None
675
- for item in grouping_items:
676
- if isinstance(item, catalog.Table):
677
- if len(grouping_items) > 1:
678
- raise excs.Error(f'group_by(): only one table can be specified')
679
- if len(self._from_clause.tbls) > 1:
680
- raise excs.Error(f'group_by() with Table not supported for joins')
681
- # we need to make sure that the grouping table is a base of self.tbl
682
- base = self._first_tbl.find_tbl_version(item._tbl_version_path.tbl_id())
683
- if base is None or base.id == self._first_tbl.tbl_id():
684
- raise excs.Error(f'group_by(): {item._name} is not a base table of {self._first_tbl.tbl_name()}')
685
- grouping_tbl = item._tbl_version_path.tbl_version
686
- break
687
- if not isinstance(item, exprs.Expr):
688
- raise excs.Error(f'Invalid expression in group_by(): {item}')
689
- if grouping_tbl is None:
690
- group_by_clause = list(grouping_items)
691
- return DataFrame(
692
- from_clause=self._from_clause,
693
- select_list=self.select_list,
694
- where_clause=self.where_clause,
695
- group_by_clause=group_by_clause,
696
- grouping_tbl=grouping_tbl,
697
- order_by_clause=self.order_by_clause,
698
- limit=self.limit_val,
699
- )
700
-
701
- def order_by(self, *expr_list: exprs.Expr, asc: bool = True) -> DataFrame:
702
- for e in expr_list:
703
- if not isinstance(e, exprs.Expr):
704
- raise excs.Error(f'Invalid expression in order_by(): {e}')
705
- order_by_clause = self.order_by_clause if self.order_by_clause is not None else []
706
- order_by_clause.extend([(e.copy(), asc) for e in expr_list])
707
- return DataFrame(
708
- from_clause=self._from_clause,
709
- select_list=self.select_list,
710
- where_clause=self.where_clause,
711
- group_by_clause=self.group_by_clause,
712
- grouping_tbl=self.grouping_tbl,
713
- order_by_clause=order_by_clause,
714
- limit=self.limit_val,
715
- )
716
-
717
- def limit(self, n: int) -> DataFrame:
718
- # TODO: allow n to be a Variable that can be substituted in bind()
719
- assert n is not None and isinstance(n, int)
720
- return DataFrame(
721
- from_clause=self._from_clause,
722
- select_list=self.select_list,
723
- where_clause=self.where_clause,
724
- group_by_clause=self.group_by_clause,
725
- grouping_tbl=self.grouping_tbl,
726
- order_by_clause=self.order_by_clause,
727
- limit=n,
728
- )
729
-
730
- def update(self, value_spec: dict[str, Any], cascade: bool = True) -> UpdateStatus:
731
- self._validate_mutable('update')
732
- return self._first_tbl.tbl_version.update(value_spec, where=self.where_clause, cascade=cascade)
733
-
734
- def delete(self) -> UpdateStatus:
735
- self._validate_mutable('delete')
736
- if not self._first_tbl.is_insertable():
737
- raise excs.Error(f'Cannot delete from view')
738
- return self._first_tbl.tbl_version.delete(where=self.where_clause)
739
-
740
- def _validate_mutable(self, op_name: str) -> None:
741
- """Tests whether this `DataFrame` can be mutated (such as by an update operation)."""
742
- if self.group_by_clause is not None or self.grouping_tbl is not None:
743
- raise excs.Error(f'Cannot use `{op_name}` after `group_by`')
744
- if self.order_by_clause is not None:
745
- raise excs.Error(f'Cannot use `{op_name}` after `order_by`')
746
- if self.select_list is not None:
747
- raise excs.Error(f'Cannot use `{op_name}` after `select`')
748
- if self.limit_val is not None:
749
- raise excs.Error(f'Cannot use `{op_name}` after `limit`')
750
-
751
- def __getitem__(self, index: Union[exprs.Expr, Sequence[exprs.Expr]]) -> DataFrame:
752
- """
753
- Allowed:
754
- - [list[Expr]]/[tuple[Expr]]: setting the select list
755
- - [Expr]: setting a single-col select list
756
- """
757
- if isinstance(index, exprs.Expr):
758
- return self.select(index)
759
- if isinstance(index, Sequence):
760
- return self.select(*index)
761
- raise TypeError(f'Invalid index type: {type(index)}')
762
-
763
- def as_dict(self) -> dict[str, Any]:
764
- """
765
- Returns:
766
- Dictionary representing this dataframe.
767
- """
768
- d = {
769
- '_classname': 'DataFrame',
770
- 'from_clause': {
771
- 'tbls': [tbl.as_dict() for tbl in self._from_clause.tbls],
772
- 'join_clauses': [dataclasses.asdict(clause) for clause in self._from_clause.join_clauses]
773
- },
774
- 'select_list':
775
- [(e.as_dict(), name) for (e, name) in self.select_list] if self.select_list is not None else None,
776
- 'where_clause': self.where_clause.as_dict() if self.where_clause is not None else None,
777
- 'group_by_clause':
778
- [e.as_dict() for e in self.group_by_clause] if self.group_by_clause is not None else None,
779
- 'grouping_tbl': self.grouping_tbl.as_dict() if self.grouping_tbl is not None else None,
780
- 'order_by_clause':
781
- [(e.as_dict(), asc) for (e,asc) in self.order_by_clause] if self.order_by_clause is not None else None,
782
- 'limit_val': self.limit_val,
783
- }
784
- return d
785
-
786
- @classmethod
787
- def from_dict(cls, d: dict[str, Any]) -> 'DataFrame':
788
- tbls = [catalog.TableVersionPath.from_dict(tbl_dict) for tbl_dict in d['from_clause']['tbls']]
789
- join_clauses = [plan.JoinClause(**clause_dict) for clause_dict in d['from_clause']['join_clauses']]
790
- from_clause = plan.FromClause(tbls=tbls, join_clauses=join_clauses)
791
- select_list = [(exprs.Expr.from_dict(e), name) for e, name in d['select_list']] \
792
- if d['select_list'] is not None else None
793
- where_clause = exprs.Expr.from_dict(d['where_clause']) \
794
- if d['where_clause'] is not None else None
795
- group_by_clause = [exprs.Expr.from_dict(e) for e in d['group_by_clause']] \
796
- if d['group_by_clause'] is not None else None
797
- grouping_tbl = catalog.TableVersion.from_dict(d['grouping_tbl']) \
798
- if d['grouping_tbl'] is not None else None
799
- order_by_clause = [(exprs.Expr.from_dict(e), asc) for e, asc in d['order_by_clause']] \
800
- if d['order_by_clause'] is not None else None
801
- limit_val = d['limit_val']
802
- return DataFrame(
803
- from_clause=from_clause, select_list=select_list, where_clause=where_clause,
804
- group_by_clause=group_by_clause, grouping_tbl=grouping_tbl, order_by_clause=order_by_clause,
805
- limit=limit_val)
806
-
807
- def _hash_result_set(self) -> str:
808
- """Return a hash that changes when the result set changes."""
809
- d = self.as_dict()
810
- # add list of referenced table versions (the actual versions, not the effective ones) in order to force cache
811
- # invalidation when any of the referenced tables changes
812
- d['tbl_versions'] = [
813
- tbl_version.version for tbl in self._from_clause.tbls for tbl_version in tbl.get_tbl_versions()
814
- ]
815
- summary_string = json.dumps(d)
816
- return hashlib.sha256(summary_string.encode()).hexdigest()
817
-
818
- def to_coco_dataset(self) -> Path:
819
- """Convert the dataframe to a COCO dataset.
820
- This dataframe must return a single json-typed output column in the following format:
821
- {
822
- 'image': PIL.Image.Image,
823
- 'annotations': [
824
- {
825
- 'bbox': [x: int, y: int, w: int, h: int],
826
- 'category': str | int,
827
- },
828
- ...
829
- ],
830
- }
831
-
832
- Returns:
833
- Path to the COCO dataset file.
834
- """
835
- from pixeltable.utils.coco import write_coco_dataset
836
-
837
- cache_key = self._hash_result_set()
838
- dest_path = Env.get().dataset_cache_dir / f'coco_{cache_key}'
839
- if dest_path.exists():
840
- assert dest_path.is_dir()
841
- data_file_path = dest_path / 'data.json'
842
- assert data_file_path.exists()
843
- assert data_file_path.is_file()
844
- return data_file_path
845
- else:
846
- return write_coco_dataset(self, dest_path)
847
-
848
- # TODO Factor this out into a separate module.
849
- # The return type is unresolvable, but torch can't be imported since it's an optional dependency.
850
- def to_pytorch_dataset(self, image_format: str = 'pt') -> 'torch.utils.data.IterableDataset':
851
- """
852
- Convert the dataframe to a pytorch IterableDataset suitable for parallel loading
853
- with torch.utils.data.DataLoader.
854
-
855
- This method requires pyarrow >= 13, torch and torchvision to work.
856
-
857
- This method serializes data so it can be read from disk efficiently and repeatedly without
858
- re-executing the query. This data is cached to disk for future re-use.
859
-
860
- Args:
861
- image_format: format of the images. Can be 'pt' (pytorch tensor) or 'np' (numpy array).
862
- 'np' means image columns return as an RGB uint8 array of shape HxWxC.
863
- 'pt' means image columns return as a CxHxW tensor with values in [0,1] and type torch.float32.
864
- (the format output by torchvision.transforms.ToTensor())
865
-
866
- Returns:
867
- A pytorch IterableDataset: Columns become fields of the dataset, where rows are returned as a dictionary
868
- compatible with torch.utils.data.DataLoader default collation.
869
-
870
- Constraints:
871
- The default collate_fn for torch.data.util.DataLoader cannot represent null values as part of a
872
- pytorch tensor when forming batches. These values will raise an exception while running the dataloader.
873
-
874
- If you have them, you can work around None values by providing your custom collate_fn to the DataLoader
875
- (and have your model handle it). Or, if these are not meaningful values within a minibtach, you can
876
- modify or remove any such values through selections and filters prior to calling to_pytorch_dataset().
877
- """
878
- # check dependencies
879
- Env.get().require_package('pyarrow', [13])
880
- Env.get().require_package('torch')
881
- Env.get().require_package('torchvision')
882
-
883
- from pixeltable.io import export_parquet
884
- from pixeltable.utils.pytorch import PixeltablePytorchDataset
885
-
886
- cache_key = self._hash_result_set()
887
-
888
- dest_path = (Env.get().dataset_cache_dir / f'df_{cache_key}').with_suffix('.parquet')
889
- if dest_path.exists(): # fast path: use cache
890
- assert dest_path.is_dir()
891
- else:
892
- export_parquet(self, dest_path, inline_images=True)
893
-
894
- return PixeltablePytorchDataset(path=dest_path, image_format=image_format)