pixeltable 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +1 -2
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/catalog.py +509 -103
- pixeltable/catalog/column.py +5 -0
- pixeltable/catalog/dir.py +15 -6
- pixeltable/catalog/globals.py +16 -0
- pixeltable/catalog/insertable_table.py +82 -41
- pixeltable/catalog/path.py +15 -0
- pixeltable/catalog/schema_object.py +7 -12
- pixeltable/catalog/table.py +81 -67
- pixeltable/catalog/table_version.py +23 -7
- pixeltable/catalog/view.py +9 -6
- pixeltable/env.py +15 -9
- pixeltable/exec/exec_node.py +1 -1
- pixeltable/exprs/__init__.py +2 -1
- pixeltable/exprs/arithmetic_expr.py +2 -0
- pixeltable/exprs/column_ref.py +38 -2
- pixeltable/exprs/expr.py +61 -12
- pixeltable/exprs/function_call.py +1 -4
- pixeltable/exprs/globals.py +12 -0
- pixeltable/exprs/json_mapper.py +4 -4
- pixeltable/exprs/json_path.py +10 -11
- pixeltable/exprs/similarity_expr.py +5 -20
- pixeltable/exprs/string_op.py +107 -0
- pixeltable/ext/functions/yolox.py +21 -64
- pixeltable/func/callable_function.py +5 -2
- pixeltable/func/query_template_function.py +6 -18
- pixeltable/func/tools.py +2 -2
- pixeltable/functions/__init__.py +1 -1
- pixeltable/functions/globals.py +16 -5
- pixeltable/globals.py +172 -262
- pixeltable/io/__init__.py +3 -2
- pixeltable/io/datarows.py +138 -0
- pixeltable/io/external_store.py +8 -5
- pixeltable/io/globals.py +7 -160
- pixeltable/io/hf_datasets.py +21 -98
- pixeltable/io/pandas.py +29 -43
- pixeltable/io/parquet.py +17 -42
- pixeltable/io/table_data_conduit.py +569 -0
- pixeltable/io/utils.py +6 -21
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_30.py +50 -0
- pixeltable/metadata/converters/util.py +26 -1
- pixeltable/metadata/notes.py +1 -0
- pixeltable/metadata/schema.py +3 -0
- pixeltable/utils/arrow.py +32 -7
- pixeltable/utils/coroutine.py +41 -0
- {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/METADATA +1 -1
- {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/RECORD +52 -47
- {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/WHEEL +1 -1
- {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/LICENSE +0 -0
- {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import enum
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import math
|
|
7
|
+
import urllib.parse
|
|
8
|
+
import urllib.request
|
|
9
|
+
from dataclasses import dataclass, field, fields
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional, Union, cast
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
from pyarrow.parquet import ParquetDataset
|
|
15
|
+
|
|
16
|
+
import pixeltable as pxt
|
|
17
|
+
import pixeltable.exceptions as excs
|
|
18
|
+
from pixeltable.io.pandas import _df_check_primary_key_values, _df_row_to_pxt_row, df_infer_schema
|
|
19
|
+
from pixeltable.utils import parse_local_file_path
|
|
20
|
+
|
|
21
|
+
from .utils import normalize_schema_names
|
|
22
|
+
|
|
23
|
+
_logger = logging.getLogger('pixeltable')
|
|
24
|
+
|
|
25
|
+
# ---------------------------------------------------------------------------------------------------------
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
import datasets # type: ignore[import-untyped]
|
|
29
|
+
|
|
30
|
+
from pixeltable.globals import RowData, TableDataSource
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TableDataConduitFormat(str, enum.Enum):
|
|
34
|
+
"""Supported formats for TableDataConduit"""
|
|
35
|
+
|
|
36
|
+
JSON = 'json'
|
|
37
|
+
CSV = 'csv'
|
|
38
|
+
EXCEL = 'excel'
|
|
39
|
+
PARQUET = 'parquet'
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def is_valid(cls, x: Any) -> bool:
|
|
43
|
+
if isinstance(x, str):
|
|
44
|
+
return x.lower() in [c.value for c in cls]
|
|
45
|
+
return False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# ---------------------------------------------------------------------------------------------------------
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class TableDataConduit:
|
|
53
|
+
source: TableDataSource
|
|
54
|
+
source_format: Optional[str] = None
|
|
55
|
+
source_column_map: Optional[dict[str, str]] = None
|
|
56
|
+
if_row_exists: Literal['update', 'ignore', 'error'] = 'error'
|
|
57
|
+
pxt_schema: Optional[dict[str, Any]] = None
|
|
58
|
+
src_schema_overrides: Optional[dict[str, Any]] = None
|
|
59
|
+
src_schema: Optional[dict[str, Any]] = None
|
|
60
|
+
pxt_pk: Optional[list[str]] = None
|
|
61
|
+
src_pk: Optional[list[str]] = None
|
|
62
|
+
valid_rows: Optional[RowData] = None
|
|
63
|
+
extra_fields: dict[str, Any] = field(default_factory=dict)
|
|
64
|
+
|
|
65
|
+
reqd_col_names: set[str] = field(default_factory=set)
|
|
66
|
+
computed_col_names: set[str] = field(default_factory=set)
|
|
67
|
+
|
|
68
|
+
total_rows: int = 0 # total number of rows emitted via valid_row_batch Iterator
|
|
69
|
+
|
|
70
|
+
_K_BATCH_SIZE_BYTES = 100_000_000 # 100 MB
|
|
71
|
+
|
|
72
|
+
def check_source_format(self) -> None:
|
|
73
|
+
assert self.source_format is None or TableDataConduitFormat.is_valid(self.source_format)
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def is_rowdata_structure(cls, d: TableDataSource) -> bool:
|
|
77
|
+
if not isinstance(d, list) or len(d) == 0:
|
|
78
|
+
return False
|
|
79
|
+
return all(isinstance(row, dict) for row in d)
|
|
80
|
+
|
|
81
|
+
def is_direct_df(self) -> bool:
|
|
82
|
+
return isinstance(self.source, pxt.DataFrame) and self.source_column_map is None
|
|
83
|
+
|
|
84
|
+
def normalize_pxt_schema_types(self) -> None:
|
|
85
|
+
for name, coltype in self.pxt_schema.items():
|
|
86
|
+
self.pxt_schema[name] = pxt.ColumnType.normalize_type(coltype)
|
|
87
|
+
|
|
88
|
+
def infer_schema(self) -> dict[str, Any]:
|
|
89
|
+
raise NotImplementedError
|
|
90
|
+
|
|
91
|
+
def valid_row_batch(self) -> Iterator[RowData]:
|
|
92
|
+
raise NotImplementedError
|
|
93
|
+
|
|
94
|
+
def prepare_for_insert_into_table(self) -> None:
|
|
95
|
+
if self.source is None:
|
|
96
|
+
return
|
|
97
|
+
raise NotImplementedError
|
|
98
|
+
|
|
99
|
+
def add_table_info(self, table: pxt.Table) -> None:
|
|
100
|
+
"""Add information about the table into which we are inserting data"""
|
|
101
|
+
assert isinstance(table, pxt.Table)
|
|
102
|
+
self.pxt_schema = table._schema
|
|
103
|
+
self.pxt_pk = table._tbl_version.get().primary_key
|
|
104
|
+
for col in table._tbl_version_path.columns():
|
|
105
|
+
if col.is_required_for_insert:
|
|
106
|
+
self.reqd_col_names.add(col.name)
|
|
107
|
+
if col.is_computed:
|
|
108
|
+
self.computed_col_names.add(col.name)
|
|
109
|
+
self.src_pk = []
|
|
110
|
+
|
|
111
|
+
# Check source columns : required, computed, unknown
|
|
112
|
+
def check_source_columns_are_insertable(self, columns: Iterable[str]) -> None:
|
|
113
|
+
col_name_set: set[str] = set()
|
|
114
|
+
for col_name in columns: # FIXME
|
|
115
|
+
mapped_col_name = self.source_column_map.get(col_name, col_name)
|
|
116
|
+
col_name_set.add(mapped_col_name)
|
|
117
|
+
if mapped_col_name not in self.pxt_schema:
|
|
118
|
+
raise excs.Error(f'Unknown column name {mapped_col_name}')
|
|
119
|
+
if mapped_col_name in self.computed_col_names:
|
|
120
|
+
raise excs.Error(f'Value for computed column {mapped_col_name}')
|
|
121
|
+
missing_cols = self.reqd_col_names - col_name_set
|
|
122
|
+
if len(missing_cols) > 0:
|
|
123
|
+
raise excs.Error(f'Missing required column(s) ({", ".join(missing_cols)})')
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
# ---------------------------------------------------------------------------------------------------------
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class DFTableDataConduit(TableDataConduit):
|
|
130
|
+
pxt_df: pxt.DataFrame = None
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def from_tds(cls, tds: TableDataConduit) -> 'DFTableDataConduit':
|
|
134
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
135
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
136
|
+
t = cls(**kwargs)
|
|
137
|
+
assert isinstance(tds.source, pxt.DataFrame)
|
|
138
|
+
t.pxt_df = tds.source
|
|
139
|
+
return t
|
|
140
|
+
|
|
141
|
+
def infer_schema(self) -> dict[str, Any]:
|
|
142
|
+
self.pxt_schema = self.pxt_df.schema
|
|
143
|
+
self.pxt_pk = self.src_pk
|
|
144
|
+
return self.pxt_schema
|
|
145
|
+
|
|
146
|
+
def prepare_for_insert_into_table(self) -> None:
|
|
147
|
+
if self.source_column_map is None:
|
|
148
|
+
self.source_column_map = {}
|
|
149
|
+
self.check_source_columns_are_insertable(self.pxt_df.schema.keys())
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# ---------------------------------------------------------------------------------------------------------
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class RowDataTableDataConduit(TableDataConduit):
|
|
156
|
+
raw_rows: Optional[RowData] = None
|
|
157
|
+
disable_mapping: bool = True
|
|
158
|
+
batch_count: int = 0
|
|
159
|
+
|
|
160
|
+
@classmethod
|
|
161
|
+
def from_tds(cls, tds: TableDataConduit) -> 'RowDataTableDataConduit':
|
|
162
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
163
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
164
|
+
t = cls(**kwargs)
|
|
165
|
+
if isinstance(tds.source, Iterator):
|
|
166
|
+
# Instantiate the iterator to get the raw rows here
|
|
167
|
+
t.raw_rows = list(tds.source)
|
|
168
|
+
elif TYPE_CHECKING:
|
|
169
|
+
t.raw_rows = cast(RowData, tds.source)
|
|
170
|
+
else:
|
|
171
|
+
t.raw_rows = tds.source
|
|
172
|
+
t.batch_count = 0
|
|
173
|
+
return t
|
|
174
|
+
|
|
175
|
+
def infer_schema(self) -> dict[str, Any]:
|
|
176
|
+
from .datarows import _infer_schema_from_rows
|
|
177
|
+
|
|
178
|
+
if self.source_column_map is None:
|
|
179
|
+
if self.src_schema_overrides is None:
|
|
180
|
+
self.src_schema_overrides = {}
|
|
181
|
+
self.src_schema = _infer_schema_from_rows(self.raw_rows, self.src_schema_overrides, self.src_pk)
|
|
182
|
+
self.pxt_schema, self.pxt_pk, self.source_column_map = normalize_schema_names(
|
|
183
|
+
self.src_schema, self.src_pk, self.src_schema_overrides, self.disable_mapping
|
|
184
|
+
)
|
|
185
|
+
self.normalize_pxt_schema_types()
|
|
186
|
+
else:
|
|
187
|
+
raise NotImplementedError()
|
|
188
|
+
|
|
189
|
+
self.prepare_for_insert_into_table()
|
|
190
|
+
return self.pxt_schema
|
|
191
|
+
|
|
192
|
+
def prepare_for_insert_into_table(self) -> None:
|
|
193
|
+
# Converting rows to insertable format is not needed, misnamed columns and types
|
|
194
|
+
# are errors in the incoming row format
|
|
195
|
+
if self.source_column_map is None:
|
|
196
|
+
self.source_column_map = {}
|
|
197
|
+
self.valid_rows = [self._translate_row(row) for row in self.raw_rows]
|
|
198
|
+
|
|
199
|
+
self.batch_count = 1 if self.raw_rows is not None else 0
|
|
200
|
+
|
|
201
|
+
def _translate_row(self, row: dict[str, Any]) -> dict[str, Any]:
|
|
202
|
+
if not isinstance(row, dict):
|
|
203
|
+
raise excs.Error(f'row {row} is not a dictionary')
|
|
204
|
+
|
|
205
|
+
col_names: set[str] = set()
|
|
206
|
+
output_row: dict[str, Any] = {}
|
|
207
|
+
for col_name, val in row.items():
|
|
208
|
+
mapped_col_name = self.source_column_map.get(col_name, col_name)
|
|
209
|
+
col_names.add(mapped_col_name)
|
|
210
|
+
if mapped_col_name not in self.pxt_schema:
|
|
211
|
+
raise excs.Error(f'Unknown column name {mapped_col_name} in row {row}')
|
|
212
|
+
if mapped_col_name in self.computed_col_names:
|
|
213
|
+
raise excs.Error(f'Value for computed column {mapped_col_name} in row {row}')
|
|
214
|
+
# basic sanity checks here
|
|
215
|
+
try:
|
|
216
|
+
checked_val = self.pxt_schema[mapped_col_name].create_literal(val)
|
|
217
|
+
except TypeError as e:
|
|
218
|
+
msg = str(e)
|
|
219
|
+
raise excs.Error(f'Error in column {col_name}: {msg[0].lower() + msg[1:]}\nRow: {row}') from e
|
|
220
|
+
output_row[mapped_col_name] = checked_val
|
|
221
|
+
missing_cols = self.reqd_col_names - col_names
|
|
222
|
+
if len(missing_cols) > 0:
|
|
223
|
+
raise excs.Error(f'Missing required column(s) ({", ".join(missing_cols)}) in row {row}')
|
|
224
|
+
return output_row
|
|
225
|
+
|
|
226
|
+
def valid_row_batch(self) -> Iterator[RowData]:
|
|
227
|
+
if self.batch_count > 0:
|
|
228
|
+
self.batch_count -= 1
|
|
229
|
+
yield self.valid_rows
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
# ---------------------------------------------------------------------------------------------------------
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class PandasTableDataConduit(TableDataConduit):
|
|
236
|
+
pd_df: pd.DataFrame = None
|
|
237
|
+
batch_count: int = 0
|
|
238
|
+
|
|
239
|
+
@classmethod
|
|
240
|
+
def from_tds(cls, tds: TableDataConduit) -> PandasTableDataConduit:
|
|
241
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
242
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
243
|
+
t = cls(**kwargs)
|
|
244
|
+
assert isinstance(tds.source, pd.DataFrame)
|
|
245
|
+
t.pd_df = tds.source
|
|
246
|
+
t.batch_count = 0
|
|
247
|
+
return t
|
|
248
|
+
|
|
249
|
+
def infer_schema_part1(self) -> tuple[dict[str, Any], list[str]]:
|
|
250
|
+
"""Return inferred schema, inferred primary key, and source column map"""
|
|
251
|
+
if self.source_column_map is None:
|
|
252
|
+
if self.src_schema_overrides is None:
|
|
253
|
+
self.src_schema_overrides = {}
|
|
254
|
+
self.src_schema = df_infer_schema(self.pd_df, self.src_schema_overrides, self.src_pk)
|
|
255
|
+
inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
|
|
256
|
+
self.src_schema, self.src_pk, self.src_schema_overrides, False
|
|
257
|
+
)
|
|
258
|
+
return inferred_schema, inferred_pk
|
|
259
|
+
else:
|
|
260
|
+
raise NotImplementedError()
|
|
261
|
+
|
|
262
|
+
def infer_schema(self) -> dict[str, Any]:
|
|
263
|
+
self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
|
|
264
|
+
self.normalize_pxt_schema_types()
|
|
265
|
+
_df_check_primary_key_values(self.pd_df, self.src_pk)
|
|
266
|
+
self.prepare_insert()
|
|
267
|
+
return self.pxt_schema
|
|
268
|
+
|
|
269
|
+
def prepare_for_insert_into_table(self) -> None:
|
|
270
|
+
_, inferred_pk = self.infer_schema_part1()
|
|
271
|
+
assert len(inferred_pk) == 0
|
|
272
|
+
self.prepare_insert()
|
|
273
|
+
|
|
274
|
+
def prepare_insert(self) -> None:
|
|
275
|
+
if self.source_column_map is None:
|
|
276
|
+
self.source_column_map = {}
|
|
277
|
+
self.check_source_columns_are_insertable(self.pd_df.columns)
|
|
278
|
+
# Convert all rows to insertable format
|
|
279
|
+
self.valid_rows = [
|
|
280
|
+
_df_row_to_pxt_row(row, self.src_schema, self.source_column_map) for row in self.pd_df.itertuples()
|
|
281
|
+
]
|
|
282
|
+
self.batch_count = 1
|
|
283
|
+
|
|
284
|
+
def valid_row_batch(self) -> Iterator[RowData]:
|
|
285
|
+
if self.batch_count > 0:
|
|
286
|
+
self.batch_count -= 1
|
|
287
|
+
yield self.valid_rows
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
# ---------------------------------------------------------------------------------------------------------
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class CSVTableDataConduit(TableDataConduit):
|
|
294
|
+
@classmethod
|
|
295
|
+
def from_tds(cls, tds: TableDataConduit) -> 'PandasTableDataConduit':
|
|
296
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
297
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
298
|
+
t = cls(**kwargs)
|
|
299
|
+
assert isinstance(t.source, str)
|
|
300
|
+
t.source = pd.read_csv(t.source, **t.extra_fields)
|
|
301
|
+
return PandasTableDataConduit.from_tds(t)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
# ---------------------------------------------------------------------------------------------------------
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class ExcelTableDataConduit(TableDataConduit):
|
|
308
|
+
@classmethod
|
|
309
|
+
def from_tds(cls, tds: TableDataConduit) -> 'PandasTableDataConduit':
|
|
310
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
311
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
312
|
+
t = cls(**kwargs)
|
|
313
|
+
assert isinstance(t.source, str)
|
|
314
|
+
t.source = pd.read_excel(t.source, **t.extra_fields)
|
|
315
|
+
return PandasTableDataConduit.from_tds(t)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
# ---------------------------------------------------------------------------------------------------------
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class JsonTableDataConduit(TableDataConduit):
|
|
322
|
+
@classmethod
|
|
323
|
+
def from_tds(cls, tds: TableDataConduit) -> RowDataTableDataConduit:
|
|
324
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
325
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
326
|
+
t = cls(**kwargs)
|
|
327
|
+
assert isinstance(t.source, str)
|
|
328
|
+
|
|
329
|
+
path = parse_local_file_path(t.source)
|
|
330
|
+
if path is None: # it's a URL
|
|
331
|
+
# TODO: This should read from S3 as well.
|
|
332
|
+
contents = urllib.request.urlopen(t.source).read()
|
|
333
|
+
else:
|
|
334
|
+
with open(path, 'r', encoding='utf-8') as fp:
|
|
335
|
+
contents = fp.read()
|
|
336
|
+
rows = json.loads(contents, **t.extra_fields)
|
|
337
|
+
t.source = rows
|
|
338
|
+
t2 = RowDataTableDataConduit.from_tds(t)
|
|
339
|
+
t2.disable_mapping = False
|
|
340
|
+
return t2
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
# ---------------------------------------------------------------------------------------------------------
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
class HFTableDataConduit(TableDataConduit):
|
|
347
|
+
hf_ds: Optional[Union[datasets.Dataset, datasets.DatasetDict]] = None
|
|
348
|
+
column_name_for_split: Optional[str] = None
|
|
349
|
+
categorical_features: dict[str, dict[int, str]]
|
|
350
|
+
hf_schema: dict[str, Any] = None
|
|
351
|
+
dataset_dict: dict[str, datasets.Dataset] = None
|
|
352
|
+
hf_schema_source: dict[str, Any] = None
|
|
353
|
+
|
|
354
|
+
@classmethod
|
|
355
|
+
def from_tds(cls, tds: TableDataConduit) -> 'HFTableDataConduit':
|
|
356
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
357
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
358
|
+
t = cls(**kwargs)
|
|
359
|
+
import datasets
|
|
360
|
+
|
|
361
|
+
assert isinstance(tds.source, (datasets.Dataset, datasets.DatasetDict))
|
|
362
|
+
t.hf_ds = tds.source
|
|
363
|
+
if 'column_name_for_split' in t.extra_fields:
|
|
364
|
+
t.column_name_for_split = t.extra_fields['column_name_for_split']
|
|
365
|
+
return t
|
|
366
|
+
|
|
367
|
+
@classmethod
|
|
368
|
+
def is_applicable(cls, tds: TableDataConduit) -> bool:
|
|
369
|
+
try:
|
|
370
|
+
import datasets
|
|
371
|
+
|
|
372
|
+
return (isinstance(tds.source_format, str) and tds.source_format.lower() == 'huggingface') or isinstance(
|
|
373
|
+
tds.source, (datasets.Dataset, datasets.DatasetDict)
|
|
374
|
+
)
|
|
375
|
+
except ImportError:
|
|
376
|
+
return False
|
|
377
|
+
|
|
378
|
+
def infer_schema_part1(self) -> tuple[dict[str, Any], list[str]]:
|
|
379
|
+
from pixeltable.io.hf_datasets import _get_hf_schema, huggingface_schema_to_pxt_schema
|
|
380
|
+
|
|
381
|
+
if self.source_column_map is None:
|
|
382
|
+
if self.src_schema_overrides is None:
|
|
383
|
+
self.src_schema_overrides = {}
|
|
384
|
+
self.hf_schema_source = _get_hf_schema(self.hf_ds)
|
|
385
|
+
self.src_schema = huggingface_schema_to_pxt_schema(
|
|
386
|
+
self.hf_schema_source, self.src_schema_overrides, self.src_pk
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# Add the split column to the schema if requested
|
|
390
|
+
if self.column_name_for_split is not None:
|
|
391
|
+
if self.column_name_for_split in self.src_schema:
|
|
392
|
+
raise excs.Error(
|
|
393
|
+
f'Column name `{self.column_name_for_split}` already exists in dataset schema;'
|
|
394
|
+
f'provide a different `column_name_for_split`'
|
|
395
|
+
)
|
|
396
|
+
self.src_schema[self.column_name_for_split] = pxt.StringType(nullable=True)
|
|
397
|
+
|
|
398
|
+
inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
|
|
399
|
+
self.src_schema, self.src_pk, self.src_schema_overrides, True
|
|
400
|
+
)
|
|
401
|
+
return inferred_schema, inferred_pk
|
|
402
|
+
else:
|
|
403
|
+
raise NotImplementedError()
|
|
404
|
+
|
|
405
|
+
def infer_schema(self) -> dict[str, Any]:
|
|
406
|
+
self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
|
|
407
|
+
self.normalize_pxt_schema_types()
|
|
408
|
+
self.prepare_insert()
|
|
409
|
+
return self.pxt_schema
|
|
410
|
+
|
|
411
|
+
def prepare_for_insert_into_table(self) -> None:
|
|
412
|
+
_, inferred_pk = self.infer_schema_part1()
|
|
413
|
+
assert len(inferred_pk) == 0
|
|
414
|
+
self.prepare_insert()
|
|
415
|
+
|
|
416
|
+
def prepare_insert(self) -> None:
|
|
417
|
+
import datasets
|
|
418
|
+
|
|
419
|
+
if isinstance(self.source, datasets.Dataset):
|
|
420
|
+
# when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
|
|
421
|
+
raw_name = self.source.split._name
|
|
422
|
+
split_name = raw_name.split('[')[0] if raw_name is not None else None
|
|
423
|
+
self.dataset_dict = {split_name: self.source}
|
|
424
|
+
else:
|
|
425
|
+
assert isinstance(self.source, datasets.DatasetDict)
|
|
426
|
+
self.dataset_dict = self.source
|
|
427
|
+
|
|
428
|
+
# extract all class labels from the dataset to translate category ints to strings
|
|
429
|
+
self.categorical_features = {
|
|
430
|
+
feature_name: feature_type.names
|
|
431
|
+
for (feature_name, feature_type) in self.hf_schema_source.items()
|
|
432
|
+
if isinstance(feature_type, datasets.ClassLabel)
|
|
433
|
+
}
|
|
434
|
+
if self.source_column_map is None:
|
|
435
|
+
self.source_column_map = {}
|
|
436
|
+
self.check_source_columns_are_insertable(self.hf_schema_source.keys())
|
|
437
|
+
|
|
438
|
+
def _translate_row(self, row: dict[str, Any], split_name: str) -> dict[str, Any]:
|
|
439
|
+
output_row: dict[str, Any] = {}
|
|
440
|
+
for col_name, val in row.items():
|
|
441
|
+
# translate category ints to strings
|
|
442
|
+
new_val = self.categorical_features[col_name][val] if col_name in self.categorical_features else val
|
|
443
|
+
mapped_col_name = self.source_column_map.get(col_name, col_name)
|
|
444
|
+
|
|
445
|
+
# Convert values to the appropriate type if needed
|
|
446
|
+
try:
|
|
447
|
+
checked_val = self.pxt_schema[mapped_col_name].create_literal(new_val)
|
|
448
|
+
except TypeError as e:
|
|
449
|
+
msg = str(e)
|
|
450
|
+
raise excs.Error(f'Error in column {col_name}: {msg[0].lower() + msg[1:]}\nRow: {row}') from e
|
|
451
|
+
output_row[mapped_col_name] = checked_val
|
|
452
|
+
|
|
453
|
+
# add split name to output row
|
|
454
|
+
if self.column_name_for_split is not None:
|
|
455
|
+
output_row[self.column_name_for_split] = split_name
|
|
456
|
+
return output_row
|
|
457
|
+
|
|
458
|
+
def valid_row_batch(self) -> Iterator[RowData]:
|
|
459
|
+
for split_name, split_dataset in self.dataset_dict.items():
|
|
460
|
+
num_batches = split_dataset.size_in_bytes / self._K_BATCH_SIZE_BYTES
|
|
461
|
+
tuples_per_batch = math.ceil(split_dataset.num_rows / num_batches)
|
|
462
|
+
assert tuples_per_batch > 0
|
|
463
|
+
|
|
464
|
+
batch = []
|
|
465
|
+
for row in split_dataset:
|
|
466
|
+
batch.append(self._translate_row(row, split_name))
|
|
467
|
+
if len(batch) >= tuples_per_batch:
|
|
468
|
+
yield batch
|
|
469
|
+
batch = []
|
|
470
|
+
# last batch
|
|
471
|
+
if len(batch) > 0:
|
|
472
|
+
yield batch
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
# ---------------------------------------------------------------------------------------------------------
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
class ParquetTableDataConduit(TableDataConduit):
|
|
479
|
+
pq_ds: Optional[ParquetDataset] = None
|
|
480
|
+
|
|
481
|
+
@classmethod
|
|
482
|
+
def from_tds(cls, tds: TableDataConduit) -> 'ParquetTableDataConduit':
|
|
483
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
484
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
485
|
+
t = cls(**kwargs)
|
|
486
|
+
|
|
487
|
+
from pyarrow import parquet
|
|
488
|
+
|
|
489
|
+
assert isinstance(tds.source, str)
|
|
490
|
+
input_path = Path(tds.source).expanduser()
|
|
491
|
+
t.pq_ds = parquet.ParquetDataset(str(input_path))
|
|
492
|
+
return t
|
|
493
|
+
|
|
494
|
+
def infer_schema_part1(self) -> tuple[dict[str, Any], list[str]]:
|
|
495
|
+
from pixeltable.utils.arrow import ar_infer_schema
|
|
496
|
+
|
|
497
|
+
if self.source_column_map is None:
|
|
498
|
+
if self.src_schema_overrides is None:
|
|
499
|
+
self.src_schema_overrides = {}
|
|
500
|
+
self.src_schema = ar_infer_schema(self.pq_ds.schema, self.src_schema_overrides, self.src_pk)
|
|
501
|
+
inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
|
|
502
|
+
self.src_schema, self.src_pk, self.src_schema_overrides
|
|
503
|
+
)
|
|
504
|
+
return inferred_schema, inferred_pk
|
|
505
|
+
else:
|
|
506
|
+
raise NotImplementedError()
|
|
507
|
+
|
|
508
|
+
def infer_schema(self) -> dict[str, Any]:
|
|
509
|
+
self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
|
|
510
|
+
self.normalize_pxt_schema_types()
|
|
511
|
+
self.prepare_insert()
|
|
512
|
+
return self.pxt_schema
|
|
513
|
+
|
|
514
|
+
def prepare_for_insert_into_table(self) -> None:
|
|
515
|
+
_, inferred_pk = self.infer_schema_part1()
|
|
516
|
+
assert len(inferred_pk) == 0
|
|
517
|
+
self.prepare_insert()
|
|
518
|
+
|
|
519
|
+
def prepare_insert(self) -> None:
|
|
520
|
+
if self.source_column_map is None:
|
|
521
|
+
self.source_column_map = {}
|
|
522
|
+
self.check_source_columns_are_insertable(self.pq_ds.schema.names)
|
|
523
|
+
self.total_rows = 0
|
|
524
|
+
|
|
525
|
+
def valid_row_batch(self) -> Iterator[RowData]:
|
|
526
|
+
from pixeltable.utils.arrow import iter_tuples2
|
|
527
|
+
|
|
528
|
+
try:
|
|
529
|
+
for fragment in self.pq_ds.fragments: # type: ignore[attr-defined]
|
|
530
|
+
for batch in fragment.to_batches():
|
|
531
|
+
dict_batch = list(iter_tuples2(batch, self.source_column_map, self.pxt_schema))
|
|
532
|
+
self.total_rows += len(dict_batch)
|
|
533
|
+
yield dict_batch
|
|
534
|
+
except Exception as e:
|
|
535
|
+
_logger.error(f'Error after inserting {self.total_rows} rows from Parquet file into table: {e}')
|
|
536
|
+
raise e
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
# ---------------------------------------------------------------------------------------------------------
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
class UnkTableDataConduit(TableDataConduit):
|
|
543
|
+
"""Source type is not known at the time of creation"""
|
|
544
|
+
|
|
545
|
+
def specialize(self) -> TableDataConduit:
|
|
546
|
+
if isinstance(self.source, pxt.DataFrame):
|
|
547
|
+
return DFTableDataConduit.from_tds(self)
|
|
548
|
+
if isinstance(self.source, pd.DataFrame):
|
|
549
|
+
return PandasTableDataConduit.from_tds(self)
|
|
550
|
+
if HFTableDataConduit.is_applicable(self):
|
|
551
|
+
return HFTableDataConduit.from_tds(self)
|
|
552
|
+
if self.source_format == 'csv' or (isinstance(self.source, str) and '.csv' in self.source.lower()):
|
|
553
|
+
return CSVTableDataConduit.from_tds(self)
|
|
554
|
+
if self.source_format == 'excel' or (isinstance(self.source, str) and '.xls' in self.source.lower()):
|
|
555
|
+
return ExcelTableDataConduit.from_tds(self)
|
|
556
|
+
if self.source_format == 'json' or (isinstance(self.source, str) and '.json' in self.source.lower()):
|
|
557
|
+
return JsonTableDataConduit.from_tds(self)
|
|
558
|
+
if self.source_format == 'parquet' or (
|
|
559
|
+
isinstance(self.source, str) and any(s in self.source.lower() for s in ['.parquet', '.pq', '.parq'])
|
|
560
|
+
):
|
|
561
|
+
return ParquetTableDataConduit.from_tds(self)
|
|
562
|
+
if (
|
|
563
|
+
self.is_rowdata_structure(self.source)
|
|
564
|
+
# An Iterator as a source is assumed to produce rows
|
|
565
|
+
or isinstance(self.source, Iterator)
|
|
566
|
+
):
|
|
567
|
+
return RowDataTableDataConduit.from_tds(self)
|
|
568
|
+
|
|
569
|
+
raise excs.Error(f'Unsupported data source type: {type(self.source)}')
|
pixeltable/io/utils.py
CHANGED
|
@@ -3,7 +3,6 @@ from typing import Any, Optional, Union
|
|
|
3
3
|
|
|
4
4
|
import pixeltable as pxt
|
|
5
5
|
import pixeltable.exceptions as excs
|
|
6
|
-
from pixeltable import Table
|
|
7
6
|
from pixeltable.catalog.globals import is_system_column_name
|
|
8
7
|
|
|
9
8
|
|
|
@@ -22,16 +21,14 @@ def normalize_pxt_col_name(name: str) -> str:
|
|
|
22
21
|
return id
|
|
23
22
|
|
|
24
23
|
|
|
25
|
-
def
|
|
26
|
-
schema_overrides: Optional[dict[str, Any]] = None, primary_key: Optional[Union[str, list[str]]] = None
|
|
27
|
-
) -> tuple[dict[str, Any], list[str]]:
|
|
28
|
-
if schema_overrides is None:
|
|
29
|
-
schema_overrides = {}
|
|
24
|
+
def normalize_primary_key_parameter(primary_key: Optional[Union[str, list[str]]] = None) -> list[str]:
|
|
30
25
|
if primary_key is None:
|
|
31
26
|
primary_key = []
|
|
32
27
|
elif isinstance(primary_key, str):
|
|
33
28
|
primary_key = [primary_key]
|
|
34
|
-
|
|
29
|
+
elif not isinstance(primary_key, list) or not all(isinstance(pk, str) for pk in primary_key):
|
|
30
|
+
raise excs.Error('primary_key must be a single column name or a list of column names')
|
|
31
|
+
return primary_key
|
|
35
32
|
|
|
36
33
|
|
|
37
34
|
def _is_usable_as_column_name(name: str, destination_schema: dict[str, Any]) -> bool:
|
|
@@ -65,7 +62,8 @@ def normalize_schema_names(
|
|
|
65
62
|
extraneous_overrides = schema_overrides.keys() - in_schema.keys()
|
|
66
63
|
if len(extraneous_overrides) > 0:
|
|
67
64
|
raise excs.Error(
|
|
68
|
-
f'Some column(s) specified in `schema_overrides` are not present
|
|
65
|
+
f'Some column(s) specified in `schema_overrides` are not present '
|
|
66
|
+
f'in the source: {", ".join(extraneous_overrides)}'
|
|
69
67
|
)
|
|
70
68
|
|
|
71
69
|
schema: dict[str, Any] = {}
|
|
@@ -100,16 +98,3 @@ def normalize_schema_names(
|
|
|
100
98
|
pxt_pk = [col_mapping[pk] for pk in primary_key] if col_mapping is not None else primary_key
|
|
101
99
|
|
|
102
100
|
return schema, pxt_pk, col_mapping
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def find_or_create_table(
|
|
106
|
-
tbl_path: str,
|
|
107
|
-
schema: dict[str, Any],
|
|
108
|
-
*,
|
|
109
|
-
primary_key: Optional[Union[str, list[str]]],
|
|
110
|
-
num_retained_versions: int,
|
|
111
|
-
comment: str,
|
|
112
|
-
) -> Table:
|
|
113
|
-
return pxt.create_table(
|
|
114
|
-
tbl_path, schema, primary_key=primary_key, num_retained_versions=num_retained_versions, comment=comment
|
|
115
|
-
)
|
pixeltable/metadata/__init__.py
CHANGED
|
@@ -16,7 +16,7 @@ _console_logger = ConsoleLogger(logging.getLogger('pixeltable'))
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
# current version of the metadata; this is incremented whenever the metadata schema changes
|
|
19
|
-
VERSION =
|
|
19
|
+
VERSION = 31
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def create_system_info(engine: sql.engine.Engine) -> None:
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
3
|
+
import sqlalchemy as sql
|
|
4
|
+
|
|
5
|
+
from pixeltable.metadata import register_converter
|
|
6
|
+
from pixeltable.metadata.converters.util import (
|
|
7
|
+
convert_table_record,
|
|
8
|
+
convert_table_schema_version_record,
|
|
9
|
+
convert_table_version_record,
|
|
10
|
+
)
|
|
11
|
+
from pixeltable.metadata.schema import Table, TableSchemaVersion, TableVersion
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@register_converter(version=30)
|
|
15
|
+
def _(engine: sql.engine.Engine) -> None:
|
|
16
|
+
convert_table_record(engine, table_record_updater=__update_table_record)
|
|
17
|
+
convert_table_version_record(engine, table_version_record_updater=__update_table_version_record)
|
|
18
|
+
convert_table_schema_version_record(
|
|
19
|
+
engine, table_schema_version_record_updater=__update_table_schema_version_record
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def __update_table_record(record: Table) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Update TableMd with table_id
|
|
26
|
+
"""
|
|
27
|
+
assert isinstance(record.md, dict)
|
|
28
|
+
md = copy.copy(record.md)
|
|
29
|
+
md['tbl_id'] = str(record.id)
|
|
30
|
+
record.md = md
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def __update_table_version_record(record: TableVersion) -> None:
|
|
34
|
+
"""
|
|
35
|
+
Update TableVersion with table_id.
|
|
36
|
+
"""
|
|
37
|
+
assert isinstance(record.md, dict)
|
|
38
|
+
md = copy.copy(record.md)
|
|
39
|
+
md['tbl_id'] = str(record.tbl_id)
|
|
40
|
+
record.md = md
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def __update_table_schema_version_record(record: TableSchemaVersion) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Update TableSchemaVersion with table_id.
|
|
46
|
+
"""
|
|
47
|
+
assert isinstance(record.md, dict)
|
|
48
|
+
md = copy.copy(record.md)
|
|
49
|
+
md['tbl_id'] = str(record.tbl_id)
|
|
50
|
+
record.md = md
|