datachain 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (44) hide show
  1. datachain/cache.py +4 -2
  2. datachain/catalog/catalog.py +100 -54
  3. datachain/catalog/datasource.py +4 -6
  4. datachain/cli/__init__.py +311 -0
  5. datachain/cli/commands/__init__.py +29 -0
  6. datachain/cli/commands/datasets.py +129 -0
  7. datachain/cli/commands/du.py +14 -0
  8. datachain/cli/commands/index.py +12 -0
  9. datachain/cli/commands/ls.py +169 -0
  10. datachain/cli/commands/misc.py +28 -0
  11. datachain/cli/commands/query.py +53 -0
  12. datachain/cli/commands/show.py +38 -0
  13. datachain/cli/parser/__init__.py +547 -0
  14. datachain/cli/parser/job.py +120 -0
  15. datachain/cli/parser/studio.py +126 -0
  16. datachain/cli/parser/utils.py +63 -0
  17. datachain/{cli_utils.py → cli/utils.py} +27 -1
  18. datachain/client/azure.py +21 -1
  19. datachain/client/fsspec.py +45 -13
  20. datachain/client/gcs.py +10 -2
  21. datachain/client/local.py +4 -4
  22. datachain/client/s3.py +10 -0
  23. datachain/dataset.py +1 -0
  24. datachain/func/__init__.py +2 -2
  25. datachain/func/conditional.py +52 -0
  26. datachain/func/func.py +5 -1
  27. datachain/lib/arrow.py +4 -0
  28. datachain/lib/dc.py +18 -3
  29. datachain/lib/file.py +1 -1
  30. datachain/lib/listing.py +36 -3
  31. datachain/lib/signal_schema.py +89 -27
  32. datachain/listing.py +1 -5
  33. datachain/node.py +27 -1
  34. datachain/progress.py +2 -2
  35. datachain/query/session.py +1 -1
  36. datachain/studio.py +58 -38
  37. datachain/utils.py +1 -1
  38. {datachain-0.8.2.dist-info → datachain-0.8.4.dist-info}/METADATA +6 -6
  39. {datachain-0.8.2.dist-info → datachain-0.8.4.dist-info}/RECORD +43 -31
  40. {datachain-0.8.2.dist-info → datachain-0.8.4.dist-info}/WHEEL +1 -1
  41. datachain/cli.py +0 -1475
  42. {datachain-0.8.2.dist-info → datachain-0.8.4.dist-info}/LICENSE +0 -0
  43. {datachain-0.8.2.dist-info → datachain-0.8.4.dist-info}/entry_points.txt +0 -0
  44. {datachain-0.8.2.dist-info → datachain-0.8.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,129 @@
1
+ import sys
2
+ from typing import TYPE_CHECKING, Optional
3
+
4
+ from tabulate import tabulate
5
+
6
+ from datachain import utils
7
+
8
+ if TYPE_CHECKING:
9
+ from datachain.catalog import Catalog
10
+
11
+ from datachain.cli.utils import determine_flavors
12
+ from datachain.config import Config
13
+ from datachain.error import DatasetNotFoundError
14
+
15
+
16
+ def list_datasets(
17
+ catalog: "Catalog",
18
+ studio: bool = False,
19
+ local: bool = False,
20
+ all: bool = True,
21
+ team: Optional[str] = None,
22
+ ):
23
+ from datachain.studio import list_datasets
24
+
25
+ token = Config().read().get("studio", {}).get("token")
26
+ all, local, studio = determine_flavors(studio, local, all, token)
27
+
28
+ local_datasets = set(list_datasets_local(catalog)) if all or local else set()
29
+ studio_datasets = (
30
+ set(list_datasets(team=team)) if (all or studio) and token else set()
31
+ )
32
+
33
+ rows = [
34
+ _datasets_tabulate_row(
35
+ name=name,
36
+ version=version,
37
+ both=(all or (local and studio)) and token,
38
+ local=(name, version) in local_datasets,
39
+ studio=(name, version) in studio_datasets,
40
+ )
41
+ for name, version in local_datasets.union(studio_datasets)
42
+ ]
43
+
44
+ print(tabulate(rows, headers="keys"))
45
+
46
+
47
+ def list_datasets_local(catalog: "Catalog"):
48
+ for d in catalog.ls_datasets():
49
+ for v in d.versions:
50
+ yield (d.name, v.version)
51
+
52
+
53
+ def _datasets_tabulate_row(name, version, both, local, studio):
54
+ row = {
55
+ "Name": name,
56
+ "Version": version,
57
+ }
58
+ if both:
59
+ row["Studio"] = "\u2714" if studio else "\u2716"
60
+ row["Local"] = "\u2714" if local else "\u2716"
61
+ return row
62
+
63
+
64
+ def rm_dataset(
65
+ catalog: "Catalog",
66
+ name: str,
67
+ version: Optional[int] = None,
68
+ force: Optional[bool] = False,
69
+ studio: bool = False,
70
+ local: bool = False,
71
+ all: bool = True,
72
+ team: Optional[str] = None,
73
+ ):
74
+ from datachain.studio import remove_studio_dataset
75
+
76
+ token = Config().read().get("studio", {}).get("token")
77
+ all, local, studio = determine_flavors(studio, local, all, token)
78
+
79
+ if all or local:
80
+ try:
81
+ catalog.remove_dataset(name, version=version, force=force)
82
+ except DatasetNotFoundError:
83
+ print("Dataset not found in local", file=sys.stderr)
84
+
85
+ if (all or studio) and token:
86
+ remove_studio_dataset(team, name, version, force)
87
+
88
+
89
+ def edit_dataset(
90
+ catalog: "Catalog",
91
+ name: str,
92
+ new_name: Optional[str] = None,
93
+ description: Optional[str] = None,
94
+ labels: Optional[list[str]] = None,
95
+ studio: bool = False,
96
+ local: bool = False,
97
+ all: bool = True,
98
+ team: Optional[str] = None,
99
+ ):
100
+ from datachain.studio import edit_studio_dataset
101
+
102
+ token = Config().read().get("studio", {}).get("token")
103
+ all, local, studio = determine_flavors(studio, local, all, token)
104
+
105
+ if all or local:
106
+ try:
107
+ catalog.edit_dataset(name, new_name, description, labels)
108
+ except DatasetNotFoundError:
109
+ print("Dataset not found in local", file=sys.stderr)
110
+
111
+ if (all or studio) and token:
112
+ edit_studio_dataset(team, name, new_name, description, labels)
113
+
114
+
115
+ def dataset_stats(
116
+ catalog: "Catalog",
117
+ name: str,
118
+ version: int,
119
+ show_bytes=False,
120
+ si=False,
121
+ ):
122
+ stats = catalog.dataset_stats(name, version)
123
+
124
+ if stats:
125
+ print(f"Number of objects: {stats.num_objects}")
126
+ if show_bytes:
127
+ print(f"Total objects size: {stats.size}")
128
+ else:
129
+ print(f"Total objects size: {utils.sizeof_fmt(stats.size, si=si): >7}")
@@ -0,0 +1,14 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from datachain import utils
4
+
5
+ if TYPE_CHECKING:
6
+ from datachain.catalog import Catalog
7
+
8
+
9
+ def du(catalog: "Catalog", sources, show_bytes=False, si=False, **kwargs):
10
+ for path, size in catalog.du(sources, **kwargs):
11
+ if show_bytes:
12
+ print(f"{size} {path}")
13
+ else:
14
+ print(f"{utils.sizeof_fmt(size, si=si): >7} {path}")
@@ -0,0 +1,12 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ if TYPE_CHECKING:
4
+ from datachain.catalog import Catalog
5
+
6
+
7
+ def index(
8
+ catalog: "Catalog",
9
+ sources,
10
+ **kwargs,
11
+ ):
12
+ catalog.index(sources, **kwargs)
@@ -0,0 +1,169 @@
1
+ import shlex
2
+ from collections.abc import Iterable, Iterator
3
+ from itertools import chain
4
+ from typing import TYPE_CHECKING, Optional
5
+
6
+ if TYPE_CHECKING:
7
+ from datachain.catalog import Catalog
8
+
9
+ from datachain.cli.utils import determine_flavors
10
+ from datachain.config import Config
11
+
12
+
13
+ def ls(
14
+ sources,
15
+ long: bool = False,
16
+ studio: bool = False,
17
+ local: bool = False,
18
+ all: bool = True,
19
+ team: Optional[str] = None,
20
+ **kwargs,
21
+ ):
22
+ token = Config().read().get("studio", {}).get("token")
23
+ all, local, studio = determine_flavors(studio, local, all, token)
24
+
25
+ if all or local:
26
+ ls_local(sources, long=long, **kwargs)
27
+
28
+ if (all or studio) and token:
29
+ ls_remote(sources, long=long, team=team)
30
+
31
+
32
+ def ls_local(
33
+ sources,
34
+ long: bool = False,
35
+ catalog: Optional["Catalog"] = None,
36
+ client_config=None,
37
+ **kwargs,
38
+ ):
39
+ from datachain import DataChain
40
+
41
+ if catalog is None:
42
+ from datachain.catalog import get_catalog
43
+
44
+ catalog = get_catalog(client_config=client_config)
45
+ if sources:
46
+ actual_sources = list(ls_urls(sources, catalog=catalog, long=long, **kwargs))
47
+ if len(actual_sources) == 1:
48
+ for _, entries in actual_sources:
49
+ for entry in entries:
50
+ print(format_ls_entry(entry))
51
+ else:
52
+ first = True
53
+ for source, entries in actual_sources:
54
+ # print a newline between directory listings
55
+ if first:
56
+ first = False
57
+ else:
58
+ print()
59
+ if source:
60
+ print(f"{source}:")
61
+ for entry in entries:
62
+ print(format_ls_entry(entry))
63
+ else:
64
+ chain = DataChain.listings()
65
+ for ls in chain.collect("listing"):
66
+ print(format_ls_entry(f"{ls.uri}@v{ls.version}")) # type: ignore[union-attr]
67
+
68
+
69
+ def format_ls_entry(entry: str) -> str:
70
+ if entry.endswith("/") or not entry:
71
+ entry = shlex.quote(entry[:-1])
72
+ return f"{entry}/"
73
+ return shlex.quote(entry)
74
+
75
+
76
+ def ls_remote(
77
+ paths: Iterable[str],
78
+ long: bool = False,
79
+ team: Optional[str] = None,
80
+ ):
81
+ from datachain.node import long_line_str
82
+ from datachain.remote.studio import StudioClient
83
+
84
+ client = StudioClient(team=team)
85
+ first = True
86
+ for path, response in client.ls(paths):
87
+ if not first:
88
+ print()
89
+ if not response.ok or response.data is None:
90
+ print(f"{path}:\n Error: {response.message}\n")
91
+ continue
92
+
93
+ print(f"{path}:")
94
+ if long:
95
+ for row in response.data:
96
+ entry = long_line_str(
97
+ row["name"] + ("/" if row["dir_type"] else ""),
98
+ row["last_modified"],
99
+ )
100
+ print(format_ls_entry(entry))
101
+ else:
102
+ for row in response.data:
103
+ entry = row["name"] + ("/" if row["dir_type"] else "")
104
+ print(format_ls_entry(entry))
105
+ first = False
106
+
107
+
108
+ def ls_urls(
109
+ sources,
110
+ catalog: "Catalog",
111
+ long: bool = False,
112
+ **kwargs,
113
+ ) -> Iterator[tuple[str, Iterator[str]]]:
114
+ curr_dir = None
115
+ value_iterables = []
116
+ for next_dir, values in _ls_urls_flat(sources, long, catalog, **kwargs):
117
+ if curr_dir is None or next_dir == curr_dir: # type: ignore[unreachable]
118
+ value_iterables.append(values)
119
+ else:
120
+ yield curr_dir, chain(*value_iterables) # type: ignore[unreachable]
121
+ value_iterables = [values]
122
+ curr_dir = next_dir
123
+ if curr_dir is not None:
124
+ yield curr_dir, chain(*value_iterables)
125
+
126
+
127
+ def _node_data_to_ls_values(row, long_format=False):
128
+ from datachain.node import DirType, long_line_str
129
+
130
+ name = row[0]
131
+ is_dir = row[1] == DirType.DIR
132
+ ending = "/" if is_dir else ""
133
+ value = name + ending
134
+ if long_format:
135
+ last_modified = row[2]
136
+ timestamp = last_modified if not is_dir else None
137
+ return long_line_str(value, timestamp)
138
+ return value
139
+
140
+
141
+ def _ls_urls_flat(
142
+ sources,
143
+ long: bool,
144
+ catalog: "Catalog",
145
+ **kwargs,
146
+ ) -> Iterator[tuple[str, Iterator[str]]]:
147
+ from datachain.client import Client
148
+ from datachain.node import long_line_str
149
+
150
+ for source in sources:
151
+ client_cls = Client.get_implementation(source)
152
+ if client_cls.is_root_url(source):
153
+ buckets = client_cls.ls_buckets(**catalog.client_config)
154
+ if long:
155
+ values = (long_line_str(b.name, b.created) for b in buckets)
156
+ else:
157
+ values = (b.name for b in buckets)
158
+ yield source, values
159
+ else:
160
+ found = False
161
+ fields = ["name", "dir_type"]
162
+ if long:
163
+ fields.append("last_modified")
164
+ for data_source, results in catalog.ls([source], fields=fields, **kwargs):
165
+ values = (_node_data_to_ls_values(r, long) for r in results)
166
+ found = True
167
+ yield data_source.dirname(), values
168
+ if not found:
169
+ raise FileNotFoundError(f"No such file or directory: {source}")
@@ -0,0 +1,28 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import shtab
4
+
5
+ if TYPE_CHECKING:
6
+ from datachain.catalog import Catalog
7
+
8
+
9
+ def clear_cache(catalog: "Catalog"):
10
+ catalog.cache.clear()
11
+
12
+
13
+ def garbage_collect(catalog: "Catalog"):
14
+ temp_tables = catalog.get_temp_table_names()
15
+ if not temp_tables:
16
+ print("Nothing to clean up.")
17
+ else:
18
+ print(f"Garbage collecting {len(temp_tables)} tables.")
19
+ catalog.cleanup_tables(temp_tables)
20
+
21
+
22
+ def completion(shell: str) -> str:
23
+ from datachain.cli import get_parser
24
+
25
+ return shtab.complete(
26
+ get_parser(),
27
+ shell=shell,
28
+ )
@@ -0,0 +1,53 @@
1
+ import os
2
+ import sys
3
+ import traceback
4
+ from typing import TYPE_CHECKING, Optional
5
+
6
+ if TYPE_CHECKING:
7
+ from datachain.catalog import Catalog
8
+
9
+
10
+ def query(
11
+ catalog: "Catalog",
12
+ script: str,
13
+ parallel: Optional[int] = None,
14
+ params: Optional[dict[str, str]] = None,
15
+ ) -> None:
16
+ from datachain.data_storage import JobQueryType, JobStatus
17
+
18
+ with open(script, encoding="utf-8") as f:
19
+ script_content = f.read()
20
+
21
+ if parallel is not None:
22
+ # This also sets this environment variable for any subprocesses
23
+ os.environ["DATACHAIN_SETTINGS_PARALLEL"] = str(parallel)
24
+
25
+ python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
26
+ python_executable = sys.executable
27
+
28
+ job_id = catalog.metastore.create_job(
29
+ name=os.path.basename(script),
30
+ query=script_content,
31
+ query_type=JobQueryType.PYTHON,
32
+ python_version=python_version,
33
+ params=params,
34
+ )
35
+
36
+ try:
37
+ catalog.query(
38
+ script_content,
39
+ python_executable=python_executable,
40
+ params=params,
41
+ job_id=job_id,
42
+ )
43
+ except Exception as e:
44
+ error_message = str(e)
45
+ error_stack = traceback.format_exc()
46
+ catalog.metastore.set_job_status(
47
+ job_id,
48
+ JobStatus.FAILED,
49
+ error_message=error_message,
50
+ error_stack=error_stack,
51
+ )
52
+ raise
53
+ catalog.metastore.set_job_status(job_id, JobStatus.COMPLETE)
@@ -0,0 +1,38 @@
1
+ from collections.abc import Sequence
2
+ from typing import TYPE_CHECKING, Optional
3
+
4
+ if TYPE_CHECKING:
5
+ from datachain.catalog import Catalog
6
+
7
+
8
+ def show(
9
+ catalog: "Catalog",
10
+ name: str,
11
+ version: Optional[int] = None,
12
+ limit: int = 10,
13
+ offset: int = 0,
14
+ columns: Sequence[str] = (),
15
+ no_collapse: bool = False,
16
+ schema: bool = False,
17
+ ) -> None:
18
+ from datachain import Session
19
+ from datachain.lib.dc import DataChain
20
+ from datachain.query.dataset import DatasetQuery
21
+ from datachain.utils import show_records
22
+
23
+ dataset = catalog.get_dataset(name)
24
+ dataset_version = dataset.get_version(version or dataset.latest_version)
25
+
26
+ query = (
27
+ DatasetQuery(name=name, version=version, catalog=catalog)
28
+ .select(*columns)
29
+ .limit(limit)
30
+ .offset(offset)
31
+ )
32
+ records = query.to_db_records()
33
+ show_records(records, collapse_columns=not no_collapse)
34
+ if schema and dataset_version.feature_schema:
35
+ print("\nSchema:")
36
+ session = Session.get(catalog=catalog)
37
+ dc = DataChain.from_dataset(name=name, version=version, session=session)
38
+ dc.print_schema()