pixeltable 0.3.14__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. pixeltable/__init__.py +42 -8
  2. pixeltable/{dataframe.py → _query.py} +470 -206
  3. pixeltable/_version.py +1 -0
  4. pixeltable/catalog/__init__.py +5 -4
  5. pixeltable/catalog/catalog.py +1785 -432
  6. pixeltable/catalog/column.py +190 -113
  7. pixeltable/catalog/dir.py +2 -4
  8. pixeltable/catalog/globals.py +19 -46
  9. pixeltable/catalog/insertable_table.py +191 -98
  10. pixeltable/catalog/path.py +63 -23
  11. pixeltable/catalog/schema_object.py +11 -15
  12. pixeltable/catalog/table.py +843 -436
  13. pixeltable/catalog/table_metadata.py +103 -0
  14. pixeltable/catalog/table_version.py +978 -657
  15. pixeltable/catalog/table_version_handle.py +72 -16
  16. pixeltable/catalog/table_version_path.py +112 -43
  17. pixeltable/catalog/tbl_ops.py +53 -0
  18. pixeltable/catalog/update_status.py +191 -0
  19. pixeltable/catalog/view.py +134 -90
  20. pixeltable/config.py +134 -22
  21. pixeltable/env.py +471 -157
  22. pixeltable/exceptions.py +6 -0
  23. pixeltable/exec/__init__.py +4 -1
  24. pixeltable/exec/aggregation_node.py +7 -8
  25. pixeltable/exec/cache_prefetch_node.py +83 -110
  26. pixeltable/exec/cell_materialization_node.py +268 -0
  27. pixeltable/exec/cell_reconstruction_node.py +168 -0
  28. pixeltable/exec/component_iteration_node.py +4 -3
  29. pixeltable/exec/data_row_batch.py +8 -65
  30. pixeltable/exec/exec_context.py +16 -4
  31. pixeltable/exec/exec_node.py +13 -36
  32. pixeltable/exec/expr_eval/evaluators.py +11 -7
  33. pixeltable/exec/expr_eval/expr_eval_node.py +27 -12
  34. pixeltable/exec/expr_eval/globals.py +8 -5
  35. pixeltable/exec/expr_eval/row_buffer.py +1 -2
  36. pixeltable/exec/expr_eval/schedulers.py +106 -56
  37. pixeltable/exec/globals.py +35 -0
  38. pixeltable/exec/in_memory_data_node.py +19 -19
  39. pixeltable/exec/object_store_save_node.py +293 -0
  40. pixeltable/exec/row_update_node.py +16 -9
  41. pixeltable/exec/sql_node.py +351 -84
  42. pixeltable/exprs/__init__.py +1 -1
  43. pixeltable/exprs/arithmetic_expr.py +27 -22
  44. pixeltable/exprs/array_slice.py +3 -3
  45. pixeltable/exprs/column_property_ref.py +36 -23
  46. pixeltable/exprs/column_ref.py +213 -89
  47. pixeltable/exprs/comparison.py +5 -5
  48. pixeltable/exprs/compound_predicate.py +5 -4
  49. pixeltable/exprs/data_row.py +164 -54
  50. pixeltable/exprs/expr.py +70 -44
  51. pixeltable/exprs/expr_dict.py +3 -3
  52. pixeltable/exprs/expr_set.py +17 -10
  53. pixeltable/exprs/function_call.py +100 -40
  54. pixeltable/exprs/globals.py +2 -2
  55. pixeltable/exprs/in_predicate.py +4 -4
  56. pixeltable/exprs/inline_expr.py +18 -32
  57. pixeltable/exprs/is_null.py +7 -3
  58. pixeltable/exprs/json_mapper.py +8 -8
  59. pixeltable/exprs/json_path.py +56 -22
  60. pixeltable/exprs/literal.py +27 -5
  61. pixeltable/exprs/method_ref.py +2 -2
  62. pixeltable/exprs/object_ref.py +2 -2
  63. pixeltable/exprs/row_builder.py +167 -67
  64. pixeltable/exprs/rowid_ref.py +25 -10
  65. pixeltable/exprs/similarity_expr.py +58 -40
  66. pixeltable/exprs/sql_element_cache.py +4 -4
  67. pixeltable/exprs/string_op.py +5 -5
  68. pixeltable/exprs/type_cast.py +3 -5
  69. pixeltable/func/__init__.py +1 -0
  70. pixeltable/func/aggregate_function.py +8 -8
  71. pixeltable/func/callable_function.py +9 -9
  72. pixeltable/func/expr_template_function.py +17 -11
  73. pixeltable/func/function.py +18 -20
  74. pixeltable/func/function_registry.py +6 -7
  75. pixeltable/func/globals.py +2 -3
  76. pixeltable/func/mcp.py +74 -0
  77. pixeltable/func/query_template_function.py +29 -27
  78. pixeltable/func/signature.py +46 -19
  79. pixeltable/func/tools.py +31 -13
  80. pixeltable/func/udf.py +18 -20
  81. pixeltable/functions/__init__.py +16 -0
  82. pixeltable/functions/anthropic.py +123 -77
  83. pixeltable/functions/audio.py +147 -10
  84. pixeltable/functions/bedrock.py +13 -6
  85. pixeltable/functions/date.py +7 -4
  86. pixeltable/functions/deepseek.py +35 -43
  87. pixeltable/functions/document.py +81 -0
  88. pixeltable/functions/fal.py +76 -0
  89. pixeltable/functions/fireworks.py +11 -20
  90. pixeltable/functions/gemini.py +195 -39
  91. pixeltable/functions/globals.py +142 -14
  92. pixeltable/functions/groq.py +108 -0
  93. pixeltable/functions/huggingface.py +1056 -24
  94. pixeltable/functions/image.py +115 -57
  95. pixeltable/functions/json.py +1 -1
  96. pixeltable/functions/llama_cpp.py +28 -13
  97. pixeltable/functions/math.py +67 -5
  98. pixeltable/functions/mistralai.py +18 -55
  99. pixeltable/functions/net.py +70 -0
  100. pixeltable/functions/ollama.py +20 -13
  101. pixeltable/functions/openai.py +240 -226
  102. pixeltable/functions/openrouter.py +143 -0
  103. pixeltable/functions/replicate.py +4 -4
  104. pixeltable/functions/reve.py +250 -0
  105. pixeltable/functions/string.py +239 -69
  106. pixeltable/functions/timestamp.py +16 -16
  107. pixeltable/functions/together.py +24 -84
  108. pixeltable/functions/twelvelabs.py +188 -0
  109. pixeltable/functions/util.py +6 -1
  110. pixeltable/functions/uuid.py +30 -0
  111. pixeltable/functions/video.py +1515 -107
  112. pixeltable/functions/vision.py +8 -8
  113. pixeltable/functions/voyageai.py +289 -0
  114. pixeltable/functions/whisper.py +16 -8
  115. pixeltable/functions/whisperx.py +179 -0
  116. pixeltable/{ext/functions → functions}/yolox.py +2 -4
  117. pixeltable/globals.py +362 -115
  118. pixeltable/index/base.py +17 -21
  119. pixeltable/index/btree.py +28 -22
  120. pixeltable/index/embedding_index.py +100 -118
  121. pixeltable/io/__init__.py +4 -2
  122. pixeltable/io/datarows.py +8 -7
  123. pixeltable/io/external_store.py +56 -105
  124. pixeltable/io/fiftyone.py +13 -13
  125. pixeltable/io/globals.py +31 -30
  126. pixeltable/io/hf_datasets.py +61 -16
  127. pixeltable/io/label_studio.py +74 -70
  128. pixeltable/io/lancedb.py +3 -0
  129. pixeltable/io/pandas.py +21 -12
  130. pixeltable/io/parquet.py +25 -105
  131. pixeltable/io/table_data_conduit.py +250 -123
  132. pixeltable/io/utils.py +4 -4
  133. pixeltable/iterators/__init__.py +2 -1
  134. pixeltable/iterators/audio.py +26 -25
  135. pixeltable/iterators/base.py +9 -3
  136. pixeltable/iterators/document.py +112 -78
  137. pixeltable/iterators/image.py +12 -15
  138. pixeltable/iterators/string.py +11 -4
  139. pixeltable/iterators/video.py +523 -120
  140. pixeltable/metadata/__init__.py +14 -3
  141. pixeltable/metadata/converters/convert_13.py +2 -2
  142. pixeltable/metadata/converters/convert_18.py +2 -2
  143. pixeltable/metadata/converters/convert_19.py +2 -2
  144. pixeltable/metadata/converters/convert_20.py +2 -2
  145. pixeltable/metadata/converters/convert_21.py +2 -2
  146. pixeltable/metadata/converters/convert_22.py +2 -2
  147. pixeltable/metadata/converters/convert_24.py +2 -2
  148. pixeltable/metadata/converters/convert_25.py +2 -2
  149. pixeltable/metadata/converters/convert_26.py +2 -2
  150. pixeltable/metadata/converters/convert_29.py +4 -4
  151. pixeltable/metadata/converters/convert_30.py +34 -21
  152. pixeltable/metadata/converters/convert_34.py +2 -2
  153. pixeltable/metadata/converters/convert_35.py +9 -0
  154. pixeltable/metadata/converters/convert_36.py +38 -0
  155. pixeltable/metadata/converters/convert_37.py +15 -0
  156. pixeltable/metadata/converters/convert_38.py +39 -0
  157. pixeltable/metadata/converters/convert_39.py +124 -0
  158. pixeltable/metadata/converters/convert_40.py +73 -0
  159. pixeltable/metadata/converters/convert_41.py +12 -0
  160. pixeltable/metadata/converters/convert_42.py +9 -0
  161. pixeltable/metadata/converters/convert_43.py +44 -0
  162. pixeltable/metadata/converters/util.py +20 -31
  163. pixeltable/metadata/notes.py +9 -0
  164. pixeltable/metadata/schema.py +140 -53
  165. pixeltable/metadata/utils.py +74 -0
  166. pixeltable/mypy/__init__.py +3 -0
  167. pixeltable/mypy/mypy_plugin.py +123 -0
  168. pixeltable/plan.py +382 -115
  169. pixeltable/share/__init__.py +1 -1
  170. pixeltable/share/packager.py +547 -83
  171. pixeltable/share/protocol/__init__.py +33 -0
  172. pixeltable/share/protocol/common.py +165 -0
  173. pixeltable/share/protocol/operation_types.py +33 -0
  174. pixeltable/share/protocol/replica.py +119 -0
  175. pixeltable/share/publish.py +257 -59
  176. pixeltable/store.py +311 -194
  177. pixeltable/type_system.py +373 -211
  178. pixeltable/utils/__init__.py +2 -3
  179. pixeltable/utils/arrow.py +131 -17
  180. pixeltable/utils/av.py +298 -0
  181. pixeltable/utils/azure_store.py +346 -0
  182. pixeltable/utils/coco.py +6 -6
  183. pixeltable/utils/code.py +3 -3
  184. pixeltable/utils/console_output.py +4 -1
  185. pixeltable/utils/coroutine.py +6 -23
  186. pixeltable/utils/dbms.py +32 -6
  187. pixeltable/utils/description_helper.py +4 -5
  188. pixeltable/utils/documents.py +7 -18
  189. pixeltable/utils/exception_handler.py +7 -30
  190. pixeltable/utils/filecache.py +6 -6
  191. pixeltable/utils/formatter.py +86 -48
  192. pixeltable/utils/gcs_store.py +295 -0
  193. pixeltable/utils/http.py +133 -0
  194. pixeltable/utils/http_server.py +2 -3
  195. pixeltable/utils/iceberg.py +1 -2
  196. pixeltable/utils/image.py +17 -0
  197. pixeltable/utils/lancedb.py +90 -0
  198. pixeltable/utils/local_store.py +322 -0
  199. pixeltable/utils/misc.py +5 -0
  200. pixeltable/utils/object_stores.py +573 -0
  201. pixeltable/utils/pydantic.py +60 -0
  202. pixeltable/utils/pytorch.py +5 -6
  203. pixeltable/utils/s3_store.py +527 -0
  204. pixeltable/utils/sql.py +26 -0
  205. pixeltable/utils/system.py +30 -0
  206. pixeltable-0.5.7.dist-info/METADATA +579 -0
  207. pixeltable-0.5.7.dist-info/RECORD +227 -0
  208. {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
  209. pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
  210. pixeltable/__version__.py +0 -3
  211. pixeltable/catalog/named_function.py +0 -40
  212. pixeltable/ext/__init__.py +0 -17
  213. pixeltable/ext/functions/__init__.py +0 -11
  214. pixeltable/ext/functions/whisperx.py +0 -77
  215. pixeltable/utils/media_store.py +0 -77
  216. pixeltable/utils/s3.py +0 -17
  217. pixeltable-0.3.14.dist-info/METADATA +0 -434
  218. pixeltable-0.3.14.dist-info/RECORD +0 -186
  219. pixeltable-0.3.14.dist-info/entry_points.txt +0 -3
  220. {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
@@ -3,14 +3,16 @@ from __future__ import annotations
3
3
  import enum
4
4
  import json
5
5
  import logging
6
- import math
7
- import urllib.parse
8
6
  import urllib.request
9
7
  from dataclasses import dataclass, field, fields
10
8
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional, Union, cast
9
+ from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, cast
12
10
 
11
+ import numpy as np
13
12
  import pandas as pd
13
+ import pyarrow as pa
14
+ import pyarrow.compute as pc
15
+ import pyarrow.types as pat
14
16
  from pyarrow.parquet import ParquetDataset
15
17
 
16
18
  import pixeltable as pxt
@@ -23,7 +25,6 @@ from .utils import normalize_schema_names
23
25
 
24
26
  _logger = logging.getLogger('pixeltable')
25
27
 
26
- # ---------------------------------------------------------------------------------------------------------
27
28
 
28
29
  if TYPE_CHECKING:
29
30
  import datasets # type: ignore[import-untyped]
@@ -46,21 +47,18 @@ class TableDataConduitFormat(str, enum.Enum):
46
47
  return False
47
48
 
48
49
 
49
- # ---------------------------------------------------------------------------------------------------------
50
-
51
-
52
50
  @dataclass
53
51
  class TableDataConduit:
54
- source: TableDataSource
55
- source_format: Optional[str] = None
56
- source_column_map: Optional[dict[str, str]] = None
52
+ source: 'TableDataSource'
53
+ source_format: str | None = None
54
+ source_column_map: dict[str, str] | None = None
57
55
  if_row_exists: Literal['update', 'ignore', 'error'] = 'error'
58
- pxt_schema: Optional[dict[str, Any]] = None
59
- src_schema_overrides: Optional[dict[str, Any]] = None
60
- src_schema: Optional[dict[str, Any]] = None
61
- pxt_pk: Optional[list[str]] = None
62
- src_pk: Optional[list[str]] = None
63
- valid_rows: Optional[RowData] = None
56
+ pxt_schema: dict[str, ts.ColumnType] | None = None
57
+ src_schema_overrides: dict[str, ts.ColumnType] | None = None
58
+ src_schema: dict[str, ts.ColumnType] | None = None
59
+ pxt_pk: list[str] | None = None
60
+ src_pk: list[str] | None = None
61
+ valid_rows: RowData | None = None
64
62
  extra_fields: dict[str, Any] = field(default_factory=dict)
65
63
 
66
64
  reqd_col_names: set[str] = field(default_factory=set)
@@ -68,7 +66,7 @@ class TableDataConduit:
68
66
 
69
67
  total_rows: int = 0 # total number of rows emitted via valid_row_batch Iterator
70
68
 
71
- _K_BATCH_SIZE_BYTES = 100_000_000 # 100 MB
69
+ _K_BATCH_SIZE_BYTES = 256 * 2**20
72
70
 
73
71
  def check_source_format(self) -> None:
74
72
  assert self.source_format is None or TableDataConduitFormat.is_valid(self.source_format)
@@ -84,14 +82,14 @@ class TableDataConduit:
84
82
  return False
85
83
  return all(isinstance(row, dict) for row in d)
86
84
 
87
- def is_direct_df(self) -> bool:
88
- return isinstance(self.source, pxt.DataFrame) and self.source_column_map is None
85
+ def is_direct_query(self) -> bool:
86
+ return isinstance(self.source, pxt.Query) and self.source_column_map is None
89
87
 
90
88
  def normalize_pxt_schema_types(self) -> None:
91
89
  for name, coltype in self.pxt_schema.items():
92
90
  self.pxt_schema[name] = ts.ColumnType.normalize_type(coltype)
93
91
 
94
- def infer_schema(self) -> dict[str, Any]:
92
+ def infer_schema(self) -> dict[str, ts.ColumnType]:
95
93
  raise NotImplementedError
96
94
 
97
95
  def valid_row_batch(self) -> Iterator[RowData]:
@@ -105,7 +103,7 @@ class TableDataConduit:
105
103
  def add_table_info(self, table: pxt.Table) -> None:
106
104
  """Add information about the table into which we are inserting data"""
107
105
  assert isinstance(table, pxt.Table)
108
- self.pxt_schema = table._schema
106
+ self.pxt_schema = table._get_schema()
109
107
  self.pxt_pk = table._tbl_version.get().primary_key
110
108
  for col in table._tbl_version_path.columns():
111
109
  if col.is_required_for_insert:
@@ -129,37 +127,34 @@ class TableDataConduit:
129
127
  raise excs.Error(f'Missing required column(s) ({", ".join(missing_cols)})')
130
128
 
131
129
 
132
- # ---------------------------------------------------------------------------------------------------------
133
-
134
-
135
- class DFTableDataConduit(TableDataConduit):
136
- pxt_df: pxt.DataFrame = None
130
+ class QueryTableDataConduit(TableDataConduit):
131
+ pxt_query: pxt.Query = None
137
132
 
138
133
  @classmethod
139
- def from_tds(cls, tds: TableDataConduit) -> 'DFTableDataConduit':
134
+ def from_tds(cls, tds: TableDataConduit) -> 'QueryTableDataConduit':
140
135
  tds_fields = {f.name for f in fields(tds)}
141
136
  kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
142
137
  t = cls(**kwargs)
143
- assert isinstance(tds.source, pxt.DataFrame)
144
- t.pxt_df = tds.source
138
+ if isinstance(tds.source, pxt.Table):
139
+ t.pxt_query = tds.source.select()
140
+ else:
141
+ assert isinstance(tds.source, pxt.Query)
142
+ t.pxt_query = tds.source
145
143
  return t
146
144
 
147
- def infer_schema(self) -> dict[str, Any]:
148
- self.pxt_schema = self.pxt_df.schema
145
+ def infer_schema(self) -> dict[str, ts.ColumnType]:
146
+ self.pxt_schema = self.pxt_query.schema
149
147
  self.pxt_pk = self.src_pk
150
148
  return self.pxt_schema
151
149
 
152
150
  def prepare_for_insert_into_table(self) -> None:
153
151
  if self.source_column_map is None:
154
152
  self.source_column_map = {}
155
- self.check_source_columns_are_insertable(self.pxt_df.schema.keys())
156
-
157
-
158
- # ---------------------------------------------------------------------------------------------------------
153
+ self.check_source_columns_are_insertable(self.pxt_query.schema.keys())
159
154
 
160
155
 
161
156
  class RowDataTableDataConduit(TableDataConduit):
162
- raw_rows: Optional[RowData] = None
157
+ raw_rows: RowData | None = None
163
158
  disable_mapping: bool = True
164
159
  batch_count: int = 0
165
160
 
@@ -178,7 +173,7 @@ class RowDataTableDataConduit(TableDataConduit):
178
173
  t.batch_count = 0
179
174
  return t
180
175
 
181
- def infer_schema(self) -> dict[str, Any]:
176
+ def infer_schema(self) -> dict[str, ts.ColumnType]:
182
177
  from .datarows import _infer_schema_from_rows
183
178
 
184
179
  if self.source_column_map is None:
@@ -235,9 +230,6 @@ class RowDataTableDataConduit(TableDataConduit):
235
230
  yield self.valid_rows
236
231
 
237
232
 
238
- # ---------------------------------------------------------------------------------------------------------
239
-
240
-
241
233
  class PandasTableDataConduit(TableDataConduit):
242
234
  pd_df: pd.DataFrame = None
243
235
  batch_count: int = 0
@@ -252,7 +244,7 @@ class PandasTableDataConduit(TableDataConduit):
252
244
  t.batch_count = 0
253
245
  return t
254
246
 
255
- def infer_schema_part1(self) -> tuple[dict[str, Any], list[str]]:
247
+ def infer_schema_part1(self) -> tuple[dict[str, ts.ColumnType], list[str]]:
256
248
  """Return inferred schema, inferred primary key, and source column map"""
257
249
  if self.source_column_map is None:
258
250
  if self.src_schema_overrides is None:
@@ -265,7 +257,7 @@ class PandasTableDataConduit(TableDataConduit):
265
257
  else:
266
258
  raise NotImplementedError()
267
259
 
268
- def infer_schema(self) -> dict[str, Any]:
260
+ def infer_schema(self) -> dict[str, ts.ColumnType]:
269
261
  self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
270
262
  self.normalize_pxt_schema_types()
271
263
  _df_check_primary_key_values(self.pd_df, self.src_pk)
@@ -293,9 +285,6 @@ class PandasTableDataConduit(TableDataConduit):
293
285
  yield self.valid_rows
294
286
 
295
287
 
296
- # ---------------------------------------------------------------------------------------------------------
297
-
298
-
299
288
  class CSVTableDataConduit(TableDataConduit):
300
289
  @classmethod
301
290
  def from_tds(cls, tds: TableDataConduit) -> 'PandasTableDataConduit':
@@ -307,9 +296,6 @@ class CSVTableDataConduit(TableDataConduit):
307
296
  return PandasTableDataConduit.from_tds(t)
308
297
 
309
298
 
310
- # ---------------------------------------------------------------------------------------------------------
311
-
312
-
313
299
  class ExcelTableDataConduit(TableDataConduit):
314
300
  @classmethod
315
301
  def from_tds(cls, tds: TableDataConduit) -> 'PandasTableDataConduit':
@@ -321,9 +307,6 @@ class ExcelTableDataConduit(TableDataConduit):
321
307
  return PandasTableDataConduit.from_tds(t)
322
308
 
323
309
 
324
- # ---------------------------------------------------------------------------------------------------------
325
-
326
-
327
310
  class JsonTableDataConduit(TableDataConduit):
328
311
  @classmethod
329
312
  def from_tds(cls, tds: TableDataConduit) -> RowDataTableDataConduit:
@@ -346,48 +329,68 @@ class JsonTableDataConduit(TableDataConduit):
346
329
  return t2
347
330
 
348
331
 
349
- # ---------------------------------------------------------------------------------------------------------
350
-
351
-
352
332
  class HFTableDataConduit(TableDataConduit):
353
- hf_ds: Optional[Union[datasets.Dataset, datasets.DatasetDict]] = None
354
- column_name_for_split: Optional[str] = None
333
+ """HuggingFace dataset importer"""
334
+
335
+ column_name_for_split: str | None = None
355
336
  categorical_features: dict[str, dict[int, str]]
356
- hf_schema: dict[str, Any] = None
357
- dataset_dict: dict[str, datasets.Dataset] = None
337
+ dataset_dict: dict[str, 'datasets.Dataset'] = None # key: split name
358
338
  hf_schema_source: dict[str, Any] = None
359
339
 
360
340
  @classmethod
361
- def from_tds(cls, tds: TableDataConduit) -> 'HFTableDataConduit':
341
+ def from_tds(cls, tds: TableDataConduit) -> HFTableDataConduit:
362
342
  tds_fields = {f.name for f in fields(tds)}
363
343
  kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
364
344
  t = cls(**kwargs)
365
345
  import datasets
366
346
 
367
- assert isinstance(tds.source, (datasets.Dataset, datasets.DatasetDict))
368
- t.hf_ds = tds.source
347
+ assert isinstance(tds.source, cls._get_dataset_classes())
369
348
  if 'column_name_for_split' in t.extra_fields:
370
349
  t.column_name_for_split = t.extra_fields['column_name_for_split']
350
+
351
+ if isinstance(tds.source, (datasets.IterableDataset, datasets.IterableDatasetDict)):
352
+ tds.source = tds.source.with_format('arrow')
353
+
354
+ if isinstance(tds.source, (datasets.Dataset, datasets.IterableDataset)):
355
+ split_name = str(tds.source.split) if tds.source.split is not None else None
356
+ t.dataset_dict = {split_name: tds.source}
357
+ else:
358
+ assert isinstance(tds.source, (datasets.DatasetDict, datasets.IterableDatasetDict))
359
+ t.dataset_dict = dict(tds.source)
360
+
361
+ # Disable auto-decoding for Audio and Image columns, we want to write the bytes directly to temp files
362
+ for ds_split_name, dataset in list(t.dataset_dict.items()):
363
+ for col_name, feature in dataset.features.items():
364
+ if isinstance(feature, (datasets.Audio, datasets.Image)):
365
+ t.dataset_dict[ds_split_name] = t.dataset_dict[ds_split_name].cast_column(
366
+ col_name, feature.__class__(decode=False)
367
+ )
371
368
  return t
372
369
 
370
+ @classmethod
371
+ def _get_dataset_classes(cls) -> tuple[type, ...]:
372
+ import datasets
373
+
374
+ return (datasets.Dataset, datasets.DatasetDict, datasets.IterableDataset, datasets.IterableDatasetDict)
375
+
373
376
  @classmethod
374
377
  def is_applicable(cls, tds: TableDataConduit) -> bool:
375
378
  try:
376
- import datasets
377
-
378
379
  return (isinstance(tds.source_format, str) and tds.source_format.lower() == 'huggingface') or isinstance(
379
- tds.source, (datasets.Dataset, datasets.DatasetDict)
380
+ tds.source, cls._get_dataset_classes()
380
381
  )
381
382
  except ImportError:
382
383
  return False
383
384
 
384
- def infer_schema_part1(self) -> tuple[dict[str, Any], list[str]]:
385
+ def infer_schema_part1(self) -> tuple[dict[str, ts.ColumnType], list[str]]:
385
386
  from pixeltable.io.hf_datasets import _get_hf_schema, huggingface_schema_to_pxt_schema
386
387
 
387
388
  if self.source_column_map is None:
388
389
  if self.src_schema_overrides is None:
389
390
  self.src_schema_overrides = {}
390
- self.hf_schema_source = _get_hf_schema(self.hf_ds)
391
+ if self.src_pk is None:
392
+ self.src_pk = []
393
+ self.hf_schema_source = _get_hf_schema(self.source)
391
394
  self.src_schema = huggingface_schema_to_pxt_schema(
392
395
  self.hf_schema_source, self.src_schema_overrides, self.src_pk
393
396
  )
@@ -402,7 +405,7 @@ class HFTableDataConduit(TableDataConduit):
402
405
  self.src_schema[self.column_name_for_split] = ts.StringType(nullable=True)
403
406
 
404
407
  inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
405
- self.src_schema, self.src_pk, self.src_schema_overrides, True
408
+ self.src_schema, self.src_pk, self.src_schema_overrides
406
409
  )
407
410
  return inferred_schema, inferred_pk
408
411
  else:
@@ -422,16 +425,7 @@ class HFTableDataConduit(TableDataConduit):
422
425
  def prepare_insert(self) -> None:
423
426
  import datasets
424
427
 
425
- if isinstance(self.source, datasets.Dataset):
426
- # when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
427
- raw_name = self.source.split._name
428
- split_name = raw_name.split('[')[0] if raw_name is not None else None
429
- self.dataset_dict = {split_name: self.source}
430
- else:
431
- assert isinstance(self.source, datasets.DatasetDict)
432
- self.dataset_dict = self.source
433
-
434
- # extract all class labels from the dataset to translate category ints to strings
428
+ # Extract all class labels from the dataset to translate category ints to strings
435
429
  self.categorical_features = {
436
430
  feature_name: feature_type.names
437
431
  for (feature_name, feature_type) in self.hf_schema_source.items()
@@ -441,48 +435,186 @@ class HFTableDataConduit(TableDataConduit):
441
435
  self.source_column_map = {}
442
436
  self.check_source_columns_are_insertable(self.hf_schema_source.keys())
443
437
 
444
- def _translate_row(self, row: dict[str, Any], split_name: str) -> dict[str, Any]:
445
- output_row: dict[str, Any] = {}
446
- for col_name, val in row.items():
447
- # translate category ints to strings
448
- new_val = self.categorical_features[col_name][val] if col_name in self.categorical_features else val
449
- mapped_col_name = self.source_column_map.get(col_name, col_name)
438
+ def _convert_column(self, column: 'pa.ChunkedArray', feature: object) -> list:
439
+ """
440
+ Convert an Arrow column to a list of Python values based on HF feature type.
441
+ Handles all feature types at the column level, recursing for structs.
442
+ Returns a list of length chunk_size.
443
+ """
444
+ import datasets
450
445
 
451
- # Convert values to the appropriate type if needed
452
- try:
453
- checked_val = self.pxt_schema[mapped_col_name].create_literal(new_val)
454
- except TypeError as e:
455
- msg = str(e)
456
- raise excs.Error(f'Error in column {col_name}: {msg[0].lower() + msg[1:]}\nRow: {row}') from e
457
- output_row[mapped_col_name] = checked_val
446
+ # return scalars as Python scalars
447
+ if isinstance(feature, datasets.Value):
448
+ return column.to_pylist()
449
+
450
+ # ClassLabel: int -> string name
451
+ if isinstance(feature, datasets.ClassLabel):
452
+ values = column.to_pylist()
453
+ return [feature.names[v] if v is not None else None for v in values]
454
+
455
+ # check for list of dict before Sequence, which could contain array data
456
+ is_list_of_dict = isinstance(feature, (datasets.Sequence, datasets.LargeList)) and isinstance(
457
+ feature.feature, dict
458
+ )
459
+ if is_list_of_dict:
460
+ return column.to_pylist()
461
+
462
+ # array data represented as a (possibly nested) sequence of numerical data: convert to numpy arrays
463
+ if self._is_sequence_of_numerical(feature):
464
+ arr = column.to_numpy(zero_copy_only=False)
465
+ result: list = []
466
+ for i in range(len(column)):
467
+ val = arr[i]
468
+ assert not isinstance(val, dict) # we dealt with list of dicts earlier
469
+ # convert object array of arrays (e.g., multi-channel audio) to proper ndarray
470
+ if (
471
+ isinstance(val, np.ndarray)
472
+ and val.dtype == object
473
+ and len(val) > 0
474
+ and isinstance(val[0], np.ndarray)
475
+ ):
476
+ val = np.stack(list(val))
477
+ result.append(val)
478
+ return result
479
+
480
+ if isinstance(feature, (datasets.Audio, datasets.Image)):
481
+ # Audio/Image is stored in Arrow as struct<bytes: binary, path: string>
482
+
483
+ from pixeltable.utils.local_store import TempStore
484
+
485
+ arrow_type = column.type
486
+ if not pa.types.is_struct(arrow_type):
487
+ raise pxt.Error(f'Expected struct type for Audio column, got {arrow_type}')
488
+ field_names = {field.name for field in arrow_type}
489
+ if 'bytes' not in field_names or 'path' not in field_names:
490
+ raise pxt.Error(f"Audio struct missing required fields 'bytes' and/or 'path', has: {field_names}")
491
+
492
+ bytes_column = pc.struct_field(column, 'bytes')
493
+ path_column = pc.struct_field(column, 'path')
494
+
495
+ bytes_list = bytes_column.to_pylist()
496
+ path_list = path_column.to_pylist()
497
+
498
+ result = []
499
+ for bytes, path in zip(bytes_list, path_list):
500
+ if bytes is None:
501
+ result.append(None)
502
+ continue
503
+ # we want to preserve the extension from the original path
504
+ ext = Path(path).suffix if path is not None else None
505
+ temp_path = TempStore.create_path(extension=ext)
506
+ temp_path.write_bytes(bytes)
507
+ result.append(str(temp_path))
508
+ return result
509
+
510
+ if isinstance(feature, dict):
511
+ return self._convert_struct_column(column, feature)
512
+
513
+ if isinstance(feature, list):
514
+ return column.to_pylist()
515
+
516
+ # Array<N>D: multi-dimensional fixed-shape arrays
517
+ if isinstance(feature, (datasets.Array2D, datasets.Array3D, datasets.Array4D, datasets.Array5D)):
518
+ return self._convert_array_feature(column, feature.shape)
519
+
520
+ return column.to_pylist()
521
+
522
+ def _is_sequence_of_numerical(self, feature: object) -> bool:
523
+ """Returns True if feature is a (nested) Sequence of numerical values."""
524
+ import datasets
458
525
 
459
- # add split name to output row
460
- if self.column_name_for_split is not None:
461
- output_row[self.column_name_for_split] = split_name
462
- return output_row
526
+ if not isinstance(feature, datasets.Sequence):
527
+ return False
528
+ if isinstance(feature.feature, datasets.Sequence):
529
+ return self._is_sequence_of_numerical(feature.feature)
463
530
 
464
- def valid_row_batch(self) -> Iterator[RowData]:
465
- for split_name, split_dataset in self.dataset_dict.items():
466
- num_batches = split_dataset.size_in_bytes / self._K_BATCH_SIZE_BYTES
467
- tuples_per_batch = math.ceil(split_dataset.num_rows / num_batches)
468
- assert tuples_per_batch > 0
531
+ pa_type = feature.feature.pa_type
532
+ return pa_type is not None and (pat.is_integer(pa_type) or pat.is_floating(pa_type))
533
+
534
+ def _convert_struct_column(self, column: 'pa.ChunkedArray', feature: dict[str, object]) -> list[dict[str, Any]]:
535
+ """
536
+ Convert a StructArray column to a list of dicts by recursively
537
+ converting each field.
538
+ """
539
+
540
+ results: list[dict[str, Any]] = [{} for _ in range(len(column))]
541
+ for field_name, field_feature in feature.items():
542
+ field_column = pc.struct_field(column, field_name)
543
+ field_values = self._convert_column(field_column, field_feature)
469
544
 
470
- batch = []
471
- for row in split_dataset:
472
- batch.append(self._translate_row(row, split_name))
473
- if len(batch) >= tuples_per_batch:
474
- yield batch
475
- batch = []
476
- # last batch
477
- if len(batch) > 0:
478
- yield batch
545
+ for i, val in enumerate(field_values):
546
+ results[i][field_name] = val
479
547
 
548
+ return results
480
549
 
481
- # ---------------------------------------------------------------------------------------------------------
550
+ def _convert_array_feature(self, column: 'pa.ChunkedArray', shape: tuple[int, ...]) -> list[np.ndarray]:
551
+ arr: pa.ExtensionArray
552
+ # TODO: can we get multiple chunks here?
553
+ if column.num_chunks == 1:
554
+ arr = column.chunks[0] # type: ignore[assignment]
555
+ else:
556
+ arr = column.combine_chunks() # type: ignore[assignment]
557
+
558
+ # an Array<N>D feature is stored in Arrow as a list<list<...<dtype>>>; we want to peel off the outer lists
559
+ # to get to contiguous storage and then reshape that
560
+ storage = arr.storage
561
+ vals = storage.values
562
+ while hasattr(vals, 'values'):
563
+ vals = vals.values
564
+ flat_arr = vals.to_numpy()
565
+ chunk_shape = (len(column), *shape)
566
+ reshaped = flat_arr.reshape(chunk_shape)
567
+
568
+ # Return as list of array views (shares memory with reshaped)
569
+ return list(reshaped)
570
+
571
+ def valid_row_batch(self) -> Iterator['RowData']:
572
+ import datasets
573
+
574
+ for split_name, split_dataset in self.dataset_dict.items():
575
+ features = split_dataset.features
576
+ if isinstance(split_dataset, datasets.Dataset):
577
+ table = split_dataset.data # the underlying Arrow table
578
+ yield from self._process_arrow_table(table, split_name, features)
579
+ else:
580
+ # we're getting batches of Arrow tables, since we did set_format('arrow');
581
+ # use a trial batch to determine the target batch size
582
+ first_batch = next(split_dataset.iter(batch_size=16))
583
+ bytes_per_row = int(first_batch.nbytes / len(first_batch))
584
+ batch_size = self._K_BATCH_SIZE_BYTES // bytes_per_row
585
+ yield from self._process_arrow_table(first_batch, split_name, features)
586
+ for batch in split_dataset.skip(16).iter(batch_size=batch_size):
587
+ yield from self._process_arrow_table(batch, split_name, features)
588
+
589
+ def _process_arrow_table(self, table: 'pa.Table', split_name: str, features: dict[str, Any]) -> Iterator[RowData]:
590
+ # get chunk boundaries from first column's ChunkedArray
591
+ first_column = table.column(0)
592
+ offset = 0
593
+ for chunk in first_column.chunks:
594
+ chunk_size = len(chunk)
595
+ # zero-copy slice using existing chunk boundaries
596
+ batch = table.slice(offset, chunk_size)
597
+
598
+ # we assemble per-row dicts by from lists of per-column values
599
+ rows: list[dict[str, Any]] = [{} for _ in range(chunk_size)]
600
+ if self.column_name_for_split is not None:
601
+ for row in rows:
602
+ row[self.column_name_for_split] = split_name
603
+
604
+ for col_idx, col_name in enumerate(batch.schema.names):
605
+ feature = features[col_name]
606
+ mapped_col_name = self.source_column_map.get(col_name, col_name)
607
+ column = batch.column(col_idx)
608
+ values = self._convert_column(column, feature)
609
+ for i, val in enumerate(values):
610
+ rows[i][mapped_col_name] = val
611
+
612
+ offset += chunk_size
613
+ yield rows
482
614
 
483
615
 
484
616
  class ParquetTableDataConduit(TableDataConduit):
485
- pq_ds: Optional[ParquetDataset] = None
617
+ pq_ds: ParquetDataset | None = None
486
618
 
487
619
  @classmethod
488
620
  def from_tds(cls, tds: TableDataConduit) -> 'ParquetTableDataConduit':
@@ -490,20 +622,18 @@ class ParquetTableDataConduit(TableDataConduit):
490
622
  kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
491
623
  t = cls(**kwargs)
492
624
 
493
- from pyarrow import parquet
494
-
495
625
  assert isinstance(tds.source, str)
496
626
  input_path = Path(tds.source).expanduser()
497
- t.pq_ds = parquet.ParquetDataset(str(input_path))
627
+ t.pq_ds = pa.parquet.ParquetDataset(str(input_path))
498
628
  return t
499
629
 
500
- def infer_schema_part1(self) -> tuple[dict[str, Any], list[str]]:
501
- from pixeltable.utils.arrow import ar_infer_schema
630
+ def infer_schema_part1(self) -> tuple[dict[str, ts.ColumnType], list[str]]:
631
+ from pixeltable.utils.arrow import to_pxt_schema
502
632
 
503
633
  if self.source_column_map is None:
504
634
  if self.src_schema_overrides is None:
505
635
  self.src_schema_overrides = {}
506
- self.src_schema = ar_infer_schema(self.pq_ds.schema, self.src_schema_overrides, self.src_pk)
636
+ self.src_schema = to_pxt_schema(self.pq_ds.schema, self.src_schema_overrides, self.src_pk)
507
637
  inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
508
638
  self.src_schema, self.src_pk, self.src_schema_overrides
509
639
  )
@@ -511,7 +641,7 @@ class ParquetTableDataConduit(TableDataConduit):
511
641
  else:
512
642
  raise NotImplementedError()
513
643
 
514
- def infer_schema(self) -> dict[str, Any]:
644
+ def infer_schema(self) -> dict[str, ts.ColumnType]:
515
645
  self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
516
646
  self.normalize_pxt_schema_types()
517
647
  self.prepare_insert()
@@ -532,7 +662,7 @@ class ParquetTableDataConduit(TableDataConduit):
532
662
  from pixeltable.utils.arrow import iter_tuples2
533
663
 
534
664
  try:
535
- for fragment in self.pq_ds.fragments: # type: ignore[attr-defined]
665
+ for fragment in self.pq_ds.fragments:
536
666
  for batch in fragment.to_batches():
537
667
  dict_batch = list(iter_tuples2(batch, self.source_column_map, self.pxt_schema))
538
668
  self.total_rows += len(dict_batch)
@@ -542,15 +672,12 @@ class ParquetTableDataConduit(TableDataConduit):
542
672
  raise e
543
673
 
544
674
 
545
- # ---------------------------------------------------------------------------------------------------------
546
-
547
-
548
675
  class UnkTableDataConduit(TableDataConduit):
549
676
  """Source type is not known at the time of creation"""
550
677
 
551
678
  def specialize(self) -> TableDataConduit:
552
- if isinstance(self.source, pxt.DataFrame):
553
- return DFTableDataConduit.from_tds(self)
679
+ if isinstance(self.source, (pxt.Table, pxt.Query)):
680
+ return QueryTableDataConduit.from_tds(self)
554
681
  if isinstance(self.source, pd.DataFrame):
555
682
  return PandasTableDataConduit.from_tds(self)
556
683
  if HFTableDataConduit.is_applicable(self):
pixeltable/io/utils.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from keyword import iskeyword as is_python_keyword
2
- from typing import Any, Optional, Union
2
+ from typing import Any
3
3
 
4
4
  import pixeltable as pxt
5
5
  import pixeltable.exceptions as excs
@@ -8,7 +8,7 @@ from pixeltable.catalog.globals import is_system_column_name
8
8
 
9
9
  def normalize_pxt_col_name(name: str) -> str:
10
10
  """
11
- Normalizes an arbitrary DataFrame column name into a valid Pixeltable identifier by:
11
+ Normalizes an arbitrary column name into a valid Pixeltable identifier by:
12
12
  - replacing any non-ascii or non-alphanumeric characters with an underscore _
13
13
  - prefixing the result with the letter 'c' if it starts with an underscore or a number
14
14
  """
@@ -21,7 +21,7 @@ def normalize_pxt_col_name(name: str) -> str:
21
21
  return id
22
22
 
23
23
 
24
- def normalize_primary_key_parameter(primary_key: Optional[Union[str, list[str]]] = None) -> list[str]:
24
+ def normalize_primary_key_parameter(primary_key: str | list[str] | None = None) -> list[str]:
25
25
  if primary_key is None:
26
26
  primary_key = []
27
27
  elif isinstance(primary_key, str):
@@ -40,7 +40,7 @@ def normalize_schema_names(
40
40
  primary_key: list[str],
41
41
  schema_overrides: dict[str, Any],
42
42
  require_valid_pxt_column_names: bool = False,
43
- ) -> tuple[dict[str, Any], list[str], Optional[dict[str, str]]]:
43
+ ) -> tuple[dict[str, Any], list[str], dict[str, str] | None]:
44
44
  """
45
45
  Convert all names in the input schema from source names to valid Pixeltable identifiers
46
46
  - Ensure that all names are unique.
@@ -1,3 +1,4 @@
1
+ """Iterators for splitting media and documents into components."""
1
2
  # ruff: noqa: F401
2
3
 
3
4
  from .audio import AudioSplitter
@@ -5,7 +6,7 @@ from .base import ComponentIterator
5
6
  from .document import DocumentSplitter
6
7
  from .image import TileIterator
7
8
  from .string import StringSplitter
8
- from .video import FrameIterator
9
+ from .video import FrameIterator, VideoSplitter
9
10
 
10
11
  __default_dir = {symbol for symbol in dir() if not symbol.startswith('_')}
11
12
  __removed_symbols = {'base', 'document', 'video'}