datachain 0.8.12__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +10 -0
- datachain/catalog/catalog.py +32 -9
- datachain/cli/__init__.py +2 -0
- datachain/cli/commands/datasets.py +78 -12
- datachain/cli/parser/__init__.py +62 -12
- datachain/cli/parser/job.py +14 -4
- datachain/cli/parser/studio.py +8 -0
- datachain/cli/parser/utils.py +20 -1
- datachain/dataset.py +7 -4
- datachain/diff/__init__.py +78 -128
- datachain/fs/reference.py +21 -0
- datachain/func/__init__.py +5 -2
- datachain/func/array.py +39 -1
- datachain/func/conditional.py +66 -2
- datachain/job.py +1 -1
- datachain/lib/arrow.py +1 -11
- datachain/lib/dc.py +2 -0
- datachain/lib/file.py +292 -5
- datachain/lib/hf.py +1 -1
- datachain/lib/video.py +223 -0
- datachain/query/dataset.py +28 -3
- datachain/remote/studio.py +13 -6
- datachain/sql/functions/array.py +13 -1
- datachain/sql/sqlite/base.py +17 -1
- datachain/sql/sqlite/types.py +5 -0
- datachain/studio.py +34 -12
- datachain/utils.py +12 -2
- {datachain-0.8.12.dist-info → datachain-0.9.0.dist-info}/METADATA +13 -5
- {datachain-0.8.12.dist-info → datachain-0.9.0.dist-info}/RECORD +34 -32
- /datachain/{lib/vfile.py → fs/__init__.py} +0 -0
- {datachain-0.8.12.dist-info → datachain-0.9.0.dist-info}/LICENSE +0 -0
- {datachain-0.8.12.dist-info → datachain-0.9.0.dist-info}/WHEEL +0 -0
- {datachain-0.8.12.dist-info → datachain-0.9.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.8.12.dist-info → datachain-0.9.0.dist-info}/top_level.txt +0 -0
datachain/lib/file.py
CHANGED
|
@@ -17,7 +17,7 @@ from urllib.parse import unquote, urlparse
|
|
|
17
17
|
from urllib.request import url2pathname
|
|
18
18
|
|
|
19
19
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
20
|
-
from PIL import Image
|
|
20
|
+
from PIL import Image as PilImage
|
|
21
21
|
from pydantic import Field, field_validator
|
|
22
22
|
|
|
23
23
|
from datachain.client.fileslice import FileSlice
|
|
@@ -27,6 +27,7 @@ from datachain.sql.types import JSON, Boolean, DateTime, Int, String
|
|
|
27
27
|
from datachain.utils import TIME_ZERO
|
|
28
28
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
30
|
+
from numpy import ndarray
|
|
30
31
|
from typing_extensions import Self
|
|
31
32
|
|
|
32
33
|
from datachain.catalog import Catalog
|
|
@@ -40,7 +41,7 @@ logger = logging.getLogger("datachain")
|
|
|
40
41
|
# how to create file path when exporting
|
|
41
42
|
ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
|
|
42
43
|
|
|
43
|
-
FileType = Literal["binary", "text", "image"]
|
|
44
|
+
FileType = Literal["binary", "text", "image", "video"]
|
|
44
45
|
|
|
45
46
|
|
|
46
47
|
class VFileError(DataChainError):
|
|
@@ -121,7 +122,21 @@ class VFileRegistry:
|
|
|
121
122
|
|
|
122
123
|
|
|
123
124
|
class File(DataModel):
|
|
124
|
-
"""
|
|
125
|
+
"""
|
|
126
|
+
`DataModel` for reading binary files.
|
|
127
|
+
|
|
128
|
+
Attributes:
|
|
129
|
+
source (str): The source of the file (e.g., 's3://bucket-name/').
|
|
130
|
+
path (str): The path to the file (e.g., 'path/to/file.txt').
|
|
131
|
+
size (int): The size of the file in bytes. Defaults to 0.
|
|
132
|
+
version (str): The version of the file. Defaults to an empty string.
|
|
133
|
+
etag (str): The ETag of the file. Defaults to an empty string.
|
|
134
|
+
is_latest (bool): Whether the file is the latest version. Defaults to `True`.
|
|
135
|
+
last_modified (datetime): The last modified timestamp of the file.
|
|
136
|
+
Defaults to Unix epoch (`1970-01-01T00:00:00`).
|
|
137
|
+
location (dict | list[dict], optional): The location of the file.
|
|
138
|
+
Defaults to `None`.
|
|
139
|
+
"""
|
|
125
140
|
|
|
126
141
|
source: str = Field(default="")
|
|
127
142
|
path: str
|
|
@@ -193,7 +208,7 @@ class File(DataModel):
|
|
|
193
208
|
@classmethod
|
|
194
209
|
def upload(
|
|
195
210
|
cls, data: bytes, path: str, catalog: Optional["Catalog"] = None
|
|
196
|
-
) -> "
|
|
211
|
+
) -> "Self":
|
|
197
212
|
if catalog is None:
|
|
198
213
|
from datachain.catalog.loader import get_catalog
|
|
199
214
|
|
|
@@ -203,6 +218,8 @@ class File(DataModel):
|
|
|
203
218
|
|
|
204
219
|
client = catalog.get_client(parent)
|
|
205
220
|
file = client.upload(data, name)
|
|
221
|
+
if not isinstance(file, cls):
|
|
222
|
+
file = cls(**file.model_dump())
|
|
206
223
|
file._set_stream(catalog)
|
|
207
224
|
return file
|
|
208
225
|
|
|
@@ -486,13 +503,281 @@ class ImageFile(File):
|
|
|
486
503
|
def read(self):
|
|
487
504
|
"""Returns `PIL.Image.Image` object."""
|
|
488
505
|
fobj = super().read()
|
|
489
|
-
return
|
|
506
|
+
return PilImage.open(BytesIO(fobj))
|
|
490
507
|
|
|
491
508
|
def save(self, destination: str):
|
|
492
509
|
"""Writes it's content to destination"""
|
|
493
510
|
self.read().save(destination)
|
|
494
511
|
|
|
495
512
|
|
|
513
|
+
class Image(DataModel):
|
|
514
|
+
"""
|
|
515
|
+
A data model representing metadata for an image file.
|
|
516
|
+
|
|
517
|
+
Attributes:
|
|
518
|
+
width (int): The width of the image in pixels. Defaults to -1 if unknown.
|
|
519
|
+
height (int): The height of the image in pixels. Defaults to -1 if unknown.
|
|
520
|
+
format (str): The format of the image file (e.g., 'jpg', 'png').
|
|
521
|
+
Defaults to an empty string.
|
|
522
|
+
"""
|
|
523
|
+
|
|
524
|
+
width: int = Field(default=-1)
|
|
525
|
+
height: int = Field(default=-1)
|
|
526
|
+
format: str = Field(default="")
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
class VideoFile(File):
|
|
530
|
+
"""
|
|
531
|
+
A data model for handling video files.
|
|
532
|
+
|
|
533
|
+
This model inherits from the `File` model and provides additional functionality
|
|
534
|
+
for reading video files, extracting video frames, and splitting videos into
|
|
535
|
+
fragments.
|
|
536
|
+
"""
|
|
537
|
+
|
|
538
|
+
def get_info(self) -> "Video":
|
|
539
|
+
"""
|
|
540
|
+
Retrieves metadata and information about the video file.
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
Video: A Model containing video metadata such as duration,
|
|
544
|
+
resolution, frame rate, and codec details.
|
|
545
|
+
"""
|
|
546
|
+
from .video import video_info
|
|
547
|
+
|
|
548
|
+
return video_info(self)
|
|
549
|
+
|
|
550
|
+
def get_frame(self, frame: int) -> "VideoFrame":
|
|
551
|
+
"""
|
|
552
|
+
Returns a specific video frame by its frame number.
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
frame (int): The frame number to read.
|
|
556
|
+
|
|
557
|
+
Returns:
|
|
558
|
+
VideoFrame: Video frame model.
|
|
559
|
+
"""
|
|
560
|
+
if frame < 0:
|
|
561
|
+
raise ValueError("frame must be a non-negative integer")
|
|
562
|
+
|
|
563
|
+
return VideoFrame(video=self, frame=frame)
|
|
564
|
+
|
|
565
|
+
def get_frames(
|
|
566
|
+
self,
|
|
567
|
+
start: int = 0,
|
|
568
|
+
end: Optional[int] = None,
|
|
569
|
+
step: int = 1,
|
|
570
|
+
) -> "Iterator[VideoFrame]":
|
|
571
|
+
"""
|
|
572
|
+
Returns video frames from the specified range in the video.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
start (int): The starting frame number (default: 0).
|
|
576
|
+
end (int, optional): The ending frame number (exclusive). If None,
|
|
577
|
+
frames are read until the end of the video
|
|
578
|
+
(default: None).
|
|
579
|
+
step (int): The interval between frames to read (default: 1).
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
Iterator[VideoFrame]: An iterator yielding video frames.
|
|
583
|
+
|
|
584
|
+
Note:
|
|
585
|
+
If end is not specified, number of frames will be taken from the video file,
|
|
586
|
+
this means video file needs to be downloaded.
|
|
587
|
+
"""
|
|
588
|
+
from .video import validate_frame_range
|
|
589
|
+
|
|
590
|
+
start, end, step = validate_frame_range(self, start, end, step)
|
|
591
|
+
|
|
592
|
+
for frame in range(start, end, step):
|
|
593
|
+
yield self.get_frame(frame)
|
|
594
|
+
|
|
595
|
+
def get_fragment(self, start: float, end: float) -> "VideoFragment":
|
|
596
|
+
"""
|
|
597
|
+
Returns a video fragment from the specified time range.
|
|
598
|
+
|
|
599
|
+
Args:
|
|
600
|
+
start (float): The start time of the fragment in seconds.
|
|
601
|
+
end (float): The end time of the fragment in seconds.
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
VideoFragment: A Model representing the video fragment.
|
|
605
|
+
"""
|
|
606
|
+
if start < 0 or end < 0 or start >= end:
|
|
607
|
+
raise ValueError(f"Invalid time range: ({start:.3f}, {end:.3f})")
|
|
608
|
+
|
|
609
|
+
return VideoFragment(video=self, start=start, end=end)
|
|
610
|
+
|
|
611
|
+
def get_fragments(
|
|
612
|
+
self,
|
|
613
|
+
duration: float,
|
|
614
|
+
start: float = 0,
|
|
615
|
+
end: Optional[float] = None,
|
|
616
|
+
) -> "Iterator[VideoFragment]":
|
|
617
|
+
"""
|
|
618
|
+
Splits the video into multiple fragments of a specified duration.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
duration (float): The duration of each video fragment in seconds.
|
|
622
|
+
start (float): The starting time in seconds (default: 0).
|
|
623
|
+
end (float, optional): The ending time in seconds. If None, the entire
|
|
624
|
+
remaining video is processed (default: None).
|
|
625
|
+
|
|
626
|
+
Returns:
|
|
627
|
+
Iterator[VideoFragment]: An iterator yielding video fragments.
|
|
628
|
+
|
|
629
|
+
Note:
|
|
630
|
+
If end is not specified, number of frames will be taken from the video file,
|
|
631
|
+
this means video file needs to be downloaded.
|
|
632
|
+
"""
|
|
633
|
+
if duration <= 0:
|
|
634
|
+
raise ValueError("duration must be a positive float")
|
|
635
|
+
if start < 0:
|
|
636
|
+
raise ValueError("start must be a non-negative float")
|
|
637
|
+
|
|
638
|
+
if end is None:
|
|
639
|
+
end = self.get_info().duration
|
|
640
|
+
|
|
641
|
+
if end < 0:
|
|
642
|
+
raise ValueError("end must be a non-negative float")
|
|
643
|
+
if start >= end:
|
|
644
|
+
raise ValueError("start must be less than end")
|
|
645
|
+
|
|
646
|
+
while start < end:
|
|
647
|
+
yield self.get_fragment(start, min(start + duration, end))
|
|
648
|
+
start += duration
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
class VideoFrame(DataModel):
|
|
652
|
+
"""
|
|
653
|
+
A data model for representing a video frame.
|
|
654
|
+
|
|
655
|
+
This model inherits from the `VideoFile` model and adds a `frame` attribute,
|
|
656
|
+
which represents a specific frame within a video file. It allows access
|
|
657
|
+
to individual frames and provides functionality for reading and saving
|
|
658
|
+
video frames as image files.
|
|
659
|
+
|
|
660
|
+
Attributes:
|
|
661
|
+
video (VideoFile): The video file containing the video frame.
|
|
662
|
+
frame (int): The frame number referencing a specific frame in the video file.
|
|
663
|
+
"""
|
|
664
|
+
|
|
665
|
+
video: VideoFile
|
|
666
|
+
frame: int
|
|
667
|
+
|
|
668
|
+
def get_np(self) -> "ndarray":
|
|
669
|
+
"""
|
|
670
|
+
Returns a video frame from the video file as a NumPy array.
|
|
671
|
+
|
|
672
|
+
Returns:
|
|
673
|
+
ndarray: A NumPy array representing the video frame,
|
|
674
|
+
in the shape (height, width, channels).
|
|
675
|
+
"""
|
|
676
|
+
from .video import video_frame_np
|
|
677
|
+
|
|
678
|
+
return video_frame_np(self.video, self.frame)
|
|
679
|
+
|
|
680
|
+
def read_bytes(self, format: str = "jpg") -> bytes:
|
|
681
|
+
"""
|
|
682
|
+
Returns a video frame from the video file as image bytes.
|
|
683
|
+
|
|
684
|
+
Args:
|
|
685
|
+
format (str): The desired image format (e.g., 'jpg', 'png').
|
|
686
|
+
Defaults to 'jpg'.
|
|
687
|
+
|
|
688
|
+
Returns:
|
|
689
|
+
bytes: The encoded video frame as image bytes.
|
|
690
|
+
"""
|
|
691
|
+
from .video import video_frame_bytes
|
|
692
|
+
|
|
693
|
+
return video_frame_bytes(self.video, self.frame, format)
|
|
694
|
+
|
|
695
|
+
def save(self, output: str, format: str = "jpg") -> "ImageFile":
|
|
696
|
+
"""
|
|
697
|
+
Saves the current video frame as an image file.
|
|
698
|
+
|
|
699
|
+
If `output` is a remote path, the image file will be uploaded to remote storage.
|
|
700
|
+
|
|
701
|
+
Args:
|
|
702
|
+
output (str): The destination path, which can be a local file path
|
|
703
|
+
or a remote URL.
|
|
704
|
+
format (str): The image format (e.g., 'jpg', 'png'). Defaults to 'jpg'.
|
|
705
|
+
|
|
706
|
+
Returns:
|
|
707
|
+
ImageFile: A Model representing the saved image file.
|
|
708
|
+
"""
|
|
709
|
+
from .video import save_video_frame
|
|
710
|
+
|
|
711
|
+
return save_video_frame(self.video, self.frame, output, format)
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
class VideoFragment(DataModel):
|
|
715
|
+
"""
|
|
716
|
+
A data model for representing a video fragment.
|
|
717
|
+
|
|
718
|
+
This model inherits from the `VideoFile` model and adds `start`
|
|
719
|
+
and `end` attributes, which represent a specific fragment within a video file.
|
|
720
|
+
It allows access to individual fragments and provides functionality for reading
|
|
721
|
+
and saving video fragments as separate video files.
|
|
722
|
+
|
|
723
|
+
Attributes:
|
|
724
|
+
video (VideoFile): The video file containing the video fragment.
|
|
725
|
+
start (float): The starting time of the video fragment in seconds.
|
|
726
|
+
end (float): The ending time of the video fragment in seconds.
|
|
727
|
+
"""
|
|
728
|
+
|
|
729
|
+
video: VideoFile
|
|
730
|
+
start: float
|
|
731
|
+
end: float
|
|
732
|
+
|
|
733
|
+
def save(self, output: str, format: Optional[str] = None) -> "VideoFile":
|
|
734
|
+
"""
|
|
735
|
+
Saves the video fragment as a new video file.
|
|
736
|
+
|
|
737
|
+
If `output` is a remote path, the video file will be uploaded to remote storage.
|
|
738
|
+
|
|
739
|
+
Args:
|
|
740
|
+
output (str): The destination path, which can be a local file path
|
|
741
|
+
or a remote URL.
|
|
742
|
+
format (str, optional): The output video format (e.g., 'mp4', 'avi').
|
|
743
|
+
If None, the format is inferred from the
|
|
744
|
+
file extension.
|
|
745
|
+
|
|
746
|
+
Returns:
|
|
747
|
+
VideoFile: A Model representing the saved video file.
|
|
748
|
+
"""
|
|
749
|
+
from .video import save_video_fragment
|
|
750
|
+
|
|
751
|
+
return save_video_fragment(self.video, self.start, self.end, output, format)
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
class Video(DataModel):
|
|
755
|
+
"""
|
|
756
|
+
A data model representing metadata for a video file.
|
|
757
|
+
|
|
758
|
+
Attributes:
|
|
759
|
+
width (int): The width of the video in pixels. Defaults to -1 if unknown.
|
|
760
|
+
height (int): The height of the video in pixels. Defaults to -1 if unknown.
|
|
761
|
+
fps (float): The frame rate of the video (frames per second).
|
|
762
|
+
Defaults to -1.0 if unknown.
|
|
763
|
+
duration (float): The total duration of the video in seconds.
|
|
764
|
+
Defaults to -1.0 if unknown.
|
|
765
|
+
frames (int): The total number of frames in the video.
|
|
766
|
+
Defaults to -1 if unknown.
|
|
767
|
+
format (str): The format of the video file (e.g., 'mp4', 'avi').
|
|
768
|
+
Defaults to an empty string.
|
|
769
|
+
codec (str): The codec used for encoding the video. Defaults to an empty string.
|
|
770
|
+
"""
|
|
771
|
+
|
|
772
|
+
width: int = Field(default=-1)
|
|
773
|
+
height: int = Field(default=-1)
|
|
774
|
+
fps: float = Field(default=-1.0)
|
|
775
|
+
duration: float = Field(default=-1.0)
|
|
776
|
+
frames: int = Field(default=-1)
|
|
777
|
+
format: str = Field(default="")
|
|
778
|
+
codec: str = Field(default="")
|
|
779
|
+
|
|
780
|
+
|
|
496
781
|
class ArrowRow(DataModel):
|
|
497
782
|
"""`DataModel` for reading row from Arrow-supported file."""
|
|
498
783
|
|
|
@@ -528,5 +813,7 @@ def get_file_type(type_: FileType = "binary") -> type[File]:
|
|
|
528
813
|
file = TextFile
|
|
529
814
|
elif type_ == "image":
|
|
530
815
|
file = ImageFile # type: ignore[assignment]
|
|
816
|
+
elif type_ == "video":
|
|
817
|
+
file = VideoFile
|
|
531
818
|
|
|
532
819
|
return file
|
datachain/lib/hf.py
CHANGED
datachain/lib/video.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import posixpath
|
|
2
|
+
import shutil
|
|
3
|
+
import tempfile
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from numpy import ndarray
|
|
7
|
+
|
|
8
|
+
from datachain.lib.file import FileError, ImageFile, Video, VideoFile
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import ffmpeg
|
|
12
|
+
import imageio.v3 as iio
|
|
13
|
+
except ImportError as exc:
|
|
14
|
+
raise ImportError(
|
|
15
|
+
"Missing dependencies for processing video.\n"
|
|
16
|
+
"To install run:\n\n"
|
|
17
|
+
" pip install 'datachain[video]'\n"
|
|
18
|
+
) from exc
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def video_info(file: VideoFile) -> Video:
|
|
22
|
+
"""
|
|
23
|
+
Returns video file information.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
file (VideoFile): Video file object.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Video: Video file information.
|
|
30
|
+
"""
|
|
31
|
+
if not (file_path := file.get_local_path()):
|
|
32
|
+
file.ensure_cached()
|
|
33
|
+
file_path = file.get_local_path()
|
|
34
|
+
if not file_path:
|
|
35
|
+
raise FileError(file, "unable to download video file")
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
probe = ffmpeg.probe(file_path)
|
|
39
|
+
except Exception as exc:
|
|
40
|
+
raise FileError(file, "unable to extract metadata from video file") from exc
|
|
41
|
+
|
|
42
|
+
all_streams = probe.get("streams")
|
|
43
|
+
video_format = probe.get("format")
|
|
44
|
+
if not all_streams or not video_format:
|
|
45
|
+
raise FileError(file, "unable to extract metadata from video file")
|
|
46
|
+
|
|
47
|
+
video_streams = [s for s in all_streams if s["codec_type"] == "video"]
|
|
48
|
+
if len(video_streams) == 0:
|
|
49
|
+
raise FileError(file, "unable to extract metadata from video file")
|
|
50
|
+
|
|
51
|
+
video_stream = video_streams[0]
|
|
52
|
+
|
|
53
|
+
r_frame_rate = video_stream.get("r_frame_rate", "0")
|
|
54
|
+
if "/" in r_frame_rate:
|
|
55
|
+
num, denom = r_frame_rate.split("/")
|
|
56
|
+
fps = float(num) / float(denom)
|
|
57
|
+
else:
|
|
58
|
+
fps = float(r_frame_rate)
|
|
59
|
+
|
|
60
|
+
width = int(video_stream.get("width", 0))
|
|
61
|
+
height = int(video_stream.get("height", 0))
|
|
62
|
+
duration = float(video_format.get("duration", 0))
|
|
63
|
+
if "nb_frames" in video_stream:
|
|
64
|
+
frames = int(video_stream.get("nb_frames", 0))
|
|
65
|
+
else:
|
|
66
|
+
start_time = float(video_format.get("start_time", 0))
|
|
67
|
+
frames = int((duration - start_time) * fps)
|
|
68
|
+
format_name = video_format.get("format_name", "")
|
|
69
|
+
codec_name = video_stream.get("codec_name", "")
|
|
70
|
+
|
|
71
|
+
return Video(
|
|
72
|
+
width=width,
|
|
73
|
+
height=height,
|
|
74
|
+
fps=fps,
|
|
75
|
+
duration=duration,
|
|
76
|
+
frames=frames,
|
|
77
|
+
format=format_name,
|
|
78
|
+
codec=codec_name,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def video_frame_np(video: VideoFile, frame: int) -> ndarray:
|
|
83
|
+
"""
|
|
84
|
+
Reads video frame from a file and returns as numpy array.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
video (VideoFile): Video file object.
|
|
88
|
+
frame (int): Frame index.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
ndarray: Video frame.
|
|
92
|
+
"""
|
|
93
|
+
if frame < 0:
|
|
94
|
+
raise ValueError("frame must be a non-negative integer")
|
|
95
|
+
|
|
96
|
+
with video.open() as f:
|
|
97
|
+
return iio.imread(f, index=frame, plugin="pyav") # type: ignore[arg-type]
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def validate_frame_range(
|
|
101
|
+
video: VideoFile,
|
|
102
|
+
start: int = 0,
|
|
103
|
+
end: Optional[int] = None,
|
|
104
|
+
step: int = 1,
|
|
105
|
+
) -> tuple[int, int, int]:
|
|
106
|
+
"""
|
|
107
|
+
Validates frame range for a video file.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
video (VideoFile): Video file object.
|
|
111
|
+
start (int): Start frame index (default: 0).
|
|
112
|
+
end (int, optional): End frame index (default: None).
|
|
113
|
+
step (int): Step between frames (default: 1).
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
tuple[int, int, int]: Start frame index, end frame index, and step.
|
|
117
|
+
"""
|
|
118
|
+
if start < 0:
|
|
119
|
+
raise ValueError("start_frame must be a non-negative integer.")
|
|
120
|
+
if step < 1:
|
|
121
|
+
raise ValueError("step must be a positive integer.")
|
|
122
|
+
|
|
123
|
+
if end is None:
|
|
124
|
+
end = video_info(video).frames
|
|
125
|
+
|
|
126
|
+
if end < 0:
|
|
127
|
+
raise ValueError("end_frame must be a non-negative integer.")
|
|
128
|
+
if start > end:
|
|
129
|
+
raise ValueError("start_frame must be less than or equal to end_frame.")
|
|
130
|
+
|
|
131
|
+
return start, end, step
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def video_frame_bytes(video: VideoFile, frame: int, format: str = "jpg") -> bytes:
|
|
135
|
+
"""
|
|
136
|
+
Reads video frame from a file and returns as image bytes.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
video (VideoFile): Video file object.
|
|
140
|
+
frame (int): Frame index.
|
|
141
|
+
format (str): Image format (default: 'jpg').
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
bytes: Video frame image as bytes.
|
|
145
|
+
"""
|
|
146
|
+
img = video_frame_np(video, frame)
|
|
147
|
+
return iio.imwrite("<bytes>", img, extension=f".{format}")
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def save_video_frame(
|
|
151
|
+
video: VideoFile,
|
|
152
|
+
frame: int,
|
|
153
|
+
output: str,
|
|
154
|
+
format: str = "jpg",
|
|
155
|
+
) -> ImageFile:
|
|
156
|
+
"""
|
|
157
|
+
Saves video frame as a new image file. If output is a remote path,
|
|
158
|
+
the image file will be uploaded to the remote storage.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
video (VideoFile): Video file object.
|
|
162
|
+
frame (int): Frame index.
|
|
163
|
+
output (str): Output path, can be a local path or a remote path.
|
|
164
|
+
format (str): Image format (default: 'jpg').
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
ImageFile: Image file model.
|
|
168
|
+
"""
|
|
169
|
+
img = video_frame_bytes(video, frame, format=format)
|
|
170
|
+
output_file = posixpath.join(
|
|
171
|
+
output, f"{video.get_file_stem()}_{frame:04d}.{format}"
|
|
172
|
+
)
|
|
173
|
+
return ImageFile.upload(img, output_file)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def save_video_fragment(
|
|
177
|
+
video: VideoFile,
|
|
178
|
+
start: float,
|
|
179
|
+
end: float,
|
|
180
|
+
output: str,
|
|
181
|
+
format: Optional[str] = None,
|
|
182
|
+
) -> VideoFile:
|
|
183
|
+
"""
|
|
184
|
+
Saves video interval as a new video file. If output is a remote path,
|
|
185
|
+
the video file will be uploaded to the remote storage.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
video (VideoFile): Video file object.
|
|
189
|
+
start (float): Start time in seconds.
|
|
190
|
+
end (float): End time in seconds.
|
|
191
|
+
output (str): Output path, can be a local path or a remote path.
|
|
192
|
+
format (str, optional): Output format (default: None). If not provided,
|
|
193
|
+
the format will be inferred from the video fragment
|
|
194
|
+
file extension.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
VideoFile: Video fragment model.
|
|
198
|
+
"""
|
|
199
|
+
if start < 0 or end < 0 or start >= end:
|
|
200
|
+
raise ValueError(f"Invalid time range: ({start:.3f}, {end:.3f})")
|
|
201
|
+
|
|
202
|
+
if format is None:
|
|
203
|
+
format = video.get_file_ext()
|
|
204
|
+
|
|
205
|
+
start_ms = int(start * 1000)
|
|
206
|
+
end_ms = int(end * 1000)
|
|
207
|
+
output_file = posixpath.join(
|
|
208
|
+
output, f"{video.get_file_stem()}_{start_ms:06d}_{end_ms:06d}.{format}"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
temp_dir = tempfile.mkdtemp()
|
|
212
|
+
try:
|
|
213
|
+
output_file_tmp = posixpath.join(temp_dir, posixpath.basename(output_file))
|
|
214
|
+
ffmpeg.input(
|
|
215
|
+
video.get_local_path(),
|
|
216
|
+
ss=start,
|
|
217
|
+
to=end,
|
|
218
|
+
).output(output_file_tmp).run(quiet=True)
|
|
219
|
+
|
|
220
|
+
with open(output_file_tmp, "rb") as f:
|
|
221
|
+
return VideoFile.upload(f.read(), output_file)
|
|
222
|
+
finally:
|
|
223
|
+
shutil.rmtree(temp_dir)
|
datachain/query/dataset.py
CHANGED
|
@@ -42,13 +42,17 @@ from datachain.data_storage.schema import (
|
|
|
42
42
|
partition_col_names,
|
|
43
43
|
partition_columns,
|
|
44
44
|
)
|
|
45
|
-
from datachain.dataset import DatasetStatus, RowDict
|
|
46
|
-
from datachain.error import
|
|
45
|
+
from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict
|
|
46
|
+
from datachain.error import (
|
|
47
|
+
DatasetNotFoundError,
|
|
48
|
+
QueryScriptCancelError,
|
|
49
|
+
)
|
|
47
50
|
from datachain.func.base import Function
|
|
48
51
|
from datachain.lib.udf import UDFAdapter, _get_cache
|
|
49
52
|
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
|
|
50
53
|
from datachain.query.schema import C, UDFParamSpec, normalize_param
|
|
51
54
|
from datachain.query.session import Session
|
|
55
|
+
from datachain.remote.studio import is_token_set
|
|
52
56
|
from datachain.sql.functions.random import rand
|
|
53
57
|
from datachain.utils import (
|
|
54
58
|
batched,
|
|
@@ -1081,6 +1085,7 @@ class DatasetQuery:
|
|
|
1081
1085
|
session: Optional[Session] = None,
|
|
1082
1086
|
indexing_column_types: Optional[dict[str, Any]] = None,
|
|
1083
1087
|
in_memory: bool = False,
|
|
1088
|
+
fallback_to_remote: bool = True,
|
|
1084
1089
|
) -> None:
|
|
1085
1090
|
self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
|
|
1086
1091
|
self.catalog = catalog or self.session.catalog
|
|
@@ -1097,7 +1102,12 @@ class DatasetQuery:
|
|
|
1097
1102
|
self.column_types: Optional[dict[str, Any]] = None
|
|
1098
1103
|
|
|
1099
1104
|
self.name = name
|
|
1100
|
-
|
|
1105
|
+
|
|
1106
|
+
if fallback_to_remote and is_token_set():
|
|
1107
|
+
ds = self.catalog.get_dataset_with_remote_fallback(name, version)
|
|
1108
|
+
else:
|
|
1109
|
+
ds = self.catalog.get_dataset(name)
|
|
1110
|
+
|
|
1101
1111
|
self.version = version or ds.latest_version
|
|
1102
1112
|
self.feature_schema = ds.get_version(self.version).feature_schema
|
|
1103
1113
|
self.column_types = copy(ds.schema)
|
|
@@ -1112,6 +1122,21 @@ class DatasetQuery:
|
|
|
1112
1122
|
def __or__(self, other):
|
|
1113
1123
|
return self.union(other)
|
|
1114
1124
|
|
|
1125
|
+
def pull_dataset(self, name: str, version: Optional[int] = None) -> "DatasetRecord":
|
|
1126
|
+
print("Dataset not found in local catalog, trying to get from studio")
|
|
1127
|
+
|
|
1128
|
+
remote_ds_uri = f"{DATASET_PREFIX}{name}"
|
|
1129
|
+
if version:
|
|
1130
|
+
remote_ds_uri += f"@v{version}"
|
|
1131
|
+
|
|
1132
|
+
self.catalog.pull_dataset(
|
|
1133
|
+
remote_ds_uri=remote_ds_uri,
|
|
1134
|
+
local_ds_name=name,
|
|
1135
|
+
local_ds_version=version,
|
|
1136
|
+
)
|
|
1137
|
+
|
|
1138
|
+
return self.catalog.get_dataset(name)
|
|
1139
|
+
|
|
1115
1140
|
@staticmethod
|
|
1116
1141
|
def get_table() -> "TableClause":
|
|
1117
1142
|
table_name = "".join(
|