datachain 0.3.17__py3-none-any.whl → 0.3.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +5 -2
- datachain/cache.py +14 -55
- datachain/catalog/catalog.py +17 -97
- datachain/cli.py +7 -2
- datachain/client/fsspec.py +29 -63
- datachain/client/local.py +2 -3
- datachain/dataset.py +7 -2
- datachain/error.py +6 -4
- datachain/lib/arrow.py +10 -4
- datachain/lib/dc.py +6 -2
- datachain/lib/file.py +64 -28
- datachain/lib/listing.py +2 -0
- datachain/listing.py +4 -4
- datachain/node.py +6 -6
- datachain/nodes_fetcher.py +12 -5
- datachain/nodes_thread_pool.py +1 -1
- datachain/progress.py +2 -12
- datachain/query/dataset.py +6 -40
- datachain/query/dispatch.py +2 -15
- datachain/query/schema.py +25 -24
- datachain/query/udf.py +0 -106
- datachain/sql/types.py +4 -2
- datachain/telemetry.py +37 -0
- datachain/utils.py +11 -0
- {datachain-0.3.17.dist-info → datachain-0.3.19.dist-info}/METADATA +5 -4
- {datachain-0.3.17.dist-info → datachain-0.3.19.dist-info}/RECORD +30 -29
- {datachain-0.3.17.dist-info → datachain-0.3.19.dist-info}/LICENSE +0 -0
- {datachain-0.3.17.dist-info → datachain-0.3.19.dist-info}/WHEEL +0 -0
- {datachain-0.3.17.dist-info → datachain-0.3.19.dist-info}/entry_points.txt +0 -0
- {datachain-0.3.17.dist-info → datachain-0.3.19.dist-info}/top_level.txt +0 -0
datachain/__init__.py
CHANGED
|
@@ -1,21 +1,23 @@
|
|
|
1
1
|
from datachain.lib.data_model import DataModel, DataType, is_chain_type
|
|
2
2
|
from datachain.lib.dc import C, Column, DataChain, Sys
|
|
3
3
|
from datachain.lib.file import (
|
|
4
|
+
ArrowRow,
|
|
4
5
|
File,
|
|
5
6
|
FileError,
|
|
6
7
|
ImageFile,
|
|
7
|
-
IndexedFile,
|
|
8
8
|
TarVFile,
|
|
9
9
|
TextFile,
|
|
10
10
|
)
|
|
11
11
|
from datachain.lib.model_store import ModelStore
|
|
12
12
|
from datachain.lib.udf import Aggregator, Generator, Mapper
|
|
13
13
|
from datachain.lib.utils import AbstractUDF, DataChainError
|
|
14
|
+
from datachain.query import metrics, param
|
|
14
15
|
from datachain.query.session import Session
|
|
15
16
|
|
|
16
17
|
__all__ = [
|
|
17
18
|
"AbstractUDF",
|
|
18
19
|
"Aggregator",
|
|
20
|
+
"ArrowRow",
|
|
19
21
|
"C",
|
|
20
22
|
"Column",
|
|
21
23
|
"DataChain",
|
|
@@ -26,7 +28,6 @@ __all__ = [
|
|
|
26
28
|
"FileError",
|
|
27
29
|
"Generator",
|
|
28
30
|
"ImageFile",
|
|
29
|
-
"IndexedFile",
|
|
30
31
|
"Mapper",
|
|
31
32
|
"ModelStore",
|
|
32
33
|
"Session",
|
|
@@ -34,4 +35,6 @@ __all__ = [
|
|
|
34
35
|
"TarVFile",
|
|
35
36
|
"TextFile",
|
|
36
37
|
"is_chain_type",
|
|
38
|
+
"metrics",
|
|
39
|
+
"param",
|
|
37
40
|
]
|
datachain/cache.py
CHANGED
|
@@ -1,56 +1,15 @@
|
|
|
1
|
-
import hashlib
|
|
2
|
-
import json
|
|
3
1
|
import os
|
|
4
|
-
from datetime import datetime
|
|
5
|
-
from functools import partial
|
|
6
2
|
from typing import TYPE_CHECKING, Optional
|
|
7
3
|
|
|
8
|
-
import attrs
|
|
9
4
|
from dvc_data.hashfile.db.local import LocalHashFileDB
|
|
10
5
|
from dvc_objects.fs.local import LocalFileSystem
|
|
11
6
|
from fsspec.callbacks import Callback, TqdmCallback
|
|
12
7
|
|
|
13
|
-
from datachain.utils import TIME_ZERO
|
|
14
|
-
|
|
15
8
|
from .progress import Tqdm
|
|
16
9
|
|
|
17
10
|
if TYPE_CHECKING:
|
|
18
11
|
from datachain.client import Client
|
|
19
|
-
from datachain.
|
|
20
|
-
|
|
21
|
-
sha256 = partial(hashlib.sha256, usedforsecurity=False)
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
@attrs.frozen
|
|
25
|
-
class UniqueId:
|
|
26
|
-
storage: "StorageURI"
|
|
27
|
-
path: str
|
|
28
|
-
size: int
|
|
29
|
-
etag: str
|
|
30
|
-
version: str = ""
|
|
31
|
-
is_latest: bool = True
|
|
32
|
-
location: Optional[str] = None
|
|
33
|
-
last_modified: datetime = TIME_ZERO
|
|
34
|
-
|
|
35
|
-
def get_parsed_location(self) -> Optional[dict]:
|
|
36
|
-
if not self.location:
|
|
37
|
-
return None
|
|
38
|
-
|
|
39
|
-
loc_stack = (
|
|
40
|
-
json.loads(self.location)
|
|
41
|
-
if isinstance(self.location, str)
|
|
42
|
-
else self.location
|
|
43
|
-
)
|
|
44
|
-
if len(loc_stack) > 1:
|
|
45
|
-
raise NotImplementedError("Nested v-objects are not supported yet.")
|
|
46
|
-
|
|
47
|
-
return loc_stack[0]
|
|
48
|
-
|
|
49
|
-
def get_hash(self) -> str:
|
|
50
|
-
fingerprint = f"{self.storage}/{self.path}/{self.version}/{self.etag}"
|
|
51
|
-
if self.location:
|
|
52
|
-
fingerprint += f"/{self.location}"
|
|
53
|
-
return sha256(fingerprint.encode()).hexdigest()
|
|
12
|
+
from datachain.lib.file import File
|
|
54
13
|
|
|
55
14
|
|
|
56
15
|
def try_scandir(path):
|
|
@@ -77,30 +36,30 @@ class DataChainCache:
|
|
|
77
36
|
def tmp_dir(self):
|
|
78
37
|
return self.odb.tmp_dir
|
|
79
38
|
|
|
80
|
-
def get_path(self,
|
|
81
|
-
if self.contains(
|
|
82
|
-
return self.path_from_checksum(
|
|
39
|
+
def get_path(self, file: "File") -> Optional[str]:
|
|
40
|
+
if self.contains(file):
|
|
41
|
+
return self.path_from_checksum(file.get_hash())
|
|
83
42
|
return None
|
|
84
43
|
|
|
85
|
-
def contains(self,
|
|
86
|
-
return self.odb.exists(
|
|
44
|
+
def contains(self, file: "File") -> bool:
|
|
45
|
+
return self.odb.exists(file.get_hash())
|
|
87
46
|
|
|
88
47
|
def path_from_checksum(self, checksum: str) -> str:
|
|
89
48
|
assert checksum
|
|
90
49
|
return self.odb.oid_to_path(checksum)
|
|
91
50
|
|
|
92
|
-
def remove(self,
|
|
93
|
-
self.odb.delete(
|
|
51
|
+
def remove(self, file: "File") -> None:
|
|
52
|
+
self.odb.delete(file.get_hash())
|
|
94
53
|
|
|
95
54
|
async def download(
|
|
96
|
-
self,
|
|
55
|
+
self, file: "File", client: "Client", callback: Optional[Callback] = None
|
|
97
56
|
) -> None:
|
|
98
|
-
from_path = f"{
|
|
57
|
+
from_path = f"{file.source}/{file.path}"
|
|
99
58
|
from dvc_objects.fs.utils import tmp_fname
|
|
100
59
|
|
|
101
60
|
odb_fs = self.odb.fs
|
|
102
61
|
tmp_info = odb_fs.join(self.odb.tmp_dir, tmp_fname()) # type: ignore[arg-type]
|
|
103
|
-
size =
|
|
62
|
+
size = file.size
|
|
104
63
|
if size < 0:
|
|
105
64
|
size = await client.get_size(from_path)
|
|
106
65
|
cb = callback or TqdmCallback(
|
|
@@ -115,13 +74,13 @@ class DataChainCache:
|
|
|
115
74
|
cb.close()
|
|
116
75
|
|
|
117
76
|
try:
|
|
118
|
-
oid =
|
|
77
|
+
oid = file.get_hash()
|
|
119
78
|
self.odb.add(tmp_info, self.odb.fs, oid)
|
|
120
79
|
finally:
|
|
121
80
|
os.unlink(tmp_info)
|
|
122
81
|
|
|
123
|
-
def store_data(self,
|
|
124
|
-
checksum =
|
|
82
|
+
def store_data(self, file: "File", contents: bytes) -> None:
|
|
83
|
+
checksum = file.get_hash()
|
|
125
84
|
dst = self.path_from_checksum(checksum)
|
|
126
85
|
if not os.path.exists(dst):
|
|
127
86
|
# Create the file only if it's not already in cache
|
datachain/catalog/catalog.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import ast
|
|
2
1
|
import glob
|
|
3
2
|
import io
|
|
4
3
|
import json
|
|
@@ -34,7 +33,7 @@ import yaml
|
|
|
34
33
|
from sqlalchemy import Column
|
|
35
34
|
from tqdm import tqdm
|
|
36
35
|
|
|
37
|
-
from datachain.cache import DataChainCache
|
|
36
|
+
from datachain.cache import DataChainCache
|
|
38
37
|
from datachain.client import Client
|
|
39
38
|
from datachain.config import get_remote_config, read_config
|
|
40
39
|
from datachain.dataset import (
|
|
@@ -53,9 +52,9 @@ from datachain.error import (
|
|
|
53
52
|
DataChainError,
|
|
54
53
|
DatasetInvalidVersionError,
|
|
55
54
|
DatasetNotFoundError,
|
|
55
|
+
DatasetVersionNotFoundError,
|
|
56
56
|
PendingIndexingError,
|
|
57
57
|
QueryScriptCancelError,
|
|
58
|
-
QueryScriptCompileError,
|
|
59
58
|
QueryScriptRunError,
|
|
60
59
|
)
|
|
61
60
|
from datachain.listing import Listing
|
|
@@ -588,44 +587,13 @@ class Catalog:
|
|
|
588
587
|
def generate_query_dataset_name(cls) -> str:
|
|
589
588
|
return f"{QUERY_DATASET_PREFIX}_{uuid4().hex}"
|
|
590
589
|
|
|
591
|
-
def
|
|
592
|
-
if code_ast.body:
|
|
593
|
-
last_expr = code_ast.body[-1]
|
|
594
|
-
if isinstance(last_expr, ast.Expr):
|
|
595
|
-
new_expressions = [
|
|
596
|
-
ast.Import(
|
|
597
|
-
names=[ast.alias(name="datachain.query.dataset", asname=None)]
|
|
598
|
-
),
|
|
599
|
-
ast.Expr(
|
|
600
|
-
value=ast.Call(
|
|
601
|
-
func=ast.Attribute(
|
|
602
|
-
value=ast.Attribute(
|
|
603
|
-
value=ast.Attribute(
|
|
604
|
-
value=ast.Name(id="datachain", ctx=ast.Load()),
|
|
605
|
-
attr="query",
|
|
606
|
-
ctx=ast.Load(),
|
|
607
|
-
),
|
|
608
|
-
attr="dataset",
|
|
609
|
-
ctx=ast.Load(),
|
|
610
|
-
),
|
|
611
|
-
attr="query_wrapper",
|
|
612
|
-
ctx=ast.Load(),
|
|
613
|
-
),
|
|
614
|
-
args=[last_expr],
|
|
615
|
-
keywords=[],
|
|
616
|
-
)
|
|
617
|
-
),
|
|
618
|
-
]
|
|
619
|
-
code_ast.body[-1:] = new_expressions
|
|
620
|
-
return code_ast
|
|
621
|
-
|
|
622
|
-
def get_client(self, uri: StorageURI, **config: Any) -> Client:
|
|
590
|
+
def get_client(self, uri: str, **config: Any) -> Client:
|
|
623
591
|
"""
|
|
624
592
|
Return the client corresponding to the given source `uri`.
|
|
625
593
|
"""
|
|
626
594
|
config = config or self.client_config
|
|
627
595
|
cls = Client.get_implementation(uri)
|
|
628
|
-
return cls.from_source(uri, self.cache, **config)
|
|
596
|
+
return cls.from_source(StorageURI(uri), self.cache, **config)
|
|
629
597
|
|
|
630
598
|
def enlist_source(
|
|
631
599
|
self,
|
|
@@ -1218,7 +1186,9 @@ class Catalog:
|
|
|
1218
1186
|
|
|
1219
1187
|
dataset_version = dataset.get_version(version)
|
|
1220
1188
|
if not dataset_version:
|
|
1221
|
-
raise
|
|
1189
|
+
raise DatasetVersionNotFoundError(
|
|
1190
|
+
f"Dataset {dataset.name} does not have version {version}"
|
|
1191
|
+
)
|
|
1222
1192
|
|
|
1223
1193
|
if not dataset_version.is_final_status():
|
|
1224
1194
|
raise ValueError("Cannot register dataset version in non final status")
|
|
@@ -1431,7 +1401,7 @@ class Catalog:
|
|
|
1431
1401
|
|
|
1432
1402
|
def get_file_signals(
|
|
1433
1403
|
self, dataset_name: str, dataset_version: int, row: RowDict
|
|
1434
|
-
) -> Optional[
|
|
1404
|
+
) -> Optional[RowDict]:
|
|
1435
1405
|
"""
|
|
1436
1406
|
Function that returns file signals from dataset row.
|
|
1437
1407
|
Note that signal names are without prefix, so if there was 'laion__file__source'
|
|
@@ -1448,7 +1418,7 @@ class Catalog:
|
|
|
1448
1418
|
|
|
1449
1419
|
version = self.get_dataset(dataset_name).get_version(dataset_version)
|
|
1450
1420
|
|
|
1451
|
-
file_signals_values =
|
|
1421
|
+
file_signals_values = RowDict()
|
|
1452
1422
|
|
|
1453
1423
|
schema = SignalSchema.deserialize(version.feature_schema)
|
|
1454
1424
|
for file_signals in schema.get_signals(File):
|
|
@@ -1476,6 +1446,8 @@ class Catalog:
|
|
|
1476
1446
|
use_cache: bool = True,
|
|
1477
1447
|
**config: Any,
|
|
1478
1448
|
):
|
|
1449
|
+
from datachain.lib.file import File
|
|
1450
|
+
|
|
1479
1451
|
file_signals = self.get_file_signals(dataset_name, dataset_version, row)
|
|
1480
1452
|
if not file_signals:
|
|
1481
1453
|
raise RuntimeError("Cannot open object without file signals")
|
|
@@ -1483,22 +1455,10 @@ class Catalog:
|
|
|
1483
1455
|
config = config or self.client_config
|
|
1484
1456
|
client = self.get_client(file_signals["source"], **config)
|
|
1485
1457
|
return client.open_object(
|
|
1486
|
-
|
|
1458
|
+
File._from_row(file_signals),
|
|
1487
1459
|
use_cache=use_cache,
|
|
1488
1460
|
)
|
|
1489
1461
|
|
|
1490
|
-
def _get_row_uid(self, row: RowDict) -> UniqueId:
|
|
1491
|
-
return UniqueId(
|
|
1492
|
-
row["source"],
|
|
1493
|
-
row["path"],
|
|
1494
|
-
row["size"],
|
|
1495
|
-
row["etag"],
|
|
1496
|
-
row["version"],
|
|
1497
|
-
row["is_latest"],
|
|
1498
|
-
row["location"],
|
|
1499
|
-
row["last_modified"],
|
|
1500
|
-
)
|
|
1501
|
-
|
|
1502
1462
|
def ls(
|
|
1503
1463
|
self,
|
|
1504
1464
|
sources: list[str],
|
|
@@ -1591,7 +1551,7 @@ class Catalog:
|
|
|
1591
1551
|
|
|
1592
1552
|
try:
|
|
1593
1553
|
remote_dataset_version = remote_dataset.get_version(version)
|
|
1594
|
-
except (
|
|
1554
|
+
except (DatasetVersionNotFoundError, StopIteration) as exc:
|
|
1595
1555
|
raise DataChainError(
|
|
1596
1556
|
f"Dataset {remote_dataset_name} doesn't have version {version}"
|
|
1597
1557
|
" on server"
|
|
@@ -1732,64 +1692,24 @@ class Catalog:
|
|
|
1732
1692
|
query_script: str,
|
|
1733
1693
|
env: Optional[Mapping[str, str]] = None,
|
|
1734
1694
|
python_executable: str = sys.executable,
|
|
1735
|
-
|
|
1736
|
-
capture_output: bool = True,
|
|
1695
|
+
capture_output: bool = False,
|
|
1737
1696
|
output_hook: Callable[[str], None] = noop,
|
|
1738
1697
|
params: Optional[dict[str, str]] = None,
|
|
1739
1698
|
job_id: Optional[str] = None,
|
|
1740
|
-
_execute_last_expression: bool = False,
|
|
1741
1699
|
) -> None:
|
|
1742
|
-
""
|
|
1743
|
-
Method to run custom user Python script to run a query and, as result,
|
|
1744
|
-
creates new dataset from the results of a query.
|
|
1745
|
-
Returns tuple of result dataset and script output.
|
|
1746
|
-
|
|
1747
|
-
Constraints on query script:
|
|
1748
|
-
1. datachain.query.DatasetQuery should be used in order to create query
|
|
1749
|
-
for a dataset
|
|
1750
|
-
2. There should not be any .save() call on DatasetQuery since the idea
|
|
1751
|
-
is to create only one dataset as the outcome of the script
|
|
1752
|
-
3. Last statement must be an instance of DatasetQuery
|
|
1753
|
-
|
|
1754
|
-
If save is set to True, we are creating new dataset with results
|
|
1755
|
-
from dataset query. If it's set to False, we will just print results
|
|
1756
|
-
without saving anything
|
|
1757
|
-
|
|
1758
|
-
Example of query script:
|
|
1759
|
-
from datachain.query import DatasetQuery, C
|
|
1760
|
-
DatasetQuery('s3://ldb-public/remote/datasets/mnist-tiny/').filter(
|
|
1761
|
-
C.size > 1000
|
|
1762
|
-
)
|
|
1763
|
-
"""
|
|
1764
|
-
if _execute_last_expression:
|
|
1765
|
-
try:
|
|
1766
|
-
code_ast = ast.parse(query_script)
|
|
1767
|
-
code_ast = self.attach_query_wrapper(code_ast)
|
|
1768
|
-
query_script_compiled = ast.unparse(code_ast)
|
|
1769
|
-
except Exception as exc:
|
|
1770
|
-
raise QueryScriptCompileError(
|
|
1771
|
-
f"Query script failed to compile, reason: {exc}"
|
|
1772
|
-
) from exc
|
|
1773
|
-
else:
|
|
1774
|
-
query_script_compiled = query_script
|
|
1775
|
-
assert not save
|
|
1776
|
-
|
|
1700
|
+
cmd = [python_executable, "-c", query_script]
|
|
1777
1701
|
env = dict(env or os.environ)
|
|
1778
1702
|
env.update(
|
|
1779
1703
|
{
|
|
1780
1704
|
"DATACHAIN_QUERY_PARAMS": json.dumps(params or {}),
|
|
1781
|
-
"PYTHONPATH": os.getcwd(), # For local imports
|
|
1782
|
-
"DATACHAIN_QUERY_SAVE": "1" if save else "",
|
|
1783
|
-
"PYTHONUNBUFFERED": "1",
|
|
1784
1705
|
"DATACHAIN_JOB_ID": job_id or "",
|
|
1785
1706
|
},
|
|
1786
1707
|
)
|
|
1787
|
-
popen_kwargs = {}
|
|
1708
|
+
popen_kwargs: dict[str, Any] = {}
|
|
1788
1709
|
if capture_output:
|
|
1789
1710
|
popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT}
|
|
1790
1711
|
|
|
1791
|
-
cmd =
|
|
1792
|
-
with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # type: ignore[call-overload] # noqa: S603
|
|
1712
|
+
with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603
|
|
1793
1713
|
if capture_output:
|
|
1794
1714
|
args = (proc.stdout, output_hook)
|
|
1795
1715
|
thread = Thread(target=_process_stream, args=args, daemon=True)
|
datachain/cli.py
CHANGED
|
@@ -15,6 +15,7 @@ import shtab
|
|
|
15
15
|
from datachain import utils
|
|
16
16
|
from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
|
|
17
17
|
from datachain.lib.dc import DataChain
|
|
18
|
+
from datachain.telemetry import telemetry
|
|
18
19
|
from datachain.utils import DataChainDir
|
|
19
20
|
|
|
20
21
|
if TYPE_CHECKING:
|
|
@@ -803,7 +804,6 @@ def query(
|
|
|
803
804
|
catalog.query(
|
|
804
805
|
script_content,
|
|
805
806
|
python_executable=python_executable,
|
|
806
|
-
capture_output=False,
|
|
807
807
|
params=params,
|
|
808
808
|
job_id=job_id,
|
|
809
809
|
)
|
|
@@ -872,6 +872,7 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
|
|
|
872
872
|
# This also sets this environment variable for any subprocesses
|
|
873
873
|
os.environ["DEBUG_SHOW_SQL_QUERIES"] = "True"
|
|
874
874
|
|
|
875
|
+
error = None
|
|
875
876
|
try:
|
|
876
877
|
catalog = get_catalog(client_config=client_config)
|
|
877
878
|
if args.command == "cp":
|
|
@@ -1003,14 +1004,16 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
|
|
|
1003
1004
|
print(f"invalid command: {args.command}", file=sys.stderr)
|
|
1004
1005
|
return 1
|
|
1005
1006
|
return 0
|
|
1006
|
-
except BrokenPipeError:
|
|
1007
|
+
except BrokenPipeError as exc:
|
|
1007
1008
|
# Python flushes standard streams on exit; redirect remaining output
|
|
1008
1009
|
# to devnull to avoid another BrokenPipeError at shutdown
|
|
1009
1010
|
# See: https://docs.python.org/3/library/signal.html#note-on-sigpipe
|
|
1011
|
+
error = str(exc)
|
|
1010
1012
|
devnull = os.open(os.devnull, os.O_WRONLY)
|
|
1011
1013
|
os.dup2(devnull, sys.stdout.fileno())
|
|
1012
1014
|
return 141 # 128 + 13 (SIGPIPE)
|
|
1013
1015
|
except (KeyboardInterrupt, Exception) as exc:
|
|
1016
|
+
error = str(exc)
|
|
1014
1017
|
if isinstance(exc, KeyboardInterrupt):
|
|
1015
1018
|
msg = "Operation cancelled by the user"
|
|
1016
1019
|
else:
|
|
@@ -1028,3 +1031,5 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
|
|
|
1028
1031
|
|
|
1029
1032
|
pdb.post_mortem()
|
|
1030
1033
|
return 1
|
|
1034
|
+
finally:
|
|
1035
|
+
telemetry.send_cli_call(args.command, error=error)
|
datachain/client/fsspec.py
CHANGED
|
@@ -3,7 +3,6 @@ import functools
|
|
|
3
3
|
import logging
|
|
4
4
|
import multiprocessing
|
|
5
5
|
import os
|
|
6
|
-
import posixpath
|
|
7
6
|
import re
|
|
8
7
|
import sys
|
|
9
8
|
from abc import ABC, abstractmethod
|
|
@@ -26,8 +25,8 @@ from fsspec.asyn import get_loop, sync
|
|
|
26
25
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
27
26
|
from tqdm import tqdm
|
|
28
27
|
|
|
29
|
-
from datachain.cache import DataChainCache
|
|
30
|
-
from datachain.client.fileslice import
|
|
28
|
+
from datachain.cache import DataChainCache
|
|
29
|
+
from datachain.client.fileslice import FileWrapper
|
|
31
30
|
from datachain.error import ClientError as DataChainClientError
|
|
32
31
|
from datachain.lib.file import File
|
|
33
32
|
from datachain.nodes_fetcher import NodesFetcher
|
|
@@ -187,8 +186,8 @@ class Client(ABC):
|
|
|
187
186
|
def url(self, path: str, expires: int = 3600, **kwargs) -> str:
|
|
188
187
|
return self.fs.sign(self.get_full_path(path), expiration=expires, **kwargs)
|
|
189
188
|
|
|
190
|
-
async def get_current_etag(self,
|
|
191
|
-
info = await self.fs._info(self.get_full_path(
|
|
189
|
+
async def get_current_etag(self, file: "File") -> str:
|
|
190
|
+
info = await self.fs._info(self.get_full_path(file.path))
|
|
192
191
|
return self.info_to_file(info, "").etag
|
|
193
192
|
|
|
194
193
|
async def get_size(self, path: str) -> int:
|
|
@@ -317,7 +316,7 @@ class Client(ABC):
|
|
|
317
316
|
|
|
318
317
|
def instantiate_object(
|
|
319
318
|
self,
|
|
320
|
-
|
|
319
|
+
file: "File",
|
|
321
320
|
dst: str,
|
|
322
321
|
progress_bar: tqdm,
|
|
323
322
|
force: bool = False,
|
|
@@ -328,10 +327,10 @@ class Client(ABC):
|
|
|
328
327
|
else:
|
|
329
328
|
progress_bar.close()
|
|
330
329
|
raise FileExistsError(f"Path {dst} already exists")
|
|
331
|
-
self.do_instantiate_object(
|
|
330
|
+
self.do_instantiate_object(file, dst)
|
|
332
331
|
|
|
333
|
-
def do_instantiate_object(self,
|
|
334
|
-
src = self.cache.get_path(
|
|
332
|
+
def do_instantiate_object(self, file: "File", dst: str) -> None:
|
|
333
|
+
src = self.cache.get_path(file)
|
|
335
334
|
assert src is not None
|
|
336
335
|
|
|
337
336
|
try:
|
|
@@ -341,66 +340,33 @@ class Client(ABC):
|
|
|
341
340
|
copy2(src, dst)
|
|
342
341
|
|
|
343
342
|
def open_object(
|
|
344
|
-
self,
|
|
343
|
+
self, file: File, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
|
|
345
344
|
) -> BinaryIO:
|
|
346
345
|
"""Open a file, including files in tar archives."""
|
|
347
|
-
|
|
348
|
-
if use_cache and (cache_path := self.cache.get_path(uid)):
|
|
346
|
+
if use_cache and (cache_path := self.cache.get_path(file)):
|
|
349
347
|
return open(cache_path, mode="rb") # noqa: SIM115
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
offset = location["offset"]
|
|
359
|
-
size = location["size"]
|
|
360
|
-
parent = location["parent"]
|
|
361
|
-
|
|
362
|
-
parent_uid = UniqueId(
|
|
363
|
-
parent["source"],
|
|
364
|
-
parent["path"],
|
|
365
|
-
parent["size"],
|
|
366
|
-
parent["etag"],
|
|
367
|
-
location=parent["location"],
|
|
368
|
-
)
|
|
369
|
-
f = self.open_object(parent_uid, use_cache=use_cache)
|
|
370
|
-
return FileSlice(f, offset, size, posixpath.basename(uid.path))
|
|
371
|
-
|
|
372
|
-
def download(self, uid: UniqueId, *, callback: Callback = DEFAULT_CALLBACK) -> None:
|
|
373
|
-
sync(get_loop(), functools.partial(self._download, uid, callback=callback))
|
|
374
|
-
|
|
375
|
-
async def _download(self, uid: UniqueId, *, callback: "Callback" = None) -> None:
|
|
376
|
-
if self.cache.contains(uid):
|
|
348
|
+
assert not file.location
|
|
349
|
+
return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value]
|
|
350
|
+
|
|
351
|
+
def download(self, file: File, *, callback: Callback = DEFAULT_CALLBACK) -> None:
|
|
352
|
+
sync(get_loop(), functools.partial(self._download, file, callback=callback))
|
|
353
|
+
|
|
354
|
+
async def _download(self, file: File, *, callback: "Callback" = None) -> None:
|
|
355
|
+
if self.cache.contains(file):
|
|
377
356
|
# Already in cache, so there's nothing to do.
|
|
378
357
|
return
|
|
379
|
-
await self._put_in_cache(
|
|
358
|
+
await self._put_in_cache(file, callback=callback)
|
|
380
359
|
|
|
381
|
-
def put_in_cache(self,
|
|
382
|
-
sync(get_loop(), functools.partial(self._put_in_cache,
|
|
360
|
+
def put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
|
|
361
|
+
sync(get_loop(), functools.partial(self._put_in_cache, file, callback=callback))
|
|
383
362
|
|
|
384
|
-
async def _put_in_cache(
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
loop = asyncio.get_running_loop()
|
|
390
|
-
await loop.run_in_executor(
|
|
391
|
-
None, functools.partial(self._download_from_tar, uid, callback=callback)
|
|
392
|
-
)
|
|
393
|
-
return
|
|
394
|
-
if uid.etag:
|
|
395
|
-
etag = await self.get_current_etag(uid)
|
|
396
|
-
if uid.etag != etag:
|
|
363
|
+
async def _put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
|
|
364
|
+
assert not file.location
|
|
365
|
+
if file.etag:
|
|
366
|
+
etag = await self.get_current_etag(file)
|
|
367
|
+
if file.etag != etag:
|
|
397
368
|
raise FileNotFoundError(
|
|
398
|
-
f"Invalid etag for {
|
|
399
|
-
f"expected {
|
|
369
|
+
f"Invalid etag for {file.source}/{file.path}: "
|
|
370
|
+
f"expected {file.etag}, got {etag}"
|
|
400
371
|
)
|
|
401
|
-
await self.cache.download(
|
|
402
|
-
|
|
403
|
-
def _download_from_tar(self, uid, *, callback: "Callback" = None):
|
|
404
|
-
with self._open_tar(uid, use_cache=False) as f:
|
|
405
|
-
contents = f.read()
|
|
406
|
-
self.cache.store_data(uid, contents)
|
|
372
|
+
await self.cache.download(file, self, callback=callback)
|
datachain/client/local.py
CHANGED
|
@@ -7,7 +7,6 @@ from urllib.parse import urlparse
|
|
|
7
7
|
|
|
8
8
|
from fsspec.implementations.local import LocalFileSystem
|
|
9
9
|
|
|
10
|
-
from datachain.cache import UniqueId
|
|
11
10
|
from datachain.lib.file import File
|
|
12
11
|
from datachain.storage import StorageURI
|
|
13
12
|
|
|
@@ -114,8 +113,8 @@ class FileClient(Client):
|
|
|
114
113
|
use_symlinks=use_symlinks,
|
|
115
114
|
)
|
|
116
115
|
|
|
117
|
-
async def get_current_etag(self,
|
|
118
|
-
info = self.fs.info(self.get_full_path(
|
|
116
|
+
async def get_current_etag(self, file: "File") -> str:
|
|
117
|
+
info = self.fs.info(self.get_full_path(file.path))
|
|
119
118
|
return self.info_to_file(info, "").etag
|
|
120
119
|
|
|
121
120
|
async def get_size(self, path: str) -> int:
|
datachain/dataset.py
CHANGED
|
@@ -12,6 +12,7 @@ from typing import (
|
|
|
12
12
|
from urllib.parse import urlparse
|
|
13
13
|
|
|
14
14
|
from datachain.client import Client
|
|
15
|
+
from datachain.error import DatasetVersionNotFoundError
|
|
15
16
|
from datachain.sql.types import NAME_TYPES_MAPPING, SQLType
|
|
16
17
|
|
|
17
18
|
if TYPE_CHECKING:
|
|
@@ -417,7 +418,9 @@ class DatasetRecord:
|
|
|
417
418
|
|
|
418
419
|
def get_version(self, version: int) -> DatasetVersion:
|
|
419
420
|
if not self.has_version(version):
|
|
420
|
-
raise
|
|
421
|
+
raise DatasetVersionNotFoundError(
|
|
422
|
+
f"Dataset {self.name} does not have version {version}"
|
|
423
|
+
)
|
|
421
424
|
return next(
|
|
422
425
|
v
|
|
423
426
|
for v in self.versions # type: ignore [union-attr]
|
|
@@ -435,7 +438,9 @@ class DatasetRecord:
|
|
|
435
438
|
Get identifier in the form my-dataset@v3
|
|
436
439
|
"""
|
|
437
440
|
if not self.has_version(version):
|
|
438
|
-
raise
|
|
441
|
+
raise DatasetVersionNotFoundError(
|
|
442
|
+
f"Dataset {self.name} doesn't have a version {version}"
|
|
443
|
+
)
|
|
439
444
|
return f"{self.name}@v{version}"
|
|
440
445
|
|
|
441
446
|
def uri(self, version: int) -> str:
|
datachain/error.py
CHANGED
|
@@ -10,6 +10,10 @@ class DatasetNotFoundError(NotFoundError):
|
|
|
10
10
|
pass
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
class DatasetVersionNotFoundError(NotFoundError):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
13
17
|
class DatasetInvalidVersionError(Exception):
|
|
14
18
|
pass
|
|
15
19
|
|
|
@@ -32,14 +36,12 @@ class QueryScriptRunError(Exception):
|
|
|
32
36
|
Attributes:
|
|
33
37
|
message Explanation of the error
|
|
34
38
|
return_code Code returned by the subprocess
|
|
35
|
-
output STDOUT + STDERR output of the subprocess
|
|
36
39
|
"""
|
|
37
40
|
|
|
38
|
-
def __init__(self, message: str, return_code: int = 0
|
|
41
|
+
def __init__(self, message: str, return_code: int = 0):
|
|
39
42
|
self.message = message
|
|
40
43
|
self.return_code = return_code
|
|
41
|
-
|
|
42
|
-
super().__init__(self.message)
|
|
44
|
+
super().__init__(message)
|
|
43
45
|
|
|
44
46
|
|
|
45
47
|
class QueryScriptCancelError(QueryScriptRunError):
|
datachain/lib/arrow.py
CHANGED
|
@@ -4,11 +4,11 @@ from tempfile import NamedTemporaryFile
|
|
|
4
4
|
from typing import TYPE_CHECKING, Optional
|
|
5
5
|
|
|
6
6
|
import pyarrow as pa
|
|
7
|
-
from pyarrow.dataset import dataset
|
|
7
|
+
from pyarrow.dataset import CsvFileFormat, dataset
|
|
8
8
|
from tqdm import tqdm
|
|
9
9
|
|
|
10
10
|
from datachain.lib.data_model import dict_to_data_model
|
|
11
|
-
from datachain.lib.file import
|
|
11
|
+
from datachain.lib.file import ArrowRow, File
|
|
12
12
|
from datachain.lib.model_store import ModelStore
|
|
13
13
|
from datachain.lib.udf import Generator
|
|
14
14
|
|
|
@@ -49,7 +49,8 @@ class ArrowGenerator(Generator):
|
|
|
49
49
|
|
|
50
50
|
def process(self, file: File):
|
|
51
51
|
if file._caching_enabled:
|
|
52
|
-
|
|
52
|
+
file.ensure_cached()
|
|
53
|
+
path = file.get_local_path()
|
|
53
54
|
ds = dataset(path, schema=self.input_schema, **self.kwargs)
|
|
54
55
|
elif self.nrows:
|
|
55
56
|
path = _nrows_file(file, self.nrows)
|
|
@@ -83,7 +84,12 @@ class ArrowGenerator(Generator):
|
|
|
83
84
|
vals_dict[field] = val
|
|
84
85
|
vals = [self.output_schema(**vals_dict)]
|
|
85
86
|
if self.source:
|
|
86
|
-
|
|
87
|
+
kwargs: dict = self.kwargs
|
|
88
|
+
# Can't serialize CsvFileFormat; may lose formatting options.
|
|
89
|
+
if isinstance(kwargs.get("format"), CsvFileFormat):
|
|
90
|
+
kwargs["format"] = "csv"
|
|
91
|
+
arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs)
|
|
92
|
+
yield [arrow_file, *vals]
|
|
87
93
|
else:
|
|
88
94
|
yield vals
|
|
89
95
|
index += 1
|