pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. pixeltable/__init__.py +83 -19
  2. pixeltable/_query.py +1444 -0
  3. pixeltable/_version.py +1 -0
  4. pixeltable/catalog/__init__.py +7 -4
  5. pixeltable/catalog/catalog.py +2394 -119
  6. pixeltable/catalog/column.py +225 -104
  7. pixeltable/catalog/dir.py +38 -9
  8. pixeltable/catalog/globals.py +53 -34
  9. pixeltable/catalog/insertable_table.py +265 -115
  10. pixeltable/catalog/path.py +80 -17
  11. pixeltable/catalog/schema_object.py +28 -43
  12. pixeltable/catalog/table.py +1270 -677
  13. pixeltable/catalog/table_metadata.py +103 -0
  14. pixeltable/catalog/table_version.py +1270 -751
  15. pixeltable/catalog/table_version_handle.py +109 -0
  16. pixeltable/catalog/table_version_path.py +137 -42
  17. pixeltable/catalog/tbl_ops.py +53 -0
  18. pixeltable/catalog/update_status.py +191 -0
  19. pixeltable/catalog/view.py +251 -134
  20. pixeltable/config.py +215 -0
  21. pixeltable/env.py +736 -285
  22. pixeltable/exceptions.py +26 -2
  23. pixeltable/exec/__init__.py +7 -2
  24. pixeltable/exec/aggregation_node.py +39 -21
  25. pixeltable/exec/cache_prefetch_node.py +87 -109
  26. pixeltable/exec/cell_materialization_node.py +268 -0
  27. pixeltable/exec/cell_reconstruction_node.py +168 -0
  28. pixeltable/exec/component_iteration_node.py +25 -28
  29. pixeltable/exec/data_row_batch.py +11 -46
  30. pixeltable/exec/exec_context.py +26 -11
  31. pixeltable/exec/exec_node.py +35 -27
  32. pixeltable/exec/expr_eval/__init__.py +3 -0
  33. pixeltable/exec/expr_eval/evaluators.py +365 -0
  34. pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
  35. pixeltable/exec/expr_eval/globals.py +200 -0
  36. pixeltable/exec/expr_eval/row_buffer.py +74 -0
  37. pixeltable/exec/expr_eval/schedulers.py +413 -0
  38. pixeltable/exec/globals.py +35 -0
  39. pixeltable/exec/in_memory_data_node.py +35 -27
  40. pixeltable/exec/object_store_save_node.py +293 -0
  41. pixeltable/exec/row_update_node.py +44 -29
  42. pixeltable/exec/sql_node.py +414 -115
  43. pixeltable/exprs/__init__.py +8 -5
  44. pixeltable/exprs/arithmetic_expr.py +79 -45
  45. pixeltable/exprs/array_slice.py +5 -5
  46. pixeltable/exprs/column_property_ref.py +40 -26
  47. pixeltable/exprs/column_ref.py +254 -61
  48. pixeltable/exprs/comparison.py +14 -9
  49. pixeltable/exprs/compound_predicate.py +9 -10
  50. pixeltable/exprs/data_row.py +213 -72
  51. pixeltable/exprs/expr.py +270 -104
  52. pixeltable/exprs/expr_dict.py +6 -5
  53. pixeltable/exprs/expr_set.py +20 -11
  54. pixeltable/exprs/function_call.py +383 -284
  55. pixeltable/exprs/globals.py +18 -5
  56. pixeltable/exprs/in_predicate.py +7 -7
  57. pixeltable/exprs/inline_expr.py +37 -37
  58. pixeltable/exprs/is_null.py +8 -4
  59. pixeltable/exprs/json_mapper.py +120 -54
  60. pixeltable/exprs/json_path.py +90 -60
  61. pixeltable/exprs/literal.py +61 -16
  62. pixeltable/exprs/method_ref.py +7 -6
  63. pixeltable/exprs/object_ref.py +19 -8
  64. pixeltable/exprs/row_builder.py +238 -75
  65. pixeltable/exprs/rowid_ref.py +53 -15
  66. pixeltable/exprs/similarity_expr.py +65 -50
  67. pixeltable/exprs/sql_element_cache.py +5 -5
  68. pixeltable/exprs/string_op.py +107 -0
  69. pixeltable/exprs/type_cast.py +25 -13
  70. pixeltable/exprs/variable.py +2 -2
  71. pixeltable/func/__init__.py +9 -5
  72. pixeltable/func/aggregate_function.py +197 -92
  73. pixeltable/func/callable_function.py +119 -35
  74. pixeltable/func/expr_template_function.py +101 -48
  75. pixeltable/func/function.py +375 -62
  76. pixeltable/func/function_registry.py +20 -19
  77. pixeltable/func/globals.py +6 -5
  78. pixeltable/func/mcp.py +74 -0
  79. pixeltable/func/query_template_function.py +151 -35
  80. pixeltable/func/signature.py +178 -49
  81. pixeltable/func/tools.py +164 -0
  82. pixeltable/func/udf.py +176 -53
  83. pixeltable/functions/__init__.py +44 -4
  84. pixeltable/functions/anthropic.py +226 -47
  85. pixeltable/functions/audio.py +148 -11
  86. pixeltable/functions/bedrock.py +137 -0
  87. pixeltable/functions/date.py +188 -0
  88. pixeltable/functions/deepseek.py +113 -0
  89. pixeltable/functions/document.py +81 -0
  90. pixeltable/functions/fal.py +76 -0
  91. pixeltable/functions/fireworks.py +72 -20
  92. pixeltable/functions/gemini.py +249 -0
  93. pixeltable/functions/globals.py +208 -53
  94. pixeltable/functions/groq.py +108 -0
  95. pixeltable/functions/huggingface.py +1088 -95
  96. pixeltable/functions/image.py +155 -84
  97. pixeltable/functions/json.py +8 -11
  98. pixeltable/functions/llama_cpp.py +31 -19
  99. pixeltable/functions/math.py +169 -0
  100. pixeltable/functions/mistralai.py +50 -75
  101. pixeltable/functions/net.py +70 -0
  102. pixeltable/functions/ollama.py +29 -36
  103. pixeltable/functions/openai.py +548 -160
  104. pixeltable/functions/openrouter.py +143 -0
  105. pixeltable/functions/replicate.py +15 -14
  106. pixeltable/functions/reve.py +250 -0
  107. pixeltable/functions/string.py +310 -85
  108. pixeltable/functions/timestamp.py +37 -19
  109. pixeltable/functions/together.py +77 -120
  110. pixeltable/functions/twelvelabs.py +188 -0
  111. pixeltable/functions/util.py +7 -2
  112. pixeltable/functions/uuid.py +30 -0
  113. pixeltable/functions/video.py +1528 -117
  114. pixeltable/functions/vision.py +26 -26
  115. pixeltable/functions/voyageai.py +289 -0
  116. pixeltable/functions/whisper.py +19 -10
  117. pixeltable/functions/whisperx.py +179 -0
  118. pixeltable/functions/yolox.py +112 -0
  119. pixeltable/globals.py +716 -236
  120. pixeltable/index/__init__.py +3 -1
  121. pixeltable/index/base.py +17 -21
  122. pixeltable/index/btree.py +32 -22
  123. pixeltable/index/embedding_index.py +155 -92
  124. pixeltable/io/__init__.py +12 -7
  125. pixeltable/io/datarows.py +140 -0
  126. pixeltable/io/external_store.py +83 -125
  127. pixeltable/io/fiftyone.py +24 -33
  128. pixeltable/io/globals.py +47 -182
  129. pixeltable/io/hf_datasets.py +96 -127
  130. pixeltable/io/label_studio.py +171 -156
  131. pixeltable/io/lancedb.py +3 -0
  132. pixeltable/io/pandas.py +136 -115
  133. pixeltable/io/parquet.py +40 -153
  134. pixeltable/io/table_data_conduit.py +702 -0
  135. pixeltable/io/utils.py +100 -0
  136. pixeltable/iterators/__init__.py +8 -4
  137. pixeltable/iterators/audio.py +207 -0
  138. pixeltable/iterators/base.py +9 -3
  139. pixeltable/iterators/document.py +144 -87
  140. pixeltable/iterators/image.py +17 -38
  141. pixeltable/iterators/string.py +15 -12
  142. pixeltable/iterators/video.py +523 -127
  143. pixeltable/metadata/__init__.py +33 -8
  144. pixeltable/metadata/converters/convert_10.py +2 -3
  145. pixeltable/metadata/converters/convert_13.py +2 -2
  146. pixeltable/metadata/converters/convert_15.py +15 -11
  147. pixeltable/metadata/converters/convert_16.py +4 -5
  148. pixeltable/metadata/converters/convert_17.py +4 -5
  149. pixeltable/metadata/converters/convert_18.py +4 -6
  150. pixeltable/metadata/converters/convert_19.py +6 -9
  151. pixeltable/metadata/converters/convert_20.py +3 -6
  152. pixeltable/metadata/converters/convert_21.py +6 -8
  153. pixeltable/metadata/converters/convert_22.py +3 -2
  154. pixeltable/metadata/converters/convert_23.py +33 -0
  155. pixeltable/metadata/converters/convert_24.py +55 -0
  156. pixeltable/metadata/converters/convert_25.py +19 -0
  157. pixeltable/metadata/converters/convert_26.py +23 -0
  158. pixeltable/metadata/converters/convert_27.py +29 -0
  159. pixeltable/metadata/converters/convert_28.py +13 -0
  160. pixeltable/metadata/converters/convert_29.py +110 -0
  161. pixeltable/metadata/converters/convert_30.py +63 -0
  162. pixeltable/metadata/converters/convert_31.py +11 -0
  163. pixeltable/metadata/converters/convert_32.py +15 -0
  164. pixeltable/metadata/converters/convert_33.py +17 -0
  165. pixeltable/metadata/converters/convert_34.py +21 -0
  166. pixeltable/metadata/converters/convert_35.py +9 -0
  167. pixeltable/metadata/converters/convert_36.py +38 -0
  168. pixeltable/metadata/converters/convert_37.py +15 -0
  169. pixeltable/metadata/converters/convert_38.py +39 -0
  170. pixeltable/metadata/converters/convert_39.py +124 -0
  171. pixeltable/metadata/converters/convert_40.py +73 -0
  172. pixeltable/metadata/converters/convert_41.py +12 -0
  173. pixeltable/metadata/converters/convert_42.py +9 -0
  174. pixeltable/metadata/converters/convert_43.py +44 -0
  175. pixeltable/metadata/converters/util.py +44 -18
  176. pixeltable/metadata/notes.py +21 -0
  177. pixeltable/metadata/schema.py +185 -42
  178. pixeltable/metadata/utils.py +74 -0
  179. pixeltable/mypy/__init__.py +3 -0
  180. pixeltable/mypy/mypy_plugin.py +123 -0
  181. pixeltable/plan.py +616 -225
  182. pixeltable/share/__init__.py +3 -0
  183. pixeltable/share/packager.py +797 -0
  184. pixeltable/share/protocol/__init__.py +33 -0
  185. pixeltable/share/protocol/common.py +165 -0
  186. pixeltable/share/protocol/operation_types.py +33 -0
  187. pixeltable/share/protocol/replica.py +119 -0
  188. pixeltable/share/publish.py +349 -0
  189. pixeltable/store.py +398 -232
  190. pixeltable/type_system.py +730 -267
  191. pixeltable/utils/__init__.py +40 -0
  192. pixeltable/utils/arrow.py +201 -29
  193. pixeltable/utils/av.py +298 -0
  194. pixeltable/utils/azure_store.py +346 -0
  195. pixeltable/utils/coco.py +26 -27
  196. pixeltable/utils/code.py +4 -4
  197. pixeltable/utils/console_output.py +46 -0
  198. pixeltable/utils/coroutine.py +24 -0
  199. pixeltable/utils/dbms.py +92 -0
  200. pixeltable/utils/description_helper.py +11 -12
  201. pixeltable/utils/documents.py +60 -61
  202. pixeltable/utils/exception_handler.py +36 -0
  203. pixeltable/utils/filecache.py +38 -22
  204. pixeltable/utils/formatter.py +88 -51
  205. pixeltable/utils/gcs_store.py +295 -0
  206. pixeltable/utils/http.py +133 -0
  207. pixeltable/utils/http_server.py +14 -13
  208. pixeltable/utils/iceberg.py +13 -0
  209. pixeltable/utils/image.py +17 -0
  210. pixeltable/utils/lancedb.py +90 -0
  211. pixeltable/utils/local_store.py +322 -0
  212. pixeltable/utils/misc.py +5 -0
  213. pixeltable/utils/object_stores.py +573 -0
  214. pixeltable/utils/pydantic.py +60 -0
  215. pixeltable/utils/pytorch.py +20 -20
  216. pixeltable/utils/s3_store.py +527 -0
  217. pixeltable/utils/sql.py +32 -5
  218. pixeltable/utils/system.py +30 -0
  219. pixeltable/utils/transactional_directory.py +4 -3
  220. pixeltable-0.5.7.dist-info/METADATA +579 -0
  221. pixeltable-0.5.7.dist-info/RECORD +227 -0
  222. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
  223. pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
  224. pixeltable/__version__.py +0 -3
  225. pixeltable/catalog/named_function.py +0 -36
  226. pixeltable/catalog/path_dict.py +0 -141
  227. pixeltable/dataframe.py +0 -894
  228. pixeltable/exec/expr_eval_node.py +0 -232
  229. pixeltable/ext/__init__.py +0 -14
  230. pixeltable/ext/functions/__init__.py +0 -8
  231. pixeltable/ext/functions/whisperx.py +0 -77
  232. pixeltable/ext/functions/yolox.py +0 -157
  233. pixeltable/tool/create_test_db_dump.py +0 -311
  234. pixeltable/tool/create_test_video.py +0 -81
  235. pixeltable/tool/doc_plugins/griffe.py +0 -50
  236. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  237. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  238. pixeltable/tool/embed_udf.py +0 -9
  239. pixeltable/tool/mypy_plugin.py +0 -55
  240. pixeltable/utils/media_store.py +0 -76
  241. pixeltable/utils/s3.py +0 -16
  242. pixeltable-0.2.26.dist-info/METADATA +0 -400
  243. pixeltable-0.2.26.dist-info/RECORD +0 -156
  244. pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
  245. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,702 @@
1
+ from __future__ import annotations
2
+
3
+ import enum
4
+ import json
5
+ import logging
6
+ import urllib.request
7
+ from dataclasses import dataclass, field, fields
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, cast
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import pyarrow as pa
14
+ import pyarrow.compute as pc
15
+ import pyarrow.types as pat
16
+ from pyarrow.parquet import ParquetDataset
17
+
18
+ import pixeltable as pxt
19
+ import pixeltable.exceptions as excs
20
+ import pixeltable.type_system as ts
21
+ from pixeltable.io.pandas import _df_check_primary_key_values, _df_row_to_pxt_row, df_infer_schema
22
+ from pixeltable.utils import parse_local_file_path
23
+
24
+ from .utils import normalize_schema_names
25
+
26
+ _logger = logging.getLogger('pixeltable')
27
+
28
+
29
+ if TYPE_CHECKING:
30
+ import datasets # type: ignore[import-untyped]
31
+
32
+ from pixeltable.globals import RowData, TableDataSource
33
+
34
+
35
+ class TableDataConduitFormat(str, enum.Enum):
36
+ """Supported formats for TableDataConduit"""
37
+
38
+ JSON = 'json'
39
+ CSV = 'csv'
40
+ EXCEL = 'excel'
41
+ PARQUET = 'parquet'
42
+
43
+ @classmethod
44
+ def is_valid(cls, x: Any) -> bool:
45
+ if isinstance(x, str):
46
+ return x.lower() in [c.value for c in cls]
47
+ return False
48
+
49
+
50
+ @dataclass
51
+ class TableDataConduit:
52
+ source: 'TableDataSource'
53
+ source_format: str | None = None
54
+ source_column_map: dict[str, str] | None = None
55
+ if_row_exists: Literal['update', 'ignore', 'error'] = 'error'
56
+ pxt_schema: dict[str, ts.ColumnType] | None = None
57
+ src_schema_overrides: dict[str, ts.ColumnType] | None = None
58
+ src_schema: dict[str, ts.ColumnType] | None = None
59
+ pxt_pk: list[str] | None = None
60
+ src_pk: list[str] | None = None
61
+ valid_rows: RowData | None = None
62
+ extra_fields: dict[str, Any] = field(default_factory=dict)
63
+
64
+ reqd_col_names: set[str] = field(default_factory=set)
65
+ computed_col_names: set[str] = field(default_factory=set)
66
+
67
+ total_rows: int = 0 # total number of rows emitted via valid_row_batch Iterator
68
+
69
+ _K_BATCH_SIZE_BYTES = 256 * 2**20
70
+
71
+ def check_source_format(self) -> None:
72
+ assert self.source_format is None or TableDataConduitFormat.is_valid(self.source_format)
73
+
74
+ def __post_init__(self) -> None:
75
+ """If no extra_fields were provided, initialize to empty dict"""
76
+ if self.extra_fields is None:
77
+ self.extra_fields = {}
78
+
79
+ @classmethod
80
+ def is_rowdata_structure(cls, d: TableDataSource) -> bool:
81
+ if not isinstance(d, list) or len(d) == 0:
82
+ return False
83
+ return all(isinstance(row, dict) for row in d)
84
+
85
+ def is_direct_query(self) -> bool:
86
+ return isinstance(self.source, pxt.Query) and self.source_column_map is None
87
+
88
+ def normalize_pxt_schema_types(self) -> None:
89
+ for name, coltype in self.pxt_schema.items():
90
+ self.pxt_schema[name] = ts.ColumnType.normalize_type(coltype)
91
+
92
+ def infer_schema(self) -> dict[str, ts.ColumnType]:
93
+ raise NotImplementedError
94
+
95
+ def valid_row_batch(self) -> Iterator[RowData]:
96
+ raise NotImplementedError
97
+
98
+ def prepare_for_insert_into_table(self) -> None:
99
+ if self.source is None:
100
+ return
101
+ raise NotImplementedError
102
+
103
+ def add_table_info(self, table: pxt.Table) -> None:
104
+ """Add information about the table into which we are inserting data"""
105
+ assert isinstance(table, pxt.Table)
106
+ self.pxt_schema = table._get_schema()
107
+ self.pxt_pk = table._tbl_version.get().primary_key
108
+ for col in table._tbl_version_path.columns():
109
+ if col.is_required_for_insert:
110
+ self.reqd_col_names.add(col.name)
111
+ if col.is_computed:
112
+ self.computed_col_names.add(col.name)
113
+ self.src_pk = []
114
+
115
+ # Check source columns : required, computed, unknown
116
+ def check_source_columns_are_insertable(self, columns: Iterable[str]) -> None:
117
+ col_name_set: set[str] = set()
118
+ for col_name in columns: # FIXME
119
+ mapped_col_name = self.source_column_map.get(col_name, col_name)
120
+ col_name_set.add(mapped_col_name)
121
+ if mapped_col_name not in self.pxt_schema:
122
+ raise excs.Error(f'Unknown column name {mapped_col_name}')
123
+ if mapped_col_name in self.computed_col_names:
124
+ raise excs.Error(f'Value for computed column {mapped_col_name}')
125
+ missing_cols = self.reqd_col_names - col_name_set
126
+ if len(missing_cols) > 0:
127
+ raise excs.Error(f'Missing required column(s) ({", ".join(missing_cols)})')
128
+
129
+
130
+ class QueryTableDataConduit(TableDataConduit):
131
+ pxt_query: pxt.Query = None
132
+
133
+ @classmethod
134
+ def from_tds(cls, tds: TableDataConduit) -> 'QueryTableDataConduit':
135
+ tds_fields = {f.name for f in fields(tds)}
136
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
137
+ t = cls(**kwargs)
138
+ if isinstance(tds.source, pxt.Table):
139
+ t.pxt_query = tds.source.select()
140
+ else:
141
+ assert isinstance(tds.source, pxt.Query)
142
+ t.pxt_query = tds.source
143
+ return t
144
+
145
+ def infer_schema(self) -> dict[str, ts.ColumnType]:
146
+ self.pxt_schema = self.pxt_query.schema
147
+ self.pxt_pk = self.src_pk
148
+ return self.pxt_schema
149
+
150
+ def prepare_for_insert_into_table(self) -> None:
151
+ if self.source_column_map is None:
152
+ self.source_column_map = {}
153
+ self.check_source_columns_are_insertable(self.pxt_query.schema.keys())
154
+
155
+
156
+ class RowDataTableDataConduit(TableDataConduit):
157
+ raw_rows: RowData | None = None
158
+ disable_mapping: bool = True
159
+ batch_count: int = 0
160
+
161
+ @classmethod
162
+ def from_tds(cls, tds: TableDataConduit) -> 'RowDataTableDataConduit':
163
+ tds_fields = {f.name for f in fields(tds)}
164
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
165
+ t = cls(**kwargs)
166
+ if isinstance(tds.source, Iterator):
167
+ # Instantiate the iterator to get the raw rows here
168
+ t.raw_rows = list(tds.source)
169
+ elif TYPE_CHECKING:
170
+ t.raw_rows = cast(RowData, tds.source)
171
+ else:
172
+ t.raw_rows = tds.source
173
+ t.batch_count = 0
174
+ return t
175
+
176
+ def infer_schema(self) -> dict[str, ts.ColumnType]:
177
+ from .datarows import _infer_schema_from_rows
178
+
179
+ if self.source_column_map is None:
180
+ if self.src_schema_overrides is None:
181
+ self.src_schema_overrides = {}
182
+ self.src_schema = _infer_schema_from_rows(self.raw_rows, self.src_schema_overrides, self.src_pk)
183
+ self.pxt_schema, self.pxt_pk, self.source_column_map = normalize_schema_names(
184
+ self.src_schema, self.src_pk, self.src_schema_overrides, self.disable_mapping
185
+ )
186
+ self.normalize_pxt_schema_types()
187
+ else:
188
+ raise NotImplementedError()
189
+
190
+ self.prepare_for_insert_into_table()
191
+ return self.pxt_schema
192
+
193
+ def prepare_for_insert_into_table(self) -> None:
194
+ # Converting rows to insertable format is not needed, misnamed columns and types
195
+ # are errors in the incoming row format
196
+ if self.source_column_map is None:
197
+ self.source_column_map = {}
198
+ self.valid_rows = [self._translate_row(row) for row in self.raw_rows]
199
+
200
+ self.batch_count = 1 if self.raw_rows is not None else 0
201
+
202
+ def _translate_row(self, row: dict[str, Any]) -> dict[str, Any]:
203
+ if not isinstance(row, dict):
204
+ raise excs.Error(f'row {row} is not a dictionary')
205
+
206
+ col_names: set[str] = set()
207
+ output_row: dict[str, Any] = {}
208
+ for col_name, val in row.items():
209
+ mapped_col_name = self.source_column_map.get(col_name, col_name)
210
+ col_names.add(mapped_col_name)
211
+ if mapped_col_name not in self.pxt_schema:
212
+ raise excs.Error(f'Unknown column name {mapped_col_name} in row {row}')
213
+ if mapped_col_name in self.computed_col_names:
214
+ raise excs.Error(f'Value for computed column {mapped_col_name} in row {row}')
215
+ # basic sanity checks here
216
+ try:
217
+ checked_val = self.pxt_schema[mapped_col_name].create_literal(val)
218
+ except TypeError as e:
219
+ msg = str(e)
220
+ raise excs.Error(f'Error in column {col_name}: {msg[0].lower() + msg[1:]}\nRow: {row}') from e
221
+ output_row[mapped_col_name] = checked_val
222
+ missing_cols = self.reqd_col_names - col_names
223
+ if len(missing_cols) > 0:
224
+ raise excs.Error(f'Missing required column(s) ({", ".join(missing_cols)}) in row {row}')
225
+ return output_row
226
+
227
+ def valid_row_batch(self) -> Iterator[RowData]:
228
+ if self.batch_count > 0:
229
+ self.batch_count -= 1
230
+ yield self.valid_rows
231
+
232
+
233
+ class PandasTableDataConduit(TableDataConduit):
234
+ pd_df: pd.DataFrame = None
235
+ batch_count: int = 0
236
+
237
+ @classmethod
238
+ def from_tds(cls, tds: TableDataConduit) -> PandasTableDataConduit:
239
+ tds_fields = {f.name for f in fields(tds)}
240
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
241
+ t = cls(**kwargs)
242
+ assert isinstance(tds.source, pd.DataFrame)
243
+ t.pd_df = tds.source
244
+ t.batch_count = 0
245
+ return t
246
+
247
+ def infer_schema_part1(self) -> tuple[dict[str, ts.ColumnType], list[str]]:
248
+ """Return inferred schema, inferred primary key, and source column map"""
249
+ if self.source_column_map is None:
250
+ if self.src_schema_overrides is None:
251
+ self.src_schema_overrides = {}
252
+ self.src_schema = df_infer_schema(self.pd_df, self.src_schema_overrides, self.src_pk)
253
+ inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
254
+ self.src_schema, self.src_pk, self.src_schema_overrides, False
255
+ )
256
+ return inferred_schema, inferred_pk
257
+ else:
258
+ raise NotImplementedError()
259
+
260
+ def infer_schema(self) -> dict[str, ts.ColumnType]:
261
+ self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
262
+ self.normalize_pxt_schema_types()
263
+ _df_check_primary_key_values(self.pd_df, self.src_pk)
264
+ self.prepare_insert()
265
+ return self.pxt_schema
266
+
267
+ def prepare_for_insert_into_table(self) -> None:
268
+ _, inferred_pk = self.infer_schema_part1()
269
+ assert len(inferred_pk) == 0
270
+ self.prepare_insert()
271
+
272
+ def prepare_insert(self) -> None:
273
+ if self.source_column_map is None:
274
+ self.source_column_map = {}
275
+ self.check_source_columns_are_insertable(self.pd_df.columns)
276
+ # Convert all rows to insertable format
277
+ self.valid_rows = [
278
+ _df_row_to_pxt_row(row, self.src_schema, self.source_column_map) for row in self.pd_df.itertuples()
279
+ ]
280
+ self.batch_count = 1
281
+
282
+ def valid_row_batch(self) -> Iterator[RowData]:
283
+ if self.batch_count > 0:
284
+ self.batch_count -= 1
285
+ yield self.valid_rows
286
+
287
+
288
+ class CSVTableDataConduit(TableDataConduit):
289
+ @classmethod
290
+ def from_tds(cls, tds: TableDataConduit) -> 'PandasTableDataConduit':
291
+ tds_fields = {f.name for f in fields(tds)}
292
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
293
+ t = cls(**kwargs)
294
+ assert isinstance(t.source, str)
295
+ t.source = pd.read_csv(t.source, **t.extra_fields)
296
+ return PandasTableDataConduit.from_tds(t)
297
+
298
+
299
+ class ExcelTableDataConduit(TableDataConduit):
300
+ @classmethod
301
+ def from_tds(cls, tds: TableDataConduit) -> 'PandasTableDataConduit':
302
+ tds_fields = {f.name for f in fields(tds)}
303
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
304
+ t = cls(**kwargs)
305
+ assert isinstance(t.source, str)
306
+ t.source = pd.read_excel(t.source, **t.extra_fields)
307
+ return PandasTableDataConduit.from_tds(t)
308
+
309
+
310
+ class JsonTableDataConduit(TableDataConduit):
311
+ @classmethod
312
+ def from_tds(cls, tds: TableDataConduit) -> RowDataTableDataConduit:
313
+ tds_fields = {f.name for f in fields(tds)}
314
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
315
+ t = cls(**kwargs)
316
+ assert isinstance(t.source, str)
317
+
318
+ path = parse_local_file_path(t.source)
319
+ if path is None: # it's a URL
320
+ # TODO: This should read from S3 as well.
321
+ contents = urllib.request.urlopen(t.source).read()
322
+ else:
323
+ with open(path, 'r', encoding='utf-8') as fp:
324
+ contents = fp.read()
325
+ rows = json.loads(contents, **t.extra_fields)
326
+ t.source = rows
327
+ t2 = RowDataTableDataConduit.from_tds(t)
328
+ t2.disable_mapping = False
329
+ return t2
330
+
331
+
332
+ class HFTableDataConduit(TableDataConduit):
333
+ """HuggingFace dataset importer"""
334
+
335
+ column_name_for_split: str | None = None
336
+ categorical_features: dict[str, dict[int, str]]
337
+ dataset_dict: dict[str, 'datasets.Dataset'] = None # key: split name
338
+ hf_schema_source: dict[str, Any] = None
339
+
340
+ @classmethod
341
+ def from_tds(cls, tds: TableDataConduit) -> HFTableDataConduit:
342
+ tds_fields = {f.name for f in fields(tds)}
343
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
344
+ t = cls(**kwargs)
345
+ import datasets
346
+
347
+ assert isinstance(tds.source, cls._get_dataset_classes())
348
+ if 'column_name_for_split' in t.extra_fields:
349
+ t.column_name_for_split = t.extra_fields['column_name_for_split']
350
+
351
+ if isinstance(tds.source, (datasets.IterableDataset, datasets.IterableDatasetDict)):
352
+ tds.source = tds.source.with_format('arrow')
353
+
354
+ if isinstance(tds.source, (datasets.Dataset, datasets.IterableDataset)):
355
+ split_name = str(tds.source.split) if tds.source.split is not None else None
356
+ t.dataset_dict = {split_name: tds.source}
357
+ else:
358
+ assert isinstance(tds.source, (datasets.DatasetDict, datasets.IterableDatasetDict))
359
+ t.dataset_dict = dict(tds.source)
360
+
361
+ # Disable auto-decoding for Audio and Image columns, we want to write the bytes directly to temp files
362
+ for ds_split_name, dataset in list(t.dataset_dict.items()):
363
+ for col_name, feature in dataset.features.items():
364
+ if isinstance(feature, (datasets.Audio, datasets.Image)):
365
+ t.dataset_dict[ds_split_name] = t.dataset_dict[ds_split_name].cast_column(
366
+ col_name, feature.__class__(decode=False)
367
+ )
368
+ return t
369
+
370
+ @classmethod
371
+ def _get_dataset_classes(cls) -> tuple[type, ...]:
372
+ import datasets
373
+
374
+ return (datasets.Dataset, datasets.DatasetDict, datasets.IterableDataset, datasets.IterableDatasetDict)
375
+
376
+ @classmethod
377
+ def is_applicable(cls, tds: TableDataConduit) -> bool:
378
+ try:
379
+ return (isinstance(tds.source_format, str) and tds.source_format.lower() == 'huggingface') or isinstance(
380
+ tds.source, cls._get_dataset_classes()
381
+ )
382
+ except ImportError:
383
+ return False
384
+
385
+ def infer_schema_part1(self) -> tuple[dict[str, ts.ColumnType], list[str]]:
386
+ from pixeltable.io.hf_datasets import _get_hf_schema, huggingface_schema_to_pxt_schema
387
+
388
+ if self.source_column_map is None:
389
+ if self.src_schema_overrides is None:
390
+ self.src_schema_overrides = {}
391
+ if self.src_pk is None:
392
+ self.src_pk = []
393
+ self.hf_schema_source = _get_hf_schema(self.source)
394
+ self.src_schema = huggingface_schema_to_pxt_schema(
395
+ self.hf_schema_source, self.src_schema_overrides, self.src_pk
396
+ )
397
+
398
+ # Add the split column to the schema if requested
399
+ if self.column_name_for_split is not None:
400
+ if self.column_name_for_split in self.src_schema:
401
+ raise excs.Error(
402
+ f'Column name `{self.column_name_for_split}` already exists in dataset schema;'
403
+ f'provide a different `column_name_for_split`'
404
+ )
405
+ self.src_schema[self.column_name_for_split] = ts.StringType(nullable=True)
406
+
407
+ inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
408
+ self.src_schema, self.src_pk, self.src_schema_overrides
409
+ )
410
+ return inferred_schema, inferred_pk
411
+ else:
412
+ raise NotImplementedError()
413
+
414
+ def infer_schema(self) -> dict[str, Any]:
415
+ self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
416
+ self.normalize_pxt_schema_types()
417
+ self.prepare_insert()
418
+ return self.pxt_schema
419
+
420
+ def prepare_for_insert_into_table(self) -> None:
421
+ _, inferred_pk = self.infer_schema_part1()
422
+ assert len(inferred_pk) == 0
423
+ self.prepare_insert()
424
+
425
+ def prepare_insert(self) -> None:
426
+ import datasets
427
+
428
+ # Extract all class labels from the dataset to translate category ints to strings
429
+ self.categorical_features = {
430
+ feature_name: feature_type.names
431
+ for (feature_name, feature_type) in self.hf_schema_source.items()
432
+ if isinstance(feature_type, datasets.ClassLabel)
433
+ }
434
+ if self.source_column_map is None:
435
+ self.source_column_map = {}
436
+ self.check_source_columns_are_insertable(self.hf_schema_source.keys())
437
+
438
+ def _convert_column(self, column: 'pa.ChunkedArray', feature: object) -> list:
439
+ """
440
+ Convert an Arrow column to a list of Python values based on HF feature type.
441
+ Handles all feature types at the column level, recursing for structs.
442
+ Returns a list of length chunk_size.
443
+ """
444
+ import datasets
445
+
446
+ # return scalars as Python scalars
447
+ if isinstance(feature, datasets.Value):
448
+ return column.to_pylist()
449
+
450
+ # ClassLabel: int -> string name
451
+ if isinstance(feature, datasets.ClassLabel):
452
+ values = column.to_pylist()
453
+ return [feature.names[v] if v is not None else None for v in values]
454
+
455
+ # check for list of dict before Sequence, which could contain array data
456
+ is_list_of_dict = isinstance(feature, (datasets.Sequence, datasets.LargeList)) and isinstance(
457
+ feature.feature, dict
458
+ )
459
+ if is_list_of_dict:
460
+ return column.to_pylist()
461
+
462
+ # array data represented as a (possibly nested) sequence of numerical data: convert to numpy arrays
463
+ if self._is_sequence_of_numerical(feature):
464
+ arr = column.to_numpy(zero_copy_only=False)
465
+ result: list = []
466
+ for i in range(len(column)):
467
+ val = arr[i]
468
+ assert not isinstance(val, dict) # we dealt with list of dicts earlier
469
+ # convert object array of arrays (e.g., multi-channel audio) to proper ndarray
470
+ if (
471
+ isinstance(val, np.ndarray)
472
+ and val.dtype == object
473
+ and len(val) > 0
474
+ and isinstance(val[0], np.ndarray)
475
+ ):
476
+ val = np.stack(list(val))
477
+ result.append(val)
478
+ return result
479
+
480
+ if isinstance(feature, (datasets.Audio, datasets.Image)):
481
+ # Audio/Image is stored in Arrow as struct<bytes: binary, path: string>
482
+
483
+ from pixeltable.utils.local_store import TempStore
484
+
485
+ arrow_type = column.type
486
+ if not pa.types.is_struct(arrow_type):
487
+ raise pxt.Error(f'Expected struct type for Audio column, got {arrow_type}')
488
+ field_names = {field.name for field in arrow_type}
489
+ if 'bytes' not in field_names or 'path' not in field_names:
490
+ raise pxt.Error(f"Audio struct missing required fields 'bytes' and/or 'path', has: {field_names}")
491
+
492
+ bytes_column = pc.struct_field(column, 'bytes')
493
+ path_column = pc.struct_field(column, 'path')
494
+
495
+ bytes_list = bytes_column.to_pylist()
496
+ path_list = path_column.to_pylist()
497
+
498
+ result = []
499
+ for bytes, path in zip(bytes_list, path_list):
500
+ if bytes is None:
501
+ result.append(None)
502
+ continue
503
+ # we want to preserve the extension from the original path
504
+ ext = Path(path).suffix if path is not None else None
505
+ temp_path = TempStore.create_path(extension=ext)
506
+ temp_path.write_bytes(bytes)
507
+ result.append(str(temp_path))
508
+ return result
509
+
510
+ if isinstance(feature, dict):
511
+ return self._convert_struct_column(column, feature)
512
+
513
+ if isinstance(feature, list):
514
+ return column.to_pylist()
515
+
516
+ # Array<N>D: multi-dimensional fixed-shape arrays
517
+ if isinstance(feature, (datasets.Array2D, datasets.Array3D, datasets.Array4D, datasets.Array5D)):
518
+ return self._convert_array_feature(column, feature.shape)
519
+
520
+ return column.to_pylist()
521
+
522
+ def _is_sequence_of_numerical(self, feature: object) -> bool:
523
+ """Returns True if feature is a (nested) Sequence of numerical values."""
524
+ import datasets
525
+
526
+ if not isinstance(feature, datasets.Sequence):
527
+ return False
528
+ if isinstance(feature.feature, datasets.Sequence):
529
+ return self._is_sequence_of_numerical(feature.feature)
530
+
531
+ pa_type = feature.feature.pa_type
532
+ return pa_type is not None and (pat.is_integer(pa_type) or pat.is_floating(pa_type))
533
+
534
+ def _convert_struct_column(self, column: 'pa.ChunkedArray', feature: dict[str, object]) -> list[dict[str, Any]]:
535
+ """
536
+ Convert a StructArray column to a list of dicts by recursively
537
+ converting each field.
538
+ """
539
+
540
+ results: list[dict[str, Any]] = [{} for _ in range(len(column))]
541
+ for field_name, field_feature in feature.items():
542
+ field_column = pc.struct_field(column, field_name)
543
+ field_values = self._convert_column(field_column, field_feature)
544
+
545
+ for i, val in enumerate(field_values):
546
+ results[i][field_name] = val
547
+
548
+ return results
549
+
550
+ def _convert_array_feature(self, column: 'pa.ChunkedArray', shape: tuple[int, ...]) -> list[np.ndarray]:
551
+ arr: pa.ExtensionArray
552
+ # TODO: can we get multiple chunks here?
553
+ if column.num_chunks == 1:
554
+ arr = column.chunks[0] # type: ignore[assignment]
555
+ else:
556
+ arr = column.combine_chunks() # type: ignore[assignment]
557
+
558
+ # an Array<N>D feature is stored in Arrow as a list<list<...<dtype>>>; we want to peel off the outer lists
559
+ # to get to contiguous storage and then reshape that
560
+ storage = arr.storage
561
+ vals = storage.values
562
+ while hasattr(vals, 'values'):
563
+ vals = vals.values
564
+ flat_arr = vals.to_numpy()
565
+ chunk_shape = (len(column), *shape)
566
+ reshaped = flat_arr.reshape(chunk_shape)
567
+
568
+ # Return as list of array views (shares memory with reshaped)
569
+ return list(reshaped)
570
+
571
+ def valid_row_batch(self) -> Iterator['RowData']:
572
+ import datasets
573
+
574
+ for split_name, split_dataset in self.dataset_dict.items():
575
+ features = split_dataset.features
576
+ if isinstance(split_dataset, datasets.Dataset):
577
+ table = split_dataset.data # the underlying Arrow table
578
+ yield from self._process_arrow_table(table, split_name, features)
579
+ else:
580
+ # we're getting batches of Arrow tables, since we did set_format('arrow');
581
+ # use a trial batch to determine the target batch size
582
+ first_batch = next(split_dataset.iter(batch_size=16))
583
+ bytes_per_row = int(first_batch.nbytes / len(first_batch))
584
+ batch_size = self._K_BATCH_SIZE_BYTES // bytes_per_row
585
+ yield from self._process_arrow_table(first_batch, split_name, features)
586
+ for batch in split_dataset.skip(16).iter(batch_size=batch_size):
587
+ yield from self._process_arrow_table(batch, split_name, features)
588
+
589
+ def _process_arrow_table(self, table: 'pa.Table', split_name: str, features: dict[str, Any]) -> Iterator[RowData]:
590
+ # get chunk boundaries from first column's ChunkedArray
591
+ first_column = table.column(0)
592
+ offset = 0
593
+ for chunk in first_column.chunks:
594
+ chunk_size = len(chunk)
595
+ # zero-copy slice using existing chunk boundaries
596
+ batch = table.slice(offset, chunk_size)
597
+
598
+ # we assemble per-row dicts by from lists of per-column values
599
+ rows: list[dict[str, Any]] = [{} for _ in range(chunk_size)]
600
+ if self.column_name_for_split is not None:
601
+ for row in rows:
602
+ row[self.column_name_for_split] = split_name
603
+
604
+ for col_idx, col_name in enumerate(batch.schema.names):
605
+ feature = features[col_name]
606
+ mapped_col_name = self.source_column_map.get(col_name, col_name)
607
+ column = batch.column(col_idx)
608
+ values = self._convert_column(column, feature)
609
+ for i, val in enumerate(values):
610
+ rows[i][mapped_col_name] = val
611
+
612
+ offset += chunk_size
613
+ yield rows
614
+
615
+
616
+ class ParquetTableDataConduit(TableDataConduit):
617
+ pq_ds: ParquetDataset | None = None
618
+
619
+ @classmethod
620
+ def from_tds(cls, tds: TableDataConduit) -> 'ParquetTableDataConduit':
621
+ tds_fields = {f.name for f in fields(tds)}
622
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
623
+ t = cls(**kwargs)
624
+
625
+ assert isinstance(tds.source, str)
626
+ input_path = Path(tds.source).expanduser()
627
+ t.pq_ds = pa.parquet.ParquetDataset(str(input_path))
628
+ return t
629
+
630
+ def infer_schema_part1(self) -> tuple[dict[str, ts.ColumnType], list[str]]:
631
+ from pixeltable.utils.arrow import to_pxt_schema
632
+
633
+ if self.source_column_map is None:
634
+ if self.src_schema_overrides is None:
635
+ self.src_schema_overrides = {}
636
+ self.src_schema = to_pxt_schema(self.pq_ds.schema, self.src_schema_overrides, self.src_pk)
637
+ inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
638
+ self.src_schema, self.src_pk, self.src_schema_overrides
639
+ )
640
+ return inferred_schema, inferred_pk
641
+ else:
642
+ raise NotImplementedError()
643
+
644
+ def infer_schema(self) -> dict[str, ts.ColumnType]:
645
+ self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
646
+ self.normalize_pxt_schema_types()
647
+ self.prepare_insert()
648
+ return self.pxt_schema
649
+
650
+ def prepare_for_insert_into_table(self) -> None:
651
+ _, inferred_pk = self.infer_schema_part1()
652
+ assert len(inferred_pk) == 0
653
+ self.prepare_insert()
654
+
655
+ def prepare_insert(self) -> None:
656
+ if self.source_column_map is None:
657
+ self.source_column_map = {}
658
+ self.check_source_columns_are_insertable(self.pq_ds.schema.names)
659
+ self.total_rows = 0
660
+
661
+ def valid_row_batch(self) -> Iterator[RowData]:
662
+ from pixeltable.utils.arrow import iter_tuples2
663
+
664
+ try:
665
+ for fragment in self.pq_ds.fragments:
666
+ for batch in fragment.to_batches():
667
+ dict_batch = list(iter_tuples2(batch, self.source_column_map, self.pxt_schema))
668
+ self.total_rows += len(dict_batch)
669
+ yield dict_batch
670
+ except Exception as e:
671
+ _logger.error(f'Error after inserting {self.total_rows} rows from Parquet file into table: {e}')
672
+ raise e
673
+
674
+
675
+ class UnkTableDataConduit(TableDataConduit):
676
+ """Source type is not known at the time of creation"""
677
+
678
+ def specialize(self) -> TableDataConduit:
679
+ if isinstance(self.source, (pxt.Table, pxt.Query)):
680
+ return QueryTableDataConduit.from_tds(self)
681
+ if isinstance(self.source, pd.DataFrame):
682
+ return PandasTableDataConduit.from_tds(self)
683
+ if HFTableDataConduit.is_applicable(self):
684
+ return HFTableDataConduit.from_tds(self)
685
+ if self.source_format == 'csv' or (isinstance(self.source, str) and '.csv' in self.source.lower()):
686
+ return CSVTableDataConduit.from_tds(self)
687
+ if self.source_format == 'excel' or (isinstance(self.source, str) and '.xls' in self.source.lower()):
688
+ return ExcelTableDataConduit.from_tds(self)
689
+ if self.source_format == 'json' or (isinstance(self.source, str) and '.json' in self.source.lower()):
690
+ return JsonTableDataConduit.from_tds(self)
691
+ if self.source_format == 'parquet' or (
692
+ isinstance(self.source, str) and any(s in self.source.lower() for s in ['.parquet', '.pq', '.parq'])
693
+ ):
694
+ return ParquetTableDataConduit.from_tds(self)
695
+ if (
696
+ self.is_rowdata_structure(self.source)
697
+ # An Iterator as a source is assumed to produce rows
698
+ or isinstance(self.source, Iterator)
699
+ ):
700
+ return RowDataTableDataConduit.from_tds(self)
701
+
702
+ raise excs.Error(f'Unsupported data source type: {type(self.source)}')