pixeltable 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (52) hide show
  1. pixeltable/__init__.py +1 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/catalog.py +509 -103
  4. pixeltable/catalog/column.py +5 -0
  5. pixeltable/catalog/dir.py +15 -6
  6. pixeltable/catalog/globals.py +16 -0
  7. pixeltable/catalog/insertable_table.py +82 -41
  8. pixeltable/catalog/path.py +15 -0
  9. pixeltable/catalog/schema_object.py +7 -12
  10. pixeltable/catalog/table.py +81 -67
  11. pixeltable/catalog/table_version.py +23 -7
  12. pixeltable/catalog/view.py +9 -6
  13. pixeltable/env.py +15 -9
  14. pixeltable/exec/exec_node.py +1 -1
  15. pixeltable/exprs/__init__.py +2 -1
  16. pixeltable/exprs/arithmetic_expr.py +2 -0
  17. pixeltable/exprs/column_ref.py +38 -2
  18. pixeltable/exprs/expr.py +61 -12
  19. pixeltable/exprs/function_call.py +1 -4
  20. pixeltable/exprs/globals.py +12 -0
  21. pixeltable/exprs/json_mapper.py +4 -4
  22. pixeltable/exprs/json_path.py +10 -11
  23. pixeltable/exprs/similarity_expr.py +5 -20
  24. pixeltable/exprs/string_op.py +107 -0
  25. pixeltable/ext/functions/yolox.py +21 -64
  26. pixeltable/func/callable_function.py +5 -2
  27. pixeltable/func/query_template_function.py +6 -18
  28. pixeltable/func/tools.py +2 -2
  29. pixeltable/functions/__init__.py +1 -1
  30. pixeltable/functions/globals.py +16 -5
  31. pixeltable/globals.py +172 -262
  32. pixeltable/io/__init__.py +3 -2
  33. pixeltable/io/datarows.py +138 -0
  34. pixeltable/io/external_store.py +8 -5
  35. pixeltable/io/globals.py +7 -160
  36. pixeltable/io/hf_datasets.py +21 -98
  37. pixeltable/io/pandas.py +29 -43
  38. pixeltable/io/parquet.py +17 -42
  39. pixeltable/io/table_data_conduit.py +569 -0
  40. pixeltable/io/utils.py +6 -21
  41. pixeltable/metadata/__init__.py +1 -1
  42. pixeltable/metadata/converters/convert_30.py +50 -0
  43. pixeltable/metadata/converters/util.py +26 -1
  44. pixeltable/metadata/notes.py +1 -0
  45. pixeltable/metadata/schema.py +3 -0
  46. pixeltable/utils/arrow.py +32 -7
  47. pixeltable/utils/coroutine.py +41 -0
  48. {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/METADATA +1 -1
  49. {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/RECORD +52 -47
  50. {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/WHEEL +1 -1
  51. {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/LICENSE +0 -0
  52. {pixeltable-0.3.8.dist-info → pixeltable-0.3.10.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,138 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Iterable, Optional, Union
4
+
5
+ import pixeltable as pxt
6
+ from pixeltable import exceptions as excs
7
+
8
+
9
+ def _infer_schema_from_rows(
10
+ rows: Iterable[dict[str, Any]], schema_overrides: dict[str, Any], primary_key: list[str]
11
+ ) -> dict[str, pxt.ColumnType]:
12
+ schema: dict[str, pxt.ColumnType] = {}
13
+ cols_with_nones: set[str] = set()
14
+
15
+ for n, row in enumerate(rows):
16
+ for col_name, value in row.items():
17
+ if col_name in schema_overrides:
18
+ # We do the insertion here; this will ensure that the column order matches the order
19
+ # in which the column names are encountered in the input data, even if `schema_overrides`
20
+ # is specified.
21
+ if col_name not in schema:
22
+ schema[col_name] = schema_overrides[col_name]
23
+ elif value is not None:
24
+ # If `key` is not in `schema_overrides`, then we infer its type from the data.
25
+ # The column type will always be nullable by default.
26
+ col_type = pxt.ColumnType.infer_literal_type(value, nullable=col_name not in primary_key)
27
+ if col_type is None:
28
+ raise excs.Error(
29
+ f'Could not infer type for column `{col_name}`; the value in row {n} '
30
+ f'has an unsupported type: {type(value)}'
31
+ )
32
+ if col_name not in schema:
33
+ schema[col_name] = col_type
34
+ else:
35
+ supertype = schema[col_name].supertype(col_type)
36
+ if supertype is None:
37
+ raise excs.Error(
38
+ f'Could not infer type of column `{col_name}`; the value in row {n} '
39
+ f'does not match preceding type {schema[col_name]}: {value!r}\n'
40
+ 'Consider specifying the type explicitly in `schema_overrides`.'
41
+ )
42
+ schema[col_name] = supertype
43
+ else:
44
+ cols_with_nones.add(col_name)
45
+
46
+ entirely_none_cols = cols_with_nones - schema.keys()
47
+ if len(entirely_none_cols) > 0:
48
+ # A column can only end up in `entirely_none_cols` if it was not in `schema_overrides` and
49
+ # was not encountered in any row with a non-None value.
50
+ raise excs.Error(
51
+ f'The following columns have no non-null values: {", ".join(entirely_none_cols)}\n'
52
+ 'Consider specifying the type(s) explicitly in `schema_overrides`.'
53
+ )
54
+ return schema
55
+
56
+
57
+ def import_rows(
58
+ tbl_path: str,
59
+ rows: list[dict[str, Any]],
60
+ *,
61
+ schema_overrides: Optional[dict[str, Any]] = None,
62
+ primary_key: Optional[Union[str, list[str]]] = None,
63
+ num_retained_versions: int = 10,
64
+ comment: str = '',
65
+ ) -> pxt.Table:
66
+ """
67
+ Creates a new base table from a list of dictionaries. The dictionaries must be of the
68
+ form `{column_name: value, ...}`. Pixeltable will attempt to infer the schema of the table from the
69
+ supplied data, using the most specific type that can represent all the values in a column.
70
+
71
+ If `schema_overrides` is specified, then for each entry `(column_name, type)` in `schema_overrides`,
72
+ Pixeltable will force the specified column to the specified type (and will not attempt any type inference
73
+ for that column).
74
+
75
+ All column types of the new table will be nullable unless explicitly specified as non-nullable in
76
+ `schema_overrides`.
77
+
78
+ Args:
79
+ tbl_path: The qualified name of the table to create.
80
+ rows: The list of dictionaries to import.
81
+ schema_overrides: If specified, then columns in `schema_overrides` will be given the specified types
82
+ as described above.
83
+ primary_key: The primary key of the table (see [`create_table()`][pixeltable.create_table]).
84
+ num_retained_versions: The number of retained versions of the table
85
+ (see [`create_table()`][pixeltable.create_table]).
86
+ comment: A comment to attach to the table (see [`create_table()`][pixeltable.create_table]).
87
+
88
+ Returns:
89
+ A handle to the newly created [`Table`][pixeltable.Table].
90
+ """
91
+ return pxt.create_table(
92
+ tbl_path,
93
+ source=rows,
94
+ schema_overrides=schema_overrides,
95
+ primary_key=primary_key,
96
+ num_retained_versions=num_retained_versions,
97
+ comment=comment,
98
+ )
99
+
100
+
101
+ def import_json(
102
+ tbl_path: str,
103
+ filepath_or_url: str,
104
+ *,
105
+ schema_overrides: Optional[dict[str, Any]] = None,
106
+ primary_key: Optional[Union[str, list[str]]] = None,
107
+ num_retained_versions: int = 10,
108
+ comment: str = '',
109
+ **kwargs: Any,
110
+ ) -> pxt.Table:
111
+ """
112
+ Creates a new base table from a JSON file. This is a convenience method and is
113
+ equivalent to calling `import_data(table_path, json.loads(file_contents, **kwargs), ...)`, where `file_contents`
114
+ is the contents of the specified `filepath_or_url`.
115
+
116
+ Args:
117
+ tbl_path: The name of the table to create.
118
+ filepath_or_url: The path or URL of the JSON file.
119
+ schema_overrides: If specified, then columns in `schema_overrides` will be given the specified types
120
+ (see [`import_rows()`][pixeltable.io.import_rows]).
121
+ primary_key: The primary key of the table (see [`create_table()`][pixeltable.create_table]).
122
+ num_retained_versions: The number of retained versions of the table
123
+ (see [`create_table()`][pixeltable.create_table]).
124
+ comment: A comment to attach to the table (see [`create_table()`][pixeltable.create_table]).
125
+ kwargs: Additional keyword arguments to pass to `json.loads`.
126
+
127
+ Returns:
128
+ A handle to the newly created [`Table`][pixeltable.Table].
129
+ """
130
+ return pxt.create_table(
131
+ tbl_path,
132
+ source=filepath_or_url,
133
+ schema_overrides=schema_overrides,
134
+ primary_key=primary_key,
135
+ num_retained_versions=num_retained_versions,
136
+ comment=comment,
137
+ extra_args=kwargs,
138
+ )
@@ -97,7 +97,7 @@ class Project(ExternalStore, abc.ABC):
97
97
  # This ensures that the media in those columns resides in the media store.
98
98
  # First determine which columns (if any) need stored proxies, but don't have one yet.
99
99
  stored_proxies_needed: list[Column] = []
100
- for col in self.col_mapping.keys():
100
+ for col in self.col_mapping:
101
101
  if col.col_type.is_media_type() and not (col.is_stored and col.is_computed):
102
102
  # If this column is already proxied in some other Project, use the existing proxy to avoid
103
103
  # duplication. Otherwise, we'll create a new one.
@@ -234,7 +234,8 @@ class Project(ExternalStore, abc.ABC):
234
234
  else:
235
235
  raise excs.Error(
236
236
  f'Column `{t_col}` does not exist in Table `{table._name}`. Either add a column `{t_col}`, '
237
- f'or specify a `col_mapping` to associate a different column with the external field `{ext_col}`.'
237
+ f'or specify a `col_mapping` to associate a different column with '
238
+ f'the external field `{ext_col}`.'
238
239
  )
239
240
  if ext_col not in export_cols and ext_col not in import_cols:
240
241
  raise excs.Error(
@@ -253,7 +254,8 @@ class Project(ExternalStore, abc.ABC):
253
254
  ext_col_type = export_cols[ext_col]
254
255
  if not ext_col_type.is_supertype_of(t_col_type, ignore_nullable=True):
255
256
  raise excs.Error(
256
- f'Column `{t_col}` cannot be exported to external column `{ext_col}` (incompatible types; expecting `{ext_col_type}`)'
257
+ f'Column `{t_col}` cannot be exported to external column `{ext_col}` '
258
+ f'(incompatible types; expecting `{ext_col_type}`)'
257
259
  )
258
260
  if ext_col in import_cols:
259
261
  # Validate that the external column can be assigned to the table column
@@ -264,7 +266,8 @@ class Project(ExternalStore, abc.ABC):
264
266
  ext_col_type = import_cols[ext_col]
265
267
  if not t_col_type.is_supertype_of(ext_col_type, ignore_nullable=True):
266
268
  raise excs.Error(
267
- f'Column `{t_col}` cannot be imported from external column `{ext_col}` (incompatible types; expecting `{ext_col_type}`)'
269
+ f'Column `{t_col}` cannot be imported from external column `{ext_col}` '
270
+ f'(incompatible types; expecting `{ext_col_type}`)'
268
271
  )
269
272
  return resolved_col_mapping
270
273
 
@@ -368,7 +371,7 @@ class MockProject(Project):
368
371
  {cls._column_from_dict(entry[0]): cls._column_from_dict(entry[1]) for entry in md['stored_proxies']},
369
372
  )
370
373
 
371
- def __eq__(self, other: Any) -> bool:
374
+ def __eq__(self, other: object) -> bool:
372
375
  if not isinstance(other, MockProject):
373
376
  return False
374
377
  return self.name == other.name
pixeltable/io/globals.py CHANGED
@@ -1,7 +1,5 @@
1
- import json
2
- import urllib.parse
3
- import urllib.request
4
- from pathlib import Path
1
+ from __future__ import annotations
2
+
5
3
  from typing import TYPE_CHECKING, Any, Literal, Optional, Union
6
4
 
7
5
  import pixeltable as pxt
@@ -9,61 +7,11 @@ import pixeltable.exceptions as excs
9
7
  from pixeltable import Table, exprs
10
8
  from pixeltable.env import Env
11
9
  from pixeltable.io.external_store import SyncStatus
12
- from pixeltable.utils import parse_local_file_path
13
10
 
14
11
  if TYPE_CHECKING:
15
12
  import fiftyone as fo # type: ignore[import-untyped]
16
13
 
17
14
 
18
- from .utils import find_or_create_table, normalize_import_parameters, normalize_schema_names
19
-
20
-
21
- def _infer_schema_from_rows(
22
- rows: list[dict[str, Any]], schema_overrides: dict[str, Any], primary_key: list[str]
23
- ) -> dict[str, pxt.ColumnType]:
24
- schema: dict[str, pxt.ColumnType] = {}
25
- cols_with_nones: set[str] = set()
26
-
27
- for n, row in enumerate(rows):
28
- for col_name, value in row.items():
29
- if col_name in schema_overrides:
30
- # We do the insertion here; this will ensure that the column order matches the order
31
- # in which the column names are encountered in the input data, even if `schema_overrides`
32
- # is specified.
33
- if col_name not in schema:
34
- schema[col_name] = schema_overrides[col_name]
35
- elif value is not None:
36
- # If `key` is not in `schema_overrides`, then we infer its type from the data.
37
- # The column type will always be nullable by default.
38
- col_type = pxt.ColumnType.infer_literal_type(value, nullable=col_name not in primary_key)
39
- if col_type is None:
40
- raise excs.Error(
41
- f'Could not infer type for column `{col_name}`; the value in row {n} has an unsupported type: {type(value)}'
42
- )
43
- if col_name not in schema:
44
- schema[col_name] = col_type
45
- else:
46
- supertype = schema[col_name].supertype(col_type)
47
- if supertype is None:
48
- raise excs.Error(
49
- f'Could not infer type of column `{col_name}`; the value in row {n} does not match preceding type {schema[col_name]}: {value!r}\n'
50
- 'Consider specifying the type explicitly in `schema_overrides`.'
51
- )
52
- schema[col_name] = supertype
53
- else:
54
- cols_with_nones.add(col_name)
55
-
56
- entirely_none_cols = cols_with_nones - schema.keys()
57
- if len(entirely_none_cols) > 0:
58
- # A column can only end up in `entirely_none_cols` if it was not in `schema_overrides` and
59
- # was not encountered in any row with a non-None value.
60
- raise excs.Error(
61
- f'The following columns have no non-null values: {", ".join(entirely_none_cols)}\n'
62
- 'Consider specifying the type(s) explicitly in `schema_overrides`.'
63
- )
64
- return schema
65
-
66
-
67
15
  def create_label_studio_project(
68
16
  t: Table,
69
17
  label_config: str,
@@ -140,9 +88,9 @@ def create_label_studio_project(
140
88
  parameters of the Label Studio `connect_s3_import_storage` method, as described in the
141
89
  [Label Studio connect_s3_import_storage docs](https://labelstud.io/sdk/project.html#label_studio_sdk.project.Project.connect_s3_import_storage).
142
90
  `bucket` must be specified; all other parameters are optional. If credentials are not specified explicitly,
143
- Pixeltable will attempt to retrieve them from the environment (such as from `~/.aws/credentials`). If a title is not
144
- specified, Pixeltable will use the default `'Pixeltable-S3-Import-Storage'`. All other parameters use their Label
145
- Studio defaults.
91
+ Pixeltable will attempt to retrieve them from the environment (such as from `~/.aws/credentials`).
92
+ If a title is not specified, Pixeltable will use the default `'Pixeltable-S3-Import-Storage'`.
93
+ All other parameters use their Label Studio defaults.
146
94
  kwargs: Additional keyword arguments are passed to the `start_project` method in the Label
147
95
  Studio SDK, as described in the
148
96
  [Label Studio start_project docs](https://labelstud.io/sdk/project.html#label_studio_sdk.project.Project.start_project).
@@ -151,7 +99,8 @@ def create_label_studio_project(
151
99
  A `SyncStatus` representing the status of any synchronization operations that occurred.
152
100
 
153
101
  Examples:
154
- Create a Label Studio project whose tasks correspond to videos stored in the `video_col` column of the table `tbl`:
102
+ Create a Label Studio project whose tasks correspond to videos stored in the `video_col`
103
+ column of the table `tbl`:
155
104
 
156
105
  >>> config = \"\"\"
157
106
  <View>
@@ -190,108 +139,6 @@ def create_label_studio_project(
190
139
  return SyncStatus.empty()
191
140
 
192
141
 
193
- def import_rows(
194
- tbl_path: str,
195
- rows: list[dict[str, Any]],
196
- *,
197
- schema_overrides: Optional[dict[str, Any]] = None,
198
- primary_key: Optional[Union[str, list[str]]] = None,
199
- num_retained_versions: int = 10,
200
- comment: str = '',
201
- ) -> Table:
202
- """
203
- Creates a new base table from a list of dictionaries. The dictionaries must be of the
204
- form `{column_name: value, ...}`. Pixeltable will attempt to infer the schema of the table from the
205
- supplied data, using the most specific type that can represent all the values in a column.
206
-
207
- If `schema_overrides` is specified, then for each entry `(column_name, type)` in `schema_overrides`,
208
- Pixeltable will force the specified column to the specified type (and will not attempt any type inference
209
- for that column).
210
-
211
- All column types of the new table will be nullable unless explicitly specified as non-nullable in
212
- `schema_overrides`.
213
-
214
- Args:
215
- tbl_path: The qualified name of the table to create.
216
- rows: The list of dictionaries to import.
217
- schema_overrides: If specified, then columns in `schema_overrides` will be given the specified types
218
- as described above.
219
- primary_key: The primary key of the table (see [`create_table()`][pixeltable.create_table]).
220
- num_retained_versions: The number of retained versions of the table (see [`create_table()`][pixeltable.create_table]).
221
- comment: A comment to attach to the table (see [`create_table()`][pixeltable.create_table]).
222
-
223
- Returns:
224
- A handle to the newly created [`Table`][pixeltable.Table].
225
- """
226
- schema_overrides, primary_key = normalize_import_parameters(schema_overrides, primary_key)
227
- row_schema = _infer_schema_from_rows(rows, schema_overrides, primary_key)
228
- schema, pxt_pk, _ = normalize_schema_names(row_schema, primary_key, schema_overrides, True)
229
-
230
- table = find_or_create_table(
231
- tbl_path, schema, primary_key=pxt_pk, num_retained_versions=num_retained_versions, comment=comment
232
- )
233
- table.insert(rows)
234
- return table
235
-
236
-
237
- def import_json(
238
- tbl_path: str,
239
- filepath_or_url: str,
240
- *,
241
- schema_overrides: Optional[dict[str, Any]] = None,
242
- primary_key: Optional[Union[str, list[str]]] = None,
243
- num_retained_versions: int = 10,
244
- comment: str = '',
245
- **kwargs: Any,
246
- ) -> Table:
247
- """
248
- Creates a new base table from a JSON file. This is a convenience method and is
249
- equivalent to calling `import_data(table_path, json.loads(file_contents, **kwargs), ...)`, where `file_contents`
250
- is the contents of the specified `filepath_or_url`.
251
-
252
- Args:
253
- tbl_path: The name of the table to create.
254
- filepath_or_url: The path or URL of the JSON file.
255
- schema_overrides: If specified, then columns in `schema_overrides` will be given the specified types
256
- (see [`import_rows()`][pixeltable.io.import_rows]).
257
- primary_key: The primary key of the table (see [`create_table()`][pixeltable.create_table]).
258
- num_retained_versions: The number of retained versions of the table (see [`create_table()`][pixeltable.create_table]).
259
- comment: A comment to attach to the table (see [`create_table()`][pixeltable.create_table]).
260
- kwargs: Additional keyword arguments to pass to `json.loads`.
261
-
262
- Returns:
263
- A handle to the newly created [`Table`][pixeltable.Table].
264
- """
265
- path = parse_local_file_path(filepath_or_url)
266
- if path is None: # it's a URL
267
- # TODO: This should read from S3 as well.
268
- contents = urllib.request.urlopen(filepath_or_url).read()
269
- else:
270
- with open(path) as fp:
271
- contents = fp.read()
272
-
273
- rows = json.loads(contents, **kwargs)
274
-
275
- schema_overrides, primary_key = normalize_import_parameters(schema_overrides, primary_key)
276
- row_schema = _infer_schema_from_rows(rows, schema_overrides, primary_key)
277
- schema, pxt_pk, col_mapping = normalize_schema_names(row_schema, primary_key, schema_overrides, False)
278
-
279
- # Convert all rows to insertable format - not needed, misnamed columns and types are errors in the incoming row format
280
- if col_mapping is not None:
281
- tbl_rows = [
282
- {field if col_mapping is None else col_mapping[field]: val for field, val in row.items()} for row in rows
283
- ]
284
- else:
285
- tbl_rows = rows
286
-
287
- table = find_or_create_table(
288
- tbl_path, schema, primary_key=pxt_pk, num_retained_versions=num_retained_versions, comment=comment
289
- )
290
-
291
- table.insert(tbl_rows)
292
- return table
293
-
294
-
295
142
  def export_images_as_fo_dataset(
296
143
  tbl: pxt.Table,
297
144
  images: exprs.Expr,
@@ -1,41 +1,38 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
- import math
5
- import random
6
3
  import typing
7
4
  from typing import Any, Optional, Union
8
5
 
9
6
  import pixeltable as pxt
10
7
  import pixeltable.type_system as ts
11
- from pixeltable import exceptions as excs
12
-
13
- from .utils import normalize_import_parameters, normalize_schema_names
14
8
 
15
9
  if typing.TYPE_CHECKING:
16
10
  import datasets # type: ignore[import-untyped]
17
11
 
18
- _logger = logging.getLogger('pixeltable')
19
-
20
- # use 100MB as the batch size limit for loading a huggingface dataset into pixeltable.
21
- # The primary goal is to bound memory use, regardless of dataset size.
22
- # Second goal is to limit overhead. 100MB is presumed to be reasonable for a lot of storage systems.
23
- _K_BATCH_SIZE_BYTES = 100_000_000
24
12
 
25
- # note, there are many more types. we allow overrides in the schema_override parameter
13
+ # note, there are many more types. we allow overrides in the schema_overrides parameter
26
14
  # to handle cases where the appropriate type is not yet mapped, or to override this mapping.
27
15
  # https://huggingface.co/docs/datasets/v2.17.0/en/package_reference/main_classes#datasets.Value
28
16
  _hf_to_pxt: dict[str, ts.ColumnType] = {
29
- 'int32': ts.IntType(nullable=True), # pixeltable widens to big int
30
- 'int64': ts.IntType(nullable=True),
31
17
  'bool': ts.BoolType(nullable=True),
18
+ 'int8': ts.IntType(nullable=True),
19
+ 'int16': ts.IntType(nullable=True),
20
+ 'int32': ts.IntType(nullable=True),
21
+ 'int64': ts.IntType(nullable=True),
22
+ 'uint8': ts.IntType(nullable=True),
23
+ 'uint16': ts.IntType(nullable=True),
24
+ 'uint32': ts.IntType(nullable=True),
25
+ 'uint64': ts.IntType(nullable=True),
26
+ 'float16': ts.FloatType(nullable=True),
32
27
  'float32': ts.FloatType(nullable=True),
33
28
  'float64': ts.FloatType(nullable=True),
34
- 'large_string': ts.StringType(nullable=True),
35
29
  'string': ts.StringType(nullable=True),
30
+ 'large_string': ts.StringType(nullable=True),
36
31
  'timestamp[s]': ts.TimestampType(nullable=True),
37
32
  'timestamp[ms]': ts.TimestampType(nullable=True), # HF dataset iterator converts timestamps to datetime.datetime
38
33
  'timestamp[us]': ts.TimestampType(nullable=True),
34
+ 'date32': ts.StringType(nullable=True), # date32 is not supported in pixeltable, use string
35
+ 'date64': ts.StringType(nullable=True), # date64 is not supported in pixeltable, use string
39
36
  }
40
37
 
41
38
 
@@ -88,7 +85,6 @@ def import_huggingface_dataset(
88
85
  table_path: str,
89
86
  dataset: Union[datasets.Dataset, datasets.DatasetDict],
90
87
  *,
91
- column_name_for_split: Optional[str] = None,
92
88
  schema_overrides: Optional[dict[str, Any]] = None,
93
89
  primary_key: Optional[Union[str, list[str]]] = None,
94
90
  **kwargs: Any,
@@ -101,91 +97,18 @@ def import_huggingface_dataset(
101
97
  dataset: Huggingface [`datasets.Dataset`](https://huggingface.co/docs/datasets/en/package_reference/main_classes#datasets.Dataset)
102
98
  or [`datasets.DatasetDict`](https://huggingface.co/docs/datasets/en/package_reference/main_classes#datasets.DatasetDict)
103
99
  to insert into the table.
104
- column_name_for_split: column name to use for split information. If None, no split information will be stored.
105
100
  schema_overrides: If specified, then for each (name, type) pair in `schema_overrides`, the column with
106
- name `name` will be given type `type`, instead of being inferred from the `Dataset` or `DatasetDict`. The keys in
107
- `schema_overrides` should be the column names of the `Dataset` or `DatasetDict` (whether or not they are valid
108
- Pixeltable identifiers).
101
+ name `name` will be given type `type`, instead of being inferred from the `Dataset` or `DatasetDict`.
102
+ The keys in `schema_overrides` should be the column names of the `Dataset` or `DatasetDict` (whether or not
103
+ they are valid Pixeltable identifiers).
109
104
  primary_key: The primary key of the table (see [`create_table()`][pixeltable.create_table]).
110
105
  kwargs: Additional arguments to pass to `create_table`.
106
+ An argument of `column_name_for_split` must be provided if the source is a DatasetDict.
107
+ This column name will contain the split information. If None, no split information will be stored.
111
108
 
112
109
  Returns:
113
110
  A handle to the newly created [`Table`][pixeltable.Table].
114
111
  """
115
- import datasets
116
-
117
- import pixeltable as pxt
118
-
119
- if not isinstance(dataset, (datasets.Dataset, datasets.DatasetDict)):
120
- raise excs.Error(f'`type(dataset)` must be `datasets.Dataset` or `datasets.DatasetDict`. Got {type(dataset)=}')
121
-
122
- # Create the pixeltable schema from the huggingface schema
123
- hf_schema_source = _get_hf_schema(dataset)
124
- schema_overrides, primary_key = normalize_import_parameters(schema_overrides, primary_key)
125
- hf_schema = huggingface_schema_to_pxt_schema(hf_schema_source, schema_overrides, primary_key)
126
-
127
- # Add the split column to the schema if requested
128
- if column_name_for_split is not None:
129
- if column_name_for_split in hf_schema:
130
- raise excs.Error(
131
- f'Column name `{column_name_for_split}` already exists in dataset schema; provide a different `column_name_for_split`'
132
- )
133
- hf_schema[column_name_for_split] = ts.StringType(nullable=True)
134
-
135
- schema, pxt_pk, _ = normalize_schema_names(hf_schema, primary_key, schema_overrides, True)
136
-
137
- # Prepare to create table and insert data
138
- if table_path in pxt.list_tables():
139
- raise excs.Error(f'table {table_path} already exists')
140
-
141
- if isinstance(dataset, datasets.Dataset):
142
- # when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
143
- raw_name = dataset.split._name
144
- split_name = raw_name.split('[')[0] if raw_name is not None else None
145
- dataset_dict = {split_name: dataset}
146
- else:
147
- dataset_dict = dataset
148
-
149
- # extract all class labels from the dataset to translate category ints to strings
150
- categorical_features = {
151
- feature_name: feature_type.names
152
- for (feature_name, feature_type) in hf_schema_source.items()
153
- if isinstance(feature_type, datasets.ClassLabel)
154
- }
155
-
156
- try:
157
- # random tmp name
158
- tmp_name = f'{table_path}_tmp_{random.randint(0, 100000000)}'
159
- tab = pxt.create_table(tmp_name, schema, primary_key=pxt_pk, **kwargs)
160
-
161
- def _translate_row(row: dict[str, Any], split_name: str) -> dict[str, Any]:
162
- output_row = row.copy()
163
- # map all class labels to strings
164
- for field, values in categorical_features.items():
165
- output_row[field] = values[row[field]]
166
- # add split name to row
167
- if column_name_for_split is not None:
168
- output_row[column_name_for_split] = split_name
169
- return output_row
170
-
171
- for split_name, split_dataset in dataset_dict.items():
172
- num_batches = split_dataset.size_in_bytes / _K_BATCH_SIZE_BYTES
173
- tuples_per_batch = math.ceil(split_dataset.num_rows / num_batches)
174
- assert tuples_per_batch > 0
175
-
176
- batch = []
177
- for row in split_dataset:
178
- batch.append(_translate_row(row, split_name))
179
- if len(batch) >= tuples_per_batch:
180
- tab.insert(batch)
181
- batch = []
182
- # last batch
183
- if len(batch) > 0:
184
- tab.insert(batch)
185
-
186
- except Exception as e:
187
- _logger.error(f'Error while inserting dataset into table: {tmp_name}')
188
- raise e
189
-
190
- pxt.move(tmp_name, table_path)
191
- return pxt.get_table(table_path)
112
+ return pxt.create_table(
113
+ table_path, source=dataset, schema_overrides=schema_overrides, primary_key=primary_key, extra_args=kwargs
114
+ )