datachain 0.8.3__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -0,0 +1,311 @@
1
+ import logging
2
+ import os
3
+ import sys
4
+ import traceback
5
+ from multiprocessing import freeze_support
6
+ from typing import Optional
7
+
8
+ from datachain.cli.utils import get_logging_level
9
+ from datachain.telemetry import telemetry
10
+
11
+ from .commands import (
12
+ clear_cache,
13
+ completion,
14
+ dataset_stats,
15
+ du,
16
+ edit_dataset,
17
+ garbage_collect,
18
+ index,
19
+ list_datasets,
20
+ ls,
21
+ query,
22
+ rm_dataset,
23
+ show,
24
+ )
25
+ from .parser import get_parser
26
+
27
+ logger = logging.getLogger("datachain")
28
+
29
+
30
+ def main(argv: Optional[list[str]] = None) -> int:
31
+ from datachain.catalog import get_catalog
32
+
33
+ # Required for Windows multiprocessing support
34
+ freeze_support()
35
+
36
+ datachain_parser = get_parser()
37
+ args = datachain_parser.parse_args(argv)
38
+
39
+ if args.command in ("internal-run-udf", "internal-run-udf-worker"):
40
+ return handle_udf(args.command)
41
+
42
+ logger.addHandler(logging.StreamHandler())
43
+ logging_level = get_logging_level(args)
44
+ logger.setLevel(logging_level)
45
+
46
+ client_config = {
47
+ "aws_endpoint_url": args.aws_endpoint_url,
48
+ "anon": args.anon,
49
+ }
50
+
51
+ if args.debug_sql:
52
+ # This also sets this environment variable for any subprocesses
53
+ os.environ["DEBUG_SHOW_SQL_QUERIES"] = "True"
54
+
55
+ error = None
56
+
57
+ try:
58
+ catalog = get_catalog(client_config=client_config)
59
+ return handle_command(args, catalog, client_config)
60
+ except BrokenPipeError as exc:
61
+ error, return_code = handle_broken_pipe_error(exc)
62
+ return return_code
63
+ except (KeyboardInterrupt, Exception) as exc:
64
+ error, return_code = handle_general_exception(exc, args, logging_level)
65
+ return return_code
66
+ finally:
67
+ telemetry.send_cli_call(args.command, error=error)
68
+
69
+
70
+ def handle_command(args, catalog, client_config) -> int:
71
+ """Handle the different CLI commands."""
72
+ from datachain.studio import process_jobs_args, process_studio_cli_args
73
+
74
+ command_handlers = {
75
+ "cp": lambda: handle_cp_command(args, catalog),
76
+ "clone": lambda: handle_clone_command(args, catalog),
77
+ "dataset": lambda: handle_dataset_command(args, catalog),
78
+ "ds": lambda: handle_dataset_command(args, catalog),
79
+ "ls": lambda: handle_ls_command(args, client_config),
80
+ "show": lambda: handle_show_command(args, catalog),
81
+ "du": lambda: handle_du_command(args, catalog, client_config),
82
+ "find": lambda: handle_find_command(args, catalog),
83
+ "index": lambda: handle_index_command(args, catalog),
84
+ "completion": lambda: handle_completion_command(args),
85
+ "query": lambda: handle_query_command(args, catalog),
86
+ "clear-cache": lambda: clear_cache(catalog),
87
+ "gc": lambda: garbage_collect(catalog),
88
+ "studio": lambda: process_studio_cli_args(args),
89
+ "job": lambda: process_jobs_args(args),
90
+ }
91
+
92
+ handler = command_handlers.get(args.command)
93
+ if handler:
94
+ handler()
95
+ return 0
96
+ print(f"invalid command: {args.command}", file=sys.stderr)
97
+ return 1
98
+
99
+
100
+ def handle_cp_command(args, catalog):
101
+ catalog.cp(
102
+ args.sources,
103
+ args.output,
104
+ force=bool(args.force),
105
+ update=bool(args.update),
106
+ recursive=bool(args.recursive),
107
+ edatachain_file=None,
108
+ edatachain_only=False,
109
+ no_edatachain_file=True,
110
+ no_glob=args.no_glob,
111
+ )
112
+
113
+
114
+ def handle_clone_command(args, catalog):
115
+ catalog.clone(
116
+ args.sources,
117
+ args.output,
118
+ force=bool(args.force),
119
+ update=bool(args.update),
120
+ recursive=bool(args.recursive),
121
+ no_glob=args.no_glob,
122
+ no_cp=args.no_cp,
123
+ edatachain=args.edatachain,
124
+ edatachain_file=args.edatachain_file,
125
+ )
126
+
127
+
128
+ def handle_dataset_command(args, catalog):
129
+ dataset_commands = {
130
+ "pull": lambda: catalog.pull_dataset(
131
+ args.dataset,
132
+ args.output,
133
+ local_ds_name=args.local_name,
134
+ local_ds_version=args.local_version,
135
+ cp=args.cp,
136
+ force=bool(args.force),
137
+ edatachain=args.edatachain,
138
+ edatachain_file=args.edatachain_file,
139
+ ),
140
+ "edit": lambda: edit_dataset(
141
+ catalog,
142
+ args.name,
143
+ new_name=args.new_name,
144
+ description=args.description,
145
+ labels=args.labels,
146
+ studio=args.studio,
147
+ local=args.local,
148
+ all=args.all,
149
+ team=args.team,
150
+ ),
151
+ "ls": lambda: list_datasets(
152
+ catalog=catalog,
153
+ studio=args.studio,
154
+ local=args.local,
155
+ all=args.all,
156
+ team=args.team,
157
+ ),
158
+ "rm": lambda: rm_dataset(
159
+ catalog,
160
+ args.name,
161
+ version=args.version,
162
+ force=args.force,
163
+ studio=args.studio,
164
+ local=args.local,
165
+ all=args.all,
166
+ team=args.team,
167
+ ),
168
+ "remove": lambda: rm_dataset(
169
+ catalog,
170
+ args.name,
171
+ version=args.version,
172
+ force=args.force,
173
+ studio=args.studio,
174
+ local=args.local,
175
+ all=args.all,
176
+ team=args.team,
177
+ ),
178
+ "stats": lambda: dataset_stats(
179
+ catalog,
180
+ args.name,
181
+ args.version,
182
+ show_bytes=args.bytes,
183
+ si=args.si,
184
+ ),
185
+ }
186
+
187
+ handler = dataset_commands.get(args.datasets_cmd)
188
+ if handler:
189
+ return handler()
190
+ raise Exception(f"Unexpected command {args.datasets_cmd}")
191
+
192
+
193
+ def handle_ls_command(args, client_config):
194
+ ls(
195
+ args.sources,
196
+ long=bool(args.long),
197
+ studio=args.studio,
198
+ local=args.local,
199
+ all=args.all,
200
+ team=args.team,
201
+ update=bool(args.update),
202
+ client_config=client_config,
203
+ )
204
+
205
+
206
+ def handle_show_command(args, catalog):
207
+ show(
208
+ catalog,
209
+ args.name,
210
+ args.version,
211
+ limit=args.limit,
212
+ offset=args.offset,
213
+ columns=args.columns,
214
+ no_collapse=args.no_collapse,
215
+ schema=args.schema,
216
+ )
217
+
218
+
219
+ def handle_du_command(args, catalog, client_config):
220
+ du(
221
+ catalog,
222
+ args.sources,
223
+ show_bytes=args.bytes,
224
+ depth=args.depth,
225
+ si=args.si,
226
+ update=bool(args.update),
227
+ client_config=client_config,
228
+ )
229
+
230
+
231
+ def handle_find_command(args, catalog):
232
+ results_found = False
233
+ for result in catalog.find(
234
+ args.sources,
235
+ update=bool(args.update),
236
+ names=args.name,
237
+ inames=args.iname,
238
+ paths=args.path,
239
+ ipaths=args.ipath,
240
+ size=args.size,
241
+ typ=args.type,
242
+ columns=args.columns,
243
+ ):
244
+ print(result)
245
+ results_found = True
246
+ if not results_found:
247
+ print("No results")
248
+
249
+
250
+ def handle_index_command(args, catalog):
251
+ index(
252
+ catalog,
253
+ args.sources,
254
+ update=bool(args.update),
255
+ )
256
+
257
+
258
+ def handle_completion_command(args):
259
+ print(completion(args.shell))
260
+
261
+
262
+ def handle_query_command(args, catalog):
263
+ query(
264
+ catalog,
265
+ args.script,
266
+ parallel=args.parallel,
267
+ params=args.param,
268
+ )
269
+
270
+
271
+ def handle_broken_pipe_error(exc):
272
+ # Python flushes standard streams on exit; redirect remaining output
273
+ # to devnull to avoid another BrokenPipeError at shutdown
274
+ # See: https://docs.python.org/3/library/signal.html#note-on-sigpipe
275
+ error = str(exc)
276
+ devnull = os.open(os.devnull, os.O_WRONLY)
277
+ os.dup2(devnull, sys.stdout.fileno())
278
+ return error, 141 # 128 + 13 (SIGPIPE)
279
+
280
+
281
+ def handle_general_exception(exc, args, logging_level):
282
+ error = str(exc)
283
+ if isinstance(exc, KeyboardInterrupt):
284
+ msg = "Operation cancelled by the user"
285
+ else:
286
+ msg = str(exc)
287
+ print("Error:", msg, file=sys.stderr)
288
+ if logging_level <= logging.DEBUG:
289
+ traceback.print_exception(
290
+ type(exc),
291
+ exc,
292
+ exc.__traceback__,
293
+ file=sys.stderr,
294
+ )
295
+ if args.pdb:
296
+ import pdb # noqa: T100
297
+
298
+ pdb.post_mortem()
299
+ return error, 1
300
+
301
+
302
+ def handle_udf(command):
303
+ if command == "internal-run-udf":
304
+ from datachain.query.dispatch import udf_entrypoint
305
+
306
+ return udf_entrypoint()
307
+
308
+ if command == "internal-run-udf-worker":
309
+ from datachain.query.dispatch import udf_worker_entrypoint
310
+
311
+ return udf_worker_entrypoint()
@@ -0,0 +1,29 @@
1
+ from .datasets import (
2
+ dataset_stats,
3
+ edit_dataset,
4
+ list_datasets,
5
+ list_datasets_local,
6
+ rm_dataset,
7
+ )
8
+ from .du import du
9
+ from .index import index
10
+ from .ls import ls
11
+ from .misc import clear_cache, completion, garbage_collect
12
+ from .query import query
13
+ from .show import show
14
+
15
+ __all__ = [
16
+ "clear_cache",
17
+ "completion",
18
+ "dataset_stats",
19
+ "du",
20
+ "edit_dataset",
21
+ "garbage_collect",
22
+ "index",
23
+ "list_datasets",
24
+ "list_datasets_local",
25
+ "ls",
26
+ "query",
27
+ "rm_dataset",
28
+ "show",
29
+ ]
@@ -0,0 +1,129 @@
1
+ import sys
2
+ from typing import TYPE_CHECKING, Optional
3
+
4
+ from tabulate import tabulate
5
+
6
+ from datachain import utils
7
+
8
+ if TYPE_CHECKING:
9
+ from datachain.catalog import Catalog
10
+
11
+ from datachain.cli.utils import determine_flavors
12
+ from datachain.config import Config
13
+ from datachain.error import DatasetNotFoundError
14
+
15
+
16
+ def list_datasets(
17
+ catalog: "Catalog",
18
+ studio: bool = False,
19
+ local: bool = False,
20
+ all: bool = True,
21
+ team: Optional[str] = None,
22
+ ):
23
+ from datachain.studio import list_datasets
24
+
25
+ token = Config().read().get("studio", {}).get("token")
26
+ all, local, studio = determine_flavors(studio, local, all, token)
27
+
28
+ local_datasets = set(list_datasets_local(catalog)) if all or local else set()
29
+ studio_datasets = (
30
+ set(list_datasets(team=team)) if (all or studio) and token else set()
31
+ )
32
+
33
+ rows = [
34
+ _datasets_tabulate_row(
35
+ name=name,
36
+ version=version,
37
+ both=(all or (local and studio)) and token,
38
+ local=(name, version) in local_datasets,
39
+ studio=(name, version) in studio_datasets,
40
+ )
41
+ for name, version in local_datasets.union(studio_datasets)
42
+ ]
43
+
44
+ print(tabulate(rows, headers="keys"))
45
+
46
+
47
+ def list_datasets_local(catalog: "Catalog"):
48
+ for d in catalog.ls_datasets():
49
+ for v in d.versions:
50
+ yield (d.name, v.version)
51
+
52
+
53
+ def _datasets_tabulate_row(name, version, both, local, studio):
54
+ row = {
55
+ "Name": name,
56
+ "Version": version,
57
+ }
58
+ if both:
59
+ row["Studio"] = "\u2714" if studio else "\u2716"
60
+ row["Local"] = "\u2714" if local else "\u2716"
61
+ return row
62
+
63
+
64
+ def rm_dataset(
65
+ catalog: "Catalog",
66
+ name: str,
67
+ version: Optional[int] = None,
68
+ force: Optional[bool] = False,
69
+ studio: bool = False,
70
+ local: bool = False,
71
+ all: bool = True,
72
+ team: Optional[str] = None,
73
+ ):
74
+ from datachain.studio import remove_studio_dataset
75
+
76
+ token = Config().read().get("studio", {}).get("token")
77
+ all, local, studio = determine_flavors(studio, local, all, token)
78
+
79
+ if all or local:
80
+ try:
81
+ catalog.remove_dataset(name, version=version, force=force)
82
+ except DatasetNotFoundError:
83
+ print("Dataset not found in local", file=sys.stderr)
84
+
85
+ if (all or studio) and token:
86
+ remove_studio_dataset(team, name, version, force)
87
+
88
+
89
+ def edit_dataset(
90
+ catalog: "Catalog",
91
+ name: str,
92
+ new_name: Optional[str] = None,
93
+ description: Optional[str] = None,
94
+ labels: Optional[list[str]] = None,
95
+ studio: bool = False,
96
+ local: bool = False,
97
+ all: bool = True,
98
+ team: Optional[str] = None,
99
+ ):
100
+ from datachain.studio import edit_studio_dataset
101
+
102
+ token = Config().read().get("studio", {}).get("token")
103
+ all, local, studio = determine_flavors(studio, local, all, token)
104
+
105
+ if all or local:
106
+ try:
107
+ catalog.edit_dataset(name, new_name, description, labels)
108
+ except DatasetNotFoundError:
109
+ print("Dataset not found in local", file=sys.stderr)
110
+
111
+ if (all or studio) and token:
112
+ edit_studio_dataset(team, name, new_name, description, labels)
113
+
114
+
115
+ def dataset_stats(
116
+ catalog: "Catalog",
117
+ name: str,
118
+ version: int,
119
+ show_bytes=False,
120
+ si=False,
121
+ ):
122
+ stats = catalog.dataset_stats(name, version)
123
+
124
+ if stats:
125
+ print(f"Number of objects: {stats.num_objects}")
126
+ if show_bytes:
127
+ print(f"Total objects size: {stats.size}")
128
+ else:
129
+ print(f"Total objects size: {utils.sizeof_fmt(stats.size, si=si): >7}")
@@ -0,0 +1,14 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from datachain import utils
4
+
5
+ if TYPE_CHECKING:
6
+ from datachain.catalog import Catalog
7
+
8
+
9
+ def du(catalog: "Catalog", sources, show_bytes=False, si=False, **kwargs):
10
+ for path, size in catalog.du(sources, **kwargs):
11
+ if show_bytes:
12
+ print(f"{size} {path}")
13
+ else:
14
+ print(f"{utils.sizeof_fmt(size, si=si): >7} {path}")
@@ -0,0 +1,12 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ if TYPE_CHECKING:
4
+ from datachain.catalog import Catalog
5
+
6
+
7
+ def index(
8
+ catalog: "Catalog",
9
+ sources,
10
+ **kwargs,
11
+ ):
12
+ catalog.index(sources, **kwargs)
@@ -0,0 +1,169 @@
1
+ import shlex
2
+ from collections.abc import Iterable, Iterator
3
+ from itertools import chain
4
+ from typing import TYPE_CHECKING, Optional
5
+
6
+ if TYPE_CHECKING:
7
+ from datachain.catalog import Catalog
8
+
9
+ from datachain.cli.utils import determine_flavors
10
+ from datachain.config import Config
11
+
12
+
13
+ def ls(
14
+ sources,
15
+ long: bool = False,
16
+ studio: bool = False,
17
+ local: bool = False,
18
+ all: bool = True,
19
+ team: Optional[str] = None,
20
+ **kwargs,
21
+ ):
22
+ token = Config().read().get("studio", {}).get("token")
23
+ all, local, studio = determine_flavors(studio, local, all, token)
24
+
25
+ if all or local:
26
+ ls_local(sources, long=long, **kwargs)
27
+
28
+ if (all or studio) and token:
29
+ ls_remote(sources, long=long, team=team)
30
+
31
+
32
+ def ls_local(
33
+ sources,
34
+ long: bool = False,
35
+ catalog: Optional["Catalog"] = None,
36
+ client_config=None,
37
+ **kwargs,
38
+ ):
39
+ from datachain import DataChain
40
+
41
+ if catalog is None:
42
+ from datachain.catalog import get_catalog
43
+
44
+ catalog = get_catalog(client_config=client_config)
45
+ if sources:
46
+ actual_sources = list(ls_urls(sources, catalog=catalog, long=long, **kwargs))
47
+ if len(actual_sources) == 1:
48
+ for _, entries in actual_sources:
49
+ for entry in entries:
50
+ print(format_ls_entry(entry))
51
+ else:
52
+ first = True
53
+ for source, entries in actual_sources:
54
+ # print a newline between directory listings
55
+ if first:
56
+ first = False
57
+ else:
58
+ print()
59
+ if source:
60
+ print(f"{source}:")
61
+ for entry in entries:
62
+ print(format_ls_entry(entry))
63
+ else:
64
+ chain = DataChain.listings()
65
+ for ls in chain.collect("listing"):
66
+ print(format_ls_entry(f"{ls.uri}@v{ls.version}")) # type: ignore[union-attr]
67
+
68
+
69
+ def format_ls_entry(entry: str) -> str:
70
+ if entry.endswith("/") or not entry:
71
+ entry = shlex.quote(entry[:-1])
72
+ return f"{entry}/"
73
+ return shlex.quote(entry)
74
+
75
+
76
+ def ls_remote(
77
+ paths: Iterable[str],
78
+ long: bool = False,
79
+ team: Optional[str] = None,
80
+ ):
81
+ from datachain.node import long_line_str
82
+ from datachain.remote.studio import StudioClient
83
+
84
+ client = StudioClient(team=team)
85
+ first = True
86
+ for path, response in client.ls(paths):
87
+ if not first:
88
+ print()
89
+ if not response.ok or response.data is None:
90
+ print(f"{path}:\n Error: {response.message}\n")
91
+ continue
92
+
93
+ print(f"{path}:")
94
+ if long:
95
+ for row in response.data:
96
+ entry = long_line_str(
97
+ row["name"] + ("/" if row["dir_type"] else ""),
98
+ row["last_modified"],
99
+ )
100
+ print(format_ls_entry(entry))
101
+ else:
102
+ for row in response.data:
103
+ entry = row["name"] + ("/" if row["dir_type"] else "")
104
+ print(format_ls_entry(entry))
105
+ first = False
106
+
107
+
108
+ def ls_urls(
109
+ sources,
110
+ catalog: "Catalog",
111
+ long: bool = False,
112
+ **kwargs,
113
+ ) -> Iterator[tuple[str, Iterator[str]]]:
114
+ curr_dir = None
115
+ value_iterables = []
116
+ for next_dir, values in _ls_urls_flat(sources, long, catalog, **kwargs):
117
+ if curr_dir is None or next_dir == curr_dir: # type: ignore[unreachable]
118
+ value_iterables.append(values)
119
+ else:
120
+ yield curr_dir, chain(*value_iterables) # type: ignore[unreachable]
121
+ value_iterables = [values]
122
+ curr_dir = next_dir
123
+ if curr_dir is not None:
124
+ yield curr_dir, chain(*value_iterables)
125
+
126
+
127
+ def _node_data_to_ls_values(row, long_format=False):
128
+ from datachain.node import DirType, long_line_str
129
+
130
+ name = row[0]
131
+ is_dir = row[1] == DirType.DIR
132
+ ending = "/" if is_dir else ""
133
+ value = name + ending
134
+ if long_format:
135
+ last_modified = row[2]
136
+ timestamp = last_modified if not is_dir else None
137
+ return long_line_str(value, timestamp)
138
+ return value
139
+
140
+
141
+ def _ls_urls_flat(
142
+ sources,
143
+ long: bool,
144
+ catalog: "Catalog",
145
+ **kwargs,
146
+ ) -> Iterator[tuple[str, Iterator[str]]]:
147
+ from datachain.client import Client
148
+ from datachain.node import long_line_str
149
+
150
+ for source in sources:
151
+ client_cls = Client.get_implementation(source)
152
+ if client_cls.is_root_url(source):
153
+ buckets = client_cls.ls_buckets(**catalog.client_config)
154
+ if long:
155
+ values = (long_line_str(b.name, b.created) for b in buckets)
156
+ else:
157
+ values = (b.name for b in buckets)
158
+ yield source, values
159
+ else:
160
+ found = False
161
+ fields = ["name", "dir_type"]
162
+ if long:
163
+ fields.append("last_modified")
164
+ for data_source, results in catalog.ls([source], fields=fields, **kwargs):
165
+ values = (_node_data_to_ls_values(r, long) for r in results)
166
+ found = True
167
+ yield data_source.dirname(), values
168
+ if not found:
169
+ raise FileNotFoundError(f"No such file or directory: {source}")
@@ -0,0 +1,28 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import shtab
4
+
5
+ if TYPE_CHECKING:
6
+ from datachain.catalog import Catalog
7
+
8
+
9
+ def clear_cache(catalog: "Catalog"):
10
+ catalog.cache.clear()
11
+
12
+
13
+ def garbage_collect(catalog: "Catalog"):
14
+ temp_tables = catalog.get_temp_table_names()
15
+ if not temp_tables:
16
+ print("Nothing to clean up.")
17
+ else:
18
+ print(f"Garbage collecting {len(temp_tables)} tables.")
19
+ catalog.cleanup_tables(temp_tables)
20
+
21
+
22
+ def completion(shell: str) -> str:
23
+ from datachain.cli import get_parser
24
+
25
+ return shtab.complete(
26
+ get_parser(),
27
+ shell=shell,
28
+ )