pixeltable 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +53 -0
- pixeltable/__version__.py +3 -0
- pixeltable/catalog/__init__.py +13 -0
- pixeltable/catalog/catalog.py +159 -0
- pixeltable/catalog/column.py +181 -0
- pixeltable/catalog/dir.py +32 -0
- pixeltable/catalog/globals.py +33 -0
- pixeltable/catalog/insertable_table.py +192 -0
- pixeltable/catalog/named_function.py +36 -0
- pixeltable/catalog/path.py +58 -0
- pixeltable/catalog/path_dict.py +139 -0
- pixeltable/catalog/schema_object.py +39 -0
- pixeltable/catalog/table.py +695 -0
- pixeltable/catalog/table_version.py +1026 -0
- pixeltable/catalog/table_version_path.py +133 -0
- pixeltable/catalog/view.py +203 -0
- pixeltable/dataframe.py +749 -0
- pixeltable/env.py +466 -0
- pixeltable/exceptions.py +17 -0
- pixeltable/exec/__init__.py +10 -0
- pixeltable/exec/aggregation_node.py +78 -0
- pixeltable/exec/cache_prefetch_node.py +116 -0
- pixeltable/exec/component_iteration_node.py +79 -0
- pixeltable/exec/data_row_batch.py +94 -0
- pixeltable/exec/exec_context.py +22 -0
- pixeltable/exec/exec_node.py +61 -0
- pixeltable/exec/expr_eval_node.py +217 -0
- pixeltable/exec/in_memory_data_node.py +73 -0
- pixeltable/exec/media_validation_node.py +43 -0
- pixeltable/exec/sql_scan_node.py +226 -0
- pixeltable/exprs/__init__.py +25 -0
- pixeltable/exprs/arithmetic_expr.py +102 -0
- pixeltable/exprs/array_slice.py +71 -0
- pixeltable/exprs/column_property_ref.py +77 -0
- pixeltable/exprs/column_ref.py +114 -0
- pixeltable/exprs/comparison.py +77 -0
- pixeltable/exprs/compound_predicate.py +98 -0
- pixeltable/exprs/data_row.py +199 -0
- pixeltable/exprs/expr.py +594 -0
- pixeltable/exprs/expr_set.py +39 -0
- pixeltable/exprs/function_call.py +382 -0
- pixeltable/exprs/globals.py +69 -0
- pixeltable/exprs/image_member_access.py +96 -0
- pixeltable/exprs/in_predicate.py +96 -0
- pixeltable/exprs/inline_array.py +109 -0
- pixeltable/exprs/inline_dict.py +103 -0
- pixeltable/exprs/is_null.py +38 -0
- pixeltable/exprs/json_mapper.py +121 -0
- pixeltable/exprs/json_path.py +159 -0
- pixeltable/exprs/literal.py +66 -0
- pixeltable/exprs/object_ref.py +41 -0
- pixeltable/exprs/predicate.py +44 -0
- pixeltable/exprs/row_builder.py +329 -0
- pixeltable/exprs/rowid_ref.py +94 -0
- pixeltable/exprs/similarity_expr.py +65 -0
- pixeltable/exprs/type_cast.py +53 -0
- pixeltable/exprs/variable.py +45 -0
- pixeltable/ext/__init__.py +5 -0
- pixeltable/ext/functions/yolox.py +92 -0
- pixeltable/func/__init__.py +7 -0
- pixeltable/func/aggregate_function.py +197 -0
- pixeltable/func/callable_function.py +113 -0
- pixeltable/func/expr_template_function.py +99 -0
- pixeltable/func/function.py +141 -0
- pixeltable/func/function_registry.py +227 -0
- pixeltable/func/globals.py +46 -0
- pixeltable/func/nos_function.py +202 -0
- pixeltable/func/signature.py +162 -0
- pixeltable/func/udf.py +164 -0
- pixeltable/functions/__init__.py +95 -0
- pixeltable/functions/eval.py +215 -0
- pixeltable/functions/fireworks.py +34 -0
- pixeltable/functions/huggingface.py +167 -0
- pixeltable/functions/image.py +16 -0
- pixeltable/functions/openai.py +289 -0
- pixeltable/functions/pil/image.py +147 -0
- pixeltable/functions/string.py +13 -0
- pixeltable/functions/together.py +143 -0
- pixeltable/functions/util.py +52 -0
- pixeltable/functions/video.py +62 -0
- pixeltable/globals.py +425 -0
- pixeltable/index/__init__.py +2 -0
- pixeltable/index/base.py +51 -0
- pixeltable/index/embedding_index.py +168 -0
- pixeltable/io/__init__.py +3 -0
- pixeltable/io/hf_datasets.py +188 -0
- pixeltable/io/pandas.py +148 -0
- pixeltable/io/parquet.py +192 -0
- pixeltable/iterators/__init__.py +3 -0
- pixeltable/iterators/base.py +52 -0
- pixeltable/iterators/document.py +432 -0
- pixeltable/iterators/video.py +88 -0
- pixeltable/metadata/__init__.py +58 -0
- pixeltable/metadata/converters/convert_10.py +18 -0
- pixeltable/metadata/converters/convert_12.py +3 -0
- pixeltable/metadata/converters/convert_13.py +41 -0
- pixeltable/metadata/schema.py +234 -0
- pixeltable/plan.py +620 -0
- pixeltable/store.py +424 -0
- pixeltable/tool/create_test_db_dump.py +184 -0
- pixeltable/tool/create_test_video.py +81 -0
- pixeltable/type_system.py +846 -0
- pixeltable/utils/__init__.py +17 -0
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/clip.py +18 -0
- pixeltable/utils/coco.py +136 -0
- pixeltable/utils/documents.py +69 -0
- pixeltable/utils/filecache.py +195 -0
- pixeltable/utils/help.py +11 -0
- pixeltable/utils/http_server.py +70 -0
- pixeltable/utils/media_store.py +76 -0
- pixeltable/utils/pytorch.py +91 -0
- pixeltable/utils/s3.py +13 -0
- pixeltable/utils/sql.py +17 -0
- pixeltable/utils/transactional_directory.py +35 -0
- pixeltable-0.0.0.dist-info/LICENSE +18 -0
- pixeltable-0.0.0.dist-info/METADATA +131 -0
- pixeltable-0.0.0.dist-info/RECORD +119 -0
- pixeltable-0.0.0.dist-info/WHEEL +4 -0
pixeltable/dataframe.py
ADDED
|
@@ -0,0 +1,749 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import copy
|
|
5
|
+
import hashlib
|
|
6
|
+
import io
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import mimetypes
|
|
10
|
+
import traceback
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import List, Optional, Any, Dict, Generator, Tuple, Set
|
|
13
|
+
|
|
14
|
+
import PIL.Image
|
|
15
|
+
import cv2
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import pandas.io.formats.style
|
|
18
|
+
import sqlalchemy as sql
|
|
19
|
+
from PIL import Image
|
|
20
|
+
|
|
21
|
+
import pixeltable.catalog as catalog
|
|
22
|
+
import pixeltable.exceptions as excs
|
|
23
|
+
import pixeltable.exprs as exprs
|
|
24
|
+
import pixeltable.type_system as ts
|
|
25
|
+
from pixeltable.catalog import is_valid_identifier
|
|
26
|
+
from pixeltable.env import Env
|
|
27
|
+
from pixeltable.plan import Planner
|
|
28
|
+
from pixeltable.type_system import ColumnType
|
|
29
|
+
from pixeltable.utils.http_server import get_file_uri
|
|
30
|
+
|
|
31
|
+
__all__ = ['DataFrame']
|
|
32
|
+
|
|
33
|
+
_logger = logging.getLogger('pixeltable')
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _create_source_tag(file_path: str) -> str:
|
|
37
|
+
src_url = get_file_uri(Env.get().http_address, file_path)
|
|
38
|
+
mime = mimetypes.guess_type(src_url)[0]
|
|
39
|
+
# if mime is None, the attribute string would not be valid html.
|
|
40
|
+
mime_attr = f'type="{mime}"' if mime is not None else ''
|
|
41
|
+
return f'<source src="{src_url}" {mime_attr} />'
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DataFrameResultSet:
|
|
45
|
+
def __init__(self, rows: List[List[Any]], col_names: List[str], col_types: List[ColumnType]):
|
|
46
|
+
self._rows = rows
|
|
47
|
+
self._col_names = col_names
|
|
48
|
+
self._col_types = col_types
|
|
49
|
+
self._formatters = {
|
|
50
|
+
ts.ImageType: self._format_img,
|
|
51
|
+
ts.VideoType: self._format_video,
|
|
52
|
+
ts.AudioType: self._format_audio,
|
|
53
|
+
ts.DocumentType: self._format_document,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
def __len__(self) -> int:
|
|
57
|
+
return len(self._rows)
|
|
58
|
+
|
|
59
|
+
def column_names(self) -> List[str]:
|
|
60
|
+
return self._col_names
|
|
61
|
+
|
|
62
|
+
def column_types(self) -> List[ColumnType]:
|
|
63
|
+
return self._col_types
|
|
64
|
+
|
|
65
|
+
def __repr__(self) -> str:
|
|
66
|
+
return self.to_pandas().__repr__()
|
|
67
|
+
|
|
68
|
+
def _repr_html_(self) -> str:
|
|
69
|
+
formatters = {
|
|
70
|
+
col_name: self._formatters[col_type.__class__]
|
|
71
|
+
for col_name, col_type in zip(self._col_names, self._col_types)
|
|
72
|
+
if col_type.__class__ in self._formatters
|
|
73
|
+
}
|
|
74
|
+
return self.to_pandas().to_html(formatters=formatters, escape=False, index=False)
|
|
75
|
+
|
|
76
|
+
def __str__(self) -> str:
|
|
77
|
+
return self.to_pandas().to_string()
|
|
78
|
+
|
|
79
|
+
def _reverse(self) -> None:
|
|
80
|
+
"""Reverse order of rows"""
|
|
81
|
+
self._rows.reverse()
|
|
82
|
+
|
|
83
|
+
def to_pandas(self) -> pd.DataFrame:
|
|
84
|
+
return pd.DataFrame.from_records(self._rows, columns=self._col_names)
|
|
85
|
+
|
|
86
|
+
def _row_to_dict(self, row_idx: int) -> Dict[str, Any]:
|
|
87
|
+
return {self._col_names[i]: self._rows[row_idx][i] for i in range(len(self._col_names))}
|
|
88
|
+
|
|
89
|
+
# Formatters
|
|
90
|
+
def _format_img(self, img: Image.Image) -> str:
|
|
91
|
+
"""
|
|
92
|
+
Create <img> tag for Image object.
|
|
93
|
+
"""
|
|
94
|
+
assert isinstance(img, Image.Image), f'Wrong type: {type(img)}'
|
|
95
|
+
# Try to make it look decent in a variety of display scenarios
|
|
96
|
+
if len(self._rows) > 1:
|
|
97
|
+
width = 240 # Multiple rows: display small images
|
|
98
|
+
elif len(self._col_names) > 1:
|
|
99
|
+
width = 480 # Multiple columns: display medium images
|
|
100
|
+
else:
|
|
101
|
+
width = 640 # A single image: larger display
|
|
102
|
+
with io.BytesIO() as buffer:
|
|
103
|
+
img.save(buffer, 'jpeg')
|
|
104
|
+
img_base64 = base64.b64encode(buffer.getvalue()).decode()
|
|
105
|
+
return f"""
|
|
106
|
+
<div class="pxt_image" style="width:{width}px;">
|
|
107
|
+
<img src="data:image/jpeg;base64,{img_base64}" width="{width}" />
|
|
108
|
+
</div>
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def _format_video(self, file_path: str) -> str:
|
|
112
|
+
thumb_tag = ''
|
|
113
|
+
# Attempt to extract the first frame of the video to use as a thumbnail,
|
|
114
|
+
# so that the notebook can be exported as HTML and viewed in contexts where
|
|
115
|
+
# the video itself is not accessible.
|
|
116
|
+
# TODO(aaron-siegel): If the video is backed by a concrete external URL,
|
|
117
|
+
# should we link to that instead?
|
|
118
|
+
video_reader = cv2.VideoCapture(str(file_path))
|
|
119
|
+
if video_reader.isOpened():
|
|
120
|
+
status, img_array = video_reader.read()
|
|
121
|
+
if status:
|
|
122
|
+
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
|
|
123
|
+
thumb = PIL.Image.fromarray(img_array)
|
|
124
|
+
with io.BytesIO() as buffer:
|
|
125
|
+
thumb.save(buffer, 'jpeg')
|
|
126
|
+
thumb_base64 = base64.b64encode(buffer.getvalue()).decode()
|
|
127
|
+
thumb_tag = f'poster="data:image/jpeg;base64,{thumb_base64}"'
|
|
128
|
+
video_reader.release()
|
|
129
|
+
if len(self._rows) > 1:
|
|
130
|
+
width = 320
|
|
131
|
+
elif len(self._col_names) > 1:
|
|
132
|
+
width = 480
|
|
133
|
+
else:
|
|
134
|
+
width = 800
|
|
135
|
+
return f"""
|
|
136
|
+
<div class="pxt_video" style="width:{width}px;">
|
|
137
|
+
<video controls width="{width}" {thumb_tag}>
|
|
138
|
+
{_create_source_tag(file_path)}
|
|
139
|
+
</video>
|
|
140
|
+
</div>
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
def _format_document(self, file_path: str) -> str:
|
|
144
|
+
max_width = max_height = 320
|
|
145
|
+
# by default, file path will be shown as a link
|
|
146
|
+
inner_element = file_path
|
|
147
|
+
# try generating a thumbnail for different types and use that if successful
|
|
148
|
+
if file_path.lower().endswith('.pdf'):
|
|
149
|
+
try:
|
|
150
|
+
import fitz
|
|
151
|
+
|
|
152
|
+
doc = fitz.open(file_path)
|
|
153
|
+
p = doc.get_page_pixmap(0)
|
|
154
|
+
while p.width > max_width or p.height > max_height:
|
|
155
|
+
# shrink(1) will halve each dimension
|
|
156
|
+
p.shrink(1)
|
|
157
|
+
data = p.tobytes(output='jpeg')
|
|
158
|
+
thumb_base64 = base64.b64encode(data).decode()
|
|
159
|
+
img_src = f'data:image/jpeg;base64,{thumb_base64}'
|
|
160
|
+
inner_element = f"""
|
|
161
|
+
<img style="object-fit: contain; border: 1px solid black;" src="{img_src}" />
|
|
162
|
+
"""
|
|
163
|
+
except:
|
|
164
|
+
logging.warning(f'Failed to produce PDF thumbnail {file_path}. Make sure you have PyMuPDF installed.')
|
|
165
|
+
|
|
166
|
+
return f"""
|
|
167
|
+
<div class="pxt_document" style="width:{max_width}px;">
|
|
168
|
+
<a href="{get_file_uri(Env.get().http_address, file_path)}">
|
|
169
|
+
{inner_element}
|
|
170
|
+
</a>
|
|
171
|
+
</div>
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def _format_audio(self, file_path: str) -> str:
|
|
175
|
+
return f"""
|
|
176
|
+
<div class="pxt_audio">
|
|
177
|
+
<audio controls>
|
|
178
|
+
{_create_source_tag(file_path)}
|
|
179
|
+
</audio>
|
|
180
|
+
</div>
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def __getitem__(self, index: Any) -> Any:
|
|
184
|
+
if isinstance(index, str):
|
|
185
|
+
if index not in self._col_names:
|
|
186
|
+
raise excs.Error(f'Invalid column name: {index}')
|
|
187
|
+
col_idx = self._col_names.index(index)
|
|
188
|
+
return [row[col_idx] for row in self._rows]
|
|
189
|
+
if isinstance(index, int):
|
|
190
|
+
return self._row_to_dict(index)
|
|
191
|
+
if isinstance(index, tuple) and len(index) == 2:
|
|
192
|
+
if not isinstance(index[0], int) or not (isinstance(index[1], str) or isinstance(index[1], int)):
|
|
193
|
+
raise excs.Error(f'Bad index, expected [<row idx>, <column name | column index>]: {index}')
|
|
194
|
+
if isinstance(index[1], str) and index[1] not in self._col_names:
|
|
195
|
+
raise excs.Error(f'Invalid column name: {index[1]}')
|
|
196
|
+
col_idx = self._col_names.index(index[1]) if isinstance(index[1], str) else index[1]
|
|
197
|
+
return self._rows[index[0]][col_idx]
|
|
198
|
+
raise excs.Error(f'Bad index: {index}')
|
|
199
|
+
|
|
200
|
+
def __iter__(self) -> DataFrameResultSetIterator:
|
|
201
|
+
return DataFrameResultSetIterator(self)
|
|
202
|
+
|
|
203
|
+
def __eq__(self, other):
|
|
204
|
+
if not isinstance(other, DataFrameResultSet):
|
|
205
|
+
return False
|
|
206
|
+
return self.to_pandas().equals(other.to_pandas())
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class DataFrameResultSetIterator:
|
|
210
|
+
def __init__(self, result_set: DataFrameResultSet):
|
|
211
|
+
self._result_set = result_set
|
|
212
|
+
self._idx = 0
|
|
213
|
+
|
|
214
|
+
def __next__(self) -> Dict[str, Any]:
|
|
215
|
+
if self._idx >= len(self._result_set):
|
|
216
|
+
raise StopIteration
|
|
217
|
+
row = self._result_set._row_to_dict(self._idx)
|
|
218
|
+
self._idx += 1
|
|
219
|
+
return row
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# # TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
|
|
223
|
+
# class AnalysisInfo:
|
|
224
|
+
# def __init__(self, tbl: catalog.TableVersion):
|
|
225
|
+
# self.tbl = tbl
|
|
226
|
+
# # output of the SQL scan stage
|
|
227
|
+
# self.sql_scan_output_exprs: List[exprs.Expr] = []
|
|
228
|
+
# # output of the agg stage
|
|
229
|
+
# self.agg_output_exprs: List[exprs.Expr] = []
|
|
230
|
+
# # Where clause of the Select stmt of the SQL scan stage
|
|
231
|
+
# self.sql_where_clause: Optional[sql.ClauseElement] = None
|
|
232
|
+
# # filter predicate applied to input rows of the SQL scan stage
|
|
233
|
+
# self.filter: Optional[exprs.Predicate] = None
|
|
234
|
+
# self.similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None
|
|
235
|
+
# self.agg_fn_calls: List[exprs.FunctionCall] = [] # derived from unique_exprs
|
|
236
|
+
# self.has_frame_col: bool = False # True if we're referencing the frame col
|
|
237
|
+
#
|
|
238
|
+
# self.evaluator: Optional[exprs.Evaluator] = None
|
|
239
|
+
# self.sql_scan_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of SQL scan stage
|
|
240
|
+
# self.agg_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of agg stage
|
|
241
|
+
# self.filter_eval_ctx: List[exprs.Expr] = []
|
|
242
|
+
# self.group_by_eval_ctx: List[exprs.Expr] = []
|
|
243
|
+
#
|
|
244
|
+
# def finalize_exec(self) -> None:
|
|
245
|
+
# """
|
|
246
|
+
# Call release() on all collected Exprs.
|
|
247
|
+
# """
|
|
248
|
+
# exprs.Expr.release_list(self.sql_scan_output_exprs)
|
|
249
|
+
# exprs.Expr.release_list(self.agg_output_exprs)
|
|
250
|
+
# if self.filter is not None:
|
|
251
|
+
# self.filter.release()
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class DataFrame:
|
|
255
|
+
def __init__(
|
|
256
|
+
self,
|
|
257
|
+
tbl: catalog.TableVersionPath,
|
|
258
|
+
select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]] = None,
|
|
259
|
+
where_clause: Optional[exprs.Predicate] = None,
|
|
260
|
+
group_by_clause: Optional[List[exprs.Expr]] = None,
|
|
261
|
+
grouping_tbl: Optional[catalog.TableVersion] = None,
|
|
262
|
+
order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, # List[(expr, asc)]
|
|
263
|
+
limit: Optional[int] = None,
|
|
264
|
+
):
|
|
265
|
+
self.tbl = tbl
|
|
266
|
+
|
|
267
|
+
# select list logic
|
|
268
|
+
DataFrame._select_list_check_rep(select_list) # check select list without expansion
|
|
269
|
+
# exprs contain execution state and therefore cannot be shared
|
|
270
|
+
select_list = copy.deepcopy(select_list)
|
|
271
|
+
select_list_exprs, column_names = DataFrame._normalize_select_list(tbl, select_list)
|
|
272
|
+
DataFrame._select_list_check_rep(list(zip(select_list_exprs, column_names)))
|
|
273
|
+
# check select list after expansion to catch early
|
|
274
|
+
# the following two lists are always non empty, even if select list is None.
|
|
275
|
+
self._select_list_exprs = select_list_exprs
|
|
276
|
+
self._column_names = column_names
|
|
277
|
+
self.select_list = select_list
|
|
278
|
+
|
|
279
|
+
self.where_clause = copy.deepcopy(where_clause)
|
|
280
|
+
assert group_by_clause is None or grouping_tbl is None
|
|
281
|
+
self.group_by_clause = copy.deepcopy(group_by_clause)
|
|
282
|
+
self.grouping_tbl = grouping_tbl
|
|
283
|
+
self.order_by_clause = copy.deepcopy(order_by_clause)
|
|
284
|
+
self.limit_val = limit
|
|
285
|
+
|
|
286
|
+
@classmethod
|
|
287
|
+
def _select_list_check_rep(
|
|
288
|
+
cls,
|
|
289
|
+
select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]],
|
|
290
|
+
) -> None:
|
|
291
|
+
"""Validate basic select list types."""
|
|
292
|
+
if select_list is None: # basic check for valid select list
|
|
293
|
+
return
|
|
294
|
+
|
|
295
|
+
assert len(select_list) > 0
|
|
296
|
+
for ent in select_list:
|
|
297
|
+
assert isinstance(ent, tuple)
|
|
298
|
+
assert len(ent) == 2
|
|
299
|
+
assert isinstance(ent[0], exprs.Expr)
|
|
300
|
+
assert ent[1] is None or isinstance(ent[1], str)
|
|
301
|
+
if isinstance(ent[1], str):
|
|
302
|
+
assert is_valid_identifier(ent[1])
|
|
303
|
+
|
|
304
|
+
@classmethod
|
|
305
|
+
def _normalize_select_list(
|
|
306
|
+
cls,
|
|
307
|
+
tbl: catalog.TableVersionPath,
|
|
308
|
+
select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]],
|
|
309
|
+
) -> Tuple[List[exprs.Expr], List[str]]:
|
|
310
|
+
"""
|
|
311
|
+
Expand select list information with all columns and their names
|
|
312
|
+
Returns:
|
|
313
|
+
a pair composed of the list of expressions and the list of corresponding names
|
|
314
|
+
"""
|
|
315
|
+
if select_list is None:
|
|
316
|
+
expanded_list = [(exprs.ColumnRef(col), None) for col in tbl.columns()]
|
|
317
|
+
else:
|
|
318
|
+
expanded_list = select_list
|
|
319
|
+
|
|
320
|
+
out_exprs: List[exprs.Expr] = []
|
|
321
|
+
out_names: List[str] = [] # keep track of order
|
|
322
|
+
seen_out_names: set[str] = set() # use to check for duplicates in loop, avoid square complexity
|
|
323
|
+
for i, (expr, name) in enumerate(expanded_list):
|
|
324
|
+
if name is None:
|
|
325
|
+
# use default, add suffix if needed so default adds no duplicates
|
|
326
|
+
default_name = expr.default_column_name()
|
|
327
|
+
if default_name is not None:
|
|
328
|
+
column_name = default_name
|
|
329
|
+
if default_name in seen_out_names:
|
|
330
|
+
# already used, then add suffix until unique name is found
|
|
331
|
+
for j in range(1, len(out_names) + 1):
|
|
332
|
+
column_name = f'{default_name}_{j}'
|
|
333
|
+
if column_name not in seen_out_names:
|
|
334
|
+
break
|
|
335
|
+
else: # no default name, eg some expressions
|
|
336
|
+
column_name = f'col_{i}'
|
|
337
|
+
else: # user provided name, no attempt to rename
|
|
338
|
+
column_name = name
|
|
339
|
+
|
|
340
|
+
out_exprs.append(expr)
|
|
341
|
+
out_names.append(column_name)
|
|
342
|
+
seen_out_names.add(column_name)
|
|
343
|
+
assert len(out_exprs) == len(out_names)
|
|
344
|
+
assert set(out_names) == seen_out_names
|
|
345
|
+
return out_exprs, out_names
|
|
346
|
+
|
|
347
|
+
def _exec(self) -> Generator[exprs.DataRow, None, None]:
|
|
348
|
+
"""Run the query and return rows as a generator.
|
|
349
|
+
This function must not modify the state of the DataFrame, otherwise it breaks dataset caching.
|
|
350
|
+
"""
|
|
351
|
+
# construct a group-by clause if we're grouping by a table
|
|
352
|
+
group_by_clause: List[exprs.Expr] = []
|
|
353
|
+
if self.grouping_tbl is not None:
|
|
354
|
+
assert self.group_by_clause is None
|
|
355
|
+
num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
|
|
356
|
+
# the grouping table must be a base of self.tbl
|
|
357
|
+
assert num_rowid_cols <= len(self.tbl.tbl_version.store_tbl.rowid_columns())
|
|
358
|
+
group_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
|
|
359
|
+
elif self.group_by_clause is not None:
|
|
360
|
+
group_by_clause = self.group_by_clause
|
|
361
|
+
|
|
362
|
+
for item in self._select_list_exprs:
|
|
363
|
+
item.bind_rel_paths(None)
|
|
364
|
+
plan = Planner.create_query_plan(
|
|
365
|
+
self.tbl,
|
|
366
|
+
self._select_list_exprs,
|
|
367
|
+
where_clause=self.where_clause,
|
|
368
|
+
group_by_clause=group_by_clause,
|
|
369
|
+
order_by_clause=self.order_by_clause if self.order_by_clause is not None else [],
|
|
370
|
+
limit=self.limit_val if self.limit_val is not None else 0,
|
|
371
|
+
) # limit_val == 0: no limit_val
|
|
372
|
+
|
|
373
|
+
with Env.get().engine.begin() as conn:
|
|
374
|
+
plan.ctx.conn = conn
|
|
375
|
+
plan.open()
|
|
376
|
+
try:
|
|
377
|
+
for row_batch in plan:
|
|
378
|
+
for data_row in row_batch:
|
|
379
|
+
yield data_row
|
|
380
|
+
finally:
|
|
381
|
+
plan.close()
|
|
382
|
+
return
|
|
383
|
+
|
|
384
|
+
def show(self, n: int = 20) -> DataFrameResultSet:
|
|
385
|
+
assert n is not None
|
|
386
|
+
return self.limit(n).collect()
|
|
387
|
+
|
|
388
|
+
def head(self, n: int = 10) -> DataFrameResultSet:
|
|
389
|
+
if self.order_by_clause is not None:
|
|
390
|
+
raise excs.Error(f'head() cannot be used with order_by()')
|
|
391
|
+
num_rowid_cols = len(self.tbl.tbl_version.store_tbl.rowid_columns())
|
|
392
|
+
order_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
|
|
393
|
+
return self.order_by(*order_by_clause, asc=True).limit(n).collect()
|
|
394
|
+
|
|
395
|
+
def tail(self, n: int = 10) -> DataFrameResultSet:
|
|
396
|
+
if self.order_by_clause is not None:
|
|
397
|
+
raise excs.Error(f'tail() cannot be used with order_by()')
|
|
398
|
+
num_rowid_cols = len(self.tbl.tbl_version.store_tbl.rowid_columns())
|
|
399
|
+
order_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
|
|
400
|
+
result = self.order_by(*order_by_clause, asc=False).limit(n).collect()
|
|
401
|
+
result._reverse()
|
|
402
|
+
return result
|
|
403
|
+
|
|
404
|
+
def get_column_names(self) -> List[str]:
|
|
405
|
+
return self._column_names
|
|
406
|
+
|
|
407
|
+
def get_column_types(self) -> List[ColumnType]:
|
|
408
|
+
return [expr.col_type for expr in self._select_list_exprs]
|
|
409
|
+
|
|
410
|
+
def collect(self) -> DataFrameResultSet:
|
|
411
|
+
try:
|
|
412
|
+
result_rows = []
|
|
413
|
+
for data_row in self._exec():
|
|
414
|
+
result_row = [data_row[e.slot_idx] for e in self._select_list_exprs]
|
|
415
|
+
result_rows.append(result_row)
|
|
416
|
+
except excs.ExprEvalError as e:
|
|
417
|
+
msg = f'In row {e.row_num} the {e.expr_msg} encountered exception ' f'{type(e.exc).__name__}:\n{str(e.exc)}'
|
|
418
|
+
if len(e.input_vals) > 0:
|
|
419
|
+
input_msgs = [
|
|
420
|
+
f"'{d}' = {d.col_type.print_value(e.input_vals[i])}" for i, d in enumerate(e.expr.dependencies())
|
|
421
|
+
]
|
|
422
|
+
msg += f'\nwith {", ".join(input_msgs)}'
|
|
423
|
+
assert e.exc_tb is not None
|
|
424
|
+
stack_trace = traceback.format_tb(e.exc_tb)
|
|
425
|
+
if len(stack_trace) > 2:
|
|
426
|
+
# append a stack trace if the exception happened in user code
|
|
427
|
+
# (frame 0 is ExprEvaluator and frame 1 is some expr's eval()
|
|
428
|
+
nl = '\n'
|
|
429
|
+
# [-1:0:-1]: leave out entry 0 and reverse order, so that the most recent frame is at the top
|
|
430
|
+
msg += f'\nStack:\n{nl.join(stack_trace[-1:1:-1])}'
|
|
431
|
+
raise excs.Error(msg)
|
|
432
|
+
except sql.exc.DBAPIError as e:
|
|
433
|
+
raise excs.Error(f'Error during SQL execution:\n{e}')
|
|
434
|
+
|
|
435
|
+
col_types = self.get_column_types()
|
|
436
|
+
return DataFrameResultSet(result_rows, self._column_names, col_types)
|
|
437
|
+
|
|
438
|
+
def count(self) -> int:
|
|
439
|
+
from pixeltable.plan import Planner
|
|
440
|
+
|
|
441
|
+
stmt = Planner.create_count_stmt(self.tbl, self.where_clause)
|
|
442
|
+
with Env.get().engine.connect() as conn:
|
|
443
|
+
result: int = conn.execute(stmt).scalar_one()
|
|
444
|
+
assert isinstance(result, int)
|
|
445
|
+
return result
|
|
446
|
+
|
|
447
|
+
def _description(self) -> pd.DataFrame:
|
|
448
|
+
"""see DataFrame.describe()"""
|
|
449
|
+
heading_vals: List[str] = []
|
|
450
|
+
info_vals: List[str] = []
|
|
451
|
+
if self.select_list is not None:
|
|
452
|
+
assert len(self.select_list) > 0
|
|
453
|
+
heading_vals.append('Select')
|
|
454
|
+
heading_vals.extend([''] * (len(self.select_list) - 1))
|
|
455
|
+
info_vals.extend(self.get_column_names())
|
|
456
|
+
if self.where_clause is not None:
|
|
457
|
+
heading_vals.append('Where')
|
|
458
|
+
info_vals.append(self.where_clause.display_str(inline=False))
|
|
459
|
+
if self.group_by_clause is not None:
|
|
460
|
+
heading_vals.append('Group By')
|
|
461
|
+
heading_vals.extend([''] * (len(self.group_by_clause) - 1))
|
|
462
|
+
info_vals.extend([e.display_str(inline=False) for e in self.group_by_clause])
|
|
463
|
+
if self.order_by_clause is not None:
|
|
464
|
+
heading_vals.append('Order By')
|
|
465
|
+
heading_vals.extend([''] * (len(self.order_by_clause) - 1))
|
|
466
|
+
info_vals.extend(
|
|
467
|
+
[f'{e[0].display_str(inline=False)} {"asc" if e[1] else "desc"}' for e in self.order_by_clause]
|
|
468
|
+
)
|
|
469
|
+
if self.limit_val is not None:
|
|
470
|
+
heading_vals.append('Limit')
|
|
471
|
+
info_vals.append(str(self.limit_val))
|
|
472
|
+
assert len(heading_vals) > 0
|
|
473
|
+
assert len(info_vals) > 0
|
|
474
|
+
assert len(heading_vals) == len(info_vals)
|
|
475
|
+
return pd.DataFrame({'Heading': heading_vals, 'Info': info_vals})
|
|
476
|
+
|
|
477
|
+
def _description_html(self) -> pandas.io.formats.style.Styler:
|
|
478
|
+
"""Return the description in an ipython-friendly manner."""
|
|
479
|
+
pd_df = self._description()
|
|
480
|
+
# white-space: pre-wrap: print \n as newline
|
|
481
|
+
# th: center-align headings
|
|
482
|
+
return (
|
|
483
|
+
pd_df.style.set_properties(**{'white-space': 'pre-wrap', 'text-align': 'left'})
|
|
484
|
+
.set_table_styles([dict(selector='th', props=[('text-align', 'center')])])
|
|
485
|
+
.hide(axis='index')
|
|
486
|
+
.hide(axis='columns')
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
def describe(self) -> None:
|
|
490
|
+
"""
|
|
491
|
+
Prints a tabular description of this DataFrame.
|
|
492
|
+
The description has two columns, heading and info, which list the contents of each 'component'
|
|
493
|
+
(select list, where clause, ...) vertically.
|
|
494
|
+
"""
|
|
495
|
+
try:
|
|
496
|
+
__IPYTHON__
|
|
497
|
+
from IPython.display import display
|
|
498
|
+
|
|
499
|
+
display(self._description_html())
|
|
500
|
+
except NameError:
|
|
501
|
+
print(self.__repr__())
|
|
502
|
+
|
|
503
|
+
def __repr__(self) -> str:
|
|
504
|
+
return self._description().to_string(header=False, index=False)
|
|
505
|
+
|
|
506
|
+
def _repr_html_(self) -> str:
|
|
507
|
+
return self._description_html()._repr_html_()
|
|
508
|
+
|
|
509
|
+
def select(self, *items: Any, **named_items: Any) -> DataFrame:
|
|
510
|
+
if self.select_list is not None:
|
|
511
|
+
raise excs.Error(f'Select list already specified')
|
|
512
|
+
for name, _ in named_items.items():
|
|
513
|
+
if not isinstance(name, str) or not is_valid_identifier(name):
|
|
514
|
+
raise excs.Error(f'Invalid name: {name}')
|
|
515
|
+
base_list = [(expr, None) for expr in items] + [(expr, k) for (k, expr) in named_items.items()]
|
|
516
|
+
if len(base_list) == 0:
|
|
517
|
+
raise excs.Error(f'Empty select list')
|
|
518
|
+
|
|
519
|
+
# analyze select list; wrap literals with the corresponding expressions
|
|
520
|
+
select_list = []
|
|
521
|
+
for raw_expr, name in base_list:
|
|
522
|
+
if isinstance(raw_expr, exprs.Expr):
|
|
523
|
+
select_list.append((raw_expr, name))
|
|
524
|
+
elif isinstance(raw_expr, dict):
|
|
525
|
+
select_list.append((exprs.InlineDict(raw_expr), name))
|
|
526
|
+
elif isinstance(raw_expr, list):
|
|
527
|
+
select_list.append((exprs.InlineArray(raw_expr), name))
|
|
528
|
+
else:
|
|
529
|
+
select_list.append((exprs.Literal(raw_expr), name))
|
|
530
|
+
expr = select_list[-1][0]
|
|
531
|
+
if expr.col_type.is_invalid_type():
|
|
532
|
+
raise excs.Error(f'Invalid type: {raw_expr}')
|
|
533
|
+
# TODO: check that ColumnRefs in expr refer to self.tbl
|
|
534
|
+
|
|
535
|
+
# check user provided names do not conflict among themselves
|
|
536
|
+
# or with auto-generated ones
|
|
537
|
+
seen: Set[str] = set()
|
|
538
|
+
_, names = DataFrame._normalize_select_list(self.tbl, select_list)
|
|
539
|
+
for name in names:
|
|
540
|
+
if name in seen:
|
|
541
|
+
repeated_names = [j for j, x in enumerate(names) if x == name]
|
|
542
|
+
pretty = ', '.join(map(str, repeated_names))
|
|
543
|
+
raise excs.Error(f'Repeated column name "{name}" in select() at positions: {pretty}')
|
|
544
|
+
seen.add(name)
|
|
545
|
+
|
|
546
|
+
return DataFrame(
|
|
547
|
+
self.tbl,
|
|
548
|
+
select_list=select_list,
|
|
549
|
+
where_clause=self.where_clause,
|
|
550
|
+
group_by_clause=self.group_by_clause,
|
|
551
|
+
grouping_tbl=self.grouping_tbl,
|
|
552
|
+
order_by_clause=self.order_by_clause,
|
|
553
|
+
limit=self.limit_val,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
def where(self, pred: exprs.Predicate) -> DataFrame:
|
|
557
|
+
return DataFrame(
|
|
558
|
+
self.tbl,
|
|
559
|
+
select_list=self.select_list,
|
|
560
|
+
where_clause=pred,
|
|
561
|
+
group_by_clause=self.group_by_clause,
|
|
562
|
+
grouping_tbl=self.grouping_tbl,
|
|
563
|
+
order_by_clause=self.order_by_clause,
|
|
564
|
+
limit=self.limit_val,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
def group_by(self, *grouping_items: Any) -> DataFrame:
|
|
568
|
+
"""Add a group-by clause to this DataFrame.
|
|
569
|
+
Variants:
|
|
570
|
+
- group_by(<base table>): group a component view by their respective base table rows
|
|
571
|
+
- group_by(<expr>, ...): group by the given expressions
|
|
572
|
+
"""
|
|
573
|
+
if self.group_by_clause is not None:
|
|
574
|
+
raise excs.Error(f'Group-by already specified')
|
|
575
|
+
grouping_tbl: Optional[catalog.TableVersion] = None
|
|
576
|
+
group_by_clause: Optional[List[exprs.Expr]] = None
|
|
577
|
+
for item in grouping_items:
|
|
578
|
+
if isinstance(item, catalog.Table):
|
|
579
|
+
if len(grouping_items) > 1:
|
|
580
|
+
raise excs.Error(f'group_by(): only one table can be specified')
|
|
581
|
+
# we need to make sure that the grouping table is a base of self.tbl
|
|
582
|
+
base = self.tbl.find_tbl_version(item.tbl_version_path.tbl_id())
|
|
583
|
+
if base is None or base.id == self.tbl.tbl_id():
|
|
584
|
+
raise excs.Error(f'group_by(): {item.name} is not a base table of {self.tbl.tbl_name()}')
|
|
585
|
+
grouping_tbl = item.tbl_version_path.tbl_version
|
|
586
|
+
break
|
|
587
|
+
if not isinstance(item, exprs.Expr):
|
|
588
|
+
raise excs.Error(f'Invalid expression in group_by(): {item}')
|
|
589
|
+
if grouping_tbl is None:
|
|
590
|
+
group_by_clause = list(grouping_items)
|
|
591
|
+
return DataFrame(
|
|
592
|
+
self.tbl,
|
|
593
|
+
select_list=self.select_list,
|
|
594
|
+
where_clause=self.where_clause,
|
|
595
|
+
group_by_clause=group_by_clause,
|
|
596
|
+
grouping_tbl=grouping_tbl,
|
|
597
|
+
order_by_clause=self.order_by_clause,
|
|
598
|
+
limit=self.limit_val,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
def order_by(self, *expr_list: exprs.Expr, asc: bool = True) -> DataFrame:
|
|
602
|
+
for e in expr_list:
|
|
603
|
+
if not isinstance(e, exprs.Expr):
|
|
604
|
+
raise excs.Error(f'Invalid expression in order_by(): {e}')
|
|
605
|
+
order_by_clause = self.order_by_clause if self.order_by_clause is not None else []
|
|
606
|
+
order_by_clause.extend([(e.copy(), asc) for e in expr_list])
|
|
607
|
+
return DataFrame(
|
|
608
|
+
self.tbl,
|
|
609
|
+
select_list=self.select_list,
|
|
610
|
+
where_clause=self.where_clause,
|
|
611
|
+
group_by_clause=self.group_by_clause,
|
|
612
|
+
grouping_tbl=self.grouping_tbl,
|
|
613
|
+
order_by_clause=order_by_clause,
|
|
614
|
+
limit=self.limit_val,
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
def limit(self, n: int) -> DataFrame:
|
|
618
|
+
assert n is not None and isinstance(n, int)
|
|
619
|
+
return DataFrame(
|
|
620
|
+
self.tbl,
|
|
621
|
+
select_list=self.select_list,
|
|
622
|
+
where_clause=self.where_clause,
|
|
623
|
+
group_by_clause=self.group_by_clause,
|
|
624
|
+
grouping_tbl=self.grouping_tbl,
|
|
625
|
+
order_by_clause=self.order_by_clause,
|
|
626
|
+
limit=n,
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
def __getitem__(self, index: object) -> DataFrame:
|
|
630
|
+
"""
|
|
631
|
+
Allowed:
|
|
632
|
+
- [<Predicate>]: filter operation
|
|
633
|
+
- [List[Expr]]/[Tuple[Expr]]: setting the select list
|
|
634
|
+
- [Expr]: setting a single-col select list
|
|
635
|
+
"""
|
|
636
|
+
if isinstance(index, exprs.Predicate):
|
|
637
|
+
return self.where(index)
|
|
638
|
+
if isinstance(index, tuple):
|
|
639
|
+
index = list(index)
|
|
640
|
+
if isinstance(index, exprs.Expr):
|
|
641
|
+
index = [index]
|
|
642
|
+
if isinstance(index, list):
|
|
643
|
+
return self.select(*index)
|
|
644
|
+
raise TypeError(f'Invalid index type: {type(index)}')
|
|
645
|
+
|
|
646
|
+
def _as_dict(self) -> Dict[str, Any]:
|
|
647
|
+
"""
|
|
648
|
+
Returns:
|
|
649
|
+
Dictionary representing this dataframe.
|
|
650
|
+
"""
|
|
651
|
+
tbl_versions = self.tbl.get_tbl_versions()
|
|
652
|
+
d = {
|
|
653
|
+
'_classname': 'DataFrame',
|
|
654
|
+
'tbl_ids': [str(t.id) for t in tbl_versions],
|
|
655
|
+
'tbl_versions': [t.version for t in tbl_versions],
|
|
656
|
+
'select_list': [(e.as_dict(), name) for (e, name) in self.select_list]
|
|
657
|
+
if self.select_list is not None
|
|
658
|
+
else None,
|
|
659
|
+
'where_clause': self.where_clause.as_dict() if self.where_clause is not None else None,
|
|
660
|
+
'group_by_clause': [e.as_dict() for e in self.group_by_clause]
|
|
661
|
+
if self.group_by_clause is not None
|
|
662
|
+
else None,
|
|
663
|
+
'order_by_clause': [(e.as_dict(), asc) for (e, asc) in self.order_by_clause]
|
|
664
|
+
if self.order_by_clause is not None
|
|
665
|
+
else None,
|
|
666
|
+
'limit_val': self.limit_val,
|
|
667
|
+
}
|
|
668
|
+
return d
|
|
669
|
+
|
|
670
|
+
def to_coco_dataset(self) -> Path:
|
|
671
|
+
"""Convert the dataframe to a COCO dataset.
|
|
672
|
+
This dataframe must return a single json-typed output column in the following format:
|
|
673
|
+
{
|
|
674
|
+
'image': PIL.Image.Image,
|
|
675
|
+
'annotations': [
|
|
676
|
+
{
|
|
677
|
+
'bbox': [x: int, y: int, w: int, h: int],
|
|
678
|
+
'category': str | int,
|
|
679
|
+
},
|
|
680
|
+
...
|
|
681
|
+
],
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
Returns:
|
|
685
|
+
Path to the COCO dataset file.
|
|
686
|
+
"""
|
|
687
|
+
from pixeltable.utils.coco import write_coco_dataset
|
|
688
|
+
|
|
689
|
+
summary_string = json.dumps(self._as_dict())
|
|
690
|
+
cache_key = hashlib.sha256(summary_string.encode()).hexdigest()
|
|
691
|
+
|
|
692
|
+
dest_path = Env.get().dataset_cache_dir / f'coco_{cache_key}'
|
|
693
|
+
if dest_path.exists():
|
|
694
|
+
assert dest_path.is_dir()
|
|
695
|
+
data_file_path = dest_path / 'data.json'
|
|
696
|
+
assert data_file_path.exists()
|
|
697
|
+
assert data_file_path.is_file()
|
|
698
|
+
return data_file_path
|
|
699
|
+
else:
|
|
700
|
+
return write_coco_dataset(self, dest_path)
|
|
701
|
+
|
|
702
|
+
# TODO Factor this out into a separate module.
|
|
703
|
+
# The return type is unresolvable, but torch can't be imported since it's an optional dependency.
|
|
704
|
+
def to_pytorch_dataset(self, image_format: str = 'pt') -> 'torch.utils.data.IterableDataset':
|
|
705
|
+
"""
|
|
706
|
+
Convert the dataframe to a pytorch IterableDataset suitable for parallel loading
|
|
707
|
+
with torch.utils.data.DataLoader.
|
|
708
|
+
|
|
709
|
+
This method requires pyarrow >= 13, torch and torchvision to work.
|
|
710
|
+
|
|
711
|
+
This method serializes data so it can be read from disk efficiently and repeatedly without
|
|
712
|
+
re-executing the query. This data is cached to disk for future re-use.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
image_format: format of the images. Can be 'pt' (pytorch tensor) or 'np' (numpy array).
|
|
716
|
+
'np' means image columns return as an RGB uint8 array of shape HxWxC.
|
|
717
|
+
'pt' means image columns return as a CxHxW tensor with values in [0,1] and type torch.float32.
|
|
718
|
+
(the format output by torchvision.transforms.ToTensor())
|
|
719
|
+
|
|
720
|
+
Returns:
|
|
721
|
+
A pytorch IterableDataset: Columns become fields of the dataset, where rows are returned as a dictionary
|
|
722
|
+
compatible with torch.utils.data.DataLoader default collation.
|
|
723
|
+
|
|
724
|
+
Constraints:
|
|
725
|
+
The default collate_fn for torch.data.util.DataLoader cannot represent null values as part of a
|
|
726
|
+
pytorch tensor when forming batches. These values will raise an exception while running the dataloader.
|
|
727
|
+
|
|
728
|
+
If you have them, you can work around None values by providing your custom collate_fn to the DataLoader
|
|
729
|
+
(and have your model handle it). Or, if these are not meaningful values within a minibtach, you can
|
|
730
|
+
modify or remove any such values through selections and filters prior to calling to_pytorch_dataset().
|
|
731
|
+
"""
|
|
732
|
+
# check dependencies
|
|
733
|
+
Env.get().require_package('pyarrow', [13])
|
|
734
|
+
Env.get().require_package('torch')
|
|
735
|
+
Env.get().require_package('torchvision')
|
|
736
|
+
|
|
737
|
+
from pixeltable.io.parquet import save_parquet # pylint: disable=import-outside-toplevel
|
|
738
|
+
from pixeltable.utils.pytorch import PixeltablePytorchDataset # pylint: disable=import-outside-toplevel
|
|
739
|
+
|
|
740
|
+
summary_string = json.dumps(self._as_dict())
|
|
741
|
+
cache_key = hashlib.sha256(summary_string.encode()).hexdigest()
|
|
742
|
+
|
|
743
|
+
dest_path = (Env.get().dataset_cache_dir / f'df_{cache_key}').with_suffix('.parquet') # pylint: disable = protected-access
|
|
744
|
+
if dest_path.exists(): # fast path: use cache
|
|
745
|
+
assert dest_path.is_dir()
|
|
746
|
+
else:
|
|
747
|
+
save_parquet(self, dest_path)
|
|
748
|
+
|
|
749
|
+
return PixeltablePytorchDataset(path=dest_path, image_format=image_format)
|