datachain 0.8.13__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/file.py CHANGED
@@ -17,7 +17,7 @@ from urllib.parse import unquote, urlparse
17
17
  from urllib.request import url2pathname
18
18
 
19
19
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
20
- from PIL import Image
20
+ from PIL import Image as PilImage
21
21
  from pydantic import Field, field_validator
22
22
 
23
23
  from datachain.client.fileslice import FileSlice
@@ -27,6 +27,7 @@ from datachain.sql.types import JSON, Boolean, DateTime, Int, String
27
27
  from datachain.utils import TIME_ZERO
28
28
 
29
29
  if TYPE_CHECKING:
30
+ from numpy import ndarray
30
31
  from typing_extensions import Self
31
32
 
32
33
  from datachain.catalog import Catalog
@@ -40,7 +41,7 @@ logger = logging.getLogger("datachain")
40
41
  # how to create file path when exporting
41
42
  ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
42
43
 
43
- FileType = Literal["binary", "text", "image"]
44
+ FileType = Literal["binary", "text", "image", "video"]
44
45
 
45
46
 
46
47
  class VFileError(DataChainError):
@@ -121,7 +122,21 @@ class VFileRegistry:
121
122
 
122
123
 
123
124
  class File(DataModel):
124
- """`DataModel` for reading binary files."""
125
+ """
126
+ `DataModel` for reading binary files.
127
+
128
+ Attributes:
129
+ source (str): The source of the file (e.g., 's3://bucket-name/').
130
+ path (str): The path to the file (e.g., 'path/to/file.txt').
131
+ size (int): The size of the file in bytes. Defaults to 0.
132
+ version (str): The version of the file. Defaults to an empty string.
133
+ etag (str): The ETag of the file. Defaults to an empty string.
134
+ is_latest (bool): Whether the file is the latest version. Defaults to `True`.
135
+ last_modified (datetime): The last modified timestamp of the file.
136
+ Defaults to Unix epoch (`1970-01-01T00:00:00`).
137
+ location (dict | list[dict], optional): The location of the file.
138
+ Defaults to `None`.
139
+ """
125
140
 
126
141
  source: str = Field(default="")
127
142
  path: str
@@ -193,16 +208,21 @@ class File(DataModel):
193
208
  @classmethod
194
209
  def upload(
195
210
  cls, data: bytes, path: str, catalog: Optional["Catalog"] = None
196
- ) -> "File":
211
+ ) -> "Self":
197
212
  if catalog is None:
198
213
  from datachain.catalog.loader import get_catalog
199
214
 
200
215
  catalog = get_catalog()
201
216
 
202
- parent, name = posixpath.split(path)
217
+ from datachain.client.fsspec import Client
203
218
 
204
- client = catalog.get_client(parent)
205
- file = client.upload(data, name)
219
+ client_cls = Client.get_implementation(path)
220
+ source, rel_path = client_cls.split_url(path)
221
+
222
+ client = catalog.get_client(client_cls.get_uri(source))
223
+ file = client.upload(data, rel_path)
224
+ if not isinstance(file, cls):
225
+ file = cls(**file.model_dump())
206
226
  file._set_stream(catalog)
207
227
  return file
208
228
 
@@ -486,13 +506,281 @@ class ImageFile(File):
486
506
  def read(self):
487
507
  """Returns `PIL.Image.Image` object."""
488
508
  fobj = super().read()
489
- return Image.open(BytesIO(fobj))
509
+ return PilImage.open(BytesIO(fobj))
490
510
 
491
511
  def save(self, destination: str):
492
512
  """Writes it's content to destination"""
493
513
  self.read().save(destination)
494
514
 
495
515
 
516
+ class Image(DataModel):
517
+ """
518
+ A data model representing metadata for an image file.
519
+
520
+ Attributes:
521
+ width (int): The width of the image in pixels. Defaults to -1 if unknown.
522
+ height (int): The height of the image in pixels. Defaults to -1 if unknown.
523
+ format (str): The format of the image file (e.g., 'jpg', 'png').
524
+ Defaults to an empty string.
525
+ """
526
+
527
+ width: int = Field(default=-1)
528
+ height: int = Field(default=-1)
529
+ format: str = Field(default="")
530
+
531
+
532
+ class VideoFile(File):
533
+ """
534
+ A data model for handling video files.
535
+
536
+ This model inherits from the `File` model and provides additional functionality
537
+ for reading video files, extracting video frames, and splitting videos into
538
+ fragments.
539
+ """
540
+
541
+ def get_info(self) -> "Video":
542
+ """
543
+ Retrieves metadata and information about the video file.
544
+
545
+ Returns:
546
+ Video: A Model containing video metadata such as duration,
547
+ resolution, frame rate, and codec details.
548
+ """
549
+ from .video import video_info
550
+
551
+ return video_info(self)
552
+
553
+ def get_frame(self, frame: int) -> "VideoFrame":
554
+ """
555
+ Returns a specific video frame by its frame number.
556
+
557
+ Args:
558
+ frame (int): The frame number to read.
559
+
560
+ Returns:
561
+ VideoFrame: Video frame model.
562
+ """
563
+ if frame < 0:
564
+ raise ValueError("frame must be a non-negative integer")
565
+
566
+ return VideoFrame(video=self, frame=frame)
567
+
568
+ def get_frames(
569
+ self,
570
+ start: int = 0,
571
+ end: Optional[int] = None,
572
+ step: int = 1,
573
+ ) -> "Iterator[VideoFrame]":
574
+ """
575
+ Returns video frames from the specified range in the video.
576
+
577
+ Args:
578
+ start (int): The starting frame number (default: 0).
579
+ end (int, optional): The ending frame number (exclusive). If None,
580
+ frames are read until the end of the video
581
+ (default: None).
582
+ step (int): The interval between frames to read (default: 1).
583
+
584
+ Returns:
585
+ Iterator[VideoFrame]: An iterator yielding video frames.
586
+
587
+ Note:
588
+ If end is not specified, number of frames will be taken from the video file,
589
+ this means video file needs to be downloaded.
590
+ """
591
+ from .video import validate_frame_range
592
+
593
+ start, end, step = validate_frame_range(self, start, end, step)
594
+
595
+ for frame in range(start, end, step):
596
+ yield self.get_frame(frame)
597
+
598
+ def get_fragment(self, start: float, end: float) -> "VideoFragment":
599
+ """
600
+ Returns a video fragment from the specified time range.
601
+
602
+ Args:
603
+ start (float): The start time of the fragment in seconds.
604
+ end (float): The end time of the fragment in seconds.
605
+
606
+ Returns:
607
+ VideoFragment: A Model representing the video fragment.
608
+ """
609
+ if start < 0 or end < 0 or start >= end:
610
+ raise ValueError(f"Invalid time range: ({start:.3f}, {end:.3f})")
611
+
612
+ return VideoFragment(video=self, start=start, end=end)
613
+
614
+ def get_fragments(
615
+ self,
616
+ duration: float,
617
+ start: float = 0,
618
+ end: Optional[float] = None,
619
+ ) -> "Iterator[VideoFragment]":
620
+ """
621
+ Splits the video into multiple fragments of a specified duration.
622
+
623
+ Args:
624
+ duration (float): The duration of each video fragment in seconds.
625
+ start (float): The starting time in seconds (default: 0).
626
+ end (float, optional): The ending time in seconds. If None, the entire
627
+ remaining video is processed (default: None).
628
+
629
+ Returns:
630
+ Iterator[VideoFragment]: An iterator yielding video fragments.
631
+
632
+ Note:
633
+ If end is not specified, number of frames will be taken from the video file,
634
+ this means video file needs to be downloaded.
635
+ """
636
+ if duration <= 0:
637
+ raise ValueError("duration must be a positive float")
638
+ if start < 0:
639
+ raise ValueError("start must be a non-negative float")
640
+
641
+ if end is None:
642
+ end = self.get_info().duration
643
+
644
+ if end < 0:
645
+ raise ValueError("end must be a non-negative float")
646
+ if start >= end:
647
+ raise ValueError("start must be less than end")
648
+
649
+ while start < end:
650
+ yield self.get_fragment(start, min(start + duration, end))
651
+ start += duration
652
+
653
+
654
+ class VideoFrame(DataModel):
655
+ """
656
+ A data model for representing a video frame.
657
+
658
+ This model inherits from the `VideoFile` model and adds a `frame` attribute,
659
+ which represents a specific frame within a video file. It allows access
660
+ to individual frames and provides functionality for reading and saving
661
+ video frames as image files.
662
+
663
+ Attributes:
664
+ video (VideoFile): The video file containing the video frame.
665
+ frame (int): The frame number referencing a specific frame in the video file.
666
+ """
667
+
668
+ video: VideoFile
669
+ frame: int
670
+
671
+ def get_np(self) -> "ndarray":
672
+ """
673
+ Returns a video frame from the video file as a NumPy array.
674
+
675
+ Returns:
676
+ ndarray: A NumPy array representing the video frame,
677
+ in the shape (height, width, channels).
678
+ """
679
+ from .video import video_frame_np
680
+
681
+ return video_frame_np(self.video, self.frame)
682
+
683
+ def read_bytes(self, format: str = "jpg") -> bytes:
684
+ """
685
+ Returns a video frame from the video file as image bytes.
686
+
687
+ Args:
688
+ format (str): The desired image format (e.g., 'jpg', 'png').
689
+ Defaults to 'jpg'.
690
+
691
+ Returns:
692
+ bytes: The encoded video frame as image bytes.
693
+ """
694
+ from .video import video_frame_bytes
695
+
696
+ return video_frame_bytes(self.video, self.frame, format)
697
+
698
+ def save(self, output: str, format: str = "jpg") -> "ImageFile":
699
+ """
700
+ Saves the current video frame as an image file.
701
+
702
+ If `output` is a remote path, the image file will be uploaded to remote storage.
703
+
704
+ Args:
705
+ output (str): The destination path, which can be a local file path
706
+ or a remote URL.
707
+ format (str): The image format (e.g., 'jpg', 'png'). Defaults to 'jpg'.
708
+
709
+ Returns:
710
+ ImageFile: A Model representing the saved image file.
711
+ """
712
+ from .video import save_video_frame
713
+
714
+ return save_video_frame(self.video, self.frame, output, format)
715
+
716
+
717
+ class VideoFragment(DataModel):
718
+ """
719
+ A data model for representing a video fragment.
720
+
721
+ This model inherits from the `VideoFile` model and adds `start`
722
+ and `end` attributes, which represent a specific fragment within a video file.
723
+ It allows access to individual fragments and provides functionality for reading
724
+ and saving video fragments as separate video files.
725
+
726
+ Attributes:
727
+ video (VideoFile): The video file containing the video fragment.
728
+ start (float): The starting time of the video fragment in seconds.
729
+ end (float): The ending time of the video fragment in seconds.
730
+ """
731
+
732
+ video: VideoFile
733
+ start: float
734
+ end: float
735
+
736
+ def save(self, output: str, format: Optional[str] = None) -> "VideoFile":
737
+ """
738
+ Saves the video fragment as a new video file.
739
+
740
+ If `output` is a remote path, the video file will be uploaded to remote storage.
741
+
742
+ Args:
743
+ output (str): The destination path, which can be a local file path
744
+ or a remote URL.
745
+ format (str, optional): The output video format (e.g., 'mp4', 'avi').
746
+ If None, the format is inferred from the
747
+ file extension.
748
+
749
+ Returns:
750
+ VideoFile: A Model representing the saved video file.
751
+ """
752
+ from .video import save_video_fragment
753
+
754
+ return save_video_fragment(self.video, self.start, self.end, output, format)
755
+
756
+
757
+ class Video(DataModel):
758
+ """
759
+ A data model representing metadata for a video file.
760
+
761
+ Attributes:
762
+ width (int): The width of the video in pixels. Defaults to -1 if unknown.
763
+ height (int): The height of the video in pixels. Defaults to -1 if unknown.
764
+ fps (float): The frame rate of the video (frames per second).
765
+ Defaults to -1.0 if unknown.
766
+ duration (float): The total duration of the video in seconds.
767
+ Defaults to -1.0 if unknown.
768
+ frames (int): The total number of frames in the video.
769
+ Defaults to -1 if unknown.
770
+ format (str): The format of the video file (e.g., 'mp4', 'avi').
771
+ Defaults to an empty string.
772
+ codec (str): The codec used for encoding the video. Defaults to an empty string.
773
+ """
774
+
775
+ width: int = Field(default=-1)
776
+ height: int = Field(default=-1)
777
+ fps: float = Field(default=-1.0)
778
+ duration: float = Field(default=-1.0)
779
+ frames: int = Field(default=-1)
780
+ format: str = Field(default="")
781
+ codec: str = Field(default="")
782
+
783
+
496
784
  class ArrowRow(DataModel):
497
785
  """`DataModel` for reading row from Arrow-supported file."""
498
786
 
@@ -528,5 +816,7 @@ def get_file_type(type_: FileType = "binary") -> type[File]:
528
816
  file = TextFile
529
817
  elif type_ == "image":
530
818
  file = ImageFile # type: ignore[assignment]
819
+ elif type_ == "video":
820
+ file = VideoFile
531
821
 
532
822
  return file
datachain/lib/hf.py CHANGED
@@ -20,7 +20,7 @@ try:
20
20
 
21
21
  except ImportError as exc:
22
22
  raise ImportError(
23
- "Missing dependencies for huggingface datasets:\n"
23
+ "Missing dependencies for huggingface datasets.\n"
24
24
  "To install run:\n\n"
25
25
  " pip install 'datachain[hf]'\n"
26
26
  ) from exc
datachain/lib/video.py ADDED
@@ -0,0 +1,223 @@
1
+ import posixpath
2
+ import shutil
3
+ import tempfile
4
+ from typing import Optional
5
+
6
+ from numpy import ndarray
7
+
8
+ from datachain.lib.file import FileError, ImageFile, Video, VideoFile
9
+
10
+ try:
11
+ import ffmpeg
12
+ import imageio.v3 as iio
13
+ except ImportError as exc:
14
+ raise ImportError(
15
+ "Missing dependencies for processing video.\n"
16
+ "To install run:\n\n"
17
+ " pip install 'datachain[video]'\n"
18
+ ) from exc
19
+
20
+
21
+ def video_info(file: VideoFile) -> Video:
22
+ """
23
+ Returns video file information.
24
+
25
+ Args:
26
+ file (VideoFile): Video file object.
27
+
28
+ Returns:
29
+ Video: Video file information.
30
+ """
31
+ if not (file_path := file.get_local_path()):
32
+ file.ensure_cached()
33
+ file_path = file.get_local_path()
34
+ if not file_path:
35
+ raise FileError(file, "unable to download video file")
36
+
37
+ try:
38
+ probe = ffmpeg.probe(file_path)
39
+ except Exception as exc:
40
+ raise FileError(file, "unable to extract metadata from video file") from exc
41
+
42
+ all_streams = probe.get("streams")
43
+ video_format = probe.get("format")
44
+ if not all_streams or not video_format:
45
+ raise FileError(file, "unable to extract metadata from video file")
46
+
47
+ video_streams = [s for s in all_streams if s["codec_type"] == "video"]
48
+ if len(video_streams) == 0:
49
+ raise FileError(file, "unable to extract metadata from video file")
50
+
51
+ video_stream = video_streams[0]
52
+
53
+ r_frame_rate = video_stream.get("r_frame_rate", "0")
54
+ if "/" in r_frame_rate:
55
+ num, denom = r_frame_rate.split("/")
56
+ fps = float(num) / float(denom)
57
+ else:
58
+ fps = float(r_frame_rate)
59
+
60
+ width = int(video_stream.get("width", 0))
61
+ height = int(video_stream.get("height", 0))
62
+ duration = float(video_format.get("duration", 0))
63
+ if "nb_frames" in video_stream:
64
+ frames = int(video_stream.get("nb_frames", 0))
65
+ else:
66
+ start_time = float(video_format.get("start_time", 0))
67
+ frames = int((duration - start_time) * fps)
68
+ format_name = video_format.get("format_name", "")
69
+ codec_name = video_stream.get("codec_name", "")
70
+
71
+ return Video(
72
+ width=width,
73
+ height=height,
74
+ fps=fps,
75
+ duration=duration,
76
+ frames=frames,
77
+ format=format_name,
78
+ codec=codec_name,
79
+ )
80
+
81
+
82
+ def video_frame_np(video: VideoFile, frame: int) -> ndarray:
83
+ """
84
+ Reads video frame from a file and returns as numpy array.
85
+
86
+ Args:
87
+ video (VideoFile): Video file object.
88
+ frame (int): Frame index.
89
+
90
+ Returns:
91
+ ndarray: Video frame.
92
+ """
93
+ if frame < 0:
94
+ raise ValueError("frame must be a non-negative integer")
95
+
96
+ with video.open() as f:
97
+ return iio.imread(f, index=frame, plugin="pyav") # type: ignore[arg-type]
98
+
99
+
100
+ def validate_frame_range(
101
+ video: VideoFile,
102
+ start: int = 0,
103
+ end: Optional[int] = None,
104
+ step: int = 1,
105
+ ) -> tuple[int, int, int]:
106
+ """
107
+ Validates frame range for a video file.
108
+
109
+ Args:
110
+ video (VideoFile): Video file object.
111
+ start (int): Start frame index (default: 0).
112
+ end (int, optional): End frame index (default: None).
113
+ step (int): Step between frames (default: 1).
114
+
115
+ Returns:
116
+ tuple[int, int, int]: Start frame index, end frame index, and step.
117
+ """
118
+ if start < 0:
119
+ raise ValueError("start_frame must be a non-negative integer.")
120
+ if step < 1:
121
+ raise ValueError("step must be a positive integer.")
122
+
123
+ if end is None:
124
+ end = video_info(video).frames
125
+
126
+ if end < 0:
127
+ raise ValueError("end_frame must be a non-negative integer.")
128
+ if start > end:
129
+ raise ValueError("start_frame must be less than or equal to end_frame.")
130
+
131
+ return start, end, step
132
+
133
+
134
+ def video_frame_bytes(video: VideoFile, frame: int, format: str = "jpg") -> bytes:
135
+ """
136
+ Reads video frame from a file and returns as image bytes.
137
+
138
+ Args:
139
+ video (VideoFile): Video file object.
140
+ frame (int): Frame index.
141
+ format (str): Image format (default: 'jpg').
142
+
143
+ Returns:
144
+ bytes: Video frame image as bytes.
145
+ """
146
+ img = video_frame_np(video, frame)
147
+ return iio.imwrite("<bytes>", img, extension=f".{format}")
148
+
149
+
150
+ def save_video_frame(
151
+ video: VideoFile,
152
+ frame: int,
153
+ output: str,
154
+ format: str = "jpg",
155
+ ) -> ImageFile:
156
+ """
157
+ Saves video frame as a new image file. If output is a remote path,
158
+ the image file will be uploaded to the remote storage.
159
+
160
+ Args:
161
+ video (VideoFile): Video file object.
162
+ frame (int): Frame index.
163
+ output (str): Output path, can be a local path or a remote path.
164
+ format (str): Image format (default: 'jpg').
165
+
166
+ Returns:
167
+ ImageFile: Image file model.
168
+ """
169
+ img = video_frame_bytes(video, frame, format=format)
170
+ output_file = posixpath.join(
171
+ output, f"{video.get_file_stem()}_{frame:04d}.{format}"
172
+ )
173
+ return ImageFile.upload(img, output_file)
174
+
175
+
176
+ def save_video_fragment(
177
+ video: VideoFile,
178
+ start: float,
179
+ end: float,
180
+ output: str,
181
+ format: Optional[str] = None,
182
+ ) -> VideoFile:
183
+ """
184
+ Saves video interval as a new video file. If output is a remote path,
185
+ the video file will be uploaded to the remote storage.
186
+
187
+ Args:
188
+ video (VideoFile): Video file object.
189
+ start (float): Start time in seconds.
190
+ end (float): End time in seconds.
191
+ output (str): Output path, can be a local path or a remote path.
192
+ format (str, optional): Output format (default: None). If not provided,
193
+ the format will be inferred from the video fragment
194
+ file extension.
195
+
196
+ Returns:
197
+ VideoFile: Video fragment model.
198
+ """
199
+ if start < 0 or end < 0 or start >= end:
200
+ raise ValueError(f"Invalid time range: ({start:.3f}, {end:.3f})")
201
+
202
+ if format is None:
203
+ format = video.get_file_ext()
204
+
205
+ start_ms = int(start * 1000)
206
+ end_ms = int(end * 1000)
207
+ output_file = posixpath.join(
208
+ output, f"{video.get_file_stem()}_{start_ms:06d}_{end_ms:06d}.{format}"
209
+ )
210
+
211
+ temp_dir = tempfile.mkdtemp()
212
+ try:
213
+ output_file_tmp = posixpath.join(temp_dir, posixpath.basename(output_file))
214
+ ffmpeg.input(
215
+ video.get_local_path(),
216
+ ss=start,
217
+ to=end,
218
+ ).output(output_file_tmp).run(quiet=True)
219
+
220
+ with open(output_file_tmp, "rb") as f:
221
+ return VideoFile.upload(f.read(), output_file)
222
+ finally:
223
+ shutil.rmtree(temp_dir)
@@ -42,13 +42,17 @@ from datachain.data_storage.schema import (
42
42
  partition_col_names,
43
43
  partition_columns,
44
44
  )
45
- from datachain.dataset import DatasetStatus, RowDict
46
- from datachain.error import DatasetNotFoundError, QueryScriptCancelError
45
+ from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict
46
+ from datachain.error import (
47
+ DatasetNotFoundError,
48
+ QueryScriptCancelError,
49
+ )
47
50
  from datachain.func.base import Function
48
51
  from datachain.lib.udf import UDFAdapter, _get_cache
49
52
  from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
50
53
  from datachain.query.schema import C, UDFParamSpec, normalize_param
51
54
  from datachain.query.session import Session
55
+ from datachain.remote.studio import is_token_set
52
56
  from datachain.sql.functions.random import rand
53
57
  from datachain.utils import (
54
58
  batched,
@@ -1081,6 +1085,7 @@ class DatasetQuery:
1081
1085
  session: Optional[Session] = None,
1082
1086
  indexing_column_types: Optional[dict[str, Any]] = None,
1083
1087
  in_memory: bool = False,
1088
+ fallback_to_remote: bool = True,
1084
1089
  ) -> None:
1085
1090
  self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
1086
1091
  self.catalog = catalog or self.session.catalog
@@ -1097,7 +1102,12 @@ class DatasetQuery:
1097
1102
  self.column_types: Optional[dict[str, Any]] = None
1098
1103
 
1099
1104
  self.name = name
1100
- ds = self.catalog.get_dataset(name)
1105
+
1106
+ if fallback_to_remote and is_token_set():
1107
+ ds = self.catalog.get_dataset_with_remote_fallback(name, version)
1108
+ else:
1109
+ ds = self.catalog.get_dataset(name)
1110
+
1101
1111
  self.version = version or ds.latest_version
1102
1112
  self.feature_schema = ds.get_version(self.version).feature_schema
1103
1113
  self.column_types = copy(ds.schema)
@@ -1112,6 +1122,21 @@ class DatasetQuery:
1112
1122
  def __or__(self, other):
1113
1123
  return self.union(other)
1114
1124
 
1125
+ def pull_dataset(self, name: str, version: Optional[int] = None) -> "DatasetRecord":
1126
+ print("Dataset not found in local catalog, trying to get from studio")
1127
+
1128
+ remote_ds_uri = f"{DATASET_PREFIX}{name}"
1129
+ if version:
1130
+ remote_ds_uri += f"@v{version}"
1131
+
1132
+ self.catalog.pull_dataset(
1133
+ remote_ds_uri=remote_ds_uri,
1134
+ local_ds_name=name,
1135
+ local_ds_version=version,
1136
+ )
1137
+
1138
+ return self.catalog.get_dataset(name)
1139
+
1115
1140
  @staticmethod
1116
1141
  def get_table() -> "TableClause":
1117
1142
  table_name = "".join(