datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/utils.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
import re
|
|
2
3
|
from abc import ABC, abstractmethod
|
|
3
4
|
from collections.abc import Sequence
|
|
5
|
+
from pathlib import PurePosixPath
|
|
6
|
+
from urllib.parse import urlparse
|
|
4
7
|
|
|
5
8
|
|
|
6
9
|
class AbstractUDF(ABC):
|
|
@@ -18,13 +21,11 @@ class AbstractUDF(ABC):
|
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
class DataChainError(Exception):
|
|
21
|
-
|
|
22
|
-
super().__init__(message)
|
|
24
|
+
pass
|
|
23
25
|
|
|
24
26
|
|
|
25
27
|
class DataChainParamsError(DataChainError):
|
|
26
|
-
|
|
27
|
-
super().__init__(message)
|
|
28
|
+
pass
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
class DataChainColumnError(DataChainParamsError):
|
|
@@ -32,6 +33,25 @@ class DataChainColumnError(DataChainParamsError):
|
|
|
32
33
|
super().__init__(f"Error for column {col_name}: {msg}")
|
|
33
34
|
|
|
34
35
|
|
|
36
|
+
def callable_name(obj: object) -> str:
|
|
37
|
+
"""Return a friendly name for a callable or UDF-like instance."""
|
|
38
|
+
# UDF classes in DataChain inherit from AbstractUDF; prefer class name
|
|
39
|
+
if isinstance(obj, AbstractUDF):
|
|
40
|
+
return obj.__class__.__name__
|
|
41
|
+
|
|
42
|
+
# Plain functions and bound/unbound methods
|
|
43
|
+
if inspect.ismethod(obj) or inspect.isfunction(obj):
|
|
44
|
+
# __name__ exists for functions/methods; includes "<lambda>" for lambdas
|
|
45
|
+
return obj.__name__ # type: ignore[attr-defined]
|
|
46
|
+
|
|
47
|
+
# Generic callable object
|
|
48
|
+
if callable(obj):
|
|
49
|
+
return obj.__class__.__name__
|
|
50
|
+
|
|
51
|
+
# Fallback for non-callables
|
|
52
|
+
return str(obj)
|
|
53
|
+
|
|
54
|
+
|
|
35
55
|
def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
|
|
36
56
|
"""Returns normalized_name -> original_name dict."""
|
|
37
57
|
gen_col_counter = 0
|
|
@@ -59,3 +79,97 @@ def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
|
|
|
59
79
|
new_col_names[generated_column] = org_column
|
|
60
80
|
|
|
61
81
|
return new_col_names
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def rebase_path(
|
|
85
|
+
src_path: str,
|
|
86
|
+
old_base: str,
|
|
87
|
+
new_base: str,
|
|
88
|
+
suffix: str = "",
|
|
89
|
+
extension: str = "",
|
|
90
|
+
) -> str:
|
|
91
|
+
"""
|
|
92
|
+
Rebase a file path from one base directory to another.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
src_path: Source file path (can include URI scheme like s3://)
|
|
96
|
+
old_base: Base directory to remove from src_path
|
|
97
|
+
new_base: New base directory to prepend
|
|
98
|
+
suffix: Optional suffix to add before file extension
|
|
99
|
+
extension: Optional new file extension (without dot)
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
str: Rebased path with new base directory
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
ValueError: If old_base is not found in src_path
|
|
106
|
+
"""
|
|
107
|
+
# Parse URIs to handle schemes properly
|
|
108
|
+
src_parsed = urlparse(src_path)
|
|
109
|
+
old_base_parsed = urlparse(old_base)
|
|
110
|
+
new_base_parsed = urlparse(new_base)
|
|
111
|
+
|
|
112
|
+
# Get the path component (without scheme)
|
|
113
|
+
if src_parsed.scheme:
|
|
114
|
+
src_path_only = src_parsed.netloc + src_parsed.path
|
|
115
|
+
else:
|
|
116
|
+
src_path_only = src_path
|
|
117
|
+
|
|
118
|
+
if old_base_parsed.scheme:
|
|
119
|
+
old_base_only = old_base_parsed.netloc + old_base_parsed.path
|
|
120
|
+
else:
|
|
121
|
+
old_base_only = old_base
|
|
122
|
+
|
|
123
|
+
# Normalize paths
|
|
124
|
+
src_path_norm = PurePosixPath(src_path_only).as_posix()
|
|
125
|
+
old_base_norm = PurePosixPath(old_base_only).as_posix()
|
|
126
|
+
|
|
127
|
+
# Find where old_base appears in src_path
|
|
128
|
+
if old_base_norm in src_path_norm:
|
|
129
|
+
# Find the index where old_base appears
|
|
130
|
+
idx = src_path_norm.find(old_base_norm)
|
|
131
|
+
if idx == -1:
|
|
132
|
+
raise ValueError(f"old_base '{old_base}' not found in src_path")
|
|
133
|
+
|
|
134
|
+
# Extract the relative path after old_base
|
|
135
|
+
relative_start = idx + len(old_base_norm)
|
|
136
|
+
# Skip leading slash if present
|
|
137
|
+
if relative_start < len(src_path_norm) and src_path_norm[relative_start] == "/":
|
|
138
|
+
relative_start += 1
|
|
139
|
+
relative_path = src_path_norm[relative_start:]
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError(f"old_base '{old_base}' not found in src_path")
|
|
142
|
+
|
|
143
|
+
# Parse the filename
|
|
144
|
+
path_obj = PurePosixPath(relative_path)
|
|
145
|
+
stem = path_obj.stem
|
|
146
|
+
current_ext = path_obj.suffix
|
|
147
|
+
|
|
148
|
+
# Apply suffix and extension changes
|
|
149
|
+
new_stem = stem + suffix if suffix else stem
|
|
150
|
+
if extension:
|
|
151
|
+
new_ext = f".{extension}"
|
|
152
|
+
elif current_ext:
|
|
153
|
+
new_ext = current_ext
|
|
154
|
+
else:
|
|
155
|
+
new_ext = ""
|
|
156
|
+
|
|
157
|
+
# Build new filename
|
|
158
|
+
new_name = new_stem + new_ext
|
|
159
|
+
|
|
160
|
+
# Reconstruct path with new base
|
|
161
|
+
parent = str(path_obj.parent)
|
|
162
|
+
if parent == ".":
|
|
163
|
+
new_relative_path = new_name
|
|
164
|
+
else:
|
|
165
|
+
new_relative_path = str(PurePosixPath(parent) / new_name)
|
|
166
|
+
|
|
167
|
+
# Handle new_base URI scheme
|
|
168
|
+
if new_base_parsed.scheme:
|
|
169
|
+
# Has schema like s3://
|
|
170
|
+
base_path = new_base_parsed.netloc + new_base_parsed.path
|
|
171
|
+
base_path = PurePosixPath(base_path).as_posix()
|
|
172
|
+
full_path = str(PurePosixPath(base_path) / new_relative_path)
|
|
173
|
+
return f"{new_base_parsed.scheme}://{full_path}"
|
|
174
|
+
# Regular path
|
|
175
|
+
return str(PurePosixPath(new_base) / new_relative_path)
|
datachain/lib/video.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import posixpath
|
|
2
2
|
import shutil
|
|
3
3
|
import tempfile
|
|
4
|
-
from typing import Optional, Union
|
|
5
4
|
|
|
6
5
|
from numpy import ndarray
|
|
7
6
|
|
|
@@ -18,7 +17,7 @@ except ImportError as exc:
|
|
|
18
17
|
) from exc
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
def video_info(file:
|
|
20
|
+
def video_info(file: File | VideoFile) -> Video:
|
|
22
21
|
"""
|
|
23
22
|
Returns video file information.
|
|
24
23
|
|
|
@@ -34,21 +33,27 @@ def video_info(file: Union[File, VideoFile]) -> Video:
|
|
|
34
33
|
file.ensure_cached()
|
|
35
34
|
file_path = file.get_local_path()
|
|
36
35
|
if not file_path:
|
|
37
|
-
raise FileError(
|
|
36
|
+
raise FileError("unable to download video file", file.source, file.path)
|
|
38
37
|
|
|
39
38
|
try:
|
|
40
39
|
probe = ffmpeg.probe(file_path)
|
|
41
40
|
except Exception as exc:
|
|
42
|
-
raise FileError(
|
|
41
|
+
raise FileError(
|
|
42
|
+
"unable to extract metadata from video file", file.source, file.path
|
|
43
|
+
) from exc
|
|
43
44
|
|
|
44
45
|
all_streams = probe.get("streams")
|
|
45
46
|
video_format = probe.get("format")
|
|
46
47
|
if not all_streams or not video_format:
|
|
47
|
-
raise FileError(
|
|
48
|
+
raise FileError(
|
|
49
|
+
"unable to extract metadata from video file", file.source, file.path
|
|
50
|
+
)
|
|
48
51
|
|
|
49
52
|
video_streams = [s for s in all_streams if s["codec_type"] == "video"]
|
|
50
53
|
if len(video_streams) == 0:
|
|
51
|
-
raise FileError(
|
|
54
|
+
raise FileError(
|
|
55
|
+
"unable to extract metadata from video file", file.source, file.path
|
|
56
|
+
)
|
|
52
57
|
|
|
53
58
|
video_stream = video_streams[0]
|
|
54
59
|
|
|
@@ -102,7 +107,7 @@ def video_frame_np(video: VideoFile, frame: int) -> ndarray:
|
|
|
102
107
|
def validate_frame_range(
|
|
103
108
|
video: VideoFile,
|
|
104
109
|
start: int = 0,
|
|
105
|
-
end:
|
|
110
|
+
end: int | None = None,
|
|
106
111
|
step: int = 1,
|
|
107
112
|
) -> tuple[int, int, int]:
|
|
108
113
|
"""
|
|
@@ -180,7 +185,7 @@ def save_video_fragment(
|
|
|
180
185
|
start: float,
|
|
181
186
|
end: float,
|
|
182
187
|
output: str,
|
|
183
|
-
format:
|
|
188
|
+
format: str | None = None,
|
|
184
189
|
) -> VideoFile:
|
|
185
190
|
"""
|
|
186
191
|
Saves video interval as a new video file. If output is a remote path,
|
|
@@ -199,7 +204,10 @@ def save_video_fragment(
|
|
|
199
204
|
VideoFile: Video fragment model.
|
|
200
205
|
"""
|
|
201
206
|
if start < 0 or end < 0 or start >= end:
|
|
202
|
-
raise ValueError(
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"Can't save video fragment for '{video.path}', "
|
|
209
|
+
f"invalid time range: ({start:.3f}, {end:.3f})"
|
|
210
|
+
)
|
|
203
211
|
|
|
204
212
|
if format is None:
|
|
205
213
|
format = video.get_file_ext()
|
datachain/lib/webdataset.py
CHANGED
|
@@ -1,20 +1,13 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import tarfile
|
|
2
|
+
import types
|
|
3
3
|
import warnings
|
|
4
|
-
from collections.abc import Iterator, Sequence
|
|
4
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
7
|
-
Any,
|
|
8
|
-
Callable,
|
|
9
|
-
ClassVar,
|
|
10
|
-
Optional,
|
|
11
|
-
Union,
|
|
12
|
-
get_args,
|
|
13
|
-
get_origin,
|
|
14
|
-
)
|
|
6
|
+
from typing import Any, ClassVar, Union, get_args, get_origin
|
|
15
7
|
|
|
16
8
|
from pydantic import Field
|
|
17
9
|
|
|
10
|
+
from datachain import json
|
|
18
11
|
from datachain.lib.data_model import DataModel
|
|
19
12
|
from datachain.lib.file import File
|
|
20
13
|
from datachain.lib.tar import build_tar_member
|
|
@@ -34,29 +27,29 @@ warnings.filterwarnings(
|
|
|
34
27
|
|
|
35
28
|
|
|
36
29
|
class WDSError(DataChainError):
|
|
37
|
-
def __init__(self,
|
|
38
|
-
super().__init__(f"WebDataset error '{
|
|
30
|
+
def __init__(self, tar_name: str, message: str):
|
|
31
|
+
super().__init__(f"WebDataset error '{tar_name}': {message}")
|
|
39
32
|
|
|
40
33
|
|
|
41
34
|
class CoreFileDuplicationError(WDSError):
|
|
42
|
-
def __init__(self,
|
|
35
|
+
def __init__(self, tar_name: str, file1: str, file2: str):
|
|
43
36
|
super().__init__(
|
|
44
|
-
|
|
37
|
+
tar_name, f"duplication of files with core extensions: {file1}, {file2}"
|
|
45
38
|
)
|
|
46
39
|
|
|
47
40
|
|
|
48
41
|
class CoreFileNotFoundError(WDSError):
|
|
49
|
-
def __init__(self,
|
|
42
|
+
def __init__(self, tar_name: str, extensions: Sequence[str], stem: str):
|
|
50
43
|
super().__init__(
|
|
51
|
-
|
|
44
|
+
tar_name,
|
|
52
45
|
f"no files with the extensions '{','.join(extensions)}'"
|
|
53
46
|
f" were found for file stem {stem}",
|
|
54
47
|
)
|
|
55
48
|
|
|
56
49
|
|
|
57
50
|
class UnknownFileExtensionError(WDSError):
|
|
58
|
-
def __init__(self,
|
|
59
|
-
super().__init__(
|
|
51
|
+
def __init__(self, tar_name, name: str, ext: str):
|
|
52
|
+
super().__init__(tar_name, f"unknown extension '{ext}' for file '{name}'")
|
|
60
53
|
|
|
61
54
|
|
|
62
55
|
class WDSBasic(DataModel):
|
|
@@ -64,28 +57,28 @@ class WDSBasic(DataModel):
|
|
|
64
57
|
|
|
65
58
|
|
|
66
59
|
class WDSAllFile(WDSBasic):
|
|
67
|
-
txt:
|
|
68
|
-
text:
|
|
69
|
-
cap:
|
|
70
|
-
transcript:
|
|
71
|
-
cls:
|
|
72
|
-
cls2:
|
|
73
|
-
index:
|
|
74
|
-
inx:
|
|
75
|
-
id:
|
|
76
|
-
json:
|
|
77
|
-
jsn:
|
|
78
|
-
|
|
79
|
-
pyd:
|
|
80
|
-
pickle:
|
|
81
|
-
pth:
|
|
82
|
-
ten:
|
|
83
|
-
tb:
|
|
84
|
-
mp:
|
|
85
|
-
msg:
|
|
86
|
-
npy:
|
|
87
|
-
npz:
|
|
88
|
-
cbor:
|
|
60
|
+
txt: str | None = Field(default=None)
|
|
61
|
+
text: str | None = Field(default=None)
|
|
62
|
+
cap: str | None = Field(default=None)
|
|
63
|
+
transcript: str | None = Field(default=None)
|
|
64
|
+
cls: int | None = Field(default=None)
|
|
65
|
+
cls2: int | None = Field(default=None)
|
|
66
|
+
index: int | None = Field(default=None)
|
|
67
|
+
inx: int | None = Field(default=None)
|
|
68
|
+
id: int | None = Field(default=None)
|
|
69
|
+
json: dict | None = Field(default=None) # type: ignore[assignment]
|
|
70
|
+
jsn: dict | None = Field(default=None)
|
|
71
|
+
|
|
72
|
+
pyd: bytes | None = Field(default=None)
|
|
73
|
+
pickle: bytes | None = Field(default=None)
|
|
74
|
+
pth: bytes | None = Field(default=None)
|
|
75
|
+
ten: bytes | None = Field(default=None)
|
|
76
|
+
tb: bytes | None = Field(default=None)
|
|
77
|
+
mp: bytes | None = Field(default=None)
|
|
78
|
+
msg: bytes | None = Field(default=None)
|
|
79
|
+
npy: bytes | None = Field(default=None)
|
|
80
|
+
npz: bytes | None = Field(default=None)
|
|
81
|
+
cbor: bytes | None = Field(default=None)
|
|
89
82
|
|
|
90
83
|
|
|
91
84
|
class WDSReadableSubclass(DataModel):
|
|
@@ -113,10 +106,10 @@ class Builder:
|
|
|
113
106
|
def __init__(
|
|
114
107
|
self,
|
|
115
108
|
tar_stream: File,
|
|
116
|
-
core_extensions:
|
|
109
|
+
core_extensions: Sequence[str],
|
|
117
110
|
wds_class: type[WDSBasic],
|
|
118
|
-
tar,
|
|
119
|
-
encoding="utf-8",
|
|
111
|
+
tar: tarfile.TarFile,
|
|
112
|
+
encoding: str = "utf-8",
|
|
120
113
|
):
|
|
121
114
|
self._core_extensions = core_extensions
|
|
122
115
|
self._tar_stream = tar_stream
|
|
@@ -145,18 +138,20 @@ class Builder:
|
|
|
145
138
|
if ext in self._core_extensions:
|
|
146
139
|
if self.state.core_file is not None:
|
|
147
140
|
raise CoreFileDuplicationError(
|
|
148
|
-
self._tar_stream, file.name, self.state.core_file.name
|
|
141
|
+
self._tar_stream.name, file.name, self.state.core_file.name
|
|
149
142
|
)
|
|
150
143
|
self.state.core_file = file
|
|
151
144
|
elif ext in self.state.data:
|
|
152
145
|
raise WDSError(
|
|
153
|
-
self._tar_stream,
|
|
146
|
+
self._tar_stream.name,
|
|
154
147
|
f"file with extension '.{ext}' already exists in the archive",
|
|
155
148
|
)
|
|
156
149
|
else:
|
|
157
150
|
type_ = self._get_type(ext)
|
|
158
151
|
if type_ is None:
|
|
159
|
-
raise UnknownFileExtensionError(
|
|
152
|
+
raise UnknownFileExtensionError(
|
|
153
|
+
self._tar_stream.name, fstream.name, ext
|
|
154
|
+
)
|
|
160
155
|
|
|
161
156
|
if issubclass(type_, WDSReadableSubclass):
|
|
162
157
|
reader = type_._reader
|
|
@@ -165,7 +160,7 @@ class Builder:
|
|
|
165
160
|
|
|
166
161
|
if reader is None:
|
|
167
162
|
raise WDSError(
|
|
168
|
-
self._tar_stream,
|
|
163
|
+
self._tar_stream.name,
|
|
169
164
|
f"unable to find a reader for type {type_}, extension .{ext}",
|
|
170
165
|
)
|
|
171
166
|
self.state.data[ext] = reader(self, file)
|
|
@@ -173,7 +168,7 @@ class Builder:
|
|
|
173
168
|
def produce(self):
|
|
174
169
|
if self.state.core_file is None:
|
|
175
170
|
raise CoreFileNotFoundError(
|
|
176
|
-
self._tar_stream, self._core_extensions, self.state.stem
|
|
171
|
+
self._tar_stream.name, self._core_extensions, self.state.stem
|
|
177
172
|
)
|
|
178
173
|
|
|
179
174
|
file = build_tar_member(self._tar_stream, self.state.core_file)
|
|
@@ -187,14 +182,22 @@ class Builder:
|
|
|
187
182
|
return
|
|
188
183
|
|
|
189
184
|
anno = field.annotation
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
185
|
+
anno_origin = get_origin(anno)
|
|
186
|
+
if anno_origin in (Union, types.UnionType):
|
|
187
|
+
anno_args = get_args(anno)
|
|
188
|
+
if len(anno_args) == 2 and type(None) in anno_args:
|
|
189
|
+
return anno_args[0] if anno_args[1] is type(None) else anno_args[1]
|
|
193
190
|
|
|
194
191
|
return anno
|
|
195
192
|
|
|
196
193
|
|
|
197
|
-
def get_tar_groups(
|
|
194
|
+
def get_tar_groups(
|
|
195
|
+
stream: File,
|
|
196
|
+
tar: tarfile.TarFile,
|
|
197
|
+
core_extensions: Sequence[str],
|
|
198
|
+
spec: type[WDSBasic],
|
|
199
|
+
encoding: str = "utf-8",
|
|
200
|
+
) -> Iterator[WDSBasic]:
|
|
198
201
|
builder = Builder(stream, core_extensions, spec, tar, encoding)
|
|
199
202
|
|
|
200
203
|
for item in sorted(tar.getmembers(), key=lambda m: Path(m.name).stem):
|
|
@@ -210,9 +213,11 @@ def get_tar_groups(stream, tar, core_extensions, spec, encoding="utf-8"):
|
|
|
210
213
|
|
|
211
214
|
|
|
212
215
|
def process_webdataset(
|
|
213
|
-
core_extensions: Sequence[str] = ("jpg", "png"),
|
|
214
|
-
|
|
215
|
-
|
|
216
|
+
core_extensions: Sequence[str] = ("jpg", "png"),
|
|
217
|
+
spec: type[WDSBasic] = WDSAllFile,
|
|
218
|
+
encoding: str = "utf-8",
|
|
219
|
+
) -> Callable[[File], Iterator]:
|
|
220
|
+
def wds_func(file: File) -> Iterator[spec]: # type: ignore[valid-type]
|
|
216
221
|
with file.open() as fd:
|
|
217
222
|
with tarfile.open(fileobj=fd) as tar:
|
|
218
223
|
yield from get_tar_groups(file, tar, core_extensions, spec, encoding)
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from collections.abc import Iterator
|
|
3
|
-
from typing import Optional
|
|
4
3
|
|
|
5
4
|
import numpy as np
|
|
6
5
|
from pydantic import BaseModel, Field
|
|
@@ -23,18 +22,18 @@ warnings.filterwarnings(
|
|
|
23
22
|
|
|
24
23
|
class Laion(WDSReadableSubclass):
|
|
25
24
|
uid: str = Field(default="")
|
|
26
|
-
face_bboxes:
|
|
27
|
-
caption:
|
|
28
|
-
url:
|
|
29
|
-
key:
|
|
30
|
-
status:
|
|
31
|
-
error_message:
|
|
32
|
-
width:
|
|
33
|
-
height:
|
|
34
|
-
original_width:
|
|
35
|
-
original_height:
|
|
36
|
-
exif:
|
|
37
|
-
sha256:
|
|
25
|
+
face_bboxes: list[list[float]] | None = Field(default=None)
|
|
26
|
+
caption: str | None = Field(default=None)
|
|
27
|
+
url: str | None = Field(default=None)
|
|
28
|
+
key: str | None = Field(default=None)
|
|
29
|
+
status: str | None = Field(default=None)
|
|
30
|
+
error_message: str | None = Field(default=None)
|
|
31
|
+
width: int | None = Field(default=None)
|
|
32
|
+
height: int | None = Field(default=None)
|
|
33
|
+
original_width: int | None = Field(default=None)
|
|
34
|
+
original_height: int | None = Field(default=None)
|
|
35
|
+
exif: str | None = Field(default=None)
|
|
36
|
+
sha256: str | None = Field(default=None)
|
|
38
37
|
|
|
39
38
|
@staticmethod
|
|
40
39
|
def _reader(builder, item):
|
|
@@ -42,13 +41,13 @@ class Laion(WDSReadableSubclass):
|
|
|
42
41
|
|
|
43
42
|
|
|
44
43
|
class WDSLaion(WDSBasic):
|
|
45
|
-
txt:
|
|
46
|
-
json: Laion # type: ignore[assignment]
|
|
44
|
+
txt: str | None = Field(default=None)
|
|
45
|
+
json: Laion = Field(default_factory=Laion) # type: ignore[assignment]
|
|
47
46
|
|
|
48
47
|
|
|
49
48
|
class LaionMeta(BaseModel):
|
|
50
49
|
file: File
|
|
51
|
-
index:
|
|
50
|
+
index: int | None = Field(default=None)
|
|
52
51
|
b32_img: list[float] = Field(default=[])
|
|
53
52
|
b32_txt: list[float] = Field(default=[])
|
|
54
53
|
l14_img: list[float] = Field(default=[])
|
datachain/listing.py
CHANGED
|
@@ -2,7 +2,7 @@ import glob
|
|
|
2
2
|
import os
|
|
3
3
|
from collections.abc import Iterable, Iterator
|
|
4
4
|
from functools import cached_property
|
|
5
|
-
from typing import TYPE_CHECKING
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
6
|
|
|
7
7
|
from sqlalchemy import Column
|
|
8
8
|
from sqlalchemy.sql import func
|
|
@@ -25,16 +25,17 @@ class Listing:
|
|
|
25
25
|
metastore: "AbstractMetastore",
|
|
26
26
|
warehouse: "AbstractWarehouse",
|
|
27
27
|
client: "Client",
|
|
28
|
-
dataset_name:
|
|
29
|
-
dataset_version:
|
|
30
|
-
|
|
28
|
+
dataset_name: str | None = None,
|
|
29
|
+
dataset_version: str | None = None,
|
|
30
|
+
column: str = "file",
|
|
31
31
|
):
|
|
32
32
|
self.metastore = metastore
|
|
33
33
|
self.warehouse = warehouse
|
|
34
34
|
self.client = client
|
|
35
35
|
self.dataset_name = dataset_name # dataset representing bucket listing
|
|
36
36
|
self.dataset_version = dataset_version # dataset representing bucket listing
|
|
37
|
-
self.
|
|
37
|
+
self.column = column
|
|
38
|
+
self._closed = False
|
|
38
39
|
|
|
39
40
|
def clone(self) -> "Listing":
|
|
40
41
|
return self.__class__(
|
|
@@ -43,7 +44,7 @@ class Listing:
|
|
|
43
44
|
self.client,
|
|
44
45
|
self.dataset_name,
|
|
45
46
|
self.dataset_version,
|
|
46
|
-
self.
|
|
47
|
+
self.column,
|
|
47
48
|
)
|
|
48
49
|
|
|
49
50
|
def __enter__(self) -> "Listing":
|
|
@@ -53,7 +54,13 @@ class Listing:
|
|
|
53
54
|
self.close()
|
|
54
55
|
|
|
55
56
|
def close(self) -> None:
|
|
56
|
-
self.
|
|
57
|
+
if self._closed:
|
|
58
|
+
return
|
|
59
|
+
self._closed = True
|
|
60
|
+
try:
|
|
61
|
+
self.warehouse.close_on_exit()
|
|
62
|
+
finally:
|
|
63
|
+
self.metastore.close_on_exit()
|
|
57
64
|
|
|
58
65
|
@property
|
|
59
66
|
def uri(self):
|
|
@@ -66,7 +73,12 @@ class Listing:
|
|
|
66
73
|
@cached_property
|
|
67
74
|
def dataset(self) -> "DatasetRecord":
|
|
68
75
|
assert self.dataset_name
|
|
69
|
-
|
|
76
|
+
project = self.metastore.listing_project
|
|
77
|
+
return self.metastore.get_dataset(
|
|
78
|
+
self.dataset_name,
|
|
79
|
+
namespace_name=project.namespace.name,
|
|
80
|
+
project_name=project.name,
|
|
81
|
+
)
|
|
70
82
|
|
|
71
83
|
@cached_property
|
|
72
84
|
def dataset_rows(self):
|
|
@@ -74,7 +86,7 @@ class Listing:
|
|
|
74
86
|
return self.warehouse.dataset_rows(
|
|
75
87
|
dataset,
|
|
76
88
|
self.dataset_version or dataset.latest_version,
|
|
77
|
-
|
|
89
|
+
column=self.column,
|
|
78
90
|
)
|
|
79
91
|
|
|
80
92
|
def expand_path(self, path, use_glob=True) -> list[Node]:
|
|
@@ -97,7 +109,7 @@ class Listing:
|
|
|
97
109
|
def collect_nodes_to_instantiate(
|
|
98
110
|
self,
|
|
99
111
|
sources: Iterable["DataSource"],
|
|
100
|
-
copy_to_filename:
|
|
112
|
+
copy_to_filename: str | None,
|
|
101
113
|
recursive=False,
|
|
102
114
|
copy_dir_contents=False,
|
|
103
115
|
from_dataset=False,
|
datachain/model/bbox.py
CHANGED
|
@@ -198,7 +198,9 @@ class BBox(DataModel):
|
|
|
198
198
|
def pose_inside(self, pose: Union["Pose", "Pose3D"]) -> bool:
|
|
199
199
|
"""Return True if the pose is inside the bounding box."""
|
|
200
200
|
return all(
|
|
201
|
-
self.point_inside(x, y)
|
|
201
|
+
self.point_inside(x, y)
|
|
202
|
+
for x, y in zip(pose.x, pose.y, strict=False)
|
|
203
|
+
if x > 0 or y > 0
|
|
202
204
|
)
|
|
203
205
|
|
|
204
206
|
@staticmethod
|
|
@@ -31,11 +31,11 @@ class YoloBBox(DataModel):
|
|
|
31
31
|
if not summary:
|
|
32
32
|
return YoloBBox(box=BBox())
|
|
33
33
|
name = summary[0].get("name", "")
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
34
|
+
if summary[0].get("box"):
|
|
35
|
+
assert isinstance(summary[0]["box"], dict)
|
|
36
|
+
box = BBox.from_dict(summary[0]["box"], title=name)
|
|
37
|
+
else:
|
|
38
|
+
box = BBox()
|
|
39
39
|
return YoloBBox(
|
|
40
40
|
cls=summary[0]["class"],
|
|
41
41
|
name=name,
|
|
@@ -69,7 +69,9 @@ class YoloBBoxes(DataModel):
|
|
|
69
69
|
cls.append(s["class"])
|
|
70
70
|
names.append(name)
|
|
71
71
|
confidence.append(s["confidence"])
|
|
72
|
-
|
|
72
|
+
if s.get("box"):
|
|
73
|
+
assert isinstance(s["box"], dict)
|
|
74
|
+
box.append(BBox.from_dict(s["box"], title=name))
|
|
73
75
|
return YoloBBoxes(
|
|
74
76
|
cls=cls,
|
|
75
77
|
name=names,
|
|
@@ -100,11 +102,11 @@ class YoloOBBox(DataModel):
|
|
|
100
102
|
if not summary:
|
|
101
103
|
return YoloOBBox(box=OBBox())
|
|
102
104
|
name = summary[0].get("name", "")
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
105
|
+
if summary[0].get("box"):
|
|
106
|
+
assert isinstance(summary[0]["box"], dict)
|
|
107
|
+
box = OBBox.from_dict(summary[0]["box"], title=name)
|
|
108
|
+
else:
|
|
109
|
+
box = OBBox()
|
|
108
110
|
return YoloOBBox(
|
|
109
111
|
cls=summary[0]["class"],
|
|
110
112
|
name=name,
|
|
@@ -138,7 +140,9 @@ class YoloOBBoxes(DataModel):
|
|
|
138
140
|
cls.append(s["class"])
|
|
139
141
|
names.append(name)
|
|
140
142
|
confidence.append(s["confidence"])
|
|
141
|
-
|
|
143
|
+
if s.get("box"):
|
|
144
|
+
assert isinstance(s["box"], dict)
|
|
145
|
+
box.append(OBBox.from_dict(s["box"], title=name))
|
|
142
146
|
return YoloOBBoxes(
|
|
143
147
|
cls=cls,
|
|
144
148
|
name=names,
|