pixeltable 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (119) hide show
  1. pixeltable/__init__.py +53 -0
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/__init__.py +13 -0
  4. pixeltable/catalog/catalog.py +159 -0
  5. pixeltable/catalog/column.py +181 -0
  6. pixeltable/catalog/dir.py +32 -0
  7. pixeltable/catalog/globals.py +33 -0
  8. pixeltable/catalog/insertable_table.py +192 -0
  9. pixeltable/catalog/named_function.py +36 -0
  10. pixeltable/catalog/path.py +58 -0
  11. pixeltable/catalog/path_dict.py +139 -0
  12. pixeltable/catalog/schema_object.py +39 -0
  13. pixeltable/catalog/table.py +695 -0
  14. pixeltable/catalog/table_version.py +1026 -0
  15. pixeltable/catalog/table_version_path.py +133 -0
  16. pixeltable/catalog/view.py +203 -0
  17. pixeltable/dataframe.py +749 -0
  18. pixeltable/env.py +466 -0
  19. pixeltable/exceptions.py +17 -0
  20. pixeltable/exec/__init__.py +10 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +116 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +94 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +73 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +226 -0
  31. pixeltable/exprs/__init__.py +25 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +114 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +199 -0
  39. pixeltable/exprs/expr.py +594 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +382 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +96 -0
  44. pixeltable/exprs/in_predicate.py +96 -0
  45. pixeltable/exprs/inline_array.py +109 -0
  46. pixeltable/exprs/inline_dict.py +103 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +66 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +329 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/similarity_expr.py +65 -0
  56. pixeltable/exprs/type_cast.py +53 -0
  57. pixeltable/exprs/variable.py +45 -0
  58. pixeltable/ext/__init__.py +5 -0
  59. pixeltable/ext/functions/yolox.py +92 -0
  60. pixeltable/func/__init__.py +7 -0
  61. pixeltable/func/aggregate_function.py +197 -0
  62. pixeltable/func/callable_function.py +113 -0
  63. pixeltable/func/expr_template_function.py +99 -0
  64. pixeltable/func/function.py +141 -0
  65. pixeltable/func/function_registry.py +227 -0
  66. pixeltable/func/globals.py +46 -0
  67. pixeltable/func/nos_function.py +202 -0
  68. pixeltable/func/signature.py +162 -0
  69. pixeltable/func/udf.py +164 -0
  70. pixeltable/functions/__init__.py +95 -0
  71. pixeltable/functions/eval.py +215 -0
  72. pixeltable/functions/fireworks.py +34 -0
  73. pixeltable/functions/huggingface.py +167 -0
  74. pixeltable/functions/image.py +16 -0
  75. pixeltable/functions/openai.py +289 -0
  76. pixeltable/functions/pil/image.py +147 -0
  77. pixeltable/functions/string.py +13 -0
  78. pixeltable/functions/together.py +143 -0
  79. pixeltable/functions/util.py +52 -0
  80. pixeltable/functions/video.py +62 -0
  81. pixeltable/globals.py +425 -0
  82. pixeltable/index/__init__.py +2 -0
  83. pixeltable/index/base.py +51 -0
  84. pixeltable/index/embedding_index.py +168 -0
  85. pixeltable/io/__init__.py +3 -0
  86. pixeltable/io/hf_datasets.py +188 -0
  87. pixeltable/io/pandas.py +148 -0
  88. pixeltable/io/parquet.py +192 -0
  89. pixeltable/iterators/__init__.py +3 -0
  90. pixeltable/iterators/base.py +52 -0
  91. pixeltable/iterators/document.py +432 -0
  92. pixeltable/iterators/video.py +88 -0
  93. pixeltable/metadata/__init__.py +58 -0
  94. pixeltable/metadata/converters/convert_10.py +18 -0
  95. pixeltable/metadata/converters/convert_12.py +3 -0
  96. pixeltable/metadata/converters/convert_13.py +41 -0
  97. pixeltable/metadata/schema.py +234 -0
  98. pixeltable/plan.py +620 -0
  99. pixeltable/store.py +424 -0
  100. pixeltable/tool/create_test_db_dump.py +184 -0
  101. pixeltable/tool/create_test_video.py +81 -0
  102. pixeltable/type_system.py +846 -0
  103. pixeltable/utils/__init__.py +17 -0
  104. pixeltable/utils/arrow.py +98 -0
  105. pixeltable/utils/clip.py +18 -0
  106. pixeltable/utils/coco.py +136 -0
  107. pixeltable/utils/documents.py +69 -0
  108. pixeltable/utils/filecache.py +195 -0
  109. pixeltable/utils/help.py +11 -0
  110. pixeltable/utils/http_server.py +70 -0
  111. pixeltable/utils/media_store.py +76 -0
  112. pixeltable/utils/pytorch.py +91 -0
  113. pixeltable/utils/s3.py +13 -0
  114. pixeltable/utils/sql.py +17 -0
  115. pixeltable/utils/transactional_directory.py +35 -0
  116. pixeltable-0.0.0.dist-info/LICENSE +18 -0
  117. pixeltable-0.0.0.dist-info/METADATA +131 -0
  118. pixeltable-0.0.0.dist-info/RECORD +119 -0
  119. pixeltable-0.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,749 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import copy
5
+ import hashlib
6
+ import io
7
+ import json
8
+ import logging
9
+ import mimetypes
10
+ import traceback
11
+ from pathlib import Path
12
+ from typing import List, Optional, Any, Dict, Generator, Tuple, Set
13
+
14
+ import PIL.Image
15
+ import cv2
16
+ import pandas as pd
17
+ import pandas.io.formats.style
18
+ import sqlalchemy as sql
19
+ from PIL import Image
20
+
21
+ import pixeltable.catalog as catalog
22
+ import pixeltable.exceptions as excs
23
+ import pixeltable.exprs as exprs
24
+ import pixeltable.type_system as ts
25
+ from pixeltable.catalog import is_valid_identifier
26
+ from pixeltable.env import Env
27
+ from pixeltable.plan import Planner
28
+ from pixeltable.type_system import ColumnType
29
+ from pixeltable.utils.http_server import get_file_uri
30
+
31
+ __all__ = ['DataFrame']
32
+
33
+ _logger = logging.getLogger('pixeltable')
34
+
35
+
36
+ def _create_source_tag(file_path: str) -> str:
37
+ src_url = get_file_uri(Env.get().http_address, file_path)
38
+ mime = mimetypes.guess_type(src_url)[0]
39
+ # if mime is None, the attribute string would not be valid html.
40
+ mime_attr = f'type="{mime}"' if mime is not None else ''
41
+ return f'<source src="{src_url}" {mime_attr} />'
42
+
43
+
44
+ class DataFrameResultSet:
45
+ def __init__(self, rows: List[List[Any]], col_names: List[str], col_types: List[ColumnType]):
46
+ self._rows = rows
47
+ self._col_names = col_names
48
+ self._col_types = col_types
49
+ self._formatters = {
50
+ ts.ImageType: self._format_img,
51
+ ts.VideoType: self._format_video,
52
+ ts.AudioType: self._format_audio,
53
+ ts.DocumentType: self._format_document,
54
+ }
55
+
56
+ def __len__(self) -> int:
57
+ return len(self._rows)
58
+
59
+ def column_names(self) -> List[str]:
60
+ return self._col_names
61
+
62
+ def column_types(self) -> List[ColumnType]:
63
+ return self._col_types
64
+
65
+ def __repr__(self) -> str:
66
+ return self.to_pandas().__repr__()
67
+
68
+ def _repr_html_(self) -> str:
69
+ formatters = {
70
+ col_name: self._formatters[col_type.__class__]
71
+ for col_name, col_type in zip(self._col_names, self._col_types)
72
+ if col_type.__class__ in self._formatters
73
+ }
74
+ return self.to_pandas().to_html(formatters=formatters, escape=False, index=False)
75
+
76
+ def __str__(self) -> str:
77
+ return self.to_pandas().to_string()
78
+
79
+ def _reverse(self) -> None:
80
+ """Reverse order of rows"""
81
+ self._rows.reverse()
82
+
83
+ def to_pandas(self) -> pd.DataFrame:
84
+ return pd.DataFrame.from_records(self._rows, columns=self._col_names)
85
+
86
+ def _row_to_dict(self, row_idx: int) -> Dict[str, Any]:
87
+ return {self._col_names[i]: self._rows[row_idx][i] for i in range(len(self._col_names))}
88
+
89
+ # Formatters
90
+ def _format_img(self, img: Image.Image) -> str:
91
+ """
92
+ Create <img> tag for Image object.
93
+ """
94
+ assert isinstance(img, Image.Image), f'Wrong type: {type(img)}'
95
+ # Try to make it look decent in a variety of display scenarios
96
+ if len(self._rows) > 1:
97
+ width = 240 # Multiple rows: display small images
98
+ elif len(self._col_names) > 1:
99
+ width = 480 # Multiple columns: display medium images
100
+ else:
101
+ width = 640 # A single image: larger display
102
+ with io.BytesIO() as buffer:
103
+ img.save(buffer, 'jpeg')
104
+ img_base64 = base64.b64encode(buffer.getvalue()).decode()
105
+ return f"""
106
+ <div class="pxt_image" style="width:{width}px;">
107
+ <img src="data:image/jpeg;base64,{img_base64}" width="{width}" />
108
+ </div>
109
+ """
110
+
111
+ def _format_video(self, file_path: str) -> str:
112
+ thumb_tag = ''
113
+ # Attempt to extract the first frame of the video to use as a thumbnail,
114
+ # so that the notebook can be exported as HTML and viewed in contexts where
115
+ # the video itself is not accessible.
116
+ # TODO(aaron-siegel): If the video is backed by a concrete external URL,
117
+ # should we link to that instead?
118
+ video_reader = cv2.VideoCapture(str(file_path))
119
+ if video_reader.isOpened():
120
+ status, img_array = video_reader.read()
121
+ if status:
122
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
123
+ thumb = PIL.Image.fromarray(img_array)
124
+ with io.BytesIO() as buffer:
125
+ thumb.save(buffer, 'jpeg')
126
+ thumb_base64 = base64.b64encode(buffer.getvalue()).decode()
127
+ thumb_tag = f'poster="data:image/jpeg;base64,{thumb_base64}"'
128
+ video_reader.release()
129
+ if len(self._rows) > 1:
130
+ width = 320
131
+ elif len(self._col_names) > 1:
132
+ width = 480
133
+ else:
134
+ width = 800
135
+ return f"""
136
+ <div class="pxt_video" style="width:{width}px;">
137
+ <video controls width="{width}" {thumb_tag}>
138
+ {_create_source_tag(file_path)}
139
+ </video>
140
+ </div>
141
+ """
142
+
143
+ def _format_document(self, file_path: str) -> str:
144
+ max_width = max_height = 320
145
+ # by default, file path will be shown as a link
146
+ inner_element = file_path
147
+ # try generating a thumbnail for different types and use that if successful
148
+ if file_path.lower().endswith('.pdf'):
149
+ try:
150
+ import fitz
151
+
152
+ doc = fitz.open(file_path)
153
+ p = doc.get_page_pixmap(0)
154
+ while p.width > max_width or p.height > max_height:
155
+ # shrink(1) will halve each dimension
156
+ p.shrink(1)
157
+ data = p.tobytes(output='jpeg')
158
+ thumb_base64 = base64.b64encode(data).decode()
159
+ img_src = f'data:image/jpeg;base64,{thumb_base64}'
160
+ inner_element = f"""
161
+ <img style="object-fit: contain; border: 1px solid black;" src="{img_src}" />
162
+ """
163
+ except:
164
+ logging.warning(f'Failed to produce PDF thumbnail {file_path}. Make sure you have PyMuPDF installed.')
165
+
166
+ return f"""
167
+ <div class="pxt_document" style="width:{max_width}px;">
168
+ <a href="{get_file_uri(Env.get().http_address, file_path)}">
169
+ {inner_element}
170
+ </a>
171
+ </div>
172
+ """
173
+
174
+ def _format_audio(self, file_path: str) -> str:
175
+ return f"""
176
+ <div class="pxt_audio">
177
+ <audio controls>
178
+ {_create_source_tag(file_path)}
179
+ </audio>
180
+ </div>
181
+ """
182
+
183
+ def __getitem__(self, index: Any) -> Any:
184
+ if isinstance(index, str):
185
+ if index not in self._col_names:
186
+ raise excs.Error(f'Invalid column name: {index}')
187
+ col_idx = self._col_names.index(index)
188
+ return [row[col_idx] for row in self._rows]
189
+ if isinstance(index, int):
190
+ return self._row_to_dict(index)
191
+ if isinstance(index, tuple) and len(index) == 2:
192
+ if not isinstance(index[0], int) or not (isinstance(index[1], str) or isinstance(index[1], int)):
193
+ raise excs.Error(f'Bad index, expected [<row idx>, <column name | column index>]: {index}')
194
+ if isinstance(index[1], str) and index[1] not in self._col_names:
195
+ raise excs.Error(f'Invalid column name: {index[1]}')
196
+ col_idx = self._col_names.index(index[1]) if isinstance(index[1], str) else index[1]
197
+ return self._rows[index[0]][col_idx]
198
+ raise excs.Error(f'Bad index: {index}')
199
+
200
+ def __iter__(self) -> DataFrameResultSetIterator:
201
+ return DataFrameResultSetIterator(self)
202
+
203
+ def __eq__(self, other):
204
+ if not isinstance(other, DataFrameResultSet):
205
+ return False
206
+ return self.to_pandas().equals(other.to_pandas())
207
+
208
+
209
+ class DataFrameResultSetIterator:
210
+ def __init__(self, result_set: DataFrameResultSet):
211
+ self._result_set = result_set
212
+ self._idx = 0
213
+
214
+ def __next__(self) -> Dict[str, Any]:
215
+ if self._idx >= len(self._result_set):
216
+ raise StopIteration
217
+ row = self._result_set._row_to_dict(self._idx)
218
+ self._idx += 1
219
+ return row
220
+
221
+
222
+ # # TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
223
+ # class AnalysisInfo:
224
+ # def __init__(self, tbl: catalog.TableVersion):
225
+ # self.tbl = tbl
226
+ # # output of the SQL scan stage
227
+ # self.sql_scan_output_exprs: List[exprs.Expr] = []
228
+ # # output of the agg stage
229
+ # self.agg_output_exprs: List[exprs.Expr] = []
230
+ # # Where clause of the Select stmt of the SQL scan stage
231
+ # self.sql_where_clause: Optional[sql.ClauseElement] = None
232
+ # # filter predicate applied to input rows of the SQL scan stage
233
+ # self.filter: Optional[exprs.Predicate] = None
234
+ # self.similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None
235
+ # self.agg_fn_calls: List[exprs.FunctionCall] = [] # derived from unique_exprs
236
+ # self.has_frame_col: bool = False # True if we're referencing the frame col
237
+ #
238
+ # self.evaluator: Optional[exprs.Evaluator] = None
239
+ # self.sql_scan_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of SQL scan stage
240
+ # self.agg_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of agg stage
241
+ # self.filter_eval_ctx: List[exprs.Expr] = []
242
+ # self.group_by_eval_ctx: List[exprs.Expr] = []
243
+ #
244
+ # def finalize_exec(self) -> None:
245
+ # """
246
+ # Call release() on all collected Exprs.
247
+ # """
248
+ # exprs.Expr.release_list(self.sql_scan_output_exprs)
249
+ # exprs.Expr.release_list(self.agg_output_exprs)
250
+ # if self.filter is not None:
251
+ # self.filter.release()
252
+
253
+
254
+ class DataFrame:
255
+ def __init__(
256
+ self,
257
+ tbl: catalog.TableVersionPath,
258
+ select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]] = None,
259
+ where_clause: Optional[exprs.Predicate] = None,
260
+ group_by_clause: Optional[List[exprs.Expr]] = None,
261
+ grouping_tbl: Optional[catalog.TableVersion] = None,
262
+ order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, # List[(expr, asc)]
263
+ limit: Optional[int] = None,
264
+ ):
265
+ self.tbl = tbl
266
+
267
+ # select list logic
268
+ DataFrame._select_list_check_rep(select_list) # check select list without expansion
269
+ # exprs contain execution state and therefore cannot be shared
270
+ select_list = copy.deepcopy(select_list)
271
+ select_list_exprs, column_names = DataFrame._normalize_select_list(tbl, select_list)
272
+ DataFrame._select_list_check_rep(list(zip(select_list_exprs, column_names)))
273
+ # check select list after expansion to catch early
274
+ # the following two lists are always non empty, even if select list is None.
275
+ self._select_list_exprs = select_list_exprs
276
+ self._column_names = column_names
277
+ self.select_list = select_list
278
+
279
+ self.where_clause = copy.deepcopy(where_clause)
280
+ assert group_by_clause is None or grouping_tbl is None
281
+ self.group_by_clause = copy.deepcopy(group_by_clause)
282
+ self.grouping_tbl = grouping_tbl
283
+ self.order_by_clause = copy.deepcopy(order_by_clause)
284
+ self.limit_val = limit
285
+
286
+ @classmethod
287
+ def _select_list_check_rep(
288
+ cls,
289
+ select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]],
290
+ ) -> None:
291
+ """Validate basic select list types."""
292
+ if select_list is None: # basic check for valid select list
293
+ return
294
+
295
+ assert len(select_list) > 0
296
+ for ent in select_list:
297
+ assert isinstance(ent, tuple)
298
+ assert len(ent) == 2
299
+ assert isinstance(ent[0], exprs.Expr)
300
+ assert ent[1] is None or isinstance(ent[1], str)
301
+ if isinstance(ent[1], str):
302
+ assert is_valid_identifier(ent[1])
303
+
304
+ @classmethod
305
+ def _normalize_select_list(
306
+ cls,
307
+ tbl: catalog.TableVersionPath,
308
+ select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]],
309
+ ) -> Tuple[List[exprs.Expr], List[str]]:
310
+ """
311
+ Expand select list information with all columns and their names
312
+ Returns:
313
+ a pair composed of the list of expressions and the list of corresponding names
314
+ """
315
+ if select_list is None:
316
+ expanded_list = [(exprs.ColumnRef(col), None) for col in tbl.columns()]
317
+ else:
318
+ expanded_list = select_list
319
+
320
+ out_exprs: List[exprs.Expr] = []
321
+ out_names: List[str] = [] # keep track of order
322
+ seen_out_names: set[str] = set() # use to check for duplicates in loop, avoid square complexity
323
+ for i, (expr, name) in enumerate(expanded_list):
324
+ if name is None:
325
+ # use default, add suffix if needed so default adds no duplicates
326
+ default_name = expr.default_column_name()
327
+ if default_name is not None:
328
+ column_name = default_name
329
+ if default_name in seen_out_names:
330
+ # already used, then add suffix until unique name is found
331
+ for j in range(1, len(out_names) + 1):
332
+ column_name = f'{default_name}_{j}'
333
+ if column_name not in seen_out_names:
334
+ break
335
+ else: # no default name, eg some expressions
336
+ column_name = f'col_{i}'
337
+ else: # user provided name, no attempt to rename
338
+ column_name = name
339
+
340
+ out_exprs.append(expr)
341
+ out_names.append(column_name)
342
+ seen_out_names.add(column_name)
343
+ assert len(out_exprs) == len(out_names)
344
+ assert set(out_names) == seen_out_names
345
+ return out_exprs, out_names
346
+
347
+ def _exec(self) -> Generator[exprs.DataRow, None, None]:
348
+ """Run the query and return rows as a generator.
349
+ This function must not modify the state of the DataFrame, otherwise it breaks dataset caching.
350
+ """
351
+ # construct a group-by clause if we're grouping by a table
352
+ group_by_clause: List[exprs.Expr] = []
353
+ if self.grouping_tbl is not None:
354
+ assert self.group_by_clause is None
355
+ num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
356
+ # the grouping table must be a base of self.tbl
357
+ assert num_rowid_cols <= len(self.tbl.tbl_version.store_tbl.rowid_columns())
358
+ group_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
359
+ elif self.group_by_clause is not None:
360
+ group_by_clause = self.group_by_clause
361
+
362
+ for item in self._select_list_exprs:
363
+ item.bind_rel_paths(None)
364
+ plan = Planner.create_query_plan(
365
+ self.tbl,
366
+ self._select_list_exprs,
367
+ where_clause=self.where_clause,
368
+ group_by_clause=group_by_clause,
369
+ order_by_clause=self.order_by_clause if self.order_by_clause is not None else [],
370
+ limit=self.limit_val if self.limit_val is not None else 0,
371
+ ) # limit_val == 0: no limit_val
372
+
373
+ with Env.get().engine.begin() as conn:
374
+ plan.ctx.conn = conn
375
+ plan.open()
376
+ try:
377
+ for row_batch in plan:
378
+ for data_row in row_batch:
379
+ yield data_row
380
+ finally:
381
+ plan.close()
382
+ return
383
+
384
+ def show(self, n: int = 20) -> DataFrameResultSet:
385
+ assert n is not None
386
+ return self.limit(n).collect()
387
+
388
+ def head(self, n: int = 10) -> DataFrameResultSet:
389
+ if self.order_by_clause is not None:
390
+ raise excs.Error(f'head() cannot be used with order_by()')
391
+ num_rowid_cols = len(self.tbl.tbl_version.store_tbl.rowid_columns())
392
+ order_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
393
+ return self.order_by(*order_by_clause, asc=True).limit(n).collect()
394
+
395
+ def tail(self, n: int = 10) -> DataFrameResultSet:
396
+ if self.order_by_clause is not None:
397
+ raise excs.Error(f'tail() cannot be used with order_by()')
398
+ num_rowid_cols = len(self.tbl.tbl_version.store_tbl.rowid_columns())
399
+ order_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
400
+ result = self.order_by(*order_by_clause, asc=False).limit(n).collect()
401
+ result._reverse()
402
+ return result
403
+
404
+ def get_column_names(self) -> List[str]:
405
+ return self._column_names
406
+
407
+ def get_column_types(self) -> List[ColumnType]:
408
+ return [expr.col_type for expr in self._select_list_exprs]
409
+
410
+ def collect(self) -> DataFrameResultSet:
411
+ try:
412
+ result_rows = []
413
+ for data_row in self._exec():
414
+ result_row = [data_row[e.slot_idx] for e in self._select_list_exprs]
415
+ result_rows.append(result_row)
416
+ except excs.ExprEvalError as e:
417
+ msg = f'In row {e.row_num} the {e.expr_msg} encountered exception ' f'{type(e.exc).__name__}:\n{str(e.exc)}'
418
+ if len(e.input_vals) > 0:
419
+ input_msgs = [
420
+ f"'{d}' = {d.col_type.print_value(e.input_vals[i])}" for i, d in enumerate(e.expr.dependencies())
421
+ ]
422
+ msg += f'\nwith {", ".join(input_msgs)}'
423
+ assert e.exc_tb is not None
424
+ stack_trace = traceback.format_tb(e.exc_tb)
425
+ if len(stack_trace) > 2:
426
+ # append a stack trace if the exception happened in user code
427
+ # (frame 0 is ExprEvaluator and frame 1 is some expr's eval()
428
+ nl = '\n'
429
+ # [-1:0:-1]: leave out entry 0 and reverse order, so that the most recent frame is at the top
430
+ msg += f'\nStack:\n{nl.join(stack_trace[-1:1:-1])}'
431
+ raise excs.Error(msg)
432
+ except sql.exc.DBAPIError as e:
433
+ raise excs.Error(f'Error during SQL execution:\n{e}')
434
+
435
+ col_types = self.get_column_types()
436
+ return DataFrameResultSet(result_rows, self._column_names, col_types)
437
+
438
+ def count(self) -> int:
439
+ from pixeltable.plan import Planner
440
+
441
+ stmt = Planner.create_count_stmt(self.tbl, self.where_clause)
442
+ with Env.get().engine.connect() as conn:
443
+ result: int = conn.execute(stmt).scalar_one()
444
+ assert isinstance(result, int)
445
+ return result
446
+
447
+ def _description(self) -> pd.DataFrame:
448
+ """see DataFrame.describe()"""
449
+ heading_vals: List[str] = []
450
+ info_vals: List[str] = []
451
+ if self.select_list is not None:
452
+ assert len(self.select_list) > 0
453
+ heading_vals.append('Select')
454
+ heading_vals.extend([''] * (len(self.select_list) - 1))
455
+ info_vals.extend(self.get_column_names())
456
+ if self.where_clause is not None:
457
+ heading_vals.append('Where')
458
+ info_vals.append(self.where_clause.display_str(inline=False))
459
+ if self.group_by_clause is not None:
460
+ heading_vals.append('Group By')
461
+ heading_vals.extend([''] * (len(self.group_by_clause) - 1))
462
+ info_vals.extend([e.display_str(inline=False) for e in self.group_by_clause])
463
+ if self.order_by_clause is not None:
464
+ heading_vals.append('Order By')
465
+ heading_vals.extend([''] * (len(self.order_by_clause) - 1))
466
+ info_vals.extend(
467
+ [f'{e[0].display_str(inline=False)} {"asc" if e[1] else "desc"}' for e in self.order_by_clause]
468
+ )
469
+ if self.limit_val is not None:
470
+ heading_vals.append('Limit')
471
+ info_vals.append(str(self.limit_val))
472
+ assert len(heading_vals) > 0
473
+ assert len(info_vals) > 0
474
+ assert len(heading_vals) == len(info_vals)
475
+ return pd.DataFrame({'Heading': heading_vals, 'Info': info_vals})
476
+
477
+ def _description_html(self) -> pandas.io.formats.style.Styler:
478
+ """Return the description in an ipython-friendly manner."""
479
+ pd_df = self._description()
480
+ # white-space: pre-wrap: print \n as newline
481
+ # th: center-align headings
482
+ return (
483
+ pd_df.style.set_properties(**{'white-space': 'pre-wrap', 'text-align': 'left'})
484
+ .set_table_styles([dict(selector='th', props=[('text-align', 'center')])])
485
+ .hide(axis='index')
486
+ .hide(axis='columns')
487
+ )
488
+
489
+ def describe(self) -> None:
490
+ """
491
+ Prints a tabular description of this DataFrame.
492
+ The description has two columns, heading and info, which list the contents of each 'component'
493
+ (select list, where clause, ...) vertically.
494
+ """
495
+ try:
496
+ __IPYTHON__
497
+ from IPython.display import display
498
+
499
+ display(self._description_html())
500
+ except NameError:
501
+ print(self.__repr__())
502
+
503
+ def __repr__(self) -> str:
504
+ return self._description().to_string(header=False, index=False)
505
+
506
+ def _repr_html_(self) -> str:
507
+ return self._description_html()._repr_html_()
508
+
509
+ def select(self, *items: Any, **named_items: Any) -> DataFrame:
510
+ if self.select_list is not None:
511
+ raise excs.Error(f'Select list already specified')
512
+ for name, _ in named_items.items():
513
+ if not isinstance(name, str) or not is_valid_identifier(name):
514
+ raise excs.Error(f'Invalid name: {name}')
515
+ base_list = [(expr, None) for expr in items] + [(expr, k) for (k, expr) in named_items.items()]
516
+ if len(base_list) == 0:
517
+ raise excs.Error(f'Empty select list')
518
+
519
+ # analyze select list; wrap literals with the corresponding expressions
520
+ select_list = []
521
+ for raw_expr, name in base_list:
522
+ if isinstance(raw_expr, exprs.Expr):
523
+ select_list.append((raw_expr, name))
524
+ elif isinstance(raw_expr, dict):
525
+ select_list.append((exprs.InlineDict(raw_expr), name))
526
+ elif isinstance(raw_expr, list):
527
+ select_list.append((exprs.InlineArray(raw_expr), name))
528
+ else:
529
+ select_list.append((exprs.Literal(raw_expr), name))
530
+ expr = select_list[-1][0]
531
+ if expr.col_type.is_invalid_type():
532
+ raise excs.Error(f'Invalid type: {raw_expr}')
533
+ # TODO: check that ColumnRefs in expr refer to self.tbl
534
+
535
+ # check user provided names do not conflict among themselves
536
+ # or with auto-generated ones
537
+ seen: Set[str] = set()
538
+ _, names = DataFrame._normalize_select_list(self.tbl, select_list)
539
+ for name in names:
540
+ if name in seen:
541
+ repeated_names = [j for j, x in enumerate(names) if x == name]
542
+ pretty = ', '.join(map(str, repeated_names))
543
+ raise excs.Error(f'Repeated column name "{name}" in select() at positions: {pretty}')
544
+ seen.add(name)
545
+
546
+ return DataFrame(
547
+ self.tbl,
548
+ select_list=select_list,
549
+ where_clause=self.where_clause,
550
+ group_by_clause=self.group_by_clause,
551
+ grouping_tbl=self.grouping_tbl,
552
+ order_by_clause=self.order_by_clause,
553
+ limit=self.limit_val,
554
+ )
555
+
556
+ def where(self, pred: exprs.Predicate) -> DataFrame:
557
+ return DataFrame(
558
+ self.tbl,
559
+ select_list=self.select_list,
560
+ where_clause=pred,
561
+ group_by_clause=self.group_by_clause,
562
+ grouping_tbl=self.grouping_tbl,
563
+ order_by_clause=self.order_by_clause,
564
+ limit=self.limit_val,
565
+ )
566
+
567
+ def group_by(self, *grouping_items: Any) -> DataFrame:
568
+ """Add a group-by clause to this DataFrame.
569
+ Variants:
570
+ - group_by(<base table>): group a component view by their respective base table rows
571
+ - group_by(<expr>, ...): group by the given expressions
572
+ """
573
+ if self.group_by_clause is not None:
574
+ raise excs.Error(f'Group-by already specified')
575
+ grouping_tbl: Optional[catalog.TableVersion] = None
576
+ group_by_clause: Optional[List[exprs.Expr]] = None
577
+ for item in grouping_items:
578
+ if isinstance(item, catalog.Table):
579
+ if len(grouping_items) > 1:
580
+ raise excs.Error(f'group_by(): only one table can be specified')
581
+ # we need to make sure that the grouping table is a base of self.tbl
582
+ base = self.tbl.find_tbl_version(item.tbl_version_path.tbl_id())
583
+ if base is None or base.id == self.tbl.tbl_id():
584
+ raise excs.Error(f'group_by(): {item.name} is not a base table of {self.tbl.tbl_name()}')
585
+ grouping_tbl = item.tbl_version_path.tbl_version
586
+ break
587
+ if not isinstance(item, exprs.Expr):
588
+ raise excs.Error(f'Invalid expression in group_by(): {item}')
589
+ if grouping_tbl is None:
590
+ group_by_clause = list(grouping_items)
591
+ return DataFrame(
592
+ self.tbl,
593
+ select_list=self.select_list,
594
+ where_clause=self.where_clause,
595
+ group_by_clause=group_by_clause,
596
+ grouping_tbl=grouping_tbl,
597
+ order_by_clause=self.order_by_clause,
598
+ limit=self.limit_val,
599
+ )
600
+
601
+ def order_by(self, *expr_list: exprs.Expr, asc: bool = True) -> DataFrame:
602
+ for e in expr_list:
603
+ if not isinstance(e, exprs.Expr):
604
+ raise excs.Error(f'Invalid expression in order_by(): {e}')
605
+ order_by_clause = self.order_by_clause if self.order_by_clause is not None else []
606
+ order_by_clause.extend([(e.copy(), asc) for e in expr_list])
607
+ return DataFrame(
608
+ self.tbl,
609
+ select_list=self.select_list,
610
+ where_clause=self.where_clause,
611
+ group_by_clause=self.group_by_clause,
612
+ grouping_tbl=self.grouping_tbl,
613
+ order_by_clause=order_by_clause,
614
+ limit=self.limit_val,
615
+ )
616
+
617
+ def limit(self, n: int) -> DataFrame:
618
+ assert n is not None and isinstance(n, int)
619
+ return DataFrame(
620
+ self.tbl,
621
+ select_list=self.select_list,
622
+ where_clause=self.where_clause,
623
+ group_by_clause=self.group_by_clause,
624
+ grouping_tbl=self.grouping_tbl,
625
+ order_by_clause=self.order_by_clause,
626
+ limit=n,
627
+ )
628
+
629
+ def __getitem__(self, index: object) -> DataFrame:
630
+ """
631
+ Allowed:
632
+ - [<Predicate>]: filter operation
633
+ - [List[Expr]]/[Tuple[Expr]]: setting the select list
634
+ - [Expr]: setting a single-col select list
635
+ """
636
+ if isinstance(index, exprs.Predicate):
637
+ return self.where(index)
638
+ if isinstance(index, tuple):
639
+ index = list(index)
640
+ if isinstance(index, exprs.Expr):
641
+ index = [index]
642
+ if isinstance(index, list):
643
+ return self.select(*index)
644
+ raise TypeError(f'Invalid index type: {type(index)}')
645
+
646
+ def _as_dict(self) -> Dict[str, Any]:
647
+ """
648
+ Returns:
649
+ Dictionary representing this dataframe.
650
+ """
651
+ tbl_versions = self.tbl.get_tbl_versions()
652
+ d = {
653
+ '_classname': 'DataFrame',
654
+ 'tbl_ids': [str(t.id) for t in tbl_versions],
655
+ 'tbl_versions': [t.version for t in tbl_versions],
656
+ 'select_list': [(e.as_dict(), name) for (e, name) in self.select_list]
657
+ if self.select_list is not None
658
+ else None,
659
+ 'where_clause': self.where_clause.as_dict() if self.where_clause is not None else None,
660
+ 'group_by_clause': [e.as_dict() for e in self.group_by_clause]
661
+ if self.group_by_clause is not None
662
+ else None,
663
+ 'order_by_clause': [(e.as_dict(), asc) for (e, asc) in self.order_by_clause]
664
+ if self.order_by_clause is not None
665
+ else None,
666
+ 'limit_val': self.limit_val,
667
+ }
668
+ return d
669
+
670
+ def to_coco_dataset(self) -> Path:
671
+ """Convert the dataframe to a COCO dataset.
672
+ This dataframe must return a single json-typed output column in the following format:
673
+ {
674
+ 'image': PIL.Image.Image,
675
+ 'annotations': [
676
+ {
677
+ 'bbox': [x: int, y: int, w: int, h: int],
678
+ 'category': str | int,
679
+ },
680
+ ...
681
+ ],
682
+ }
683
+
684
+ Returns:
685
+ Path to the COCO dataset file.
686
+ """
687
+ from pixeltable.utils.coco import write_coco_dataset
688
+
689
+ summary_string = json.dumps(self._as_dict())
690
+ cache_key = hashlib.sha256(summary_string.encode()).hexdigest()
691
+
692
+ dest_path = Env.get().dataset_cache_dir / f'coco_{cache_key}'
693
+ if dest_path.exists():
694
+ assert dest_path.is_dir()
695
+ data_file_path = dest_path / 'data.json'
696
+ assert data_file_path.exists()
697
+ assert data_file_path.is_file()
698
+ return data_file_path
699
+ else:
700
+ return write_coco_dataset(self, dest_path)
701
+
702
+ # TODO Factor this out into a separate module.
703
+ # The return type is unresolvable, but torch can't be imported since it's an optional dependency.
704
+ def to_pytorch_dataset(self, image_format: str = 'pt') -> 'torch.utils.data.IterableDataset':
705
+ """
706
+ Convert the dataframe to a pytorch IterableDataset suitable for parallel loading
707
+ with torch.utils.data.DataLoader.
708
+
709
+ This method requires pyarrow >= 13, torch and torchvision to work.
710
+
711
+ This method serializes data so it can be read from disk efficiently and repeatedly without
712
+ re-executing the query. This data is cached to disk for future re-use.
713
+
714
+ Args:
715
+ image_format: format of the images. Can be 'pt' (pytorch tensor) or 'np' (numpy array).
716
+ 'np' means image columns return as an RGB uint8 array of shape HxWxC.
717
+ 'pt' means image columns return as a CxHxW tensor with values in [0,1] and type torch.float32.
718
+ (the format output by torchvision.transforms.ToTensor())
719
+
720
+ Returns:
721
+ A pytorch IterableDataset: Columns become fields of the dataset, where rows are returned as a dictionary
722
+ compatible with torch.utils.data.DataLoader default collation.
723
+
724
+ Constraints:
725
+ The default collate_fn for torch.data.util.DataLoader cannot represent null values as part of a
726
+ pytorch tensor when forming batches. These values will raise an exception while running the dataloader.
727
+
728
+ If you have them, you can work around None values by providing your custom collate_fn to the DataLoader
729
+ (and have your model handle it). Or, if these are not meaningful values within a minibtach, you can
730
+ modify or remove any such values through selections and filters prior to calling to_pytorch_dataset().
731
+ """
732
+ # check dependencies
733
+ Env.get().require_package('pyarrow', [13])
734
+ Env.get().require_package('torch')
735
+ Env.get().require_package('torchvision')
736
+
737
+ from pixeltable.io.parquet import save_parquet # pylint: disable=import-outside-toplevel
738
+ from pixeltable.utils.pytorch import PixeltablePytorchDataset # pylint: disable=import-outside-toplevel
739
+
740
+ summary_string = json.dumps(self._as_dict())
741
+ cache_key = hashlib.sha256(summary_string.encode()).hexdigest()
742
+
743
+ dest_path = (Env.get().dataset_cache_dir / f'df_{cache_key}').with_suffix('.parquet') # pylint: disable = protected-access
744
+ if dest_path.exists(): # fast path: use cache
745
+ assert dest_path.is_dir()
746
+ else:
747
+ save_parquet(self, dest_path)
748
+
749
+ return PixeltablePytorchDataset(path=dest_path, image_format=image_format)