pixeltable 0.4.14__py3-none-any.whl → 0.4.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (64) hide show
  1. pixeltable/__init__.py +6 -1
  2. pixeltable/catalog/catalog.py +107 -45
  3. pixeltable/catalog/column.py +7 -2
  4. pixeltable/catalog/table.py +1 -0
  5. pixeltable/catalog/table_metadata.py +5 -0
  6. pixeltable/catalog/table_version.py +100 -106
  7. pixeltable/catalog/table_version_handle.py +4 -1
  8. pixeltable/catalog/update_status.py +12 -0
  9. pixeltable/config.py +6 -0
  10. pixeltable/dataframe.py +11 -5
  11. pixeltable/env.py +52 -19
  12. pixeltable/exec/__init__.py +2 -0
  13. pixeltable/exec/cell_materialization_node.py +231 -0
  14. pixeltable/exec/cell_reconstruction_node.py +135 -0
  15. pixeltable/exec/exec_node.py +1 -1
  16. pixeltable/exec/expr_eval/evaluators.py +1 -0
  17. pixeltable/exec/expr_eval/expr_eval_node.py +14 -0
  18. pixeltable/exec/expr_eval/globals.py +2 -0
  19. pixeltable/exec/globals.py +32 -0
  20. pixeltable/exec/object_store_save_node.py +1 -4
  21. pixeltable/exec/row_update_node.py +16 -9
  22. pixeltable/exec/sql_node.py +107 -14
  23. pixeltable/exprs/__init__.py +1 -1
  24. pixeltable/exprs/arithmetic_expr.py +10 -11
  25. pixeltable/exprs/column_property_ref.py +10 -10
  26. pixeltable/exprs/column_ref.py +2 -2
  27. pixeltable/exprs/data_row.py +106 -37
  28. pixeltable/exprs/expr.py +9 -0
  29. pixeltable/exprs/expr_set.py +14 -7
  30. pixeltable/exprs/inline_expr.py +2 -19
  31. pixeltable/exprs/json_path.py +45 -12
  32. pixeltable/exprs/row_builder.py +54 -22
  33. pixeltable/functions/__init__.py +1 -0
  34. pixeltable/functions/bedrock.py +7 -0
  35. pixeltable/functions/deepseek.py +11 -4
  36. pixeltable/functions/llama_cpp.py +7 -0
  37. pixeltable/functions/math.py +1 -1
  38. pixeltable/functions/ollama.py +7 -0
  39. pixeltable/functions/openai.py +4 -4
  40. pixeltable/functions/openrouter.py +143 -0
  41. pixeltable/functions/video.py +123 -9
  42. pixeltable/functions/whisperx.py +2 -0
  43. pixeltable/functions/yolox.py +2 -0
  44. pixeltable/globals.py +56 -31
  45. pixeltable/io/__init__.py +1 -0
  46. pixeltable/io/globals.py +16 -15
  47. pixeltable/io/table_data_conduit.py +46 -21
  48. pixeltable/iterators/__init__.py +1 -0
  49. pixeltable/metadata/__init__.py +1 -1
  50. pixeltable/metadata/converters/convert_40.py +73 -0
  51. pixeltable/metadata/notes.py +1 -0
  52. pixeltable/plan.py +175 -46
  53. pixeltable/share/publish.py +0 -1
  54. pixeltable/store.py +2 -2
  55. pixeltable/type_system.py +5 -3
  56. pixeltable/utils/console_output.py +4 -1
  57. pixeltable/utils/exception_handler.py +5 -28
  58. pixeltable/utils/image.py +7 -0
  59. pixeltable/utils/misc.py +5 -0
  60. {pixeltable-0.4.14.dist-info → pixeltable-0.4.16.dist-info}/METADATA +2 -1
  61. {pixeltable-0.4.14.dist-info → pixeltable-0.4.16.dist-info}/RECORD +64 -57
  62. {pixeltable-0.4.14.dist-info → pixeltable-0.4.16.dist-info}/WHEEL +0 -0
  63. {pixeltable-0.4.14.dist-info → pixeltable-0.4.16.dist-info}/entry_points.txt +0 -0
  64. {pixeltable-0.4.14.dist-info → pixeltable-0.4.16.dist-info}/licenses/LICENSE +0 -0
@@ -4,7 +4,6 @@ Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
4
4
 
5
5
  import logging
6
6
  import pathlib
7
- import shutil
8
7
  import subprocess
9
8
  from typing import Literal, NoReturn
10
9
 
@@ -327,6 +326,7 @@ def clip(
327
326
  Returns:
328
327
  New video containing only the specified time range or None if start_time is beyond the end of the video.
329
328
  """
329
+ Env.get().require_binary('ffmpeg')
330
330
  if start_time < 0:
331
331
  raise pxt.Error(f'start_time must be non-negative, got {start_time}')
332
332
  if end_time is not None and end_time <= start_time:
@@ -335,8 +335,6 @@ def clip(
335
335
  raise pxt.Error(f'duration must be positive, got {duration}')
336
336
  if end_time is not None and duration is not None:
337
337
  raise pxt.Error('end_time and duration cannot both be specified')
338
- if not shutil.which('ffmpeg'):
339
- raise pxt.Error('ffmpeg is not installed or not in PATH. Please install ffmpeg to use get_clip().')
340
338
 
341
339
  video_duration = av_utils.get_video_duration(video)
342
340
  if video_duration is not None and start_time > video_duration:
@@ -388,10 +386,9 @@ def segment_video(video: pxt.Video, *, duration: float) -> list[str]:
388
386
  >>> duration = tbl.video.get_duration()
389
387
  >>> tbl.select(segment_paths=tbl.video.segment_video(duration=duration / 2 + 1)).collect()
390
388
  """
389
+ Env.get().require_binary('ffmpeg')
391
390
  if duration <= 0:
392
391
  raise pxt.Error(f'duration must be positive, got {duration}')
393
- if not shutil.which('ffmpeg'):
394
- raise pxt.Error('ffmpeg is not installed or not in PATH. Please install ffmpeg to use segment_video().')
395
392
 
396
393
  base_path = TempStore.create_path(extension='')
397
394
 
@@ -436,10 +433,9 @@ def concat_videos(videos: list[pxt.Video]) -> pxt.Video:
436
433
  Returns:
437
434
  A new video containing the merged videos.
438
435
  """
436
+ Env.get().require_binary('ffmpeg')
439
437
  if len(videos) == 0:
440
438
  raise pxt.Error('concat_videos(): empty argument list')
441
- if not shutil.which('ffmpeg'):
442
- raise pxt.Error('ffmpeg is not installed or not in PATH. Please install ffmpeg to use concat_videos().')
443
439
 
444
440
  # Check that all videos have the same resolution
445
441
  resolutions: list[tuple[int, int]] = []
@@ -528,6 +524,125 @@ def concat_videos(videos: list[pxt.Video]) -> pxt.Video:
528
524
  filelist_path.unlink()
529
525
 
530
526
 
527
+ @pxt.udf
528
+ def with_audio(
529
+ video: pxt.Video,
530
+ audio: pxt.Audio,
531
+ *,
532
+ video_start_time: float = 0.0,
533
+ video_duration: float | None = None,
534
+ audio_start_time: float = 0.0,
535
+ audio_duration: float | None = None,
536
+ ) -> pxt.Video:
537
+ """
538
+ Creates a new video that combines the video stream from `video` and the audio stream from `audio`.
539
+ The `start_time` and `duration` parameters can be used to select a specific time range from each input.
540
+ If the audio input (or selected time range) is longer than the video, the audio will be truncated.
541
+
542
+
543
+ __Requirements:__
544
+
545
+ - `ffmpeg` needs to be installed and in PATH
546
+
547
+ Args:
548
+ video: Input video.
549
+ audio: Input audio.
550
+ video_start_time: Start time in the video input (in seconds).
551
+ video_duration: Duration of video segment (in seconds). If None, uses the remainder of the video after
552
+ `video_start_time`. `video_duration` determines the duration of the output video.
553
+ audio_start_time: Start time in the audio input (in seconds).
554
+ audio_duration: Duration of audio segment (in seconds). If None, uses the remainder of the audio after
555
+ `audio_start_time`. If the audio is longer than the output video, it will be truncated.
556
+
557
+ Returns:
558
+ A new video file with the audio track added.
559
+
560
+ Examples:
561
+ Add background music to a video:
562
+
563
+ >>> tbl.select(tbl.video.with_audio(tbl.music_track)).collect()
564
+
565
+ Add audio starting 5 seconds into both files:
566
+
567
+ >>> tbl.select(
568
+ ... tbl.video.with_audio(
569
+ ... tbl.music_track,
570
+ ... video_start_time=5.0,
571
+ ... audio_start_time=5.0
572
+ ... )
573
+ ... ).collect()
574
+
575
+ Use a 10-second clip from the middle of both files:
576
+
577
+ >>> tbl.select(
578
+ ... tbl.video.with_audio(
579
+ ... tbl.music_track,
580
+ ... video_start_time=30.0,
581
+ ... video_duration=10.0,
582
+ ... audio_start_time=15.0,
583
+ ... audio_duration=10.0
584
+ ... )
585
+ ... ).collect()
586
+ """
587
+ Env.get().require_binary('ffmpeg')
588
+ if video_start_time < 0:
589
+ raise pxt.Error(f'video_offset must be non-negative, got {video_start_time}')
590
+ if audio_start_time < 0:
591
+ raise pxt.Error(f'audio_offset must be non-negative, got {audio_start_time}')
592
+ if video_duration is not None and video_duration <= 0:
593
+ raise pxt.Error(f'video_duration must be positive, got {video_duration}')
594
+ if audio_duration is not None and audio_duration <= 0:
595
+ raise pxt.Error(f'audio_duration must be positive, got {audio_duration}')
596
+
597
+ output_path = str(TempStore.create_path(extension='.mp4'))
598
+
599
+ cmd = ['ffmpeg']
600
+ if video_start_time > 0:
601
+ # fast seek, must precede -i
602
+ cmd.extend(['-ss', str(video_start_time)])
603
+ if video_duration is not None:
604
+ cmd.extend(['-t', str(video_duration)])
605
+ else:
606
+ video_duration = av_utils.get_video_duration(video)
607
+ cmd.extend(['-i', str(video)])
608
+
609
+ if audio_start_time > 0:
610
+ cmd.extend(['-ss', str(audio_start_time)])
611
+ if audio_duration is not None:
612
+ cmd.extend(['-t', str(audio_duration)])
613
+ cmd.extend(['-i', str(audio)])
614
+
615
+ cmd.extend(
616
+ [
617
+ '-map',
618
+ '0:v:0', # video from first input
619
+ '-map',
620
+ '1:a:0', # audio from second input
621
+ '-c:v',
622
+ 'copy', # avoid re-encoding
623
+ '-c:a',
624
+ 'copy', # avoid re-encoding
625
+ '-t',
626
+ str(video_duration), # limit output duration to video duration
627
+ '-loglevel',
628
+ 'error', # only show errors
629
+ output_path,
630
+ ]
631
+ )
632
+
633
+ _logger.debug(f'with_audio(): {" ".join(cmd)}')
634
+
635
+ try:
636
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
637
+ output_file = pathlib.Path(output_path)
638
+ if not output_file.exists() or output_file.stat().st_size == 0:
639
+ stderr_output = result.stderr.strip() if result.stderr is not None else ''
640
+ raise pxt.Error(f'ffmpeg failed to create output file for commandline: {" ".join(cmd)}\n{stderr_output}')
641
+ return output_path
642
+ except subprocess.CalledProcessError as e:
643
+ _handle_ffmpeg_error(e)
644
+
645
+
531
646
  @pxt.udf(is_method=True)
532
647
  def overlay_text(
533
648
  video: pxt.Video,
@@ -614,8 +729,7 @@ def overlay_text(
614
729
  ... )
615
730
  ... ).collect()
616
731
  """
617
- if not shutil.which('ffmpeg'):
618
- raise pxt.Error('ffmpeg is not installed or not in PATH. Please install ffmpeg to use overlay_text().')
732
+ Env.get().require_binary('ffmpeg')
619
733
  if font_size <= 0:
620
734
  raise pxt.Error(f'font_size must be positive, got {font_size}')
621
735
  if opacity < 0.0 or opacity > 1.0:
@@ -1,3 +1,5 @@
1
+ """WhisperX audio transcription and diarization functions."""
2
+
1
3
  from typing import TYPE_CHECKING, Any, Optional
2
4
 
3
5
  import numpy as np
@@ -1,3 +1,5 @@
1
+ """YOLOX object detection functions."""
2
+
1
3
  import logging
2
4
  from typing import TYPE_CHECKING
3
5
 
pixeltable/globals.py CHANGED
@@ -179,7 +179,7 @@ def create_table(
179
179
  'Unable to create a proper schema from supplied `source`. Please use appropriate `schema_overrides`.'
180
180
  )
181
181
 
182
- table = Catalog.get().create_table(
182
+ table, was_created = Catalog.get().create_table(
183
183
  path_obj,
184
184
  schema,
185
185
  data_source.pxt_df if isinstance(data_source, DFTableDataConduit) else None,
@@ -189,7 +189,7 @@ def create_table(
189
189
  media_validation=media_validation_,
190
190
  num_retained_versions=num_retained_versions,
191
191
  )
192
- if data_source is not None and not is_direct_df:
192
+ if was_created and data_source is not None and not is_direct_df:
193
193
  fail_on_exception = OnErrorParameter.fail_on_exception(on_error)
194
194
  table.insert_table_data_source(data_source=data_source, fail_on_exception=fail_on_exception)
195
195
 
@@ -397,47 +397,66 @@ def create_snapshot(
397
397
  )
398
398
 
399
399
 
400
- def create_replica(
401
- destination: str,
400
+ def publish(
402
401
  source: str | catalog.Table,
402
+ destination_uri: str,
403
403
  bucket_name: str | None = None,
404
404
  access: Literal['public', 'private'] = 'private',
405
- ) -> Optional[catalog.Table]:
405
+ ) -> None:
406
406
  """
407
- Create a replica of a table. Can be used either to create a remote replica of a local table, or to create a local
408
- replica of a remote table. A given table can have at most one replica per Pixeltable instance.
407
+ Publishes a replica of a local Pixeltable table to Pixeltable cloud. A given table can be published to at most one
408
+ URI per Pixeltable cloud database.
409
409
 
410
410
  Args:
411
- destination: Path where the replica will be created. Can be either a local path such as `'my_dir.my_table'`, or
412
- a remote URI such as `'pxt://username/mydir.my_table'`.
413
- source: Path to the source table, or (if the source table is a local table) a handle to the source table.
414
- bucket_name: The name of the pixeltable cloud-registered bucket to use to store replica's data.
415
- If no `bucket_name` is provided, the default Pixeltable storage bucket will be used.
411
+ source: Path or table handle of the local table to be published.
412
+ destination_uri: Remote URI where the replica will be published, such as `'pxt://org_name/my_dir/my_table'`.
413
+ bucket_name: The name of the bucket to use to store replica's data. The bucket must be registered with
414
+ Pixeltable cloud. If no `bucket_name` is provided, the default storage bucket for the destination
415
+ database will be used.
416
416
  access: Access control for the replica.
417
417
 
418
418
  - `'public'`: Anyone can access this replica.
419
- - `'private'`: Only the owner can access.
419
+ - `'private'`: Only the host organization can access.
420
420
  """
421
- remote_dest = destination.startswith('pxt://')
422
- remote_source = isinstance(source, str) and source.startswith('pxt://')
423
- if remote_dest == remote_source:
424
- raise excs.Error('Exactly one of `destination` or `source` must be a remote URI.')
425
-
426
- if remote_dest:
427
- if isinstance(source, str):
428
- source = get_table(source)
429
- share.push_replica(destination, source, bucket_name, access)
430
- return None
431
- else:
432
- assert isinstance(source, str)
433
- return share.pull_replica(destination, source)
421
+ if not destination_uri.startswith('pxt://'):
422
+ raise excs.Error("`destination_uri` must be a remote Pixeltable URI with the prefix 'pxt://'")
434
423
 
424
+ if isinstance(source, str):
425
+ source = get_table(source)
435
426
 
436
- def get_table(path: str) -> catalog.Table:
427
+ share.push_replica(destination_uri, source, bucket_name, access)
428
+
429
+
430
+ def replicate(remote_uri: str, local_path: str) -> catalog.Table:
431
+ """
432
+ Retrieve a replica from Pixeltable cloud as a local table. This will create a full local copy of the replica in a
433
+ way that preserves the table structure of the original source data. Once replicated, the local table can be
434
+ queried offline just as any other Pixeltable table.
435
+
436
+ Args:
437
+ remote_uri: Remote URI of the table to be replicated, such as `'pxt://org_name/my_dir/my_table'`.
438
+ local_path: Local table path where the replica will be created, such as `'my_new_dir.my_new_tbl'`. It can be
439
+ the same or different from the cloud table name.
440
+
441
+ Returns:
442
+ A handle to the newly created local replica table.
443
+ """
444
+ if not remote_uri.startswith('pxt://'):
445
+ raise excs.Error("`remote_uri` must be a remote Pixeltable URI with the prefix 'pxt://'")
446
+
447
+ return share.pull_replica(local_path, remote_uri)
448
+
449
+
450
+ def get_table(path: str, if_not_exists: Literal['error', 'ignore'] = 'error') -> catalog.Table | None:
437
451
  """Get a handle to an existing table, view, or snapshot.
438
452
 
439
453
  Args:
440
454
  path: Path to the table.
455
+ if_not_exists: Directive regarding how to handle if the path does not exist.
456
+ Must be one of the following:
457
+
458
+ - `'error'`: raise an error
459
+ - `'ignore'`: do nothing and return `None`
441
460
 
442
461
  Returns:
443
462
  A handle to the [`Table`][pixeltable.Table].
@@ -462,8 +481,9 @@ def get_table(path: str) -> catalog.Table:
462
481
 
463
482
  >>> tbl = pxt.get_table('my_table:722')
464
483
  """
484
+ if_not_exists_ = catalog.IfNotExistsParam.validated(if_not_exists, 'if_not_exists')
465
485
  path_obj = catalog.Path.parse(path, allow_versioned_path=True)
466
- tbl = Catalog.get().get_table(path_obj)
486
+ tbl = Catalog.get().get_table(path_obj, if_not_exists_)
467
487
  return tbl
468
488
 
469
489
 
@@ -498,10 +518,11 @@ def move(path: str, new_path: str) -> None:
498
518
  def drop_table(
499
519
  table: str | catalog.Table, force: bool = False, if_not_exists: Literal['error', 'ignore'] = 'error'
500
520
  ) -> None:
501
- """Drop a table, view, or snapshot.
521
+ """Drop a table, view, snapshot, or replica.
502
522
 
503
523
  Args:
504
- table: Fully qualified name, or handle, of the table to be dropped.
524
+ table: Fully qualified name or table handle of the table to be dropped; or a remote URI of a cloud replica to
525
+ be deleted.
505
526
  force: If `True`, will also drop all views and sub-views of this table.
506
527
  if_not_exists: Directive regarding how to handle if the path does not exist.
507
528
  Must be one of the following:
@@ -541,13 +562,17 @@ def drop_table(
541
562
  assert isinstance(table, str)
542
563
  tbl_path = table
543
564
 
565
+ if_not_exists_ = catalog.IfNotExistsParam.validated(if_not_exists, 'if_not_exists')
566
+
544
567
  if tbl_path.startswith('pxt://'):
545
568
  # Remote table
569
+ if force:
570
+ raise excs.Error('Cannot use `force=True` with a cloud replica URI.')
571
+ # TODO: Handle if_not_exists properly
546
572
  share.delete_replica(tbl_path)
547
573
  else:
548
574
  # Local table
549
575
  path_obj = catalog.Path.parse(tbl_path)
550
- if_not_exists_ = catalog.IfNotExistsParam.validated(if_not_exists, 'if_not_exists')
551
576
  Catalog.get().drop_table(path_obj, force=force, if_not_exists=if_not_exists_)
552
577
 
553
578
 
pixeltable/io/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
+ """Functions for importing and exporting Pixeltable data."""
1
2
  # ruff: noqa: F401
2
3
 
3
4
  from .datarows import import_json, import_rows
pixeltable/io/globals.py CHANGED
@@ -103,25 +103,26 @@ def create_label_studio_project(
103
103
  column of the table `tbl`:
104
104
 
105
105
  >>> config = \"\"\"
106
- <View>
107
- <Video name="video_obj" value="$video_col"/>
108
- <Choices name="video-category" toName="video" showInLine="true">
109
- <Choice value="city"/>
110
- <Choice value="food"/>
111
- <Choice value="sports"/>
112
- </Choices>
113
- </View>\"\"\"
114
- create_label_studio_project(tbl, config)
106
+ ... <View>
107
+ ... <Video name="video_obj" value="$video_col"/>
108
+ ... <Choices name="video-category" toName="video" showInLine="true">
109
+ ... <Choice value="city"/>
110
+ ... <Choice value="food"/>
111
+ ... <Choice value="sports"/>
112
+ ... </Choices>
113
+ ... </View>
114
+ ... \"\"\"
115
+ >>> create_label_studio_project(tbl, config)
115
116
 
116
117
  Create a Label Studio project with the same configuration, using `media_import_method='url'`,
117
118
  whose media are stored in an S3 bucket:
118
119
 
119
120
  >>> create_label_studio_project(
120
- tbl,
121
- config,
122
- media_import_method='url',
123
- s3_configuration={'bucket': 'my-bucket', 'region_name': 'us-east-2'}
124
- )
121
+ ... tbl,
122
+ ... config,
123
+ ... media_import_method='url',
124
+ ... s3_configuration={'bucket': 'my-bucket', 'region_name': 'us-east-2'}
125
+ ... )
125
126
  """
126
127
  Env.get().require_package('label_studio_sdk')
127
128
 
@@ -204,7 +205,7 @@ def export_images_as_fo_dataset(
204
205
  Export the images in the `image` column of the table `tbl` as a Voxel51 dataset, using classification
205
206
  labels from `tbl.classifications`:
206
207
 
207
- >>> export_as_fiftyone(
208
+ >>> export_images_as_fo_dataset(
208
209
  ... tbl,
209
210
  ... tbl.image,
210
211
  ... classifications=tbl.classifications
@@ -10,7 +10,9 @@ from dataclasses import dataclass, field, fields
10
10
  from pathlib import Path
11
11
  from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional, cast
12
12
 
13
+ import numpy as np
13
14
  import pandas as pd
15
+ import PIL
14
16
  from pyarrow.parquet import ParquetDataset
15
17
 
16
18
  import pixeltable as pxt
@@ -325,7 +327,11 @@ class JsonTableDataConduit(TableDataConduit):
325
327
 
326
328
 
327
329
  class HFTableDataConduit(TableDataConduit):
328
- hf_ds: datasets.Dataset | datasets.DatasetDict | None = None
330
+ """
331
+ TODO:
332
+ - use set_format('arrow') and convert ChunkedArrays to PIL.Image.Image instead of going through numpy, which is slow
333
+ """
334
+
329
335
  column_name_for_split: Optional[str] = None
330
336
  categorical_features: dict[str, dict[int, str]]
331
337
  dataset_dict: dict[str, datasets.Dataset] = None
@@ -339,9 +345,19 @@ class HFTableDataConduit(TableDataConduit):
339
345
  import datasets
340
346
 
341
347
  assert isinstance(tds.source, (datasets.Dataset, datasets.DatasetDict))
342
- t.hf_ds = tds.source
343
348
  if 'column_name_for_split' in t.extra_fields:
344
349
  t.column_name_for_split = t.extra_fields['column_name_for_split']
350
+
351
+ # make sure we get numpy arrays for arrays, not Python lists
352
+ source = tds.source.with_format(type='numpy')
353
+ if isinstance(source, datasets.Dataset):
354
+ # when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
355
+ raw_name = source.split._name
356
+ split_name = raw_name.split('[')[0] if raw_name is not None else None
357
+ t.dataset_dict = {split_name: source}
358
+ else:
359
+ assert isinstance(source, datasets.DatasetDict)
360
+ t.dataset_dict = source
345
361
  return t
346
362
 
347
363
  @classmethod
@@ -361,7 +377,7 @@ class HFTableDataConduit(TableDataConduit):
361
377
  if self.source_column_map is None:
362
378
  if self.src_schema_overrides is None:
363
379
  self.src_schema_overrides = {}
364
- self.hf_schema_source = _get_hf_schema(self.hf_ds)
380
+ self.hf_schema_source = _get_hf_schema(self.source)
365
381
  self.src_schema = huggingface_schema_to_pxt_schema(
366
382
  self.hf_schema_source, self.src_schema_overrides, self.src_pk
367
383
  )
@@ -396,15 +412,6 @@ class HFTableDataConduit(TableDataConduit):
396
412
  def prepare_insert(self) -> None:
397
413
  import datasets
398
414
 
399
- if isinstance(self.source, datasets.Dataset):
400
- # when loading an hf dataset partially, dataset.split._name is sometimes the form "train[0:1000]"
401
- raw_name = self.source.split._name
402
- split_name = raw_name.split('[')[0] if raw_name is not None else None
403
- self.dataset_dict = {split_name: self.source}
404
- else:
405
- assert isinstance(self.source, datasets.DatasetDict)
406
- self.dataset_dict = self.source
407
-
408
415
  # extract all class labels from the dataset to translate category ints to strings
409
416
  self.categorical_features = {
410
417
  feature_name: feature_type.names
@@ -415,26 +422,44 @@ class HFTableDataConduit(TableDataConduit):
415
422
  self.source_column_map = {}
416
423
  self.check_source_columns_are_insertable(self.hf_schema_source.keys())
417
424
 
418
- def _translate_row(self, row: dict[str, Any], split_name: str) -> dict[str, Any]:
425
+ def _translate_row(self, row: dict[str, Any], split_name: str, features: datasets.Features) -> dict[str, Any]:
419
426
  output_row: dict[str, Any] = {}
420
427
  for col_name, val in row.items():
421
428
  # translate category ints to strings
422
429
  new_val = self.categorical_features[col_name][val] if col_name in self.categorical_features else val
423
430
  mapped_col_name = self.source_column_map.get(col_name, col_name)
424
431
 
425
- # Convert values to the appropriate type if needed
426
- try:
427
- checked_val = self.pxt_schema[mapped_col_name].create_literal(new_val)
428
- except TypeError as e:
429
- msg = str(e)
430
- raise excs.Error(f'Error in column {col_name}: {msg[0].lower() + msg[1:]}\nRow: {row}') from e
431
- output_row[mapped_col_name] = checked_val
432
+ new_val = self._translate_val(new_val, features[col_name])
433
+ output_row[mapped_col_name] = new_val
432
434
 
433
435
  # add split name to output row
434
436
  if self.column_name_for_split is not None:
435
437
  output_row[self.column_name_for_split] = split_name
436
438
  return output_row
437
439
 
440
+ def _translate_val(self, val: Any, feature: datasets.Feature) -> Any:
441
+ """Convert numpy scalars to Python types and images to PIL.Image.Image"""
442
+ import datasets
443
+
444
+ if isinstance(feature, datasets.Value):
445
+ if isinstance(val, (np.generic, np.ndarray)):
446
+ # a scalar, which we want as a standard Python type
447
+ assert np.ndim(val) == 0
448
+ return val.item()
449
+ else:
450
+ # a standard Python object
451
+ return val
452
+ elif isinstance(feature, datasets.Sequence):
453
+ assert np.ndim(val) > 0
454
+ return val
455
+ elif isinstance(feature, datasets.Image):
456
+ return PIL.Image.fromarray(val)
457
+ elif isinstance(feature, dict):
458
+ assert isinstance(val, dict)
459
+ return {k: self._translate_val(v, feature[k]) for k, v in val.items()}
460
+ else:
461
+ return val
462
+
438
463
  def valid_row_batch(self) -> Iterator[RowData]:
439
464
  for split_name, split_dataset in self.dataset_dict.items():
440
465
  num_batches = split_dataset.size_in_bytes / self._K_BATCH_SIZE_BYTES
@@ -443,7 +468,7 @@ class HFTableDataConduit(TableDataConduit):
443
468
 
444
469
  batch = []
445
470
  for row in split_dataset:
446
- batch.append(self._translate_row(row, split_name))
471
+ batch.append(self._translate_row(row, split_name, split_dataset.features))
447
472
  if len(batch) >= tuples_per_batch:
448
473
  yield batch
449
474
  batch = []
@@ -1,3 +1,4 @@
1
+ """Iterators for splitting media and documents into components."""
1
2
  # ruff: noqa: F401
2
3
 
3
4
  from .audio import AudioSplitter
@@ -18,7 +18,7 @@ _console_logger = ConsoleLogger(logging.getLogger('pixeltable'))
18
18
  _logger = logging.getLogger('pixeltable')
19
19
 
20
20
  # current version of the metadata; this is incremented whenever the metadata schema changes
21
- VERSION = 40
21
+ VERSION = 41
22
22
 
23
23
 
24
24
  def create_system_info(engine: sql.engine.Engine) -> None:
@@ -0,0 +1,73 @@
1
+ import logging
2
+ from uuid import UUID
3
+
4
+ import sqlalchemy as sql
5
+
6
+ from pixeltable.metadata import register_converter
7
+ from pixeltable.metadata.converters.util import convert_table_md
8
+
9
+ _logger = logging.getLogger('pixeltable')
10
+
11
+
12
+ @register_converter(version=40)
13
+ def _(engine: sql.engine.Engine) -> None:
14
+ convert_table_md(engine, table_modifier=__table_modifier)
15
+
16
+
17
+ def __table_modifier(conn: sql.Connection, tbl_id: UUID, orig_table_md: dict, updated_table_md: dict) -> None:
18
+ store_prefix = 'view' if orig_table_md['view_md'] is not None else 'tbl'
19
+ store_name = f'{store_prefix}_{tbl_id.hex}'
20
+
21
+ # Get the list of column names that need _cellmd columns
22
+ _logger.info(f'Checking table {orig_table_md["name"]} ({store_name})')
23
+ col_ids = find_target_columns(orig_table_md)
24
+ if len(col_ids) == 0:
25
+ _logger.info(f'No Array or Json columns found in table {orig_table_md["name"]}. Skipping migration.')
26
+ return
27
+
28
+ # Check which columns already exist in the table
29
+ check_columns_sql = sql.text(f"""
30
+ SELECT column_name
31
+ FROM information_schema.columns
32
+ WHERE table_name = '{store_name}'
33
+ """)
34
+ existing_columns = {row[0] for row in conn.execute(check_columns_sql)}
35
+
36
+ # Filter out columns that already have _cellmd
37
+ col_ids_to_add: list[int] = []
38
+ for col_id in col_ids:
39
+ cellmd_col = f'col_{col_id}_cellmd'
40
+ if cellmd_col not in existing_columns:
41
+ col_ids_to_add.append(col_id)
42
+ else:
43
+ _logger.info(f'Column {cellmd_col} already exists in table {orig_table_md["name"]}. Skipping.')
44
+
45
+ if len(col_ids_to_add) == 0:
46
+ _logger.info(f'All _cellmd columns already exist in table {orig_table_md["name"]}. Skipping migration.')
47
+ return
48
+
49
+ return add_cellmd_columns(conn, store_name, col_ids_to_add)
50
+
51
+
52
+ def find_target_columns(table_md: dict) -> list[int]:
53
+ """Returns ids of stored array and json columns"""
54
+ result: list[int] = []
55
+ for col_id, col_md in table_md['column_md'].items():
56
+ col_type = col_md['col_type']
57
+ classname = col_type.get('_classname')
58
+ if classname in ['ArrayType', 'JsonType'] and col_md.get('stored', False):
59
+ result.append(col_id)
60
+ _logger.info(f'Found {classname} column: {col_id}')
61
+ return result
62
+
63
+
64
+ def add_cellmd_columns(conn: sql.Connection, store_name: str, col_ids: list[int]) -> None:
65
+ try:
66
+ # Add new columns
67
+ add_column_str = ', '.join(f'ADD COLUMN col_{col_id}_cellmd JSONB DEFAULT NULL' for col_id in col_ids)
68
+ add_column_sql = sql.text(f'ALTER TABLE {store_name} {add_column_str}')
69
+ conn.execute(add_column_sql)
70
+ _logger.info(f'Added columns to {store_name}: {", ".join(f"col_{col_id}_cellmd" for col_id in col_ids)}')
71
+ except sql.exc.SQLAlchemyError as e:
72
+ _logger.error(f'Migration for table {store_name} failed: {e}')
73
+ raise
@@ -2,6 +2,7 @@
2
2
  # rather than as a comment, so that the existence of a description can be enforced by
3
3
  # the unit tests when new versions are added.
4
4
  VERSION_NOTES = {
5
+ 41: 'Cellmd columns for array and json columns',
5
6
  40: 'Convert error property columns to cellmd columns',
6
7
  39: 'ColumnHandles in external stores',
7
8
  38: 'Added TableMd.view_sn',