datachain 0.25.2__py3-none-any.whl → 0.26.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +6 -0
- datachain/catalog/loader.py +4 -0
- datachain/func/__init__.py +2 -1
- datachain/func/conditional.py +34 -0
- datachain/lib/audio.py +151 -0
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/dc/datachain.py +227 -67
- datachain/lib/file.py +190 -1
- datachain/lib/model_store.py +8 -0
- datachain/lib/pytorch.py +4 -1
- datachain/lib/signal_schema.py +56 -11
- datachain/lib/udf.py +17 -5
- datachain/query/dataset.py +37 -9
- {datachain-0.25.2.dist-info → datachain-0.26.1.dist-info}/METADATA +6 -2
- {datachain-0.25.2.dist-info → datachain-0.26.1.dist-info}/RECORD +19 -18
- {datachain-0.25.2.dist-info → datachain-0.26.1.dist-info}/WHEEL +0 -0
- {datachain-0.25.2.dist-info → datachain-0.26.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.25.2.dist-info → datachain-0.26.1.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.25.2.dist-info → datachain-0.26.1.dist-info}/top_level.txt +0 -0
datachain/lib/file.py
CHANGED
|
@@ -43,7 +43,7 @@ logger = logging.getLogger("datachain")
|
|
|
43
43
|
# how to create file path when exporting
|
|
44
44
|
ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
|
|
45
45
|
|
|
46
|
-
FileType = Literal["binary", "text", "image", "video"]
|
|
46
|
+
FileType = Literal["binary", "text", "image", "video", "audio"]
|
|
47
47
|
EXPORT_FILES_MAX_THREADS = 5
|
|
48
48
|
|
|
49
49
|
|
|
@@ -312,6 +312,14 @@ class File(DataModel):
|
|
|
312
312
|
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
313
313
|
return file
|
|
314
314
|
|
|
315
|
+
def as_audio_file(self) -> "AudioFile":
|
|
316
|
+
"""Convert the file to a `AudioFile` object."""
|
|
317
|
+
if isinstance(self, AudioFile):
|
|
318
|
+
return self
|
|
319
|
+
file = AudioFile(**self.model_dump())
|
|
320
|
+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
321
|
+
return file
|
|
322
|
+
|
|
315
323
|
@classmethod
|
|
316
324
|
def upload(
|
|
317
325
|
cls, data: bytes, path: str, catalog: Optional["Catalog"] = None
|
|
@@ -851,6 +859,157 @@ class VideoFile(File):
|
|
|
851
859
|
start += duration
|
|
852
860
|
|
|
853
861
|
|
|
862
|
+
class AudioFile(File):
|
|
863
|
+
"""
|
|
864
|
+
A data model for handling audio files.
|
|
865
|
+
|
|
866
|
+
This model inherits from the `File` model and provides additional functionality
|
|
867
|
+
for reading audio files, extracting audio fragments, and splitting audio into
|
|
868
|
+
fragments.
|
|
869
|
+
"""
|
|
870
|
+
|
|
871
|
+
def get_info(self) -> "Audio":
|
|
872
|
+
"""
|
|
873
|
+
Retrieves metadata and information about the audio file. It does not
|
|
874
|
+
download the file if possible, only reads its header. It is thus might be
|
|
875
|
+
a good idea to disable caching and prefetching for UDF if you only need
|
|
876
|
+
audio metadata.
|
|
877
|
+
|
|
878
|
+
Returns:
|
|
879
|
+
Audio: A Model containing audio metadata such as duration,
|
|
880
|
+
sample rate, channels, and codec details.
|
|
881
|
+
"""
|
|
882
|
+
from .audio import audio_info
|
|
883
|
+
|
|
884
|
+
return audio_info(self)
|
|
885
|
+
|
|
886
|
+
def get_fragment(self, start: float, end: float) -> "AudioFragment":
|
|
887
|
+
"""
|
|
888
|
+
Returns an audio fragment from the specified time range. It does not
|
|
889
|
+
download the file, neither it actually extracts the fragment. It returns
|
|
890
|
+
a Model representing the audio fragment, which can be used to read or save
|
|
891
|
+
it later.
|
|
892
|
+
|
|
893
|
+
Args:
|
|
894
|
+
start (float): The start time of the fragment in seconds.
|
|
895
|
+
end (float): The end time of the fragment in seconds.
|
|
896
|
+
|
|
897
|
+
Returns:
|
|
898
|
+
AudioFragment: A Model representing the audio fragment.
|
|
899
|
+
"""
|
|
900
|
+
if start < 0 or end < 0 or start >= end:
|
|
901
|
+
raise ValueError(f"Invalid time range: ({start:.3f}, {end:.3f})")
|
|
902
|
+
|
|
903
|
+
return AudioFragment(audio=self, start=start, end=end)
|
|
904
|
+
|
|
905
|
+
def get_fragments(
|
|
906
|
+
self,
|
|
907
|
+
duration: float,
|
|
908
|
+
start: float = 0,
|
|
909
|
+
end: Optional[float] = None,
|
|
910
|
+
) -> "Iterator[AudioFragment]":
|
|
911
|
+
"""
|
|
912
|
+
Splits the audio into multiple fragments of a specified duration.
|
|
913
|
+
|
|
914
|
+
Args:
|
|
915
|
+
duration (float): The duration of each audio fragment in seconds.
|
|
916
|
+
start (float): The starting time in seconds (default: 0).
|
|
917
|
+
end (float, optional): The ending time in seconds. If None, the entire
|
|
918
|
+
remaining audio is processed (default: None).
|
|
919
|
+
|
|
920
|
+
Returns:
|
|
921
|
+
Iterator[AudioFragment]: An iterator yielding audio fragments.
|
|
922
|
+
|
|
923
|
+
Note:
|
|
924
|
+
If end is not specified, number of samples will be taken from the
|
|
925
|
+
audio file, this means audio file needs to be downloaded.
|
|
926
|
+
"""
|
|
927
|
+
if duration <= 0:
|
|
928
|
+
raise ValueError("duration must be a positive float")
|
|
929
|
+
if start < 0:
|
|
930
|
+
raise ValueError("start must be a non-negative float")
|
|
931
|
+
|
|
932
|
+
if end is None:
|
|
933
|
+
end = self.get_info().duration
|
|
934
|
+
|
|
935
|
+
if end < 0:
|
|
936
|
+
raise ValueError("end must be a non-negative float")
|
|
937
|
+
if start >= end:
|
|
938
|
+
raise ValueError("start must be less than end")
|
|
939
|
+
|
|
940
|
+
while start < end:
|
|
941
|
+
yield self.get_fragment(start, min(start + duration, end))
|
|
942
|
+
start += duration
|
|
943
|
+
|
|
944
|
+
|
|
945
|
+
class AudioFragment(DataModel):
|
|
946
|
+
"""
|
|
947
|
+
A data model for representing an audio fragment.
|
|
948
|
+
|
|
949
|
+
This model represents a specific fragment within an audio file with defined
|
|
950
|
+
start and end times. It allows access to individual fragments and provides
|
|
951
|
+
functionality for reading and saving audio fragments as separate audio files.
|
|
952
|
+
|
|
953
|
+
Attributes:
|
|
954
|
+
audio (AudioFile): The audio file containing the audio fragment.
|
|
955
|
+
start (float): The starting time of the audio fragment in seconds.
|
|
956
|
+
end (float): The ending time of the audio fragment in seconds.
|
|
957
|
+
"""
|
|
958
|
+
|
|
959
|
+
audio: AudioFile
|
|
960
|
+
start: float
|
|
961
|
+
end: float
|
|
962
|
+
|
|
963
|
+
def get_np(self) -> tuple["ndarray", int]:
|
|
964
|
+
"""
|
|
965
|
+
Returns the audio fragment as a NumPy array with sample rate.
|
|
966
|
+
|
|
967
|
+
Returns:
|
|
968
|
+
tuple[ndarray, int]: A tuple containing the audio data as a NumPy array
|
|
969
|
+
and the sample rate.
|
|
970
|
+
"""
|
|
971
|
+
from .audio import audio_fragment_np
|
|
972
|
+
|
|
973
|
+
duration = self.end - self.start
|
|
974
|
+
return audio_fragment_np(self.audio, self.start, duration)
|
|
975
|
+
|
|
976
|
+
def read_bytes(self, format: str = "wav") -> bytes:
|
|
977
|
+
"""
|
|
978
|
+
Returns the audio fragment as audio bytes.
|
|
979
|
+
|
|
980
|
+
Args:
|
|
981
|
+
format (str): The desired audio format (e.g., 'wav', 'mp3').
|
|
982
|
+
Defaults to 'wav'.
|
|
983
|
+
|
|
984
|
+
Returns:
|
|
985
|
+
bytes: The encoded audio fragment as bytes.
|
|
986
|
+
"""
|
|
987
|
+
from .audio import audio_fragment_bytes
|
|
988
|
+
|
|
989
|
+
duration = self.end - self.start
|
|
990
|
+
return audio_fragment_bytes(self.audio, self.start, duration, format)
|
|
991
|
+
|
|
992
|
+
def save(self, output: str, format: Optional[str] = None) -> "AudioFile":
|
|
993
|
+
"""
|
|
994
|
+
Saves the audio fragment as a new audio file.
|
|
995
|
+
|
|
996
|
+
If `output` is a remote path, the audio file will be uploaded to remote storage.
|
|
997
|
+
|
|
998
|
+
Args:
|
|
999
|
+
output (str): The destination path, which can be a local file path
|
|
1000
|
+
or a remote URL.
|
|
1001
|
+
format (str, optional): The output audio format (e.g., 'wav', 'mp3').
|
|
1002
|
+
If None, the format is inferred from the
|
|
1003
|
+
file extension.
|
|
1004
|
+
|
|
1005
|
+
Returns:
|
|
1006
|
+
AudioFile: A Model representing the saved audio file.
|
|
1007
|
+
"""
|
|
1008
|
+
from .audio import save_audio_fragment
|
|
1009
|
+
|
|
1010
|
+
return save_audio_fragment(self.audio, self.start, self.end, output, format)
|
|
1011
|
+
|
|
1012
|
+
|
|
854
1013
|
class VideoFrame(DataModel):
|
|
855
1014
|
"""
|
|
856
1015
|
A data model for representing a video frame.
|
|
@@ -981,6 +1140,34 @@ class Video(DataModel):
|
|
|
981
1140
|
codec: str = Field(default="")
|
|
982
1141
|
|
|
983
1142
|
|
|
1143
|
+
class Audio(DataModel):
|
|
1144
|
+
"""
|
|
1145
|
+
A data model representing metadata for an audio file.
|
|
1146
|
+
|
|
1147
|
+
Attributes:
|
|
1148
|
+
sample_rate (int): The sample rate of the audio (samples per second).
|
|
1149
|
+
Defaults to -1 if unknown.
|
|
1150
|
+
channels (int): The number of audio channels. Defaults to -1 if unknown.
|
|
1151
|
+
duration (float): The total duration of the audio in seconds.
|
|
1152
|
+
Defaults to -1.0 if unknown.
|
|
1153
|
+
samples (int): The total number of samples in the audio.
|
|
1154
|
+
Defaults to -1 if unknown.
|
|
1155
|
+
format (str): The format of the audio file (e.g., 'wav', 'mp3').
|
|
1156
|
+
Defaults to an empty string.
|
|
1157
|
+
codec (str): The codec used for encoding the audio. Defaults to an empty string.
|
|
1158
|
+
bit_rate (int): The bit rate of the audio in bits per second.
|
|
1159
|
+
Defaults to -1 if unknown.
|
|
1160
|
+
"""
|
|
1161
|
+
|
|
1162
|
+
sample_rate: int = Field(default=-1)
|
|
1163
|
+
channels: int = Field(default=-1)
|
|
1164
|
+
duration: float = Field(default=-1.0)
|
|
1165
|
+
samples: int = Field(default=-1)
|
|
1166
|
+
format: str = Field(default="")
|
|
1167
|
+
codec: str = Field(default="")
|
|
1168
|
+
bit_rate: int = Field(default=-1)
|
|
1169
|
+
|
|
1170
|
+
|
|
984
1171
|
class ArrowRow(DataModel):
|
|
985
1172
|
"""`DataModel` for reading row from Arrow-supported file."""
|
|
986
1173
|
|
|
@@ -1018,5 +1205,7 @@ def get_file_type(type_: FileType = "binary") -> type[File]:
|
|
|
1018
1205
|
file = ImageFile # type: ignore[assignment]
|
|
1019
1206
|
elif type_ == "video":
|
|
1020
1207
|
file = VideoFile
|
|
1208
|
+
elif type_ == "audio":
|
|
1209
|
+
file = AudioFile
|
|
1021
1210
|
|
|
1022
1211
|
return file
|
datachain/lib/model_store.py
CHANGED
|
@@ -81,3 +81,11 @@ class ModelStore:
|
|
|
81
81
|
if val is None or not ModelStore.is_pydantic(val):
|
|
82
82
|
return None
|
|
83
83
|
return val
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def is_partial(parent_type) -> bool:
|
|
87
|
+
return (
|
|
88
|
+
parent_type
|
|
89
|
+
and ModelStore.is_pydantic(parent_type)
|
|
90
|
+
and "@" in ModelStore.get_name(parent_type)
|
|
91
|
+
)
|
datachain/lib/pytorch.py
CHANGED
|
@@ -125,7 +125,10 @@ class PytorchDataset(IterableDataset):
|
|
|
125
125
|
ds = read_dataset(
|
|
126
126
|
name=self.name, version=self.version, session=session
|
|
127
127
|
).settings(cache=self.cache, prefetch=self.prefetch)
|
|
128
|
-
|
|
128
|
+
|
|
129
|
+
# remove file signals from dataset
|
|
130
|
+
schema = ds.signals_schema.clone_without_file_signals()
|
|
131
|
+
ds = ds.select(*schema.values.keys())
|
|
129
132
|
|
|
130
133
|
if self.num_samples > 0:
|
|
131
134
|
ds = ds.sample(self.num_samples)
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -446,14 +446,14 @@ class SignalSchema:
|
|
|
446
446
|
res[db_name] = python_to_sql(type_)
|
|
447
447
|
return res
|
|
448
448
|
|
|
449
|
-
def row_to_objs(self, row: Sequence[Any]) -> list[
|
|
449
|
+
def row_to_objs(self, row: Sequence[Any]) -> list[Any]:
|
|
450
450
|
self._init_setup_values()
|
|
451
451
|
|
|
452
|
-
objs: list[
|
|
452
|
+
objs: list[Any] = []
|
|
453
453
|
pos = 0
|
|
454
454
|
for name, fr_type in self.values.items():
|
|
455
|
-
if self.setup_values and
|
|
456
|
-
objs.append(
|
|
455
|
+
if self.setup_values and name in self.setup_values:
|
|
456
|
+
objs.append(self.setup_values.get(name))
|
|
457
457
|
elif (fr := ModelStore.to_pydantic(fr_type)) is not None:
|
|
458
458
|
j, pos = unflatten_to_json_pos(fr, row, pos)
|
|
459
459
|
objs.append(fr(**j))
|
|
@@ -589,6 +589,9 @@ class SignalSchema:
|
|
|
589
589
|
]
|
|
590
590
|
|
|
591
591
|
if name:
|
|
592
|
+
if "." in name:
|
|
593
|
+
name = name.replace(".", "__")
|
|
594
|
+
|
|
592
595
|
signals = [
|
|
593
596
|
s
|
|
594
597
|
for s in signals
|
|
@@ -607,24 +610,38 @@ class SignalSchema:
|
|
|
607
610
|
return SignalSchema(schema)
|
|
608
611
|
|
|
609
612
|
def _find_in_tree(self, path: list[str]) -> DataType:
|
|
613
|
+
if val := self.tree.get(".".join(path)):
|
|
614
|
+
# If the path is a single string, we can directly access it
|
|
615
|
+
# without traversing the tree.
|
|
616
|
+
return val[0]
|
|
617
|
+
|
|
610
618
|
curr_tree = self.tree
|
|
611
619
|
curr_type = None
|
|
612
620
|
i = 0
|
|
613
621
|
while curr_tree is not None and i < len(path):
|
|
614
622
|
if val := curr_tree.get(path[i]):
|
|
615
623
|
curr_type, curr_tree = val
|
|
616
|
-
elif i == 0 and len(path) > 1 and (val := curr_tree.get(".".join(path))):
|
|
617
|
-
curr_type, curr_tree = val
|
|
618
|
-
break
|
|
619
624
|
else:
|
|
620
625
|
curr_type = None
|
|
626
|
+
break
|
|
621
627
|
i += 1
|
|
622
628
|
|
|
623
|
-
if curr_type is None:
|
|
629
|
+
if curr_type is None or i < len(path):
|
|
630
|
+
# If we reached the end of the path and didn't find a type,
|
|
631
|
+
# or if we didn't traverse the entire path, raise an error.
|
|
624
632
|
raise SignalResolvingError(path, "is not found")
|
|
625
633
|
|
|
626
634
|
return curr_type
|
|
627
635
|
|
|
636
|
+
def group_by(
|
|
637
|
+
self, partition_by: Sequence[str], new_column: Sequence[Column]
|
|
638
|
+
) -> "SignalSchema":
|
|
639
|
+
orig_schema = SignalSchema(copy.deepcopy(self.values))
|
|
640
|
+
schema = orig_schema.to_partial(*partition_by)
|
|
641
|
+
|
|
642
|
+
vals = {c.name: sql_to_python(c) for c in new_column}
|
|
643
|
+
return SignalSchema(schema.values | vals)
|
|
644
|
+
|
|
628
645
|
def select_except_signals(self, *args: str) -> "SignalSchema":
|
|
629
646
|
def has_signal(signal: str):
|
|
630
647
|
signal = signal.replace(".", DEFAULT_DELIMITER)
|
|
@@ -888,7 +905,7 @@ class SignalSchema:
|
|
|
888
905
|
|
|
889
906
|
return res
|
|
890
907
|
|
|
891
|
-
def to_partial(self, *columns: str) -> "SignalSchema":
|
|
908
|
+
def to_partial(self, *columns: str) -> "SignalSchema": # noqa: C901
|
|
892
909
|
"""
|
|
893
910
|
Convert the schema to a partial schema with only the specified columns.
|
|
894
911
|
|
|
@@ -931,9 +948,15 @@ class SignalSchema:
|
|
|
931
948
|
partial_versions: dict[str, int] = {}
|
|
932
949
|
|
|
933
950
|
def _type_name_to_partial(signal_name: str, type_name: str) -> str:
|
|
934
|
-
if
|
|
951
|
+
# Check if we need to create a partial for this type
|
|
952
|
+
# Only create partials for custom types that are in the custom_types dict
|
|
953
|
+
if type_name not in custom_types:
|
|
935
954
|
return type_name
|
|
936
|
-
|
|
955
|
+
|
|
956
|
+
if "@" in type_name:
|
|
957
|
+
model_name, _ = ModelStore.parse_name_version(type_name)
|
|
958
|
+
else:
|
|
959
|
+
model_name = type_name
|
|
937
960
|
|
|
938
961
|
if signal_name not in signal_partials:
|
|
939
962
|
partial_versions.setdefault(model_name, 0)
|
|
@@ -957,6 +980,14 @@ class SignalSchema:
|
|
|
957
980
|
parent_type_partial = _type_name_to_partial(signal, parent_type)
|
|
958
981
|
|
|
959
982
|
schema[signal] = parent_type_partial
|
|
983
|
+
|
|
984
|
+
# If this is a complex signal without field specifier (just "file")
|
|
985
|
+
# and it's a custom type, include the entire complex signal
|
|
986
|
+
if len(column_parts) == 1 and parent_type in custom_types:
|
|
987
|
+
# Include the entire complex signal - no need to create partial
|
|
988
|
+
schema[signal] = parent_type
|
|
989
|
+
continue
|
|
990
|
+
|
|
960
991
|
continue
|
|
961
992
|
|
|
962
993
|
if parent_type not in custom_types:
|
|
@@ -971,6 +1002,20 @@ class SignalSchema:
|
|
|
971
1002
|
f"Field {signal} not found in custom type {parent_type}"
|
|
972
1003
|
)
|
|
973
1004
|
|
|
1005
|
+
# Check if this is the last part and if the column type is a complex
|
|
1006
|
+
is_last_part = i == len(column_parts) - 1
|
|
1007
|
+
is_complex_signal = signal_type in custom_types
|
|
1008
|
+
|
|
1009
|
+
if is_last_part and is_complex_signal:
|
|
1010
|
+
schema[column] = signal_type
|
|
1011
|
+
# Also need to remove the partial schema entry we created for the
|
|
1012
|
+
# parent since we're promoting the nested complex column to root
|
|
1013
|
+
parent_signal = column_parts[0]
|
|
1014
|
+
schema.pop(parent_signal, None)
|
|
1015
|
+
# Don't create partial types for this case
|
|
1016
|
+
break
|
|
1017
|
+
|
|
1018
|
+
# Create partial type for this field
|
|
974
1019
|
partial_type = _type_name_to_partial(
|
|
975
1020
|
".".join(column_parts[: i + 1]),
|
|
976
1021
|
signal_type,
|
datachain/lib/udf.py
CHANGED
|
@@ -13,8 +13,7 @@ from datachain.asyn import AsyncMapper
|
|
|
13
13
|
from datachain.cache import temporary_cache
|
|
14
14
|
from datachain.dataset import RowDict
|
|
15
15
|
from datachain.lib.convert.flatten import flatten
|
|
16
|
-
from datachain.lib.
|
|
17
|
-
from datachain.lib.file import File
|
|
16
|
+
from datachain.lib.file import DataModel, File
|
|
18
17
|
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
19
18
|
from datachain.query.batch import (
|
|
20
19
|
Batch,
|
|
@@ -266,15 +265,28 @@ class UDFBase(AbstractUDF):
|
|
|
266
265
|
|
|
267
266
|
def _parse_row(
|
|
268
267
|
self, row_dict: RowDict, catalog: "Catalog", cache: bool, download_cb: Callback
|
|
269
|
-
) -> list[
|
|
268
|
+
) -> list[Any]:
|
|
270
269
|
assert self.params
|
|
271
270
|
row = [row_dict[p] for p in self.params.to_udf_spec()]
|
|
272
271
|
obj_row = self.params.row_to_objs(row)
|
|
273
272
|
for obj in obj_row:
|
|
274
|
-
|
|
275
|
-
obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb)
|
|
273
|
+
self._set_stream_recursive(obj, catalog, cache, download_cb)
|
|
276
274
|
return obj_row
|
|
277
275
|
|
|
276
|
+
def _set_stream_recursive(
|
|
277
|
+
self, obj: Any, catalog: "Catalog", cache: bool, download_cb: Callback
|
|
278
|
+
) -> None:
|
|
279
|
+
"""Recursively set the catalog stream on all File objects within an object."""
|
|
280
|
+
if isinstance(obj, File):
|
|
281
|
+
obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb)
|
|
282
|
+
|
|
283
|
+
# Check all fields for nested File objects, but only for DataModel objects
|
|
284
|
+
if isinstance(obj, DataModel):
|
|
285
|
+
for field_name in obj.model_fields:
|
|
286
|
+
field_value = getattr(obj, field_name, None)
|
|
287
|
+
if isinstance(field_value, DataModel):
|
|
288
|
+
self._set_stream_recursive(field_value, catalog, cache, download_cb)
|
|
289
|
+
|
|
278
290
|
def _prepare_row(self, row, udf_fields, catalog, cache, download_cb):
|
|
279
291
|
row_dict = RowDict(zip(udf_fields, row))
|
|
280
292
|
return self._parse_row(row_dict, catalog, cache, download_cb)
|
datachain/query/dataset.py
CHANGED
|
@@ -559,7 +559,13 @@ class UDFStep(Step, ABC):
|
|
|
559
559
|
"""
|
|
560
560
|
Create temporary table with group by partitions.
|
|
561
561
|
"""
|
|
562
|
+
# Check if partition_by is set, we need it to create partitions.
|
|
562
563
|
assert self.partition_by is not None
|
|
564
|
+
# Check if sys__id is in the query, we need it to be able to join
|
|
565
|
+
# the partition table with the udf table later.
|
|
566
|
+
assert any(c.name == "sys__id" for c in query.selected_columns), (
|
|
567
|
+
"Query must have sys__id column to use partitioning."
|
|
568
|
+
)
|
|
563
569
|
|
|
564
570
|
if isinstance(self.partition_by, (list, tuple, GeneratorType)):
|
|
565
571
|
list_partition_by = list(self.partition_by)
|
|
@@ -606,6 +612,22 @@ class UDFStep(Step, ABC):
|
|
|
606
612
|
|
|
607
613
|
# Apply partitioning if needed.
|
|
608
614
|
if self.partition_by is not None:
|
|
615
|
+
if not any(c.name == "sys__id" for c in query.selected_columns):
|
|
616
|
+
# If sys__id is not in the query, we need to create a temp table
|
|
617
|
+
# to hold the query results, so we can join it with the
|
|
618
|
+
# partition table later.
|
|
619
|
+
columns = [
|
|
620
|
+
c if isinstance(c, Column) else Column(c.name, c.type)
|
|
621
|
+
for c in query.subquery().columns
|
|
622
|
+
]
|
|
623
|
+
temp_table = self.catalog.warehouse.create_dataset_rows_table(
|
|
624
|
+
self.catalog.warehouse.temp_table_name(),
|
|
625
|
+
columns=columns,
|
|
626
|
+
)
|
|
627
|
+
temp_tables.append(temp_table.name)
|
|
628
|
+
self.catalog.warehouse.copy_table(temp_table, query)
|
|
629
|
+
_query = query = temp_table.select()
|
|
630
|
+
|
|
609
631
|
partition_tbl = self.create_partitions_table(query)
|
|
610
632
|
temp_tables.append(partition_tbl.name)
|
|
611
633
|
query = query.outerjoin(
|
|
@@ -1031,16 +1053,22 @@ class SQLGroupBy(SQLClause):
|
|
|
1031
1053
|
c.get_column() if isinstance(c, Function) else c for c in self.group_by
|
|
1032
1054
|
]
|
|
1033
1055
|
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
if isinstance(c, Function)
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1056
|
+
cols_dict: dict[str, Any] = {}
|
|
1057
|
+
for c in (*group_by, *self.cols):
|
|
1058
|
+
if isinstance(c, Function):
|
|
1059
|
+
key = c.name
|
|
1060
|
+
value = c.get_column()
|
|
1061
|
+
elif isinstance(c, (str, C)):
|
|
1062
|
+
key = str(c)
|
|
1063
|
+
value = subquery.c[str(c)]
|
|
1064
|
+
else:
|
|
1065
|
+
key = c.name
|
|
1066
|
+
value = c # type: ignore[assignment]
|
|
1067
|
+
cols_dict[key] = value
|
|
1068
|
+
|
|
1069
|
+
unique_cols = cols_dict.values()
|
|
1042
1070
|
|
|
1043
|
-
return sqlalchemy.select(*
|
|
1071
|
+
return sqlalchemy.select(*unique_cols).select_from(subquery).group_by(*group_by)
|
|
1044
1072
|
|
|
1045
1073
|
|
|
1046
1074
|
def _validate_columns(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: datachain
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.26.1
|
|
4
4
|
Summary: Wrangle unstructured AI data at scale
|
|
5
5
|
Author-email: Dmitry Petrov <support@dvc.org>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -63,6 +63,9 @@ Provides-Extra: torch
|
|
|
63
63
|
Requires-Dist: torch>=2.1.0; extra == "torch"
|
|
64
64
|
Requires-Dist: torchvision; extra == "torch"
|
|
65
65
|
Requires-Dist: transformers>=4.36.0; extra == "torch"
|
|
66
|
+
Provides-Extra: audio
|
|
67
|
+
Requires-Dist: torchaudio; extra == "audio"
|
|
68
|
+
Requires-Dist: soundfile; extra == "audio"
|
|
66
69
|
Provides-Extra: remote
|
|
67
70
|
Requires-Dist: lz4; extra == "remote"
|
|
68
71
|
Requires-Dist: requests>=2.22.0; extra == "remote"
|
|
@@ -78,7 +81,7 @@ Requires-Dist: ffmpeg-python; extra == "video"
|
|
|
78
81
|
Requires-Dist: imageio[ffmpeg,pyav]>=2.37.0; extra == "video"
|
|
79
82
|
Requires-Dist: opencv-python; extra == "video"
|
|
80
83
|
Provides-Extra: tests
|
|
81
|
-
Requires-Dist: datachain[hf,remote,torch,vector,video]; extra == "tests"
|
|
84
|
+
Requires-Dist: datachain[audio,hf,remote,torch,vector,video]; extra == "tests"
|
|
82
85
|
Requires-Dist: pytest<9,>=8; extra == "tests"
|
|
83
86
|
Requires-Dist: pytest-sugar>=0.9.6; extra == "tests"
|
|
84
87
|
Requires-Dist: pytest-cov>=4.1.0; extra == "tests"
|
|
@@ -108,6 +111,7 @@ Requires-Dist: accelerate; extra == "examples"
|
|
|
108
111
|
Requires-Dist: huggingface_hub[hf_transfer]; extra == "examples"
|
|
109
112
|
Requires-Dist: ultralytics; extra == "examples"
|
|
110
113
|
Requires-Dist: open_clip_torch; extra == "examples"
|
|
114
|
+
Requires-Dist: openai; extra == "examples"
|
|
111
115
|
Dynamic: license-file
|
|
112
116
|
|
|
113
117
|
================
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
datachain/__init__.py,sha256=
|
|
1
|
+
datachain/__init__.py,sha256=2TZ8ptSB9BtnYF31mDEhWG9N16EQ5pf9vNqQaFr2txs,1712
|
|
2
2
|
datachain/__main__.py,sha256=hG3Y4ARGEqe1AWwNMd259rBlqtphx1Wk39YbueQ0yV8,91
|
|
3
3
|
datachain/asyn.py,sha256=RH_jFwJcTXxhEFomaI9yL6S3Onau6NZ6FSKfKFGtrJE,9689
|
|
4
4
|
datachain/cache.py,sha256=ESVRaCJXEThMIfGEFVHx6wJPOZA7FYk9V6WxjyuqUBY,3626
|
|
@@ -23,7 +23,7 @@ datachain/utils.py,sha256=DNqOi-Ydb7InyWvD9m7_yailxz6-YGpZzh00biQaHNo,15305
|
|
|
23
23
|
datachain/catalog/__init__.py,sha256=cMZzSz3VoUi-6qXSVaHYN-agxQuAcz2XSqnEPZ55crE,353
|
|
24
24
|
datachain/catalog/catalog.py,sha256=QTWCXy75iWo-0MCXyfV_WbsKeZ1fpLpvL8d60rxn1ws,65528
|
|
25
25
|
datachain/catalog/datasource.py,sha256=IkGMh0Ttg6Q-9DWfU_H05WUnZepbGa28HYleECi6K7I,1353
|
|
26
|
-
datachain/catalog/loader.py,sha256=
|
|
26
|
+
datachain/catalog/loader.py,sha256=B2cps5coFE4MBttM-j8cs7JgNVPjnHKF4Gx1s2fJrxw,6119
|
|
27
27
|
datachain/cli/__init__.py,sha256=WvBqnwjG8Wp9xGCn-4eqfoZ3n7Sj1HJemCi4MayJh_c,8221
|
|
28
28
|
datachain/cli/utils.py,sha256=wrLnAh7Wx8O_ojZE8AE4Lxn5WoxHbOj7as8NWlLAA74,3036
|
|
29
29
|
datachain/cli/commands/__init__.py,sha256=zp3bYIioO60x_X04A4-IpZqSYVnpwOa1AdERQaRlIhI,493
|
|
@@ -58,11 +58,11 @@ datachain/diff/__init__.py,sha256=-OFZzgOplqO84iWgGY7kfe60NXaWR9JRIh9T-uJboAM,96
|
|
|
58
58
|
datachain/fs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
59
59
|
datachain/fs/reference.py,sha256=A8McpXF0CqbXPqanXuvpKu50YLB3a2ZXA3YAPxtBXSM,914
|
|
60
60
|
datachain/fs/utils.py,sha256=s-FkTOCGBk-b6TT3toQH51s9608pofoFjUSTc1yy7oE,825
|
|
61
|
-
datachain/func/__init__.py,sha256=
|
|
61
|
+
datachain/func/__init__.py,sha256=9K2MEC1NclY_zWuqevfEUOcrSE26cXDVnGqhNTj4lF8,1288
|
|
62
62
|
datachain/func/aggregate.py,sha256=fmVEKf3MUR29dEgllGdtl6nG7Lwz-SiyA5X1EyRRNUk,12456
|
|
63
63
|
datachain/func/array.py,sha256=fz5NUIPkp_KZ7tadCqJQSSJwWMYXEfYn60QkG2epC3k,13627
|
|
64
64
|
datachain/func/base.py,sha256=wA0sBQAVyN9LPxoo7Ox83peS0zUVnyuKxukwAcjGLfY,534
|
|
65
|
-
datachain/func/conditional.py,sha256=
|
|
65
|
+
datachain/func/conditional.py,sha256=9YYurD_PBMyf5rR9dj2gLv-Lo7UhYfHW6EtrUErVpz8,10165
|
|
66
66
|
datachain/func/func.py,sha256=fpslnn4edr0dH3mD8BSTndRFJiiVZvbJoBJV6HkHMqw,17400
|
|
67
67
|
datachain/func/numeric.py,sha256=J6FgzuIAcS6B02Cm1qPnJdB6ut21jyBDVXSBrkZNZaQ,6978
|
|
68
68
|
datachain/func/path.py,sha256=9Jas35QhEtRai4l54hMqVvuJsqxHvOx88oo4vym1H_I,4077
|
|
@@ -71,24 +71,25 @@ datachain/func/string.py,sha256=X9u4ip97U63RCaKRhMddoze7HgPiY3LbPRn9G06UWWo,7311
|
|
|
71
71
|
datachain/func/window.py,sha256=ImyRpc1QI8QUSPO7KdD60e_DPVo7Ja0G5kcm6BlyMcw,1584
|
|
72
72
|
datachain/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
73
73
|
datachain/lib/arrow.py,sha256=hdEQ8I1JgNmEAaXTaqaU1qvZDi5dgtes1IC69ycthz8,10753
|
|
74
|
+
datachain/lib/audio.py,sha256=J7XJ14ItPF9y6pN-tmMV9In9X9rgwlBwzyzdGOUkPGk,4376
|
|
74
75
|
datachain/lib/clip.py,sha256=lm5CzVi4Cj1jVLEKvERKArb-egb9j1Ls-fwTItT6vlI,6150
|
|
75
76
|
datachain/lib/data_model.py,sha256=ZwBXELtqROEdLL4DmxTipnwUZmhQvMz_UVDzyf7nQ9Y,2899
|
|
76
77
|
datachain/lib/dataset_info.py,sha256=7w-DoKOyIVoOtWGCgciMLcP5CiAWJB3rVI-vUDF80k0,3311
|
|
77
|
-
datachain/lib/file.py,sha256=
|
|
78
|
+
datachain/lib/file.py,sha256=tHBBacsh1580UPFC6fAINBNwNiyymNgzj89rpsz1LKc,40817
|
|
78
79
|
datachain/lib/hf.py,sha256=_dCoGTv7n5cBgxhCDfZI-t3hnMCXGHd6sEsxRThcizE,5754
|
|
79
80
|
datachain/lib/image.py,sha256=erWvZW5M3emnbl6_fGAOPyKm-1EKbt3vOdWPfe3Oo7U,3265
|
|
80
81
|
datachain/lib/listing.py,sha256=U-2stsTEwEsq4Y80dqGfktGzkmB5-ZntnL1_rzXlH0k,7089
|
|
81
82
|
datachain/lib/listing_info.py,sha256=9ua40Hw0aiQByUw3oAEeNzMavJYfW0Uhe8YdCTK-m_g,1110
|
|
82
83
|
datachain/lib/meta_formats.py,sha256=zdyg6XLk3QIsSk3I7s0Ez5kaCJSlE3uq7JiGxf7UwtU,6348
|
|
83
|
-
datachain/lib/model_store.py,sha256=
|
|
84
|
+
datachain/lib/model_store.py,sha256=dkL2rcT5ag-kbgkhQPL_byEs-TCYr29qvdltroL5NxM,2734
|
|
84
85
|
datachain/lib/namespaces.py,sha256=it52UbbwB8dzhesO2pMs_nThXiPQ1Ph9sD9I3GQkg5s,2099
|
|
85
86
|
datachain/lib/projects.py,sha256=8lN0qV8czX1LGtWURCUvRlSJk-RpO9w9Rra_pOZus6g,2595
|
|
86
|
-
datachain/lib/pytorch.py,sha256=
|
|
87
|
+
datachain/lib/pytorch.py,sha256=S-st2SAczYut13KMf6eSqP_OQ8otWI5TRmzhK5fN3k0,7828
|
|
87
88
|
datachain/lib/settings.py,sha256=9wi0FoHxRxNiyn99pR28IYsMkoo47jQxeXuObQr2Ar0,2929
|
|
88
|
-
datachain/lib/signal_schema.py,sha256=
|
|
89
|
+
datachain/lib/signal_schema.py,sha256=tOWcWEG0ZwiU0qxywEYs3qkTexQQHmzg28wZ1CJGyEI,38552
|
|
89
90
|
datachain/lib/tar.py,sha256=MLcVjzIgBqRuJacCNpZ6kwSZNq1i2tLyROc8PVprHsA,999
|
|
90
91
|
datachain/lib/text.py,sha256=UNHm8fhidk7wdrWqacEWaA6I9ykfYqarQ2URby7jc7M,1261
|
|
91
|
-
datachain/lib/udf.py,sha256=
|
|
92
|
+
datachain/lib/udf.py,sha256=nkcB3HNtSteUspwsGmOKyy3mH2F-Sfo6iW64-Ep47-I,17299
|
|
92
93
|
datachain/lib/udf_signature.py,sha256=Yz20iJ-WF1pijT3hvcDIKFzgWV9gFxZM73KZRx3NbPk,7560
|
|
93
94
|
datachain/lib/utils.py,sha256=rG2y7NwTqZOuomZZRmrA-Q-ANM_j1cToQYqDJoOeGyU,1480
|
|
94
95
|
datachain/lib/video.py,sha256=u6fLJWj5G6QqsVkpfHnKGklBNpG3BRRg6v3izngnNcU,6767
|
|
@@ -97,13 +98,13 @@ datachain/lib/webdataset_laion.py,sha256=xvT6m_r5y0KbOx14BUe7UC5mOgrktJq53Mh-H0E
|
|
|
97
98
|
datachain/lib/convert/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
98
99
|
datachain/lib/convert/flatten.py,sha256=IZFiUYbgXSxXhPSG5Cqf5IjnJ4ZDZKXMr4o_yCR1NY4,1505
|
|
99
100
|
datachain/lib/convert/python_to_sql.py,sha256=wg-O5FRKX3x3Wh8ZL1b9ntMlgf1zRO4djMP3t8CHJLo,3188
|
|
100
|
-
datachain/lib/convert/sql_to_python.py,sha256=
|
|
101
|
+
datachain/lib/convert/sql_to_python.py,sha256=Gxc4FylWC_Pvvuawuc2MKZIiuAWI7wje8pyeN1MxRrU,670
|
|
101
102
|
datachain/lib/convert/unflatten.py,sha256=ysMkstwJzPMWUlnxn-Z-tXJR3wmhjHeSN_P-sDcLS6s,2010
|
|
102
103
|
datachain/lib/convert/values_to_tuples.py,sha256=j5yZMrVUH6W7b-7yUvdCTGI7JCUAYUOzHUGPoyZXAB0,4360
|
|
103
104
|
datachain/lib/dc/__init__.py,sha256=TFci5HTvYGjBesNUxDAnXaX36PnzPEUSn5a6JxB9o0U,872
|
|
104
105
|
datachain/lib/dc/csv.py,sha256=q6a9BpapGwP6nwy6c5cklxQumep2fUp9l2LAjtTJr6s,4411
|
|
105
106
|
datachain/lib/dc/database.py,sha256=g5M6NjYR1T0vKte-abV-3Ejnm-HqxTIMir5cRi_SziE,6051
|
|
106
|
-
datachain/lib/dc/datachain.py,sha256=
|
|
107
|
+
datachain/lib/dc/datachain.py,sha256=ap54lcuj71tvp0zX1jiFFiEWvA5UPeyYJRJkd2APmlI,92897
|
|
107
108
|
datachain/lib/dc/datasets.py,sha256=P6CIJizD2IYFwOQG5D3VbQRjDmUiRH0ysdtb551Xdm8,15098
|
|
108
109
|
datachain/lib/dc/hf.py,sha256=PJl2wiLjdRsMz0SYbLT-6H8b-D5i2WjeH7li8HHOk_0,2145
|
|
109
110
|
datachain/lib/dc/json.py,sha256=dNijfJ-H92vU3soyR7X1IiDrWhm6yZIGG3bSnZkPdAE,2733
|
|
@@ -125,7 +126,7 @@ datachain/model/ultralytics/pose.py,sha256=pBlmt63Qe68FKmexHimUGlNbNOoOlMHXG4fzX
|
|
|
125
126
|
datachain/model/ultralytics/segment.py,sha256=63bDCj43E6iZ0hFI5J6uQfksdCmjEp6sEm1XzVaE8pw,2986
|
|
126
127
|
datachain/query/__init__.py,sha256=7DhEIjAA8uZJfejruAVMZVcGFmvUpffuZJwgRqNwe-c,263
|
|
127
128
|
datachain/query/batch.py,sha256=-goxLpE0EUvaDHu66rstj53UnfHpYfBUGux8GSpJ93k,4306
|
|
128
|
-
datachain/query/dataset.py,sha256=
|
|
129
|
+
datachain/query/dataset.py,sha256=cYNrg1QyrZpO-oup3mqmSYHUvgEYBKe8RgkVbyQa6p0,62777
|
|
129
130
|
datachain/query/dispatch.py,sha256=A0nPxn6mEN5d9dDo6S8m16Ji_9IvJLXrgF2kqXdi4fs,15546
|
|
130
131
|
datachain/query/metrics.py,sha256=DOK5HdNVaRugYPjl8qnBONvTkwjMloLqAr7Mi3TjCO0,858
|
|
131
132
|
datachain/query/params.py,sha256=O_j89mjYRLOwWNhYZl-z7mi-rkdP7WyFmaDufsdTryE,863
|
|
@@ -157,9 +158,9 @@ datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR
|
|
|
157
158
|
datachain/toolkit/__init__.py,sha256=eQ58Q5Yf_Fgv1ZG0IO5dpB4jmP90rk8YxUWmPc1M2Bo,68
|
|
158
159
|
datachain/toolkit/split.py,sha256=ktGWzY4kyzjWyR86dhvzw-Zhl0lVk_LOX3NciTac6qo,2914
|
|
159
160
|
datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
|
|
160
|
-
datachain-0.
|
|
161
|
-
datachain-0.
|
|
162
|
-
datachain-0.
|
|
163
|
-
datachain-0.
|
|
164
|
-
datachain-0.
|
|
165
|
-
datachain-0.
|
|
161
|
+
datachain-0.26.1.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
|
|
162
|
+
datachain-0.26.1.dist-info/METADATA,sha256=C0Pb9d9IcJ6oOPXihcyEhTc_Rf7Fe4pP_anKhC3JfeU,13543
|
|
163
|
+
datachain-0.26.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
164
|
+
datachain-0.26.1.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
|
|
165
|
+
datachain-0.26.1.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
|
|
166
|
+
datachain-0.26.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|