datachain 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +3 -4
- datachain/cache.py +10 -4
- datachain/catalog/catalog.py +42 -16
- datachain/cli.py +48 -32
- datachain/data_storage/metastore.py +24 -0
- datachain/data_storage/warehouse.py +3 -1
- datachain/job.py +56 -0
- datachain/lib/arrow.py +19 -7
- datachain/lib/clip.py +89 -66
- datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
- datachain/lib/convert/sql_to_python.py +23 -0
- datachain/lib/convert/values_to_tuples.py +51 -33
- datachain/lib/data_model.py +6 -27
- datachain/lib/dataset_info.py +70 -0
- datachain/lib/dc.py +618 -156
- datachain/lib/file.py +130 -22
- datachain/lib/image.py +1 -1
- datachain/lib/meta_formats.py +14 -2
- datachain/lib/model_store.py +3 -2
- datachain/lib/pytorch.py +10 -7
- datachain/lib/signal_schema.py +19 -11
- datachain/lib/text.py +2 -1
- datachain/lib/udf.py +56 -5
- datachain/lib/udf_signature.py +1 -1
- datachain/node.py +11 -8
- datachain/query/dataset.py +62 -28
- datachain/query/schema.py +2 -0
- datachain/query/session.py +4 -4
- datachain/sql/functions/array.py +12 -0
- datachain/sql/functions/string.py +8 -0
- datachain/torch/__init__.py +1 -1
- datachain/utils.py +6 -0
- datachain-0.2.13.dist-info/METADATA +411 -0
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/RECORD +38 -42
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/WHEEL +1 -1
- datachain/lib/gpt4_vision.py +0 -97
- datachain/lib/hf_image_to_text.py +0 -97
- datachain/lib/hf_pipeline.py +0 -90
- datachain/lib/image_transform.py +0 -103
- datachain/lib/iptc_exif_xmp.py +0 -76
- datachain/lib/unstructured.py +0 -41
- datachain/text/__init__.py +0 -3
- datachain-0.2.11.dist-info/METADATA +0 -431
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/LICENSE +0 -0
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/top_level.txt +0 -0
datachain/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from datachain.lib.data_model import DataModel, DataType,
|
|
1
|
+
from datachain.lib.data_model import DataModel, DataType, is_chain_type
|
|
2
2
|
from datachain.lib.dc import C, Column, DataChain, Sys
|
|
3
3
|
from datachain.lib.file import (
|
|
4
4
|
File,
|
|
@@ -8,15 +8,14 @@ from datachain.lib.file import (
|
|
|
8
8
|
TarVFile,
|
|
9
9
|
TextFile,
|
|
10
10
|
)
|
|
11
|
+
from datachain.lib.model_store import ModelStore
|
|
11
12
|
from datachain.lib.udf import Aggregator, Generator, Mapper
|
|
12
13
|
from datachain.lib.utils import AbstractUDF, DataChainError
|
|
13
|
-
from datachain.query.dataset import UDF as BaseUDF # noqa: N811
|
|
14
14
|
from datachain.query.session import Session
|
|
15
15
|
|
|
16
16
|
__all__ = [
|
|
17
17
|
"AbstractUDF",
|
|
18
18
|
"Aggregator",
|
|
19
|
-
"BaseUDF",
|
|
20
19
|
"C",
|
|
21
20
|
"Column",
|
|
22
21
|
"DataChain",
|
|
@@ -24,12 +23,12 @@ __all__ = [
|
|
|
24
23
|
"DataModel",
|
|
25
24
|
"DataType",
|
|
26
25
|
"File",
|
|
27
|
-
"FileBasic",
|
|
28
26
|
"FileError",
|
|
29
27
|
"Generator",
|
|
30
28
|
"ImageFile",
|
|
31
29
|
"IndexedFile",
|
|
32
30
|
"Mapper",
|
|
31
|
+
"ModelStore",
|
|
33
32
|
"Session",
|
|
34
33
|
"Sys",
|
|
35
34
|
"TarVFile",
|
datachain/cache.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import hashlib
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
|
+
from datetime import datetime
|
|
4
5
|
from functools import partial
|
|
5
6
|
from typing import TYPE_CHECKING, Optional
|
|
6
7
|
|
|
@@ -9,6 +10,8 @@ from dvc_data.hashfile.db.local import LocalHashFileDB
|
|
|
9
10
|
from dvc_objects.fs.local import LocalFileSystem
|
|
10
11
|
from fsspec.callbacks import Callback, TqdmCallback
|
|
11
12
|
|
|
13
|
+
from datachain.utils import TIME_ZERO
|
|
14
|
+
|
|
12
15
|
from .progress import Tqdm
|
|
13
16
|
|
|
14
17
|
if TYPE_CHECKING:
|
|
@@ -23,10 +26,13 @@ class UniqueId:
|
|
|
23
26
|
storage: "StorageURI"
|
|
24
27
|
parent: str
|
|
25
28
|
name: str
|
|
26
|
-
etag: str
|
|
27
29
|
size: int
|
|
28
|
-
|
|
29
|
-
|
|
30
|
+
etag: str
|
|
31
|
+
version: str = ""
|
|
32
|
+
is_latest: bool = True
|
|
33
|
+
vtype: str = ""
|
|
34
|
+
location: Optional[str] = None
|
|
35
|
+
last_modified: datetime = TIME_ZERO
|
|
30
36
|
|
|
31
37
|
@property
|
|
32
38
|
def path(self) -> str:
|
|
@@ -49,7 +55,7 @@ class UniqueId:
|
|
|
49
55
|
def get_hash(self) -> str:
|
|
50
56
|
etag = f"{self.vtype}{self.location}" if self.vtype else self.etag
|
|
51
57
|
return sha256(
|
|
52
|
-
f"{self.storage}/{self.parent}/{self.name}/{etag}".encode()
|
|
58
|
+
f"{self.storage}/{self.parent}/{self.name}/{self.version}/{etag}".encode()
|
|
53
59
|
).hexdigest()
|
|
54
60
|
|
|
55
61
|
|
datachain/catalog/catalog.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import ast
|
|
2
|
+
import glob
|
|
2
3
|
import io
|
|
3
4
|
import json
|
|
4
5
|
import logging
|
|
@@ -84,12 +85,14 @@ if TYPE_CHECKING:
|
|
|
84
85
|
AbstractMetastore,
|
|
85
86
|
AbstractWarehouse,
|
|
86
87
|
)
|
|
88
|
+
from datachain.dataset import DatasetVersion
|
|
89
|
+
from datachain.job import Job
|
|
87
90
|
|
|
88
91
|
logger = logging.getLogger("datachain")
|
|
89
92
|
|
|
90
93
|
DEFAULT_DATASET_DIR = "dataset"
|
|
91
94
|
DATASET_FILE_SUFFIX = ".edatachain"
|
|
92
|
-
FEATURE_CLASSES = ["
|
|
95
|
+
FEATURE_CLASSES = ["DataModel"]
|
|
93
96
|
|
|
94
97
|
TTL_INT = 4 * 60 * 60
|
|
95
98
|
|
|
@@ -707,7 +710,12 @@ class Catalog:
|
|
|
707
710
|
|
|
708
711
|
client_config = client_config or self.client_config
|
|
709
712
|
client, path = self.parse_url(source, **client_config)
|
|
710
|
-
|
|
713
|
+
stem = os.path.basename(os.path.normpath(path))
|
|
714
|
+
prefix = (
|
|
715
|
+
posixpath.dirname(path)
|
|
716
|
+
if glob.has_magic(stem) or client.fs.isfile(source)
|
|
717
|
+
else path
|
|
718
|
+
)
|
|
711
719
|
storage_dataset_name = Storage.dataset_name(
|
|
712
720
|
client.uri, posixpath.join(prefix, "")
|
|
713
721
|
)
|
|
@@ -948,13 +956,9 @@ class Catalog:
|
|
|
948
956
|
ms = self.metastore.clone(uri, None)
|
|
949
957
|
st = self.warehouse.clone()
|
|
950
958
|
listing = Listing(None, ms, st, client, None)
|
|
951
|
-
rows = (
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
)
|
|
955
|
-
.select()
|
|
956
|
-
.to_records()
|
|
957
|
-
)
|
|
959
|
+
rows = DatasetQuery(
|
|
960
|
+
name=dataset.name, version=ds_version, catalog=self
|
|
961
|
+
).to_db_records()
|
|
958
962
|
indexed_sources.append(
|
|
959
963
|
(
|
|
960
964
|
listing,
|
|
@@ -1160,9 +1164,8 @@ class Catalog:
|
|
|
1160
1164
|
if not dataset_version.preview:
|
|
1161
1165
|
values["preview"] = (
|
|
1162
1166
|
DatasetQuery(name=dataset.name, version=version, catalog=self)
|
|
1163
|
-
.select()
|
|
1164
1167
|
.limit(20)
|
|
1165
|
-
.
|
|
1168
|
+
.to_db_records()
|
|
1166
1169
|
)
|
|
1167
1170
|
|
|
1168
1171
|
if not values:
|
|
@@ -1420,6 +1423,25 @@ class Catalog:
|
|
|
1420
1423
|
if not d.is_bucket_listing:
|
|
1421
1424
|
yield d
|
|
1422
1425
|
|
|
1426
|
+
def list_datasets_versions(
|
|
1427
|
+
self,
|
|
1428
|
+
) -> Iterator[tuple[DatasetRecord, "DatasetVersion", Optional["Job"]]]:
|
|
1429
|
+
"""Iterate over all dataset versions with related jobs."""
|
|
1430
|
+
datasets = list(self.ls_datasets())
|
|
1431
|
+
|
|
1432
|
+
# preselect dataset versions jobs from db to avoid multiple queries
|
|
1433
|
+
jobs_ids: set[str] = {
|
|
1434
|
+
v.job_id for ds in datasets for v in ds.versions if v.job_id
|
|
1435
|
+
}
|
|
1436
|
+
jobs: dict[str, Job] = {}
|
|
1437
|
+
if jobs_ids:
|
|
1438
|
+
jobs = {j.id: j for j in self.metastore.list_jobs_by_ids(list(jobs_ids))}
|
|
1439
|
+
|
|
1440
|
+
for d in datasets:
|
|
1441
|
+
yield from (
|
|
1442
|
+
(d, v, jobs.get(v.job_id) if v.job_id else None) for v in d.versions
|
|
1443
|
+
)
|
|
1444
|
+
|
|
1423
1445
|
def ls_dataset_rows(
|
|
1424
1446
|
self, name: str, version: int, offset=None, limit=None
|
|
1425
1447
|
) -> list[dict]:
|
|
@@ -1427,7 +1449,7 @@ class Catalog:
|
|
|
1427
1449
|
|
|
1428
1450
|
dataset = self.get_dataset(name)
|
|
1429
1451
|
|
|
1430
|
-
q = DatasetQuery(name=dataset.name, version=version, catalog=self)
|
|
1452
|
+
q = DatasetQuery(name=dataset.name, version=version, catalog=self)
|
|
1431
1453
|
if limit:
|
|
1432
1454
|
q = q.limit(limit)
|
|
1433
1455
|
if offset:
|
|
@@ -1435,7 +1457,7 @@ class Catalog:
|
|
|
1435
1457
|
|
|
1436
1458
|
q = q.order_by("sys__id")
|
|
1437
1459
|
|
|
1438
|
-
return q.
|
|
1460
|
+
return q.to_db_records()
|
|
1439
1461
|
|
|
1440
1462
|
def signed_url(self, source: str, path: str, client_config=None) -> str:
|
|
1441
1463
|
client_config = client_config or self.client_config
|
|
@@ -1609,6 +1631,7 @@ class Catalog:
|
|
|
1609
1631
|
...
|
|
1610
1632
|
}
|
|
1611
1633
|
"""
|
|
1634
|
+
from datachain.lib.file import File
|
|
1612
1635
|
from datachain.lib.signal_schema import DEFAULT_DELIMITER, SignalSchema
|
|
1613
1636
|
|
|
1614
1637
|
version = self.get_dataset(dataset_name).get_version(dataset_version)
|
|
@@ -1616,7 +1639,7 @@ class Catalog:
|
|
|
1616
1639
|
file_signals_values = {}
|
|
1617
1640
|
|
|
1618
1641
|
schema = SignalSchema.deserialize(version.feature_schema)
|
|
1619
|
-
for file_signals in schema.
|
|
1642
|
+
for file_signals in schema.get_signals(File):
|
|
1620
1643
|
prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
|
|
1621
1644
|
file_signals_values[file_signals] = {
|
|
1622
1645
|
c_name.removeprefix(prefix): c_value
|
|
@@ -1657,10 +1680,13 @@ class Catalog:
|
|
|
1657
1680
|
row["source"],
|
|
1658
1681
|
row["parent"],
|
|
1659
1682
|
row["name"],
|
|
1660
|
-
row["etag"],
|
|
1661
1683
|
row["size"],
|
|
1684
|
+
row["etag"],
|
|
1685
|
+
row["version"],
|
|
1686
|
+
row["is_latest"],
|
|
1662
1687
|
row["vtype"],
|
|
1663
1688
|
row["location"],
|
|
1689
|
+
row["last_modified"],
|
|
1664
1690
|
)
|
|
1665
1691
|
|
|
1666
1692
|
def ls(
|
|
@@ -1992,7 +2018,7 @@ class Catalog:
|
|
|
1992
2018
|
)
|
|
1993
2019
|
if proc.returncode == QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE:
|
|
1994
2020
|
raise QueryScriptRunError(
|
|
1995
|
-
"Last line in a script was not an instance of
|
|
2021
|
+
"Last line in a script was not an instance of DataChain",
|
|
1996
2022
|
return_code=proc.returncode,
|
|
1997
2023
|
output=output,
|
|
1998
2024
|
)
|
datachain/cli.py
CHANGED
|
@@ -3,7 +3,7 @@ import os
|
|
|
3
3
|
import shlex
|
|
4
4
|
import sys
|
|
5
5
|
import traceback
|
|
6
|
-
from argparse import
|
|
6
|
+
from argparse import Action, ArgumentParser, ArgumentTypeError, Namespace
|
|
7
7
|
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
8
8
|
from importlib.metadata import PackageNotFoundError, version
|
|
9
9
|
from itertools import chain
|
|
@@ -106,10 +106,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
106
106
|
parser = ArgumentParser(
|
|
107
107
|
description="DataChain: Wrangle unstructured AI data at scale", prog="datachain"
|
|
108
108
|
)
|
|
109
|
-
|
|
110
109
|
parser.add_argument("-V", "--version", action="version", version=__version__)
|
|
111
|
-
parser.add_argument("--internal-run-udf", action="store_true", help=SUPPRESS)
|
|
112
|
-
parser.add_argument("--internal-run-udf-worker", action="store_true", help=SUPPRESS)
|
|
113
110
|
|
|
114
111
|
parent_parser = ArgumentParser(add_help=False)
|
|
115
112
|
parent_parser.add_argument(
|
|
@@ -150,9 +147,15 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
150
147
|
help="Drop into the pdb debugger on fatal exception",
|
|
151
148
|
)
|
|
152
149
|
|
|
153
|
-
subp = parser.add_subparsers(
|
|
150
|
+
subp = parser.add_subparsers(
|
|
151
|
+
title="Available Commands",
|
|
152
|
+
metavar="command",
|
|
153
|
+
dest="command",
|
|
154
|
+
help=f"Use `{parser.prog} command --help` for command-specific help.",
|
|
155
|
+
required=True,
|
|
156
|
+
)
|
|
154
157
|
parse_cp = subp.add_parser(
|
|
155
|
-
"cp", parents=[parent_parser],
|
|
158
|
+
"cp", parents=[parent_parser], description="Copy data files from the cloud"
|
|
156
159
|
)
|
|
157
160
|
add_sources_arg(parse_cp).complete = shtab.DIR # type: ignore[attr-defined]
|
|
158
161
|
parse_cp.add_argument("output", type=str, help="Output")
|
|
@@ -179,7 +182,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
179
182
|
)
|
|
180
183
|
|
|
181
184
|
parse_clone = subp.add_parser(
|
|
182
|
-
"clone", parents=[parent_parser],
|
|
185
|
+
"clone", parents=[parent_parser], description="Copy data files from the cloud"
|
|
183
186
|
)
|
|
184
187
|
add_sources_arg(parse_clone).complete = shtab.DIR # type: ignore[attr-defined]
|
|
185
188
|
parse_clone.add_argument("output", type=str, help="Output")
|
|
@@ -222,7 +225,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
222
225
|
)
|
|
223
226
|
|
|
224
227
|
parse_pull = subp.add_parser(
|
|
225
|
-
"pull",
|
|
228
|
+
"pull",
|
|
229
|
+
parents=[parent_parser],
|
|
230
|
+
description="Pull specific dataset version from SaaS",
|
|
226
231
|
)
|
|
227
232
|
parse_pull.add_argument(
|
|
228
233
|
"dataset",
|
|
@@ -263,7 +268,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
263
268
|
)
|
|
264
269
|
|
|
265
270
|
parse_edit_dataset = subp.add_parser(
|
|
266
|
-
"edit-dataset", parents=[parent_parser],
|
|
271
|
+
"edit-dataset", parents=[parent_parser], description="Edit dataset metadata"
|
|
267
272
|
)
|
|
268
273
|
parse_edit_dataset.add_argument("name", type=str, help="Dataset name")
|
|
269
274
|
parse_edit_dataset.add_argument(
|
|
@@ -285,9 +290,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
285
290
|
help="Dataset labels",
|
|
286
291
|
)
|
|
287
292
|
|
|
288
|
-
subp.add_parser("ls-datasets", parents=[parent_parser],
|
|
293
|
+
subp.add_parser("ls-datasets", parents=[parent_parser], description="List datasets")
|
|
289
294
|
rm_dataset_parser = subp.add_parser(
|
|
290
|
-
"rm-dataset", parents=[parent_parser],
|
|
295
|
+
"rm-dataset", parents=[parent_parser], description="Removes dataset"
|
|
291
296
|
)
|
|
292
297
|
rm_dataset_parser.add_argument("name", type=str, help="Dataset name")
|
|
293
298
|
rm_dataset_parser.add_argument(
|
|
@@ -305,7 +310,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
305
310
|
)
|
|
306
311
|
|
|
307
312
|
dataset_stats_parser = subp.add_parser(
|
|
308
|
-
"dataset-stats",
|
|
313
|
+
"dataset-stats",
|
|
314
|
+
parents=[parent_parser],
|
|
315
|
+
description="Shows basic dataset stats",
|
|
309
316
|
)
|
|
310
317
|
dataset_stats_parser.add_argument("name", type=str, help="Dataset name")
|
|
311
318
|
dataset_stats_parser.add_argument(
|
|
@@ -330,7 +337,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
330
337
|
)
|
|
331
338
|
|
|
332
339
|
parse_merge_datasets = subp.add_parser(
|
|
333
|
-
"merge-datasets", parents=[parent_parser],
|
|
340
|
+
"merge-datasets", parents=[parent_parser], description="Merges datasets"
|
|
334
341
|
)
|
|
335
342
|
parse_merge_datasets.add_argument(
|
|
336
343
|
"--src",
|
|
@@ -360,7 +367,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
360
367
|
)
|
|
361
368
|
|
|
362
369
|
parse_ls = subp.add_parser(
|
|
363
|
-
"ls", parents=[parent_parser],
|
|
370
|
+
"ls", parents=[parent_parser], description="List storage contents"
|
|
364
371
|
)
|
|
365
372
|
add_sources_arg(parse_ls, nargs="*")
|
|
366
373
|
parse_ls.add_argument(
|
|
@@ -378,7 +385,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
378
385
|
)
|
|
379
386
|
|
|
380
387
|
parse_du = subp.add_parser(
|
|
381
|
-
"du", parents=[parent_parser],
|
|
388
|
+
"du", parents=[parent_parser], description="Display space usage"
|
|
382
389
|
)
|
|
383
390
|
add_sources_arg(parse_du)
|
|
384
391
|
parse_du.add_argument(
|
|
@@ -408,7 +415,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
408
415
|
)
|
|
409
416
|
|
|
410
417
|
parse_find = subp.add_parser(
|
|
411
|
-
"find", parents=[parent_parser],
|
|
418
|
+
"find", parents=[parent_parser], description="Search in a directory hierarchy"
|
|
412
419
|
)
|
|
413
420
|
add_sources_arg(parse_find)
|
|
414
421
|
parse_find.add_argument(
|
|
@@ -461,20 +468,20 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
461
468
|
)
|
|
462
469
|
|
|
463
470
|
parse_index = subp.add_parser(
|
|
464
|
-
"index", parents=[parent_parser],
|
|
471
|
+
"index", parents=[parent_parser], description="Index storage location"
|
|
465
472
|
)
|
|
466
473
|
add_sources_arg(parse_index)
|
|
467
474
|
|
|
468
475
|
subp.add_parser(
|
|
469
476
|
"find-stale-storages",
|
|
470
477
|
parents=[parent_parser],
|
|
471
|
-
|
|
478
|
+
description="Finds and marks stale storages",
|
|
472
479
|
)
|
|
473
480
|
|
|
474
481
|
show_parser = subp.add_parser(
|
|
475
482
|
"show",
|
|
476
483
|
parents=[parent_parser],
|
|
477
|
-
|
|
484
|
+
description="Create a new dataset with a query script",
|
|
478
485
|
)
|
|
479
486
|
show_parser.add_argument("name", type=str, help="Dataset name")
|
|
480
487
|
show_parser.add_argument(
|
|
@@ -484,12 +491,13 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
484
491
|
type=int,
|
|
485
492
|
help="Dataset version",
|
|
486
493
|
)
|
|
494
|
+
show_parser.add_argument("--schema", action="store_true", help="Show schema")
|
|
487
495
|
add_show_args(show_parser)
|
|
488
496
|
|
|
489
497
|
query_parser = subp.add_parser(
|
|
490
498
|
"query",
|
|
491
499
|
parents=[parent_parser],
|
|
492
|
-
|
|
500
|
+
description="Create a new dataset with a query script",
|
|
493
501
|
)
|
|
494
502
|
query_parser.add_argument(
|
|
495
503
|
"script", metavar="<script.py>", type=str, help="Filepath for script"
|
|
@@ -520,7 +528,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
520
528
|
)
|
|
521
529
|
|
|
522
530
|
apply_udf_parser = subp.add_parser(
|
|
523
|
-
"apply-udf", parents=[parent_parser],
|
|
531
|
+
"apply-udf", parents=[parent_parser], description="Apply UDF"
|
|
524
532
|
)
|
|
525
533
|
apply_udf_parser.add_argument("udf", type=str, help="UDF location")
|
|
526
534
|
apply_udf_parser.add_argument("source", type=str, help="Source storage or dataset")
|
|
@@ -541,12 +549,14 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
541
549
|
"--udf-params", type=str, default=None, help="UDF class parameters"
|
|
542
550
|
)
|
|
543
551
|
subp.add_parser(
|
|
544
|
-
"clear-cache", parents=[parent_parser],
|
|
552
|
+
"clear-cache", parents=[parent_parser], description="Clear the local file cache"
|
|
545
553
|
)
|
|
546
554
|
subp.add_parser(
|
|
547
|
-
"gc", parents=[parent_parser],
|
|
555
|
+
"gc", parents=[parent_parser], description="Garbage collect temporary tables"
|
|
548
556
|
)
|
|
549
557
|
|
|
558
|
+
subp.add_parser("internal-run-udf", parents=[parent_parser])
|
|
559
|
+
subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
|
|
550
560
|
add_completion_parser(subp, [parent_parser])
|
|
551
561
|
return parser
|
|
552
562
|
|
|
@@ -555,7 +565,7 @@ def add_completion_parser(subparsers, parents):
|
|
|
555
565
|
parser = subparsers.add_parser(
|
|
556
566
|
"completion",
|
|
557
567
|
parents=parents,
|
|
558
|
-
|
|
568
|
+
description="Output shell completion script",
|
|
559
569
|
)
|
|
560
570
|
parser.add_argument(
|
|
561
571
|
"-s",
|
|
@@ -807,18 +817,27 @@ def show(
|
|
|
807
817
|
offset: int = 0,
|
|
808
818
|
columns: Sequence[str] = (),
|
|
809
819
|
no_collapse: bool = False,
|
|
820
|
+
schema: bool = False,
|
|
810
821
|
) -> None:
|
|
822
|
+
from datachain.lib.dc import DataChain
|
|
811
823
|
from datachain.query import DatasetQuery
|
|
812
824
|
from datachain.utils import show_records
|
|
813
825
|
|
|
826
|
+
dataset = catalog.get_dataset(name)
|
|
827
|
+
dataset_version = dataset.get_version(version or dataset.latest_version)
|
|
828
|
+
|
|
814
829
|
query = (
|
|
815
830
|
DatasetQuery(name=name, version=version, catalog=catalog)
|
|
816
831
|
.select(*columns)
|
|
817
832
|
.limit(limit)
|
|
818
833
|
.offset(offset)
|
|
819
834
|
)
|
|
820
|
-
records = query.
|
|
835
|
+
records = query.to_db_records()
|
|
821
836
|
show_records(records, collapse_columns=not no_collapse)
|
|
837
|
+
if schema and dataset_version.feature_schema:
|
|
838
|
+
print("\nSchema:")
|
|
839
|
+
dc = DataChain(name=name, version=version, catalog=catalog)
|
|
840
|
+
dc.print_schema()
|
|
822
841
|
|
|
823
842
|
|
|
824
843
|
def query(
|
|
@@ -901,27 +920,23 @@ def completion(shell: str) -> str:
|
|
|
901
920
|
)
|
|
902
921
|
|
|
903
922
|
|
|
904
|
-
def main(argv: Optional[list[str]] = None) -> int: # noqa: C901,
|
|
923
|
+
def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR0915
|
|
905
924
|
# Required for Windows multiprocessing support
|
|
906
925
|
freeze_support()
|
|
907
926
|
|
|
908
927
|
parser = get_parser()
|
|
909
928
|
args = parser.parse_args(argv)
|
|
910
929
|
|
|
911
|
-
if args.
|
|
930
|
+
if args.command == "internal-run-udf":
|
|
912
931
|
from datachain.query.dispatch import udf_entrypoint
|
|
913
932
|
|
|
914
933
|
return udf_entrypoint()
|
|
915
934
|
|
|
916
|
-
if args.
|
|
935
|
+
if args.command == "internal-run-udf-worker":
|
|
917
936
|
from datachain.query.dispatch import udf_worker_entrypoint
|
|
918
937
|
|
|
919
938
|
return udf_worker_entrypoint()
|
|
920
939
|
|
|
921
|
-
if args.command is None:
|
|
922
|
-
parser.print_help()
|
|
923
|
-
return 1
|
|
924
|
-
|
|
925
940
|
from .catalog import get_catalog
|
|
926
941
|
|
|
927
942
|
logger.addHandler(logging.StreamHandler())
|
|
@@ -1008,6 +1023,7 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0911, PLR09
|
|
|
1008
1023
|
offset=args.offset,
|
|
1009
1024
|
columns=args.columns,
|
|
1010
1025
|
no_collapse=args.no_collapse,
|
|
1026
|
+
schema=args.schema,
|
|
1011
1027
|
)
|
|
1012
1028
|
elif args.command == "rm-dataset":
|
|
1013
1029
|
rm_dataset(catalog, args.name, version=args.version, force=args.force)
|
|
@@ -40,6 +40,7 @@ from datachain.error import (
|
|
|
40
40
|
StorageNotFoundError,
|
|
41
41
|
TableMissingError,
|
|
42
42
|
)
|
|
43
|
+
from datachain.job import Job
|
|
43
44
|
from datachain.storage import Storage, StorageStatus, StorageURI
|
|
44
45
|
from datachain.utils import JSONSerialize, is_expired
|
|
45
46
|
|
|
@@ -67,6 +68,7 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
67
68
|
storage_class: type[Storage] = Storage
|
|
68
69
|
dataset_class: type[DatasetRecord] = DatasetRecord
|
|
69
70
|
dependency_class: type[DatasetDependency] = DatasetDependency
|
|
71
|
+
job_class: type[Job] = Job
|
|
70
72
|
|
|
71
73
|
def __init__(
|
|
72
74
|
self,
|
|
@@ -377,6 +379,9 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
377
379
|
# Jobs
|
|
378
380
|
#
|
|
379
381
|
|
|
382
|
+
def list_jobs_by_ids(self, ids: list[str], conn=None) -> Iterator["Job"]:
|
|
383
|
+
raise NotImplementedError
|
|
384
|
+
|
|
380
385
|
@abstractmethod
|
|
381
386
|
def create_job(
|
|
382
387
|
self,
|
|
@@ -1467,6 +1472,10 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1467
1472
|
Column("metrics", JSON, nullable=False),
|
|
1468
1473
|
]
|
|
1469
1474
|
|
|
1475
|
+
@cached_property
|
|
1476
|
+
def _job_fields(self) -> list[str]:
|
|
1477
|
+
return [c.name for c in self._jobs_columns() if c.name] # type: ignore[attr-defined]
|
|
1478
|
+
|
|
1470
1479
|
@cached_property
|
|
1471
1480
|
def _jobs(self) -> "Table":
|
|
1472
1481
|
return Table(self.JOBS_TABLE, self.db.metadata, *self._jobs_columns())
|
|
@@ -1484,6 +1493,21 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1484
1493
|
return self._jobs.update()
|
|
1485
1494
|
return self._jobs.update().where(*where)
|
|
1486
1495
|
|
|
1496
|
+
def _parse_job(self, rows) -> Job:
|
|
1497
|
+
return Job.parse(*rows)
|
|
1498
|
+
|
|
1499
|
+
def _parse_jobs(self, rows) -> Iterator["Job"]:
|
|
1500
|
+
for _, g in groupby(rows, lambda r: r[0]):
|
|
1501
|
+
yield self._parse_job(*list(g))
|
|
1502
|
+
|
|
1503
|
+
def _jobs_query(self):
|
|
1504
|
+
return self._jobs_select(*[getattr(self._jobs.c, f) for f in self._job_fields])
|
|
1505
|
+
|
|
1506
|
+
def list_jobs_by_ids(self, ids: list[str], conn=None) -> Iterator["Job"]:
|
|
1507
|
+
"""List jobs by ids."""
|
|
1508
|
+
query = self._jobs_query().where(self._jobs.c.id.in_(ids))
|
|
1509
|
+
yield from self._parse_jobs(self.db.execute(query, conn=conn))
|
|
1510
|
+
|
|
1487
1511
|
def create_job(
|
|
1488
1512
|
self,
|
|
1489
1513
|
name: str,
|
|
@@ -390,7 +390,9 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
390
390
|
expressions: tuple[_ColumnsClauseArgument[Any], ...] = (
|
|
391
391
|
sa.func.count(table.c.sys__id),
|
|
392
392
|
)
|
|
393
|
-
if "
|
|
393
|
+
if "file__size" in table.columns:
|
|
394
|
+
expressions = (*expressions, sa.func.sum(table.c.file__size))
|
|
395
|
+
elif "size" in table.columns:
|
|
394
396
|
expressions = (*expressions, sa.func.sum(table.c.size))
|
|
395
397
|
query = select(*expressions)
|
|
396
398
|
((nrows, *rest),) = self.db.execute(query)
|
datachain/job.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Any, Optional, TypeVar
|
|
5
|
+
|
|
6
|
+
J = TypeVar("J", bound="Job")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class Job:
|
|
11
|
+
id: str
|
|
12
|
+
name: str
|
|
13
|
+
status: int
|
|
14
|
+
created_at: datetime
|
|
15
|
+
query: str
|
|
16
|
+
query_type: int
|
|
17
|
+
workers: int
|
|
18
|
+
params: dict[str, str]
|
|
19
|
+
metrics: dict[str, Any]
|
|
20
|
+
finished_at: Optional[datetime] = None
|
|
21
|
+
python_version: Optional[str] = None
|
|
22
|
+
error_message: str = ""
|
|
23
|
+
error_stack: str = ""
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def parse(
|
|
27
|
+
cls: type[J],
|
|
28
|
+
id: str,
|
|
29
|
+
name: str,
|
|
30
|
+
status: int,
|
|
31
|
+
created_at: datetime,
|
|
32
|
+
finished_at: Optional[datetime],
|
|
33
|
+
query: str,
|
|
34
|
+
query_type: int,
|
|
35
|
+
workers: int,
|
|
36
|
+
python_version: Optional[str],
|
|
37
|
+
error_message: str,
|
|
38
|
+
error_stack: str,
|
|
39
|
+
params: str,
|
|
40
|
+
metrics: str,
|
|
41
|
+
) -> "Job":
|
|
42
|
+
return cls(
|
|
43
|
+
id,
|
|
44
|
+
name,
|
|
45
|
+
status,
|
|
46
|
+
created_at,
|
|
47
|
+
query,
|
|
48
|
+
query_type,
|
|
49
|
+
workers,
|
|
50
|
+
json.loads(params),
|
|
51
|
+
json.loads(metrics),
|
|
52
|
+
finished_at,
|
|
53
|
+
python_version,
|
|
54
|
+
error_message,
|
|
55
|
+
error_stack,
|
|
56
|
+
)
|
datachain/lib/arrow.py
CHANGED
|
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional
|
|
|
4
4
|
|
|
5
5
|
import pyarrow as pa
|
|
6
6
|
from pyarrow.dataset import dataset
|
|
7
|
+
from tqdm import tqdm
|
|
7
8
|
|
|
8
9
|
from datachain.lib.file import File, IndexedFile
|
|
9
10
|
from datachain.lib.udf import Generator
|
|
@@ -13,33 +14,44 @@ if TYPE_CHECKING:
|
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class ArrowGenerator(Generator):
|
|
16
|
-
def __init__(
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
schema: Optional["pa.Schema"] = None,
|
|
20
|
+
nrows: Optional[int] = None,
|
|
21
|
+
**kwargs,
|
|
22
|
+
):
|
|
17
23
|
"""
|
|
18
24
|
Generator for getting rows from tabular files.
|
|
19
25
|
|
|
20
26
|
Parameters:
|
|
21
27
|
|
|
22
28
|
schema : Optional pyarrow schema for validation.
|
|
29
|
+
nrows : Optional row limit.
|
|
23
30
|
kwargs: Parameters to pass to pyarrow.dataset.dataset.
|
|
24
31
|
"""
|
|
25
32
|
super().__init__()
|
|
26
33
|
self.schema = schema
|
|
34
|
+
self.nrows = nrows
|
|
27
35
|
self.kwargs = kwargs
|
|
28
36
|
|
|
29
37
|
def process(self, file: File):
|
|
30
38
|
path = file.get_path()
|
|
31
39
|
ds = dataset(path, filesystem=file.get_fs(), schema=self.schema, **self.kwargs)
|
|
32
40
|
index = 0
|
|
33
|
-
|
|
34
|
-
for
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
41
|
+
with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
|
|
42
|
+
for record_batch in ds.to_batches():
|
|
43
|
+
for record in record_batch.to_pylist():
|
|
44
|
+
source = IndexedFile(file=file, index=index)
|
|
45
|
+
yield [source, *record.values()]
|
|
46
|
+
index += 1
|
|
47
|
+
if self.nrows and index >= self.nrows:
|
|
48
|
+
return
|
|
49
|
+
pbar.update(len(record_batch))
|
|
38
50
|
|
|
39
51
|
|
|
40
52
|
def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
|
|
41
53
|
schemas = []
|
|
42
|
-
for file in chain.
|
|
54
|
+
for file in chain.collect("file"):
|
|
43
55
|
ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
|
|
44
56
|
schemas.append(ds.schema)
|
|
45
57
|
return pa.unify_schemas(schemas)
|