pixeltable 0.4.15__py3-none-any.whl → 0.4.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (68) hide show
  1. pixeltable/__init__.py +4 -0
  2. pixeltable/catalog/catalog.py +125 -63
  3. pixeltable/catalog/column.py +7 -2
  4. pixeltable/catalog/table.py +1 -0
  5. pixeltable/catalog/table_metadata.py +4 -0
  6. pixeltable/catalog/table_version.py +174 -117
  7. pixeltable/catalog/table_version_handle.py +4 -1
  8. pixeltable/catalog/table_version_path.py +0 -11
  9. pixeltable/catalog/view.py +6 -0
  10. pixeltable/config.py +7 -0
  11. pixeltable/dataframe.py +10 -5
  12. pixeltable/env.py +56 -19
  13. pixeltable/exec/__init__.py +2 -0
  14. pixeltable/exec/cell_materialization_node.py +231 -0
  15. pixeltable/exec/cell_reconstruction_node.py +135 -0
  16. pixeltable/exec/exec_node.py +1 -1
  17. pixeltable/exec/expr_eval/evaluators.py +1 -0
  18. pixeltable/exec/expr_eval/expr_eval_node.py +3 -0
  19. pixeltable/exec/expr_eval/globals.py +2 -0
  20. pixeltable/exec/globals.py +32 -0
  21. pixeltable/exec/object_store_save_node.py +1 -4
  22. pixeltable/exec/row_update_node.py +16 -9
  23. pixeltable/exec/sql_node.py +107 -14
  24. pixeltable/exprs/__init__.py +1 -1
  25. pixeltable/exprs/arithmetic_expr.py +23 -18
  26. pixeltable/exprs/column_property_ref.py +10 -10
  27. pixeltable/exprs/column_ref.py +2 -2
  28. pixeltable/exprs/data_row.py +106 -37
  29. pixeltable/exprs/expr.py +9 -0
  30. pixeltable/exprs/expr_set.py +14 -7
  31. pixeltable/exprs/inline_expr.py +2 -19
  32. pixeltable/exprs/json_path.py +45 -12
  33. pixeltable/exprs/row_builder.py +54 -22
  34. pixeltable/functions/__init__.py +1 -0
  35. pixeltable/functions/bedrock.py +7 -0
  36. pixeltable/functions/deepseek.py +11 -4
  37. pixeltable/functions/llama_cpp.py +7 -0
  38. pixeltable/functions/math.py +1 -1
  39. pixeltable/functions/ollama.py +7 -0
  40. pixeltable/functions/openai.py +4 -4
  41. pixeltable/functions/openrouter.py +143 -0
  42. pixeltable/functions/video.py +110 -28
  43. pixeltable/globals.py +10 -4
  44. pixeltable/io/globals.py +18 -17
  45. pixeltable/io/parquet.py +1 -1
  46. pixeltable/io/table_data_conduit.py +47 -22
  47. pixeltable/iterators/document.py +61 -23
  48. pixeltable/iterators/video.py +126 -53
  49. pixeltable/metadata/__init__.py +1 -1
  50. pixeltable/metadata/converters/convert_40.py +73 -0
  51. pixeltable/metadata/notes.py +1 -0
  52. pixeltable/plan.py +175 -46
  53. pixeltable/share/packager.py +155 -26
  54. pixeltable/store.py +2 -3
  55. pixeltable/type_system.py +5 -3
  56. pixeltable/utils/arrow.py +6 -6
  57. pixeltable/utils/av.py +65 -0
  58. pixeltable/utils/console_output.py +4 -1
  59. pixeltable/utils/exception_handler.py +5 -28
  60. pixeltable/utils/image.py +7 -0
  61. pixeltable/utils/misc.py +5 -0
  62. pixeltable/utils/object_stores.py +16 -1
  63. pixeltable/utils/s3_store.py +44 -11
  64. {pixeltable-0.4.15.dist-info → pixeltable-0.4.17.dist-info}/METADATA +29 -28
  65. {pixeltable-0.4.15.dist-info → pixeltable-0.4.17.dist-info}/RECORD +68 -61
  66. {pixeltable-0.4.15.dist-info → pixeltable-0.4.17.dist-info}/WHEEL +0 -0
  67. {pixeltable-0.4.15.dist-info → pixeltable-0.4.17.dist-info}/entry_points.txt +0 -0
  68. {pixeltable-0.4.15.dist-info → pixeltable-0.4.17.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,10 @@
1
+ """
2
+ Pixeltable UDFs for Ollama local models.
3
+
4
+ Provides integration with Ollama for running large language models locally,
5
+ including chat completions and embeddings.
6
+ """
7
+
1
8
  from typing import TYPE_CHECKING, Optional
2
9
 
3
10
  import numpy as np
@@ -395,10 +395,10 @@ async def chat_completions(
395
395
  of the table `tbl`:
396
396
 
397
397
  >>> messages = [
398
- {'role': 'system', 'content': 'You are a helpful assistant.'},
399
- {'role': 'user', 'content': tbl.prompt}
400
- ]
401
- tbl.add_computed_column(response=chat_completions(messages, model='gpt-4o-mini'))
398
+ ... {'role': 'system', 'content': 'You are a helpful assistant.'},
399
+ ... {'role': 'user', 'content': tbl.prompt}
400
+ ... ]
401
+ >>> tbl.add_computed_column(response=chat_completions(messages, model='gpt-4o-mini'))
402
402
  """
403
403
  if model_kwargs is None:
404
404
  model_kwargs = {}
@@ -0,0 +1,143 @@
1
+ """
2
+ Pixeltable UDFs that wrap the OpenRouter API.
3
+
4
+ OpenRouter provides a unified interface to multiple LLM providers. In order to use it,
5
+ you must first sign up at https://openrouter.ai, create an API key, and configure it
6
+ as described in the Working with OpenRouter tutorial.
7
+ """
8
+
9
+ from typing import TYPE_CHECKING, Any, Optional
10
+
11
+ import pixeltable as pxt
12
+ from pixeltable.env import Env, register_client
13
+ from pixeltable.utils.code import local_public_names
14
+
15
+ if TYPE_CHECKING:
16
+ import openai
17
+
18
+
19
+ @register_client('openrouter')
20
+ def _(api_key: str, site_url: Optional[str] = None, app_name: Optional[str] = None) -> 'openai.AsyncOpenAI':
21
+ import openai
22
+
23
+ # Create default headers for OpenRouter
24
+ default_headers: dict[str, Any] = {}
25
+ if site_url:
26
+ default_headers['HTTP-Referer'] = site_url
27
+ if app_name:
28
+ default_headers['X-Title'] = app_name
29
+
30
+ return openai.AsyncOpenAI(base_url='https://openrouter.ai/api/v1', api_key=api_key, default_headers=default_headers)
31
+
32
+
33
+ def _openrouter_client() -> 'openai.AsyncOpenAI':
34
+ return Env.get().get_client('openrouter')
35
+
36
+
37
+ @pxt.udf(resource_pool='request-rate:openrouter')
38
+ async def chat_completions(
39
+ messages: list,
40
+ *,
41
+ model: str,
42
+ model_kwargs: Optional[dict[str, Any]] = None,
43
+ tools: Optional[list[dict[str, Any]]] = None,
44
+ tool_choice: Optional[dict[str, Any]] = None,
45
+ provider: Optional[dict[str, Any]] = None,
46
+ transforms: Optional[list[str]] = None,
47
+ ) -> dict:
48
+ """
49
+ Chat Completion API via OpenRouter.
50
+
51
+ OpenRouter provides access to multiple LLM providers through a unified API.
52
+ For additional details, see: <https://openrouter.ai/docs>
53
+
54
+ Supported models can be found at: <https://openrouter.ai/models>
55
+
56
+ Request throttling:
57
+ Applies the rate limit set in the config (section `openrouter`, key `rate_limit`). If no rate
58
+ limit is configured, uses a default of 600 RPM.
59
+
60
+ __Requirements:__
61
+
62
+ - `pip install openai`
63
+
64
+ Args:
65
+ messages: A list of messages comprising the conversation so far.
66
+ model: ID of the model to use (e.g., 'anthropic/claude-3.5-sonnet', 'openai/gpt-4').
67
+ model_kwargs: Additional OpenAI-compatible parameters.
68
+ tools: List of tools available to the model.
69
+ tool_choice: Controls which (if any) tool is called by the model.
70
+ provider: OpenRouter-specific provider preferences (e.g., {'order': ['Anthropic', 'OpenAI']}).
71
+ transforms: List of message transforms to apply (e.g., ['middle-out']).
72
+
73
+ Returns:
74
+ A dictionary containing the response in OpenAI format.
75
+
76
+ Examples:
77
+ Basic chat completion:
78
+
79
+ >>> messages = [{'role': 'user', 'content': tbl.prompt}]
80
+ ... tbl.add_computed_column(
81
+ ... response=chat_completions(
82
+ ... messages,
83
+ ... model='anthropic/claude-3.5-sonnet'
84
+ ... )
85
+ ... )
86
+
87
+ With provider routing:
88
+
89
+ >>> tbl.add_computed_column(
90
+ ... response=chat_completions(
91
+ ... messages,
92
+ ... model='anthropic/claude-3.5-sonnet',
93
+ ... provider={'require_parameters': True, 'order': ['Anthropic']}
94
+ ... )
95
+ ... )
96
+
97
+ With transforms:
98
+
99
+ >>> tbl.add_computed_column(
100
+ ... response=chat_completions(
101
+ ... messages,
102
+ ... model='openai/gpt-4',
103
+ ... transforms=['middle-out'] # Optimize for long contexts
104
+ ... )
105
+ ... )
106
+ """
107
+ if model_kwargs is None:
108
+ model_kwargs = {}
109
+
110
+ Env.get().require_package('openai')
111
+
112
+ # Handle tools if provided
113
+ if tools is not None:
114
+ model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
115
+
116
+ if tool_choice is not None:
117
+ if tool_choice['auto']:
118
+ model_kwargs['tool_choice'] = 'auto'
119
+ elif tool_choice['required']:
120
+ model_kwargs['tool_choice'] = 'required'
121
+ else:
122
+ assert tool_choice['tool'] is not None
123
+ model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
124
+
125
+ # Prepare OpenRouter-specific parameters for extra_body
126
+ extra_body: dict[str, Any] = {}
127
+ if provider is not None:
128
+ extra_body['provider'] = provider
129
+ if transforms is not None:
130
+ extra_body['transforms'] = transforms
131
+
132
+ # Make the API call
133
+ result = await _openrouter_client().chat.completions.create(
134
+ messages=messages, model=model, extra_body=extra_body if extra_body else None, **model_kwargs
135
+ )
136
+ return result.model_dump()
137
+
138
+
139
+ __all__ = local_public_names(__name__)
140
+
141
+
142
+ def __dir__() -> list[str]:
143
+ return __all__
@@ -2,10 +2,11 @@
2
2
  Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs) for `VideoType`.
3
3
  """
4
4
 
5
+ import glob
5
6
  import logging
6
7
  import pathlib
7
8
  import subprocess
8
- from typing import Literal, NoReturn
9
+ from typing import Any, Literal, NoReturn
9
10
 
10
11
  import av
11
12
  import av.stream
@@ -358,9 +359,17 @@ def clip(
358
359
 
359
360
 
360
361
  @pxt.udf(is_method=True)
361
- def segment_video(video: pxt.Video, *, duration: float) -> list[str]:
362
+ def segment_video(
363
+ video: pxt.Video,
364
+ *,
365
+ duration: float | None = None,
366
+ segment_times: list[float] | None = None,
367
+ mode: Literal['fast', 'accurate'] = 'fast',
368
+ video_encoder: str | None = None,
369
+ video_encoder_args: dict[str, Any] | None = None,
370
+ ) -> list[str]:
362
371
  """
363
- Split a video into fixed-size segments.
372
+ Split a video into segments.
364
373
 
365
374
  __Requirements:__
366
375
 
@@ -368,7 +377,19 @@ def segment_video(video: pxt.Video, *, duration: float) -> list[str]:
368
377
 
369
378
  Args:
370
379
  video: Input video file to segment
371
- duration: Approximate duration of each segment (in seconds).
380
+ duration: Duration of each segment (in seconds). For `mode='fast'`, this is approximate;
381
+ for `mode='accurate'`, segments will have exact durations. Cannot be specified together with
382
+ `segment_times`.
383
+ segment_times: List of timestamps (in seconds) in video where segments should be split. Note that these are not
384
+ segment durations. If all segment times are less than the duration of the video, produces exactly
385
+ `len(segment_times) + 1` segments. Cannot be empty or be specified together with `duration`.
386
+ mode: Segmentation mode:
387
+
388
+ - `'fast'`: Quick segmentation using stream copy (splits only at keyframes, approximate durations)
389
+ - `'accurate'`: Precise segmentation with re-encoding (exact durations, slower)
390
+ video_encoder: Video encoder to use. If not specified, uses the default encoder for the current platform.
391
+ Only available for `mode='accurate'`.
392
+ video_encoder_args: Additional arguments to pass to the video encoder. Only available for `mode='accurate'`.
372
393
 
373
394
  Returns:
374
395
  List of file paths for the generated video segments.
@@ -377,45 +398,106 @@ def segment_video(video: pxt.Video, *, duration: float) -> list[str]:
377
398
  pxt.Error: If the video is missing timing information.
378
399
 
379
400
  Examples:
380
- Split a video at 1 minute intervals
401
+ Split a video at 1 minute intervals using fast mode:
381
402
 
382
403
  >>> tbl.select(segment_paths=tbl.video.segment_video(duration=60)).collect()
383
404
 
405
+ Split video into exact 10-second segments with accurate mode, using the libx264 encoder with a CRF of 23 and
406
+ slow preset (for smaller output files):
407
+
408
+ >>> tbl.select(
409
+ ... segment_paths=tbl.video.segment_video(
410
+ ... duration=10,
411
+ ... mode='accurate',
412
+ ... video_encoder='libx264',
413
+ ... video_encoder_args={'crf': 23, 'preset': 'slow'}
414
+ ... )
415
+ ... ).collect()
416
+
384
417
  Split video into two parts at the midpoint:
385
418
 
386
419
  >>> duration = tbl.video.get_duration()
387
- >>> tbl.select(segment_paths=tbl.video.segment_video(duration=duration / 2 + 1)).collect()
420
+ >>> tbl.select(segment_paths=tbl.video.segment_video(segment_times=[duration / 2])).collect()
388
421
  """
389
422
  Env.get().require_binary('ffmpeg')
390
- if duration <= 0:
423
+ if duration is not None and segment_times is not None:
424
+ raise pxt.Error('duration and segment_times cannot both be specified')
425
+ if duration is not None and duration <= 0:
391
426
  raise pxt.Error(f'duration must be positive, got {duration}')
427
+ if segment_times is not None and len(segment_times) == 0:
428
+ raise pxt.Error('segment_times cannot be empty')
429
+ if mode == 'fast':
430
+ if video_encoder is not None:
431
+ raise pxt.Error("video_encoder is not supported for mode='fast'")
432
+ if video_encoder_args is not None:
433
+ raise pxt.Error("video_encoder_args is not supported for mode='fast'")
392
434
 
393
435
  base_path = TempStore.create_path(extension='')
394
436
 
395
- # we extract consecutive clips instead of running ffmpeg -f segment, which is inexplicably much slower
396
- start_time = 0.0
397
- result: list[str] = []
398
- try:
399
- while True:
400
- segment_path = f'{base_path}_segment_{len(result)}.mp4'
401
- cmd = av_utils.ffmpeg_clip_cmd(str(video), segment_path, start_time, duration)
437
+ output_paths: list[str] = []
438
+ if mode == 'accurate':
439
+ # Use ffmpeg -f segment for accurate segmentation with re-encoding
440
+ output_pattern = f'{base_path}_segment_%04d.mp4'
441
+ cmd = av_utils.ffmpeg_segment_cmd(
442
+ str(video),
443
+ output_pattern,
444
+ segment_duration=duration,
445
+ segment_times=segment_times,
446
+ video_encoder=video_encoder,
447
+ video_encoder_args=video_encoder_args,
448
+ )
402
449
 
450
+ try:
403
451
  _ = subprocess.run(cmd, capture_output=True, text=True, check=True)
404
- segment_duration = av_utils.get_video_duration(segment_path)
405
- if segment_duration == 0.0:
406
- # we're done
407
- pathlib.Path(segment_path).unlink()
408
- return result
409
- result.append(segment_path)
410
- start_time += segment_duration # use the actual segment duration here, it won't match duration exactly
452
+ output_paths = sorted(glob.glob(f'{base_path}_segment_*.mp4'))
453
+ # TODO: is this actually an error?
454
+ # if len(output_paths) == 0:
455
+ # stderr_output = result.stderr.strip() if result.stderr is not None else ''
456
+ # raise pxt.Error(
457
+ # f'ffmpeg failed to create output files for commandline: {" ".join(cmd)}\n{stderr_output}'
458
+ # )
459
+ return output_paths
460
+
461
+ except subprocess.CalledProcessError as e:
462
+ _handle_ffmpeg_error(e)
411
463
 
412
- return result
413
-
414
- except subprocess.CalledProcessError as e:
415
- # clean up partial results
416
- for segment_path in result:
417
- pathlib.Path(segment_path).unlink()
418
- _handle_ffmpeg_error(e)
464
+ else:
465
+ # Fast mode: extract consecutive clips using stream copy (no re-encoding)
466
+ # This is faster but can only split at keyframes, leading to approximate durations
467
+ start_time = 0.0
468
+ segment_idx = 0
469
+ try:
470
+ while True:
471
+ target_duration: float | None
472
+ if duration is not None:
473
+ target_duration = duration
474
+ elif segment_idx < len(segment_times):
475
+ target_duration = segment_times[segment_idx] - start_time
476
+ else:
477
+ target_duration = None # the rest
478
+ segment_path = f'{base_path}_segment_{len(output_paths)}.mp4'
479
+ cmd = av_utils.ffmpeg_clip_cmd(str(video), segment_path, start_time, target_duration)
480
+
481
+ _ = subprocess.run(cmd, capture_output=True, text=True, check=True)
482
+ segment_duration = av_utils.get_video_duration(segment_path)
483
+ if segment_duration == 0.0:
484
+ # we're done
485
+ pathlib.Path(segment_path).unlink()
486
+ return output_paths
487
+ output_paths.append(segment_path)
488
+ start_time += segment_duration # use the actual segment duration here, it won't match duration exactly
489
+
490
+ segment_idx += 1
491
+ if segment_times is not None and segment_idx > len(segment_times):
492
+ break
493
+
494
+ return output_paths
495
+
496
+ except subprocess.CalledProcessError as e:
497
+ # clean up partial results
498
+ for segment_path in output_paths:
499
+ pathlib.Path(segment_path).unlink()
500
+ _handle_ffmpeg_error(e)
419
501
 
420
502
 
421
503
  @pxt.udf(is_method=True)
pixeltable/globals.py CHANGED
@@ -179,7 +179,7 @@ def create_table(
179
179
  'Unable to create a proper schema from supplied `source`. Please use appropriate `schema_overrides`.'
180
180
  )
181
181
 
182
- table = Catalog.get().create_table(
182
+ table, was_created = Catalog.get().create_table(
183
183
  path_obj,
184
184
  schema,
185
185
  data_source.pxt_df if isinstance(data_source, DFTableDataConduit) else None,
@@ -189,7 +189,7 @@ def create_table(
189
189
  media_validation=media_validation_,
190
190
  num_retained_versions=num_retained_versions,
191
191
  )
192
- if data_source is not None and not is_direct_df:
192
+ if was_created and data_source is not None and not is_direct_df:
193
193
  fail_on_exception = OnErrorParameter.fail_on_exception(on_error)
194
194
  table.insert_table_data_source(data_source=data_source, fail_on_exception=fail_on_exception)
195
195
 
@@ -447,11 +447,16 @@ def replicate(remote_uri: str, local_path: str) -> catalog.Table:
447
447
  return share.pull_replica(local_path, remote_uri)
448
448
 
449
449
 
450
- def get_table(path: str) -> catalog.Table:
450
+ def get_table(path: str, if_not_exists: Literal['error', 'ignore'] = 'error') -> catalog.Table | None:
451
451
  """Get a handle to an existing table, view, or snapshot.
452
452
 
453
453
  Args:
454
454
  path: Path to the table.
455
+ if_not_exists: Directive regarding how to handle if the path does not exist.
456
+ Must be one of the following:
457
+
458
+ - `'error'`: raise an error
459
+ - `'ignore'`: do nothing and return `None`
455
460
 
456
461
  Returns:
457
462
  A handle to the [`Table`][pixeltable.Table].
@@ -476,8 +481,9 @@ def get_table(path: str) -> catalog.Table:
476
481
 
477
482
  >>> tbl = pxt.get_table('my_table:722')
478
483
  """
484
+ if_not_exists_ = catalog.IfNotExistsParam.validated(if_not_exists, 'if_not_exists')
479
485
  path_obj = catalog.Path.parse(path, allow_versioned_path=True)
480
- tbl = Catalog.get().get_table(path_obj)
486
+ tbl = Catalog.get().get_table(path_obj, if_not_exists_)
481
487
  return tbl
482
488
 
483
489
 
pixeltable/io/globals.py CHANGED
@@ -103,25 +103,26 @@ def create_label_studio_project(
103
103
  column of the table `tbl`:
104
104
 
105
105
  >>> config = \"\"\"
106
- <View>
107
- <Video name="video_obj" value="$video_col"/>
108
- <Choices name="video-category" toName="video" showInLine="true">
109
- <Choice value="city"/>
110
- <Choice value="food"/>
111
- <Choice value="sports"/>
112
- </Choices>
113
- </View>\"\"\"
114
- create_label_studio_project(tbl, config)
106
+ ... <View>
107
+ ... <Video name="video_obj" value="$video_col"/>
108
+ ... <Choices name="video-category" toName="video" showInLine="true">
109
+ ... <Choice value="city"/>
110
+ ... <Choice value="food"/>
111
+ ... <Choice value="sports"/>
112
+ ... </Choices>
113
+ ... </View>
114
+ ... \"\"\"
115
+ >>> create_label_studio_project(tbl, config)
115
116
 
116
117
  Create a Label Studio project with the same configuration, using `media_import_method='url'`,
117
118
  whose media are stored in an S3 bucket:
118
119
 
119
120
  >>> create_label_studio_project(
120
- tbl,
121
- config,
122
- media_import_method='url',
123
- s3_configuration={'bucket': 'my-bucket', 'region_name': 'us-east-2'}
124
- )
121
+ ... tbl,
122
+ ... config,
123
+ ... media_import_method='url',
124
+ ... s3_configuration={'bucket': 'my-bucket', 'region_name': 'us-east-2'}
125
+ ... )
125
126
  """
126
127
  Env.get().require_package('label_studio_sdk')
127
128
 
@@ -151,7 +152,7 @@ def export_images_as_fo_dataset(
151
152
  (or expression) containing image data, along with optional additional columns containing labels. Currently, only
152
153
  classification and detection labels are supported.
153
154
 
154
- The [Working with Voxel51 in Pixeltable](https://docs.pixeltable.com/docs/working-with-voxel51) tutorial contains a
155
+ The [Working with Voxel51 in Pixeltable](https://docs.pixeltable.com/examples/vision/voxel51) tutorial contains a
155
156
  fully worked example showing how to export data from a Pixeltable table and load it into Voxel51.
156
157
 
157
158
  Images in the dataset that already exist on disk will be exported directly, in whatever format they
@@ -204,13 +205,13 @@ def export_images_as_fo_dataset(
204
205
  Export the images in the `image` column of the table `tbl` as a Voxel51 dataset, using classification
205
206
  labels from `tbl.classifications`:
206
207
 
207
- >>> export_as_fiftyone(
208
+ >>> export_images_as_fo_dataset(
208
209
  ... tbl,
209
210
  ... tbl.image,
210
211
  ... classifications=tbl.classifications
211
212
  ... )
212
213
 
213
- See the [Working with Voxel51 in Pixeltable](https://docs.pixeltable.com/docs/working-with-voxel51) tutorial
214
+ See the [Working with Voxel51 in Pixeltable](https://docs.pixeltable.com/examples/vision/voxel51) tutorial
214
215
  for a fully worked example.
215
216
  """
216
217
  Env.get().require_package('fiftyone')
pixeltable/io/parquet.py CHANGED
@@ -62,7 +62,7 @@ def export_parquet(
62
62
  with Catalog.get().begin_xact(for_write=False):
63
63
  for record_batch in to_record_batches(df, partition_size_bytes):
64
64
  output_path = temp_path / f'part-{batch_num:05d}.parquet'
65
- arrow_tbl = pa.Table.from_batches([record_batch]) # type: ignore
65
+ arrow_tbl = pa.Table.from_batches([record_batch])
66
66
  pa.parquet.write_table(arrow_tbl, str(output_path))
67
67
  batch_num += 1
68
68
 
@@ -10,7 +10,9 @@ from dataclasses import dataclass, field, fields
10
10
  from pathlib import Path
11
11
  from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional, cast
12
12
 
13
+ import numpy as np
13
14
  import pandas as pd
15
+ import PIL
14
16
  from pyarrow.parquet import ParquetDataset
15
17
 
16
18
  import pixeltable as pxt
@@ -325,7 +327,11 @@ class JsonTableDataConduit(TableDataConduit):
325
327
 
326
328
 
327
329
  class HFTableDataConduit(TableDataConduit):
328
- hf_ds: datasets.Dataset | datasets.DatasetDict | None = None
330
+ """
331
+ TODO:
332
+ - use set_format('arrow') and convert ChunkedArrays to PIL.Image.Image instead of going through numpy, which is slow
333
+ """
334
+
329
335
  column_name_for_split: Optional[str] = None
330
336
  categorical_features: dict[str, dict[int, str]]
331
337
  dataset_dict: dict[str, datasets.Dataset] = None
@@ -339,9 +345,19 @@ class HFTableDataConduit(TableDataConduit):
339
345
  import datasets
340
346
 
341
347
  assert isinstance(tds.source, (datasets.Dataset, datasets.DatasetDict))
342
- t.hf_ds = tds.source
343
348
  if 'column_name_for_split' in t.extra_fields:
344
349
  t.column_name_for_split = t.extra_fields['column_name_for_split']
350
+
351
+ # make sure we get numpy arrays for arrays, not Python lists
352
+ source = tds.source.with_format(type='numpy')
353
+ if isinstance(source, datasets.Dataset):
354
+ # when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
355
+ raw_name = source.split._name
356
+ split_name = raw_name.split('[')[0] if raw_name is not None else None
357
+ t.dataset_dict = {split_name: source}
358
+ else:
359
+ assert isinstance(source, datasets.DatasetDict)
360
+ t.dataset_dict = source
345
361
  return t
346
362
 
347
363
  @classmethod
@@ -361,7 +377,7 @@ class HFTableDataConduit(TableDataConduit):
361
377
  if self.source_column_map is None:
362
378
  if self.src_schema_overrides is None:
363
379
  self.src_schema_overrides = {}
364
- self.hf_schema_source = _get_hf_schema(self.hf_ds)
380
+ self.hf_schema_source = _get_hf_schema(self.source)
365
381
  self.src_schema = huggingface_schema_to_pxt_schema(
366
382
  self.hf_schema_source, self.src_schema_overrides, self.src_pk
367
383
  )
@@ -396,15 +412,6 @@ class HFTableDataConduit(TableDataConduit):
396
412
  def prepare_insert(self) -> None:
397
413
  import datasets
398
414
 
399
- if isinstance(self.source, datasets.Dataset):
400
- # when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
401
- raw_name = self.source.split._name
402
- split_name = raw_name.split('[')[0] if raw_name is not None else None
403
- self.dataset_dict = {split_name: self.source}
404
- else:
405
- assert isinstance(self.source, datasets.DatasetDict)
406
- self.dataset_dict = self.source
407
-
408
415
  # extract all class labels from the dataset to translate category ints to strings
409
416
  self.categorical_features = {
410
417
  feature_name: feature_type.names
@@ -415,26 +422,44 @@ class HFTableDataConduit(TableDataConduit):
415
422
  self.source_column_map = {}
416
423
  self.check_source_columns_are_insertable(self.hf_schema_source.keys())
417
424
 
418
- def _translate_row(self, row: dict[str, Any], split_name: str) -> dict[str, Any]:
425
+ def _translate_row(self, row: dict[str, Any], split_name: str, features: datasets.Features) -> dict[str, Any]:
419
426
  output_row: dict[str, Any] = {}
420
427
  for col_name, val in row.items():
421
428
  # translate category ints to strings
422
429
  new_val = self.categorical_features[col_name][val] if col_name in self.categorical_features else val
423
430
  mapped_col_name = self.source_column_map.get(col_name, col_name)
424
431
 
425
- # Convert values to the appropriate type if needed
426
- try:
427
- checked_val = self.pxt_schema[mapped_col_name].create_literal(new_val)
428
- except TypeError as e:
429
- msg = str(e)
430
- raise excs.Error(f'Error in column {col_name}: {msg[0].lower() + msg[1:]}\nRow: {row}') from e
431
- output_row[mapped_col_name] = checked_val
432
+ new_val = self._translate_val(new_val, features[col_name])
433
+ output_row[mapped_col_name] = new_val
432
434
 
433
435
  # add split name to output row
434
436
  if self.column_name_for_split is not None:
435
437
  output_row[self.column_name_for_split] = split_name
436
438
  return output_row
437
439
 
440
+ def _translate_val(self, val: Any, feature: datasets.Feature) -> Any:
441
+ """Convert numpy scalars to Python types and images to PIL.Image.Image"""
442
+ import datasets
443
+
444
+ if isinstance(feature, datasets.Value):
445
+ if isinstance(val, (np.generic, np.ndarray)):
446
+ # a scalar, which we want as a standard Python type
447
+ assert np.ndim(val) == 0
448
+ return val.item()
449
+ else:
450
+ # a standard Python object
451
+ return val
452
+ elif isinstance(feature, datasets.Sequence):
453
+ assert np.ndim(val) > 0
454
+ return val
455
+ elif isinstance(feature, datasets.Image):
456
+ return PIL.Image.fromarray(val)
457
+ elif isinstance(feature, dict):
458
+ assert isinstance(val, dict)
459
+ return {k: self._translate_val(v, feature[k]) for k, v in val.items()}
460
+ else:
461
+ return val
462
+
438
463
  def valid_row_batch(self) -> Iterator[RowData]:
439
464
  for split_name, split_dataset in self.dataset_dict.items():
440
465
  num_batches = split_dataset.size_in_bytes / self._K_BATCH_SIZE_BYTES
@@ -443,7 +468,7 @@ class HFTableDataConduit(TableDataConduit):
443
468
 
444
469
  batch = []
445
470
  for row in split_dataset:
446
- batch.append(self._translate_row(row, split_name))
471
+ batch.append(self._translate_row(row, split_name, split_dataset.features))
447
472
  if len(batch) >= tuples_per_batch:
448
473
  yield batch
449
474
  batch = []
@@ -503,7 +528,7 @@ class ParquetTableDataConduit(TableDataConduit):
503
528
  from pixeltable.utils.arrow import iter_tuples2
504
529
 
505
530
  try:
506
- for fragment in self.pq_ds.fragments: # type: ignore[attr-defined]
531
+ for fragment in self.pq_ds.fragments:
507
532
  for batch in fragment.to_batches():
508
533
  dict_batch = list(iter_tuples2(batch, self.source_column_map, self.pxt_schema))
509
534
  self.total_rows += len(dict_batch)