datachain 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (50) hide show
  1. datachain/asyn.py +16 -6
  2. datachain/cache.py +32 -10
  3. datachain/catalog/catalog.py +17 -1
  4. datachain/cli/__init__.py +311 -0
  5. datachain/cli/commands/__init__.py +29 -0
  6. datachain/cli/commands/datasets.py +129 -0
  7. datachain/cli/commands/du.py +14 -0
  8. datachain/cli/commands/index.py +12 -0
  9. datachain/cli/commands/ls.py +169 -0
  10. datachain/cli/commands/misc.py +28 -0
  11. datachain/cli/commands/query.py +53 -0
  12. datachain/cli/commands/show.py +38 -0
  13. datachain/cli/parser/__init__.py +547 -0
  14. datachain/cli/parser/job.py +120 -0
  15. datachain/cli/parser/studio.py +126 -0
  16. datachain/cli/parser/utils.py +63 -0
  17. datachain/{cli_utils.py → cli/utils.py} +27 -1
  18. datachain/client/azure.py +6 -2
  19. datachain/client/fsspec.py +9 -3
  20. datachain/client/gcs.py +6 -2
  21. datachain/client/s3.py +16 -1
  22. datachain/data_storage/db_engine.py +9 -0
  23. datachain/data_storage/schema.py +4 -10
  24. datachain/data_storage/sqlite.py +7 -1
  25. datachain/data_storage/warehouse.py +6 -4
  26. datachain/{lib/diff.py → diff/__init__.py} +116 -12
  27. datachain/func/__init__.py +3 -2
  28. datachain/func/conditional.py +74 -0
  29. datachain/func/func.py +5 -1
  30. datachain/lib/arrow.py +7 -1
  31. datachain/lib/dc.py +8 -3
  32. datachain/lib/file.py +16 -5
  33. datachain/lib/hf.py +1 -1
  34. datachain/lib/listing.py +19 -1
  35. datachain/lib/pytorch.py +57 -13
  36. datachain/lib/signal_schema.py +89 -27
  37. datachain/lib/udf.py +82 -40
  38. datachain/listing.py +1 -0
  39. datachain/progress.py +20 -3
  40. datachain/query/dataset.py +122 -93
  41. datachain/query/dispatch.py +22 -16
  42. datachain/studio.py +58 -38
  43. datachain/utils.py +14 -3
  44. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/METADATA +9 -9
  45. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/RECORD +49 -37
  46. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/WHEEL +1 -1
  47. datachain/cli.py +0 -1475
  48. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/LICENSE +0 -0
  49. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/entry_points.txt +0 -0
  50. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/top_level.txt +0 -0
datachain/asyn.py CHANGED
@@ -8,12 +8,14 @@ from collections.abc import (
8
8
  Iterable,
9
9
  Iterator,
10
10
  )
11
- from concurrent.futures import ThreadPoolExecutor
11
+ from concurrent.futures import ThreadPoolExecutor, wait
12
12
  from heapq import heappop, heappush
13
13
  from typing import Any, Callable, Generic, Optional, TypeVar
14
14
 
15
15
  from fsspec.asyn import get_loop
16
16
 
17
+ from datachain.utils import safe_closing
18
+
17
19
  ASYNC_WORKERS = 20
18
20
 
19
21
  InputT = TypeVar("InputT", contravariant=True) # noqa: PLC0105
@@ -56,6 +58,7 @@ class AsyncMapper(Generic[InputT, ResultT]):
56
58
  self.pool = ThreadPoolExecutor(workers)
57
59
  self._tasks: set[asyncio.Task] = set()
58
60
  self._shutdown_producer = threading.Event()
61
+ self._producer_is_shutdown = threading.Event()
59
62
 
60
63
  def start_task(self, coro: Coroutine) -> asyncio.Task:
61
64
  task = self.loop.create_task(coro)
@@ -64,11 +67,16 @@ class AsyncMapper(Generic[InputT, ResultT]):
64
67
  return task
65
68
 
66
69
  def _produce(self) -> None:
67
- for item in self.iterable:
68
- if self._shutdown_producer.is_set():
69
- return
70
- fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop)
71
- fut.result() # wait until the item is in the queue
70
+ try:
71
+ with safe_closing(self.iterable):
72
+ for item in self.iterable:
73
+ if self._shutdown_producer.is_set():
74
+ return
75
+ coro = self.work_queue.put(item)
76
+ fut = asyncio.run_coroutine_threadsafe(coro, self.loop)
77
+ fut.result() # wait until the item is in the queue
78
+ finally:
79
+ self._producer_is_shutdown.set()
72
80
 
73
81
  async def produce(self) -> None:
74
82
  await self.to_thread(self._produce)
@@ -179,6 +187,8 @@ class AsyncMapper(Generic[InputT, ResultT]):
179
187
  self.shutdown_producer()
180
188
  if not async_run.done():
181
189
  async_run.cancel()
190
+ wait([async_run])
191
+ self._producer_is_shutdown.wait()
182
192
 
183
193
  def __iter__(self):
184
194
  return self.iterate()
datachain/cache.py CHANGED
@@ -1,8 +1,12 @@
1
1
  import os
2
+ from collections.abc import Iterator
3
+ from contextlib import contextmanager
4
+ from tempfile import mkdtemp
2
5
  from typing import TYPE_CHECKING, Optional
3
6
 
4
7
  from dvc_data.hashfile.db.local import LocalHashFileDB
5
8
  from dvc_objects.fs.local import LocalFileSystem
9
+ from dvc_objects.fs.utils import remove
6
10
  from fsspec.callbacks import Callback, TqdmCallback
7
11
 
8
12
  from .progress import Tqdm
@@ -20,6 +24,23 @@ def try_scandir(path):
20
24
  pass
21
25
 
22
26
 
27
+ def get_temp_cache(tmp_dir: str, prefix: Optional[str] = None) -> "DataChainCache":
28
+ cache_dir = mkdtemp(prefix=prefix, dir=tmp_dir)
29
+ return DataChainCache(cache_dir, tmp_dir=tmp_dir)
30
+
31
+
32
+ @contextmanager
33
+ def temporary_cache(
34
+ tmp_dir: str, prefix: Optional[str] = None, delete: bool = True
35
+ ) -> Iterator["DataChainCache"]:
36
+ cache = get_temp_cache(tmp_dir, prefix=prefix)
37
+ try:
38
+ yield cache
39
+ finally:
40
+ if delete:
41
+ cache.destroy()
42
+
43
+
23
44
  class DataChainCache:
24
45
  def __init__(self, cache_dir: str, tmp_dir: str):
25
46
  self.odb = LocalHashFileDB(
@@ -28,6 +49,9 @@ class DataChainCache:
28
49
  tmp_dir=tmp_dir,
29
50
  )
30
51
 
52
+ def __eq__(self, other) -> bool:
53
+ return self.odb == other.odb
54
+
31
55
  @property
32
56
  def cache_dir(self):
33
57
  return self.odb.path
@@ -63,7 +87,7 @@ class DataChainCache:
63
87
  if size < 0:
64
88
  size = await client.get_size(from_path, version_id=file.version)
65
89
  cb = callback or TqdmCallback(
66
- tqdm_kwargs={"desc": odb_fs.name(from_path), "bytes": True},
90
+ tqdm_kwargs={"desc": odb_fs.name(from_path), "bytes": True, "leave": False},
67
91
  tqdm_cls=Tqdm,
68
92
  size=size,
69
93
  )
@@ -82,20 +106,18 @@ class DataChainCache:
82
106
  os.unlink(tmp_info)
83
107
 
84
108
  def store_data(self, file: "File", contents: bytes) -> None:
85
- checksum = file.get_hash()
86
- dst = self.path_from_checksum(checksum)
87
- if not os.path.exists(dst):
88
- # Create the file only if it's not already in cache
89
- os.makedirs(os.path.dirname(dst), exist_ok=True)
90
- with open(dst, mode="wb") as f:
91
- f.write(contents)
92
-
93
- def clear(self):
109
+ self.odb.add_bytes(file.get_hash(), contents)
110
+
111
+ def clear(self) -> None:
94
112
  """
95
113
  Completely clear the cache.
96
114
  """
97
115
  self.odb.clear()
98
116
 
117
+ def destroy(self) -> None:
118
+ # `clear` leaves the prefix directory structure intact.
119
+ remove(self.cache_dir)
120
+
99
121
  def get_total_size(self) -> int:
100
122
  total = 0
101
123
  for subdir in try_scandir(self.odb.path):
@@ -405,6 +405,7 @@ def get_download_bar(bar_format: str, total_size: int):
405
405
  unit_scale=True,
406
406
  unit_divisor=1000,
407
407
  total=total_size,
408
+ leave=False,
408
409
  )
409
410
 
410
411
 
@@ -429,6 +430,7 @@ def instantiate_node_groups(
429
430
  unit_scale=True,
430
431
  unit_divisor=1000,
431
432
  total=total_files,
433
+ leave=False,
432
434
  )
433
435
  )
434
436
 
@@ -534,6 +536,12 @@ def find_column_to_str( # noqa: PLR0911
534
536
  return ""
535
537
 
536
538
 
539
+ def clone_catalog_with_cache(catalog: "Catalog", cache: "DataChainCache") -> "Catalog":
540
+ clone = catalog.copy()
541
+ clone.cache = cache
542
+ return clone
543
+
544
+
537
545
  class Catalog:
538
546
  def __init__(
539
547
  self,
@@ -1242,10 +1250,17 @@ class Catalog:
1242
1250
  path: str,
1243
1251
  version_id: Optional[str] = None,
1244
1252
  client_config=None,
1253
+ content_disposition: Optional[str] = None,
1254
+ **kwargs,
1245
1255
  ) -> str:
1246
1256
  client_config = client_config or self.client_config
1247
1257
  client = Client.get_client(source, self.cache, **client_config)
1248
- return client.url(path, version_id=version_id)
1258
+ return client.url(
1259
+ path,
1260
+ version_id=version_id,
1261
+ content_disposition=content_disposition,
1262
+ **kwargs,
1263
+ )
1249
1264
 
1250
1265
  def export_dataset_table(
1251
1266
  self,
@@ -1437,6 +1452,7 @@ class Catalog:
1437
1452
  unit_scale=True,
1438
1453
  unit_divisor=1000,
1439
1454
  total=ds_stats.num_objects, # type: ignore [union-attr]
1455
+ leave=False,
1440
1456
  )
1441
1457
 
1442
1458
  schema = DatasetRecord.parse_schema(remote_ds_version.schema)
@@ -0,0 +1,311 @@
1
+ import logging
2
+ import os
3
+ import sys
4
+ import traceback
5
+ from multiprocessing import freeze_support
6
+ from typing import Optional
7
+
8
+ from datachain.cli.utils import get_logging_level
9
+ from datachain.telemetry import telemetry
10
+
11
+ from .commands import (
12
+ clear_cache,
13
+ completion,
14
+ dataset_stats,
15
+ du,
16
+ edit_dataset,
17
+ garbage_collect,
18
+ index,
19
+ list_datasets,
20
+ ls,
21
+ query,
22
+ rm_dataset,
23
+ show,
24
+ )
25
+ from .parser import get_parser
26
+
27
+ logger = logging.getLogger("datachain")
28
+
29
+
30
+ def main(argv: Optional[list[str]] = None) -> int:
31
+ from datachain.catalog import get_catalog
32
+
33
+ # Required for Windows multiprocessing support
34
+ freeze_support()
35
+
36
+ datachain_parser = get_parser()
37
+ args = datachain_parser.parse_args(argv)
38
+
39
+ if args.command in ("internal-run-udf", "internal-run-udf-worker"):
40
+ return handle_udf(args.command)
41
+
42
+ logger.addHandler(logging.StreamHandler())
43
+ logging_level = get_logging_level(args)
44
+ logger.setLevel(logging_level)
45
+
46
+ client_config = {
47
+ "aws_endpoint_url": args.aws_endpoint_url,
48
+ "anon": args.anon,
49
+ }
50
+
51
+ if args.debug_sql:
52
+ # This also sets this environment variable for any subprocesses
53
+ os.environ["DEBUG_SHOW_SQL_QUERIES"] = "True"
54
+
55
+ error = None
56
+
57
+ try:
58
+ catalog = get_catalog(client_config=client_config)
59
+ return handle_command(args, catalog, client_config)
60
+ except BrokenPipeError as exc:
61
+ error, return_code = handle_broken_pipe_error(exc)
62
+ return return_code
63
+ except (KeyboardInterrupt, Exception) as exc:
64
+ error, return_code = handle_general_exception(exc, args, logging_level)
65
+ return return_code
66
+ finally:
67
+ telemetry.send_cli_call(args.command, error=error)
68
+
69
+
70
+ def handle_command(args, catalog, client_config) -> int:
71
+ """Handle the different CLI commands."""
72
+ from datachain.studio import process_jobs_args, process_studio_cli_args
73
+
74
+ command_handlers = {
75
+ "cp": lambda: handle_cp_command(args, catalog),
76
+ "clone": lambda: handle_clone_command(args, catalog),
77
+ "dataset": lambda: handle_dataset_command(args, catalog),
78
+ "ds": lambda: handle_dataset_command(args, catalog),
79
+ "ls": lambda: handle_ls_command(args, client_config),
80
+ "show": lambda: handle_show_command(args, catalog),
81
+ "du": lambda: handle_du_command(args, catalog, client_config),
82
+ "find": lambda: handle_find_command(args, catalog),
83
+ "index": lambda: handle_index_command(args, catalog),
84
+ "completion": lambda: handle_completion_command(args),
85
+ "query": lambda: handle_query_command(args, catalog),
86
+ "clear-cache": lambda: clear_cache(catalog),
87
+ "gc": lambda: garbage_collect(catalog),
88
+ "studio": lambda: process_studio_cli_args(args),
89
+ "job": lambda: process_jobs_args(args),
90
+ }
91
+
92
+ handler = command_handlers.get(args.command)
93
+ if handler:
94
+ handler()
95
+ return 0
96
+ print(f"invalid command: {args.command}", file=sys.stderr)
97
+ return 1
98
+
99
+
100
+ def handle_cp_command(args, catalog):
101
+ catalog.cp(
102
+ args.sources,
103
+ args.output,
104
+ force=bool(args.force),
105
+ update=bool(args.update),
106
+ recursive=bool(args.recursive),
107
+ edatachain_file=None,
108
+ edatachain_only=False,
109
+ no_edatachain_file=True,
110
+ no_glob=args.no_glob,
111
+ )
112
+
113
+
114
+ def handle_clone_command(args, catalog):
115
+ catalog.clone(
116
+ args.sources,
117
+ args.output,
118
+ force=bool(args.force),
119
+ update=bool(args.update),
120
+ recursive=bool(args.recursive),
121
+ no_glob=args.no_glob,
122
+ no_cp=args.no_cp,
123
+ edatachain=args.edatachain,
124
+ edatachain_file=args.edatachain_file,
125
+ )
126
+
127
+
128
+ def handle_dataset_command(args, catalog):
129
+ dataset_commands = {
130
+ "pull": lambda: catalog.pull_dataset(
131
+ args.dataset,
132
+ args.output,
133
+ local_ds_name=args.local_name,
134
+ local_ds_version=args.local_version,
135
+ cp=args.cp,
136
+ force=bool(args.force),
137
+ edatachain=args.edatachain,
138
+ edatachain_file=args.edatachain_file,
139
+ ),
140
+ "edit": lambda: edit_dataset(
141
+ catalog,
142
+ args.name,
143
+ new_name=args.new_name,
144
+ description=args.description,
145
+ labels=args.labels,
146
+ studio=args.studio,
147
+ local=args.local,
148
+ all=args.all,
149
+ team=args.team,
150
+ ),
151
+ "ls": lambda: list_datasets(
152
+ catalog=catalog,
153
+ studio=args.studio,
154
+ local=args.local,
155
+ all=args.all,
156
+ team=args.team,
157
+ ),
158
+ "rm": lambda: rm_dataset(
159
+ catalog,
160
+ args.name,
161
+ version=args.version,
162
+ force=args.force,
163
+ studio=args.studio,
164
+ local=args.local,
165
+ all=args.all,
166
+ team=args.team,
167
+ ),
168
+ "remove": lambda: rm_dataset(
169
+ catalog,
170
+ args.name,
171
+ version=args.version,
172
+ force=args.force,
173
+ studio=args.studio,
174
+ local=args.local,
175
+ all=args.all,
176
+ team=args.team,
177
+ ),
178
+ "stats": lambda: dataset_stats(
179
+ catalog,
180
+ args.name,
181
+ args.version,
182
+ show_bytes=args.bytes,
183
+ si=args.si,
184
+ ),
185
+ }
186
+
187
+ handler = dataset_commands.get(args.datasets_cmd)
188
+ if handler:
189
+ return handler()
190
+ raise Exception(f"Unexpected command {args.datasets_cmd}")
191
+
192
+
193
+ def handle_ls_command(args, client_config):
194
+ ls(
195
+ args.sources,
196
+ long=bool(args.long),
197
+ studio=args.studio,
198
+ local=args.local,
199
+ all=args.all,
200
+ team=args.team,
201
+ update=bool(args.update),
202
+ client_config=client_config,
203
+ )
204
+
205
+
206
+ def handle_show_command(args, catalog):
207
+ show(
208
+ catalog,
209
+ args.name,
210
+ args.version,
211
+ limit=args.limit,
212
+ offset=args.offset,
213
+ columns=args.columns,
214
+ no_collapse=args.no_collapse,
215
+ schema=args.schema,
216
+ )
217
+
218
+
219
+ def handle_du_command(args, catalog, client_config):
220
+ du(
221
+ catalog,
222
+ args.sources,
223
+ show_bytes=args.bytes,
224
+ depth=args.depth,
225
+ si=args.si,
226
+ update=bool(args.update),
227
+ client_config=client_config,
228
+ )
229
+
230
+
231
+ def handle_find_command(args, catalog):
232
+ results_found = False
233
+ for result in catalog.find(
234
+ args.sources,
235
+ update=bool(args.update),
236
+ names=args.name,
237
+ inames=args.iname,
238
+ paths=args.path,
239
+ ipaths=args.ipath,
240
+ size=args.size,
241
+ typ=args.type,
242
+ columns=args.columns,
243
+ ):
244
+ print(result)
245
+ results_found = True
246
+ if not results_found:
247
+ print("No results")
248
+
249
+
250
+ def handle_index_command(args, catalog):
251
+ index(
252
+ catalog,
253
+ args.sources,
254
+ update=bool(args.update),
255
+ )
256
+
257
+
258
+ def handle_completion_command(args):
259
+ print(completion(args.shell))
260
+
261
+
262
+ def handle_query_command(args, catalog):
263
+ query(
264
+ catalog,
265
+ args.script,
266
+ parallel=args.parallel,
267
+ params=args.param,
268
+ )
269
+
270
+
271
+ def handle_broken_pipe_error(exc):
272
+ # Python flushes standard streams on exit; redirect remaining output
273
+ # to devnull to avoid another BrokenPipeError at shutdown
274
+ # See: https://docs.python.org/3/library/signal.html#note-on-sigpipe
275
+ error = str(exc)
276
+ devnull = os.open(os.devnull, os.O_WRONLY)
277
+ os.dup2(devnull, sys.stdout.fileno())
278
+ return error, 141 # 128 + 13 (SIGPIPE)
279
+
280
+
281
+ def handle_general_exception(exc, args, logging_level):
282
+ error = str(exc)
283
+ if isinstance(exc, KeyboardInterrupt):
284
+ msg = "Operation cancelled by the user"
285
+ else:
286
+ msg = str(exc)
287
+ print("Error:", msg, file=sys.stderr)
288
+ if logging_level <= logging.DEBUG:
289
+ traceback.print_exception(
290
+ type(exc),
291
+ exc,
292
+ exc.__traceback__,
293
+ file=sys.stderr,
294
+ )
295
+ if args.pdb:
296
+ import pdb # noqa: T100
297
+
298
+ pdb.post_mortem()
299
+ return error, 1
300
+
301
+
302
+ def handle_udf(command):
303
+ if command == "internal-run-udf":
304
+ from datachain.query.dispatch import udf_entrypoint
305
+
306
+ return udf_entrypoint()
307
+
308
+ if command == "internal-run-udf-worker":
309
+ from datachain.query.dispatch import udf_worker_entrypoint
310
+
311
+ return udf_worker_entrypoint()
@@ -0,0 +1,29 @@
1
+ from .datasets import (
2
+ dataset_stats,
3
+ edit_dataset,
4
+ list_datasets,
5
+ list_datasets_local,
6
+ rm_dataset,
7
+ )
8
+ from .du import du
9
+ from .index import index
10
+ from .ls import ls
11
+ from .misc import clear_cache, completion, garbage_collect
12
+ from .query import query
13
+ from .show import show
14
+
15
+ __all__ = [
16
+ "clear_cache",
17
+ "completion",
18
+ "dataset_stats",
19
+ "du",
20
+ "edit_dataset",
21
+ "garbage_collect",
22
+ "index",
23
+ "list_datasets",
24
+ "list_datasets_local",
25
+ "ls",
26
+ "query",
27
+ "rm_dataset",
28
+ "show",
29
+ ]
@@ -0,0 +1,129 @@
1
+ import sys
2
+ from typing import TYPE_CHECKING, Optional
3
+
4
+ from tabulate import tabulate
5
+
6
+ from datachain import utils
7
+
8
+ if TYPE_CHECKING:
9
+ from datachain.catalog import Catalog
10
+
11
+ from datachain.cli.utils import determine_flavors
12
+ from datachain.config import Config
13
+ from datachain.error import DatasetNotFoundError
14
+
15
+
16
+ def list_datasets(
17
+ catalog: "Catalog",
18
+ studio: bool = False,
19
+ local: bool = False,
20
+ all: bool = True,
21
+ team: Optional[str] = None,
22
+ ):
23
+ from datachain.studio import list_datasets
24
+
25
+ token = Config().read().get("studio", {}).get("token")
26
+ all, local, studio = determine_flavors(studio, local, all, token)
27
+
28
+ local_datasets = set(list_datasets_local(catalog)) if all or local else set()
29
+ studio_datasets = (
30
+ set(list_datasets(team=team)) if (all or studio) and token else set()
31
+ )
32
+
33
+ rows = [
34
+ _datasets_tabulate_row(
35
+ name=name,
36
+ version=version,
37
+ both=(all or (local and studio)) and token,
38
+ local=(name, version) in local_datasets,
39
+ studio=(name, version) in studio_datasets,
40
+ )
41
+ for name, version in local_datasets.union(studio_datasets)
42
+ ]
43
+
44
+ print(tabulate(rows, headers="keys"))
45
+
46
+
47
+ def list_datasets_local(catalog: "Catalog"):
48
+ for d in catalog.ls_datasets():
49
+ for v in d.versions:
50
+ yield (d.name, v.version)
51
+
52
+
53
+ def _datasets_tabulate_row(name, version, both, local, studio):
54
+ row = {
55
+ "Name": name,
56
+ "Version": version,
57
+ }
58
+ if both:
59
+ row["Studio"] = "\u2714" if studio else "\u2716"
60
+ row["Local"] = "\u2714" if local else "\u2716"
61
+ return row
62
+
63
+
64
+ def rm_dataset(
65
+ catalog: "Catalog",
66
+ name: str,
67
+ version: Optional[int] = None,
68
+ force: Optional[bool] = False,
69
+ studio: bool = False,
70
+ local: bool = False,
71
+ all: bool = True,
72
+ team: Optional[str] = None,
73
+ ):
74
+ from datachain.studio import remove_studio_dataset
75
+
76
+ token = Config().read().get("studio", {}).get("token")
77
+ all, local, studio = determine_flavors(studio, local, all, token)
78
+
79
+ if all or local:
80
+ try:
81
+ catalog.remove_dataset(name, version=version, force=force)
82
+ except DatasetNotFoundError:
83
+ print("Dataset not found in local", file=sys.stderr)
84
+
85
+ if (all or studio) and token:
86
+ remove_studio_dataset(team, name, version, force)
87
+
88
+
89
+ def edit_dataset(
90
+ catalog: "Catalog",
91
+ name: str,
92
+ new_name: Optional[str] = None,
93
+ description: Optional[str] = None,
94
+ labels: Optional[list[str]] = None,
95
+ studio: bool = False,
96
+ local: bool = False,
97
+ all: bool = True,
98
+ team: Optional[str] = None,
99
+ ):
100
+ from datachain.studio import edit_studio_dataset
101
+
102
+ token = Config().read().get("studio", {}).get("token")
103
+ all, local, studio = determine_flavors(studio, local, all, token)
104
+
105
+ if all or local:
106
+ try:
107
+ catalog.edit_dataset(name, new_name, description, labels)
108
+ except DatasetNotFoundError:
109
+ print("Dataset not found in local", file=sys.stderr)
110
+
111
+ if (all or studio) and token:
112
+ edit_studio_dataset(team, name, new_name, description, labels)
113
+
114
+
115
+ def dataset_stats(
116
+ catalog: "Catalog",
117
+ name: str,
118
+ version: int,
119
+ show_bytes=False,
120
+ si=False,
121
+ ):
122
+ stats = catalog.dataset_stats(name, version)
123
+
124
+ if stats:
125
+ print(f"Number of objects: {stats.num_objects}")
126
+ if show_bytes:
127
+ print(f"Total objects size: {stats.size}")
128
+ else:
129
+ print(f"Total objects size: {utils.sizeof_fmt(stats.size, si=si): >7}")
@@ -0,0 +1,14 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from datachain import utils
4
+
5
+ if TYPE_CHECKING:
6
+ from datachain.catalog import Catalog
7
+
8
+
9
+ def du(catalog: "Catalog", sources, show_bytes=False, si=False, **kwargs):
10
+ for path, size in catalog.du(sources, **kwargs):
11
+ if show_bytes:
12
+ print(f"{size} {path}")
13
+ else:
14
+ print(f"{utils.sizeof_fmt(size, si=si): >7} {path}")
@@ -0,0 +1,12 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ if TYPE_CHECKING:
4
+ from datachain.catalog import Catalog
5
+
6
+
7
+ def index(
8
+ catalog: "Catalog",
9
+ sources,
10
+ **kwargs,
11
+ ):
12
+ catalog.index(sources, **kwargs)