datachain 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/cache.py +3 -3
- datachain/catalog/catalog.py +1 -1
- datachain/cli/__init__.py +12 -4
- datachain/cli/commands/datasets.py +2 -3
- datachain/cli/parser/__init__.py +51 -69
- datachain/cli/parser/job.py +20 -25
- datachain/cli/parser/studio.py +22 -46
- datachain/cli/parser/utils.py +1 -1
- datachain/client/azure.py +1 -1
- datachain/client/fsspec.py +1 -1
- datachain/client/gcs.py +1 -1
- datachain/client/local.py +1 -1
- datachain/client/s3.py +1 -1
- datachain/data_storage/sqlite.py +1 -1
- datachain/data_storage/warehouse.py +1 -1
- datachain/lib/arrow.py +2 -2
- datachain/lib/convert/unflatten.py +1 -2
- datachain/lib/dc.py +38 -11
- datachain/lib/file.py +27 -4
- datachain/lib/hf.py +1 -1
- datachain/lib/listing.py +4 -4
- datachain/lib/pytorch.py +3 -1
- datachain/lib/udf.py +56 -20
- datachain/listing.py +1 -1
- datachain/model/bbox.py +9 -9
- datachain/model/pose.py +9 -9
- datachain/model/segment.py +6 -6
- datachain/progress.py +0 -133
- datachain/query/dataset.py +19 -12
- datachain/studio.py +15 -9
- {datachain-0.8.7.dist-info → datachain-0.8.9.dist-info}/METADATA +4 -3
- {datachain-0.8.7.dist-info → datachain-0.8.9.dist-info}/RECORD +36 -36
- {datachain-0.8.7.dist-info → datachain-0.8.9.dist-info}/LICENSE +0 -0
- {datachain-0.8.7.dist-info → datachain-0.8.9.dist-info}/WHEEL +0 -0
- {datachain-0.8.7.dist-info → datachain-0.8.9.dist-info}/entry_points.txt +0 -0
- {datachain-0.8.7.dist-info → datachain-0.8.9.dist-info}/top_level.txt +0 -0
datachain/cli/parser/studio.py
CHANGED
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
def add_studio_parser(subparsers, parent_parser) -> None:
|
|
2
|
-
studio_help = "
|
|
2
|
+
studio_help = "Manage Studio authentication"
|
|
3
3
|
studio_description = (
|
|
4
|
-
"
|
|
5
|
-
"
|
|
6
|
-
"DataChain will utilize it for seamlessly sharing datasets\n"
|
|
7
|
-
"and using Studio features from CLI"
|
|
4
|
+
"Manage authentication and settings for Studio. "
|
|
5
|
+
"Configure tokens for sharing datasets and using Studio features."
|
|
8
6
|
)
|
|
9
7
|
|
|
10
8
|
studio_parser = subparsers.add_parser(
|
|
@@ -15,14 +13,13 @@ def add_studio_parser(subparsers, parent_parser) -> None:
|
|
|
15
13
|
)
|
|
16
14
|
studio_subparser = studio_parser.add_subparsers(
|
|
17
15
|
dest="cmd",
|
|
18
|
-
help="Use `
|
|
19
|
-
required=True,
|
|
16
|
+
help="Use `datachain studio CMD --help` to display command-specific help",
|
|
20
17
|
)
|
|
21
18
|
|
|
22
|
-
studio_login_help = "Authenticate
|
|
19
|
+
studio_login_help = "Authenticate with Studio"
|
|
23
20
|
studio_login_description = (
|
|
24
|
-
"
|
|
25
|
-
"
|
|
21
|
+
"Authenticate with Studio using default scopes. "
|
|
22
|
+
"A random name will be assigned as the token name if not specified."
|
|
26
23
|
)
|
|
27
24
|
login_parser = studio_subparser.add_parser(
|
|
28
25
|
"login",
|
|
@@ -36,14 +33,14 @@ def add_studio_parser(subparsers, parent_parser) -> None:
|
|
|
36
33
|
"--hostname",
|
|
37
34
|
action="store",
|
|
38
35
|
default=None,
|
|
39
|
-
help="
|
|
36
|
+
help="Hostname of the Studio instance",
|
|
40
37
|
)
|
|
41
38
|
login_parser.add_argument(
|
|
42
39
|
"-s",
|
|
43
40
|
"--scopes",
|
|
44
41
|
action="store",
|
|
45
42
|
default=None,
|
|
46
|
-
help="
|
|
43
|
+
help="Authentication token scopes",
|
|
47
44
|
)
|
|
48
45
|
|
|
49
46
|
login_parser.add_argument(
|
|
@@ -51,21 +48,20 @@ def add_studio_parser(subparsers, parent_parser) -> None:
|
|
|
51
48
|
"--name",
|
|
52
49
|
action="store",
|
|
53
50
|
default=None,
|
|
54
|
-
help="
|
|
55
|
-
"identify token shown in Studio profile.",
|
|
51
|
+
help="Authentication token name (shown in Studio profile)",
|
|
56
52
|
)
|
|
57
53
|
|
|
58
54
|
login_parser.add_argument(
|
|
59
55
|
"--no-open",
|
|
60
56
|
action="store_true",
|
|
61
57
|
default=False,
|
|
62
|
-
help="Use
|
|
63
|
-
"You will be presented with user code to enter in browser.\n"
|
|
64
|
-
"DataChain will also use this if it cannot launch browser on your behalf.",
|
|
58
|
+
help="Use code-based authentication without browser",
|
|
65
59
|
)
|
|
66
60
|
|
|
67
|
-
studio_logout_help = "
|
|
68
|
-
studio_logout_description =
|
|
61
|
+
studio_logout_help = "Log out from Studio"
|
|
62
|
+
studio_logout_description = (
|
|
63
|
+
"Remove the Studio authentication token from global config."
|
|
64
|
+
)
|
|
69
65
|
|
|
70
66
|
studio_subparser.add_parser(
|
|
71
67
|
"logout",
|
|
@@ -74,10 +70,8 @@ def add_studio_parser(subparsers, parent_parser) -> None:
|
|
|
74
70
|
help=studio_logout_help,
|
|
75
71
|
)
|
|
76
72
|
|
|
77
|
-
studio_team_help = "Set
|
|
78
|
-
studio_team_description =
|
|
79
|
-
"Set the default team for DataChain to use when interacting with Studio."
|
|
80
|
-
)
|
|
73
|
+
studio_team_help = "Set default team for Studio operations"
|
|
74
|
+
studio_team_description = "Set the default team for Studio operations."
|
|
81
75
|
|
|
82
76
|
team_parser = studio_subparser.add_parser(
|
|
83
77
|
"team",
|
|
@@ -88,39 +82,21 @@ def add_studio_parser(subparsers, parent_parser) -> None:
|
|
|
88
82
|
team_parser.add_argument(
|
|
89
83
|
"team_name",
|
|
90
84
|
action="store",
|
|
91
|
-
help="
|
|
85
|
+
help="Name of the team to set as default",
|
|
92
86
|
)
|
|
93
87
|
team_parser.add_argument(
|
|
94
88
|
"--global",
|
|
95
89
|
action="store_true",
|
|
96
90
|
default=False,
|
|
97
|
-
help="Set
|
|
91
|
+
help="Set team globally for all projects",
|
|
98
92
|
)
|
|
99
93
|
|
|
100
|
-
studio_token_help = "View
|
|
94
|
+
studio_token_help = "View Studio authentication token" # noqa: S105
|
|
95
|
+
studio_token_description = "Display the current authentication token for Studio." # noqa: S105
|
|
101
96
|
|
|
102
97
|
studio_subparser.add_parser(
|
|
103
98
|
"token",
|
|
104
99
|
parents=[parent_parser],
|
|
105
|
-
description=
|
|
100
|
+
description=studio_token_description,
|
|
106
101
|
help=studio_token_help,
|
|
107
102
|
)
|
|
108
|
-
|
|
109
|
-
studio_ls_dataset_help = "List the available datasets from Studio"
|
|
110
|
-
studio_ls_dataset_description = (
|
|
111
|
-
"This command lists all the datasets available in Studio.\n"
|
|
112
|
-
"It will show the dataset name and the number of versions available."
|
|
113
|
-
)
|
|
114
|
-
|
|
115
|
-
ls_dataset_parser = studio_subparser.add_parser(
|
|
116
|
-
"dataset",
|
|
117
|
-
parents=[parent_parser],
|
|
118
|
-
description=studio_ls_dataset_description,
|
|
119
|
-
help=studio_ls_dataset_help,
|
|
120
|
-
)
|
|
121
|
-
ls_dataset_parser.add_argument(
|
|
122
|
-
"--team",
|
|
123
|
-
action="store",
|
|
124
|
-
default=None,
|
|
125
|
-
help="The team to list datasets for. By default, it will use team from config.",
|
|
126
|
-
)
|
datachain/cli/parser/utils.py
CHANGED
datachain/client/azure.py
CHANGED
datachain/client/fsspec.py
CHANGED
|
@@ -23,7 +23,7 @@ from botocore.exceptions import ClientError
|
|
|
23
23
|
from dvc_objects.fs.system import reflink
|
|
24
24
|
from fsspec.asyn import get_loop, sync
|
|
25
25
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
26
|
-
from tqdm import tqdm
|
|
26
|
+
from tqdm.auto import tqdm
|
|
27
27
|
|
|
28
28
|
from datachain.cache import DataChainCache
|
|
29
29
|
from datachain.client.fileslice import FileWrapper
|
datachain/client/gcs.py
CHANGED
datachain/client/local.py
CHANGED
|
@@ -38,7 +38,7 @@ class FileClient(Client):
|
|
|
38
38
|
def get_uri(cls, name: str) -> "StorageURI":
|
|
39
39
|
from datachain.dataset import StorageURI
|
|
40
40
|
|
|
41
|
-
return StorageURI(f
|
|
41
|
+
return StorageURI(f"{cls.PREFIX}/{name.removeprefix('/')}")
|
|
42
42
|
|
|
43
43
|
@classmethod
|
|
44
44
|
def ls_buckets(cls, **kwargs):
|
datachain/client/s3.py
CHANGED
datachain/data_storage/sqlite.py
CHANGED
|
@@ -21,7 +21,7 @@ from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
|
|
|
21
21
|
from sqlalchemy.sql import func
|
|
22
22
|
from sqlalchemy.sql.expression import bindparam, cast
|
|
23
23
|
from sqlalchemy.sql.selectable import Select
|
|
24
|
-
from tqdm import tqdm
|
|
24
|
+
from tqdm.auto import tqdm
|
|
25
25
|
|
|
26
26
|
import datachain.sql.sqlite
|
|
27
27
|
from datachain.data_storage import AbstractDBMetastore, AbstractWarehouse
|
|
@@ -14,7 +14,7 @@ import sqlalchemy as sa
|
|
|
14
14
|
from sqlalchemy import Table, case, select
|
|
15
15
|
from sqlalchemy.sql import func
|
|
16
16
|
from sqlalchemy.sql.expression import true
|
|
17
|
-
from tqdm import tqdm
|
|
17
|
+
from tqdm.auto import tqdm
|
|
18
18
|
|
|
19
19
|
from datachain.client import Client
|
|
20
20
|
from datachain.data_storage.schema import convert_rows_custom_column_types
|
datachain/lib/arrow.py
CHANGED
|
@@ -7,7 +7,7 @@ import orjson
|
|
|
7
7
|
import pyarrow as pa
|
|
8
8
|
from fsspec.core import split_protocol
|
|
9
9
|
from pyarrow.dataset import CsvFileFormat, dataset
|
|
10
|
-
from tqdm import tqdm
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
11
|
|
|
12
12
|
from datachain.lib.data_model import dict_to_data_model
|
|
13
13
|
from datachain.lib.file import ArrowRow, File
|
|
@@ -33,7 +33,7 @@ class ReferenceFileSystem(fsspec.implementations.reference.ReferenceFileSystem):
|
|
|
33
33
|
# reads the whole file in-memory.
|
|
34
34
|
(uri,) = self.references[path]
|
|
35
35
|
protocol, _ = split_protocol(uri)
|
|
36
|
-
return self.fss[protocol].
|
|
36
|
+
return self.fss[protocol].open(uri, mode, *args, **kwargs)
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
class ArrowGenerator(Generator):
|
|
@@ -35,8 +35,7 @@ def unflatten_to_json_pos(
|
|
|
35
35
|
def _normalize(name: str) -> str:
|
|
36
36
|
if DEFAULT_DELIMITER in name:
|
|
37
37
|
raise RuntimeError(
|
|
38
|
-
f"variable '{name}' cannot be used "
|
|
39
|
-
f"because it contains {DEFAULT_DELIMITER}"
|
|
38
|
+
f"variable '{name}' cannot be used because it contains {DEFAULT_DELIMITER}"
|
|
40
39
|
)
|
|
41
40
|
return _to_snake_case(name)
|
|
42
41
|
|
datachain/lib/dc.py
CHANGED
|
@@ -11,6 +11,7 @@ from typing import (
|
|
|
11
11
|
BinaryIO,
|
|
12
12
|
Callable,
|
|
13
13
|
ClassVar,
|
|
14
|
+
Literal,
|
|
14
15
|
Optional,
|
|
15
16
|
TypeVar,
|
|
16
17
|
Union,
|
|
@@ -1276,7 +1277,12 @@ class DataChain:
|
|
|
1276
1277
|
yield ret[0] if len(cols) == 1 else tuple(ret)
|
|
1277
1278
|
|
|
1278
1279
|
def to_pytorch(
|
|
1279
|
-
self,
|
|
1280
|
+
self,
|
|
1281
|
+
transform=None,
|
|
1282
|
+
tokenizer=None,
|
|
1283
|
+
tokenizer_kwargs=None,
|
|
1284
|
+
num_samples=0,
|
|
1285
|
+
remove_prefetched: bool = False,
|
|
1280
1286
|
):
|
|
1281
1287
|
"""Convert to pytorch dataset format.
|
|
1282
1288
|
|
|
@@ -1286,6 +1292,7 @@ class DataChain:
|
|
|
1286
1292
|
tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
|
|
1287
1293
|
num_samples (int): Number of random samples to draw for each epoch.
|
|
1288
1294
|
This argument is ignored if `num_samples=0` (the default).
|
|
1295
|
+
remove_prefetched (bool): Whether to remove prefetched files after reading.
|
|
1289
1296
|
|
|
1290
1297
|
Example:
|
|
1291
1298
|
```py
|
|
@@ -1312,6 +1319,7 @@ class DataChain:
|
|
|
1312
1319
|
tokenizer_kwargs=tokenizer_kwargs,
|
|
1313
1320
|
num_samples=num_samples,
|
|
1314
1321
|
dc_settings=chain._settings,
|
|
1322
|
+
remove_prefetched=remove_prefetched,
|
|
1315
1323
|
)
|
|
1316
1324
|
|
|
1317
1325
|
def remove_file_signals(self) -> "Self": # noqa: D102
|
|
@@ -1330,19 +1338,27 @@ class DataChain:
|
|
|
1330
1338
|
|
|
1331
1339
|
Parameters:
|
|
1332
1340
|
right_ds: Chain to join with.
|
|
1333
|
-
on: Predicate
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1341
|
+
on: Predicate ("column.name", C("column.name"), or Func) or list of
|
|
1342
|
+
Predicates to join on. If both chains have the same predicates then
|
|
1343
|
+
this predicate is enough for the join. Otherwise, `right_on` parameter
|
|
1344
|
+
has to specify the predicates for the other chain.
|
|
1345
|
+
right_on: Optional predicate or list of Predicates for the `right_ds`
|
|
1346
|
+
to join.
|
|
1338
1347
|
inner (bool): Whether to run inner join or outer join.
|
|
1339
|
-
rname (str):
|
|
1348
|
+
rname (str): Name prefix for conflicting signal names.
|
|
1340
1349
|
|
|
1341
|
-
|
|
1350
|
+
Examples:
|
|
1342
1351
|
```py
|
|
1343
1352
|
meta = meta_emd.merge(meta_pq, on=(C.name, C.emd__index),
|
|
1344
1353
|
right_on=(C.name, C.pq__index))
|
|
1345
1354
|
```
|
|
1355
|
+
|
|
1356
|
+
```py
|
|
1357
|
+
imgs.merge(captions,
|
|
1358
|
+
on=func.path.file_stem(imgs.c("file.path")),
|
|
1359
|
+
right_on=func.path.file_stem(captions.c("file.path"))
|
|
1360
|
+
```
|
|
1361
|
+
)
|
|
1346
1362
|
"""
|
|
1347
1363
|
if on is None:
|
|
1348
1364
|
raise DatasetMergeError(["None"], None, "'on' must be specified")
|
|
@@ -2407,11 +2423,22 @@ class DataChain:
|
|
|
2407
2423
|
def export_files(
|
|
2408
2424
|
self,
|
|
2409
2425
|
output: str,
|
|
2410
|
-
signal="file",
|
|
2426
|
+
signal: str = "file",
|
|
2411
2427
|
placement: FileExportPlacement = "fullpath",
|
|
2412
2428
|
use_cache: bool = True,
|
|
2429
|
+
link_type: Literal["copy", "symlink"] = "copy",
|
|
2413
2430
|
) -> None:
|
|
2414
|
-
"""
|
|
2431
|
+
"""Export files from a specified signal to a directory.
|
|
2432
|
+
|
|
2433
|
+
Args:
|
|
2434
|
+
output: Path to the target directory for exporting files.
|
|
2435
|
+
signal: Name of the signal to export files from.
|
|
2436
|
+
placement: The method to use for naming exported files.
|
|
2437
|
+
The possible values are: "filename", "etag", "fullpath", and "checksum".
|
|
2438
|
+
use_cache: If `True`, cache the files before exporting.
|
|
2439
|
+
link_type: Method to use for exporting files.
|
|
2440
|
+
Falls back to `'copy'` if symlinking fails.
|
|
2441
|
+
"""
|
|
2415
2442
|
if placement == "filename" and (
|
|
2416
2443
|
self._query.distinct(pathfunc.name(C(f"{signal}__path"))).count()
|
|
2417
2444
|
!= self._query.count()
|
|
@@ -2419,7 +2446,7 @@ class DataChain:
|
|
|
2419
2446
|
raise ValueError("Files with the same name found")
|
|
2420
2447
|
|
|
2421
2448
|
for file in self.collect(signal):
|
|
2422
|
-
file.export(output, placement, use_cache) # type: ignore[union-attr]
|
|
2449
|
+
file.export(output, placement, use_cache, link_type=link_type) # type: ignore[union-attr]
|
|
2423
2450
|
|
|
2424
2451
|
def shuffle(self) -> "Self":
|
|
2425
2452
|
"""Shuffle the rows of the chain deterministically."""
|
datachain/lib/file.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import errno
|
|
1
2
|
import hashlib
|
|
2
3
|
import io
|
|
3
4
|
import json
|
|
@@ -76,18 +77,18 @@ class TarVFile(VFile):
|
|
|
76
77
|
def open(cls, file: "File", location: list[dict]):
|
|
77
78
|
"""Stream file from tar archive based on location in archive."""
|
|
78
79
|
if len(location) > 1:
|
|
79
|
-
VFileError(file, "multiple 'location's are not supported yet")
|
|
80
|
+
raise VFileError(file, "multiple 'location's are not supported yet")
|
|
80
81
|
|
|
81
82
|
loc = location[0]
|
|
82
83
|
|
|
83
84
|
if (offset := loc.get("offset", None)) is None:
|
|
84
|
-
VFileError(file, "'offset' is not specified")
|
|
85
|
+
raise VFileError(file, "'offset' is not specified")
|
|
85
86
|
|
|
86
87
|
if (size := loc.get("size", None)) is None:
|
|
87
|
-
VFileError(file, "'size' is not specified")
|
|
88
|
+
raise VFileError(file, "'size' is not specified")
|
|
88
89
|
|
|
89
90
|
if (parent := loc.get("parent", None)) is None:
|
|
90
|
-
VFileError(file, "'parent' is not specified")
|
|
91
|
+
raise VFileError(file, "'parent' is not specified")
|
|
91
92
|
|
|
92
93
|
tar_file = File(**parent)
|
|
93
94
|
tar_file._set_stream(file._catalog)
|
|
@@ -236,11 +237,26 @@ class File(DataModel):
|
|
|
236
237
|
with open(destination, mode="wb") as f:
|
|
237
238
|
f.write(self.read())
|
|
238
239
|
|
|
240
|
+
def _symlink_to(self, destination: str):
|
|
241
|
+
if self.location:
|
|
242
|
+
raise OSError(errno.ENOTSUP, "Symlinking virtual file is not supported")
|
|
243
|
+
|
|
244
|
+
if self._caching_enabled:
|
|
245
|
+
self.ensure_cached()
|
|
246
|
+
source = self.get_local_path()
|
|
247
|
+
assert source, "File was not cached"
|
|
248
|
+
elif self.source.startswith("file://"):
|
|
249
|
+
source = self.get_path()
|
|
250
|
+
else:
|
|
251
|
+
raise OSError(errno.EXDEV, "can't link across filesystems")
|
|
252
|
+
return os.symlink(source, destination)
|
|
253
|
+
|
|
239
254
|
def export(
|
|
240
255
|
self,
|
|
241
256
|
output: str,
|
|
242
257
|
placement: ExportPlacement = "fullpath",
|
|
243
258
|
use_cache: bool = True,
|
|
259
|
+
link_type: Literal["copy", "symlink"] = "copy",
|
|
244
260
|
) -> None:
|
|
245
261
|
"""Export file to new location."""
|
|
246
262
|
if use_cache:
|
|
@@ -249,6 +265,13 @@ class File(DataModel):
|
|
|
249
265
|
dst_dir = os.path.dirname(dst)
|
|
250
266
|
os.makedirs(dst_dir, exist_ok=True)
|
|
251
267
|
|
|
268
|
+
if link_type == "symlink":
|
|
269
|
+
try:
|
|
270
|
+
return self._symlink_to(dst)
|
|
271
|
+
except OSError as exc:
|
|
272
|
+
if exc.errno not in (errno.ENOTSUP, errno.EXDEV, errno.ENOSYS):
|
|
273
|
+
raise
|
|
274
|
+
|
|
252
275
|
self.save(dst)
|
|
253
276
|
|
|
254
277
|
def _set_stream(
|
datachain/lib/hf.py
CHANGED
|
@@ -29,7 +29,7 @@ from io import BytesIO
|
|
|
29
29
|
from typing import TYPE_CHECKING, Any, Union
|
|
30
30
|
|
|
31
31
|
import PIL
|
|
32
|
-
from tqdm import tqdm
|
|
32
|
+
from tqdm.auto import tqdm
|
|
33
33
|
|
|
34
34
|
from datachain.lib.arrow import arrow_type_mapper
|
|
35
35
|
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
|
datachain/lib/listing.py
CHANGED
|
@@ -113,14 +113,14 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[Optional[str], st
|
|
|
113
113
|
telemetry.log_param("client", client.PREFIX)
|
|
114
114
|
|
|
115
115
|
if not uri.endswith("/") and _isfile(client, uri):
|
|
116
|
-
return None, f
|
|
116
|
+
return None, f"{storage_uri}/{path.lstrip('/')}", path
|
|
117
117
|
if uses_glob(path):
|
|
118
118
|
lst_uri_path = posixpath.dirname(path)
|
|
119
119
|
else:
|
|
120
|
-
storage_uri, path = Client.parse_url(f
|
|
120
|
+
storage_uri, path = Client.parse_url(f"{uri.rstrip('/')}/")
|
|
121
121
|
lst_uri_path = path
|
|
122
122
|
|
|
123
|
-
lst_uri = f
|
|
123
|
+
lst_uri = f"{storage_uri}/{lst_uri_path.lstrip('/')}"
|
|
124
124
|
ds_name = (
|
|
125
125
|
f"{LISTING_PREFIX}{storage_uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
|
|
126
126
|
)
|
|
@@ -180,7 +180,7 @@ def get_listing(
|
|
|
180
180
|
# for local file system we need to fix listing path / prefix
|
|
181
181
|
# if we are reusing existing listing
|
|
182
182
|
if isinstance(client, FileClient) and listing and listing.name != ds_name:
|
|
183
|
-
list_path = f
|
|
183
|
+
list_path = f"{ds_name.strip('/').removeprefix(listing.name)}/{list_path}"
|
|
184
184
|
|
|
185
185
|
ds_name = listing.name if listing else ds_name
|
|
186
186
|
|
datachain/lib/pytorch.py
CHANGED
|
@@ -50,6 +50,7 @@ class PytorchDataset(IterableDataset):
|
|
|
50
50
|
tokenizer_kwargs: Optional[dict[str, Any]] = None,
|
|
51
51
|
num_samples: int = 0,
|
|
52
52
|
dc_settings: Optional[Settings] = None,
|
|
53
|
+
remove_prefetched: bool = False,
|
|
53
54
|
):
|
|
54
55
|
"""
|
|
55
56
|
Pytorch IterableDataset that streams DataChain datasets.
|
|
@@ -84,6 +85,7 @@ class PytorchDataset(IterableDataset):
|
|
|
84
85
|
|
|
85
86
|
self._cache = catalog.cache
|
|
86
87
|
self._prefetch_cache: Optional[Cache] = None
|
|
88
|
+
self._remove_prefetched = remove_prefetched
|
|
87
89
|
if prefetch and not self.cache:
|
|
88
90
|
tmp_dir = catalog.cache.tmp_dir
|
|
89
91
|
assert tmp_dir
|
|
@@ -147,7 +149,7 @@ class PytorchDataset(IterableDataset):
|
|
|
147
149
|
rows,
|
|
148
150
|
self.prefetch,
|
|
149
151
|
download_cb=download_cb,
|
|
150
|
-
|
|
152
|
+
remove_prefetched=self._remove_prefetched,
|
|
151
153
|
)
|
|
152
154
|
|
|
153
155
|
with download_cb, closing(rows):
|
datachain/lib/udf.py
CHANGED
|
@@ -16,6 +16,7 @@ from datachain.lib.convert.flatten import flatten
|
|
|
16
16
|
from datachain.lib.data_model import DataValue
|
|
17
17
|
from datachain.lib.file import File
|
|
18
18
|
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
19
|
+
from datachain.progress import CombinedDownloadCallback
|
|
19
20
|
from datachain.query.batch import (
|
|
20
21
|
Batch,
|
|
21
22
|
BatchingStrategy,
|
|
@@ -301,20 +302,42 @@ async def _prefetch_input(
|
|
|
301
302
|
return row
|
|
302
303
|
|
|
303
304
|
|
|
305
|
+
def _remove_prefetched(row: T) -> None:
|
|
306
|
+
for obj in row:
|
|
307
|
+
if isinstance(obj, File):
|
|
308
|
+
catalog = obj._catalog
|
|
309
|
+
assert catalog is not None
|
|
310
|
+
try:
|
|
311
|
+
catalog.cache.remove(obj)
|
|
312
|
+
except Exception as e: # noqa: BLE001
|
|
313
|
+
print(f"Failed to remove prefetched item {obj.name!r}: {e!s}")
|
|
314
|
+
|
|
315
|
+
|
|
304
316
|
def _prefetch_inputs(
|
|
305
317
|
prepared_inputs: "Iterable[T]",
|
|
306
318
|
prefetch: int = 0,
|
|
307
319
|
download_cb: Optional["Callback"] = None,
|
|
308
|
-
after_prefetch:
|
|
320
|
+
after_prefetch: Optional[Callable[[], None]] = None,
|
|
321
|
+
remove_prefetched: bool = False,
|
|
309
322
|
) -> "abc.Generator[T, None, None]":
|
|
310
|
-
if prefetch
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
323
|
+
if not prefetch:
|
|
324
|
+
yield from prepared_inputs
|
|
325
|
+
return
|
|
326
|
+
|
|
327
|
+
if after_prefetch is None:
|
|
328
|
+
after_prefetch = noop
|
|
329
|
+
if isinstance(download_cb, CombinedDownloadCallback):
|
|
330
|
+
after_prefetch = download_cb.increment_file_count
|
|
331
|
+
|
|
332
|
+
f = partial(_prefetch_input, download_cb=download_cb, after_prefetch=after_prefetch)
|
|
333
|
+
mapper = AsyncMapper(f, prepared_inputs, workers=prefetch)
|
|
334
|
+
with closing(mapper.iterate()) as row_iter:
|
|
335
|
+
for row in row_iter:
|
|
336
|
+
try:
|
|
337
|
+
yield row # type: ignore[misc]
|
|
338
|
+
finally:
|
|
339
|
+
if remove_prefetched:
|
|
340
|
+
_remove_prefetched(row)
|
|
318
341
|
|
|
319
342
|
|
|
320
343
|
def _get_cache(
|
|
@@ -351,7 +374,13 @@ class Mapper(UDFBase):
|
|
|
351
374
|
)
|
|
352
375
|
|
|
353
376
|
prepared_inputs = _prepare_rows(udf_inputs)
|
|
354
|
-
prepared_inputs = _prefetch_inputs(
|
|
377
|
+
prepared_inputs = _prefetch_inputs(
|
|
378
|
+
prepared_inputs,
|
|
379
|
+
self.prefetch,
|
|
380
|
+
download_cb=download_cb,
|
|
381
|
+
remove_prefetched=bool(self.prefetch) and not cache,
|
|
382
|
+
)
|
|
383
|
+
|
|
355
384
|
with closing(prepared_inputs):
|
|
356
385
|
for id_, *udf_args in prepared_inputs:
|
|
357
386
|
result_objs = self.process_safe(udf_args)
|
|
@@ -391,9 +420,9 @@ class BatchMapper(UDFBase):
|
|
|
391
420
|
)
|
|
392
421
|
result_objs = list(self.process_safe(udf_args))
|
|
393
422
|
n_objs = len(result_objs)
|
|
394
|
-
assert (
|
|
395
|
-
n_objs
|
|
396
|
-
)
|
|
423
|
+
assert n_objs == n_rows, (
|
|
424
|
+
f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
|
|
425
|
+
)
|
|
397
426
|
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
398
427
|
output = [
|
|
399
428
|
{"sys__id": row_id} | dict(zip(self.signal_names, signals))
|
|
@@ -429,15 +458,22 @@ class Generator(UDFBase):
|
|
|
429
458
|
row, udf_fields, catalog, cache, download_cb
|
|
430
459
|
)
|
|
431
460
|
|
|
461
|
+
def _process_row(row):
|
|
462
|
+
with safe_closing(self.process_safe(row)) as result_objs:
|
|
463
|
+
for result_obj in result_objs:
|
|
464
|
+
udf_output = self._flatten_row(result_obj)
|
|
465
|
+
yield dict(zip(self.signal_names, udf_output))
|
|
466
|
+
|
|
432
467
|
prepared_inputs = _prepare_rows(udf_inputs)
|
|
433
|
-
prepared_inputs = _prefetch_inputs(
|
|
468
|
+
prepared_inputs = _prefetch_inputs(
|
|
469
|
+
prepared_inputs,
|
|
470
|
+
self.prefetch,
|
|
471
|
+
download_cb=download_cb,
|
|
472
|
+
remove_prefetched=bool(self.prefetch) and not cache,
|
|
473
|
+
)
|
|
434
474
|
with closing(prepared_inputs):
|
|
435
|
-
for row in prepared_inputs:
|
|
436
|
-
|
|
437
|
-
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
438
|
-
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
|
|
439
|
-
processed_cb.relative_update(1)
|
|
440
|
-
yield output
|
|
475
|
+
for row in processed_cb.wrap(prepared_inputs):
|
|
476
|
+
yield _process_row(row)
|
|
441
477
|
|
|
442
478
|
self.teardown()
|
|
443
479
|
|
datachain/listing.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
|
|
|
7
7
|
|
|
8
8
|
from sqlalchemy import Column
|
|
9
9
|
from sqlalchemy.sql import func
|
|
10
|
-
from tqdm import tqdm
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
11
|
|
|
12
12
|
from datachain.node import DirType, Node, NodeWithPath
|
|
13
13
|
from datachain.sql.functions import path as pathfunc
|
datachain/model/bbox.py
CHANGED
|
@@ -22,9 +22,9 @@ class BBox(DataModel):
|
|
|
22
22
|
@staticmethod
|
|
23
23
|
def from_list(coords: list[float], title: str = "") -> "BBox":
|
|
24
24
|
assert len(coords) == 4, "Bounding box must be a list of 4 coordinates."
|
|
25
|
-
assert all(
|
|
26
|
-
|
|
27
|
-
)
|
|
25
|
+
assert all(isinstance(value, (int, float)) for value in coords), (
|
|
26
|
+
"Bounding box coordinates must be floats or integers."
|
|
27
|
+
)
|
|
28
28
|
return BBox(
|
|
29
29
|
title=title,
|
|
30
30
|
coords=[round(c) for c in coords],
|
|
@@ -64,12 +64,12 @@ class OBBox(DataModel):
|
|
|
64
64
|
|
|
65
65
|
@staticmethod
|
|
66
66
|
def from_list(coords: list[float], title: str = "") -> "OBBox":
|
|
67
|
-
assert (
|
|
68
|
-
|
|
69
|
-
)
|
|
70
|
-
assert all(
|
|
71
|
-
|
|
72
|
-
)
|
|
67
|
+
assert len(coords) == 8, (
|
|
68
|
+
"Oriented bounding box must be a list of 8 coordinates."
|
|
69
|
+
)
|
|
70
|
+
assert all(isinstance(value, (int, float)) for value in coords), (
|
|
71
|
+
"Oriented bounding box coordinates must be floats or integers."
|
|
72
|
+
)
|
|
73
73
|
return OBBox(
|
|
74
74
|
title=title,
|
|
75
75
|
coords=[round(c) for c in coords],
|