pixeltable 0.1.0__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +34 -6
- pixeltable/catalog/__init__.py +13 -0
- pixeltable/catalog/catalog.py +159 -0
- pixeltable/catalog/column.py +200 -0
- pixeltable/catalog/dir.py +32 -0
- pixeltable/catalog/globals.py +33 -0
- pixeltable/catalog/insertable_table.py +191 -0
- pixeltable/catalog/named_function.py +36 -0
- pixeltable/catalog/path.py +58 -0
- pixeltable/catalog/path_dict.py +139 -0
- pixeltable/catalog/schema_object.py +39 -0
- pixeltable/catalog/table.py +581 -0
- pixeltable/catalog/table_version.py +749 -0
- pixeltable/catalog/table_version_path.py +133 -0
- pixeltable/catalog/view.py +203 -0
- pixeltable/client.py +590 -30
- pixeltable/dataframe.py +540 -349
- pixeltable/env.py +359 -45
- pixeltable/exceptions.py +12 -21
- pixeltable/exec/__init__.py +9 -0
- pixeltable/exec/aggregation_node.py +78 -0
- pixeltable/exec/cache_prefetch_node.py +116 -0
- pixeltable/exec/component_iteration_node.py +79 -0
- pixeltable/exec/data_row_batch.py +95 -0
- pixeltable/exec/exec_context.py +22 -0
- pixeltable/exec/exec_node.py +61 -0
- pixeltable/exec/expr_eval_node.py +217 -0
- pixeltable/exec/in_memory_data_node.py +69 -0
- pixeltable/exec/media_validation_node.py +43 -0
- pixeltable/exec/sql_scan_node.py +225 -0
- pixeltable/exprs/__init__.py +24 -0
- pixeltable/exprs/arithmetic_expr.py +102 -0
- pixeltable/exprs/array_slice.py +71 -0
- pixeltable/exprs/column_property_ref.py +77 -0
- pixeltable/exprs/column_ref.py +105 -0
- pixeltable/exprs/comparison.py +77 -0
- pixeltable/exprs/compound_predicate.py +98 -0
- pixeltable/exprs/data_row.py +195 -0
- pixeltable/exprs/expr.py +586 -0
- pixeltable/exprs/expr_set.py +39 -0
- pixeltable/exprs/function_call.py +380 -0
- pixeltable/exprs/globals.py +69 -0
- pixeltable/exprs/image_member_access.py +115 -0
- pixeltable/exprs/image_similarity_predicate.py +58 -0
- pixeltable/exprs/inline_array.py +107 -0
- pixeltable/exprs/inline_dict.py +101 -0
- pixeltable/exprs/is_null.py +38 -0
- pixeltable/exprs/json_mapper.py +121 -0
- pixeltable/exprs/json_path.py +159 -0
- pixeltable/exprs/literal.py +54 -0
- pixeltable/exprs/object_ref.py +41 -0
- pixeltable/exprs/predicate.py +44 -0
- pixeltable/exprs/row_builder.py +355 -0
- pixeltable/exprs/rowid_ref.py +94 -0
- pixeltable/exprs/type_cast.py +53 -0
- pixeltable/exprs/variable.py +45 -0
- pixeltable/func/__init__.py +9 -0
- pixeltable/func/aggregate_function.py +194 -0
- pixeltable/func/batched_function.py +53 -0
- pixeltable/func/callable_function.py +69 -0
- pixeltable/func/expr_template_function.py +82 -0
- pixeltable/func/function.py +110 -0
- pixeltable/func/function_registry.py +227 -0
- pixeltable/func/globals.py +36 -0
- pixeltable/func/nos_function.py +202 -0
- pixeltable/func/signature.py +166 -0
- pixeltable/func/udf.py +163 -0
- pixeltable/functions/__init__.py +52 -103
- pixeltable/functions/eval.py +216 -0
- pixeltable/functions/fireworks.py +34 -0
- pixeltable/functions/huggingface.py +120 -0
- pixeltable/functions/image.py +16 -0
- pixeltable/functions/openai.py +256 -0
- pixeltable/functions/pil/image.py +148 -7
- pixeltable/functions/string.py +13 -0
- pixeltable/functions/together.py +122 -0
- pixeltable/functions/util.py +41 -0
- pixeltable/functions/video.py +62 -0
- pixeltable/iterators/__init__.py +3 -0
- pixeltable/iterators/base.py +48 -0
- pixeltable/iterators/document.py +311 -0
- pixeltable/iterators/video.py +89 -0
- pixeltable/metadata/__init__.py +54 -0
- pixeltable/metadata/converters/convert_10.py +18 -0
- pixeltable/metadata/schema.py +211 -0
- pixeltable/plan.py +656 -0
- pixeltable/store.py +418 -182
- pixeltable/tests/conftest.py +146 -88
- pixeltable/tests/functions/test_fireworks.py +42 -0
- pixeltable/tests/functions/test_functions.py +60 -0
- pixeltable/tests/functions/test_huggingface.py +158 -0
- pixeltable/tests/functions/test_openai.py +152 -0
- pixeltable/tests/functions/test_together.py +111 -0
- pixeltable/tests/test_audio.py +65 -0
- pixeltable/tests/test_catalog.py +27 -0
- pixeltable/tests/test_client.py +14 -14
- pixeltable/tests/test_component_view.py +370 -0
- pixeltable/tests/test_dataframe.py +439 -0
- pixeltable/tests/test_dirs.py +78 -62
- pixeltable/tests/test_document.py +120 -0
- pixeltable/tests/test_exprs.py +592 -135
- pixeltable/tests/test_function.py +297 -67
- pixeltable/tests/test_migration.py +43 -0
- pixeltable/tests/test_nos.py +54 -0
- pixeltable/tests/test_snapshot.py +208 -0
- pixeltable/tests/test_table.py +1195 -263
- pixeltable/tests/test_transactional_directory.py +42 -0
- pixeltable/tests/test_types.py +5 -11
- pixeltable/tests/test_video.py +151 -34
- pixeltable/tests/test_view.py +530 -0
- pixeltable/tests/utils.py +320 -45
- pixeltable/tool/create_test_db_dump.py +149 -0
- pixeltable/tool/create_test_video.py +81 -0
- pixeltable/type_system.py +445 -124
- pixeltable/utils/__init__.py +17 -46
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/clip.py +12 -15
- pixeltable/utils/coco.py +136 -0
- pixeltable/utils/documents.py +39 -0
- pixeltable/utils/filecache.py +195 -0
- pixeltable/utils/help.py +11 -0
- pixeltable/utils/hf_datasets.py +157 -0
- pixeltable/utils/media_store.py +76 -0
- pixeltable/utils/parquet.py +167 -0
- pixeltable/utils/pytorch.py +91 -0
- pixeltable/utils/s3.py +13 -0
- pixeltable/utils/sql.py +17 -0
- pixeltable/utils/transactional_directory.py +35 -0
- pixeltable-0.2.4.dist-info/LICENSE +18 -0
- pixeltable-0.2.4.dist-info/METADATA +127 -0
- pixeltable-0.2.4.dist-info/RECORD +132 -0
- {pixeltable-0.1.0.dist-info → pixeltable-0.2.4.dist-info}/WHEEL +1 -1
- pixeltable/catalog.py +0 -1421
- pixeltable/exprs.py +0 -1745
- pixeltable/function.py +0 -269
- pixeltable/functions/clip.py +0 -10
- pixeltable/functions/pil/__init__.py +0 -23
- pixeltable/functions/tf.py +0 -21
- pixeltable/index.py +0 -57
- pixeltable/tests/test_dict.py +0 -24
- pixeltable/tests/test_functions.py +0 -11
- pixeltable/tests/test_tf.py +0 -69
- pixeltable/tf.py +0 -33
- pixeltable/utils/tf.py +0 -33
- pixeltable/utils/video.py +0 -32
- pixeltable-0.1.0.dist-info/METADATA +0 -34
- pixeltable-0.1.0.dist-info/RECORD +0 -36
pixeltable/dataframe.py
CHANGED
|
@@ -1,24 +1,35 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import base64
|
|
4
|
+
import copy
|
|
5
|
+
import hashlib
|
|
2
6
|
import io
|
|
3
|
-
import
|
|
4
|
-
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import mimetypes
|
|
10
|
+
import traceback
|
|
5
11
|
from pathlib import Path
|
|
6
|
-
from
|
|
12
|
+
from typing import List, Optional, Any, Dict, Generator, Tuple, Set
|
|
13
|
+
|
|
7
14
|
import pandas as pd
|
|
15
|
+
import pandas.io.formats.style
|
|
8
16
|
import sqlalchemy as sql
|
|
9
17
|
from PIL import Image
|
|
10
|
-
import copy
|
|
11
18
|
|
|
12
|
-
|
|
19
|
+
import pixeltable.catalog as catalog
|
|
20
|
+
import pixeltable.exceptions as excs
|
|
21
|
+
import pixeltable.exprs as exprs
|
|
22
|
+
import pixeltable.type_system as ts
|
|
23
|
+
from pixeltable.catalog import is_valid_identifier
|
|
13
24
|
from pixeltable.env import Env
|
|
25
|
+
from pixeltable.plan import Planner
|
|
14
26
|
from pixeltable.type_system import ColumnType
|
|
15
|
-
from pixeltable import exprs
|
|
16
|
-
from pixeltable import exceptions as exc
|
|
17
27
|
|
|
18
28
|
__all__ = [
|
|
19
29
|
'DataFrame'
|
|
20
30
|
]
|
|
21
31
|
|
|
32
|
+
_logger = logging.getLogger('pixeltable')
|
|
22
33
|
|
|
23
34
|
def _format_img(img: object) -> str:
|
|
24
35
|
"""
|
|
@@ -28,360 +39,479 @@ def _format_img(img: object) -> str:
|
|
|
28
39
|
with io.BytesIO() as buffer:
|
|
29
40
|
img.save(buffer, 'jpeg')
|
|
30
41
|
img_base64 = base64.b64encode(buffer.getvalue()).decode()
|
|
31
|
-
return f'<img src="data:image/jpeg;base64,{img_base64}">'
|
|
32
|
-
|
|
33
|
-
def
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
#
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
42
|
+
return f'<div style="width:200px;"><img src="data:image/jpeg;base64,{img_base64}" width="200" /></div>'
|
|
43
|
+
|
|
44
|
+
def _create_source_tag(file_path: str) -> str:
|
|
45
|
+
abs_path = Path(file_path)
|
|
46
|
+
assert abs_path.is_absolute()
|
|
47
|
+
src_url = f'{Env.get().http_address}/{abs_path}'
|
|
48
|
+
mime = mimetypes.guess_type(src_url)[0]
|
|
49
|
+
# if mime is None, the attribute string would not be valid html.
|
|
50
|
+
mime_attr = f'type="{mime}"' if mime is not None else ''
|
|
51
|
+
return f'<source src="{src_url}" {mime_attr} />'
|
|
52
|
+
|
|
53
|
+
def _format_video(file_path: str) -> str:
|
|
54
|
+
return f'<video controls>{_create_source_tag(file_path)}</video>'
|
|
55
|
+
|
|
56
|
+
def _format_audio(file_path: str) -> str:
|
|
57
|
+
return f'<audio controls>{_create_source_tag(file_path)}</audio>'
|
|
45
58
|
|
|
46
59
|
class DataFrameResultSet:
|
|
47
|
-
def __init__(self, rows: List[List], col_names: List[str], col_types: List[ColumnType]):
|
|
48
|
-
self.
|
|
49
|
-
self.
|
|
50
|
-
self.
|
|
60
|
+
def __init__(self, rows: List[List[Any]], col_names: List[str], col_types: List[ColumnType]):
|
|
61
|
+
self._rows = rows
|
|
62
|
+
self._col_names = col_names
|
|
63
|
+
self._col_types = col_types
|
|
64
|
+
self._formatters = {
|
|
65
|
+
ts.ImageType: _format_img,
|
|
66
|
+
ts.VideoType: _format_video,
|
|
67
|
+
ts.AudioType: _format_audio,
|
|
68
|
+
}
|
|
51
69
|
|
|
52
70
|
def __len__(self) -> int:
|
|
53
|
-
return len(self.
|
|
71
|
+
return len(self._rows)
|
|
72
|
+
|
|
73
|
+
def column_names(self) -> List[str]:
|
|
74
|
+
return self._col_names
|
|
75
|
+
|
|
76
|
+
def column_types(self) -> List[ColumnType]:
|
|
77
|
+
return self._col_types
|
|
78
|
+
|
|
79
|
+
def __repr__(self) -> str:
|
|
80
|
+
return self.to_pandas().__repr__()
|
|
54
81
|
|
|
55
82
|
def _repr_html_(self) -> str:
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
83
|
+
formatters = {
|
|
84
|
+
col_name: self._formatters[col_type.__class__]
|
|
85
|
+
for col_name, col_type in zip(self._col_names, self._col_types)
|
|
86
|
+
if col_type.__class__ in self._formatters
|
|
87
|
+
}
|
|
88
|
+
|
|
61
89
|
# TODO: why does mypy complain about formatters having an incorrect type?
|
|
62
90
|
return self.to_pandas().to_html(formatters=formatters, escape=False, index=False) # type: ignore[arg-type]
|
|
63
91
|
|
|
64
92
|
def __str__(self) -> str:
|
|
65
93
|
return self.to_pandas().to_string()
|
|
66
94
|
|
|
67
|
-
def
|
|
68
|
-
|
|
95
|
+
def _reverse(self) -> None:
|
|
96
|
+
"""Reverse order of rows"""
|
|
97
|
+
self._rows.reverse()
|
|
69
98
|
|
|
70
|
-
def
|
|
71
|
-
|
|
72
|
-
if len(index) != 2 or not isinstance(index[0], int) or not isinstance(index[1], int):
|
|
73
|
-
raise exc.OperationalError(f'Bad index: {index}')
|
|
74
|
-
return self.rows[index[0]][index[1]]
|
|
99
|
+
def to_pandas(self) -> pd.DataFrame:
|
|
100
|
+
return pd.DataFrame.from_records(self._rows, columns=self._col_names)
|
|
75
101
|
|
|
102
|
+
def _row_to_dict(self, row_idx: int) -> Dict[str, Any]:
|
|
103
|
+
return {self._col_names[i]: self._rows[row_idx][i] for i in range(len(self._col_names))}
|
|
76
104
|
|
|
105
|
+
def __getitem__(self, index: Any) -> Any:
|
|
106
|
+
if isinstance(index, str):
|
|
107
|
+
if index not in self._col_names:
|
|
108
|
+
raise excs.Error(f'Invalid column name: {index}')
|
|
109
|
+
col_idx = self._col_names.index(index)
|
|
110
|
+
return [row[col_idx] for row in self._rows]
|
|
111
|
+
if isinstance(index, int):
|
|
112
|
+
return self._row_to_dict(index)
|
|
113
|
+
if isinstance(index, tuple) and len(index) == 2:
|
|
114
|
+
if not isinstance(index[0], int) or not (isinstance(index[1], str) or isinstance(index[1], int)):
|
|
115
|
+
raise excs.Error(f'Bad index, expected [<row idx>, <column name | column index>]: {index}')
|
|
116
|
+
if isinstance(index[1], str) and index[1] not in self._col_names:
|
|
117
|
+
raise excs.Error(f'Invalid column name: {index[1]}')
|
|
118
|
+
col_idx = self._col_names.index(index[1]) if isinstance(index[1], str) else index[1]
|
|
119
|
+
return self._rows[index[0]][col_idx]
|
|
120
|
+
raise excs.Error(f'Bad index: {index}')
|
|
121
|
+
|
|
122
|
+
def __iter__(self) -> DataFrameResultSetIterator:
|
|
123
|
+
return DataFrameResultSetIterator(self)
|
|
124
|
+
|
|
125
|
+
def __eq__(self, other):
|
|
126
|
+
if not isinstance(other, DataFrameResultSet):
|
|
127
|
+
return False
|
|
128
|
+
return self.to_pandas().equals(other.to_pandas())
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class DataFrameResultSetIterator:
|
|
132
|
+
def __init__(self, result_set: DataFrameResultSet):
|
|
133
|
+
self._result_set = result_set
|
|
134
|
+
self._idx = 0
|
|
135
|
+
|
|
136
|
+
def __next__(self) -> Dict[str, Any]:
|
|
137
|
+
if self._idx >= len(self._result_set):
|
|
138
|
+
raise StopIteration
|
|
139
|
+
row = self._result_set._row_to_dict(self._idx)
|
|
140
|
+
self._idx += 1
|
|
141
|
+
return row
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
|
|
77
145
|
class AnalysisInfo:
|
|
78
|
-
def __init__(self):
|
|
146
|
+
def __init__(self, tbl: catalog.TableVersion):
|
|
147
|
+
self.tbl = tbl
|
|
79
148
|
# output of the SQL scan stage
|
|
80
149
|
self.sql_scan_output_exprs: List[exprs.Expr] = []
|
|
81
150
|
# output of the agg stage
|
|
82
151
|
self.agg_output_exprs: List[exprs.Expr] = []
|
|
83
|
-
# select list providing the input to the SQL scan stage
|
|
84
|
-
self.sql_select_list: List[sql.sql.expression.ClauseElement] = []
|
|
85
152
|
# Where clause of the Select stmt of the SQL scan stage
|
|
86
|
-
self.sql_where_clause: Optional[sql.
|
|
153
|
+
self.sql_where_clause: Optional[sql.ClauseElement] = None
|
|
87
154
|
# filter predicate applied to input rows of the SQL scan stage
|
|
88
155
|
self.filter: Optional[exprs.Predicate] = None
|
|
89
156
|
self.similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None
|
|
90
157
|
self.agg_fn_calls: List[exprs.FunctionCall] = [] # derived from unique_exprs
|
|
158
|
+
self.has_frame_col: bool = False # True if we're referencing the frame col
|
|
91
159
|
|
|
92
|
-
self.
|
|
93
|
-
self.
|
|
160
|
+
self.evaluator: Optional[exprs.Evaluator] = None
|
|
161
|
+
self.sql_scan_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of SQL scan stage
|
|
162
|
+
self.agg_eval_ctx: List[exprs.Expr] = [] # needed to materialize output of agg stage
|
|
163
|
+
self.filter_eval_ctx: List[exprs.Expr] = []
|
|
164
|
+
self.group_by_eval_ctx: List[exprs.Expr] = []
|
|
94
165
|
|
|
95
|
-
|
|
96
|
-
def num_materialized(self) -> int:
|
|
97
|
-
return self.next_data_row_idx
|
|
98
|
-
|
|
99
|
-
def assign_idxs(self, expr_list: List[exprs.Expr]) -> None:
|
|
166
|
+
def finalize_exec(self) -> None:
|
|
100
167
|
"""
|
|
101
|
-
|
|
102
|
-
An expr with to_sql() != None is assumed to be materialized fully via SQL; its components
|
|
103
|
-
aren't materialized and don't receive idxs.
|
|
168
|
+
Call release() on all collected Exprs.
|
|
104
169
|
"""
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
def _assign_idxs_aux(self, expr: exprs.Expr) -> None:
|
|
110
|
-
if not self.unique_exprs.add(expr):
|
|
111
|
-
# nothing left to do
|
|
112
|
-
return
|
|
113
|
-
|
|
114
|
-
sql_expr = expr.sql_expr()
|
|
115
|
-
# if this can be materialized via SQL we don't need to look at its components;
|
|
116
|
-
# we special-case Literals because we don't want to have to materialize them via SQL
|
|
117
|
-
if sql_expr is not None and not isinstance(expr, exprs.Literal):
|
|
118
|
-
assert expr.data_row_idx < 0
|
|
119
|
-
expr.data_row_idx = self.next_data_row_idx
|
|
120
|
-
self.next_data_row_idx += 1
|
|
121
|
-
expr.sql_row_idx = len(self.sql_select_list)
|
|
122
|
-
self.sql_select_list.append(sql_expr)
|
|
123
|
-
return
|
|
170
|
+
exprs.Expr.release_list(self.sql_scan_output_exprs)
|
|
171
|
+
exprs.Expr.release_list(self.agg_output_exprs)
|
|
172
|
+
if self.filter is not None:
|
|
173
|
+
self.filter.release()
|
|
124
174
|
|
|
125
|
-
# expr value needs to be computed via Expr.eval()
|
|
126
|
-
for c in expr.components:
|
|
127
|
-
self._assign_idxs_aux(c)
|
|
128
|
-
assert expr.data_row_idx < 0
|
|
129
|
-
expr.data_row_idx = self.next_data_row_idx
|
|
130
|
-
self.next_data_row_idx += 1
|
|
131
175
|
|
|
132
176
|
|
|
133
177
|
class DataFrame:
|
|
134
178
|
def __init__(
|
|
135
|
-
self, tbl: catalog.
|
|
136
|
-
select_list: Optional[List[exprs.Expr]] = None,
|
|
137
|
-
where_clause: Optional[exprs.Predicate] = None
|
|
179
|
+
self, tbl: catalog.TableVersionPath,
|
|
180
|
+
select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]] = None,
|
|
181
|
+
where_clause: Optional[exprs.Predicate] = None,
|
|
182
|
+
group_by_clause: Optional[List[exprs.Expr]] = None,
|
|
183
|
+
grouping_tbl: Optional[catalog.TableVersion] = None,
|
|
184
|
+
order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, # List[(expr, asc)]
|
|
185
|
+
limit: Optional[int] = None):
|
|
138
186
|
self.tbl = tbl
|
|
139
|
-
# self.select_list and self.where_clause contain execution state and therefore cannot be shared
|
|
140
|
-
self.select_list: Optional[List[exprs.Expr]] = None # None: implies all cols
|
|
141
|
-
if select_list is not None:
|
|
142
|
-
self.select_list = [e.copy() for e in select_list]
|
|
143
|
-
self.where_clause: Optional[exprs.Predicate] = None
|
|
144
|
-
if where_clause is not None:
|
|
145
|
-
self.where_clause = where_clause.copy()
|
|
146
|
-
self.group_by_clause: Optional[List[exprs.Expr]] = None
|
|
147
|
-
self.analysis_info: Optional[AnalysisInfo] = None
|
|
148
|
-
|
|
149
|
-
def analyze(self) -> None:
|
|
150
|
-
"""
|
|
151
|
-
Populates self.analysis_info.
|
|
152
|
-
"""
|
|
153
|
-
info = self.analysis_info = AnalysisInfo()
|
|
154
|
-
if self.where_clause is not None:
|
|
155
|
-
info.sql_where_clause, info.filter = self.where_clause.extract_sql_predicate()
|
|
156
|
-
if info.filter is not None:
|
|
157
|
-
similarity_clauses, info.filter = info.filter.split_conjuncts(
|
|
158
|
-
lambda e: isinstance(e, exprs.ImageSimilarityPredicate))
|
|
159
|
-
if len(similarity_clauses) > 1:
|
|
160
|
-
raise exc.OperationalError(f'More than one nearest() or matches() not supported')
|
|
161
|
-
if len(similarity_clauses) == 1:
|
|
162
|
-
info.similarity_clause = similarity_clauses[0]
|
|
163
|
-
img_col = info.similarity_clause.img_col_ref.col
|
|
164
|
-
if not img_col.is_indexed:
|
|
165
|
-
raise exc.OperationalError(
|
|
166
|
-
f'nearest()/matches() not available for unindexed column {img_col.name}')
|
|
167
|
-
|
|
168
|
-
if info.filter is not None:
|
|
169
|
-
info.assign_idxs([info.filter])
|
|
170
|
-
if len(self.group_by_clause) > 0:
|
|
171
|
-
info.assign_idxs(self.group_by_clause)
|
|
172
|
-
for e in self.group_by_clause:
|
|
173
|
-
self._analyze_group_by(e, True)
|
|
174
|
-
info.assign_idxs(self.select_list)
|
|
175
|
-
grouping_expr_idxs = set([e.data_row_idx for e in self.group_by_clause])
|
|
176
|
-
item_is_agg = [self._analyze_select_list(e, grouping_expr_idxs)[0] for e in self.select_list]
|
|
177
|
-
|
|
178
|
-
if self.is_agg():
|
|
179
|
-
# this is an aggregation
|
|
180
|
-
if item_is_agg.count(False) > 0:
|
|
181
|
-
raise exc.Error(f'Invalid non-aggregate in select list: {self.select_list[item_is_agg.find(False)]}')
|
|
182
|
-
# the agg stage materializes select list items that haven't already been provided by SQL
|
|
183
|
-
info.agg_output_exprs = [e for e in self.select_list if e.sql_row_idx == -1]
|
|
184
|
-
# our sql scan stage needs to materialize: grouping exprs, arguments of agg fn calls
|
|
185
|
-
info.sql_scan_output_exprs = copy.copy(self.group_by_clause)
|
|
186
|
-
unique_args: Set[int] = set()
|
|
187
|
-
for fn_call in info.agg_fn_calls:
|
|
188
|
-
for c in fn_call.components:
|
|
189
|
-
unique_args.add(c.data_row_idx)
|
|
190
|
-
all_exprs = {e.data_row_idx: e for e in info.unique_exprs}
|
|
191
|
-
info.sql_scan_output_exprs.extend([all_exprs[idx] for idx in unique_args])
|
|
192
|
-
else:
|
|
193
|
-
info.sql_scan_output_exprs = self.select_list
|
|
194
187
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
188
|
+
# select list logic
|
|
189
|
+
DataFrame._select_list_check_rep(select_list) # check select list without expansion
|
|
190
|
+
# exprs contain execution state and therefore cannot be shared
|
|
191
|
+
select_list = copy.deepcopy(select_list)
|
|
192
|
+
select_list_exprs, column_names = DataFrame._normalize_select_list(tbl, select_list)
|
|
193
|
+
DataFrame._select_list_check_rep(list(zip(select_list_exprs, column_names)))
|
|
194
|
+
# check select list after expansion to catch early
|
|
195
|
+
# the following two lists are always non empty, even if select list is None.
|
|
196
|
+
self._select_list_exprs = select_list_exprs
|
|
197
|
+
self._column_names = column_names
|
|
198
|
+
self.select_list = select_list
|
|
199
|
+
|
|
200
|
+
self.where_clause = copy.deepcopy(where_clause)
|
|
201
|
+
assert group_by_clause is None or grouping_tbl is None
|
|
202
|
+
self.group_by_clause = copy.deepcopy(group_by_clause)
|
|
203
|
+
self.grouping_tbl = grouping_tbl
|
|
204
|
+
self.order_by_clause = copy.deepcopy(order_by_clause)
|
|
205
|
+
self.limit_val = limit
|
|
206
|
+
|
|
207
|
+
@classmethod
|
|
208
|
+
def _select_list_check_rep(cls,
|
|
209
|
+
select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]],
|
|
210
|
+
) -> None:
|
|
211
|
+
"""Validate basic select list types.
|
|
203
212
|
"""
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
213
|
+
if select_list is None: # basic check for valid select list
|
|
214
|
+
return
|
|
215
|
+
|
|
216
|
+
assert len(select_list) > 0
|
|
217
|
+
for ent in select_list:
|
|
218
|
+
assert isinstance(ent, tuple)
|
|
219
|
+
assert len(ent) == 2
|
|
220
|
+
assert isinstance(ent[0], exprs.Expr)
|
|
221
|
+
assert ent[1] is None or isinstance(ent[1], str)
|
|
222
|
+
if isinstance(ent[1], str):
|
|
223
|
+
assert is_valid_identifier(ent[1])
|
|
224
|
+
|
|
225
|
+
@classmethod
|
|
226
|
+
def _normalize_select_list(cls,
|
|
227
|
+
tbl: catalog.TableVersionPath,
|
|
228
|
+
select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]],
|
|
229
|
+
) -> Tuple[List[exprs.Expr], List[str]]:
|
|
214
230
|
"""
|
|
215
|
-
|
|
216
|
-
|
|
231
|
+
Expand select list information with all columns and their names
|
|
232
|
+
Returns:
|
|
233
|
+
a pair composed of the list of expressions and the list of corresponding names
|
|
217
234
|
"""
|
|
218
|
-
if
|
|
219
|
-
|
|
220
|
-
elif self._is_agg_fn_call(e):
|
|
221
|
-
for c in e.components:
|
|
222
|
-
_, is_scan_output = self._analyze_select_list(c, grouping_exprs)
|
|
223
|
-
if not is_scan_output:
|
|
224
|
-
raise exc.Error(f'Invalid nested aggregates: {e}')
|
|
225
|
-
return True, False
|
|
226
|
-
elif isinstance(e, exprs.Literal):
|
|
227
|
-
return True, True
|
|
228
|
-
elif isinstance(e, exprs.ColumnRef):
|
|
229
|
-
# we already know that this isn't a grouping expr
|
|
230
|
-
return False, True
|
|
235
|
+
if select_list is None:
|
|
236
|
+
expanded_list = [(exprs.ColumnRef(col), None) for col in tbl.columns()]
|
|
231
237
|
else:
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
238
|
+
expanded_list = select_list
|
|
239
|
+
|
|
240
|
+
out_exprs : List[exprs.Expr] = []
|
|
241
|
+
out_names : List[str] = [] # keep track of order
|
|
242
|
+
seen_out_names : set[str] = set() # use to check for duplicates in loop, avoid square complexity
|
|
243
|
+
for i, (expr, name) in enumerate(expanded_list):
|
|
244
|
+
if name is None:
|
|
245
|
+
# use default, add suffix if needed so default adds no duplicates
|
|
246
|
+
default_name = expr.default_column_name()
|
|
247
|
+
if default_name is not None:
|
|
248
|
+
column_name = default_name
|
|
249
|
+
if default_name in seen_out_names:
|
|
250
|
+
# already used, then add suffix until unique name is found
|
|
251
|
+
for j in range(1, len(out_names)+1):
|
|
252
|
+
column_name = f'{default_name}_{j}'
|
|
253
|
+
if column_name not in seen_out_names:
|
|
254
|
+
break
|
|
255
|
+
else: # no default name, eg some expressions
|
|
256
|
+
column_name = f'col_{i}'
|
|
257
|
+
else: # user provided name, no attempt to rename
|
|
258
|
+
column_name = name
|
|
259
|
+
|
|
260
|
+
out_exprs.append(expr)
|
|
261
|
+
out_names.append(column_name)
|
|
262
|
+
seen_out_names.add(column_name)
|
|
263
|
+
assert len(out_exprs) == len(out_names)
|
|
264
|
+
assert set(out_names) == seen_out_names
|
|
265
|
+
return out_exprs, out_names
|
|
266
|
+
|
|
267
|
+
def _exec(self) -> Generator[exprs.DataRow, None, None]:
|
|
268
|
+
"""Run the query and return rows as a generator.
|
|
269
|
+
This function must not modify the state of the DataFrame, otherwise it breaks dataset caching.
|
|
250
270
|
"""
|
|
251
|
-
if
|
|
252
|
-
|
|
253
|
-
if self.
|
|
254
|
-
self.group_by_clause
|
|
255
|
-
|
|
271
|
+
# construct a group-by clause if we're grouping by a table
|
|
272
|
+
group_by_clause: List[exprs.Expr] = []
|
|
273
|
+
if self.grouping_tbl is not None:
|
|
274
|
+
assert self.group_by_clause is None
|
|
275
|
+
num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
|
|
276
|
+
# the grouping table must be a base of self.tbl
|
|
277
|
+
assert num_rowid_cols <= len(self.tbl.tbl_version.store_tbl.rowid_columns())
|
|
278
|
+
group_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
|
|
279
|
+
elif self.group_by_clause is not None:
|
|
280
|
+
group_by_clause = self.group_by_clause
|
|
281
|
+
|
|
282
|
+
for item in self._select_list_exprs:
|
|
256
283
|
item.bind_rel_paths(None)
|
|
257
|
-
|
|
258
|
-
self.
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
order_by_exprs = window_fn_calls[0].get_window_sort_exprs()
|
|
273
|
-
elif self.is_agg():
|
|
274
|
-
# TODO: collect aggs with order-by and analyze for compatibility
|
|
275
|
-
order_by_exprs = self.group_by_clause + self.analysis_info.agg_fn_calls[0].get_agg_order_by()
|
|
276
|
-
order_by_clause = [e.sql_expr() for e in order_by_exprs]
|
|
277
|
-
for i in range(len(order_by_exprs)):
|
|
278
|
-
if order_by_clause[i] is None:
|
|
279
|
-
raise exc.Error(f'order_by element cannot be expressed in SQL: {order_by_exprs[i]}')
|
|
280
|
-
|
|
281
|
-
idx_rowids: List[int] = [] # rowids returned by index lookup
|
|
282
|
-
if self.analysis_info.similarity_clause is not None:
|
|
283
|
-
# do index lookup
|
|
284
|
-
assert self.analysis_info.similarity_clause.img_col_ref.col.idx is not None
|
|
285
|
-
embed = self.analysis_info.similarity_clause.embedding()
|
|
286
|
-
idx_rowids = self.analysis_info.similarity_clause.img_col_ref.col.idx.search(embed, n, self.tbl.valid_rowids)
|
|
287
|
-
|
|
288
|
-
with Env.get().engine.connect() as conn:
|
|
289
|
-
stmt = self._create_select_stmt(
|
|
290
|
-
self.analysis_info.sql_select_list, self.analysis_info.sql_where_clause, idx_rowids, select_pk,
|
|
291
|
-
order_by_clause)
|
|
292
|
-
num_rows = 0
|
|
293
|
-
sql_scan_evaluator = exprs.ExprEvaluator(
|
|
294
|
-
self.analysis_info.sql_scan_output_exprs, self.analysis_info.filter)
|
|
295
|
-
agg_evaluator = exprs.ExprEvaluator(self.analysis_info.agg_output_exprs, None)
|
|
296
|
-
|
|
297
|
-
current_group: Optional[List[Any]] = None # for grouping agg, the values of the group-by exprs
|
|
298
|
-
for row in conn.execute(stmt):
|
|
299
|
-
sql_row = row._data
|
|
300
|
-
data_row: List[Any] = [None] * self.analysis_info.num_materialized
|
|
301
|
-
if not sql_scan_evaluator.eval(sql_row, data_row):
|
|
302
|
-
continue
|
|
303
|
-
|
|
304
|
-
# copy select list results into contiguous array
|
|
305
|
-
result_row: Optional[List[Any]] = None
|
|
306
|
-
if self.is_agg():
|
|
307
|
-
group = [data_row[e.data_row_idx] for e in self.group_by_clause]
|
|
308
|
-
if current_group is None:
|
|
309
|
-
current_group = group
|
|
310
|
-
if group != current_group:
|
|
311
|
-
# we're entering a new group, emit a row for the last one
|
|
312
|
-
agg_evaluator.eval(last_sql_row, last_data_row)
|
|
313
|
-
result_row = [last_data_row[e.data_row_idx] for e in self.select_list]
|
|
314
|
-
current_group = group
|
|
315
|
-
for fn_call in self.analysis_info.agg_fn_calls:
|
|
316
|
-
fn_call.reset_agg()
|
|
317
|
-
for fn_call in self.analysis_info.agg_fn_calls:
|
|
318
|
-
fn_call.update(data_row)
|
|
319
|
-
else:
|
|
320
|
-
result_row = [data_row[e.data_row_idx] for e in self.select_list]
|
|
321
|
-
if select_pk:
|
|
322
|
-
result_row.extend(sql_row[-2:])
|
|
323
|
-
|
|
324
|
-
last_data_row = data_row
|
|
325
|
-
last_sql_row = row._data
|
|
326
|
-
if result_row is not None:
|
|
327
|
-
yield result_row
|
|
328
|
-
num_rows += 1
|
|
329
|
-
if n > 0 and num_rows == n:
|
|
330
|
-
break
|
|
331
|
-
|
|
332
|
-
if self.is_agg():
|
|
333
|
-
# we need to emit the output row for the current group
|
|
334
|
-
agg_evaluator.eval(sql_row, data_row)
|
|
335
|
-
result_row = [data_row[e.data_row_idx] for e in self.select_list]
|
|
336
|
-
yield result_row
|
|
284
|
+
plan = Planner.create_query_plan(
|
|
285
|
+
self.tbl, self._select_list_exprs, where_clause=self.where_clause, group_by_clause=group_by_clause,
|
|
286
|
+
order_by_clause=self.order_by_clause if self.order_by_clause is not None else [],
|
|
287
|
+
limit=self.limit_val if self.limit_val is not None else 0) # limit_val == 0: no limit_val
|
|
288
|
+
|
|
289
|
+
with Env.get().engine.begin() as conn:
|
|
290
|
+
plan.ctx.conn = conn
|
|
291
|
+
plan.open()
|
|
292
|
+
try:
|
|
293
|
+
for row_batch in plan:
|
|
294
|
+
for data_row in row_batch:
|
|
295
|
+
yield data_row
|
|
296
|
+
finally:
|
|
297
|
+
plan.close()
|
|
298
|
+
return
|
|
337
299
|
|
|
338
300
|
def show(self, n: int = 20) -> DataFrameResultSet:
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
301
|
+
assert n is not None
|
|
302
|
+
return self.limit(n).collect()
|
|
303
|
+
|
|
304
|
+
def head(self, n: int = 10) -> DataFrameResultSet:
|
|
305
|
+
if self.order_by_clause is not None:
|
|
306
|
+
raise excs.Error(f'head() cannot be used with order_by()')
|
|
307
|
+
num_rowid_cols = len(self.tbl.tbl_version.store_tbl.rowid_columns())
|
|
308
|
+
order_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
|
|
309
|
+
return self.order_by(*order_by_clause, asc=True).limit(n).collect()
|
|
310
|
+
|
|
311
|
+
def tail(self, n: int = 10) -> DataFrameResultSet:
|
|
312
|
+
if self.order_by_clause is not None:
|
|
313
|
+
raise excs.Error(f'tail() cannot be used with order_by()')
|
|
314
|
+
num_rowid_cols = len(self.tbl.tbl_version.store_tbl.rowid_columns())
|
|
315
|
+
order_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
|
|
316
|
+
result = self.order_by(*order_by_clause, asc=False).limit(n).collect()
|
|
317
|
+
result._reverse()
|
|
318
|
+
return result
|
|
319
|
+
|
|
320
|
+
def get_column_names(self) -> List[str]:
|
|
321
|
+
return self._column_names
|
|
322
|
+
|
|
323
|
+
def get_column_types(self) -> List[ColumnType]:
|
|
324
|
+
return [expr.col_type for expr in self._select_list_exprs]
|
|
325
|
+
|
|
326
|
+
def collect(self) -> DataFrameResultSet:
|
|
327
|
+
try:
|
|
328
|
+
result_rows = []
|
|
329
|
+
for data_row in self._exec():
|
|
330
|
+
result_row = [data_row[e.slot_idx] for e in self._select_list_exprs]
|
|
331
|
+
result_rows.append(result_row)
|
|
332
|
+
except excs.ExprEvalError as e:
|
|
333
|
+
msg = (f'In row {e.row_num} the {e.expr_msg} encountered exception '
|
|
334
|
+
f'{type(e.exc).__name__}:\n{str(e.exc)}')
|
|
335
|
+
if len(e.input_vals) > 0:
|
|
336
|
+
input_msgs = [
|
|
337
|
+
f"'{d}' = {d.col_type.print_value(e.input_vals[i])}"
|
|
338
|
+
for i, d in enumerate(e.expr.dependencies())
|
|
339
|
+
]
|
|
340
|
+
msg += f'\nwith {", ".join(input_msgs)}'
|
|
341
|
+
assert e.exc_tb is not None
|
|
342
|
+
stack_trace = traceback.format_tb(e.exc_tb)
|
|
343
|
+
if len(stack_trace) > 2:
|
|
344
|
+
# append a stack trace if the exception happened in user code
|
|
345
|
+
# (frame 0 is ExprEvaluator and frame 1 is some expr's eval()
|
|
346
|
+
nl = '\n'
|
|
347
|
+
# [-1:0:-1]: leave out entry 0 and reverse order, so that the most recent frame is at the top
|
|
348
|
+
msg += f'\nStack:\n{nl.join(stack_trace[-1:1:-1])}'
|
|
349
|
+
raise excs.Error(msg)
|
|
350
|
+
except sql.exc.DBAPIError as e:
|
|
351
|
+
raise excs.Error(f'Error during SQL execution:\n{e}')
|
|
352
|
+
|
|
353
|
+
col_types = self.get_column_types()
|
|
354
|
+
return DataFrameResultSet(result_rows, self._column_names, col_types)
|
|
344
355
|
|
|
345
356
|
def count(self) -> int:
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
"""
|
|
349
|
-
stmt = sql.select(sql.func.count('*')).select_from(self.tbl.sa_tbl) \
|
|
350
|
-
.where(self.tbl.v_min_col <= self.tbl.version) \
|
|
351
|
-
.where(self.tbl.v_max_col > self.tbl.version)
|
|
352
|
-
if self.where_clause is not None:
|
|
353
|
-
sql_where_clause = self.where_clause.sql_expr()
|
|
354
|
-
assert sql_where_clause is not None
|
|
355
|
-
stmt = stmt.where(sql_where_clause)
|
|
357
|
+
from pixeltable.plan import Planner
|
|
358
|
+
stmt = Planner.create_count_stmt(self.tbl, self.where_clause)
|
|
356
359
|
with Env.get().engine.connect() as conn:
|
|
357
360
|
result: int = conn.execute(stmt).scalar_one()
|
|
358
361
|
assert isinstance(result, int)
|
|
359
362
|
return result
|
|
360
363
|
|
|
361
|
-
def
|
|
364
|
+
def _description(self) -> pd.DataFrame:
|
|
365
|
+
"""see DataFrame.describe()"""
|
|
366
|
+
heading_vals: List[str] = []
|
|
367
|
+
info_vals: List[str] = []
|
|
368
|
+
if self.select_list is not None:
|
|
369
|
+
assert len(self.select_list) > 0
|
|
370
|
+
heading_vals.append('Select')
|
|
371
|
+
heading_vals.extend([''] * (len(self.select_list) - 1))
|
|
372
|
+
info_vals.extend(self.get_column_names())
|
|
373
|
+
if self.where_clause is not None:
|
|
374
|
+
heading_vals.append('Where')
|
|
375
|
+
info_vals.append(self.where_clause.display_str(inline=False))
|
|
376
|
+
if self.group_by_clause is not None:
|
|
377
|
+
heading_vals.append('Group By')
|
|
378
|
+
heading_vals.extend([''] * (len(self.group_by_clause) - 1))
|
|
379
|
+
info_vals.extend([e.display_str(inline=False) for e in self.group_by_clause])
|
|
380
|
+
if self.order_by_clause is not None:
|
|
381
|
+
heading_vals.append('Order By')
|
|
382
|
+
heading_vals.extend([''] * (len(self.order_by_clause) - 1))
|
|
383
|
+
info_vals.extend([
|
|
384
|
+
f'{e[0].display_str(inline=False)} {"asc" if e[1] else "desc"}' for e in self.order_by_clause
|
|
385
|
+
])
|
|
386
|
+
if self.limit_val is not None:
|
|
387
|
+
heading_vals.append('Limit')
|
|
388
|
+
info_vals.append(str(self.limit_val))
|
|
389
|
+
assert len(heading_vals) > 0
|
|
390
|
+
assert len(info_vals) > 0
|
|
391
|
+
assert len(heading_vals) == len(info_vals)
|
|
392
|
+
return pd.DataFrame({'Heading': heading_vals, 'Info': info_vals})
|
|
393
|
+
|
|
394
|
+
def _description_html(self) -> pandas.io.formats.style.Styler:
|
|
395
|
+
"""Return the description in an ipython-friendly manner."""
|
|
396
|
+
pd_df = self._description()
|
|
397
|
+
# white-space: pre-wrap: print \n as newline
|
|
398
|
+
# th: center-align headings
|
|
399
|
+
return pd_df.style.set_properties(**{'white-space': 'pre-wrap', 'text-align': 'left'}) \
|
|
400
|
+
.set_table_styles([dict(selector='th', props=[('text-align', 'center')])]) \
|
|
401
|
+
.hide(axis='index').hide(axis='columns')
|
|
402
|
+
|
|
403
|
+
def describe(self) -> None:
|
|
362
404
|
"""
|
|
363
|
-
|
|
364
|
-
|
|
405
|
+
Prints a tabular description of this DataFrame.
|
|
406
|
+
The description has two columns, heading and info, which list the contents of each 'component'
|
|
407
|
+
(select list, where clause, ...) vertically.
|
|
365
408
|
"""
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
stmt = sql.select(sql.distinct(col.sa_col)) \
|
|
373
|
-
.where(self.tbl.v_min_col <= self.tbl.version) \
|
|
374
|
-
.where(self.tbl.v_max_col > self.tbl.version) \
|
|
375
|
-
.order_by(col.sa_col)
|
|
376
|
-
if self.where_clause is not None:
|
|
377
|
-
sql_where_clause = self.where_clause.sql_expr()
|
|
378
|
-
assert sql_where_clause is not None
|
|
379
|
-
stmt = stmt.where(sql_where_clause)
|
|
380
|
-
with Env.get().engine.connect() as conn:
|
|
381
|
-
result = {row._data[0]: i for i, row in enumerate(conn.execute(stmt))}
|
|
382
|
-
return result
|
|
409
|
+
try:
|
|
410
|
+
__IPYTHON__
|
|
411
|
+
from IPython.display import display
|
|
412
|
+
display(self._description_html())
|
|
413
|
+
except NameError:
|
|
414
|
+
print(self.__repr__())
|
|
383
415
|
|
|
384
|
-
def
|
|
416
|
+
def __repr__(self) -> str:
|
|
417
|
+
return self._description().to_string(header=False, index=False)
|
|
418
|
+
|
|
419
|
+
def _repr_html_(self) -> str:
|
|
420
|
+
return self._description_html()._repr_html_()
|
|
421
|
+
|
|
422
|
+
def select(self, *items: Any, **named_items : Any) -> DataFrame:
|
|
423
|
+
if self.select_list is not None:
|
|
424
|
+
raise excs.Error(f'Select list already specified')
|
|
425
|
+
for (name, _) in named_items.items():
|
|
426
|
+
if not isinstance(name, str) or not is_valid_identifier(name):
|
|
427
|
+
raise excs.Error(f'Invalid name: {name}')
|
|
428
|
+
base_list = [(expr, None) for expr in items] + [(expr, k) for (k, expr) in named_items.items()]
|
|
429
|
+
if len(base_list) == 0:
|
|
430
|
+
raise excs.Error(f'Empty select list')
|
|
431
|
+
|
|
432
|
+
# analyze select list; wrap literals with the corresponding expressions
|
|
433
|
+
select_list = []
|
|
434
|
+
for raw_expr, name in base_list:
|
|
435
|
+
if isinstance(raw_expr, exprs.Expr):
|
|
436
|
+
select_list.append((raw_expr, name))
|
|
437
|
+
elif isinstance(raw_expr, dict):
|
|
438
|
+
select_list.append((exprs.InlineDict(raw_expr), name))
|
|
439
|
+
elif isinstance(raw_expr, list):
|
|
440
|
+
select_list.append((exprs.InlineArray(raw_expr), name))
|
|
441
|
+
else:
|
|
442
|
+
select_list.append((exprs.Literal(raw_expr), name))
|
|
443
|
+
expr = select_list[-1][0]
|
|
444
|
+
if expr.col_type.is_invalid_type():
|
|
445
|
+
raise excs.Error(f'Invalid type: {raw_expr}')
|
|
446
|
+
# TODO: check that ColumnRefs in expr refer to self.tbl
|
|
447
|
+
|
|
448
|
+
# check user provided names do not conflict among themselves
|
|
449
|
+
# or with auto-generated ones
|
|
450
|
+
seen: Set[str] = set()
|
|
451
|
+
_, names = DataFrame._normalize_select_list(self.tbl, select_list)
|
|
452
|
+
for name in names:
|
|
453
|
+
if name in seen:
|
|
454
|
+
repeated_names = [j for j, x in enumerate(names) if x == name]
|
|
455
|
+
pretty = ', '.join(map(str, repeated_names))
|
|
456
|
+
raise excs.Error(f'Repeated column name "{name}" in select() at positions: {pretty}')
|
|
457
|
+
seen.add(name)
|
|
458
|
+
|
|
459
|
+
return DataFrame(
|
|
460
|
+
self.tbl, select_list=select_list, where_clause=self.where_clause, group_by_clause=self.group_by_clause,
|
|
461
|
+
grouping_tbl=self.grouping_tbl, order_by_clause=self.order_by_clause, limit=self.limit_val)
|
|
462
|
+
|
|
463
|
+
def where(self, pred: exprs.Predicate) -> DataFrame:
|
|
464
|
+
return DataFrame(
|
|
465
|
+
self.tbl, select_list=self.select_list, where_clause=pred, group_by_clause=self.group_by_clause,
|
|
466
|
+
grouping_tbl=self.grouping_tbl, order_by_clause=self.order_by_clause, limit=self.limit_val)
|
|
467
|
+
|
|
468
|
+
def group_by(self, *grouping_items: Any) -> DataFrame:
|
|
469
|
+
"""Add a group-by clause to this DataFrame.
|
|
470
|
+
Variants:
|
|
471
|
+
- group_by(<base table>): group a component view by their respective base table rows
|
|
472
|
+
- group_by(<expr>, ...): group by the given expressions
|
|
473
|
+
"""
|
|
474
|
+
if self.group_by_clause is not None:
|
|
475
|
+
raise excs.Error(f'Group-by already specified')
|
|
476
|
+
grouping_tbl: Optional[catalog.TableVersion] = None
|
|
477
|
+
group_by_clause: Optional[List[exprs.Expr]] = None
|
|
478
|
+
for item in grouping_items:
|
|
479
|
+
if isinstance(item, catalog.Table):
|
|
480
|
+
if len(grouping_items) > 1:
|
|
481
|
+
raise excs.Error(f'group_by(): only one table can be specified')
|
|
482
|
+
# we need to make sure that the grouping table is a base of self.tbl
|
|
483
|
+
base = self.tbl.find_tbl_version(item.tbl_version_path.tbl_id())
|
|
484
|
+
if base is None or base.id == self.tbl.tbl_id():
|
|
485
|
+
raise excs.Error(f'group_by(): {item.name} is not a base table of {self.tbl.tbl_name()}')
|
|
486
|
+
grouping_tbl = item.tbl_version_path.tbl_version
|
|
487
|
+
break
|
|
488
|
+
if not isinstance(item, exprs.Expr):
|
|
489
|
+
raise excs.Error(f'Invalid expression in group_by(): {item}')
|
|
490
|
+
if grouping_tbl is None:
|
|
491
|
+
group_by_clause = list(grouping_items)
|
|
492
|
+
return DataFrame(
|
|
493
|
+
self.tbl, select_list=self.select_list, where_clause=self.where_clause, group_by_clause=group_by_clause,
|
|
494
|
+
grouping_tbl=grouping_tbl, order_by_clause=self.order_by_clause, limit=self.limit_val)
|
|
495
|
+
|
|
496
|
+
def order_by(self, *expr_list: exprs.Expr, asc: bool = True) -> DataFrame:
|
|
497
|
+
for e in expr_list:
|
|
498
|
+
if not isinstance(e, exprs.Expr):
|
|
499
|
+
raise excs.Error(f'Invalid expression in order_by(): {e}')
|
|
500
|
+
order_by_clause = self.order_by_clause if self.order_by_clause is not None else []
|
|
501
|
+
order_by_clause.extend([(e.copy(), asc) for e in expr_list])
|
|
502
|
+
return DataFrame(
|
|
503
|
+
self.tbl, select_list=self.select_list, where_clause=self.where_clause,
|
|
504
|
+
group_by_clause=self.group_by_clause, grouping_tbl=self.grouping_tbl, order_by_clause=order_by_clause,
|
|
505
|
+
limit=self.limit_val)
|
|
506
|
+
|
|
507
|
+
def limit(self, n: int) -> DataFrame:
|
|
508
|
+
assert n is not None and isinstance(n, int)
|
|
509
|
+
return DataFrame(
|
|
510
|
+
self.tbl, select_list=self.select_list, where_clause=self.where_clause,
|
|
511
|
+
group_by_clause=self.group_by_clause, grouping_tbl=self.grouping_tbl, order_by_clause=self.order_by_clause,
|
|
512
|
+
limit=n)
|
|
513
|
+
|
|
514
|
+
def __getitem__(self, index: object) -> DataFrame:
|
|
385
515
|
"""
|
|
386
516
|
Allowed:
|
|
387
517
|
- [<Predicate>]: filter operation
|
|
@@ -389,52 +519,113 @@ class DataFrame:
|
|
|
389
519
|
- [Expr]: setting a single-col select list
|
|
390
520
|
"""
|
|
391
521
|
if isinstance(index, exprs.Predicate):
|
|
392
|
-
return
|
|
522
|
+
return self.where(index)
|
|
393
523
|
if isinstance(index, tuple):
|
|
394
524
|
index = list(index)
|
|
395
525
|
if isinstance(index, exprs.Expr):
|
|
396
526
|
index = [index]
|
|
397
527
|
if isinstance(index, list):
|
|
398
|
-
|
|
399
|
-
raise exc.OperationalError(f'[] for column selection is only allowed once')
|
|
400
|
-
# analyze select list; wrap literals with the corresponding expressions and update it in place
|
|
401
|
-
for i in range(len(index)):
|
|
402
|
-
expr = index[i]
|
|
403
|
-
if isinstance(expr, dict):
|
|
404
|
-
index[i] = expr = exprs.InlineDict(expr)
|
|
405
|
-
if isinstance(expr, list):
|
|
406
|
-
index[i] = expr = exprs.InlineArray(tuple(expr))
|
|
407
|
-
if not isinstance(expr, exprs.Expr):
|
|
408
|
-
raise exc.OperationalError(f'Invalid expression in []: {expr}')
|
|
409
|
-
if expr.col_type.is_invalid_type():
|
|
410
|
-
raise exc.OperationalError(f'Invalid type: {expr}')
|
|
411
|
-
# TODO: check that ColumnRefs in expr refer to self.tbl
|
|
412
|
-
return DataFrame(self.tbl, select_list=index, where_clause=self.where_clause)
|
|
528
|
+
return self.select(*index)
|
|
413
529
|
raise TypeError(f'Invalid index type: {type(index)}')
|
|
530
|
+
|
|
531
|
+
def _as_dict(self) -> Dict[str, Any]:
|
|
532
|
+
"""
|
|
533
|
+
Returns:
|
|
534
|
+
Dictionary representing this dataframe.
|
|
535
|
+
"""
|
|
536
|
+
tbl_versions = self.tbl.get_tbl_versions()
|
|
537
|
+
d = {
|
|
538
|
+
'_classname': 'DataFrame',
|
|
539
|
+
'tbl_ids': [str(t.id) for t in tbl_versions],
|
|
540
|
+
'tbl_versions': [t.version for t in tbl_versions],
|
|
541
|
+
'select_list':
|
|
542
|
+
[(e.as_dict(), name) for (e, name) in self.select_list] if self.select_list is not None else None,
|
|
543
|
+
'where_clause': self.where_clause.as_dict() if self.where_clause is not None else None,
|
|
544
|
+
'group_by_clause':
|
|
545
|
+
[e.as_dict() for e in self.group_by_clause] if self.group_by_clause is not None else None,
|
|
546
|
+
'order_by_clause':
|
|
547
|
+
[(e.as_dict(), asc) for (e,asc) in self.order_by_clause] if self.order_by_clause is not None else None,
|
|
548
|
+
'limit_val': self.limit_val,
|
|
549
|
+
}
|
|
550
|
+
return d
|
|
551
|
+
|
|
552
|
+
def to_coco_dataset(self) -> Path:
|
|
553
|
+
"""Convert the dataframe to a COCO dataset.
|
|
554
|
+
This dataframe must return a single json-typed output column in the following format:
|
|
555
|
+
{
|
|
556
|
+
'image': PIL.Image.Image,
|
|
557
|
+
'annotations': [
|
|
558
|
+
{
|
|
559
|
+
'bbox': [x: int, y: int, w: int, h: int],
|
|
560
|
+
'category': str | int,
|
|
561
|
+
},
|
|
562
|
+
...
|
|
563
|
+
],
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
Returns:
|
|
567
|
+
Path to the COCO dataset file.
|
|
568
|
+
"""
|
|
569
|
+
from pixeltable.utils.coco import write_coco_dataset
|
|
570
|
+
|
|
571
|
+
summary_string = json.dumps(self._as_dict())
|
|
572
|
+
cache_key = hashlib.sha256(summary_string.encode()).hexdigest()
|
|
573
|
+
|
|
574
|
+
dest_path = (Env.get().dataset_cache_dir / f'coco_{cache_key}')
|
|
575
|
+
if dest_path.exists():
|
|
576
|
+
assert dest_path.is_dir()
|
|
577
|
+
data_file_path = dest_path / 'data.json'
|
|
578
|
+
assert data_file_path.exists()
|
|
579
|
+
assert data_file_path.is_file()
|
|
580
|
+
return data_file_path
|
|
581
|
+
else:
|
|
582
|
+
return write_coco_dataset(self, dest_path)
|
|
414
583
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
584
|
+
# TODO Factor this out into a separate module.
|
|
585
|
+
# The return type is unresolvable, but torch can't be imported since it's an optional dependency.
|
|
586
|
+
def to_pytorch_dataset(self, image_format: str = 'pt') -> 'torch.utils.data.IterableDataset':
|
|
587
|
+
"""
|
|
588
|
+
Convert the dataframe to a pytorch IterableDataset suitable for parallel loading
|
|
589
|
+
with torch.utils.data.DataLoader.
|
|
590
|
+
|
|
591
|
+
This method requires pyarrow >= 13, torch and torchvision to work.
|
|
592
|
+
|
|
593
|
+
This method serializes data so it can be read from disk efficiently and repeatedly without
|
|
594
|
+
re-executing the query. This data is cached to disk for future re-use.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
image_format: format of the images. Can be 'pt' (pytorch tensor) or 'np' (numpy array).
|
|
598
|
+
'np' means image columns return as an RGB uint8 array of shape HxWxC.
|
|
599
|
+
'pt' means image columns return as a CxHxW tensor with values in [0,1] and type torch.float32.
|
|
600
|
+
(the format output by torchvision.transforms.ToTensor())
|
|
601
|
+
|
|
602
|
+
Returns:
|
|
603
|
+
A pytorch IterableDataset: Columns become fields of the dataset, where rows are returned as a dictionary
|
|
604
|
+
compatible with torch.utils.data.DataLoader default collation.
|
|
605
|
+
|
|
606
|
+
Constraints:
|
|
607
|
+
The default collate_fn for torch.data.util.DataLoader cannot represent null values as part of a
|
|
608
|
+
pytorch tensor when forming batches. These values will raise an exception while running the dataloader.
|
|
609
|
+
|
|
610
|
+
If you have them, you can work around None values by providing your custom collate_fn to the DataLoader
|
|
611
|
+
(and have your model handle it). Or, if these are not meaningful values within a minibtach, you can
|
|
612
|
+
modify or remove any such values through selections and filters prior to calling to_pytorch_dataset().
|
|
613
|
+
"""
|
|
614
|
+
# check dependencies
|
|
615
|
+
Env.get().require_package('pyarrow', [13])
|
|
616
|
+
Env.get().require_package('torch')
|
|
617
|
+
Env.get().require_package('torchvision')
|
|
618
|
+
|
|
619
|
+
from pixeltable.utils.parquet import save_parquet # pylint: disable=import-outside-toplevel
|
|
620
|
+
from pixeltable.utils.pytorch import PixeltablePytorchDataset # pylint: disable=import-outside-toplevel
|
|
621
|
+
|
|
622
|
+
summary_string = json.dumps(self._as_dict())
|
|
623
|
+
cache_key = hashlib.sha256(summary_string.encode()).hexdigest()
|
|
624
|
+
|
|
625
|
+
dest_path = (Env.get().dataset_cache_dir / f'df_{cache_key}').with_suffix('.parquet') # pylint: disable = protected-access
|
|
626
|
+
if dest_path.exists(): # fast path: use cache
|
|
627
|
+
assert dest_path.is_dir()
|
|
628
|
+
else:
|
|
629
|
+
save_parquet(self, dest_path)
|
|
630
|
+
|
|
631
|
+
return PixeltablePytorchDataset(path=dest_path, image_format=image_format)
|