pixeltable 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (52) hide show
  1. pixeltable/__init__.py +1 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/catalog.py +509 -103
  4. pixeltable/catalog/column.py +5 -0
  5. pixeltable/catalog/dir.py +15 -6
  6. pixeltable/catalog/globals.py +16 -0
  7. pixeltable/catalog/insertable_table.py +82 -41
  8. pixeltable/catalog/path.py +15 -0
  9. pixeltable/catalog/schema_object.py +7 -12
  10. pixeltable/catalog/table.py +81 -67
  11. pixeltable/catalog/table_version.py +23 -7
  12. pixeltable/catalog/view.py +9 -6
  13. pixeltable/env.py +15 -9
  14. pixeltable/exec/exec_node.py +1 -1
  15. pixeltable/exprs/__init__.py +2 -1
  16. pixeltable/exprs/arithmetic_expr.py +2 -0
  17. pixeltable/exprs/column_ref.py +38 -2
  18. pixeltable/exprs/expr.py +61 -12
  19. pixeltable/exprs/function_call.py +1 -4
  20. pixeltable/exprs/globals.py +12 -0
  21. pixeltable/exprs/json_mapper.py +4 -4
  22. pixeltable/exprs/json_path.py +10 -11
  23. pixeltable/exprs/similarity_expr.py +5 -20
  24. pixeltable/exprs/string_op.py +107 -0
  25. pixeltable/ext/functions/yolox.py +21 -64
  26. pixeltable/func/callable_function.py +5 -2
  27. pixeltable/func/query_template_function.py +6 -18
  28. pixeltable/func/tools.py +2 -2
  29. pixeltable/functions/__init__.py +1 -1
  30. pixeltable/functions/globals.py +16 -5
  31. pixeltable/globals.py +172 -262
  32. pixeltable/io/__init__.py +3 -2
  33. pixeltable/io/datarows.py +138 -0
  34. pixeltable/io/external_store.py +8 -5
  35. pixeltable/io/globals.py +7 -160
  36. pixeltable/io/hf_datasets.py +21 -98
  37. pixeltable/io/pandas.py +29 -43
  38. pixeltable/io/parquet.py +17 -42
  39. pixeltable/io/table_data_conduit.py +569 -0
  40. pixeltable/io/utils.py +6 -21
  41. pixeltable/metadata/__init__.py +1 -1
  42. pixeltable/metadata/converters/convert_30.py +50 -0
  43. pixeltable/metadata/converters/util.py +26 -1
  44. pixeltable/metadata/notes.py +1 -0
  45. pixeltable/metadata/schema.py +3 -0
  46. pixeltable/utils/arrow.py +32 -7
  47. pixeltable/utils/coroutine.py +41 -0
  48. {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/METADATA +1 -1
  49. {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/RECORD +52 -47
  50. {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/WHEEL +1 -1
  51. {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/LICENSE +0 -0
  52. {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,569 @@
1
+ from __future__ import annotations
2
+
3
+ import enum
4
+ import json
5
+ import logging
6
+ import math
7
+ import urllib.parse
8
+ import urllib.request
9
+ from dataclasses import dataclass, field, fields
10
+ from pathlib import Path
11
+ from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional, Union, cast
12
+
13
+ import pandas as pd
14
+ from pyarrow.parquet import ParquetDataset
15
+
16
+ import pixeltable as pxt
17
+ import pixeltable.exceptions as excs
18
+ from pixeltable.io.pandas import _df_check_primary_key_values, _df_row_to_pxt_row, df_infer_schema
19
+ from pixeltable.utils import parse_local_file_path
20
+
21
+ from .utils import normalize_schema_names
22
+
23
+ _logger = logging.getLogger('pixeltable')
24
+
25
+ # ---------------------------------------------------------------------------------------------------------
26
+
27
+ if TYPE_CHECKING:
28
+ import datasets # type: ignore[import-untyped]
29
+
30
+ from pixeltable.globals import RowData, TableDataSource
31
+
32
+
33
+ class TableDataConduitFormat(str, enum.Enum):
34
+ """Supported formats for TableDataConduit"""
35
+
36
+ JSON = 'json'
37
+ CSV = 'csv'
38
+ EXCEL = 'excel'
39
+ PARQUET = 'parquet'
40
+
41
+ @classmethod
42
+ def is_valid(cls, x: Any) -> bool:
43
+ if isinstance(x, str):
44
+ return x.lower() in [c.value for c in cls]
45
+ return False
46
+
47
+
48
+ # ---------------------------------------------------------------------------------------------------------
49
+
50
+
51
+ @dataclass
52
+ class TableDataConduit:
53
+ source: TableDataSource
54
+ source_format: Optional[str] = None
55
+ source_column_map: Optional[dict[str, str]] = None
56
+ if_row_exists: Literal['update', 'ignore', 'error'] = 'error'
57
+ pxt_schema: Optional[dict[str, Any]] = None
58
+ src_schema_overrides: Optional[dict[str, Any]] = None
59
+ src_schema: Optional[dict[str, Any]] = None
60
+ pxt_pk: Optional[list[str]] = None
61
+ src_pk: Optional[list[str]] = None
62
+ valid_rows: Optional[RowData] = None
63
+ extra_fields: dict[str, Any] = field(default_factory=dict)
64
+
65
+ reqd_col_names: set[str] = field(default_factory=set)
66
+ computed_col_names: set[str] = field(default_factory=set)
67
+
68
+ total_rows: int = 0 # total number of rows emitted via valid_row_batch Iterator
69
+
70
+ _K_BATCH_SIZE_BYTES = 100_000_000 # 100 MB
71
+
72
+ def check_source_format(self) -> None:
73
+ assert self.source_format is None or TableDataConduitFormat.is_valid(self.source_format)
74
+
75
+ @classmethod
76
+ def is_rowdata_structure(cls, d: TableDataSource) -> bool:
77
+ if not isinstance(d, list) or len(d) == 0:
78
+ return False
79
+ return all(isinstance(row, dict) for row in d)
80
+
81
+ def is_direct_df(self) -> bool:
82
+ return isinstance(self.source, pxt.DataFrame) and self.source_column_map is None
83
+
84
+ def normalize_pxt_schema_types(self) -> None:
85
+ for name, coltype in self.pxt_schema.items():
86
+ self.pxt_schema[name] = pxt.ColumnType.normalize_type(coltype)
87
+
88
+ def infer_schema(self) -> dict[str, Any]:
89
+ raise NotImplementedError
90
+
91
+ def valid_row_batch(self) -> Iterator[RowData]:
92
+ raise NotImplementedError
93
+
94
+ def prepare_for_insert_into_table(self) -> None:
95
+ if self.source is None:
96
+ return
97
+ raise NotImplementedError
98
+
99
+ def add_table_info(self, table: pxt.Table) -> None:
100
+ """Add information about the table into which we are inserting data"""
101
+ assert isinstance(table, pxt.Table)
102
+ self.pxt_schema = table._schema
103
+ self.pxt_pk = table._tbl_version.get().primary_key
104
+ for col in table._tbl_version_path.columns():
105
+ if col.is_required_for_insert:
106
+ self.reqd_col_names.add(col.name)
107
+ if col.is_computed:
108
+ self.computed_col_names.add(col.name)
109
+ self.src_pk = []
110
+
111
+ # Check source columns : required, computed, unknown
112
+ def check_source_columns_are_insertable(self, columns: Iterable[str]) -> None:
113
+ col_name_set: set[str] = set()
114
+ for col_name in columns: # FIXME
115
+ mapped_col_name = self.source_column_map.get(col_name, col_name)
116
+ col_name_set.add(mapped_col_name)
117
+ if mapped_col_name not in self.pxt_schema:
118
+ raise excs.Error(f'Unknown column name {mapped_col_name}')
119
+ if mapped_col_name in self.computed_col_names:
120
+ raise excs.Error(f'Value for computed column {mapped_col_name}')
121
+ missing_cols = self.reqd_col_names - col_name_set
122
+ if len(missing_cols) > 0:
123
+ raise excs.Error(f'Missing required column(s) ({", ".join(missing_cols)})')
124
+
125
+
126
+ # ---------------------------------------------------------------------------------------------------------
127
+
128
+
129
+ class DFTableDataConduit(TableDataConduit):
130
+ pxt_df: pxt.DataFrame = None
131
+
132
+ @classmethod
133
+ def from_tds(cls, tds: TableDataConduit) -> 'DFTableDataConduit':
134
+ tds_fields = {f.name for f in fields(tds)}
135
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
136
+ t = cls(**kwargs)
137
+ assert isinstance(tds.source, pxt.DataFrame)
138
+ t.pxt_df = tds.source
139
+ return t
140
+
141
+ def infer_schema(self) -> dict[str, Any]:
142
+ self.pxt_schema = self.pxt_df.schema
143
+ self.pxt_pk = self.src_pk
144
+ return self.pxt_schema
145
+
146
+ def prepare_for_insert_into_table(self) -> None:
147
+ if self.source_column_map is None:
148
+ self.source_column_map = {}
149
+ self.check_source_columns_are_insertable(self.pxt_df.schema.keys())
150
+
151
+
152
+ # ---------------------------------------------------------------------------------------------------------
153
+
154
+
155
+ class RowDataTableDataConduit(TableDataConduit):
156
+ raw_rows: Optional[RowData] = None
157
+ disable_mapping: bool = True
158
+ batch_count: int = 0
159
+
160
+ @classmethod
161
+ def from_tds(cls, tds: TableDataConduit) -> 'RowDataTableDataConduit':
162
+ tds_fields = {f.name for f in fields(tds)}
163
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
164
+ t = cls(**kwargs)
165
+ if isinstance(tds.source, Iterator):
166
+ # Instantiate the iterator to get the raw rows here
167
+ t.raw_rows = list(tds.source)
168
+ elif TYPE_CHECKING:
169
+ t.raw_rows = cast(RowData, tds.source)
170
+ else:
171
+ t.raw_rows = tds.source
172
+ t.batch_count = 0
173
+ return t
174
+
175
+ def infer_schema(self) -> dict[str, Any]:
176
+ from .datarows import _infer_schema_from_rows
177
+
178
+ if self.source_column_map is None:
179
+ if self.src_schema_overrides is None:
180
+ self.src_schema_overrides = {}
181
+ self.src_schema = _infer_schema_from_rows(self.raw_rows, self.src_schema_overrides, self.src_pk)
182
+ self.pxt_schema, self.pxt_pk, self.source_column_map = normalize_schema_names(
183
+ self.src_schema, self.src_pk, self.src_schema_overrides, self.disable_mapping
184
+ )
185
+ self.normalize_pxt_schema_types()
186
+ else:
187
+ raise NotImplementedError()
188
+
189
+ self.prepare_for_insert_into_table()
190
+ return self.pxt_schema
191
+
192
+ def prepare_for_insert_into_table(self) -> None:
193
+ # Converting rows to insertable format is not needed, misnamed columns and types
194
+ # are errors in the incoming row format
195
+ if self.source_column_map is None:
196
+ self.source_column_map = {}
197
+ self.valid_rows = [self._translate_row(row) for row in self.raw_rows]
198
+
199
+ self.batch_count = 1 if self.raw_rows is not None else 0
200
+
201
+ def _translate_row(self, row: dict[str, Any]) -> dict[str, Any]:
202
+ if not isinstance(row, dict):
203
+ raise excs.Error(f'row {row} is not a dictionary')
204
+
205
+ col_names: set[str] = set()
206
+ output_row: dict[str, Any] = {}
207
+ for col_name, val in row.items():
208
+ mapped_col_name = self.source_column_map.get(col_name, col_name)
209
+ col_names.add(mapped_col_name)
210
+ if mapped_col_name not in self.pxt_schema:
211
+ raise excs.Error(f'Unknown column name {mapped_col_name} in row {row}')
212
+ if mapped_col_name in self.computed_col_names:
213
+ raise excs.Error(f'Value for computed column {mapped_col_name} in row {row}')
214
+ # basic sanity checks here
215
+ try:
216
+ checked_val = self.pxt_schema[mapped_col_name].create_literal(val)
217
+ except TypeError as e:
218
+ msg = str(e)
219
+ raise excs.Error(f'Error in column {col_name}: {msg[0].lower() + msg[1:]}\nRow: {row}') from e
220
+ output_row[mapped_col_name] = checked_val
221
+ missing_cols = self.reqd_col_names - col_names
222
+ if len(missing_cols) > 0:
223
+ raise excs.Error(f'Missing required column(s) ({", ".join(missing_cols)}) in row {row}')
224
+ return output_row
225
+
226
+ def valid_row_batch(self) -> Iterator[RowData]:
227
+ if self.batch_count > 0:
228
+ self.batch_count -= 1
229
+ yield self.valid_rows
230
+
231
+
232
+ # ---------------------------------------------------------------------------------------------------------
233
+
234
+
235
+ class PandasTableDataConduit(TableDataConduit):
236
+ pd_df: pd.DataFrame = None
237
+ batch_count: int = 0
238
+
239
+ @classmethod
240
+ def from_tds(cls, tds: TableDataConduit) -> PandasTableDataConduit:
241
+ tds_fields = {f.name for f in fields(tds)}
242
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
243
+ t = cls(**kwargs)
244
+ assert isinstance(tds.source, pd.DataFrame)
245
+ t.pd_df = tds.source
246
+ t.batch_count = 0
247
+ return t
248
+
249
+ def infer_schema_part1(self) -> tuple[dict[str, Any], list[str]]:
250
+ """Return inferred schema, inferred primary key, and source column map"""
251
+ if self.source_column_map is None:
252
+ if self.src_schema_overrides is None:
253
+ self.src_schema_overrides = {}
254
+ self.src_schema = df_infer_schema(self.pd_df, self.src_schema_overrides, self.src_pk)
255
+ inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
256
+ self.src_schema, self.src_pk, self.src_schema_overrides, False
257
+ )
258
+ return inferred_schema, inferred_pk
259
+ else:
260
+ raise NotImplementedError()
261
+
262
+ def infer_schema(self) -> dict[str, Any]:
263
+ self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
264
+ self.normalize_pxt_schema_types()
265
+ _df_check_primary_key_values(self.pd_df, self.src_pk)
266
+ self.prepare_insert()
267
+ return self.pxt_schema
268
+
269
+ def prepare_for_insert_into_table(self) -> None:
270
+ _, inferred_pk = self.infer_schema_part1()
271
+ assert len(inferred_pk) == 0
272
+ self.prepare_insert()
273
+
274
+ def prepare_insert(self) -> None:
275
+ if self.source_column_map is None:
276
+ self.source_column_map = {}
277
+ self.check_source_columns_are_insertable(self.pd_df.columns)
278
+ # Convert all rows to insertable format
279
+ self.valid_rows = [
280
+ _df_row_to_pxt_row(row, self.src_schema, self.source_column_map) for row in self.pd_df.itertuples()
281
+ ]
282
+ self.batch_count = 1
283
+
284
+ def valid_row_batch(self) -> Iterator[RowData]:
285
+ if self.batch_count > 0:
286
+ self.batch_count -= 1
287
+ yield self.valid_rows
288
+
289
+
290
+ # ---------------------------------------------------------------------------------------------------------
291
+
292
+
293
+ class CSVTableDataConduit(TableDataConduit):
294
+ @classmethod
295
+ def from_tds(cls, tds: TableDataConduit) -> 'PandasTableDataConduit':
296
+ tds_fields = {f.name for f in fields(tds)}
297
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
298
+ t = cls(**kwargs)
299
+ assert isinstance(t.source, str)
300
+ t.source = pd.read_csv(t.source, **t.extra_fields)
301
+ return PandasTableDataConduit.from_tds(t)
302
+
303
+
304
+ # ---------------------------------------------------------------------------------------------------------
305
+
306
+
307
+ class ExcelTableDataConduit(TableDataConduit):
308
+ @classmethod
309
+ def from_tds(cls, tds: TableDataConduit) -> 'PandasTableDataConduit':
310
+ tds_fields = {f.name for f in fields(tds)}
311
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
312
+ t = cls(**kwargs)
313
+ assert isinstance(t.source, str)
314
+ t.source = pd.read_excel(t.source, **t.extra_fields)
315
+ return PandasTableDataConduit.from_tds(t)
316
+
317
+
318
+ # ---------------------------------------------------------------------------------------------------------
319
+
320
+
321
+ class JsonTableDataConduit(TableDataConduit):
322
+ @classmethod
323
+ def from_tds(cls, tds: TableDataConduit) -> RowDataTableDataConduit:
324
+ tds_fields = {f.name for f in fields(tds)}
325
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
326
+ t = cls(**kwargs)
327
+ assert isinstance(t.source, str)
328
+
329
+ path = parse_local_file_path(t.source)
330
+ if path is None: # it's a URL
331
+ # TODO: This should read from S3 as well.
332
+ contents = urllib.request.urlopen(t.source).read()
333
+ else:
334
+ with open(path, 'r', encoding='utf-8') as fp:
335
+ contents = fp.read()
336
+ rows = json.loads(contents, **t.extra_fields)
337
+ t.source = rows
338
+ t2 = RowDataTableDataConduit.from_tds(t)
339
+ t2.disable_mapping = False
340
+ return t2
341
+
342
+
343
+ # ---------------------------------------------------------------------------------------------------------
344
+
345
+
346
+ class HFTableDataConduit(TableDataConduit):
347
+ hf_ds: Optional[Union[datasets.Dataset, datasets.DatasetDict]] = None
348
+ column_name_for_split: Optional[str] = None
349
+ categorical_features: dict[str, dict[int, str]]
350
+ hf_schema: dict[str, Any] = None
351
+ dataset_dict: dict[str, datasets.Dataset] = None
352
+ hf_schema_source: dict[str, Any] = None
353
+
354
+ @classmethod
355
+ def from_tds(cls, tds: TableDataConduit) -> 'HFTableDataConduit':
356
+ tds_fields = {f.name for f in fields(tds)}
357
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
358
+ t = cls(**kwargs)
359
+ import datasets
360
+
361
+ assert isinstance(tds.source, (datasets.Dataset, datasets.DatasetDict))
362
+ t.hf_ds = tds.source
363
+ if 'column_name_for_split' in t.extra_fields:
364
+ t.column_name_for_split = t.extra_fields['column_name_for_split']
365
+ return t
366
+
367
+ @classmethod
368
+ def is_applicable(cls, tds: TableDataConduit) -> bool:
369
+ try:
370
+ import datasets
371
+
372
+ return (isinstance(tds.source_format, str) and tds.source_format.lower() == 'huggingface') or isinstance(
373
+ tds.source, (datasets.Dataset, datasets.DatasetDict)
374
+ )
375
+ except ImportError:
376
+ return False
377
+
378
+ def infer_schema_part1(self) -> tuple[dict[str, Any], list[str]]:
379
+ from pixeltable.io.hf_datasets import _get_hf_schema, huggingface_schema_to_pxt_schema
380
+
381
+ if self.source_column_map is None:
382
+ if self.src_schema_overrides is None:
383
+ self.src_schema_overrides = {}
384
+ self.hf_schema_source = _get_hf_schema(self.hf_ds)
385
+ self.src_schema = huggingface_schema_to_pxt_schema(
386
+ self.hf_schema_source, self.src_schema_overrides, self.src_pk
387
+ )
388
+
389
+ # Add the split column to the schema if requested
390
+ if self.column_name_for_split is not None:
391
+ if self.column_name_for_split in self.src_schema:
392
+ raise excs.Error(
393
+ f'Column name `{self.column_name_for_split}` already exists in dataset schema;'
394
+ f'provide a different `column_name_for_split`'
395
+ )
396
+ self.src_schema[self.column_name_for_split] = pxt.StringType(nullable=True)
397
+
398
+ inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
399
+ self.src_schema, self.src_pk, self.src_schema_overrides, True
400
+ )
401
+ return inferred_schema, inferred_pk
402
+ else:
403
+ raise NotImplementedError()
404
+
405
+ def infer_schema(self) -> dict[str, Any]:
406
+ self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
407
+ self.normalize_pxt_schema_types()
408
+ self.prepare_insert()
409
+ return self.pxt_schema
410
+
411
+ def prepare_for_insert_into_table(self) -> None:
412
+ _, inferred_pk = self.infer_schema_part1()
413
+ assert len(inferred_pk) == 0
414
+ self.prepare_insert()
415
+
416
+ def prepare_insert(self) -> None:
417
+ import datasets
418
+
419
+ if isinstance(self.source, datasets.Dataset):
420
+ # when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
421
+ raw_name = self.source.split._name
422
+ split_name = raw_name.split('[')[0] if raw_name is not None else None
423
+ self.dataset_dict = {split_name: self.source}
424
+ else:
425
+ assert isinstance(self.source, datasets.DatasetDict)
426
+ self.dataset_dict = self.source
427
+
428
+ # extract all class labels from the dataset to translate category ints to strings
429
+ self.categorical_features = {
430
+ feature_name: feature_type.names
431
+ for (feature_name, feature_type) in self.hf_schema_source.items()
432
+ if isinstance(feature_type, datasets.ClassLabel)
433
+ }
434
+ if self.source_column_map is None:
435
+ self.source_column_map = {}
436
+ self.check_source_columns_are_insertable(self.hf_schema_source.keys())
437
+
438
+ def _translate_row(self, row: dict[str, Any], split_name: str) -> dict[str, Any]:
439
+ output_row: dict[str, Any] = {}
440
+ for col_name, val in row.items():
441
+ # translate category ints to strings
442
+ new_val = self.categorical_features[col_name][val] if col_name in self.categorical_features else val
443
+ mapped_col_name = self.source_column_map.get(col_name, col_name)
444
+
445
+ # Convert values to the appropriate type if needed
446
+ try:
447
+ checked_val = self.pxt_schema[mapped_col_name].create_literal(new_val)
448
+ except TypeError as e:
449
+ msg = str(e)
450
+ raise excs.Error(f'Error in column {col_name}: {msg[0].lower() + msg[1:]}\nRow: {row}') from e
451
+ output_row[mapped_col_name] = checked_val
452
+
453
+ # add split name to output row
454
+ if self.column_name_for_split is not None:
455
+ output_row[self.column_name_for_split] = split_name
456
+ return output_row
457
+
458
+ def valid_row_batch(self) -> Iterator[RowData]:
459
+ for split_name, split_dataset in self.dataset_dict.items():
460
+ num_batches = split_dataset.size_in_bytes / self._K_BATCH_SIZE_BYTES
461
+ tuples_per_batch = math.ceil(split_dataset.num_rows / num_batches)
462
+ assert tuples_per_batch > 0
463
+
464
+ batch = []
465
+ for row in split_dataset:
466
+ batch.append(self._translate_row(row, split_name))
467
+ if len(batch) >= tuples_per_batch:
468
+ yield batch
469
+ batch = []
470
+ # last batch
471
+ if len(batch) > 0:
472
+ yield batch
473
+
474
+
475
+ # ---------------------------------------------------------------------------------------------------------
476
+
477
+
478
+ class ParquetTableDataConduit(TableDataConduit):
479
+ pq_ds: Optional[ParquetDataset] = None
480
+
481
+ @classmethod
482
+ def from_tds(cls, tds: TableDataConduit) -> 'ParquetTableDataConduit':
483
+ tds_fields = {f.name for f in fields(tds)}
484
+ kwargs = {k: v for k, v in tds.__dict__.items() if k in tds_fields}
485
+ t = cls(**kwargs)
486
+
487
+ from pyarrow import parquet
488
+
489
+ assert isinstance(tds.source, str)
490
+ input_path = Path(tds.source).expanduser()
491
+ t.pq_ds = parquet.ParquetDataset(str(input_path))
492
+ return t
493
+
494
+ def infer_schema_part1(self) -> tuple[dict[str, Any], list[str]]:
495
+ from pixeltable.utils.arrow import ar_infer_schema
496
+
497
+ if self.source_column_map is None:
498
+ if self.src_schema_overrides is None:
499
+ self.src_schema_overrides = {}
500
+ self.src_schema = ar_infer_schema(self.pq_ds.schema, self.src_schema_overrides, self.src_pk)
501
+ inferred_schema, inferred_pk, self.source_column_map = normalize_schema_names(
502
+ self.src_schema, self.src_pk, self.src_schema_overrides
503
+ )
504
+ return inferred_schema, inferred_pk
505
+ else:
506
+ raise NotImplementedError()
507
+
508
+ def infer_schema(self) -> dict[str, Any]:
509
+ self.pxt_schema, self.pxt_pk = self.infer_schema_part1()
510
+ self.normalize_pxt_schema_types()
511
+ self.prepare_insert()
512
+ return self.pxt_schema
513
+
514
+ def prepare_for_insert_into_table(self) -> None:
515
+ _, inferred_pk = self.infer_schema_part1()
516
+ assert len(inferred_pk) == 0
517
+ self.prepare_insert()
518
+
519
+ def prepare_insert(self) -> None:
520
+ if self.source_column_map is None:
521
+ self.source_column_map = {}
522
+ self.check_source_columns_are_insertable(self.pq_ds.schema.names)
523
+ self.total_rows = 0
524
+
525
+ def valid_row_batch(self) -> Iterator[RowData]:
526
+ from pixeltable.utils.arrow import iter_tuples2
527
+
528
+ try:
529
+ for fragment in self.pq_ds.fragments: # type: ignore[attr-defined]
530
+ for batch in fragment.to_batches():
531
+ dict_batch = list(iter_tuples2(batch, self.source_column_map, self.pxt_schema))
532
+ self.total_rows += len(dict_batch)
533
+ yield dict_batch
534
+ except Exception as e:
535
+ _logger.error(f'Error after inserting {self.total_rows} rows from Parquet file into table: {e}')
536
+ raise e
537
+
538
+
539
+ # ---------------------------------------------------------------------------------------------------------
540
+
541
+
542
+ class UnkTableDataConduit(TableDataConduit):
543
+ """Source type is not known at the time of creation"""
544
+
545
+ def specialize(self) -> TableDataConduit:
546
+ if isinstance(self.source, pxt.DataFrame):
547
+ return DFTableDataConduit.from_tds(self)
548
+ if isinstance(self.source, pd.DataFrame):
549
+ return PandasTableDataConduit.from_tds(self)
550
+ if HFTableDataConduit.is_applicable(self):
551
+ return HFTableDataConduit.from_tds(self)
552
+ if self.source_format == 'csv' or (isinstance(self.source, str) and '.csv' in self.source.lower()):
553
+ return CSVTableDataConduit.from_tds(self)
554
+ if self.source_format == 'excel' or (isinstance(self.source, str) and '.xls' in self.source.lower()):
555
+ return ExcelTableDataConduit.from_tds(self)
556
+ if self.source_format == 'json' or (isinstance(self.source, str) and '.json' in self.source.lower()):
557
+ return JsonTableDataConduit.from_tds(self)
558
+ if self.source_format == 'parquet' or (
559
+ isinstance(self.source, str) and any(s in self.source.lower() for s in ['.parquet', '.pq', '.parq'])
560
+ ):
561
+ return ParquetTableDataConduit.from_tds(self)
562
+ if (
563
+ self.is_rowdata_structure(self.source)
564
+ # An Iterator as a source is assumed to produce rows
565
+ or isinstance(self.source, Iterator)
566
+ ):
567
+ return RowDataTableDataConduit.from_tds(self)
568
+
569
+ raise excs.Error(f'Unsupported data source type: {type(self.source)}')
pixeltable/io/utils.py CHANGED
@@ -3,7 +3,6 @@ from typing import Any, Optional, Union
3
3
 
4
4
  import pixeltable as pxt
5
5
  import pixeltable.exceptions as excs
6
- from pixeltable import Table
7
6
  from pixeltable.catalog.globals import is_system_column_name
8
7
 
9
8
 
@@ -22,16 +21,14 @@ def normalize_pxt_col_name(name: str) -> str:
22
21
  return id
23
22
 
24
23
 
25
- def normalize_import_parameters(
26
- schema_overrides: Optional[dict[str, Any]] = None, primary_key: Optional[Union[str, list[str]]] = None
27
- ) -> tuple[dict[str, Any], list[str]]:
28
- if schema_overrides is None:
29
- schema_overrides = {}
24
+ def normalize_primary_key_parameter(primary_key: Optional[Union[str, list[str]]] = None) -> list[str]:
30
25
  if primary_key is None:
31
26
  primary_key = []
32
27
  elif isinstance(primary_key, str):
33
28
  primary_key = [primary_key]
34
- return schema_overrides, primary_key
29
+ elif not isinstance(primary_key, list) or not all(isinstance(pk, str) for pk in primary_key):
30
+ raise excs.Error('primary_key must be a single column name or a list of column names')
31
+ return primary_key
35
32
 
36
33
 
37
34
  def _is_usable_as_column_name(name: str, destination_schema: dict[str, Any]) -> bool:
@@ -65,7 +62,8 @@ def normalize_schema_names(
65
62
  extraneous_overrides = schema_overrides.keys() - in_schema.keys()
66
63
  if len(extraneous_overrides) > 0:
67
64
  raise excs.Error(
68
- f'Some column(s) specified in `schema_overrides` are not present in the source: {", ".join(extraneous_overrides)}'
65
+ f'Some column(s) specified in `schema_overrides` are not present '
66
+ f'in the source: {", ".join(extraneous_overrides)}'
69
67
  )
70
68
 
71
69
  schema: dict[str, Any] = {}
@@ -100,16 +98,3 @@ def normalize_schema_names(
100
98
  pxt_pk = [col_mapping[pk] for pk in primary_key] if col_mapping is not None else primary_key
101
99
 
102
100
  return schema, pxt_pk, col_mapping
103
-
104
-
105
- def find_or_create_table(
106
- tbl_path: str,
107
- schema: dict[str, Any],
108
- *,
109
- primary_key: Optional[Union[str, list[str]]],
110
- num_retained_versions: int,
111
- comment: str,
112
- ) -> Table:
113
- return pxt.create_table(
114
- tbl_path, schema, primary_key=primary_key, num_retained_versions=num_retained_versions, comment=comment
115
- )
@@ -16,7 +16,7 @@ _console_logger = ConsoleLogger(logging.getLogger('pixeltable'))
16
16
 
17
17
 
18
18
  # current version of the metadata; this is incremented whenever the metadata schema changes
19
- VERSION = 30
19
+ VERSION = 31
20
20
 
21
21
 
22
22
  def create_system_info(engine: sql.engine.Engine) -> None:
@@ -0,0 +1,50 @@
1
+ import copy
2
+
3
+ import sqlalchemy as sql
4
+
5
+ from pixeltable.metadata import register_converter
6
+ from pixeltable.metadata.converters.util import (
7
+ convert_table_record,
8
+ convert_table_schema_version_record,
9
+ convert_table_version_record,
10
+ )
11
+ from pixeltable.metadata.schema import Table, TableSchemaVersion, TableVersion
12
+
13
+
14
+ @register_converter(version=30)
15
+ def _(engine: sql.engine.Engine) -> None:
16
+ convert_table_record(engine, table_record_updater=__update_table_record)
17
+ convert_table_version_record(engine, table_version_record_updater=__update_table_version_record)
18
+ convert_table_schema_version_record(
19
+ engine, table_schema_version_record_updater=__update_table_schema_version_record
20
+ )
21
+
22
+
23
+ def __update_table_record(record: Table) -> None:
24
+ """
25
+ Update TableMd with table_id
26
+ """
27
+ assert isinstance(record.md, dict)
28
+ md = copy.copy(record.md)
29
+ md['tbl_id'] = str(record.id)
30
+ record.md = md
31
+
32
+
33
+ def __update_table_version_record(record: TableVersion) -> None:
34
+ """
35
+ Update TableVersion with table_id.
36
+ """
37
+ assert isinstance(record.md, dict)
38
+ md = copy.copy(record.md)
39
+ md['tbl_id'] = str(record.tbl_id)
40
+ record.md = md
41
+
42
+
43
+ def __update_table_schema_version_record(record: TableSchemaVersion) -> None:
44
+ """
45
+ Update TableSchemaVersion with table_id.
46
+ """
47
+ assert isinstance(record.md, dict)
48
+ md = copy.copy(record.md)
49
+ md['tbl_id'] = str(record.tbl_id)
50
+ record.md = md