datachain 0.8.12__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/file.py CHANGED
@@ -17,7 +17,7 @@ from urllib.parse import unquote, urlparse
17
17
  from urllib.request import url2pathname
18
18
 
19
19
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
20
- from PIL import Image
20
+ from PIL import Image as PilImage
21
21
  from pydantic import Field, field_validator
22
22
 
23
23
  from datachain.client.fileslice import FileSlice
@@ -27,6 +27,7 @@ from datachain.sql.types import JSON, Boolean, DateTime, Int, String
27
27
  from datachain.utils import TIME_ZERO
28
28
 
29
29
  if TYPE_CHECKING:
30
+ from numpy import ndarray
30
31
  from typing_extensions import Self
31
32
 
32
33
  from datachain.catalog import Catalog
@@ -40,7 +41,7 @@ logger = logging.getLogger("datachain")
40
41
  # how to create file path when exporting
41
42
  ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
42
43
 
43
- FileType = Literal["binary", "text", "image"]
44
+ FileType = Literal["binary", "text", "image", "video"]
44
45
 
45
46
 
46
47
  class VFileError(DataChainError):
@@ -121,7 +122,21 @@ class VFileRegistry:
121
122
 
122
123
 
123
124
  class File(DataModel):
124
- """`DataModel` for reading binary files."""
125
+ """
126
+ `DataModel` for reading binary files.
127
+
128
+ Attributes:
129
+ source (str): The source of the file (e.g., 's3://bucket-name/').
130
+ path (str): The path to the file (e.g., 'path/to/file.txt').
131
+ size (int): The size of the file in bytes. Defaults to 0.
132
+ version (str): The version of the file. Defaults to an empty string.
133
+ etag (str): The ETag of the file. Defaults to an empty string.
134
+ is_latest (bool): Whether the file is the latest version. Defaults to `True`.
135
+ last_modified (datetime): The last modified timestamp of the file.
136
+ Defaults to Unix epoch (`1970-01-01T00:00:00`).
137
+ location (dict | list[dict], optional): The location of the file.
138
+ Defaults to `None`.
139
+ """
125
140
 
126
141
  source: str = Field(default="")
127
142
  path: str
@@ -193,7 +208,7 @@ class File(DataModel):
193
208
  @classmethod
194
209
  def upload(
195
210
  cls, data: bytes, path: str, catalog: Optional["Catalog"] = None
196
- ) -> "File":
211
+ ) -> "Self":
197
212
  if catalog is None:
198
213
  from datachain.catalog.loader import get_catalog
199
214
 
@@ -203,6 +218,8 @@ class File(DataModel):
203
218
 
204
219
  client = catalog.get_client(parent)
205
220
  file = client.upload(data, name)
221
+ if not isinstance(file, cls):
222
+ file = cls(**file.model_dump())
206
223
  file._set_stream(catalog)
207
224
  return file
208
225
 
@@ -486,13 +503,281 @@ class ImageFile(File):
486
503
  def read(self):
487
504
  """Returns `PIL.Image.Image` object."""
488
505
  fobj = super().read()
489
- return Image.open(BytesIO(fobj))
506
+ return PilImage.open(BytesIO(fobj))
490
507
 
491
508
  def save(self, destination: str):
492
509
  """Writes it's content to destination"""
493
510
  self.read().save(destination)
494
511
 
495
512
 
513
+ class Image(DataModel):
514
+ """
515
+ A data model representing metadata for an image file.
516
+
517
+ Attributes:
518
+ width (int): The width of the image in pixels. Defaults to -1 if unknown.
519
+ height (int): The height of the image in pixels. Defaults to -1 if unknown.
520
+ format (str): The format of the image file (e.g., 'jpg', 'png').
521
+ Defaults to an empty string.
522
+ """
523
+
524
+ width: int = Field(default=-1)
525
+ height: int = Field(default=-1)
526
+ format: str = Field(default="")
527
+
528
+
529
+ class VideoFile(File):
530
+ """
531
+ A data model for handling video files.
532
+
533
+ This model inherits from the `File` model and provides additional functionality
534
+ for reading video files, extracting video frames, and splitting videos into
535
+ fragments.
536
+ """
537
+
538
+ def get_info(self) -> "Video":
539
+ """
540
+ Retrieves metadata and information about the video file.
541
+
542
+ Returns:
543
+ Video: A Model containing video metadata such as duration,
544
+ resolution, frame rate, and codec details.
545
+ """
546
+ from .video import video_info
547
+
548
+ return video_info(self)
549
+
550
+ def get_frame(self, frame: int) -> "VideoFrame":
551
+ """
552
+ Returns a specific video frame by its frame number.
553
+
554
+ Args:
555
+ frame (int): The frame number to read.
556
+
557
+ Returns:
558
+ VideoFrame: Video frame model.
559
+ """
560
+ if frame < 0:
561
+ raise ValueError("frame must be a non-negative integer")
562
+
563
+ return VideoFrame(video=self, frame=frame)
564
+
565
+ def get_frames(
566
+ self,
567
+ start: int = 0,
568
+ end: Optional[int] = None,
569
+ step: int = 1,
570
+ ) -> "Iterator[VideoFrame]":
571
+ """
572
+ Returns video frames from the specified range in the video.
573
+
574
+ Args:
575
+ start (int): The starting frame number (default: 0).
576
+ end (int, optional): The ending frame number (exclusive). If None,
577
+ frames are read until the end of the video
578
+ (default: None).
579
+ step (int): The interval between frames to read (default: 1).
580
+
581
+ Returns:
582
+ Iterator[VideoFrame]: An iterator yielding video frames.
583
+
584
+ Note:
585
+ If end is not specified, number of frames will be taken from the video file,
586
+ this means video file needs to be downloaded.
587
+ """
588
+ from .video import validate_frame_range
589
+
590
+ start, end, step = validate_frame_range(self, start, end, step)
591
+
592
+ for frame in range(start, end, step):
593
+ yield self.get_frame(frame)
594
+
595
+ def get_fragment(self, start: float, end: float) -> "VideoFragment":
596
+ """
597
+ Returns a video fragment from the specified time range.
598
+
599
+ Args:
600
+ start (float): The start time of the fragment in seconds.
601
+ end (float): The end time of the fragment in seconds.
602
+
603
+ Returns:
604
+ VideoFragment: A Model representing the video fragment.
605
+ """
606
+ if start < 0 or end < 0 or start >= end:
607
+ raise ValueError(f"Invalid time range: ({start:.3f}, {end:.3f})")
608
+
609
+ return VideoFragment(video=self, start=start, end=end)
610
+
611
+ def get_fragments(
612
+ self,
613
+ duration: float,
614
+ start: float = 0,
615
+ end: Optional[float] = None,
616
+ ) -> "Iterator[VideoFragment]":
617
+ """
618
+ Splits the video into multiple fragments of a specified duration.
619
+
620
+ Args:
621
+ duration (float): The duration of each video fragment in seconds.
622
+ start (float): The starting time in seconds (default: 0).
623
+ end (float, optional): The ending time in seconds. If None, the entire
624
+ remaining video is processed (default: None).
625
+
626
+ Returns:
627
+ Iterator[VideoFragment]: An iterator yielding video fragments.
628
+
629
+ Note:
630
+ If end is not specified, number of frames will be taken from the video file,
631
+ this means video file needs to be downloaded.
632
+ """
633
+ if duration <= 0:
634
+ raise ValueError("duration must be a positive float")
635
+ if start < 0:
636
+ raise ValueError("start must be a non-negative float")
637
+
638
+ if end is None:
639
+ end = self.get_info().duration
640
+
641
+ if end < 0:
642
+ raise ValueError("end must be a non-negative float")
643
+ if start >= end:
644
+ raise ValueError("start must be less than end")
645
+
646
+ while start < end:
647
+ yield self.get_fragment(start, min(start + duration, end))
648
+ start += duration
649
+
650
+
651
+ class VideoFrame(DataModel):
652
+ """
653
+ A data model for representing a video frame.
654
+
655
+ This model inherits from the `VideoFile` model and adds a `frame` attribute,
656
+ which represents a specific frame within a video file. It allows access
657
+ to individual frames and provides functionality for reading and saving
658
+ video frames as image files.
659
+
660
+ Attributes:
661
+ video (VideoFile): The video file containing the video frame.
662
+ frame (int): The frame number referencing a specific frame in the video file.
663
+ """
664
+
665
+ video: VideoFile
666
+ frame: int
667
+
668
+ def get_np(self) -> "ndarray":
669
+ """
670
+ Returns a video frame from the video file as a NumPy array.
671
+
672
+ Returns:
673
+ ndarray: A NumPy array representing the video frame,
674
+ in the shape (height, width, channels).
675
+ """
676
+ from .video import video_frame_np
677
+
678
+ return video_frame_np(self.video, self.frame)
679
+
680
+ def read_bytes(self, format: str = "jpg") -> bytes:
681
+ """
682
+ Returns a video frame from the video file as image bytes.
683
+
684
+ Args:
685
+ format (str): The desired image format (e.g., 'jpg', 'png').
686
+ Defaults to 'jpg'.
687
+
688
+ Returns:
689
+ bytes: The encoded video frame as image bytes.
690
+ """
691
+ from .video import video_frame_bytes
692
+
693
+ return video_frame_bytes(self.video, self.frame, format)
694
+
695
+ def save(self, output: str, format: str = "jpg") -> "ImageFile":
696
+ """
697
+ Saves the current video frame as an image file.
698
+
699
+ If `output` is a remote path, the image file will be uploaded to remote storage.
700
+
701
+ Args:
702
+ output (str): The destination path, which can be a local file path
703
+ or a remote URL.
704
+ format (str): The image format (e.g., 'jpg', 'png'). Defaults to 'jpg'.
705
+
706
+ Returns:
707
+ ImageFile: A Model representing the saved image file.
708
+ """
709
+ from .video import save_video_frame
710
+
711
+ return save_video_frame(self.video, self.frame, output, format)
712
+
713
+
714
+ class VideoFragment(DataModel):
715
+ """
716
+ A data model for representing a video fragment.
717
+
718
+ This model inherits from the `VideoFile` model and adds `start`
719
+ and `end` attributes, which represent a specific fragment within a video file.
720
+ It allows access to individual fragments and provides functionality for reading
721
+ and saving video fragments as separate video files.
722
+
723
+ Attributes:
724
+ video (VideoFile): The video file containing the video fragment.
725
+ start (float): The starting time of the video fragment in seconds.
726
+ end (float): The ending time of the video fragment in seconds.
727
+ """
728
+
729
+ video: VideoFile
730
+ start: float
731
+ end: float
732
+
733
+ def save(self, output: str, format: Optional[str] = None) -> "VideoFile":
734
+ """
735
+ Saves the video fragment as a new video file.
736
+
737
+ If `output` is a remote path, the video file will be uploaded to remote storage.
738
+
739
+ Args:
740
+ output (str): The destination path, which can be a local file path
741
+ or a remote URL.
742
+ format (str, optional): The output video format (e.g., 'mp4', 'avi').
743
+ If None, the format is inferred from the
744
+ file extension.
745
+
746
+ Returns:
747
+ VideoFile: A Model representing the saved video file.
748
+ """
749
+ from .video import save_video_fragment
750
+
751
+ return save_video_fragment(self.video, self.start, self.end, output, format)
752
+
753
+
754
+ class Video(DataModel):
755
+ """
756
+ A data model representing metadata for a video file.
757
+
758
+ Attributes:
759
+ width (int): The width of the video in pixels. Defaults to -1 if unknown.
760
+ height (int): The height of the video in pixels. Defaults to -1 if unknown.
761
+ fps (float): The frame rate of the video (frames per second).
762
+ Defaults to -1.0 if unknown.
763
+ duration (float): The total duration of the video in seconds.
764
+ Defaults to -1.0 if unknown.
765
+ frames (int): The total number of frames in the video.
766
+ Defaults to -1 if unknown.
767
+ format (str): The format of the video file (e.g., 'mp4', 'avi').
768
+ Defaults to an empty string.
769
+ codec (str): The codec used for encoding the video. Defaults to an empty string.
770
+ """
771
+
772
+ width: int = Field(default=-1)
773
+ height: int = Field(default=-1)
774
+ fps: float = Field(default=-1.0)
775
+ duration: float = Field(default=-1.0)
776
+ frames: int = Field(default=-1)
777
+ format: str = Field(default="")
778
+ codec: str = Field(default="")
779
+
780
+
496
781
  class ArrowRow(DataModel):
497
782
  """`DataModel` for reading row from Arrow-supported file."""
498
783
 
@@ -528,5 +813,7 @@ def get_file_type(type_: FileType = "binary") -> type[File]:
528
813
  file = TextFile
529
814
  elif type_ == "image":
530
815
  file = ImageFile # type: ignore[assignment]
816
+ elif type_ == "video":
817
+ file = VideoFile
531
818
 
532
819
  return file
datachain/lib/hf.py CHANGED
@@ -20,7 +20,7 @@ try:
20
20
 
21
21
  except ImportError as exc:
22
22
  raise ImportError(
23
- "Missing dependencies for huggingface datasets:\n"
23
+ "Missing dependencies for huggingface datasets.\n"
24
24
  "To install run:\n\n"
25
25
  " pip install 'datachain[hf]'\n"
26
26
  ) from exc
datachain/lib/video.py ADDED
@@ -0,0 +1,223 @@
1
+ import posixpath
2
+ import shutil
3
+ import tempfile
4
+ from typing import Optional
5
+
6
+ from numpy import ndarray
7
+
8
+ from datachain.lib.file import FileError, ImageFile, Video, VideoFile
9
+
10
+ try:
11
+ import ffmpeg
12
+ import imageio.v3 as iio
13
+ except ImportError as exc:
14
+ raise ImportError(
15
+ "Missing dependencies for processing video.\n"
16
+ "To install run:\n\n"
17
+ " pip install 'datachain[video]'\n"
18
+ ) from exc
19
+
20
+
21
+ def video_info(file: VideoFile) -> Video:
22
+ """
23
+ Returns video file information.
24
+
25
+ Args:
26
+ file (VideoFile): Video file object.
27
+
28
+ Returns:
29
+ Video: Video file information.
30
+ """
31
+ if not (file_path := file.get_local_path()):
32
+ file.ensure_cached()
33
+ file_path = file.get_local_path()
34
+ if not file_path:
35
+ raise FileError(file, "unable to download video file")
36
+
37
+ try:
38
+ probe = ffmpeg.probe(file_path)
39
+ except Exception as exc:
40
+ raise FileError(file, "unable to extract metadata from video file") from exc
41
+
42
+ all_streams = probe.get("streams")
43
+ video_format = probe.get("format")
44
+ if not all_streams or not video_format:
45
+ raise FileError(file, "unable to extract metadata from video file")
46
+
47
+ video_streams = [s for s in all_streams if s["codec_type"] == "video"]
48
+ if len(video_streams) == 0:
49
+ raise FileError(file, "unable to extract metadata from video file")
50
+
51
+ video_stream = video_streams[0]
52
+
53
+ r_frame_rate = video_stream.get("r_frame_rate", "0")
54
+ if "/" in r_frame_rate:
55
+ num, denom = r_frame_rate.split("/")
56
+ fps = float(num) / float(denom)
57
+ else:
58
+ fps = float(r_frame_rate)
59
+
60
+ width = int(video_stream.get("width", 0))
61
+ height = int(video_stream.get("height", 0))
62
+ duration = float(video_format.get("duration", 0))
63
+ if "nb_frames" in video_stream:
64
+ frames = int(video_stream.get("nb_frames", 0))
65
+ else:
66
+ start_time = float(video_format.get("start_time", 0))
67
+ frames = int((duration - start_time) * fps)
68
+ format_name = video_format.get("format_name", "")
69
+ codec_name = video_stream.get("codec_name", "")
70
+
71
+ return Video(
72
+ width=width,
73
+ height=height,
74
+ fps=fps,
75
+ duration=duration,
76
+ frames=frames,
77
+ format=format_name,
78
+ codec=codec_name,
79
+ )
80
+
81
+
82
+ def video_frame_np(video: VideoFile, frame: int) -> ndarray:
83
+ """
84
+ Reads video frame from a file and returns as numpy array.
85
+
86
+ Args:
87
+ video (VideoFile): Video file object.
88
+ frame (int): Frame index.
89
+
90
+ Returns:
91
+ ndarray: Video frame.
92
+ """
93
+ if frame < 0:
94
+ raise ValueError("frame must be a non-negative integer")
95
+
96
+ with video.open() as f:
97
+ return iio.imread(f, index=frame, plugin="pyav") # type: ignore[arg-type]
98
+
99
+
100
+ def validate_frame_range(
101
+ video: VideoFile,
102
+ start: int = 0,
103
+ end: Optional[int] = None,
104
+ step: int = 1,
105
+ ) -> tuple[int, int, int]:
106
+ """
107
+ Validates frame range for a video file.
108
+
109
+ Args:
110
+ video (VideoFile): Video file object.
111
+ start (int): Start frame index (default: 0).
112
+ end (int, optional): End frame index (default: None).
113
+ step (int): Step between frames (default: 1).
114
+
115
+ Returns:
116
+ tuple[int, int, int]: Start frame index, end frame index, and step.
117
+ """
118
+ if start < 0:
119
+ raise ValueError("start_frame must be a non-negative integer.")
120
+ if step < 1:
121
+ raise ValueError("step must be a positive integer.")
122
+
123
+ if end is None:
124
+ end = video_info(video).frames
125
+
126
+ if end < 0:
127
+ raise ValueError("end_frame must be a non-negative integer.")
128
+ if start > end:
129
+ raise ValueError("start_frame must be less than or equal to end_frame.")
130
+
131
+ return start, end, step
132
+
133
+
134
+ def video_frame_bytes(video: VideoFile, frame: int, format: str = "jpg") -> bytes:
135
+ """
136
+ Reads video frame from a file and returns as image bytes.
137
+
138
+ Args:
139
+ video (VideoFile): Video file object.
140
+ frame (int): Frame index.
141
+ format (str): Image format (default: 'jpg').
142
+
143
+ Returns:
144
+ bytes: Video frame image as bytes.
145
+ """
146
+ img = video_frame_np(video, frame)
147
+ return iio.imwrite("<bytes>", img, extension=f".{format}")
148
+
149
+
150
+ def save_video_frame(
151
+ video: VideoFile,
152
+ frame: int,
153
+ output: str,
154
+ format: str = "jpg",
155
+ ) -> ImageFile:
156
+ """
157
+ Saves video frame as a new image file. If output is a remote path,
158
+ the image file will be uploaded to the remote storage.
159
+
160
+ Args:
161
+ video (VideoFile): Video file object.
162
+ frame (int): Frame index.
163
+ output (str): Output path, can be a local path or a remote path.
164
+ format (str): Image format (default: 'jpg').
165
+
166
+ Returns:
167
+ ImageFile: Image file model.
168
+ """
169
+ img = video_frame_bytes(video, frame, format=format)
170
+ output_file = posixpath.join(
171
+ output, f"{video.get_file_stem()}_{frame:04d}.{format}"
172
+ )
173
+ return ImageFile.upload(img, output_file)
174
+
175
+
176
+ def save_video_fragment(
177
+ video: VideoFile,
178
+ start: float,
179
+ end: float,
180
+ output: str,
181
+ format: Optional[str] = None,
182
+ ) -> VideoFile:
183
+ """
184
+ Saves video interval as a new video file. If output is a remote path,
185
+ the video file will be uploaded to the remote storage.
186
+
187
+ Args:
188
+ video (VideoFile): Video file object.
189
+ start (float): Start time in seconds.
190
+ end (float): End time in seconds.
191
+ output (str): Output path, can be a local path or a remote path.
192
+ format (str, optional): Output format (default: None). If not provided,
193
+ the format will be inferred from the video fragment
194
+ file extension.
195
+
196
+ Returns:
197
+ VideoFile: Video fragment model.
198
+ """
199
+ if start < 0 or end < 0 or start >= end:
200
+ raise ValueError(f"Invalid time range: ({start:.3f}, {end:.3f})")
201
+
202
+ if format is None:
203
+ format = video.get_file_ext()
204
+
205
+ start_ms = int(start * 1000)
206
+ end_ms = int(end * 1000)
207
+ output_file = posixpath.join(
208
+ output, f"{video.get_file_stem()}_{start_ms:06d}_{end_ms:06d}.{format}"
209
+ )
210
+
211
+ temp_dir = tempfile.mkdtemp()
212
+ try:
213
+ output_file_tmp = posixpath.join(temp_dir, posixpath.basename(output_file))
214
+ ffmpeg.input(
215
+ video.get_local_path(),
216
+ ss=start,
217
+ to=end,
218
+ ).output(output_file_tmp).run(quiet=True)
219
+
220
+ with open(output_file_tmp, "rb") as f:
221
+ return VideoFile.upload(f.read(), output_file)
222
+ finally:
223
+ shutil.rmtree(temp_dir)
@@ -42,13 +42,17 @@ from datachain.data_storage.schema import (
42
42
  partition_col_names,
43
43
  partition_columns,
44
44
  )
45
- from datachain.dataset import DatasetStatus, RowDict
46
- from datachain.error import DatasetNotFoundError, QueryScriptCancelError
45
+ from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict
46
+ from datachain.error import (
47
+ DatasetNotFoundError,
48
+ QueryScriptCancelError,
49
+ )
47
50
  from datachain.func.base import Function
48
51
  from datachain.lib.udf import UDFAdapter, _get_cache
49
52
  from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
50
53
  from datachain.query.schema import C, UDFParamSpec, normalize_param
51
54
  from datachain.query.session import Session
55
+ from datachain.remote.studio import is_token_set
52
56
  from datachain.sql.functions.random import rand
53
57
  from datachain.utils import (
54
58
  batched,
@@ -1081,6 +1085,7 @@ class DatasetQuery:
1081
1085
  session: Optional[Session] = None,
1082
1086
  indexing_column_types: Optional[dict[str, Any]] = None,
1083
1087
  in_memory: bool = False,
1088
+ fallback_to_remote: bool = True,
1084
1089
  ) -> None:
1085
1090
  self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
1086
1091
  self.catalog = catalog or self.session.catalog
@@ -1097,7 +1102,12 @@ class DatasetQuery:
1097
1102
  self.column_types: Optional[dict[str, Any]] = None
1098
1103
 
1099
1104
  self.name = name
1100
- ds = self.catalog.get_dataset(name)
1105
+
1106
+ if fallback_to_remote and is_token_set():
1107
+ ds = self.catalog.get_dataset_with_remote_fallback(name, version)
1108
+ else:
1109
+ ds = self.catalog.get_dataset(name)
1110
+
1101
1111
  self.version = version or ds.latest_version
1102
1112
  self.feature_schema = ds.get_version(self.version).feature_schema
1103
1113
  self.column_types = copy(ds.schema)
@@ -1112,6 +1122,21 @@ class DatasetQuery:
1112
1122
  def __or__(self, other):
1113
1123
  return self.union(other)
1114
1124
 
1125
+ def pull_dataset(self, name: str, version: Optional[int] = None) -> "DatasetRecord":
1126
+ print("Dataset not found in local catalog, trying to get from studio")
1127
+
1128
+ remote_ds_uri = f"{DATASET_PREFIX}{name}"
1129
+ if version:
1130
+ remote_ds_uri += f"@v{version}"
1131
+
1132
+ self.catalog.pull_dataset(
1133
+ remote_ds_uri=remote_ds_uri,
1134
+ local_ds_name=name,
1135
+ local_ds_version=version,
1136
+ )
1137
+
1138
+ return self.catalog.get_dataset(name)
1139
+
1115
1140
  @staticmethod
1116
1141
  def get_table() -> "TableClause":
1117
1142
  table_name = "".join(