datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/file.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
import errno
|
|
2
2
|
import hashlib
|
|
3
3
|
import io
|
|
4
|
-
import json
|
|
5
4
|
import logging
|
|
6
5
|
import os
|
|
7
6
|
import posixpath
|
|
7
|
+
import warnings
|
|
8
8
|
from abc import ABC, abstractmethod
|
|
9
9
|
from collections.abc import Iterator
|
|
10
10
|
from contextlib import contextmanager
|
|
11
11
|
from datetime import datetime
|
|
12
12
|
from functools import partial
|
|
13
13
|
from io import BytesIO
|
|
14
|
-
from pathlib import Path, PurePosixPath
|
|
15
|
-
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
|
14
|
+
from pathlib import Path, PurePath, PurePosixPath
|
|
15
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
|
16
16
|
from urllib.parse import unquote, urlparse
|
|
17
17
|
from urllib.request import url2pathname
|
|
18
18
|
|
|
@@ -20,9 +20,10 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
|
20
20
|
from fsspec.utils import stringify_path
|
|
21
21
|
from pydantic import Field, field_validator
|
|
22
22
|
|
|
23
|
+
from datachain import json
|
|
23
24
|
from datachain.client.fileslice import FileSlice
|
|
24
25
|
from datachain.lib.data_model import DataModel
|
|
25
|
-
from datachain.lib.utils import DataChainError
|
|
26
|
+
from datachain.lib.utils import DataChainError, rebase_path
|
|
26
27
|
from datachain.nodes_thread_pool import NodesThreadPool
|
|
27
28
|
from datachain.sql.types import JSON, Boolean, DateTime, Int, String
|
|
28
29
|
from datachain.utils import TIME_ZERO
|
|
@@ -34,15 +35,16 @@ if TYPE_CHECKING:
|
|
|
34
35
|
from datachain.catalog import Catalog
|
|
35
36
|
from datachain.client.fsspec import Client
|
|
36
37
|
from datachain.dataset import RowDict
|
|
38
|
+
from datachain.query.session import Session
|
|
37
39
|
|
|
38
40
|
sha256 = partial(hashlib.sha256, usedforsecurity=False)
|
|
39
41
|
|
|
40
42
|
logger = logging.getLogger("datachain")
|
|
41
43
|
|
|
42
44
|
# how to create file path when exporting
|
|
43
|
-
ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
|
|
45
|
+
ExportPlacement = Literal["filename", "etag", "fullpath", "checksum", "filepath"]
|
|
44
46
|
|
|
45
|
-
FileType = Literal["binary", "text", "image", "video"]
|
|
47
|
+
FileType = Literal["binary", "text", "image", "video", "audio"]
|
|
46
48
|
EXPORT_FILES_MAX_THREADS = 5
|
|
47
49
|
|
|
48
50
|
|
|
@@ -51,12 +53,12 @@ class FileExporter(NodesThreadPool):
|
|
|
51
53
|
|
|
52
54
|
def __init__(
|
|
53
55
|
self,
|
|
54
|
-
output:
|
|
56
|
+
output: str | os.PathLike[str],
|
|
55
57
|
placement: ExportPlacement,
|
|
56
58
|
use_cache: bool,
|
|
57
59
|
link_type: Literal["copy", "symlink"],
|
|
58
60
|
max_threads: int = EXPORT_FILES_MAX_THREADS,
|
|
59
|
-
client_config:
|
|
61
|
+
client_config: dict | None = None,
|
|
60
62
|
):
|
|
61
63
|
super().__init__(max_threads)
|
|
62
64
|
self.output = output
|
|
@@ -69,7 +71,7 @@ class FileExporter(NodesThreadPool):
|
|
|
69
71
|
for task in done:
|
|
70
72
|
task.result()
|
|
71
73
|
|
|
72
|
-
def do_task(self, file):
|
|
74
|
+
def do_task(self, file: "File"):
|
|
73
75
|
file.export(
|
|
74
76
|
self.output,
|
|
75
77
|
self.placement,
|
|
@@ -81,14 +83,28 @@ class FileExporter(NodesThreadPool):
|
|
|
81
83
|
|
|
82
84
|
|
|
83
85
|
class VFileError(DataChainError):
|
|
84
|
-
def __init__(self,
|
|
86
|
+
def __init__(self, message: str, source: str, path: str, vtype: str = ""):
|
|
87
|
+
self.message = message
|
|
88
|
+
self.source = source
|
|
89
|
+
self.path = path
|
|
90
|
+
self.vtype = vtype
|
|
91
|
+
|
|
85
92
|
type_ = f" of vtype '{vtype}'" if vtype else ""
|
|
86
|
-
super().__init__(f"Error in v-file '{
|
|
93
|
+
super().__init__(f"Error in v-file '{source}/{path}'{type_}: {message}")
|
|
94
|
+
|
|
95
|
+
def __reduce__(self):
|
|
96
|
+
return self.__class__, (self.message, self.source, self.path, self.vtype)
|
|
87
97
|
|
|
88
98
|
|
|
89
99
|
class FileError(DataChainError):
|
|
90
|
-
def __init__(self,
|
|
91
|
-
|
|
100
|
+
def __init__(self, message: str, source: str, path: str):
|
|
101
|
+
self.message = message
|
|
102
|
+
self.source = source
|
|
103
|
+
self.path = path
|
|
104
|
+
super().__init__(f"Error in file '{source}/{path}': {message}")
|
|
105
|
+
|
|
106
|
+
def __reduce__(self):
|
|
107
|
+
return self.__class__, (self.message, self.source, self.path)
|
|
92
108
|
|
|
93
109
|
|
|
94
110
|
class VFile(ABC):
|
|
@@ -113,26 +129,36 @@ class TarVFile(VFile):
|
|
|
113
129
|
@classmethod
|
|
114
130
|
def open(cls, file: "File", location: list[dict]):
|
|
115
131
|
"""Stream file from tar archive based on location in archive."""
|
|
116
|
-
|
|
117
|
-
raise VFileError(file, "multiple 'location's are not supported yet")
|
|
132
|
+
tar_file = cls.parent(file, location)
|
|
118
133
|
|
|
119
134
|
loc = location[0]
|
|
120
135
|
|
|
121
136
|
if (offset := loc.get("offset", None)) is None:
|
|
122
|
-
raise VFileError(
|
|
137
|
+
raise VFileError("'offset' is not specified", file.source, file.path)
|
|
123
138
|
|
|
124
139
|
if (size := loc.get("size", None)) is None:
|
|
125
|
-
raise VFileError(
|
|
140
|
+
raise VFileError("'size' is not specified", file.source, file.path)
|
|
141
|
+
|
|
142
|
+
client = file._catalog.get_client(tar_file.source)
|
|
143
|
+
fd = client.open_object(tar_file, use_cache=file._caching_enabled)
|
|
144
|
+
return FileSlice(fd, offset, size, file.name)
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
def parent(cls, file: "File", location: list[dict]) -> "File":
|
|
148
|
+
if len(location) > 1:
|
|
149
|
+
raise VFileError(
|
|
150
|
+
"multiple 'location's are not supported yet", file.source, file.path
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
loc = location[0]
|
|
126
154
|
|
|
127
155
|
if (parent := loc.get("parent", None)) is None:
|
|
128
|
-
raise VFileError(
|
|
156
|
+
raise VFileError("'parent' is not specified", file.source, file.path)
|
|
129
157
|
|
|
130
158
|
tar_file = File(**parent)
|
|
131
159
|
tar_file._set_stream(file._catalog)
|
|
132
160
|
|
|
133
|
-
|
|
134
|
-
fd = client.open_object(tar_file, use_cache=file._caching_enabled)
|
|
135
|
-
return FileSlice(fd, offset, size, file.name)
|
|
161
|
+
return tar_file
|
|
136
162
|
|
|
137
163
|
|
|
138
164
|
class VFileRegistry:
|
|
@@ -143,19 +169,33 @@ class VFileRegistry:
|
|
|
143
169
|
cls._vtype_readers[reader.get_vtype()] = reader
|
|
144
170
|
|
|
145
171
|
@classmethod
|
|
146
|
-
def
|
|
172
|
+
def _get_reader(cls, file: "File", location: list[dict]):
|
|
147
173
|
if len(location) == 0:
|
|
148
|
-
raise VFileError(
|
|
174
|
+
raise VFileError(
|
|
175
|
+
"'location' must not be list of JSONs", file.source, file.path
|
|
176
|
+
)
|
|
149
177
|
|
|
150
178
|
if not (vtype := location[0].get("vtype", "")):
|
|
151
|
-
raise VFileError(
|
|
179
|
+
raise VFileError("vtype is not specified", file.source, file.path)
|
|
152
180
|
|
|
153
181
|
reader = cls._vtype_readers.get(vtype, None)
|
|
154
182
|
if not reader:
|
|
155
|
-
raise VFileError(
|
|
183
|
+
raise VFileError(
|
|
184
|
+
"reader not registered", file.source, file.path, vtype=vtype
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return reader
|
|
156
188
|
|
|
189
|
+
@classmethod
|
|
190
|
+
def open(cls, file: "File", location: list[dict]):
|
|
191
|
+
reader = cls._get_reader(file, location)
|
|
157
192
|
return reader.open(file, location)
|
|
158
193
|
|
|
194
|
+
@classmethod
|
|
195
|
+
def parent(cls, file: "File", location: list[dict]) -> "File":
|
|
196
|
+
reader = cls._get_reader(file, location)
|
|
197
|
+
return reader.parent(file, location)
|
|
198
|
+
|
|
159
199
|
|
|
160
200
|
class File(DataModel):
|
|
161
201
|
"""
|
|
@@ -181,7 +221,7 @@ class File(DataModel):
|
|
|
181
221
|
etag: str = Field(default="")
|
|
182
222
|
is_latest: bool = Field(default=True)
|
|
183
223
|
last_modified: datetime = Field(default=TIME_ZERO)
|
|
184
|
-
location:
|
|
224
|
+
location: dict | list[dict] | None = Field(default=None)
|
|
185
225
|
|
|
186
226
|
_datachain_column_types: ClassVar[dict[str, Any]] = {
|
|
187
227
|
"source": String,
|
|
@@ -213,10 +253,19 @@ class File(DataModel):
|
|
|
213
253
|
"last_modified",
|
|
214
254
|
]
|
|
215
255
|
|
|
256
|
+
# Allowed kwargs we forward to TextIOWrapper
|
|
257
|
+
_TEXT_WRAPPER_ALLOWED: ClassVar[tuple[str, ...]] = (
|
|
258
|
+
"encoding",
|
|
259
|
+
"errors",
|
|
260
|
+
"newline",
|
|
261
|
+
"line_buffering",
|
|
262
|
+
"write_through",
|
|
263
|
+
)
|
|
264
|
+
|
|
216
265
|
@staticmethod
|
|
217
266
|
def _validate_dict(
|
|
218
|
-
v:
|
|
219
|
-
) ->
|
|
267
|
+
v: str | dict | list[dict] | None,
|
|
268
|
+
) -> str | dict | list[dict] | None:
|
|
220
269
|
if v is None or v == "":
|
|
221
270
|
return None
|
|
222
271
|
if isinstance(v, str):
|
|
@@ -236,8 +285,8 @@ class File(DataModel):
|
|
|
236
285
|
|
|
237
286
|
@field_validator("path", mode="before")
|
|
238
287
|
@classmethod
|
|
239
|
-
def validate_path(cls, path):
|
|
240
|
-
return
|
|
288
|
+
def validate_path(cls, path: str) -> str:
|
|
289
|
+
return PurePath(path).as_posix() if path else ""
|
|
241
290
|
|
|
242
291
|
def model_dump_custom(self):
|
|
243
292
|
res = self.model_dump()
|
|
@@ -248,6 +297,16 @@ class File(DataModel):
|
|
|
248
297
|
super().__init__(**kwargs)
|
|
249
298
|
self._catalog = None
|
|
250
299
|
self._caching_enabled: bool = False
|
|
300
|
+
self._download_cb: Callback = DEFAULT_CALLBACK
|
|
301
|
+
|
|
302
|
+
def __getstate__(self):
|
|
303
|
+
state = super().__getstate__()
|
|
304
|
+
# Exclude _catalog from pickling - it contains SQLAlchemy engine and other
|
|
305
|
+
# non-picklable objects. The catalog will be re-set by _set_stream() on the
|
|
306
|
+
# worker side when needed.
|
|
307
|
+
state["__dict__"] = state["__dict__"].copy()
|
|
308
|
+
state["__dict__"]["_catalog"] = None
|
|
309
|
+
return state
|
|
251
310
|
|
|
252
311
|
def as_text_file(self) -> "TextFile":
|
|
253
312
|
"""Convert the file to a `TextFile` object."""
|
|
@@ -273,19 +332,31 @@ class File(DataModel):
|
|
|
273
332
|
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
274
333
|
return file
|
|
275
334
|
|
|
335
|
+
def as_audio_file(self) -> "AudioFile":
|
|
336
|
+
"""Convert the file to a `AudioFile` object."""
|
|
337
|
+
if isinstance(self, AudioFile):
|
|
338
|
+
return self
|
|
339
|
+
file = AudioFile(**self.model_dump())
|
|
340
|
+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
341
|
+
return file
|
|
342
|
+
|
|
276
343
|
@classmethod
|
|
277
344
|
def upload(
|
|
278
|
-
cls,
|
|
345
|
+
cls,
|
|
346
|
+
data: bytes,
|
|
347
|
+
path: str | os.PathLike[str],
|
|
348
|
+
catalog: "Catalog | None" = None,
|
|
279
349
|
) -> "Self":
|
|
280
350
|
if catalog is None:
|
|
281
|
-
from datachain.
|
|
282
|
-
|
|
283
|
-
catalog = get_catalog()
|
|
351
|
+
from datachain.query.session import Session
|
|
284
352
|
|
|
353
|
+
catalog = Session.get().catalog
|
|
285
354
|
from datachain.client.fsspec import Client
|
|
286
355
|
|
|
287
|
-
|
|
288
|
-
|
|
356
|
+
path_str = stringify_path(path)
|
|
357
|
+
|
|
358
|
+
client_cls = Client.get_implementation(path_str)
|
|
359
|
+
source, rel_path = client_cls.split_url(path_str)
|
|
289
360
|
|
|
290
361
|
client = catalog.get_client(client_cls.get_uri(source))
|
|
291
362
|
file = client.upload(data, rel_path)
|
|
@@ -294,49 +365,150 @@ class File(DataModel):
|
|
|
294
365
|
file._set_stream(catalog)
|
|
295
366
|
return file
|
|
296
367
|
|
|
368
|
+
@classmethod
|
|
369
|
+
def at(
|
|
370
|
+
cls, uri: str | os.PathLike[str], session: "Session | None" = None
|
|
371
|
+
) -> "Self":
|
|
372
|
+
"""Construct a File from a full URI in one call.
|
|
373
|
+
|
|
374
|
+
Example:
|
|
375
|
+
file = File.at("s3://bucket/path/to/output.png")
|
|
376
|
+
with file.open("wb") as f: ...
|
|
377
|
+
"""
|
|
378
|
+
from datachain.client.fsspec import Client
|
|
379
|
+
from datachain.query.session import Session
|
|
380
|
+
|
|
381
|
+
if session is None:
|
|
382
|
+
session = Session.get()
|
|
383
|
+
catalog = session.catalog
|
|
384
|
+
uri_str = stringify_path(uri)
|
|
385
|
+
if uri_str.endswith(("/", os.sep)):
|
|
386
|
+
raise ValueError(
|
|
387
|
+
f"File.at directory URL/path given (trailing slash), got: {uri_str}"
|
|
388
|
+
)
|
|
389
|
+
client_cls = Client.get_implementation(uri_str)
|
|
390
|
+
uri_str = client_cls.path_to_uri(uri_str)
|
|
391
|
+
source, rel_path = client_cls.split_url(uri_str)
|
|
392
|
+
source_uri = client_cls.get_uri(source)
|
|
393
|
+
file = cls(source=source_uri, path=rel_path)
|
|
394
|
+
file._set_stream(catalog)
|
|
395
|
+
return file
|
|
396
|
+
|
|
297
397
|
@classmethod
|
|
298
398
|
def _from_row(cls, row: "RowDict") -> "Self":
|
|
299
399
|
return cls(**{key: row[key] for key in cls._datachain_column_types})
|
|
300
400
|
|
|
301
401
|
@property
|
|
302
|
-
def name(self):
|
|
402
|
+
def name(self) -> str:
|
|
303
403
|
return PurePosixPath(self.path).name
|
|
304
404
|
|
|
305
405
|
@property
|
|
306
|
-
def parent(self):
|
|
406
|
+
def parent(self) -> str:
|
|
307
407
|
return str(PurePosixPath(self.path).parent)
|
|
308
408
|
|
|
309
409
|
@contextmanager
|
|
310
|
-
def open(
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
410
|
+
def open(
|
|
411
|
+
self,
|
|
412
|
+
mode: str = "rb",
|
|
413
|
+
*,
|
|
414
|
+
client_config: dict[str, Any] | None = None,
|
|
415
|
+
**open_kwargs,
|
|
416
|
+
) -> Iterator[Any]:
|
|
417
|
+
"""Open the file and return a file-like object.
|
|
418
|
+
|
|
419
|
+
Supports both read ("rb", "r") and write modes (e.g. "wb", "w", "ab").
|
|
420
|
+
When opened in a write mode, metadata is refreshed after closing.
|
|
421
|
+
"""
|
|
422
|
+
writing = any(ch in mode for ch in "wax+")
|
|
423
|
+
if self.location and writing:
|
|
424
|
+
raise VFileError(
|
|
425
|
+
"Writing to virtual file is not supported",
|
|
426
|
+
self.source,
|
|
427
|
+
self.path,
|
|
428
|
+
)
|
|
315
429
|
|
|
316
|
-
|
|
430
|
+
if self._catalog is None:
|
|
431
|
+
raise RuntimeError("Cannot open file: catalog is not set")
|
|
432
|
+
|
|
433
|
+
base_cfg = getattr(self._catalog, "client_config", {}) or {}
|
|
434
|
+
merged_cfg = {**base_cfg, **(client_config or {})}
|
|
435
|
+
client: Client = self._catalog.get_client(self.source, **merged_cfg)
|
|
436
|
+
|
|
437
|
+
if not writing:
|
|
438
|
+
if self.location:
|
|
439
|
+
with VFileRegistry.open(self, self.location) as f: # type: ignore[arg-type]
|
|
440
|
+
yield self._wrap_text(f, mode, open_kwargs)
|
|
441
|
+
return
|
|
317
442
|
if self._caching_enabled:
|
|
318
443
|
self.ensure_cached()
|
|
319
|
-
client: Client = self._catalog.get_client(self.source)
|
|
320
444
|
with client.open_object(
|
|
321
445
|
self, use_cache=self._caching_enabled, cb=self._download_cb
|
|
322
446
|
) as f:
|
|
323
|
-
yield
|
|
447
|
+
yield self._wrap_text(f, mode, open_kwargs)
|
|
448
|
+
return
|
|
449
|
+
|
|
450
|
+
# write path
|
|
451
|
+
full_path = client.get_full_path(self.get_path_normalized())
|
|
452
|
+
with client.fs.open(full_path, mode, **open_kwargs) as f:
|
|
453
|
+
yield self._wrap_text(f, mode, open_kwargs)
|
|
454
|
+
|
|
455
|
+
version_hint = self._extract_write_version(f)
|
|
456
|
+
|
|
457
|
+
# refresh metadata pinned to the version that was just written
|
|
458
|
+
refreshed = client.get_file_info(
|
|
459
|
+
self.get_path_normalized(), version_id=version_hint
|
|
460
|
+
)
|
|
461
|
+
for k, v in refreshed.model_dump().items():
|
|
462
|
+
setattr(self, k, v)
|
|
463
|
+
|
|
464
|
+
def _wrap_text(self, f: Any, mode: str, open_kwargs: dict[str, Any]) -> Any:
|
|
465
|
+
"""Return stream possibly wrapped for text."""
|
|
466
|
+
if "b" in mode or isinstance(f, io.TextIOBase):
|
|
467
|
+
return f
|
|
468
|
+
filtered = {
|
|
469
|
+
k: open_kwargs[k] for k in self._TEXT_WRAPPER_ALLOWED if k in open_kwargs
|
|
470
|
+
}
|
|
471
|
+
return io.TextIOWrapper(f, **filtered)
|
|
472
|
+
|
|
473
|
+
def _extract_write_version(self, handle: Any) -> str | None:
|
|
474
|
+
"""Best-effort extraction of object version after a write.
|
|
475
|
+
|
|
476
|
+
S3 (s3fs) and Azure (adlfs) populate version_id on the handle.
|
|
477
|
+
GCS (gcsfs) populates generation. Azure and GCS require upstream
|
|
478
|
+
fixes to be released.
|
|
479
|
+
"""
|
|
480
|
+
for attr in ("version_id", "generation"):
|
|
481
|
+
if value := getattr(handle, attr, None):
|
|
482
|
+
return value
|
|
483
|
+
return None
|
|
324
484
|
|
|
325
485
|
def read_bytes(self, length: int = -1):
|
|
326
486
|
"""Returns file contents as bytes."""
|
|
327
487
|
with self.open() as stream:
|
|
328
488
|
return stream.read(length)
|
|
329
489
|
|
|
330
|
-
def read_text(self):
|
|
331
|
-
"""
|
|
332
|
-
|
|
490
|
+
def read_text(self, **open_kwargs):
|
|
491
|
+
"""Return file contents decoded as text.
|
|
492
|
+
|
|
493
|
+
**open_kwargs : Any
|
|
494
|
+
Extra keyword arguments forwarded to ``open(mode="r", ...)``
|
|
495
|
+
(e.g. ``encoding="utf-8"``, ``errors="ignore"``)
|
|
496
|
+
"""
|
|
497
|
+
if self.location:
|
|
498
|
+
raise VFileError(
|
|
499
|
+
"Reading text from virtual file is not supported",
|
|
500
|
+
self.source,
|
|
501
|
+
self.path,
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
with self.open(mode="r", **open_kwargs) as stream:
|
|
333
505
|
return stream.read()
|
|
334
506
|
|
|
335
507
|
def read(self, length: int = -1):
|
|
336
508
|
"""Returns file contents."""
|
|
337
509
|
return self.read_bytes(length)
|
|
338
510
|
|
|
339
|
-
def save(self, destination: str, client_config:
|
|
511
|
+
def save(self, destination: str, client_config: dict | None = None):
|
|
340
512
|
"""Writes it's content to destination"""
|
|
341
513
|
destination = stringify_path(destination)
|
|
342
514
|
client: Client = self._catalog.get_client(destination, **(client_config or {}))
|
|
@@ -346,7 +518,7 @@ class File(DataModel):
|
|
|
346
518
|
|
|
347
519
|
client.upload(self.read(), destination)
|
|
348
520
|
|
|
349
|
-
def _symlink_to(self, destination: str):
|
|
521
|
+
def _symlink_to(self, destination: str) -> None:
|
|
350
522
|
if self.location:
|
|
351
523
|
raise OSError(errno.ENOTSUP, "Symlinking virtual file is not supported")
|
|
352
524
|
|
|
@@ -355,7 +527,7 @@ class File(DataModel):
|
|
|
355
527
|
source = self.get_local_path()
|
|
356
528
|
assert source, "File was not cached"
|
|
357
529
|
elif self.source.startswith("file://"):
|
|
358
|
-
source = self.
|
|
530
|
+
source = self.get_fs_path()
|
|
359
531
|
else:
|
|
360
532
|
raise OSError(errno.EXDEV, "can't link across filesystems")
|
|
361
533
|
|
|
@@ -363,11 +535,11 @@ class File(DataModel):
|
|
|
363
535
|
|
|
364
536
|
def export(
|
|
365
537
|
self,
|
|
366
|
-
output:
|
|
538
|
+
output: str | os.PathLike[str],
|
|
367
539
|
placement: ExportPlacement = "fullpath",
|
|
368
540
|
use_cache: bool = True,
|
|
369
541
|
link_type: Literal["copy", "symlink"] = "copy",
|
|
370
|
-
client_config:
|
|
542
|
+
client_config: dict | None = None,
|
|
371
543
|
) -> None:
|
|
372
544
|
"""Export file to new location."""
|
|
373
545
|
self._caching_enabled = use_cache
|
|
@@ -403,18 +575,22 @@ class File(DataModel):
|
|
|
403
575
|
client = self._catalog.get_client(self.source)
|
|
404
576
|
client.download(self, callback=self._download_cb)
|
|
405
577
|
|
|
406
|
-
async def _prefetch(self, download_cb:
|
|
578
|
+
async def _prefetch(self, download_cb: "Callback | None" = None) -> bool:
|
|
407
579
|
if self._catalog is None:
|
|
408
580
|
raise RuntimeError("cannot prefetch file because catalog is not setup")
|
|
409
581
|
|
|
582
|
+
file = self
|
|
583
|
+
if self.location:
|
|
584
|
+
file = VFileRegistry.parent(self, self.location) # type: ignore[arg-type]
|
|
585
|
+
|
|
410
586
|
client = self._catalog.get_client(self.source)
|
|
411
|
-
await client._download(
|
|
412
|
-
|
|
587
|
+
await client._download(file, callback=download_cb or self._download_cb)
|
|
588
|
+
file._set_stream(
|
|
413
589
|
self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
|
|
414
590
|
)
|
|
415
591
|
return True
|
|
416
592
|
|
|
417
|
-
def get_local_path(self) ->
|
|
593
|
+
def get_local_path(self) -> str | None:
|
|
418
594
|
"""Return path to a file in a local cache.
|
|
419
595
|
|
|
420
596
|
Returns None if file is not cached.
|
|
@@ -432,31 +608,66 @@ class File(DataModel):
|
|
|
432
608
|
|
|
433
609
|
def get_file_ext(self):
|
|
434
610
|
"""Returns last part of file name without `.`."""
|
|
435
|
-
return PurePosixPath(self.path).suffix.
|
|
611
|
+
return PurePosixPath(self.path).suffix.lstrip(".")
|
|
436
612
|
|
|
437
613
|
def get_file_stem(self):
|
|
438
614
|
"""Returns file name without extension."""
|
|
439
615
|
return PurePosixPath(self.path).stem
|
|
440
616
|
|
|
441
617
|
def get_full_name(self):
|
|
442
|
-
"""
|
|
618
|
+
"""
|
|
619
|
+
[DEPRECATED] Use `file.path` directly instead.
|
|
620
|
+
|
|
621
|
+
Returns name with parent directories.
|
|
622
|
+
"""
|
|
623
|
+
warnings.warn(
|
|
624
|
+
"file.get_full_name() is deprecated and will be removed "
|
|
625
|
+
"in a future version. Use `file.path` directly.",
|
|
626
|
+
DeprecationWarning,
|
|
627
|
+
stacklevel=2,
|
|
628
|
+
)
|
|
443
629
|
return self.path
|
|
444
630
|
|
|
445
|
-
def
|
|
631
|
+
def get_path_normalized(self) -> str:
|
|
632
|
+
if not self.path:
|
|
633
|
+
raise FileError("path must not be empty", self.source, self.path)
|
|
634
|
+
|
|
635
|
+
if self.path.endswith("/"):
|
|
636
|
+
raise FileError("path must not be a directory", self.source, self.path)
|
|
637
|
+
|
|
638
|
+
normpath = os.path.normpath(self.path)
|
|
639
|
+
normpath = PurePath(normpath).as_posix()
|
|
640
|
+
|
|
641
|
+
if normpath == ".":
|
|
642
|
+
raise FileError("path must not be a directory", self.source, self.path)
|
|
643
|
+
|
|
644
|
+
if any(part == ".." for part in PurePath(normpath).parts):
|
|
645
|
+
raise FileError("path must not contain '..'", self.source, self.path)
|
|
646
|
+
|
|
647
|
+
return normpath
|
|
648
|
+
|
|
649
|
+
def get_uri(self) -> str:
|
|
446
650
|
"""Returns file URI."""
|
|
447
|
-
return f"{self.source}/{self.
|
|
651
|
+
return f"{self.source}/{self.get_path_normalized()}"
|
|
652
|
+
|
|
653
|
+
def get_fs_path(self) -> str:
|
|
654
|
+
"""
|
|
655
|
+
Returns file path with respect to the filescheme.
|
|
656
|
+
|
|
657
|
+
If `normalize` is True, the path is normalized to remove any redundant
|
|
658
|
+
separators and up-level references.
|
|
448
659
|
|
|
449
|
-
|
|
450
|
-
|
|
660
|
+
If the file scheme is "file", the path is converted to a local file path
|
|
661
|
+
using `url2pathname`. Otherwise, the original path with scheme is returned.
|
|
662
|
+
"""
|
|
451
663
|
path = unquote(self.get_uri())
|
|
452
|
-
|
|
453
|
-
if
|
|
454
|
-
path =
|
|
455
|
-
path = url2pathname(path)
|
|
664
|
+
path_parsed = urlparse(path)
|
|
665
|
+
if path_parsed.scheme == "file":
|
|
666
|
+
path = url2pathname(path_parsed.path)
|
|
456
667
|
return path
|
|
457
668
|
|
|
458
669
|
def get_destination_path(
|
|
459
|
-
self, output:
|
|
670
|
+
self, output: str | os.PathLike[str], placement: ExportPlacement
|
|
460
671
|
) -> str:
|
|
461
672
|
"""
|
|
462
673
|
Returns full destination path of a file for exporting to some output
|
|
@@ -467,10 +678,12 @@ class File(DataModel):
|
|
|
467
678
|
elif placement == "etag":
|
|
468
679
|
path = f"{self.etag}{self.get_file_suffix()}"
|
|
469
680
|
elif placement == "fullpath":
|
|
470
|
-
path = unquote(self.
|
|
681
|
+
path = unquote(self.get_path_normalized())
|
|
471
682
|
source = urlparse(self.source)
|
|
472
683
|
if source.scheme and source.scheme != "file":
|
|
473
684
|
path = posixpath.join(source.netloc, path)
|
|
685
|
+
elif placement == "filepath":
|
|
686
|
+
path = unquote(self.get_path_normalized())
|
|
474
687
|
elif placement == "checksum":
|
|
475
688
|
raise NotImplementedError("Checksum placement not implemented yet")
|
|
476
689
|
else:
|
|
@@ -505,9 +718,10 @@ class File(DataModel):
|
|
|
505
718
|
) from e
|
|
506
719
|
|
|
507
720
|
try:
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
721
|
+
normalized_path = self.get_path_normalized()
|
|
722
|
+
info = client.fs.info(client.get_full_path(normalized_path))
|
|
723
|
+
converted_info = client.info_to_file(info, normalized_path)
|
|
724
|
+
res = type(self)(
|
|
511
725
|
path=self.path,
|
|
512
726
|
source=self.source,
|
|
513
727
|
size=converted_info.size,
|
|
@@ -517,10 +731,21 @@ class File(DataModel):
|
|
|
517
731
|
last_modified=converted_info.last_modified,
|
|
518
732
|
location=self.location,
|
|
519
733
|
)
|
|
734
|
+
res._set_stream(self._catalog)
|
|
735
|
+
return res
|
|
736
|
+
except FileError as e:
|
|
737
|
+
logger.warning(
|
|
738
|
+
"File error when resolving %s/%s: %s", self.source, self.path, str(e)
|
|
739
|
+
)
|
|
520
740
|
except (FileNotFoundError, PermissionError, OSError) as e:
|
|
521
|
-
logger.warning(
|
|
741
|
+
logger.warning(
|
|
742
|
+
"File system error when resolving %s/%s: %s",
|
|
743
|
+
self.source,
|
|
744
|
+
self.path,
|
|
745
|
+
str(e),
|
|
746
|
+
)
|
|
522
747
|
|
|
523
|
-
|
|
748
|
+
res = type(self)(
|
|
524
749
|
path=self.path,
|
|
525
750
|
source=self.source,
|
|
526
751
|
size=0,
|
|
@@ -530,10 +755,48 @@ class File(DataModel):
|
|
|
530
755
|
last_modified=TIME_ZERO,
|
|
531
756
|
location=self.location,
|
|
532
757
|
)
|
|
758
|
+
res._set_stream(self._catalog)
|
|
759
|
+
return res
|
|
760
|
+
|
|
761
|
+
def rebase(
|
|
762
|
+
self,
|
|
763
|
+
old_base: str,
|
|
764
|
+
new_base: str,
|
|
765
|
+
suffix: str = "",
|
|
766
|
+
extension: str = "",
|
|
767
|
+
) -> str:
|
|
768
|
+
"""
|
|
769
|
+
Rebase the file's URI from one base directory to another.
|
|
770
|
+
|
|
771
|
+
Args:
|
|
772
|
+
old_base: Base directory to remove from the file's URI
|
|
773
|
+
new_base: New base directory to prepend
|
|
774
|
+
suffix: Optional suffix to add before file extension
|
|
775
|
+
extension: Optional new file extension (without dot)
|
|
776
|
+
|
|
777
|
+
Returns:
|
|
778
|
+
str: Rebased URI with new base directory
|
|
779
|
+
|
|
780
|
+
Raises:
|
|
781
|
+
ValueError: If old_base is not found in the file's URI
|
|
782
|
+
|
|
783
|
+
Examples:
|
|
784
|
+
>>> file = File(source="s3://bucket", path="data/2025-05-27/file.wav")
|
|
785
|
+
>>> file.rebase("s3://bucket/data", "s3://output-bucket/processed", \
|
|
786
|
+
extension="mp3")
|
|
787
|
+
's3://output-bucket/processed/2025-05-27/file.mp3'
|
|
788
|
+
|
|
789
|
+
>>> file.rebase("data/audio", "/local/output", suffix="_ch1",
|
|
790
|
+
extension="npy")
|
|
791
|
+
'/local/output/file_ch1.npy'
|
|
792
|
+
"""
|
|
793
|
+
return rebase_path(self.get_uri(), old_base, new_base, suffix, extension)
|
|
533
794
|
|
|
534
795
|
|
|
535
796
|
def resolve(file: File) -> File:
|
|
536
797
|
"""
|
|
798
|
+
[DEPRECATED] Use `file.resolve()` directly instead.
|
|
799
|
+
|
|
537
800
|
Resolve a File object by checking its existence and updating its metadata.
|
|
538
801
|
|
|
539
802
|
This function is a wrapper around the File.resolve() method, designed to be
|
|
@@ -549,6 +812,12 @@ def resolve(file: File) -> File:
|
|
|
549
812
|
RuntimeError: If the file's catalog is not set or if
|
|
550
813
|
the file source protocol is unsupported.
|
|
551
814
|
"""
|
|
815
|
+
warnings.warn(
|
|
816
|
+
"resolve() is deprecated and will be removed "
|
|
817
|
+
"in a future version. Use file.resolve() directly.",
|
|
818
|
+
DeprecationWarning,
|
|
819
|
+
stacklevel=2,
|
|
820
|
+
)
|
|
552
821
|
return file.resolve()
|
|
553
822
|
|
|
554
823
|
|
|
@@ -556,17 +825,30 @@ class TextFile(File):
|
|
|
556
825
|
"""`DataModel` for reading text files."""
|
|
557
826
|
|
|
558
827
|
@contextmanager
|
|
559
|
-
def open(
|
|
560
|
-
|
|
561
|
-
|
|
828
|
+
def open(
|
|
829
|
+
self,
|
|
830
|
+
mode: str = "r",
|
|
831
|
+
*,
|
|
832
|
+
client_config: dict[str, Any] | None = None,
|
|
833
|
+
**open_kwargs,
|
|
834
|
+
) -> Iterator[Any]:
|
|
835
|
+
"""Open the file and return a file-like object.
|
|
836
|
+
Default to text mode"""
|
|
837
|
+
with super().open(
|
|
838
|
+
mode=mode, client_config=client_config, **open_kwargs
|
|
839
|
+
) as stream:
|
|
562
840
|
yield stream
|
|
563
841
|
|
|
564
|
-
def read_text(self):
|
|
565
|
-
"""
|
|
566
|
-
|
|
842
|
+
def read_text(self, **open_kwargs):
|
|
843
|
+
"""Return file contents as text.
|
|
844
|
+
|
|
845
|
+
**open_kwargs : Any
|
|
846
|
+
Extra keyword arguments forwarded to ``open()`` (e.g. encoding).
|
|
847
|
+
"""
|
|
848
|
+
with self.open(**open_kwargs) as stream:
|
|
567
849
|
return stream.read()
|
|
568
850
|
|
|
569
|
-
def save(self, destination: str, client_config:
|
|
851
|
+
def save(self, destination: str, client_config: dict | None = None):
|
|
570
852
|
"""Writes it's content to destination"""
|
|
571
853
|
destination = stringify_path(destination)
|
|
572
854
|
|
|
@@ -599,13 +881,30 @@ class ImageFile(File):
|
|
|
599
881
|
def save( # type: ignore[override]
|
|
600
882
|
self,
|
|
601
883
|
destination: str,
|
|
602
|
-
format:
|
|
603
|
-
client_config:
|
|
884
|
+
format: str | None = None,
|
|
885
|
+
client_config: dict | None = None,
|
|
604
886
|
):
|
|
605
887
|
"""Writes it's content to destination"""
|
|
606
888
|
destination = stringify_path(destination)
|
|
607
889
|
|
|
608
890
|
client: Client = self._catalog.get_client(destination, **(client_config or {}))
|
|
891
|
+
|
|
892
|
+
# If format is not provided, determine it from the file extension
|
|
893
|
+
if format is None:
|
|
894
|
+
from pathlib import PurePosixPath
|
|
895
|
+
|
|
896
|
+
from PIL import Image as PilImage
|
|
897
|
+
|
|
898
|
+
ext = PurePosixPath(destination).suffix.lower()
|
|
899
|
+
format = PilImage.registered_extensions().get(ext)
|
|
900
|
+
|
|
901
|
+
if not format:
|
|
902
|
+
raise FileError(
|
|
903
|
+
f"Can't determine format for destination '{destination}'",
|
|
904
|
+
self.source,
|
|
905
|
+
self.path,
|
|
906
|
+
)
|
|
907
|
+
|
|
609
908
|
with client.fs.open(destination, mode="wb") as f:
|
|
610
909
|
self.read().save(f, format=format)
|
|
611
910
|
|
|
@@ -665,7 +964,7 @@ class VideoFile(File):
|
|
|
665
964
|
def get_frames(
|
|
666
965
|
self,
|
|
667
966
|
start: int = 0,
|
|
668
|
-
end:
|
|
967
|
+
end: int | None = None,
|
|
669
968
|
step: int = 1,
|
|
670
969
|
) -> "Iterator[VideoFrame]":
|
|
671
970
|
"""
|
|
@@ -704,7 +1003,10 @@ class VideoFile(File):
|
|
|
704
1003
|
VideoFragment: A Model representing the video fragment.
|
|
705
1004
|
"""
|
|
706
1005
|
if start < 0 or end < 0 or start >= end:
|
|
707
|
-
raise ValueError(
|
|
1006
|
+
raise ValueError(
|
|
1007
|
+
f"Can't get video fragment for '{self.path}', "
|
|
1008
|
+
f"invalid time range: ({start:.3f}, {end:.3f})"
|
|
1009
|
+
)
|
|
708
1010
|
|
|
709
1011
|
return VideoFragment(video=self, start=start, end=end)
|
|
710
1012
|
|
|
@@ -712,7 +1014,7 @@ class VideoFile(File):
|
|
|
712
1014
|
self,
|
|
713
1015
|
duration: float,
|
|
714
1016
|
start: float = 0,
|
|
715
|
-
end:
|
|
1017
|
+
end: float | None = None,
|
|
716
1018
|
) -> "Iterator[VideoFragment]":
|
|
717
1019
|
"""
|
|
718
1020
|
Splits the video into multiple fragments of a specified duration.
|
|
@@ -748,6 +1050,189 @@ class VideoFile(File):
|
|
|
748
1050
|
start += duration
|
|
749
1051
|
|
|
750
1052
|
|
|
1053
|
+
class AudioFile(File):
|
|
1054
|
+
"""
|
|
1055
|
+
A data model for handling audio files.
|
|
1056
|
+
|
|
1057
|
+
This model inherits from the `File` model and provides additional functionality
|
|
1058
|
+
for reading audio files, extracting audio fragments, and splitting audio into
|
|
1059
|
+
fragments.
|
|
1060
|
+
"""
|
|
1061
|
+
|
|
1062
|
+
def get_info(self) -> "Audio":
|
|
1063
|
+
"""
|
|
1064
|
+
Retrieves metadata and information about the audio file. It does not
|
|
1065
|
+
download the file if possible, only reads its header. It is thus might be
|
|
1066
|
+
a good idea to disable caching and prefetching for UDF if you only need
|
|
1067
|
+
audio metadata.
|
|
1068
|
+
|
|
1069
|
+
Returns:
|
|
1070
|
+
Audio: A Model containing audio metadata such as duration,
|
|
1071
|
+
sample rate, channels, and codec details.
|
|
1072
|
+
"""
|
|
1073
|
+
from .audio import audio_info
|
|
1074
|
+
|
|
1075
|
+
return audio_info(self)
|
|
1076
|
+
|
|
1077
|
+
def get_fragment(self, start: float, end: float) -> "AudioFragment":
|
|
1078
|
+
"""
|
|
1079
|
+
Returns an audio fragment from the specified time range. It does not
|
|
1080
|
+
download the file, neither it actually extracts the fragment. It returns
|
|
1081
|
+
a Model representing the audio fragment, which can be used to read or save
|
|
1082
|
+
it later.
|
|
1083
|
+
|
|
1084
|
+
Args:
|
|
1085
|
+
start (float): The start time of the fragment in seconds.
|
|
1086
|
+
end (float): The end time of the fragment in seconds.
|
|
1087
|
+
|
|
1088
|
+
Returns:
|
|
1089
|
+
AudioFragment: A Model representing the audio fragment.
|
|
1090
|
+
"""
|
|
1091
|
+
if start < 0 or end < 0 or start >= end:
|
|
1092
|
+
raise ValueError(
|
|
1093
|
+
f"Can't get audio fragment for '{self.path}', "
|
|
1094
|
+
f"invalid time range: ({start:.3f}, {end:.3f})"
|
|
1095
|
+
)
|
|
1096
|
+
|
|
1097
|
+
return AudioFragment(audio=self, start=start, end=end)
|
|
1098
|
+
|
|
1099
|
+
def get_fragments(
|
|
1100
|
+
self,
|
|
1101
|
+
duration: float,
|
|
1102
|
+
start: float = 0,
|
|
1103
|
+
end: float | None = None,
|
|
1104
|
+
) -> "Iterator[AudioFragment]":
|
|
1105
|
+
"""
|
|
1106
|
+
Splits the audio into multiple fragments of a specified duration.
|
|
1107
|
+
|
|
1108
|
+
Args:
|
|
1109
|
+
duration (float): The duration of each audio fragment in seconds.
|
|
1110
|
+
start (float): The starting time in seconds (default: 0).
|
|
1111
|
+
end (float, optional): The ending time in seconds. If None, the entire
|
|
1112
|
+
remaining audio is processed (default: None).
|
|
1113
|
+
|
|
1114
|
+
Returns:
|
|
1115
|
+
Iterator[AudioFragment]: An iterator yielding audio fragments.
|
|
1116
|
+
|
|
1117
|
+
Note:
|
|
1118
|
+
If end is not specified, number of samples will be taken from the
|
|
1119
|
+
audio file, this means audio file needs to be downloaded.
|
|
1120
|
+
"""
|
|
1121
|
+
if duration <= 0:
|
|
1122
|
+
raise ValueError("duration must be a positive float")
|
|
1123
|
+
if start < 0:
|
|
1124
|
+
raise ValueError("start must be a non-negative float")
|
|
1125
|
+
|
|
1126
|
+
if end is None:
|
|
1127
|
+
end = self.get_info().duration
|
|
1128
|
+
|
|
1129
|
+
if end < 0:
|
|
1130
|
+
raise ValueError("end must be a non-negative float")
|
|
1131
|
+
if start >= end:
|
|
1132
|
+
raise ValueError("start must be less than end")
|
|
1133
|
+
|
|
1134
|
+
while start < end:
|
|
1135
|
+
yield self.get_fragment(start, min(start + duration, end))
|
|
1136
|
+
start += duration
|
|
1137
|
+
|
|
1138
|
+
def save( # type: ignore[override]
|
|
1139
|
+
self,
|
|
1140
|
+
output: str,
|
|
1141
|
+
format: str | None = None,
|
|
1142
|
+
start: float = 0,
|
|
1143
|
+
end: float | None = None,
|
|
1144
|
+
client_config: dict | None = None,
|
|
1145
|
+
) -> "AudioFile":
|
|
1146
|
+
"""Save audio file or extract fragment to specified format.
|
|
1147
|
+
|
|
1148
|
+
Args:
|
|
1149
|
+
output: Output directory path
|
|
1150
|
+
format: Output format ('wav', 'mp3', etc). Defaults to source format
|
|
1151
|
+
start: Start time in seconds (>= 0). Defaults to 0
|
|
1152
|
+
end: End time in seconds. If None, extracts to end of file
|
|
1153
|
+
client_config: Optional client configuration
|
|
1154
|
+
|
|
1155
|
+
Returns:
|
|
1156
|
+
AudioFile: New audio file with format conversion/extraction applied
|
|
1157
|
+
|
|
1158
|
+
Examples:
|
|
1159
|
+
audio.save("/path", "mp3") # Entire file to MP3
|
|
1160
|
+
audio.save("s3://bucket/path", "wav", start=2.5) # From 2.5s to end as WAV
|
|
1161
|
+
audio.save("/path", "flac", start=1, end=3) # 1-3s fragment as FLAC
|
|
1162
|
+
"""
|
|
1163
|
+
from .audio import save_audio
|
|
1164
|
+
|
|
1165
|
+
return save_audio(self, output, format, start, end)
|
|
1166
|
+
|
|
1167
|
+
|
|
1168
|
+
class AudioFragment(DataModel):
|
|
1169
|
+
"""
|
|
1170
|
+
A data model for representing an audio fragment.
|
|
1171
|
+
|
|
1172
|
+
This model represents a specific fragment within an audio file with defined
|
|
1173
|
+
start and end times. It allows access to individual fragments and provides
|
|
1174
|
+
functionality for reading and saving audio fragments as separate audio files.
|
|
1175
|
+
|
|
1176
|
+
Attributes:
|
|
1177
|
+
audio (AudioFile): The audio file containing the audio fragment.
|
|
1178
|
+
start (float): The starting time of the audio fragment in seconds.
|
|
1179
|
+
end (float): The ending time of the audio fragment in seconds.
|
|
1180
|
+
"""
|
|
1181
|
+
|
|
1182
|
+
audio: AudioFile
|
|
1183
|
+
start: float
|
|
1184
|
+
end: float
|
|
1185
|
+
|
|
1186
|
+
def get_np(self) -> tuple["ndarray", int]:
|
|
1187
|
+
"""
|
|
1188
|
+
Returns the audio fragment as a NumPy array with sample rate.
|
|
1189
|
+
|
|
1190
|
+
Returns:
|
|
1191
|
+
tuple[ndarray, int]: A tuple containing the audio data as a NumPy array
|
|
1192
|
+
and the sample rate.
|
|
1193
|
+
"""
|
|
1194
|
+
from .audio import audio_to_np
|
|
1195
|
+
|
|
1196
|
+
duration = self.end - self.start
|
|
1197
|
+
return audio_to_np(self.audio, self.start, duration)
|
|
1198
|
+
|
|
1199
|
+
def read_bytes(self, format: str = "wav") -> bytes:
|
|
1200
|
+
"""
|
|
1201
|
+
Returns the audio fragment as audio bytes.
|
|
1202
|
+
|
|
1203
|
+
Args:
|
|
1204
|
+
format (str): The desired audio format (e.g., 'wav', 'mp3').
|
|
1205
|
+
Defaults to 'wav'.
|
|
1206
|
+
|
|
1207
|
+
Returns:
|
|
1208
|
+
bytes: The encoded audio fragment as bytes.
|
|
1209
|
+
"""
|
|
1210
|
+
from .audio import audio_to_bytes
|
|
1211
|
+
|
|
1212
|
+
duration = self.end - self.start
|
|
1213
|
+
return audio_to_bytes(self.audio, format, self.start, duration)
|
|
1214
|
+
|
|
1215
|
+
def save(self, output: str, format: str | None = None) -> "AudioFile":
|
|
1216
|
+
"""
|
|
1217
|
+
Saves the audio fragment as a new audio file.
|
|
1218
|
+
|
|
1219
|
+
If `output` is a remote path, the audio file will be uploaded to remote storage.
|
|
1220
|
+
|
|
1221
|
+
Args:
|
|
1222
|
+
output (str): The destination path, which can be a local file path
|
|
1223
|
+
or a remote URL.
|
|
1224
|
+
format (str, optional): The output audio format (e.g., 'wav', 'mp3').
|
|
1225
|
+
If None, the format is inferred from the
|
|
1226
|
+
file extension.
|
|
1227
|
+
|
|
1228
|
+
Returns:
|
|
1229
|
+
AudioFile: A Model representing the saved audio file.
|
|
1230
|
+
"""
|
|
1231
|
+
from .audio import save_audio
|
|
1232
|
+
|
|
1233
|
+
return save_audio(self.audio, output, format, self.start, self.end)
|
|
1234
|
+
|
|
1235
|
+
|
|
751
1236
|
class VideoFrame(DataModel):
|
|
752
1237
|
"""
|
|
753
1238
|
A data model for representing a video frame.
|
|
@@ -830,7 +1315,7 @@ class VideoFragment(DataModel):
|
|
|
830
1315
|
start: float
|
|
831
1316
|
end: float
|
|
832
1317
|
|
|
833
|
-
def save(self, output: str, format:
|
|
1318
|
+
def save(self, output: str, format: str | None = None) -> "VideoFile":
|
|
834
1319
|
"""
|
|
835
1320
|
Saves the video fragment as a new video file.
|
|
836
1321
|
|
|
@@ -878,6 +1363,52 @@ class Video(DataModel):
|
|
|
878
1363
|
codec: str = Field(default="")
|
|
879
1364
|
|
|
880
1365
|
|
|
1366
|
+
class Audio(DataModel):
|
|
1367
|
+
"""
|
|
1368
|
+
A data model representing metadata for an audio file.
|
|
1369
|
+
|
|
1370
|
+
Attributes:
|
|
1371
|
+
sample_rate (int): The sample rate of the audio (samples per second).
|
|
1372
|
+
Defaults to -1 if unknown.
|
|
1373
|
+
channels (int): The number of audio channels. Defaults to -1 if unknown.
|
|
1374
|
+
duration (float): The total duration of the audio in seconds.
|
|
1375
|
+
Defaults to -1.0 if unknown.
|
|
1376
|
+
samples (int): The total number of samples in the audio.
|
|
1377
|
+
Defaults to -1 if unknown.
|
|
1378
|
+
format (str): The format of the audio file (e.g., 'wav', 'mp3').
|
|
1379
|
+
Defaults to an empty string.
|
|
1380
|
+
codec (str): The codec used for encoding the audio. Defaults to an empty string.
|
|
1381
|
+
bit_rate (int): The bit rate of the audio in bits per second.
|
|
1382
|
+
Defaults to -1 if unknown.
|
|
1383
|
+
"""
|
|
1384
|
+
|
|
1385
|
+
sample_rate: int = Field(default=-1)
|
|
1386
|
+
channels: int = Field(default=-1)
|
|
1387
|
+
duration: float = Field(default=-1.0)
|
|
1388
|
+
samples: int = Field(default=-1)
|
|
1389
|
+
format: str = Field(default="")
|
|
1390
|
+
codec: str = Field(default="")
|
|
1391
|
+
bit_rate: int = Field(default=-1)
|
|
1392
|
+
|
|
1393
|
+
@staticmethod
|
|
1394
|
+
def get_channel_name(num_channels: int, channel_idx: int) -> str:
|
|
1395
|
+
"""Map channel index to meaningful name based on common audio formats"""
|
|
1396
|
+
channel_mappings = {
|
|
1397
|
+
1: ["Mono"],
|
|
1398
|
+
2: ["Left", "Right"],
|
|
1399
|
+
4: ["W", "X", "Y", "Z"], # First-order Ambisonics
|
|
1400
|
+
6: ["FL", "FR", "FC", "LFE", "BL", "BR"], # 5.1 surround
|
|
1401
|
+
8: ["FL", "FR", "FC", "LFE", "BL", "BR", "SL", "SR"], # 7.1 surround
|
|
1402
|
+
}
|
|
1403
|
+
|
|
1404
|
+
if num_channels in channel_mappings:
|
|
1405
|
+
channels = channel_mappings[num_channels]
|
|
1406
|
+
if 0 <= channel_idx < len(channels):
|
|
1407
|
+
return channels[channel_idx]
|
|
1408
|
+
|
|
1409
|
+
return f"Ch{channel_idx + 1}"
|
|
1410
|
+
|
|
1411
|
+
|
|
881
1412
|
class ArrowRow(DataModel):
|
|
882
1413
|
"""`DataModel` for reading row from Arrow-supported file."""
|
|
883
1414
|
|
|
@@ -896,7 +1427,7 @@ class ArrowRow(DataModel):
|
|
|
896
1427
|
ds = dataset(path, **self.kwargs)
|
|
897
1428
|
|
|
898
1429
|
else:
|
|
899
|
-
path = self.file.
|
|
1430
|
+
path = self.file.get_fs_path()
|
|
900
1431
|
ds = dataset(path, filesystem=self.file.get_fs(), **self.kwargs)
|
|
901
1432
|
|
|
902
1433
|
return ds.take([self.index]).to_reader()
|
|
@@ -915,5 +1446,7 @@ def get_file_type(type_: FileType = "binary") -> type[File]:
|
|
|
915
1446
|
file = ImageFile # type: ignore[assignment]
|
|
916
1447
|
elif type_ == "video":
|
|
917
1448
|
file = VideoFile
|
|
1449
|
+
elif type_ == "audio":
|
|
1450
|
+
file = AudioFile
|
|
918
1451
|
|
|
919
1452
|
return file
|