pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__init__.py +83 -19
- pixeltable/_query.py +1444 -0
- pixeltable/_version.py +1 -0
- pixeltable/catalog/__init__.py +7 -4
- pixeltable/catalog/catalog.py +2394 -119
- pixeltable/catalog/column.py +225 -104
- pixeltable/catalog/dir.py +38 -9
- pixeltable/catalog/globals.py +53 -34
- pixeltable/catalog/insertable_table.py +265 -115
- pixeltable/catalog/path.py +80 -17
- pixeltable/catalog/schema_object.py +28 -43
- pixeltable/catalog/table.py +1270 -677
- pixeltable/catalog/table_metadata.py +103 -0
- pixeltable/catalog/table_version.py +1270 -751
- pixeltable/catalog/table_version_handle.py +109 -0
- pixeltable/catalog/table_version_path.py +137 -42
- pixeltable/catalog/tbl_ops.py +53 -0
- pixeltable/catalog/update_status.py +191 -0
- pixeltable/catalog/view.py +251 -134
- pixeltable/config.py +215 -0
- pixeltable/env.py +736 -285
- pixeltable/exceptions.py +26 -2
- pixeltable/exec/__init__.py +7 -2
- pixeltable/exec/aggregation_node.py +39 -21
- pixeltable/exec/cache_prefetch_node.py +87 -109
- pixeltable/exec/cell_materialization_node.py +268 -0
- pixeltable/exec/cell_reconstruction_node.py +168 -0
- pixeltable/exec/component_iteration_node.py +25 -28
- pixeltable/exec/data_row_batch.py +11 -46
- pixeltable/exec/exec_context.py +26 -11
- pixeltable/exec/exec_node.py +35 -27
- pixeltable/exec/expr_eval/__init__.py +3 -0
- pixeltable/exec/expr_eval/evaluators.py +365 -0
- pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
- pixeltable/exec/expr_eval/globals.py +200 -0
- pixeltable/exec/expr_eval/row_buffer.py +74 -0
- pixeltable/exec/expr_eval/schedulers.py +413 -0
- pixeltable/exec/globals.py +35 -0
- pixeltable/exec/in_memory_data_node.py +35 -27
- pixeltable/exec/object_store_save_node.py +293 -0
- pixeltable/exec/row_update_node.py +44 -29
- pixeltable/exec/sql_node.py +414 -115
- pixeltable/exprs/__init__.py +8 -5
- pixeltable/exprs/arithmetic_expr.py +79 -45
- pixeltable/exprs/array_slice.py +5 -5
- pixeltable/exprs/column_property_ref.py +40 -26
- pixeltable/exprs/column_ref.py +254 -61
- pixeltable/exprs/comparison.py +14 -9
- pixeltable/exprs/compound_predicate.py +9 -10
- pixeltable/exprs/data_row.py +213 -72
- pixeltable/exprs/expr.py +270 -104
- pixeltable/exprs/expr_dict.py +6 -5
- pixeltable/exprs/expr_set.py +20 -11
- pixeltable/exprs/function_call.py +383 -284
- pixeltable/exprs/globals.py +18 -5
- pixeltable/exprs/in_predicate.py +7 -7
- pixeltable/exprs/inline_expr.py +37 -37
- pixeltable/exprs/is_null.py +8 -4
- pixeltable/exprs/json_mapper.py +120 -54
- pixeltable/exprs/json_path.py +90 -60
- pixeltable/exprs/literal.py +61 -16
- pixeltable/exprs/method_ref.py +7 -6
- pixeltable/exprs/object_ref.py +19 -8
- pixeltable/exprs/row_builder.py +238 -75
- pixeltable/exprs/rowid_ref.py +53 -15
- pixeltable/exprs/similarity_expr.py +65 -50
- pixeltable/exprs/sql_element_cache.py +5 -5
- pixeltable/exprs/string_op.py +107 -0
- pixeltable/exprs/type_cast.py +25 -13
- pixeltable/exprs/variable.py +2 -2
- pixeltable/func/__init__.py +9 -5
- pixeltable/func/aggregate_function.py +197 -92
- pixeltable/func/callable_function.py +119 -35
- pixeltable/func/expr_template_function.py +101 -48
- pixeltable/func/function.py +375 -62
- pixeltable/func/function_registry.py +20 -19
- pixeltable/func/globals.py +6 -5
- pixeltable/func/mcp.py +74 -0
- pixeltable/func/query_template_function.py +151 -35
- pixeltable/func/signature.py +178 -49
- pixeltable/func/tools.py +164 -0
- pixeltable/func/udf.py +176 -53
- pixeltable/functions/__init__.py +44 -4
- pixeltable/functions/anthropic.py +226 -47
- pixeltable/functions/audio.py +148 -11
- pixeltable/functions/bedrock.py +137 -0
- pixeltable/functions/date.py +188 -0
- pixeltable/functions/deepseek.py +113 -0
- pixeltable/functions/document.py +81 -0
- pixeltable/functions/fal.py +76 -0
- pixeltable/functions/fireworks.py +72 -20
- pixeltable/functions/gemini.py +249 -0
- pixeltable/functions/globals.py +208 -53
- pixeltable/functions/groq.py +108 -0
- pixeltable/functions/huggingface.py +1088 -95
- pixeltable/functions/image.py +155 -84
- pixeltable/functions/json.py +8 -11
- pixeltable/functions/llama_cpp.py +31 -19
- pixeltable/functions/math.py +169 -0
- pixeltable/functions/mistralai.py +50 -75
- pixeltable/functions/net.py +70 -0
- pixeltable/functions/ollama.py +29 -36
- pixeltable/functions/openai.py +548 -160
- pixeltable/functions/openrouter.py +143 -0
- pixeltable/functions/replicate.py +15 -14
- pixeltable/functions/reve.py +250 -0
- pixeltable/functions/string.py +310 -85
- pixeltable/functions/timestamp.py +37 -19
- pixeltable/functions/together.py +77 -120
- pixeltable/functions/twelvelabs.py +188 -0
- pixeltable/functions/util.py +7 -2
- pixeltable/functions/uuid.py +30 -0
- pixeltable/functions/video.py +1528 -117
- pixeltable/functions/vision.py +26 -26
- pixeltable/functions/voyageai.py +289 -0
- pixeltable/functions/whisper.py +19 -10
- pixeltable/functions/whisperx.py +179 -0
- pixeltable/functions/yolox.py +112 -0
- pixeltable/globals.py +716 -236
- pixeltable/index/__init__.py +3 -1
- pixeltable/index/base.py +17 -21
- pixeltable/index/btree.py +32 -22
- pixeltable/index/embedding_index.py +155 -92
- pixeltable/io/__init__.py +12 -7
- pixeltable/io/datarows.py +140 -0
- pixeltable/io/external_store.py +83 -125
- pixeltable/io/fiftyone.py +24 -33
- pixeltable/io/globals.py +47 -182
- pixeltable/io/hf_datasets.py +96 -127
- pixeltable/io/label_studio.py +171 -156
- pixeltable/io/lancedb.py +3 -0
- pixeltable/io/pandas.py +136 -115
- pixeltable/io/parquet.py +40 -153
- pixeltable/io/table_data_conduit.py +702 -0
- pixeltable/io/utils.py +100 -0
- pixeltable/iterators/__init__.py +8 -4
- pixeltable/iterators/audio.py +207 -0
- pixeltable/iterators/base.py +9 -3
- pixeltable/iterators/document.py +144 -87
- pixeltable/iterators/image.py +17 -38
- pixeltable/iterators/string.py +15 -12
- pixeltable/iterators/video.py +523 -127
- pixeltable/metadata/__init__.py +33 -8
- pixeltable/metadata/converters/convert_10.py +2 -3
- pixeltable/metadata/converters/convert_13.py +2 -2
- pixeltable/metadata/converters/convert_15.py +15 -11
- pixeltable/metadata/converters/convert_16.py +4 -5
- pixeltable/metadata/converters/convert_17.py +4 -5
- pixeltable/metadata/converters/convert_18.py +4 -6
- pixeltable/metadata/converters/convert_19.py +6 -9
- pixeltable/metadata/converters/convert_20.py +3 -6
- pixeltable/metadata/converters/convert_21.py +6 -8
- pixeltable/metadata/converters/convert_22.py +3 -2
- pixeltable/metadata/converters/convert_23.py +33 -0
- pixeltable/metadata/converters/convert_24.py +55 -0
- pixeltable/metadata/converters/convert_25.py +19 -0
- pixeltable/metadata/converters/convert_26.py +23 -0
- pixeltable/metadata/converters/convert_27.py +29 -0
- pixeltable/metadata/converters/convert_28.py +13 -0
- pixeltable/metadata/converters/convert_29.py +110 -0
- pixeltable/metadata/converters/convert_30.py +63 -0
- pixeltable/metadata/converters/convert_31.py +11 -0
- pixeltable/metadata/converters/convert_32.py +15 -0
- pixeltable/metadata/converters/convert_33.py +17 -0
- pixeltable/metadata/converters/convert_34.py +21 -0
- pixeltable/metadata/converters/convert_35.py +9 -0
- pixeltable/metadata/converters/convert_36.py +38 -0
- pixeltable/metadata/converters/convert_37.py +15 -0
- pixeltable/metadata/converters/convert_38.py +39 -0
- pixeltable/metadata/converters/convert_39.py +124 -0
- pixeltable/metadata/converters/convert_40.py +73 -0
- pixeltable/metadata/converters/convert_41.py +12 -0
- pixeltable/metadata/converters/convert_42.py +9 -0
- pixeltable/metadata/converters/convert_43.py +44 -0
- pixeltable/metadata/converters/util.py +44 -18
- pixeltable/metadata/notes.py +21 -0
- pixeltable/metadata/schema.py +185 -42
- pixeltable/metadata/utils.py +74 -0
- pixeltable/mypy/__init__.py +3 -0
- pixeltable/mypy/mypy_plugin.py +123 -0
- pixeltable/plan.py +616 -225
- pixeltable/share/__init__.py +3 -0
- pixeltable/share/packager.py +797 -0
- pixeltable/share/protocol/__init__.py +33 -0
- pixeltable/share/protocol/common.py +165 -0
- pixeltable/share/protocol/operation_types.py +33 -0
- pixeltable/share/protocol/replica.py +119 -0
- pixeltable/share/publish.py +349 -0
- pixeltable/store.py +398 -232
- pixeltable/type_system.py +730 -267
- pixeltable/utils/__init__.py +40 -0
- pixeltable/utils/arrow.py +201 -29
- pixeltable/utils/av.py +298 -0
- pixeltable/utils/azure_store.py +346 -0
- pixeltable/utils/coco.py +26 -27
- pixeltable/utils/code.py +4 -4
- pixeltable/utils/console_output.py +46 -0
- pixeltable/utils/coroutine.py +24 -0
- pixeltable/utils/dbms.py +92 -0
- pixeltable/utils/description_helper.py +11 -12
- pixeltable/utils/documents.py +60 -61
- pixeltable/utils/exception_handler.py +36 -0
- pixeltable/utils/filecache.py +38 -22
- pixeltable/utils/formatter.py +88 -51
- pixeltable/utils/gcs_store.py +295 -0
- pixeltable/utils/http.py +133 -0
- pixeltable/utils/http_server.py +14 -13
- pixeltable/utils/iceberg.py +13 -0
- pixeltable/utils/image.py +17 -0
- pixeltable/utils/lancedb.py +90 -0
- pixeltable/utils/local_store.py +322 -0
- pixeltable/utils/misc.py +5 -0
- pixeltable/utils/object_stores.py +573 -0
- pixeltable/utils/pydantic.py +60 -0
- pixeltable/utils/pytorch.py +20 -20
- pixeltable/utils/s3_store.py +527 -0
- pixeltable/utils/sql.py +32 -5
- pixeltable/utils/system.py +30 -0
- pixeltable/utils/transactional_directory.py +4 -3
- pixeltable-0.5.7.dist-info/METADATA +579 -0
- pixeltable-0.5.7.dist-info/RECORD +227 -0
- {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
- pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
- pixeltable/__version__.py +0 -3
- pixeltable/catalog/named_function.py +0 -36
- pixeltable/catalog/path_dict.py +0 -141
- pixeltable/dataframe.py +0 -894
- pixeltable/exec/expr_eval_node.py +0 -232
- pixeltable/ext/__init__.py +0 -14
- pixeltable/ext/functions/__init__.py +0 -8
- pixeltable/ext/functions/whisperx.py +0 -77
- pixeltable/ext/functions/yolox.py +0 -157
- pixeltable/tool/create_test_db_dump.py +0 -311
- pixeltable/tool/create_test_video.py +0 -81
- pixeltable/tool/doc_plugins/griffe.py +0 -50
- pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
- pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
- pixeltable/tool/embed_udf.py +0 -9
- pixeltable/tool/mypy_plugin.py +0 -55
- pixeltable/utils/media_store.py +0 -76
- pixeltable/utils/s3.py +0 -16
- pixeltable-0.2.26.dist-info/METADATA +0 -400
- pixeltable-0.2.26.dist-info/RECORD +0 -156
- pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
- {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,702 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import enum
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import urllib.request
|
|
7
|
+
from dataclasses import dataclass, field, fields
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, cast
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import pyarrow as pa
|
|
14
|
+
import pyarrow.compute as pc
|
|
15
|
+
import pyarrow.types as pat
|
|
16
|
+
from pyarrow.parquet import ParquetDataset
|
|
17
|
+
|
|
18
|
+
import pixeltable as pxt
|
|
19
|
+
import pixeltable.exceptions as excs
|
|
20
|
+
import pixeltable.type_system as ts
|
|
21
|
+
from pixeltable.io.pandas import _df_check_primary_key_values, _df_row_to_pxt_row, df_infer_schema
|
|
22
|
+
from pixeltable.utils import parse_local_file_path
|
|
23
|
+
|
|
24
|
+
from .utils import normalize_schema_names
|
|
25
|
+
|
|
26
|
+
_logger = logging.getLogger('pixeltable')
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
import datasets # type: ignore[import-untyped]
|
|
31
|
+
|
|
32
|
+
from pixeltable.globals import RowData, TableDataSource
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TableDataConduitFormat(str, enum.Enum):
|
|
36
|
+
"""Supported formats for TableDataConduit"""
|
|
37
|
+
|
|
38
|
+
JSON = 'json'
|
|
39
|
+
CSV = 'csv'
|
|
40
|
+
EXCEL = 'excel'
|
|
41
|
+
PARQUET = 'parquet'
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def is_valid(cls, x: Any) -> bool:
|
|
45
|
+
if isinstance(x, str):
|
|
46
|
+
return x.lower() in [c.value for c in cls]
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class TableDataConduit:
|
|
52
|
+
source: 'TableDataSource'
|
|
53
|
+
source_format: str | None = None
|
|
54
|
+
source_column_map: dict[str, str] | None = None
|
|
55
|
+
if_row_exists: Literal['update', 'ignore', 'error'] = 'error'
|
|
56
|
+
pxt_schema: dict[str, ts.ColumnType] | None = None
|
|
57
|
+
src_schema_overrides: dict[str, ts.ColumnType] | None = None
|
|
58
|
+
src_schema: dict[str, ts.ColumnType] | None = None
|
|
59
|
+
pxt_pk: list[str] | None = None
|
|
60
|
+
src_pk: list[str] | None = None
|
|
61
|
+
valid_rows: RowData | None = None
|
|
62
|
+
extra_fields: dict[str, Any] = field(default_factory=dict)
|
|
63
|
+
|
|
64
|
+
reqd_col_names: set[str] = field(default_factory=set)
|
|
65
|
+
computed_col_names: set[str] = field(default_factory=set)
|
|
66
|
+
|
|
67
|
+
total_rows: int = 0 # total number of rows emitted via valid_row_batch Iterator
|
|
68
|
+
|
|
69
|
+
_K_BATCH_SIZE_BYTES = 256 * 2**20
|
|
70
|
+
|
|
71
|
+
def check_source_format(self) -> None:
|
|
72
|
+
assert self.source_format is None or TableDataConduitFormat.is_valid(self.source_format)
|
|
73
|
+
|
|
74
|
+
def __post_init__(self) -> None:
|
|
75
|
+
"""If no extra_fields were provided, initialize to empty dict"""
|
|
76
|
+
if self.extra_fields is None:
|
|
77
|
+
self.extra_fields = {}
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def is_rowdata_structure(cls, d: TableDataSource) -> bool:
|
|
81
|
+
if not isinstance(d, list) or len(d) == 0:
|
|
82
|
+
return False
|
|
83
|
+
return all(isinstance(row, dict) for row in d)
|
|
84
|
+
|
|
85
|
+
def is_direct_query(self) -> bool:
|
|
86
|
+
return isinstance(self.source, pxt.Query) and self.source_column_map is None
|
|
87
|
+
|
|
88
|
+
def normalize_pxt_schema_types(self) -> None:
|
|
89
|
+
for name, coltype in self.pxt_schema.items():
|
|
90
|
+
self.pxt_schema[name] = ts.ColumnType.normalize_type(coltype)
|
|
91
|
+
|
|
92
|
+
def infer_schema(self) -> dict[str, ts.ColumnType]:
|
|
93
|
+
raise NotImplementedError
|
|
94
|
+
|
|
95
|
+
def valid_row_batch(self) -> Iterator[RowData]:
|
|
96
|
+
raise NotImplementedError
|
|
97
|
+
|
|
98
|
+
def prepare_for_insert_into_table(self) -> None:
|
|
99
|
+
if self.source is None:
|
|
100
|
+
return
|
|
101
|
+
raise NotImplementedError
|
|
102
|
+
|
|
103
|
+
def add_table_info(self, table: pxt.Table) -> None:
|
|
104
|
+
"""Add information about the table into which we are inserting data"""
|
|
105
|
+
assert isinstance(table, pxt.Table)
|
|
106
|
+
self.pxt_schema = table._get_schema()
|
|
107
|
+
self.pxt_pk = table._tbl_version.get().primary_key
|
|
108
|
+
for col in table._tbl_version_path.columns():
|
|
109
|
+
if col.is_required_for_insert:
|
|
110
|
+
self.reqd_col_names.add(col.name)
|
|
111
|
+
if col.is_computed:
|
|
112
|
+
self.computed_col_names.add(col.name)
|
|
113
|
+
self.src_pk = []
|
|
114
|
+
|
|
115
|
+
# Check source columns : required, computed, unknown
|
|
116
|
+
def check_source_columns_are_insertable(self, columns: Iterable[str]) -> None:
|
|
117
|
+
col_name_set: set[str] = set()
|
|
118
|
+
for col_name in columns: # FIXME
|
|
119
|
+
mapped_col_name = self.source_column_map.get(col_name, col_name)
|
|
120
|
+
col_name_set.add(mapped_col_name)
|
|
121
|
+
if mapped_col_name not in self.pxt_schema:
|
|
122
|
+
raise excs.Error(f'Unknown column name {mapped_col_name}')
|
|
123
|
+
if mapped_col_name in self.computed_col_names:
|
|
124
|
+
raise excs.Error(f'Value for computed column {mapped_col_name}')
|
|
125
|
+
missing_cols = self.reqd_col_names - col_name_set
|
|
126
|
+
if len(missing_cols) > 0:
|
|
127
|
+
raise excs.Error(f'Missing required column(s) ({", ".join(missing_cols)})')
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class QueryTableDataConduit(TableDataConduit):
|
|
131
|
+
pxt_query: pxt.Query = None
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def from_tds(cls, tds: TableDataConduit) -> 'QueryTableDataConduit':
|
|
135
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
136
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
137
|
+
t = cls(**kwargs)
|
|
138
|
+
if isinstance(tds.source, pxt.Table):
|
|
139
|
+
t.pxt_query = tds.source.select()
|
|
140
|
+
else:
|
|
141
|
+
assert isinstance(tds.source, pxt.Query)
|
|
142
|
+
t.pxt_query = tds.source
|
|
143
|
+
return t
|
|
144
|
+
|
|
145
|
+
def infer_schema(self) -> dict[str, ts.ColumnType]:
|
|
146
|
+
self.pxt_schema = self.pxt_query.schema
|
|
147
|
+
self.pxt_pk = self.src_pk
|
|
148
|
+
return self.pxt_schema
|
|
149
|
+
|
|
150
|
+
def prepare_for_insert_into_table(self) -> None:
|
|
151
|
+
if self.source_column_map is None:
|
|
152
|
+
self.source_column_map = {}
|
|
153
|
+
self.check_source_columns_are_insertable(self.pxt_query.schema.keys())
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class RowDataTableDataConduit(TableDataConduit):
|
|
157
|
+
raw_rows: RowData | None = None
|
|
158
|
+
disable_mapping: bool = True
|
|
159
|
+
batch_count: int = 0
|
|
160
|
+
|
|
161
|
+
@classmethod
|
|
162
|
+
def from_tds(cls, tds: TableDataConduit) -> 'RowDataTableDataConduit':
|
|
163
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
164
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
165
|
+
t = cls(**kwargs)
|
|
166
|
+
if isinstance(tds.source, Iterator):
|
|
167
|
+
# Instantiate the iterator to get the raw rows here
|
|
168
|
+
t.raw_rows = list(tds.source)
|
|
169
|
+
elif TYPE_CHECKING:
|
|
170
|
+
t.raw_rows = cast(RowData, tds.source)
|
|
171
|
+
else:
|
|
172
|
+
t.raw_rows = tds.source
|
|
173
|
+
t.batch_count = 0
|
|
174
|
+
return t
|
|
175
|
+
|
|
176
|
+
def infer_schema(self) -> dict[str, ts.ColumnType]:
|
|
177
|
+
from .datarows import _infer_schema_from_rows
|
|
178
|
+
|
|
179
|
+
if self.source_column_map is None:
|
|
180
|
+
if self.src_schema_overrides is None:
|
|
181
|
+
self.src_schema_overrides = {}
|
|
182
|
+
self.src_schema = _infer_schema_from_rows(self.raw_rows, self.src_schema_overrides, self.src_pk)
|
|
183
|
+
self.pxt_schema, self.pxt_pk, self.source_column_map = normalize_schema_names(
|
|
184
|
+
self.src_schema, self.src_pk, self.src_schema_overrides, self.disable_mapping
|
|
185
|
+
)
|
|
186
|
+
self.normalize_pxt_schema_types()
|
|
187
|
+
else:
|
|
188
|
+
raise NotImplementedError()
|
|
189
|
+
|
|
190
|
+
self.prepare_for_insert_into_table()
|
|
191
|
+
return self.pxt_schema
|
|
192
|
+
|
|
193
|
+
def prepare_for_insert_into_table(self) -> None:
|
|
194
|
+
# Converting rows to insertable format is not needed, misnamed columns and types
|
|
195
|
+
# are errors in the incoming row format
|
|
196
|
+
if self.source_column_map is None:
|
|
197
|
+
self.source_column_map = {}
|
|
198
|
+
self.valid_rows = [self._translate_row(row) for row in self.raw_rows]
|
|
199
|
+
|
|
200
|
+
self.batch_count = 1 if self.raw_rows is not None else 0
|
|
201
|
+
|
|
202
|
+
def _translate_row(self, row: dict[str, Any]) -> dict[str, Any]:
|
|
203
|
+
if not isinstance(row, dict):
|
|
204
|
+
raise excs.Error(f'row {row} is not a dictionary')
|
|
205
|
+
|
|
206
|
+
col_names: set[str] = set()
|
|
207
|
+
output_row: dict[str, Any] = {}
|
|
208
|
+
for col_name, val in row.items():
|
|
209
|
+
mapped_col_name = self.source_column_map.get(col_name, col_name)
|
|
210
|
+
col_names.add(mapped_col_name)
|
|
211
|
+
if mapped_col_name not in self.pxt_schema:
|
|
212
|
+
raise excs.Error(f'Unknown column name {mapped_col_name} in row {row}')
|
|
213
|
+
if mapped_col_name in self.computed_col_names:
|
|
214
|
+
raise excs.Error(f'Value for computed column {mapped_col_name} in row {row}')
|
|
215
|
+
# basic sanity checks here
|
|
216
|
+
try:
|
|
217
|
+
checked_val = self.pxt_schema[mapped_col_name].create_literal(val)
|
|
218
|
+
except TypeError as e:
|
|
219
|
+
msg = str(e)
|
|
220
|
+
raise excs.Error(f'Error in column {col_name}: {msg[0].lower() + msg[1:]}\nRow: {row}') from e
|
|
221
|
+
output_row[mapped_col_name] = checked_val
|
|
222
|
+
missing_cols = self.reqd_col_names - col_names
|
|
223
|
+
if len(missing_cols) > 0:
|
|
224
|
+
raise excs.Error(f'Missing required column(s) ({", ".join(missing_cols)}) in row {row}')
|
|
225
|
+
return output_row
|
|
226
|
+
|
|
227
|
+
def valid_row_batch(self) -> Iterator[RowData]:
|
|
228
|
+
if self.batch_count > 0:
|
|
229
|
+
self.batch_count -= 1
|
|
230
|
+
yield self.valid_rows
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class PandasTableDataConduit(TableDataConduit):
|
|
234
|
+
pd_df: pd.DataFrame = None
|
|
235
|
+
batch_count: int = 0
|
|
236
|
+
|
|
237
|
+
@classmethod
|
|
238
|
+
def from_tds(cls, tds: TableDataConduit) -> PandasTableDataConduit:
|
|
239
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
240
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
241
|
+
t = cls(**kwargs)
|
|
242
|
+
assert isinstance(tds.source, pd.DataFrame)
|
|
243
|
+
t.pd_df = tds.source
|
|
244
|
+
t.batch_count = 0
|
|
245
|
+
return t
|
|
246
|
+
|
|
247
|
+
def infer_schema_part1(self) -> tuple[dict[str, ts.ColumnType], list[str]]:
|
|
248
|
+
"""Return inferred schema, inferred primary key, and source column map"""
|
|
249
|
+
if self.source_column_map is None:
|
|
250
|
+
if self.src_schema_overrides is None:
|
|
251
|
+
self.src_schema_overrides = {}
|
|
252
|
+
self.src_schema = df_infer_schema(self.pd_df, self.src_schema_overrides, self.src_pk)
|
|
253
|
+
inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
|
|
254
|
+
self.src_schema, self.src_pk, self.src_schema_overrides, False
|
|
255
|
+
)
|
|
256
|
+
return inferred_schema, inferred_pk
|
|
257
|
+
else:
|
|
258
|
+
raise NotImplementedError()
|
|
259
|
+
|
|
260
|
+
def infer_schema(self) -> dict[str, ts.ColumnType]:
|
|
261
|
+
self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
|
|
262
|
+
self.normalize_pxt_schema_types()
|
|
263
|
+
_df_check_primary_key_values(self.pd_df, self.src_pk)
|
|
264
|
+
self.prepare_insert()
|
|
265
|
+
return self.pxt_schema
|
|
266
|
+
|
|
267
|
+
def prepare_for_insert_into_table(self) -> None:
|
|
268
|
+
_, inferred_pk = self.infer_schema_part1()
|
|
269
|
+
assert len(inferred_pk) == 0
|
|
270
|
+
self.prepare_insert()
|
|
271
|
+
|
|
272
|
+
def prepare_insert(self) -> None:
|
|
273
|
+
if self.source_column_map is None:
|
|
274
|
+
self.source_column_map = {}
|
|
275
|
+
self.check_source_columns_are_insertable(self.pd_df.columns)
|
|
276
|
+
# Convert all rows to insertable format
|
|
277
|
+
self.valid_rows = [
|
|
278
|
+
_df_row_to_pxt_row(row, self.src_schema, self.source_column_map) for row in self.pd_df.itertuples()
|
|
279
|
+
]
|
|
280
|
+
self.batch_count = 1
|
|
281
|
+
|
|
282
|
+
def valid_row_batch(self) -> Iterator[RowData]:
|
|
283
|
+
if self.batch_count > 0:
|
|
284
|
+
self.batch_count -= 1
|
|
285
|
+
yield self.valid_rows
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class CSVTableDataConduit(TableDataConduit):
|
|
289
|
+
@classmethod
|
|
290
|
+
def from_tds(cls, tds: TableDataConduit) -> 'PandasTableDataConduit':
|
|
291
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
292
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
293
|
+
t = cls(**kwargs)
|
|
294
|
+
assert isinstance(t.source, str)
|
|
295
|
+
t.source = pd.read_csv(t.source, **t.extra_fields)
|
|
296
|
+
return PandasTableDataConduit.from_tds(t)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class ExcelTableDataConduit(TableDataConduit):
|
|
300
|
+
@classmethod
|
|
301
|
+
def from_tds(cls, tds: TableDataConduit) -> 'PandasTableDataConduit':
|
|
302
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
303
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
304
|
+
t = cls(**kwargs)
|
|
305
|
+
assert isinstance(t.source, str)
|
|
306
|
+
t.source = pd.read_excel(t.source, **t.extra_fields)
|
|
307
|
+
return PandasTableDataConduit.from_tds(t)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class JsonTableDataConduit(TableDataConduit):
|
|
311
|
+
@classmethod
|
|
312
|
+
def from_tds(cls, tds: TableDataConduit) -> RowDataTableDataConduit:
|
|
313
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
314
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
315
|
+
t = cls(**kwargs)
|
|
316
|
+
assert isinstance(t.source, str)
|
|
317
|
+
|
|
318
|
+
path = parse_local_file_path(t.source)
|
|
319
|
+
if path is None: # it's a URL
|
|
320
|
+
# TODO: This should read from S3 as well.
|
|
321
|
+
contents = urllib.request.urlopen(t.source).read()
|
|
322
|
+
else:
|
|
323
|
+
with open(path, 'r', encoding='utf-8') as fp:
|
|
324
|
+
contents = fp.read()
|
|
325
|
+
rows = json.loads(contents, **t.extra_fields)
|
|
326
|
+
t.source = rows
|
|
327
|
+
t2 = RowDataTableDataConduit.from_tds(t)
|
|
328
|
+
t2.disable_mapping = False
|
|
329
|
+
return t2
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
class HFTableDataConduit(TableDataConduit):
|
|
333
|
+
"""HuggingFace dataset importer"""
|
|
334
|
+
|
|
335
|
+
column_name_for_split: str | None = None
|
|
336
|
+
categorical_features: dict[str, dict[int, str]]
|
|
337
|
+
dataset_dict: dict[str, 'datasets.Dataset'] = None # key: split name
|
|
338
|
+
hf_schema_source: dict[str, Any] = None
|
|
339
|
+
|
|
340
|
+
@classmethod
|
|
341
|
+
def from_tds(cls, tds: TableDataConduit) -> HFTableDataConduit:
|
|
342
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
343
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
344
|
+
t = cls(**kwargs)
|
|
345
|
+
import datasets
|
|
346
|
+
|
|
347
|
+
assert isinstance(tds.source, cls._get_dataset_classes())
|
|
348
|
+
if 'column_name_for_split' in t.extra_fields:
|
|
349
|
+
t.column_name_for_split = t.extra_fields['column_name_for_split']
|
|
350
|
+
|
|
351
|
+
if isinstance(tds.source, (datasets.IterableDataset, datasets.IterableDatasetDict)):
|
|
352
|
+
tds.source = tds.source.with_format('arrow')
|
|
353
|
+
|
|
354
|
+
if isinstance(tds.source, (datasets.Dataset, datasets.IterableDataset)):
|
|
355
|
+
split_name = str(tds.source.split) if tds.source.split is not None else None
|
|
356
|
+
t.dataset_dict = {split_name: tds.source}
|
|
357
|
+
else:
|
|
358
|
+
assert isinstance(tds.source, (datasets.DatasetDict, datasets.IterableDatasetDict))
|
|
359
|
+
t.dataset_dict = dict(tds.source)
|
|
360
|
+
|
|
361
|
+
# Disable auto-decoding for Audio and Image columns, we want to write the bytes directly to temp files
|
|
362
|
+
for ds_split_name, dataset in list(t.dataset_dict.items()):
|
|
363
|
+
for col_name, feature in dataset.features.items():
|
|
364
|
+
if isinstance(feature, (datasets.Audio, datasets.Image)):
|
|
365
|
+
t.dataset_dict[ds_split_name] = t.dataset_dict[ds_split_name].cast_column(
|
|
366
|
+
col_name, feature.__class__(decode=False)
|
|
367
|
+
)
|
|
368
|
+
return t
|
|
369
|
+
|
|
370
|
+
@classmethod
|
|
371
|
+
def _get_dataset_classes(cls) -> tuple[type, ...]:
|
|
372
|
+
import datasets
|
|
373
|
+
|
|
374
|
+
return (datasets.Dataset, datasets.DatasetDict, datasets.IterableDataset, datasets.IterableDatasetDict)
|
|
375
|
+
|
|
376
|
+
@classmethod
|
|
377
|
+
def is_applicable(cls, tds: TableDataConduit) -> bool:
|
|
378
|
+
try:
|
|
379
|
+
return (isinstance(tds.source_format, str) and tds.source_format.lower() == 'huggingface') or isinstance(
|
|
380
|
+
tds.source, cls._get_dataset_classes()
|
|
381
|
+
)
|
|
382
|
+
except ImportError:
|
|
383
|
+
return False
|
|
384
|
+
|
|
385
|
+
def infer_schema_part1(self) -> tuple[dict[str, ts.ColumnType], list[str]]:
|
|
386
|
+
from pixeltable.io.hf_datasets import _get_hf_schema, huggingface_schema_to_pxt_schema
|
|
387
|
+
|
|
388
|
+
if self.source_column_map is None:
|
|
389
|
+
if self.src_schema_overrides is None:
|
|
390
|
+
self.src_schema_overrides = {}
|
|
391
|
+
if self.src_pk is None:
|
|
392
|
+
self.src_pk = []
|
|
393
|
+
self.hf_schema_source = _get_hf_schema(self.source)
|
|
394
|
+
self.src_schema = huggingface_schema_to_pxt_schema(
|
|
395
|
+
self.hf_schema_source, self.src_schema_overrides, self.src_pk
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# Add the split column to the schema if requested
|
|
399
|
+
if self.column_name_for_split is not None:
|
|
400
|
+
if self.column_name_for_split in self.src_schema:
|
|
401
|
+
raise excs.Error(
|
|
402
|
+
f'Column name `{self.column_name_for_split}` already exists in dataset schema;'
|
|
403
|
+
f'provide a different `column_name_for_split`'
|
|
404
|
+
)
|
|
405
|
+
self.src_schema[self.column_name_for_split] = ts.StringType(nullable=True)
|
|
406
|
+
|
|
407
|
+
inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
|
|
408
|
+
self.src_schema, self.src_pk, self.src_schema_overrides
|
|
409
|
+
)
|
|
410
|
+
return inferred_schema, inferred_pk
|
|
411
|
+
else:
|
|
412
|
+
raise NotImplementedError()
|
|
413
|
+
|
|
414
|
+
def infer_schema(self) -> dict[str, Any]:
|
|
415
|
+
self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
|
|
416
|
+
self.normalize_pxt_schema_types()
|
|
417
|
+
self.prepare_insert()
|
|
418
|
+
return self.pxt_schema
|
|
419
|
+
|
|
420
|
+
def prepare_for_insert_into_table(self) -> None:
|
|
421
|
+
_, inferred_pk = self.infer_schema_part1()
|
|
422
|
+
assert len(inferred_pk) == 0
|
|
423
|
+
self.prepare_insert()
|
|
424
|
+
|
|
425
|
+
def prepare_insert(self) -> None:
|
|
426
|
+
import datasets
|
|
427
|
+
|
|
428
|
+
# Extract all class labels from the dataset to translate category ints to strings
|
|
429
|
+
self.categorical_features = {
|
|
430
|
+
feature_name: feature_type.names
|
|
431
|
+
for (feature_name, feature_type) in self.hf_schema_source.items()
|
|
432
|
+
if isinstance(feature_type, datasets.ClassLabel)
|
|
433
|
+
}
|
|
434
|
+
if self.source_column_map is None:
|
|
435
|
+
self.source_column_map = {}
|
|
436
|
+
self.check_source_columns_are_insertable(self.hf_schema_source.keys())
|
|
437
|
+
|
|
438
|
+
def _convert_column(self, column: 'pa.ChunkedArray', feature: object) -> list:
|
|
439
|
+
"""
|
|
440
|
+
Convert an Arrow column to a list of Python values based on HF feature type.
|
|
441
|
+
Handles all feature types at the column level, recursing for structs.
|
|
442
|
+
Returns a list of length chunk_size.
|
|
443
|
+
"""
|
|
444
|
+
import datasets
|
|
445
|
+
|
|
446
|
+
# return scalars as Python scalars
|
|
447
|
+
if isinstance(feature, datasets.Value):
|
|
448
|
+
return column.to_pylist()
|
|
449
|
+
|
|
450
|
+
# ClassLabel: int -> string name
|
|
451
|
+
if isinstance(feature, datasets.ClassLabel):
|
|
452
|
+
values = column.to_pylist()
|
|
453
|
+
return [feature.names[v] if v is not None else None for v in values]
|
|
454
|
+
|
|
455
|
+
# check for list of dict before Sequence, which could contain array data
|
|
456
|
+
is_list_of_dict = isinstance(feature, (datasets.Sequence, datasets.LargeList)) and isinstance(
|
|
457
|
+
feature.feature, dict
|
|
458
|
+
)
|
|
459
|
+
if is_list_of_dict:
|
|
460
|
+
return column.to_pylist()
|
|
461
|
+
|
|
462
|
+
# array data represented as a (possibly nested) sequence of numerical data: convert to numpy arrays
|
|
463
|
+
if self._is_sequence_of_numerical(feature):
|
|
464
|
+
arr = column.to_numpy(zero_copy_only=False)
|
|
465
|
+
result: list = []
|
|
466
|
+
for i in range(len(column)):
|
|
467
|
+
val = arr[i]
|
|
468
|
+
assert not isinstance(val, dict) # we dealt with list of dicts earlier
|
|
469
|
+
# convert object array of arrays (e.g., multi-channel audio) to proper ndarray
|
|
470
|
+
if (
|
|
471
|
+
isinstance(val, np.ndarray)
|
|
472
|
+
and val.dtype == object
|
|
473
|
+
and len(val) > 0
|
|
474
|
+
and isinstance(val[0], np.ndarray)
|
|
475
|
+
):
|
|
476
|
+
val = np.stack(list(val))
|
|
477
|
+
result.append(val)
|
|
478
|
+
return result
|
|
479
|
+
|
|
480
|
+
if isinstance(feature, (datasets.Audio, datasets.Image)):
|
|
481
|
+
# Audio/Image is stored in Arrow as struct<bytes: binary, path: string>
|
|
482
|
+
|
|
483
|
+
from pixeltable.utils.local_store import TempStore
|
|
484
|
+
|
|
485
|
+
arrow_type = column.type
|
|
486
|
+
if not pa.types.is_struct(arrow_type):
|
|
487
|
+
raise pxt.Error(f'Expected struct type for Audio column, got {arrow_type}')
|
|
488
|
+
field_names = {field.name for field in arrow_type}
|
|
489
|
+
if 'bytes' not in field_names or 'path' not in field_names:
|
|
490
|
+
raise pxt.Error(f"Audio struct missing required fields 'bytes' and/or 'path', has: {field_names}")
|
|
491
|
+
|
|
492
|
+
bytes_column = pc.struct_field(column, 'bytes')
|
|
493
|
+
path_column = pc.struct_field(column, 'path')
|
|
494
|
+
|
|
495
|
+
bytes_list = bytes_column.to_pylist()
|
|
496
|
+
path_list = path_column.to_pylist()
|
|
497
|
+
|
|
498
|
+
result = []
|
|
499
|
+
for bytes, path in zip(bytes_list, path_list):
|
|
500
|
+
if bytes is None:
|
|
501
|
+
result.append(None)
|
|
502
|
+
continue
|
|
503
|
+
# we want to preserve the extension from the original path
|
|
504
|
+
ext = Path(path).suffix if path is not None else None
|
|
505
|
+
temp_path = TempStore.create_path(extension=ext)
|
|
506
|
+
temp_path.write_bytes(bytes)
|
|
507
|
+
result.append(str(temp_path))
|
|
508
|
+
return result
|
|
509
|
+
|
|
510
|
+
if isinstance(feature, dict):
|
|
511
|
+
return self._convert_struct_column(column, feature)
|
|
512
|
+
|
|
513
|
+
if isinstance(feature, list):
|
|
514
|
+
return column.to_pylist()
|
|
515
|
+
|
|
516
|
+
# Array<N>D: multi-dimensional fixed-shape arrays
|
|
517
|
+
if isinstance(feature, (datasets.Array2D, datasets.Array3D, datasets.Array4D, datasets.Array5D)):
|
|
518
|
+
return self._convert_array_feature(column, feature.shape)
|
|
519
|
+
|
|
520
|
+
return column.to_pylist()
|
|
521
|
+
|
|
522
|
+
def _is_sequence_of_numerical(self, feature: object) -> bool:
|
|
523
|
+
"""Returns True if feature is a (nested) Sequence of numerical values."""
|
|
524
|
+
import datasets
|
|
525
|
+
|
|
526
|
+
if not isinstance(feature, datasets.Sequence):
|
|
527
|
+
return False
|
|
528
|
+
if isinstance(feature.feature, datasets.Sequence):
|
|
529
|
+
return self._is_sequence_of_numerical(feature.feature)
|
|
530
|
+
|
|
531
|
+
pa_type = feature.feature.pa_type
|
|
532
|
+
return pa_type is not None and (pat.is_integer(pa_type) or pat.is_floating(pa_type))
|
|
533
|
+
|
|
534
|
+
def _convert_struct_column(self, column: 'pa.ChunkedArray', feature: dict[str, object]) -> list[dict[str, Any]]:
|
|
535
|
+
"""
|
|
536
|
+
Convert a StructArray column to a list of dicts by recursively
|
|
537
|
+
converting each field.
|
|
538
|
+
"""
|
|
539
|
+
|
|
540
|
+
results: list[dict[str, Any]] = [{} for _ in range(len(column))]
|
|
541
|
+
for field_name, field_feature in feature.items():
|
|
542
|
+
field_column = pc.struct_field(column, field_name)
|
|
543
|
+
field_values = self._convert_column(field_column, field_feature)
|
|
544
|
+
|
|
545
|
+
for i, val in enumerate(field_values):
|
|
546
|
+
results[i][field_name] = val
|
|
547
|
+
|
|
548
|
+
return results
|
|
549
|
+
|
|
550
|
+
def _convert_array_feature(self, column: 'pa.ChunkedArray', shape: tuple[int, ...]) -> list[np.ndarray]:
|
|
551
|
+
arr: pa.ExtensionArray
|
|
552
|
+
# TODO: can we get multiple chunks here?
|
|
553
|
+
if column.num_chunks == 1:
|
|
554
|
+
arr = column.chunks[0] # type: ignore[assignment]
|
|
555
|
+
else:
|
|
556
|
+
arr = column.combine_chunks() # type: ignore[assignment]
|
|
557
|
+
|
|
558
|
+
# an Array<N>D feature is stored in Arrow as a list<list<...<dtype>>>; we want to peel off the outer lists
|
|
559
|
+
# to get to contiguous storage and then reshape that
|
|
560
|
+
storage = arr.storage
|
|
561
|
+
vals = storage.values
|
|
562
|
+
while hasattr(vals, 'values'):
|
|
563
|
+
vals = vals.values
|
|
564
|
+
flat_arr = vals.to_numpy()
|
|
565
|
+
chunk_shape = (len(column), *shape)
|
|
566
|
+
reshaped = flat_arr.reshape(chunk_shape)
|
|
567
|
+
|
|
568
|
+
# Return as list of array views (shares memory with reshaped)
|
|
569
|
+
return list(reshaped)
|
|
570
|
+
|
|
571
|
+
def valid_row_batch(self) -> Iterator['RowData']:
|
|
572
|
+
import datasets
|
|
573
|
+
|
|
574
|
+
for split_name, split_dataset in self.dataset_dict.items():
|
|
575
|
+
features = split_dataset.features
|
|
576
|
+
if isinstance(split_dataset, datasets.Dataset):
|
|
577
|
+
table = split_dataset.data # the underlying Arrow table
|
|
578
|
+
yield from self._process_arrow_table(table, split_name, features)
|
|
579
|
+
else:
|
|
580
|
+
# we're getting batches of Arrow tables, since we did set_format('arrow');
|
|
581
|
+
# use a trial batch to determine the target batch size
|
|
582
|
+
first_batch = next(split_dataset.iter(batch_size=16))
|
|
583
|
+
bytes_per_row = int(first_batch.nbytes / len(first_batch))
|
|
584
|
+
batch_size = self._K_BATCH_SIZE_BYTES // bytes_per_row
|
|
585
|
+
yield from self._process_arrow_table(first_batch, split_name, features)
|
|
586
|
+
for batch in split_dataset.skip(16).iter(batch_size=batch_size):
|
|
587
|
+
yield from self._process_arrow_table(batch, split_name, features)
|
|
588
|
+
|
|
589
|
+
def _process_arrow_table(self, table: 'pa.Table', split_name: str, features: dict[str, Any]) -> Iterator[RowData]:
|
|
590
|
+
# get chunk boundaries from first column's ChunkedArray
|
|
591
|
+
first_column = table.column(0)
|
|
592
|
+
offset = 0
|
|
593
|
+
for chunk in first_column.chunks:
|
|
594
|
+
chunk_size = len(chunk)
|
|
595
|
+
# zero-copy slice using existing chunk boundaries
|
|
596
|
+
batch = table.slice(offset, chunk_size)
|
|
597
|
+
|
|
598
|
+
# we assemble per-row dicts by from lists of per-column values
|
|
599
|
+
rows: list[dict[str, Any]] = [{} for _ in range(chunk_size)]
|
|
600
|
+
if self.column_name_for_split is not None:
|
|
601
|
+
for row in rows:
|
|
602
|
+
row[self.column_name_for_split] = split_name
|
|
603
|
+
|
|
604
|
+
for col_idx, col_name in enumerate(batch.schema.names):
|
|
605
|
+
feature = features[col_name]
|
|
606
|
+
mapped_col_name = self.source_column_map.get(col_name, col_name)
|
|
607
|
+
column = batch.column(col_idx)
|
|
608
|
+
values = self._convert_column(column, feature)
|
|
609
|
+
for i, val in enumerate(values):
|
|
610
|
+
rows[i][mapped_col_name] = val
|
|
611
|
+
|
|
612
|
+
offset += chunk_size
|
|
613
|
+
yield rows
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
class ParquetTableDataConduit(TableDataConduit):
|
|
617
|
+
pq_ds: ParquetDataset | None = None
|
|
618
|
+
|
|
619
|
+
@classmethod
|
|
620
|
+
def from_tds(cls, tds: TableDataConduit) -> 'ParquetTableDataConduit':
|
|
621
|
+
tds_fields = {f.name for f in fields(tds)}
|
|
622
|
+
kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
|
|
623
|
+
t = cls(**kwargs)
|
|
624
|
+
|
|
625
|
+
assert isinstance(tds.source, str)
|
|
626
|
+
input_path = Path(tds.source).expanduser()
|
|
627
|
+
t.pq_ds = pa.parquet.ParquetDataset(str(input_path))
|
|
628
|
+
return t
|
|
629
|
+
|
|
630
|
+
def infer_schema_part1(self) -> tuple[dict[str, ts.ColumnType], list[str]]:
|
|
631
|
+
from pixeltable.utils.arrow import to_pxt_schema
|
|
632
|
+
|
|
633
|
+
if self.source_column_map is None:
|
|
634
|
+
if self.src_schema_overrides is None:
|
|
635
|
+
self.src_schema_overrides = {}
|
|
636
|
+
self.src_schema = to_pxt_schema(self.pq_ds.schema, self.src_schema_overrides, self.src_pk)
|
|
637
|
+
inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
|
|
638
|
+
self.src_schema, self.src_pk, self.src_schema_overrides
|
|
639
|
+
)
|
|
640
|
+
return inferred_schema, inferred_pk
|
|
641
|
+
else:
|
|
642
|
+
raise NotImplementedError()
|
|
643
|
+
|
|
644
|
+
def infer_schema(self) -> dict[str, ts.ColumnType]:
|
|
645
|
+
self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
|
|
646
|
+
self.normalize_pxt_schema_types()
|
|
647
|
+
self.prepare_insert()
|
|
648
|
+
return self.pxt_schema
|
|
649
|
+
|
|
650
|
+
def prepare_for_insert_into_table(self) -> None:
|
|
651
|
+
_, inferred_pk = self.infer_schema_part1()
|
|
652
|
+
assert len(inferred_pk) == 0
|
|
653
|
+
self.prepare_insert()
|
|
654
|
+
|
|
655
|
+
def prepare_insert(self) -> None:
|
|
656
|
+
if self.source_column_map is None:
|
|
657
|
+
self.source_column_map = {}
|
|
658
|
+
self.check_source_columns_are_insertable(self.pq_ds.schema.names)
|
|
659
|
+
self.total_rows = 0
|
|
660
|
+
|
|
661
|
+
def valid_row_batch(self) -> Iterator[RowData]:
|
|
662
|
+
from pixeltable.utils.arrow import iter_tuples2
|
|
663
|
+
|
|
664
|
+
try:
|
|
665
|
+
for fragment in self.pq_ds.fragments:
|
|
666
|
+
for batch in fragment.to_batches():
|
|
667
|
+
dict_batch = list(iter_tuples2(batch, self.source_column_map, self.pxt_schema))
|
|
668
|
+
self.total_rows += len(dict_batch)
|
|
669
|
+
yield dict_batch
|
|
670
|
+
except Exception as e:
|
|
671
|
+
_logger.error(f'Error after inserting {self.total_rows} rows from Parquet file into table: {e}')
|
|
672
|
+
raise e
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
class UnkTableDataConduit(TableDataConduit):
|
|
676
|
+
"""Source type is not known at the time of creation"""
|
|
677
|
+
|
|
678
|
+
def specialize(self) -> TableDataConduit:
|
|
679
|
+
if isinstance(self.source, (pxt.Table, pxt.Query)):
|
|
680
|
+
return QueryTableDataConduit.from_tds(self)
|
|
681
|
+
if isinstance(self.source, pd.DataFrame):
|
|
682
|
+
return PandasTableDataConduit.from_tds(self)
|
|
683
|
+
if HFTableDataConduit.is_applicable(self):
|
|
684
|
+
return HFTableDataConduit.from_tds(self)
|
|
685
|
+
if self.source_format == 'csv' or (isinstance(self.source, str) and '.csv' in self.source.lower()):
|
|
686
|
+
return CSVTableDataConduit.from_tds(self)
|
|
687
|
+
if self.source_format == 'excel' or (isinstance(self.source, str) and '.xls' in self.source.lower()):
|
|
688
|
+
return ExcelTableDataConduit.from_tds(self)
|
|
689
|
+
if self.source_format == 'json' or (isinstance(self.source, str) and '.json' in self.source.lower()):
|
|
690
|
+
return JsonTableDataConduit.from_tds(self)
|
|
691
|
+
if self.source_format == 'parquet' or (
|
|
692
|
+
isinstance(self.source, str) and any(s in self.source.lower() for s in ['.parquet', '.pq', '.parq'])
|
|
693
|
+
):
|
|
694
|
+
return ParquetTableDataConduit.from_tds(self)
|
|
695
|
+
if (
|
|
696
|
+
self.is_rowdata_structure(self.source)
|
|
697
|
+
# An Iterator as a source is assumed to produce rows
|
|
698
|
+
or isinstance(self.source, Iterator)
|
|
699
|
+
):
|
|
700
|
+
return RowDataTableDataConduit.from_tds(self)
|
|
701
|
+
|
|
702
|
+
raise excs.Error(f'Unsupported data source type: {type(self.source)}')
|