datachain 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +3 -4
- datachain/cache.py +10 -4
- datachain/catalog/catalog.py +35 -15
- datachain/cli.py +37 -32
- datachain/data_storage/metastore.py +24 -0
- datachain/data_storage/warehouse.py +3 -1
- datachain/job.py +56 -0
- datachain/lib/arrow.py +19 -7
- datachain/lib/clip.py +89 -66
- datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
- datachain/lib/convert/sql_to_python.py +23 -0
- datachain/lib/convert/values_to_tuples.py +51 -33
- datachain/lib/data_model.py +6 -27
- datachain/lib/dataset_info.py +70 -0
- datachain/lib/dc.py +646 -152
- datachain/lib/file.py +117 -15
- datachain/lib/image.py +1 -1
- datachain/lib/meta_formats.py +14 -2
- datachain/lib/model_store.py +3 -2
- datachain/lib/pytorch.py +10 -7
- datachain/lib/signal_schema.py +39 -14
- datachain/lib/text.py +2 -1
- datachain/lib/udf.py +56 -5
- datachain/lib/udf_signature.py +1 -1
- datachain/lib/webdataset.py +4 -3
- datachain/node.py +11 -8
- datachain/query/dataset.py +66 -147
- datachain/query/dispatch.py +15 -13
- datachain/query/schema.py +2 -0
- datachain/query/session.py +4 -4
- datachain/sql/functions/array.py +12 -0
- datachain/sql/functions/string.py +8 -0
- datachain/torch/__init__.py +1 -1
- datachain/utils.py +45 -0
- datachain-0.2.12.dist-info/METADATA +412 -0
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/RECORD +40 -45
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/WHEEL +1 -1
- datachain/lib/feature_registry.py +0 -77
- datachain/lib/gpt4_vision.py +0 -97
- datachain/lib/hf_image_to_text.py +0 -97
- datachain/lib/hf_pipeline.py +0 -90
- datachain/lib/image_transform.py +0 -103
- datachain/lib/iptc_exif_xmp.py +0 -76
- datachain/lib/unstructured.py +0 -41
- datachain/text/__init__.py +0 -3
- datachain-0.2.10.dist-info/METADATA +0 -430
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/LICENSE +0 -0
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/top_level.txt +0 -0
datachain/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from datachain.lib.data_model import DataModel, DataType,
|
|
1
|
+
from datachain.lib.data_model import DataModel, DataType, is_chain_type
|
|
2
2
|
from datachain.lib.dc import C, Column, DataChain, Sys
|
|
3
3
|
from datachain.lib.file import (
|
|
4
4
|
File,
|
|
@@ -8,15 +8,14 @@ from datachain.lib.file import (
|
|
|
8
8
|
TarVFile,
|
|
9
9
|
TextFile,
|
|
10
10
|
)
|
|
11
|
+
from datachain.lib.model_store import ModelStore
|
|
11
12
|
from datachain.lib.udf import Aggregator, Generator, Mapper
|
|
12
13
|
from datachain.lib.utils import AbstractUDF, DataChainError
|
|
13
|
-
from datachain.query.dataset import UDF as BaseUDF # noqa: N811
|
|
14
14
|
from datachain.query.session import Session
|
|
15
15
|
|
|
16
16
|
__all__ = [
|
|
17
17
|
"AbstractUDF",
|
|
18
18
|
"Aggregator",
|
|
19
|
-
"BaseUDF",
|
|
20
19
|
"C",
|
|
21
20
|
"Column",
|
|
22
21
|
"DataChain",
|
|
@@ -24,12 +23,12 @@ __all__ = [
|
|
|
24
23
|
"DataModel",
|
|
25
24
|
"DataType",
|
|
26
25
|
"File",
|
|
27
|
-
"FileBasic",
|
|
28
26
|
"FileError",
|
|
29
27
|
"Generator",
|
|
30
28
|
"ImageFile",
|
|
31
29
|
"IndexedFile",
|
|
32
30
|
"Mapper",
|
|
31
|
+
"ModelStore",
|
|
33
32
|
"Session",
|
|
34
33
|
"Sys",
|
|
35
34
|
"TarVFile",
|
datachain/cache.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import hashlib
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
|
+
from datetime import datetime
|
|
4
5
|
from functools import partial
|
|
5
6
|
from typing import TYPE_CHECKING, Optional
|
|
6
7
|
|
|
@@ -9,6 +10,8 @@ from dvc_data.hashfile.db.local import LocalHashFileDB
|
|
|
9
10
|
from dvc_objects.fs.local import LocalFileSystem
|
|
10
11
|
from fsspec.callbacks import Callback, TqdmCallback
|
|
11
12
|
|
|
13
|
+
from datachain.utils import TIME_ZERO
|
|
14
|
+
|
|
12
15
|
from .progress import Tqdm
|
|
13
16
|
|
|
14
17
|
if TYPE_CHECKING:
|
|
@@ -23,10 +26,13 @@ class UniqueId:
|
|
|
23
26
|
storage: "StorageURI"
|
|
24
27
|
parent: str
|
|
25
28
|
name: str
|
|
26
|
-
etag: str
|
|
27
29
|
size: int
|
|
28
|
-
|
|
29
|
-
|
|
30
|
+
etag: str
|
|
31
|
+
version: str = ""
|
|
32
|
+
is_latest: bool = True
|
|
33
|
+
vtype: str = ""
|
|
34
|
+
location: Optional[str] = None
|
|
35
|
+
last_modified: datetime = TIME_ZERO
|
|
30
36
|
|
|
31
37
|
@property
|
|
32
38
|
def path(self) -> str:
|
|
@@ -49,7 +55,7 @@ class UniqueId:
|
|
|
49
55
|
def get_hash(self) -> str:
|
|
50
56
|
etag = f"{self.vtype}{self.location}" if self.vtype else self.etag
|
|
51
57
|
return sha256(
|
|
52
|
-
f"{self.storage}/{self.parent}/{self.name}/{etag}".encode()
|
|
58
|
+
f"{self.storage}/{self.parent}/{self.name}/{self.version}/{etag}".encode()
|
|
53
59
|
).hexdigest()
|
|
54
60
|
|
|
55
61
|
|
datachain/catalog/catalog.py
CHANGED
|
@@ -84,12 +84,14 @@ if TYPE_CHECKING:
|
|
|
84
84
|
AbstractMetastore,
|
|
85
85
|
AbstractWarehouse,
|
|
86
86
|
)
|
|
87
|
+
from datachain.dataset import DatasetVersion
|
|
88
|
+
from datachain.job import Job
|
|
87
89
|
|
|
88
90
|
logger = logging.getLogger("datachain")
|
|
89
91
|
|
|
90
92
|
DEFAULT_DATASET_DIR = "dataset"
|
|
91
93
|
DATASET_FILE_SUFFIX = ".edatachain"
|
|
92
|
-
FEATURE_CLASSES = ["
|
|
94
|
+
FEATURE_CLASSES = ["DataModel"]
|
|
93
95
|
|
|
94
96
|
TTL_INT = 4 * 60 * 60
|
|
95
97
|
|
|
@@ -948,13 +950,9 @@ class Catalog:
|
|
|
948
950
|
ms = self.metastore.clone(uri, None)
|
|
949
951
|
st = self.warehouse.clone()
|
|
950
952
|
listing = Listing(None, ms, st, client, None)
|
|
951
|
-
rows = (
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
)
|
|
955
|
-
.select()
|
|
956
|
-
.to_records()
|
|
957
|
-
)
|
|
953
|
+
rows = DatasetQuery(
|
|
954
|
+
name=dataset.name, version=ds_version, catalog=self
|
|
955
|
+
).to_db_records()
|
|
958
956
|
indexed_sources.append(
|
|
959
957
|
(
|
|
960
958
|
listing,
|
|
@@ -1160,9 +1158,8 @@ class Catalog:
|
|
|
1160
1158
|
if not dataset_version.preview:
|
|
1161
1159
|
values["preview"] = (
|
|
1162
1160
|
DatasetQuery(name=dataset.name, version=version, catalog=self)
|
|
1163
|
-
.select()
|
|
1164
1161
|
.limit(20)
|
|
1165
|
-
.
|
|
1162
|
+
.to_db_records()
|
|
1166
1163
|
)
|
|
1167
1164
|
|
|
1168
1165
|
if not values:
|
|
@@ -1420,6 +1417,25 @@ class Catalog:
|
|
|
1420
1417
|
if not d.is_bucket_listing:
|
|
1421
1418
|
yield d
|
|
1422
1419
|
|
|
1420
|
+
def list_datasets_versions(
|
|
1421
|
+
self,
|
|
1422
|
+
) -> Iterator[tuple[DatasetRecord, "DatasetVersion", Optional["Job"]]]:
|
|
1423
|
+
"""Iterate over all dataset versions with related jobs."""
|
|
1424
|
+
datasets = list(self.ls_datasets())
|
|
1425
|
+
|
|
1426
|
+
# preselect dataset versions jobs from db to avoid multiple queries
|
|
1427
|
+
jobs_ids: set[str] = {
|
|
1428
|
+
v.job_id for ds in datasets for v in ds.versions if v.job_id
|
|
1429
|
+
}
|
|
1430
|
+
jobs: dict[str, Job] = {}
|
|
1431
|
+
if jobs_ids:
|
|
1432
|
+
jobs = {j.id: j for j in self.metastore.list_jobs_by_ids(list(jobs_ids))}
|
|
1433
|
+
|
|
1434
|
+
for d in datasets:
|
|
1435
|
+
yield from (
|
|
1436
|
+
(d, v, jobs.get(v.job_id) if v.job_id else None) for v in d.versions
|
|
1437
|
+
)
|
|
1438
|
+
|
|
1423
1439
|
def ls_dataset_rows(
|
|
1424
1440
|
self, name: str, version: int, offset=None, limit=None
|
|
1425
1441
|
) -> list[dict]:
|
|
@@ -1427,7 +1443,7 @@ class Catalog:
|
|
|
1427
1443
|
|
|
1428
1444
|
dataset = self.get_dataset(name)
|
|
1429
1445
|
|
|
1430
|
-
q = DatasetQuery(name=dataset.name, version=version, catalog=self)
|
|
1446
|
+
q = DatasetQuery(name=dataset.name, version=version, catalog=self)
|
|
1431
1447
|
if limit:
|
|
1432
1448
|
q = q.limit(limit)
|
|
1433
1449
|
if offset:
|
|
@@ -1435,7 +1451,7 @@ class Catalog:
|
|
|
1435
1451
|
|
|
1436
1452
|
q = q.order_by("sys__id")
|
|
1437
1453
|
|
|
1438
|
-
return q.
|
|
1454
|
+
return q.to_db_records()
|
|
1439
1455
|
|
|
1440
1456
|
def signed_url(self, source: str, path: str, client_config=None) -> str:
|
|
1441
1457
|
client_config = client_config or self.client_config
|
|
@@ -1609,6 +1625,7 @@ class Catalog:
|
|
|
1609
1625
|
...
|
|
1610
1626
|
}
|
|
1611
1627
|
"""
|
|
1628
|
+
from datachain.lib.file import File
|
|
1612
1629
|
from datachain.lib.signal_schema import DEFAULT_DELIMITER, SignalSchema
|
|
1613
1630
|
|
|
1614
1631
|
version = self.get_dataset(dataset_name).get_version(dataset_version)
|
|
@@ -1616,7 +1633,7 @@ class Catalog:
|
|
|
1616
1633
|
file_signals_values = {}
|
|
1617
1634
|
|
|
1618
1635
|
schema = SignalSchema.deserialize(version.feature_schema)
|
|
1619
|
-
for file_signals in schema.
|
|
1636
|
+
for file_signals in schema.get_signals(File):
|
|
1620
1637
|
prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
|
|
1621
1638
|
file_signals_values[file_signals] = {
|
|
1622
1639
|
c_name.removeprefix(prefix): c_value
|
|
@@ -1657,10 +1674,13 @@ class Catalog:
|
|
|
1657
1674
|
row["source"],
|
|
1658
1675
|
row["parent"],
|
|
1659
1676
|
row["name"],
|
|
1660
|
-
row["etag"],
|
|
1661
1677
|
row["size"],
|
|
1678
|
+
row["etag"],
|
|
1679
|
+
row["version"],
|
|
1680
|
+
row["is_latest"],
|
|
1662
1681
|
row["vtype"],
|
|
1663
1682
|
row["location"],
|
|
1683
|
+
row["last_modified"],
|
|
1664
1684
|
)
|
|
1665
1685
|
|
|
1666
1686
|
def ls(
|
|
@@ -1992,7 +2012,7 @@ class Catalog:
|
|
|
1992
2012
|
)
|
|
1993
2013
|
if proc.returncode == QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE:
|
|
1994
2014
|
raise QueryScriptRunError(
|
|
1995
|
-
"Last line in a script was not an instance of
|
|
2015
|
+
"Last line in a script was not an instance of DataChain",
|
|
1996
2016
|
return_code=proc.returncode,
|
|
1997
2017
|
output=output,
|
|
1998
2018
|
)
|
datachain/cli.py
CHANGED
|
@@ -3,7 +3,7 @@ import os
|
|
|
3
3
|
import shlex
|
|
4
4
|
import sys
|
|
5
5
|
import traceback
|
|
6
|
-
from argparse import
|
|
6
|
+
from argparse import Action, ArgumentParser, ArgumentTypeError, Namespace
|
|
7
7
|
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
8
8
|
from importlib.metadata import PackageNotFoundError, version
|
|
9
9
|
from itertools import chain
|
|
@@ -106,10 +106,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
106
106
|
parser = ArgumentParser(
|
|
107
107
|
description="DataChain: Wrangle unstructured AI data at scale", prog="datachain"
|
|
108
108
|
)
|
|
109
|
-
|
|
110
109
|
parser.add_argument("-V", "--version", action="version", version=__version__)
|
|
111
|
-
parser.add_argument("--internal-run-udf", action="store_true", help=SUPPRESS)
|
|
112
|
-
parser.add_argument("--internal-run-udf-worker", action="store_true", help=SUPPRESS)
|
|
113
110
|
|
|
114
111
|
parent_parser = ArgumentParser(add_help=False)
|
|
115
112
|
parent_parser.add_argument(
|
|
@@ -150,9 +147,15 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
150
147
|
help="Drop into the pdb debugger on fatal exception",
|
|
151
148
|
)
|
|
152
149
|
|
|
153
|
-
subp = parser.add_subparsers(
|
|
150
|
+
subp = parser.add_subparsers(
|
|
151
|
+
title="Available Commands",
|
|
152
|
+
metavar="command",
|
|
153
|
+
dest="command",
|
|
154
|
+
help=f"Use `{parser.prog} command --help` for command-specific help.",
|
|
155
|
+
required=True,
|
|
156
|
+
)
|
|
154
157
|
parse_cp = subp.add_parser(
|
|
155
|
-
"cp", parents=[parent_parser],
|
|
158
|
+
"cp", parents=[parent_parser], description="Copy data files from the cloud"
|
|
156
159
|
)
|
|
157
160
|
add_sources_arg(parse_cp).complete = shtab.DIR # type: ignore[attr-defined]
|
|
158
161
|
parse_cp.add_argument("output", type=str, help="Output")
|
|
@@ -179,7 +182,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
179
182
|
)
|
|
180
183
|
|
|
181
184
|
parse_clone = subp.add_parser(
|
|
182
|
-
"clone", parents=[parent_parser],
|
|
185
|
+
"clone", parents=[parent_parser], description="Copy data files from the cloud"
|
|
183
186
|
)
|
|
184
187
|
add_sources_arg(parse_clone).complete = shtab.DIR # type: ignore[attr-defined]
|
|
185
188
|
parse_clone.add_argument("output", type=str, help="Output")
|
|
@@ -222,7 +225,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
222
225
|
)
|
|
223
226
|
|
|
224
227
|
parse_pull = subp.add_parser(
|
|
225
|
-
"pull",
|
|
228
|
+
"pull",
|
|
229
|
+
parents=[parent_parser],
|
|
230
|
+
description="Pull specific dataset version from SaaS",
|
|
226
231
|
)
|
|
227
232
|
parse_pull.add_argument(
|
|
228
233
|
"dataset",
|
|
@@ -263,7 +268,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
263
268
|
)
|
|
264
269
|
|
|
265
270
|
parse_edit_dataset = subp.add_parser(
|
|
266
|
-
"edit-dataset", parents=[parent_parser],
|
|
271
|
+
"edit-dataset", parents=[parent_parser], description="Edit dataset metadata"
|
|
267
272
|
)
|
|
268
273
|
parse_edit_dataset.add_argument("name", type=str, help="Dataset name")
|
|
269
274
|
parse_edit_dataset.add_argument(
|
|
@@ -285,9 +290,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
285
290
|
help="Dataset labels",
|
|
286
291
|
)
|
|
287
292
|
|
|
288
|
-
subp.add_parser("ls-datasets", parents=[parent_parser],
|
|
293
|
+
subp.add_parser("ls-datasets", parents=[parent_parser], description="List datasets")
|
|
289
294
|
rm_dataset_parser = subp.add_parser(
|
|
290
|
-
"rm-dataset", parents=[parent_parser],
|
|
295
|
+
"rm-dataset", parents=[parent_parser], description="Removes dataset"
|
|
291
296
|
)
|
|
292
297
|
rm_dataset_parser.add_argument("name", type=str, help="Dataset name")
|
|
293
298
|
rm_dataset_parser.add_argument(
|
|
@@ -305,7 +310,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
305
310
|
)
|
|
306
311
|
|
|
307
312
|
dataset_stats_parser = subp.add_parser(
|
|
308
|
-
"dataset-stats",
|
|
313
|
+
"dataset-stats",
|
|
314
|
+
parents=[parent_parser],
|
|
315
|
+
description="Shows basic dataset stats",
|
|
309
316
|
)
|
|
310
317
|
dataset_stats_parser.add_argument("name", type=str, help="Dataset name")
|
|
311
318
|
dataset_stats_parser.add_argument(
|
|
@@ -330,7 +337,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
330
337
|
)
|
|
331
338
|
|
|
332
339
|
parse_merge_datasets = subp.add_parser(
|
|
333
|
-
"merge-datasets", parents=[parent_parser],
|
|
340
|
+
"merge-datasets", parents=[parent_parser], description="Merges datasets"
|
|
334
341
|
)
|
|
335
342
|
parse_merge_datasets.add_argument(
|
|
336
343
|
"--src",
|
|
@@ -360,7 +367,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
360
367
|
)
|
|
361
368
|
|
|
362
369
|
parse_ls = subp.add_parser(
|
|
363
|
-
"ls", parents=[parent_parser],
|
|
370
|
+
"ls", parents=[parent_parser], description="List storage contents"
|
|
364
371
|
)
|
|
365
372
|
add_sources_arg(parse_ls, nargs="*")
|
|
366
373
|
parse_ls.add_argument(
|
|
@@ -378,7 +385,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
378
385
|
)
|
|
379
386
|
|
|
380
387
|
parse_du = subp.add_parser(
|
|
381
|
-
"du", parents=[parent_parser],
|
|
388
|
+
"du", parents=[parent_parser], description="Display space usage"
|
|
382
389
|
)
|
|
383
390
|
add_sources_arg(parse_du)
|
|
384
391
|
parse_du.add_argument(
|
|
@@ -408,7 +415,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
408
415
|
)
|
|
409
416
|
|
|
410
417
|
parse_find = subp.add_parser(
|
|
411
|
-
"find", parents=[parent_parser],
|
|
418
|
+
"find", parents=[parent_parser], description="Search in a directory hierarchy"
|
|
412
419
|
)
|
|
413
420
|
add_sources_arg(parse_find)
|
|
414
421
|
parse_find.add_argument(
|
|
@@ -461,20 +468,20 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
461
468
|
)
|
|
462
469
|
|
|
463
470
|
parse_index = subp.add_parser(
|
|
464
|
-
"index", parents=[parent_parser],
|
|
471
|
+
"index", parents=[parent_parser], description="Index storage location"
|
|
465
472
|
)
|
|
466
473
|
add_sources_arg(parse_index)
|
|
467
474
|
|
|
468
475
|
subp.add_parser(
|
|
469
476
|
"find-stale-storages",
|
|
470
477
|
parents=[parent_parser],
|
|
471
|
-
|
|
478
|
+
description="Finds and marks stale storages",
|
|
472
479
|
)
|
|
473
480
|
|
|
474
481
|
show_parser = subp.add_parser(
|
|
475
482
|
"show",
|
|
476
483
|
parents=[parent_parser],
|
|
477
|
-
|
|
484
|
+
description="Create a new dataset with a query script",
|
|
478
485
|
)
|
|
479
486
|
show_parser.add_argument("name", type=str, help="Dataset name")
|
|
480
487
|
show_parser.add_argument(
|
|
@@ -489,7 +496,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
489
496
|
query_parser = subp.add_parser(
|
|
490
497
|
"query",
|
|
491
498
|
parents=[parent_parser],
|
|
492
|
-
|
|
499
|
+
description="Create a new dataset with a query script",
|
|
493
500
|
)
|
|
494
501
|
query_parser.add_argument(
|
|
495
502
|
"script", metavar="<script.py>", type=str, help="Filepath for script"
|
|
@@ -520,7 +527,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
520
527
|
)
|
|
521
528
|
|
|
522
529
|
apply_udf_parser = subp.add_parser(
|
|
523
|
-
"apply-udf", parents=[parent_parser],
|
|
530
|
+
"apply-udf", parents=[parent_parser], description="Apply UDF"
|
|
524
531
|
)
|
|
525
532
|
apply_udf_parser.add_argument("udf", type=str, help="UDF location")
|
|
526
533
|
apply_udf_parser.add_argument("source", type=str, help="Source storage or dataset")
|
|
@@ -541,12 +548,14 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
541
548
|
"--udf-params", type=str, default=None, help="UDF class parameters"
|
|
542
549
|
)
|
|
543
550
|
subp.add_parser(
|
|
544
|
-
"clear-cache", parents=[parent_parser],
|
|
551
|
+
"clear-cache", parents=[parent_parser], description="Clear the local file cache"
|
|
545
552
|
)
|
|
546
553
|
subp.add_parser(
|
|
547
|
-
"gc", parents=[parent_parser],
|
|
554
|
+
"gc", parents=[parent_parser], description="Garbage collect temporary tables"
|
|
548
555
|
)
|
|
549
556
|
|
|
557
|
+
subp.add_parser("internal-run-udf", parents=[parent_parser])
|
|
558
|
+
subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
|
|
550
559
|
add_completion_parser(subp, [parent_parser])
|
|
551
560
|
return parser
|
|
552
561
|
|
|
@@ -555,7 +564,7 @@ def add_completion_parser(subparsers, parents):
|
|
|
555
564
|
parser = subparsers.add_parser(
|
|
556
565
|
"completion",
|
|
557
566
|
parents=parents,
|
|
558
|
-
|
|
567
|
+
description="Output shell completion script",
|
|
559
568
|
)
|
|
560
569
|
parser.add_argument(
|
|
561
570
|
"-s",
|
|
@@ -817,7 +826,7 @@ def show(
|
|
|
817
826
|
.limit(limit)
|
|
818
827
|
.offset(offset)
|
|
819
828
|
)
|
|
820
|
-
records = query.
|
|
829
|
+
records = query.to_db_records()
|
|
821
830
|
show_records(records, collapse_columns=not no_collapse)
|
|
822
831
|
|
|
823
832
|
|
|
@@ -901,27 +910,23 @@ def completion(shell: str) -> str:
|
|
|
901
910
|
)
|
|
902
911
|
|
|
903
912
|
|
|
904
|
-
def main(argv: Optional[list[str]] = None) -> int: # noqa: C901,
|
|
913
|
+
def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR0915
|
|
905
914
|
# Required for Windows multiprocessing support
|
|
906
915
|
freeze_support()
|
|
907
916
|
|
|
908
917
|
parser = get_parser()
|
|
909
918
|
args = parser.parse_args(argv)
|
|
910
919
|
|
|
911
|
-
if args.
|
|
920
|
+
if args.command == "internal-run-udf":
|
|
912
921
|
from datachain.query.dispatch import udf_entrypoint
|
|
913
922
|
|
|
914
923
|
return udf_entrypoint()
|
|
915
924
|
|
|
916
|
-
if args.
|
|
925
|
+
if args.command == "internal-run-udf-worker":
|
|
917
926
|
from datachain.query.dispatch import udf_worker_entrypoint
|
|
918
927
|
|
|
919
928
|
return udf_worker_entrypoint()
|
|
920
929
|
|
|
921
|
-
if args.command is None:
|
|
922
|
-
parser.print_help()
|
|
923
|
-
return 1
|
|
924
|
-
|
|
925
930
|
from .catalog import get_catalog
|
|
926
931
|
|
|
927
932
|
logger.addHandler(logging.StreamHandler())
|
|
@@ -40,6 +40,7 @@ from datachain.error import (
|
|
|
40
40
|
StorageNotFoundError,
|
|
41
41
|
TableMissingError,
|
|
42
42
|
)
|
|
43
|
+
from datachain.job import Job
|
|
43
44
|
from datachain.storage import Storage, StorageStatus, StorageURI
|
|
44
45
|
from datachain.utils import JSONSerialize, is_expired
|
|
45
46
|
|
|
@@ -67,6 +68,7 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
67
68
|
storage_class: type[Storage] = Storage
|
|
68
69
|
dataset_class: type[DatasetRecord] = DatasetRecord
|
|
69
70
|
dependency_class: type[DatasetDependency] = DatasetDependency
|
|
71
|
+
job_class: type[Job] = Job
|
|
70
72
|
|
|
71
73
|
def __init__(
|
|
72
74
|
self,
|
|
@@ -377,6 +379,9 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
377
379
|
# Jobs
|
|
378
380
|
#
|
|
379
381
|
|
|
382
|
+
def list_jobs_by_ids(self, ids: list[str], conn=None) -> Iterator["Job"]:
|
|
383
|
+
raise NotImplementedError
|
|
384
|
+
|
|
380
385
|
@abstractmethod
|
|
381
386
|
def create_job(
|
|
382
387
|
self,
|
|
@@ -1467,6 +1472,10 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1467
1472
|
Column("metrics", JSON, nullable=False),
|
|
1468
1473
|
]
|
|
1469
1474
|
|
|
1475
|
+
@cached_property
|
|
1476
|
+
def _job_fields(self) -> list[str]:
|
|
1477
|
+
return [c.name for c in self._jobs_columns() if c.name] # type: ignore[attr-defined]
|
|
1478
|
+
|
|
1470
1479
|
@cached_property
|
|
1471
1480
|
def _jobs(self) -> "Table":
|
|
1472
1481
|
return Table(self.JOBS_TABLE, self.db.metadata, *self._jobs_columns())
|
|
@@ -1484,6 +1493,21 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
1484
1493
|
return self._jobs.update()
|
|
1485
1494
|
return self._jobs.update().where(*where)
|
|
1486
1495
|
|
|
1496
|
+
def _parse_job(self, rows) -> Job:
|
|
1497
|
+
return Job.parse(*rows)
|
|
1498
|
+
|
|
1499
|
+
def _parse_jobs(self, rows) -> Iterator["Job"]:
|
|
1500
|
+
for _, g in groupby(rows, lambda r: r[0]):
|
|
1501
|
+
yield self._parse_job(*list(g))
|
|
1502
|
+
|
|
1503
|
+
def _jobs_query(self):
|
|
1504
|
+
return self._jobs_select(*[getattr(self._jobs.c, f) for f in self._job_fields])
|
|
1505
|
+
|
|
1506
|
+
def list_jobs_by_ids(self, ids: list[str], conn=None) -> Iterator["Job"]:
|
|
1507
|
+
"""List jobs by ids."""
|
|
1508
|
+
query = self._jobs_query().where(self._jobs.c.id.in_(ids))
|
|
1509
|
+
yield from self._parse_jobs(self.db.execute(query, conn=conn))
|
|
1510
|
+
|
|
1487
1511
|
def create_job(
|
|
1488
1512
|
self,
|
|
1489
1513
|
name: str,
|
|
@@ -390,7 +390,9 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
390
390
|
expressions: tuple[_ColumnsClauseArgument[Any], ...] = (
|
|
391
391
|
sa.func.count(table.c.sys__id),
|
|
392
392
|
)
|
|
393
|
-
if "
|
|
393
|
+
if "file__size" in table.columns:
|
|
394
|
+
expressions = (*expressions, sa.func.sum(table.c.file__size))
|
|
395
|
+
elif "size" in table.columns:
|
|
394
396
|
expressions = (*expressions, sa.func.sum(table.c.size))
|
|
395
397
|
query = select(*expressions)
|
|
396
398
|
((nrows, *rest),) = self.db.execute(query)
|
datachain/job.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Any, Optional, TypeVar
|
|
5
|
+
|
|
6
|
+
J = TypeVar("J", bound="Job")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class Job:
|
|
11
|
+
id: str
|
|
12
|
+
name: str
|
|
13
|
+
status: int
|
|
14
|
+
created_at: datetime
|
|
15
|
+
query: str
|
|
16
|
+
query_type: int
|
|
17
|
+
workers: int
|
|
18
|
+
params: dict[str, str]
|
|
19
|
+
metrics: dict[str, Any]
|
|
20
|
+
finished_at: Optional[datetime] = None
|
|
21
|
+
python_version: Optional[str] = None
|
|
22
|
+
error_message: str = ""
|
|
23
|
+
error_stack: str = ""
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def parse(
|
|
27
|
+
cls: type[J],
|
|
28
|
+
id: str,
|
|
29
|
+
name: str,
|
|
30
|
+
status: int,
|
|
31
|
+
created_at: datetime,
|
|
32
|
+
finished_at: Optional[datetime],
|
|
33
|
+
query: str,
|
|
34
|
+
query_type: int,
|
|
35
|
+
workers: int,
|
|
36
|
+
python_version: Optional[str],
|
|
37
|
+
error_message: str,
|
|
38
|
+
error_stack: str,
|
|
39
|
+
params: str,
|
|
40
|
+
metrics: str,
|
|
41
|
+
) -> "Job":
|
|
42
|
+
return cls(
|
|
43
|
+
id,
|
|
44
|
+
name,
|
|
45
|
+
status,
|
|
46
|
+
created_at,
|
|
47
|
+
query,
|
|
48
|
+
query_type,
|
|
49
|
+
workers,
|
|
50
|
+
json.loads(params),
|
|
51
|
+
json.loads(metrics),
|
|
52
|
+
finished_at,
|
|
53
|
+
python_version,
|
|
54
|
+
error_message,
|
|
55
|
+
error_stack,
|
|
56
|
+
)
|
datachain/lib/arrow.py
CHANGED
|
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional
|
|
|
4
4
|
|
|
5
5
|
import pyarrow as pa
|
|
6
6
|
from pyarrow.dataset import dataset
|
|
7
|
+
from tqdm import tqdm
|
|
7
8
|
|
|
8
9
|
from datachain.lib.file import File, IndexedFile
|
|
9
10
|
from datachain.lib.udf import Generator
|
|
@@ -13,33 +14,44 @@ if TYPE_CHECKING:
|
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class ArrowGenerator(Generator):
|
|
16
|
-
def __init__(
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
schema: Optional["pa.Schema"] = None,
|
|
20
|
+
nrows: Optional[int] = None,
|
|
21
|
+
**kwargs,
|
|
22
|
+
):
|
|
17
23
|
"""
|
|
18
24
|
Generator for getting rows from tabular files.
|
|
19
25
|
|
|
20
26
|
Parameters:
|
|
21
27
|
|
|
22
28
|
schema : Optional pyarrow schema for validation.
|
|
29
|
+
nrows : Optional row limit.
|
|
23
30
|
kwargs: Parameters to pass to pyarrow.dataset.dataset.
|
|
24
31
|
"""
|
|
25
32
|
super().__init__()
|
|
26
33
|
self.schema = schema
|
|
34
|
+
self.nrows = nrows
|
|
27
35
|
self.kwargs = kwargs
|
|
28
36
|
|
|
29
37
|
def process(self, file: File):
|
|
30
38
|
path = file.get_path()
|
|
31
39
|
ds = dataset(path, filesystem=file.get_fs(), schema=self.schema, **self.kwargs)
|
|
32
40
|
index = 0
|
|
33
|
-
|
|
34
|
-
for
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
41
|
+
with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
|
|
42
|
+
for record_batch in ds.to_batches():
|
|
43
|
+
for record in record_batch.to_pylist():
|
|
44
|
+
source = IndexedFile(file=file, index=index)
|
|
45
|
+
yield [source, *record.values()]
|
|
46
|
+
index += 1
|
|
47
|
+
if self.nrows and index >= self.nrows:
|
|
48
|
+
return
|
|
49
|
+
pbar.update(len(record_batch))
|
|
38
50
|
|
|
39
51
|
|
|
40
52
|
def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
|
|
41
53
|
schemas = []
|
|
42
|
-
for file in chain.
|
|
54
|
+
for file in chain.collect("file"):
|
|
43
55
|
ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
|
|
44
56
|
schemas.append(ds.schema)
|
|
45
57
|
return pa.unify_schemas(schemas)
|