pixeltable 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (99) hide show
  1. pixeltable/__init__.py +18 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +31 -50
  4. pixeltable/catalog/insertable_table.py +7 -6
  5. pixeltable/catalog/table.py +171 -57
  6. pixeltable/catalog/table_version.py +417 -140
  7. pixeltable/catalog/table_version_path.py +2 -2
  8. pixeltable/dataframe.py +239 -121
  9. pixeltable/env.py +82 -16
  10. pixeltable/exec/__init__.py +2 -1
  11. pixeltable/exec/cache_prefetch_node.py +1 -1
  12. pixeltable/exec/data_row_batch.py +6 -7
  13. pixeltable/exec/expr_eval_node.py +28 -28
  14. pixeltable/exec/in_memory_data_node.py +11 -7
  15. pixeltable/exec/sql_scan_node.py +7 -6
  16. pixeltable/exprs/__init__.py +4 -3
  17. pixeltable/exprs/column_ref.py +9 -0
  18. pixeltable/exprs/comparison.py +3 -3
  19. pixeltable/exprs/data_row.py +5 -1
  20. pixeltable/exprs/expr.py +15 -7
  21. pixeltable/exprs/function_call.py +17 -15
  22. pixeltable/exprs/image_member_access.py +9 -28
  23. pixeltable/exprs/in_predicate.py +96 -0
  24. pixeltable/exprs/inline_array.py +13 -11
  25. pixeltable/exprs/inline_dict.py +15 -13
  26. pixeltable/exprs/literal.py +16 -4
  27. pixeltable/exprs/row_builder.py +15 -41
  28. pixeltable/exprs/similarity_expr.py +65 -0
  29. pixeltable/ext/__init__.py +5 -0
  30. pixeltable/ext/functions/yolox.py +92 -0
  31. pixeltable/func/__init__.py +0 -2
  32. pixeltable/func/aggregate_function.py +18 -15
  33. pixeltable/func/callable_function.py +57 -13
  34. pixeltable/func/expr_template_function.py +20 -3
  35. pixeltable/func/function.py +35 -4
  36. pixeltable/func/globals.py +24 -14
  37. pixeltable/func/signature.py +23 -27
  38. pixeltable/func/udf.py +13 -12
  39. pixeltable/functions/__init__.py +8 -8
  40. pixeltable/functions/eval.py +7 -8
  41. pixeltable/functions/huggingface.py +64 -17
  42. pixeltable/functions/openai.py +36 -3
  43. pixeltable/functions/pil/image.py +61 -64
  44. pixeltable/functions/together.py +21 -0
  45. pixeltable/functions/util.py +11 -0
  46. pixeltable/globals.py +425 -0
  47. pixeltable/index/__init__.py +2 -0
  48. pixeltable/index/base.py +51 -0
  49. pixeltable/index/embedding_index.py +168 -0
  50. pixeltable/io/__init__.py +3 -0
  51. pixeltable/{utils → io}/hf_datasets.py +48 -17
  52. pixeltable/io/pandas.py +148 -0
  53. pixeltable/{utils → io}/parquet.py +58 -33
  54. pixeltable/iterators/__init__.py +1 -1
  55. pixeltable/iterators/base.py +4 -0
  56. pixeltable/iterators/document.py +218 -97
  57. pixeltable/iterators/video.py +8 -9
  58. pixeltable/metadata/__init__.py +7 -3
  59. pixeltable/metadata/converters/convert_12.py +3 -0
  60. pixeltable/metadata/converters/convert_13.py +41 -0
  61. pixeltable/metadata/schema.py +45 -22
  62. pixeltable/plan.py +15 -51
  63. pixeltable/store.py +38 -41
  64. pixeltable/tool/create_test_db_dump.py +39 -4
  65. pixeltable/type_system.py +47 -96
  66. pixeltable/utils/documents.py +42 -12
  67. pixeltable/utils/http_server.py +70 -0
  68. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/METADATA +14 -10
  69. pixeltable-0.2.6.dist-info/RECORD +119 -0
  70. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/WHEEL +1 -1
  71. pixeltable/client.py +0 -604
  72. pixeltable/exprs/image_similarity_predicate.py +0 -58
  73. pixeltable/func/batched_function.py +0 -53
  74. pixeltable/tests/conftest.py +0 -177
  75. pixeltable/tests/functions/test_fireworks.py +0 -42
  76. pixeltable/tests/functions/test_functions.py +0 -60
  77. pixeltable/tests/functions/test_huggingface.py +0 -158
  78. pixeltable/tests/functions/test_openai.py +0 -152
  79. pixeltable/tests/functions/test_together.py +0 -111
  80. pixeltable/tests/test_audio.py +0 -65
  81. pixeltable/tests/test_catalog.py +0 -27
  82. pixeltable/tests/test_client.py +0 -21
  83. pixeltable/tests/test_component_view.py +0 -370
  84. pixeltable/tests/test_dataframe.py +0 -439
  85. pixeltable/tests/test_dirs.py +0 -107
  86. pixeltable/tests/test_document.py +0 -120
  87. pixeltable/tests/test_exprs.py +0 -805
  88. pixeltable/tests/test_function.py +0 -324
  89. pixeltable/tests/test_migration.py +0 -43
  90. pixeltable/tests/test_nos.py +0 -54
  91. pixeltable/tests/test_snapshot.py +0 -208
  92. pixeltable/tests/test_table.py +0 -1267
  93. pixeltable/tests/test_transactional_directory.py +0 -42
  94. pixeltable/tests/test_types.py +0 -22
  95. pixeltable/tests/test_video.py +0 -159
  96. pixeltable/tests/test_view.py +0 -530
  97. pixeltable/tests/utils.py +0 -408
  98. pixeltable-0.2.4.dist-info/RECORD +0 -132
  99. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/LICENSE +0 -0
@@ -101,8 +101,8 @@ class TableVersionPath:
101
101
  return DataFrame(self).__getitem__(index)
102
102
 
103
103
  def columns(self) -> List[Column]:
104
- """Return all columns visible in this tbl version path, including columns from bases"""
105
- result = self.tbl_version.cols.copy()
104
+ """Return all user columns visible in this tbl version path, including columns from bases"""
105
+ result = list(self.tbl_version.cols_by_name.values())
106
106
  if self.base is not None:
107
107
  base_cols = self.base.columns()
108
108
  # we only include base columns that don't conflict with one of our column names
pixeltable/dataframe.py CHANGED
@@ -11,6 +11,8 @@ import traceback
11
11
  from pathlib import Path
12
12
  from typing import List, Optional, Any, Dict, Generator, Tuple, Set
13
13
 
14
+ import PIL.Image
15
+ import cv2
14
16
  import pandas as pd
15
17
  import pandas.io.formats.style
16
18
  import sqlalchemy as sql
@@ -24,37 +26,20 @@ from pixeltable.catalog import is_valid_identifier
24
26
  from pixeltable.env import Env
25
27
  from pixeltable.plan import Planner
26
28
  from pixeltable.type_system import ColumnType
29
+ from pixeltable.utils.http_server import get_file_uri
27
30
 
28
- __all__ = [
29
- 'DataFrame'
30
- ]
31
+ __all__ = ['DataFrame']
31
32
 
32
33
  _logger = logging.getLogger('pixeltable')
33
34
 
34
- def _format_img(img: object) -> str:
35
- """
36
- Create <img> tag for Image object.
37
- """
38
- assert isinstance(img, Image.Image), f'Wrong type: {type(img)}'
39
- with io.BytesIO() as buffer:
40
- img.save(buffer, 'jpeg')
41
- img_base64 = base64.b64encode(buffer.getvalue()).decode()
42
- return f'<div style="width:200px;"><img src="data:image/jpeg;base64,{img_base64}" width="200" /></div>'
43
35
 
44
36
  def _create_source_tag(file_path: str) -> str:
45
- abs_path = Path(file_path)
46
- assert abs_path.is_absolute()
47
- src_url = f'{Env.get().http_address}/{abs_path}'
37
+ src_url = get_file_uri(Env.get().http_address, file_path)
48
38
  mime = mimetypes.guess_type(src_url)[0]
49
39
  # if mime is None, the attribute string would not be valid html.
50
40
  mime_attr = f'type="{mime}"' if mime is not None else ''
51
41
  return f'<source src="{src_url}" {mime_attr} />'
52
42
 
53
- def _format_video(file_path: str) -> str:
54
- return f'<video controls>{_create_source_tag(file_path)}</video>'
55
-
56
- def _format_audio(file_path: str) -> str:
57
- return f'<audio controls>{_create_source_tag(file_path)}</audio>'
58
43
 
59
44
  class DataFrameResultSet:
60
45
  def __init__(self, rows: List[List[Any]], col_names: List[str], col_types: List[ColumnType]):
@@ -62,9 +47,10 @@ class DataFrameResultSet:
62
47
  self._col_names = col_names
63
48
  self._col_types = col_types
64
49
  self._formatters = {
65
- ts.ImageType: _format_img,
66
- ts.VideoType: _format_video,
67
- ts.AudioType: _format_audio,
50
+ ts.ImageType: self._format_img,
51
+ ts.VideoType: self._format_video,
52
+ ts.AudioType: self._format_audio,
53
+ ts.DocumentType: self._format_document,
68
54
  }
69
55
 
70
56
  def __len__(self) -> int:
@@ -85,9 +71,7 @@ class DataFrameResultSet:
85
71
  for col_name, col_type in zip(self._col_names, self._col_types)
86
72
  if col_type.__class__ in self._formatters
87
73
  }
88
-
89
- # TODO: why does mypy complain about formatters having an incorrect type?
90
- return self.to_pandas().to_html(formatters=formatters, escape=False, index=False) # type: ignore[arg-type]
74
+ return self.to_pandas().to_html(formatters=formatters, escape=False, index=False)
91
75
 
92
76
  def __str__(self) -> str:
93
77
  return self.to_pandas().to_string()
@@ -102,6 +86,100 @@ class DataFrameResultSet:
102
86
  def _row_to_dict(self, row_idx: int) -> Dict[str, Any]:
103
87
  return {self._col_names[i]: self._rows[row_idx][i] for i in range(len(self._col_names))}
104
88
 
89
+ # Formatters
90
+ def _format_img(self, img: Image.Image) -> str:
91
+ """
92
+ Create <img> tag for Image object.
93
+ """
94
+ assert isinstance(img, Image.Image), f'Wrong type: {type(img)}'
95
+ # Try to make it look decent in a variety of display scenarios
96
+ if len(self._rows) > 1:
97
+ width = 240 # Multiple rows: display small images
98
+ elif len(self._col_names) > 1:
99
+ width = 480 # Multiple columns: display medium images
100
+ else:
101
+ width = 640 # A single image: larger display
102
+ with io.BytesIO() as buffer:
103
+ img.save(buffer, 'jpeg')
104
+ img_base64 = base64.b64encode(buffer.getvalue()).decode()
105
+ return f"""
106
+ <div class="pxt_image" style="width:{width}px;">
107
+ <img src="data:image/jpeg;base64,{img_base64}" width="{width}" />
108
+ </div>
109
+ """
110
+
111
+ def _format_video(self, file_path: str) -> str:
112
+ thumb_tag = ''
113
+ # Attempt to extract the first frame of the video to use as a thumbnail,
114
+ # so that the notebook can be exported as HTML and viewed in contexts where
115
+ # the video itself is not accessible.
116
+ # TODO(aaron-siegel): If the video is backed by a concrete external URL,
117
+ # should we link to that instead?
118
+ video_reader = cv2.VideoCapture(str(file_path))
119
+ if video_reader.isOpened():
120
+ status, img_array = video_reader.read()
121
+ if status:
122
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
123
+ thumb = PIL.Image.fromarray(img_array)
124
+ with io.BytesIO() as buffer:
125
+ thumb.save(buffer, 'jpeg')
126
+ thumb_base64 = base64.b64encode(buffer.getvalue()).decode()
127
+ thumb_tag = f'poster="data:image/jpeg;base64,{thumb_base64}"'
128
+ video_reader.release()
129
+ if len(self._rows) > 1:
130
+ width = 320
131
+ elif len(self._col_names) > 1:
132
+ width = 480
133
+ else:
134
+ width = 800
135
+ return f"""
136
+ <div class="pxt_video" style="width:{width}px;">
137
+ <video controls width="{width}" {thumb_tag}>
138
+ {_create_source_tag(file_path)}
139
+ </video>
140
+ </div>
141
+ """
142
+
143
+ def _format_document(self, file_path: str) -> str:
144
+ max_width = max_height = 320
145
+ # by default, file path will be shown as a link
146
+ inner_element = file_path
147
+ # try generating a thumbnail for different types and use that if successful
148
+ if file_path.lower().endswith('.pdf'):
149
+ try:
150
+ import fitz
151
+
152
+ doc = fitz.open(file_path)
153
+ p = doc.get_page_pixmap(0)
154
+ while p.width > max_width or p.height > max_height:
155
+ # shrink(1) will halve each dimension
156
+ p.shrink(1)
157
+ data = p.tobytes(output='jpeg')
158
+ thumb_base64 = base64.b64encode(data).decode()
159
+ img_src = f'data:image/jpeg;base64,{thumb_base64}'
160
+ inner_element = f"""
161
+ <img style="object-fit: contain; border: 1px solid black;" src="{img_src}" />
162
+ """
163
+ except:
164
+ logging.warning(f'Failed to produce PDF thumbnail {file_path}. Make sure you have PyMuPDF installed.')
165
+
166
+ return f"""
167
+ <div class="pxt_document" style="width:{max_width}px;">
168
+ <a href="{get_file_uri(Env.get().http_address, file_path)}">
169
+ {inner_element}
170
+ </a>
171
+ </div>
172
+ """
173
+
174
+ def _format_audio(self, file_path: str) -> str:
175
+ return f"""
176
+ <div class="pxt_audio">
177
+ <audio controls>
178
+ {_create_source_tag(file_path)}
179
+ </audio>
180
+ </div>
181
+ """
182
+
105
183
  def __getitem__(self, index: Any) -> Any:
106
184
  if isinstance(index, str):
107
185
  if index not in self._col_names:
@@ -141,52 +219,53 @@ class DataFrameResultSetIterator:
141
219
  return row
142
220
 
143
221
 
144
- # TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
145
- class AnalysisInfo:
146
- def __init__(self, tbl: catalog.TableVersion):
147
- self.tbl = tbl
148
- # output of the SQL scan stage
149
- self.sql_scan_output_exprs: List[exprs.Expr] = []
150
- # output of the agg stage
151
- self.agg_output_exprs: List[exprs.Expr] = []
152
- # Where clause of the Select stmt of the SQL scan stage
153
- self.sql_where_clause: Optional[sql.ClauseElement] = None
154
- # filter predicate applied to input rows of the SQL scan stage
155
- self.filter: Optional[exprs.Predicate] = None
156
- self.similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None
157
- self.agg_fn_calls: List[exprs.FunctionCall] = [] # derived from unique_exprs
158
- self.has_frame_col: bool = False # True if we're referencing the frame col
159
-
160
- self.evaluator: Optional[exprs.Evaluator] = None
161
- self.sql_scan_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of SQL scan stage
162
- self.agg_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of agg stage
163
- self.filter_eval_ctx: List[exprs.Expr] = []
164
- self.group_by_eval_ctx: List[exprs.Expr] = []
165
-
166
- def finalize_exec(self) -> None:
167
- """
168
- Call release() on all collected Exprs.
169
- """
170
- exprs.Expr.release_list(self.sql_scan_output_exprs)
171
- exprs.Expr.release_list(self.agg_output_exprs)
172
- if self.filter is not None:
173
- self.filter.release()
174
-
222
+ # # TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
223
+ # class AnalysisInfo:
224
+ # def __init__(self, tbl: catalog.TableVersion):
225
+ # self.tbl = tbl
226
+ # # output of the SQL scan stage
227
+ # self.sql_scan_output_exprs: List[exprs.Expr] = []
228
+ # # output of the agg stage
229
+ # self.agg_output_exprs: List[exprs.Expr] = []
230
+ # # Where clause of the Select stmt of the SQL scan stage
231
+ # self.sql_where_clause: Optional[sql.ClauseElement] = None
232
+ # # filter predicate applied to input rows of the SQL scan stage
233
+ # self.filter: Optional[exprs.Predicate] = None
234
+ # self.similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None
235
+ # self.agg_fn_calls: List[exprs.FunctionCall] = [] # derived from unique_exprs
236
+ # self.has_frame_col: bool = False # True if we're referencing the frame col
237
+ #
238
+ # self.evaluator: Optional[exprs.Evaluator] = None
239
+ # self.sql_scan_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of SQL scan stage
240
+ # self.agg_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of agg stage
241
+ # self.filter_eval_ctx: List[exprs.Expr] = []
242
+ # self.group_by_eval_ctx: List[exprs.Expr] = []
243
+ #
244
+ # def finalize_exec(self) -> None:
245
+ # """
246
+ # Call release() on all collected Exprs.
247
+ # """
248
+ # exprs.Expr.release_list(self.sql_scan_output_exprs)
249
+ # exprs.Expr.release_list(self.agg_output_exprs)
250
+ # if self.filter is not None:
251
+ # self.filter.release()
175
252
 
176
253
 
177
254
  class DataFrame:
178
255
  def __init__(
179
- self, tbl: catalog.TableVersionPath,
180
- select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]] = None,
181
- where_clause: Optional[exprs.Predicate] = None,
182
- group_by_clause: Optional[List[exprs.Expr]] = None,
183
- grouping_tbl: Optional[catalog.TableVersion] = None,
184
- order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, # List[(expr, asc)]
185
- limit: Optional[int] = None):
256
+ self,
257
+ tbl: catalog.TableVersionPath,
258
+ select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]] = None,
259
+ where_clause: Optional[exprs.Predicate] = None,
260
+ group_by_clause: Optional[List[exprs.Expr]] = None,
261
+ grouping_tbl: Optional[catalog.TableVersion] = None,
262
+ order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, # List[(expr, asc)]
263
+ limit: Optional[int] = None,
264
+ ):
186
265
  self.tbl = tbl
187
266
 
188
267
  # select list logic
189
- DataFrame._select_list_check_rep(select_list) # check select list without expansion
268
+ DataFrame._select_list_check_rep(select_list) # check select list without expansion
190
269
  # exprs contain execution state and therefore cannot be shared
191
270
  select_list = copy.deepcopy(select_list)
192
271
  select_list_exprs, column_names = DataFrame._normalize_select_list(tbl, select_list)
@@ -205,12 +284,12 @@ class DataFrame:
205
284
  self.limit_val = limit
206
285
 
207
286
  @classmethod
208
- def _select_list_check_rep(cls,
287
+ def _select_list_check_rep(
288
+ cls,
209
289
  select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]],
210
290
  ) -> None:
211
- """Validate basic select list types.
212
- """
213
- if select_list is None: # basic check for valid select list
291
+ """Validate basic select list types."""
292
+ if select_list is None: # basic check for valid select list
214
293
  return
215
294
 
216
295
  assert len(select_list) > 0
@@ -223,13 +302,14 @@ class DataFrame:
223
302
  assert is_valid_identifier(ent[1])
224
303
 
225
304
  @classmethod
226
- def _normalize_select_list(cls,
305
+ def _normalize_select_list(
306
+ cls,
227
307
  tbl: catalog.TableVersionPath,
228
308
  select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]],
229
309
  ) -> Tuple[List[exprs.Expr], List[str]]:
230
310
  """
231
311
  Expand select list information with all columns and their names
232
- Returns:
312
+ Returns:
233
313
  a pair composed of the list of expressions and the list of corresponding names
234
314
  """
235
315
  if select_list is None:
@@ -237,9 +317,9 @@ class DataFrame:
237
317
  else:
238
318
  expanded_list = select_list
239
319
 
240
- out_exprs : List[exprs.Expr] = []
241
- out_names : List[str] = [] # keep track of order
242
- seen_out_names : set[str] = set() # use to check for duplicates in loop, avoid square complexity
320
+ out_exprs: List[exprs.Expr] = []
321
+ out_names: List[str] = [] # keep track of order
322
+ seen_out_names: set[str] = set() # use to check for duplicates in loop, avoid square complexity
243
323
  for i, (expr, name) in enumerate(expanded_list):
244
324
  if name is None:
245
325
  # use default, add suffix if needed so default adds no duplicates
@@ -248,13 +328,13 @@ class DataFrame:
248
328
  column_name = default_name
249
329
  if default_name in seen_out_names:
250
330
  # already used, then add suffix until unique name is found
251
- for j in range(1, len(out_names)+1):
331
+ for j in range(1, len(out_names) + 1):
252
332
  column_name = f'{default_name}_{j}'
253
333
  if column_name not in seen_out_names:
254
334
  break
255
- else: # no default name, eg some expressions
335
+ else: # no default name, eg some expressions
256
336
  column_name = f'col_{i}'
257
- else: # user provided name, no attempt to rename
337
+ else: # user provided name, no attempt to rename
258
338
  column_name = name
259
339
 
260
340
  out_exprs.append(expr)
@@ -282,9 +362,13 @@ class DataFrame:
282
362
  for item in self._select_list_exprs:
283
363
  item.bind_rel_paths(None)
284
364
  plan = Planner.create_query_plan(
285
- self.tbl, self._select_list_exprs, where_clause=self.where_clause, group_by_clause=group_by_clause,
365
+ self.tbl,
366
+ self._select_list_exprs,
367
+ where_clause=self.where_clause,
368
+ group_by_clause=group_by_clause,
286
369
  order_by_clause=self.order_by_clause if self.order_by_clause is not None else [],
287
- limit=self.limit_val if self.limit_val is not None else 0) # limit_val == 0: no limit_val
370
+ limit=self.limit_val if self.limit_val is not None else 0,
371
+ ) # limit_val == 0: no limit_val
288
372
 
289
373
  with Env.get().engine.begin() as conn:
290
374
  plan.ctx.conn = conn
@@ -330,12 +414,10 @@ class DataFrame:
330
414
  result_row = [data_row[e.slot_idx] for e in self._select_list_exprs]
331
415
  result_rows.append(result_row)
332
416
  except excs.ExprEvalError as e:
333
- msg = (f'In row {e.row_num} the {e.expr_msg} encountered exception '
334
- f'{type(e.exc).__name__}:\n{str(e.exc)}')
417
+ msg = f'In row {e.row_num} the {e.expr_msg} encountered exception ' f'{type(e.exc).__name__}:\n{str(e.exc)}'
335
418
  if len(e.input_vals) > 0:
336
419
  input_msgs = [
337
- f"'{d}' = {d.col_type.print_value(e.input_vals[i])}"
338
- for i, d in enumerate(e.expr.dependencies())
420
+ f"'{d}' = {d.col_type.print_value(e.input_vals[i])}" for i, d in enumerate(e.expr.dependencies())
339
421
  ]
340
422
  msg += f'\nwith {", ".join(input_msgs)}'
341
423
  assert e.exc_tb is not None
@@ -355,6 +437,7 @@ class DataFrame:
355
437
 
356
438
  def count(self) -> int:
357
439
  from pixeltable.plan import Planner
440
+
358
441
  stmt = Planner.create_count_stmt(self.tbl, self.where_clause)
359
442
  with Env.get().engine.connect() as conn:
360
443
  result: int = conn.execute(stmt).scalar_one()
@@ -380,9 +463,9 @@ class DataFrame:
380
463
  if self.order_by_clause is not None:
381
464
  heading_vals.append('Order By')
382
465
  heading_vals.extend([''] * (len(self.order_by_clause) - 1))
383
- info_vals.extend([
384
- f'{e[0].display_str(inline=False)} {"asc" if e[1] else "desc"}' for e in self.order_by_clause
385
- ])
466
+ info_vals.extend(
467
+ [f'{e[0].display_str(inline=False)} {"asc" if e[1] else "desc"}' for e in self.order_by_clause]
468
+ )
386
469
  if self.limit_val is not None:
387
470
  heading_vals.append('Limit')
388
471
  info_vals.append(str(self.limit_val))
@@ -396,9 +479,12 @@ class DataFrame:
396
479
  pd_df = self._description()
397
480
  # white-space: pre-wrap: print \n as newline
398
481
  # th: center-align headings
399
- return pd_df.style.set_properties(**{'white-space': 'pre-wrap', 'text-align': 'left'}) \
400
- .set_table_styles([dict(selector='th', props=[('text-align', 'center')])]) \
401
- .hide(axis='index').hide(axis='columns')
482
+ return (
483
+ pd_df.style.set_properties(**{'white-space': 'pre-wrap', 'text-align': 'left'})
484
+ .set_table_styles([dict(selector='th', props=[('text-align', 'center')])])
485
+ .hide(axis='index')
486
+ .hide(axis='columns')
487
+ )
402
488
 
403
489
  def describe(self) -> None:
404
490
  """
@@ -409,6 +495,7 @@ class DataFrame:
409
495
  try:
410
496
  __IPYTHON__
411
497
  from IPython.display import display
498
+
412
499
  display(self._description_html())
413
500
  except NameError:
414
501
  print(self.__repr__())
@@ -419,16 +506,16 @@ class DataFrame:
419
506
  def _repr_html_(self) -> str:
420
507
  return self._description_html()._repr_html_()
421
508
 
422
- def select(self, *items: Any, **named_items : Any) -> DataFrame:
509
+ def select(self, *items: Any, **named_items: Any) -> DataFrame:
423
510
  if self.select_list is not None:
424
511
  raise excs.Error(f'Select list already specified')
425
- for (name, _) in named_items.items():
512
+ for name, _ in named_items.items():
426
513
  if not isinstance(name, str) or not is_valid_identifier(name):
427
514
  raise excs.Error(f'Invalid name: {name}')
428
515
  base_list = [(expr, None) for expr in items] + [(expr, k) for (k, expr) in named_items.items()]
429
516
  if len(base_list) == 0:
430
517
  raise excs.Error(f'Empty select list')
431
-
518
+
432
519
  # analyze select list; wrap literals with the corresponding expressions
433
520
  select_list = []
434
521
  for raw_expr, name in base_list:
@@ -457,13 +544,25 @@ class DataFrame:
457
544
  seen.add(name)
458
545
 
459
546
  return DataFrame(
460
- self.tbl, select_list=select_list, where_clause=self.where_clause, group_by_clause=self.group_by_clause,
461
- grouping_tbl=self.grouping_tbl, order_by_clause=self.order_by_clause, limit=self.limit_val)
547
+ self.tbl,
548
+ select_list=select_list,
549
+ where_clause=self.where_clause,
550
+ group_by_clause=self.group_by_clause,
551
+ grouping_tbl=self.grouping_tbl,
552
+ order_by_clause=self.order_by_clause,
553
+ limit=self.limit_val,
554
+ )
462
555
 
463
556
  def where(self, pred: exprs.Predicate) -> DataFrame:
464
557
  return DataFrame(
465
- self.tbl, select_list=self.select_list, where_clause=pred, group_by_clause=self.group_by_clause,
466
- grouping_tbl=self.grouping_tbl, order_by_clause=self.order_by_clause, limit=self.limit_val)
558
+ self.tbl,
559
+ select_list=self.select_list,
560
+ where_clause=pred,
561
+ group_by_clause=self.group_by_clause,
562
+ grouping_tbl=self.grouping_tbl,
563
+ order_by_clause=self.order_by_clause,
564
+ limit=self.limit_val,
565
+ )
467
566
 
468
567
  def group_by(self, *grouping_items: Any) -> DataFrame:
469
568
  """Add a group-by clause to this DataFrame.
@@ -490,8 +589,14 @@ class DataFrame:
490
589
  if grouping_tbl is None:
491
590
  group_by_clause = list(grouping_items)
492
591
  return DataFrame(
493
- self.tbl, select_list=self.select_list, where_clause=self.where_clause, group_by_clause=group_by_clause,
494
- grouping_tbl=grouping_tbl, order_by_clause=self.order_by_clause, limit=self.limit_val)
592
+ self.tbl,
593
+ select_list=self.select_list,
594
+ where_clause=self.where_clause,
595
+ group_by_clause=group_by_clause,
596
+ grouping_tbl=grouping_tbl,
597
+ order_by_clause=self.order_by_clause,
598
+ limit=self.limit_val,
599
+ )
495
600
 
496
601
  def order_by(self, *expr_list: exprs.Expr, asc: bool = True) -> DataFrame:
497
602
  for e in expr_list:
@@ -500,16 +605,26 @@ class DataFrame:
500
605
  order_by_clause = self.order_by_clause if self.order_by_clause is not None else []
501
606
  order_by_clause.extend([(e.copy(), asc) for e in expr_list])
502
607
  return DataFrame(
503
- self.tbl, select_list=self.select_list, where_clause=self.where_clause,
504
- group_by_clause=self.group_by_clause, grouping_tbl=self.grouping_tbl, order_by_clause=order_by_clause,
505
- limit=self.limit_val)
608
+ self.tbl,
609
+ select_list=self.select_list,
610
+ where_clause=self.where_clause,
611
+ group_by_clause=self.group_by_clause,
612
+ grouping_tbl=self.grouping_tbl,
613
+ order_by_clause=order_by_clause,
614
+ limit=self.limit_val,
615
+ )
506
616
 
507
617
  def limit(self, n: int) -> DataFrame:
508
618
  assert n is not None and isinstance(n, int)
509
619
  return DataFrame(
510
- self.tbl, select_list=self.select_list, where_clause=self.where_clause,
511
- group_by_clause=self.group_by_clause, grouping_tbl=self.grouping_tbl, order_by_clause=self.order_by_clause,
512
- limit=n)
620
+ self.tbl,
621
+ select_list=self.select_list,
622
+ where_clause=self.where_clause,
623
+ group_by_clause=self.group_by_clause,
624
+ grouping_tbl=self.grouping_tbl,
625
+ order_by_clause=self.order_by_clause,
626
+ limit=n,
627
+ )
513
628
 
514
629
  def __getitem__(self, index: object) -> DataFrame:
515
630
  """
@@ -527,24 +642,27 @@ class DataFrame:
527
642
  if isinstance(index, list):
528
643
  return self.select(*index)
529
644
  raise TypeError(f'Invalid index type: {type(index)}')
530
-
645
+
531
646
  def _as_dict(self) -> Dict[str, Any]:
532
- """
533
- Returns:
534
- Dictionary representing this dataframe.
647
+ """
648
+ Returns:
649
+ Dictionary representing this dataframe.
535
650
  """
536
651
  tbl_versions = self.tbl.get_tbl_versions()
537
652
  d = {
538
653
  '_classname': 'DataFrame',
539
654
  'tbl_ids': [str(t.id) for t in tbl_versions],
540
655
  'tbl_versions': [t.version for t in tbl_versions],
541
- 'select_list':
542
- [(e.as_dict(), name) for (e, name) in self.select_list] if self.select_list is not None else None,
656
+ 'select_list': [(e.as_dict(), name) for (e, name) in self.select_list]
657
+ if self.select_list is not None
658
+ else None,
543
659
  'where_clause': self.where_clause.as_dict() if self.where_clause is not None else None,
544
- 'group_by_clause':
545
- [e.as_dict() for e in self.group_by_clause] if self.group_by_clause is not None else None,
546
- 'order_by_clause':
547
- [(e.as_dict(), asc) for (e,asc) in self.order_by_clause] if self.order_by_clause is not None else None,
660
+ 'group_by_clause': [e.as_dict() for e in self.group_by_clause]
661
+ if self.group_by_clause is not None
662
+ else None,
663
+ 'order_by_clause': [(e.as_dict(), asc) for (e, asc) in self.order_by_clause]
664
+ if self.order_by_clause is not None
665
+ else None,
548
666
  'limit_val': self.limit_val,
549
667
  }
550
668
  return d
@@ -571,7 +689,7 @@ class DataFrame:
571
689
  summary_string = json.dumps(self._as_dict())
572
690
  cache_key = hashlib.sha256(summary_string.encode()).hexdigest()
573
691
 
574
- dest_path = (Env.get().dataset_cache_dir / f'coco_{cache_key}')
692
+ dest_path = Env.get().dataset_cache_dir / f'coco_{cache_key}'
575
693
  if dest_path.exists():
576
694
  assert dest_path.is_dir()
577
695
  data_file_path = dest_path / 'data.json'
@@ -616,14 +734,14 @@ class DataFrame:
616
734
  Env.get().require_package('torch')
617
735
  Env.get().require_package('torchvision')
618
736
 
619
- from pixeltable.utils.parquet import save_parquet # pylint: disable=import-outside-toplevel
620
- from pixeltable.utils.pytorch import PixeltablePytorchDataset # pylint: disable=import-outside-toplevel
737
+ from pixeltable.io.parquet import save_parquet # pylint: disable=import-outside-toplevel
738
+ from pixeltable.utils.pytorch import PixeltablePytorchDataset # pylint: disable=import-outside-toplevel
621
739
 
622
- summary_string = json.dumps(self._as_dict())
740
+ summary_string = json.dumps(self._as_dict())
623
741
  cache_key = hashlib.sha256(summary_string.encode()).hexdigest()
624
-
625
- dest_path = (Env.get().dataset_cache_dir / f'df_{cache_key}').with_suffix('.parquet') # pylint: disable = protected-access
626
- if dest_path.exists(): # fast path: use cache
742
+
743
+ dest_path = (Env.get().dataset_cache_dir / f'df_{cache_key}').with_suffix('.parquet') # pylint: disable = protected-access
744
+ if dest_path.exists(): # fast path: use cache
627
745
  assert dest_path.is_dir()
628
746
  else:
629
747
  save_parquet(self, dest_path)