datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,28 +1,16 @@
1
1
  import io
2
- import json
3
2
  import logging
4
3
  import os
5
4
  import os.path
6
5
  import posixpath
7
- import signal
8
- import subprocess
9
- import sys
10
6
  import time
11
7
  import traceback
12
- from collections.abc import Iterable, Iterator, Mapping, Sequence
8
+ from collections.abc import Callable, Iterable, Iterator, Sequence
9
+ from contextlib import contextmanager, suppress
13
10
  from copy import copy
14
11
  from dataclasses import dataclass
15
12
  from functools import cached_property, reduce
16
- from threading import Thread
17
- from typing import (
18
- IO,
19
- TYPE_CHECKING,
20
- Any,
21
- Callable,
22
- NoReturn,
23
- Optional,
24
- Union,
25
- )
13
+ from typing import TYPE_CHECKING, Any
26
14
  from uuid import uuid4
27
15
 
28
16
  import sqlalchemy as sa
@@ -33,6 +21,7 @@ from datachain.cache import Cache
33
21
  from datachain.client import Client
34
22
  from datachain.dataset import (
35
23
  DATASET_PREFIX,
24
+ DEFAULT_DATASET_VERSION,
36
25
  QUERY_DATASET_PREFIX,
37
26
  DatasetDependency,
38
27
  DatasetListRecord,
@@ -40,31 +29,33 @@ from datachain.dataset import (
40
29
  DatasetStatus,
41
30
  StorageURI,
42
31
  create_dataset_uri,
32
+ parse_dataset_name,
43
33
  parse_dataset_uri,
34
+ parse_schema,
44
35
  )
45
36
  from datachain.error import (
46
37
  DataChainError,
47
38
  DatasetInvalidVersionError,
48
39
  DatasetNotFoundError,
49
40
  DatasetVersionNotFoundError,
50
- QueryScriptCancelError,
51
- QueryScriptRunError,
41
+ NamespaceNotFoundError,
42
+ ProjectNotFoundError,
52
43
  )
53
44
  from datachain.lib.listing import get_listing
54
45
  from datachain.node import DirType, Node, NodeWithPath
55
46
  from datachain.nodes_thread_pool import NodesThreadPool
47
+ from datachain.project import Project
56
48
  from datachain.sql.types import DateTime, SQLType
57
49
  from datachain.utils import DataChainDir
58
50
 
59
51
  from .datasource import DataSource
52
+ from .dependency import build_dependency_hierarchy, populate_nested_dependencies
60
53
 
61
54
  if TYPE_CHECKING:
62
- from datachain.data_storage import (
63
- AbstractMetastore,
64
- AbstractWarehouse,
65
- )
55
+ from datachain.data_storage import AbstractMetastore, AbstractWarehouse
66
56
  from datachain.dataset import DatasetListVersion
67
57
  from datachain.job import Job
58
+ from datachain.lib.listing_info import ListingInfo
68
59
  from datachain.listing import Listing
69
60
 
70
61
  logger = logging.getLogger("datachain")
@@ -75,10 +66,9 @@ TTL_INT = 4 * 60 * 60
75
66
 
76
67
  INDEX_INTERNAL_ERROR_MESSAGE = "Internal error on indexing"
77
68
  DATASET_INTERNAL_ERROR_MESSAGE = "Internal error on creating dataset"
78
- # exit code we use if last statement in query script is not instance of DatasetQuery
79
- QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE = 10
80
69
  # exit code we use if query script was canceled
81
70
  QUERY_SCRIPT_CANCELED_EXIT_CODE = 11
71
+ QUERY_SCRIPT_SIGTERM_EXIT_CODE = -15 # if query script was terminated by SIGTERM
82
72
 
83
73
  # dataset pull
84
74
  PULL_DATASET_MAX_THREADS = 5
@@ -87,64 +77,9 @@ PULL_DATASET_SLEEP_INTERVAL = 0.1 # sleep time while waiting for chunk to be av
87
77
  PULL_DATASET_CHECK_STATUS_INTERVAL = 20 # interval to check export status in Studio
88
78
 
89
79
 
90
- def noop(_: str):
91
- pass
92
-
93
-
94
- class TerminationSignal(RuntimeError): # noqa: N818
95
- def __init__(self, signal):
96
- self.signal = signal
97
- super().__init__("Received termination signal", signal)
98
-
99
- def __repr__(self):
100
- return f"{self.__class__.__name__}({self.signal})"
101
-
102
-
103
- if sys.platform == "win32":
104
- SIGINT = signal.CTRL_C_EVENT
105
- else:
106
- SIGINT = signal.SIGINT
107
-
108
-
109
- def shutdown_process(
110
- proc: subprocess.Popen,
111
- interrupt_timeout: Optional[int] = None,
112
- terminate_timeout: Optional[int] = None,
113
- ) -> int:
114
- """Shut down the process gracefully with SIGINT -> SIGTERM -> SIGKILL."""
115
-
116
- logger.info("sending interrupt signal to the process %s", proc.pid)
117
- proc.send_signal(SIGINT)
118
-
119
- logger.info("waiting for the process %s to finish", proc.pid)
120
- try:
121
- return proc.wait(interrupt_timeout)
122
- except subprocess.TimeoutExpired:
123
- logger.info(
124
- "timed out waiting, sending terminate signal to the process %s", proc.pid
125
- )
126
- proc.terminate()
127
- try:
128
- return proc.wait(terminate_timeout)
129
- except subprocess.TimeoutExpired:
130
- logger.info("timed out waiting, killing the process %s", proc.pid)
131
- proc.kill()
132
- return proc.wait()
133
-
134
-
135
- def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> None:
136
- buffer = b""
137
- while byt := stream.read(1): # Read one byte at a time
138
- buffer += byt
139
-
140
- if byt in (b"\n", b"\r"): # Check for newline or carriage return
141
- line = buffer.decode("utf-8")
142
- callback(line)
143
- buffer = b"" # Clear buffer for next line
144
-
145
- if buffer: # Handle any remaining data in the buffer
146
- line = buffer.decode("utf-8")
147
- callback(line)
80
+ def is_namespace_local(namespace_name) -> bool:
81
+ """Checks if namespace is from local environment, i.e. is `local`"""
82
+ return namespace_name == "local"
148
83
 
149
84
 
150
85
  class DatasetRowsFetcher(NodesThreadPool):
@@ -152,11 +87,11 @@ class DatasetRowsFetcher(NodesThreadPool):
152
87
  self,
153
88
  metastore: "AbstractMetastore",
154
89
  warehouse: "AbstractWarehouse",
155
- remote_ds_name: str,
156
- remote_ds_version: int,
157
- local_ds_name: str,
158
- local_ds_version: int,
159
- schema: dict[str, Union[SQLType, type[SQLType]]],
90
+ remote_ds: DatasetRecord,
91
+ remote_ds_version: str,
92
+ local_ds: DatasetRecord,
93
+ local_ds_version: str,
94
+ schema: dict[str, SQLType | type[SQLType]],
160
95
  max_threads: int = PULL_DATASET_MAX_THREADS,
161
96
  progress_bar=None,
162
97
  ):
@@ -166,12 +101,12 @@ class DatasetRowsFetcher(NodesThreadPool):
166
101
  self._check_dependencies()
167
102
  self.metastore = metastore
168
103
  self.warehouse = warehouse
169
- self.remote_ds_name = remote_ds_name
104
+ self.remote_ds = remote_ds
170
105
  self.remote_ds_version = remote_ds_version
171
- self.local_ds_name = local_ds_name
106
+ self.local_ds = local_ds
172
107
  self.local_ds_version = local_ds_version
173
108
  self.schema = schema
174
- self.last_status_check: Optional[float] = None
109
+ self.last_status_check: float | None = None
175
110
  self.studio_client = StudioClient()
176
111
  self.progress_bar = progress_bar
177
112
 
@@ -204,7 +139,7 @@ class DatasetRowsFetcher(NodesThreadPool):
204
139
  Checks are done every PULL_DATASET_CHECK_STATUS_INTERVAL seconds
205
140
  """
206
141
  export_status_response = self.studio_client.dataset_export_status(
207
- self.remote_ds_name, self.remote_ds_version
142
+ self.remote_ds, self.remote_ds_version
208
143
  )
209
144
  if not export_status_response.ok:
210
145
  raise DataChainError(export_status_response.message)
@@ -251,9 +186,7 @@ class DatasetRowsFetcher(NodesThreadPool):
251
186
  import pandas as pd
252
187
 
253
188
  # metastore and warehouse are not thread safe
254
- with self.metastore.clone() as metastore, self.warehouse.clone() as warehouse:
255
- local_ds = metastore.get_dataset(self.local_ds_name)
256
-
189
+ with self.warehouse.clone() as warehouse:
257
190
  urls = list(urls)
258
191
 
259
192
  for url in urls:
@@ -266,7 +199,7 @@ class DatasetRowsFetcher(NodesThreadPool):
266
199
  df = self.fix_columns(df)
267
200
 
268
201
  inserted = warehouse.insert_dataset_rows(
269
- df, local_ds, self.local_ds_version
202
+ df, self.local_ds, self.local_ds_version
270
203
  )
271
204
  self.increase_counter(inserted) # type: ignore [arg-type]
272
205
  # sometimes progress bar doesn't get updated so manually updating it
@@ -277,16 +210,16 @@ class DatasetRowsFetcher(NodesThreadPool):
277
210
  class NodeGroup:
278
211
  """Class for a group of nodes from the same source"""
279
212
 
280
- listing: Optional["Listing"]
281
- client: "Client"
213
+ listing: "Listing | None"
214
+ client: Client
282
215
  sources: list[DataSource]
283
216
 
284
217
  # The source path within the bucket
285
218
  # (not including the bucket name or s3:// prefix)
286
219
  source_path: str = ""
287
- dataset_name: Optional[str] = None
288
- dataset_version: Optional[int] = None
289
- instantiated_nodes: Optional[list[NodeWithPath]] = None
220
+ dataset_name: str | None = None
221
+ dataset_version: str | None = None
222
+ instantiated_nodes: list[NodeWithPath] | None = None
290
223
 
291
224
  @property
292
225
  def is_dataset(self) -> bool:
@@ -307,13 +240,23 @@ class NodeGroup:
307
240
  if self.sources:
308
241
  self.client.fetch_nodes(self.iternodes(recursive), shared_progress_bar=pbar)
309
242
 
243
+ def close(self) -> None:
244
+ if self.listing:
245
+ self.listing.close()
246
+
247
+ def __enter__(self) -> "NodeGroup":
248
+ return self
249
+
250
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
251
+ self.close()
252
+
310
253
 
311
254
  def prepare_output_for_cp(
312
255
  node_groups: list[NodeGroup],
313
256
  output: str,
314
257
  force: bool = False,
315
258
  no_cp: bool = False,
316
- ) -> tuple[bool, Optional[str]]:
259
+ ) -> tuple[bool, str | None]:
317
260
  total_node_count = 0
318
261
  for node_group in node_groups:
319
262
  if not node_group.sources:
@@ -362,7 +305,7 @@ def collect_nodes_for_cp(
362
305
 
363
306
  # Collect all sources to process
364
307
  for node_group in node_groups:
365
- listing: Optional[Listing] = node_group.listing
308
+ listing: Listing | None = node_group.listing
366
309
  valid_sources: list[DataSource] = []
367
310
  for dsrc in node_group.sources:
368
311
  if dsrc.is_single_object():
@@ -406,7 +349,7 @@ def instantiate_node_groups(
406
349
  recursive: bool = False,
407
350
  virtual_only: bool = False,
408
351
  always_copy_dir_contents: bool = False,
409
- copy_to_filename: Optional[str] = None,
352
+ copy_to_filename: str | None = None,
410
353
  ) -> None:
411
354
  instantiate_progress_bar = (
412
355
  None
@@ -434,7 +377,7 @@ def instantiate_node_groups(
434
377
  for node_group in node_groups:
435
378
  if not node_group.sources:
436
379
  continue
437
- listing: Optional[Listing] = node_group.listing
380
+ listing: Listing | None = node_group.listing
438
381
  source_path: str = node_group.source_path
439
382
 
440
383
  copy_dir_contents = always_copy_dir_contents or source_path.endswith("/")
@@ -517,10 +460,8 @@ class Catalog:
517
460
  warehouse: "AbstractWarehouse",
518
461
  cache_dir=None,
519
462
  tmp_dir=None,
520
- client_config: Optional[dict[str, Any]] = None,
521
- warehouse_ready_callback: Optional[
522
- Callable[["AbstractWarehouse"], None]
523
- ] = None,
463
+ client_config: dict[str, Any] | None = None,
464
+ warehouse_ready_callback: Callable[["AbstractWarehouse"], None] | None = None,
524
465
  in_memory: bool = False,
525
466
  ):
526
467
  datachain_dir = DataChainDir(cache=cache_dir, tmp=tmp_dir)
@@ -535,6 +476,7 @@ class Catalog:
535
476
  }
536
477
  self._warehouse_ready_callback = warehouse_ready_callback
537
478
  self.in_memory = in_memory
479
+ self._owns_connections = True # False for copies, prevents double-close
538
480
 
539
481
  @cached_property
540
482
  def warehouse(self) -> "AbstractWarehouse":
@@ -556,13 +498,36 @@ class Catalog:
556
498
  }
557
499
 
558
500
  def copy(self, cache=True, db=True):
501
+ """
502
+ Create a shallow copy of this catalog.
503
+
504
+ The copy shares metastore and warehouse with the original but will not
505
+ close them - only the original catalog owns the connections.
506
+ """
559
507
  result = copy(self)
508
+ result._owns_connections = False
560
509
  if not db:
561
510
  result.metastore = None
562
511
  result._warehouse = None
563
512
  result.warehouse = None
564
513
  return result
565
514
 
515
+ def close(self) -> None:
516
+ if not self._owns_connections:
517
+ return
518
+ if self.metastore is not None:
519
+ with suppress(Exception):
520
+ self.metastore.close_on_exit()
521
+ if self._warehouse is not None:
522
+ with suppress(Exception):
523
+ self._warehouse.close_on_exit()
524
+
525
+ def __enter__(self) -> "Catalog":
526
+ return self
527
+
528
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
529
+ self.close()
530
+
566
531
  @classmethod
567
532
  def generate_query_dataset_name(cls) -> str:
568
533
  return f"{QUERY_DATASET_PREFIX}_{uuid4().hex}"
@@ -580,15 +545,13 @@ class Catalog:
580
545
  source: str,
581
546
  update=False,
582
547
  client_config=None,
583
- object_name="file",
548
+ column="file",
584
549
  skip_indexing=False,
585
- ) -> tuple[Optional["Listing"], "Client", str]:
550
+ ) -> tuple["Listing | None", Client, str]:
586
551
  from datachain import read_storage
587
552
  from datachain.listing import Listing
588
553
 
589
- read_storage(
590
- source, session=self.session, update=update, object_name=object_name
591
- ).exec()
554
+ read_storage(source, session=self.session, update=update, column=column).exec()
592
555
 
593
556
  list_ds_name, list_uri, list_path, _ = get_listing(
594
557
  source, self.session, update=update
@@ -602,13 +565,13 @@ class Catalog:
602
565
  self.warehouse.clone(),
603
566
  client,
604
567
  dataset_name=list_ds_name,
605
- object_name=object_name,
568
+ column=column,
606
569
  )
607
570
 
608
571
  return lst, client, list_path
609
572
 
610
573
  def _remove_dataset_rows_and_warehouse_info(
611
- self, dataset: DatasetRecord, version: int, **kwargs
574
+ self, dataset: DatasetRecord, version: str, **kwargs
612
575
  ):
613
576
  self.warehouse.drop_dataset_rows_table(dataset, version)
614
577
  self.update_dataset_version_with_warehouse_info(
@@ -618,6 +581,7 @@ class Catalog:
618
581
  **kwargs,
619
582
  )
620
583
 
584
+ @contextmanager
621
585
  def enlist_sources(
622
586
  self,
623
587
  sources: list[str],
@@ -625,34 +589,41 @@ class Catalog:
625
589
  skip_indexing=False,
626
590
  client_config=None,
627
591
  only_index=False,
628
- ) -> Optional[list["DataSource"]]:
629
- enlisted_sources = []
630
- for src in sources: # Opt: parallel
631
- listing, client, file_path = self.enlist_source(
632
- src,
633
- update,
634
- client_config=client_config or self.client_config,
635
- skip_indexing=skip_indexing,
636
- )
637
- enlisted_sources.append((listing, client, file_path))
638
-
639
- if only_index:
640
- # sometimes we don't really need listing result (e.g on indexing process)
641
- # so this is to improve performance
642
- return None
643
-
644
- dsrc_all: list[DataSource] = []
645
- for listing, client, file_path in enlisted_sources:
646
- if not listing:
647
- nodes = [Node.from_file(client.get_file_info(file_path))]
648
- dir_only = False
649
- else:
650
- nodes = listing.expand_path(file_path)
651
- dir_only = file_path.endswith("/")
652
- dsrc_all.extend(
653
- DataSource(listing, client, node, dir_only) for node in nodes
654
- )
655
- return dsrc_all
592
+ ) -> Iterator[list["DataSource"] | None]:
593
+ enlisted_sources: list[tuple[Listing | None, Client, str]] = []
594
+ try:
595
+ for src in sources: # Opt: parallel
596
+ listing, client, file_path = self.enlist_source(
597
+ src,
598
+ update,
599
+ client_config=client_config or self.client_config,
600
+ skip_indexing=skip_indexing,
601
+ )
602
+ enlisted_sources.append((listing, client, file_path))
603
+
604
+ if only_index:
605
+ # sometimes we don't really need listing result (e.g. on indexing
606
+ # process) so this is to improve performance
607
+ yield None
608
+ return
609
+
610
+ dsrc_all: list[DataSource] = []
611
+ for listing, client, file_path in enlisted_sources:
612
+ if not listing:
613
+ nodes = [Node.from_file(client.get_file_info(file_path))]
614
+ dir_only = False
615
+ else:
616
+ nodes = listing.expand_path(file_path)
617
+ dir_only = file_path.endswith("/")
618
+ dsrc_all.extend(
619
+ DataSource(listing, client, node, dir_only) for node in nodes
620
+ )
621
+ yield dsrc_all
622
+ finally:
623
+ for listing, _, _ in enlisted_sources:
624
+ if listing:
625
+ with suppress(Exception):
626
+ listing.close()
656
627
 
657
628
  def enlist_sources_grouped(
658
629
  self,
@@ -671,10 +642,15 @@ class Catalog:
671
642
  enlisted_sources: list[tuple[bool, bool, Any]] = []
672
643
  client_config = client_config or self.client_config
673
644
  for src in sources: # Opt: parallel
674
- listing: Optional[Listing]
645
+ listing: Listing | None
675
646
  if src.startswith("ds://"):
676
647
  ds_name, ds_version = parse_dataset_uri(src)
677
- dataset = self.get_dataset(ds_name)
648
+ ds_namespace, ds_project, ds_name = parse_dataset_name(ds_name)
649
+ assert ds_namespace
650
+ assert ds_project
651
+ dataset = self.get_dataset(
652
+ ds_name, namespace_name=ds_namespace, project_name=ds_project
653
+ )
678
654
  if not ds_version:
679
655
  ds_version = dataset.latest_version
680
656
  dataset_sources = self.warehouse.get_dataset_sources(
@@ -694,7 +670,11 @@ class Catalog:
694
670
  dataset_name=dataset_name,
695
671
  )
696
672
  rows = DatasetQuery(
697
- name=dataset.name, version=ds_version, catalog=self
673
+ name=dataset.name,
674
+ namespace_name=dataset.project.namespace.name,
675
+ project_name=dataset.project.name,
676
+ version=ds_version,
677
+ catalog=self,
698
678
  ).to_db_records()
699
679
  indexed_sources.append(
700
680
  (
@@ -768,44 +748,56 @@ class Catalog:
768
748
  def create_dataset(
769
749
  self,
770
750
  name: str,
771
- version: Optional[int] = None,
751
+ project: Project | None = None,
752
+ version: str | None = None,
772
753
  *,
773
754
  columns: Sequence[Column],
774
- feature_schema: Optional[dict] = None,
755
+ feature_schema: dict | None = None,
775
756
  query_script: str = "",
776
- create_rows: Optional[bool] = True,
777
- validate_version: Optional[bool] = True,
778
- listing: Optional[bool] = False,
779
- uuid: Optional[str] = None,
780
- description: Optional[str] = None,
781
- labels: Optional[list[str]] = None,
757
+ create_rows: bool | None = True,
758
+ validate_version: bool | None = True,
759
+ listing: bool | None = False,
760
+ uuid: str | None = None,
761
+ description: str | None = None,
762
+ attrs: list[str] | None = None,
763
+ update_version: str | None = "patch",
764
+ job_id: str | None = None,
782
765
  ) -> "DatasetRecord":
783
766
  """
784
767
  Creates new dataset of a specific version.
785
768
  If dataset is not yet created, it will create it with version 1
786
769
  If version is None, then next unused version is created.
787
- If version is given, then it must be an unused version number.
770
+ If version is given, then it must be an unused version.
788
771
  """
772
+ DatasetRecord.validate_name(name)
789
773
  assert [c.name for c in columns if c.name != "sys__id"], f"got {columns=}"
790
774
  if not listing and Client.is_data_source_uri(name):
791
775
  raise RuntimeError(
792
776
  "Cannot create dataset that starts with source prefix, e.g s3://"
793
777
  )
794
- default_version = 1
778
+ default_version = DEFAULT_DATASET_VERSION
795
779
  try:
796
- dataset = self.get_dataset(name)
797
- default_version = dataset.next_version
798
-
799
- if (description or labels) and (
800
- dataset.description != description or dataset.labels != labels
780
+ dataset = self.get_dataset(
781
+ name,
782
+ namespace_name=project.namespace.name if project else None,
783
+ project_name=project.name if project else None,
784
+ )
785
+ default_version = dataset.next_version_patch
786
+ if update_version == "major":
787
+ default_version = dataset.next_version_major
788
+ if update_version == "minor":
789
+ default_version = dataset.next_version_minor
790
+
791
+ if (description or attrs) and (
792
+ dataset.description != description or dataset.attrs != attrs
801
793
  ):
802
794
  description = description or dataset.description
803
- labels = labels or dataset.labels
795
+ attrs = attrs or dataset.attrs
804
796
 
805
797
  self.update_dataset(
806
798
  dataset,
807
799
  description=description,
808
- labels=labels,
800
+ attrs=attrs,
809
801
  )
810
802
 
811
803
  except DatasetNotFoundError:
@@ -814,12 +806,13 @@ class Catalog:
814
806
  }
815
807
  dataset = self.metastore.create_dataset(
816
808
  name,
809
+ project.id if project else None,
817
810
  feature_schema=feature_schema,
818
811
  query_script=query_script,
819
812
  schema=schema,
820
813
  ignore_if_exists=True,
821
814
  description=description,
822
- labels=labels,
815
+ attrs=attrs,
823
816
  )
824
817
 
825
818
  version = version or default_version
@@ -834,7 +827,7 @@ class Catalog:
834
827
  f"Version {version} must be higher than the current latest one"
835
828
  )
836
829
 
837
- return self.create_new_dataset_version(
830
+ return self.create_dataset_version(
838
831
  dataset,
839
832
  version,
840
833
  feature_schema=feature_schema,
@@ -842,12 +835,13 @@ class Catalog:
842
835
  create_rows_table=create_rows,
843
836
  columns=columns,
844
837
  uuid=uuid,
838
+ job_id=job_id,
845
839
  )
846
840
 
847
- def create_new_dataset_version(
841
+ def create_dataset_version(
848
842
  self,
849
843
  dataset: DatasetRecord,
850
- version: int,
844
+ version: str,
851
845
  *,
852
846
  columns: Sequence[Column],
853
847
  sources="",
@@ -857,8 +851,8 @@ class Catalog:
857
851
  error_stack="",
858
852
  script_output="",
859
853
  create_rows_table=True,
860
- job_id: Optional[str] = None,
861
- uuid: Optional[str] = None,
854
+ job_id: str | None = None,
855
+ uuid: str | None = None,
862
856
  ) -> DatasetRecord:
863
857
  """
864
858
  Creates dataset version if it doesn't exist.
@@ -872,7 +866,7 @@ class Catalog:
872
866
  dataset = self.metastore.create_dataset_version(
873
867
  dataset,
874
868
  version,
875
- status=DatasetStatus.PENDING,
869
+ status=DatasetStatus.CREATED,
876
870
  sources=sources,
877
871
  feature_schema=feature_schema,
878
872
  query_script=query_script,
@@ -886,14 +880,14 @@ class Catalog:
886
880
  )
887
881
 
888
882
  if create_rows_table:
889
- table_name = self.warehouse.dataset_table_name(dataset.name, version)
883
+ table_name = self.warehouse.dataset_table_name(dataset, version)
890
884
  self.warehouse.create_dataset_rows_table(table_name, columns=columns)
891
885
  self.update_dataset_version_with_warehouse_info(dataset, version)
892
886
 
893
887
  return dataset
894
888
 
895
889
  def update_dataset_version_with_warehouse_info(
896
- self, dataset: DatasetRecord, version: int, rows_dropped=False, **kwargs
890
+ self, dataset: DatasetRecord, version: str, rows_dropped=False, **kwargs
897
891
  ) -> None:
898
892
  from datachain.query.dataset import DatasetQuery
899
893
 
@@ -905,11 +899,7 @@ class Catalog:
905
899
  values["num_objects"] = None
906
900
  values["size"] = None
907
901
  values["preview"] = None
908
- self.metastore.update_dataset_version(
909
- dataset,
910
- version,
911
- **values,
912
- )
902
+ self.metastore.update_dataset_version(dataset, version, **values)
913
903
  return
914
904
 
915
905
  if not dataset_version.num_objects:
@@ -921,7 +911,13 @@ class Catalog:
921
911
 
922
912
  if not dataset_version.preview:
923
913
  values["preview"] = (
924
- DatasetQuery(name=dataset.name, version=version, catalog=self)
914
+ DatasetQuery(
915
+ name=dataset.name,
916
+ namespace_name=dataset.project.namespace.name,
917
+ project_name=dataset.project.name,
918
+ version=version,
919
+ catalog=self,
920
+ )
925
921
  .limit(20)
926
922
  .to_db_records()
927
923
  )
@@ -929,38 +925,18 @@ class Catalog:
929
925
  if not values:
930
926
  return
931
927
 
932
- self.metastore.update_dataset_version(
933
- dataset,
934
- version,
935
- **values,
936
- )
928
+ self.metastore.update_dataset_version(dataset, version, **values)
937
929
 
938
930
  def update_dataset(
939
931
  self, dataset: DatasetRecord, conn=None, **kwargs
940
932
  ) -> DatasetRecord:
941
933
  """Updates dataset fields."""
942
- old_name = None
943
- new_name = None
944
- if "name" in kwargs and kwargs["name"] != dataset.name:
945
- old_name = dataset.name
946
- new_name = kwargs["name"]
947
-
948
- dataset = self.metastore.update_dataset(dataset, conn=conn, **kwargs)
949
-
950
- if old_name and new_name:
951
- # updating name must result in updating dataset table names as well
952
- for version in [v.version for v in dataset.versions]:
953
- self.warehouse.rename_dataset_table(
954
- old_name,
955
- new_name,
956
- old_version=version,
957
- new_version=version,
958
- )
959
-
960
- return dataset
934
+ dataset_updated = self.metastore.update_dataset(dataset, conn=conn, **kwargs)
935
+ self.warehouse.rename_dataset_tables(dataset, dataset_updated)
936
+ return dataset_updated
961
937
 
962
938
  def remove_dataset_version(
963
- self, dataset: DatasetRecord, version: int, drop_rows: Optional[bool] = True
939
+ self, dataset: DatasetRecord, version: str, drop_rows: bool | None = True
964
940
  ) -> None:
965
941
  """
966
942
  Deletes one single dataset version.
@@ -988,6 +964,7 @@ class Catalog:
988
964
  self,
989
965
  name: str,
990
966
  sources: list[str],
967
+ project: Project | None = None,
991
968
  client_config=None,
992
969
  recursive=False,
993
970
  ) -> DatasetRecord:
@@ -996,6 +973,8 @@ class Catalog:
996
973
 
997
974
  from datachain import read_dataset, read_storage
998
975
 
976
+ project = project or self.metastore.default_project
977
+
999
978
  chains = []
1000
979
  for source in sources:
1001
980
  if source.startswith(DATASET_PREFIX):
@@ -1008,10 +987,15 @@ class Catalog:
1008
987
  # create union of all dataset queries created from sources
1009
988
  dc = reduce(lambda dc1, dc2: dc1.union(dc2), chains)
1010
989
  try:
990
+ dc = dc.settings(project=project.name, namespace=project.namespace.name)
1011
991
  dc.save(name)
1012
992
  except Exception as e: # noqa: BLE001
1013
993
  try:
1014
- ds = self.get_dataset(name)
994
+ ds = self.get_dataset(
995
+ name,
996
+ namespace_name=project.namespace.name,
997
+ project_name=project.name,
998
+ )
1015
999
  self.metastore.update_dataset_status(
1016
1000
  ds,
1017
1001
  DatasetStatus.FAILED,
@@ -1028,7 +1012,11 @@ class Catalog:
1028
1012
  except DatasetNotFoundError:
1029
1013
  raise e from None
1030
1014
 
1031
- ds = self.get_dataset(name)
1015
+ ds = self.get_dataset(
1016
+ name,
1017
+ namespace_name=project.namespace.name,
1018
+ project_name=project.name,
1019
+ )
1032
1020
 
1033
1021
  self.update_dataset_version_with_warehouse_info(
1034
1022
  ds,
@@ -1036,159 +1024,231 @@ class Catalog:
1036
1024
  sources="\n".join(sources),
1037
1025
  )
1038
1026
 
1039
- return self.get_dataset(name)
1027
+ return self.get_dataset(
1028
+ name,
1029
+ namespace_name=project.namespace.name,
1030
+ project_name=project.name,
1031
+ )
1040
1032
 
1041
- def register_dataset(
1033
+ def get_full_dataset_name(
1042
1034
  self,
1043
- dataset: DatasetRecord,
1044
- version: int,
1045
- target_dataset: DatasetRecord,
1046
- target_version: Optional[int] = None,
1047
- ) -> DatasetRecord:
1035
+ name: str,
1036
+ project_name: str | None = None,
1037
+ namespace_name: str | None = None,
1038
+ ) -> tuple[str, str, str]:
1048
1039
  """
1049
- Registers dataset version of one dataset as dataset version of another
1050
- one (it can be new version of existing one).
1051
- It also removes original dataset version
1040
+ Returns dataset name together with separated namespace and project name.
1041
+ It takes into account all the ways namespace and project can be added.
1052
1042
  """
1053
- target_version = target_version or target_dataset.next_version
1054
-
1055
- if not target_dataset.is_valid_next_version(target_version):
1056
- raise DatasetInvalidVersionError(
1057
- f"Version {target_version} must be higher than the current latest one"
1058
- )
1059
-
1060
- dataset_version = dataset.get_version(version)
1061
- if not dataset_version:
1062
- raise DatasetVersionNotFoundError(
1063
- f"Dataset {dataset.name} does not have version {version}"
1064
- )
1065
-
1066
- if not dataset_version.is_final_status():
1067
- raise ValueError("Cannot register dataset version in non final status")
1068
-
1069
- # copy dataset version
1070
- target_dataset = self.metastore.create_dataset_version(
1071
- target_dataset,
1072
- target_version,
1073
- sources=dataset_version.sources,
1074
- status=dataset_version.status,
1075
- query_script=dataset_version.query_script,
1076
- error_message=dataset_version.error_message,
1077
- error_stack=dataset_version.error_stack,
1078
- script_output=dataset_version.script_output,
1079
- created_at=dataset_version.created_at,
1080
- finished_at=dataset_version.finished_at,
1081
- schema=dataset_version.serialized_schema,
1082
- num_objects=dataset_version.num_objects,
1083
- size=dataset_version.size,
1084
- preview=dataset_version.preview,
1085
- job_id=dataset_version.job_id,
1086
- )
1087
-
1088
- # to avoid re-creating rows table, we are just renaming it for a new version
1089
- # of target dataset
1090
- self.warehouse.rename_dataset_table(
1091
- dataset.name,
1092
- target_dataset.name,
1093
- old_version=version,
1094
- new_version=target_version,
1043
+ parsed_namespace_name, parsed_project_name, name = parse_dataset_name(name)
1044
+
1045
+ namespace_env = os.environ.get("DATACHAIN_NAMESPACE")
1046
+ project_env = os.environ.get("DATACHAIN_PROJECT")
1047
+ if project_env and len(project_env.split(".")) == 2:
1048
+ # we allow setting both namespace and project in DATACHAIN_PROJECT
1049
+ namespace_env, project_env = project_env.split(".")
1050
+
1051
+ namespace_name = (
1052
+ parsed_namespace_name
1053
+ or namespace_name
1054
+ or namespace_env
1055
+ or self.metastore.default_namespace_name
1095
1056
  )
1096
- self.metastore.update_dataset_dependency_source(
1097
- dataset,
1098
- version,
1099
- new_source_dataset=target_dataset,
1100
- new_source_dataset_version=target_version,
1057
+ project_name = (
1058
+ parsed_project_name
1059
+ or project_name
1060
+ or project_env
1061
+ or self.metastore.default_project_name
1101
1062
  )
1102
1063
 
1103
- if dataset.id == target_dataset.id:
1104
- # we are updating the same dataset so we need to refresh it to have newly
1105
- # added version in step before
1106
- dataset = self.get_dataset(dataset.name)
1064
+ return namespace_name, project_name, name
1065
+
1066
+ def get_dataset(
1067
+ self,
1068
+ name: str,
1069
+ namespace_name: str | None = None,
1070
+ project_name: str | None = None,
1071
+ ) -> DatasetRecord:
1072
+ from datachain.lib.listing import is_listing_dataset
1107
1073
 
1108
- self.remove_dataset_version(dataset, version, drop_rows=False)
1074
+ namespace_name = namespace_name or self.metastore.default_namespace_name
1075
+ project_name = project_name or self.metastore.default_project_name
1109
1076
 
1110
- return self.get_dataset(target_dataset.name)
1077
+ if is_listing_dataset(name):
1078
+ namespace_name = self.metastore.system_namespace_name
1079
+ project_name = self.metastore.listing_project_name
1111
1080
 
1112
- def get_dataset(self, name: str) -> DatasetRecord:
1113
- return self.metastore.get_dataset(name)
1081
+ return self.metastore.get_dataset(
1082
+ name, namespace_name=namespace_name, project_name=project_name
1083
+ )
1114
1084
 
1115
1085
  def get_dataset_with_remote_fallback(
1116
- self, name: str, version: Optional[int] = None
1086
+ self,
1087
+ name: str,
1088
+ namespace_name: str,
1089
+ project_name: str,
1090
+ version: str | None = None,
1091
+ pull_dataset: bool = False,
1092
+ update: bool = False,
1117
1093
  ) -> DatasetRecord:
1118
- try:
1119
- ds = self.get_dataset(name)
1120
- if version and not ds.has_version(version):
1121
- raise DatasetVersionNotFoundError(
1122
- f"Dataset {name} does not have version {version}"
1094
+ from datachain.lib.dc.utils import is_studio
1095
+
1096
+ # Intentionally ignore update flag is version is provided. Here only exact
1097
+ # version can be provided and update then doesn't make sense.
1098
+ # It corresponds to a query like this for example:
1099
+ #
1100
+ # dc.read_dataset("some.remote.dataset", version="1.0.0", update=True)
1101
+ if version:
1102
+ update = False
1103
+
1104
+ # we don't do Studio fallback is script is already ran in Studio, or if we try
1105
+ # to fetch dataset with local namespace as that one cannot
1106
+ # exist in Studio in the first place
1107
+ no_fallback = is_studio() or is_namespace_local(namespace_name)
1108
+
1109
+ if no_fallback or not update:
1110
+ try:
1111
+ ds = self.get_dataset(
1112
+ name,
1113
+ namespace_name=namespace_name,
1114
+ project_name=project_name,
1123
1115
  )
1124
- return ds
1116
+ if not version or ds.has_version(version):
1117
+ return ds
1118
+ except (NamespaceNotFoundError, ProjectNotFoundError, DatasetNotFoundError):
1119
+ pass
1120
+
1121
+ if no_fallback:
1122
+ raise DatasetNotFoundError(
1123
+ f"Dataset {name}"
1124
+ + (f" version {version} " if version else " ")
1125
+ + f"not found in namespace {namespace_name} and project {project_name}"
1126
+ )
1125
1127
 
1126
- except (DatasetNotFoundError, DatasetVersionNotFoundError):
1128
+ if pull_dataset:
1127
1129
  print("Dataset not found in local catalog, trying to get from studio")
1128
-
1129
- remote_ds_uri = f"{DATASET_PREFIX}{name}"
1130
- if version:
1131
- remote_ds_uri += f"@v{version}"
1130
+ remote_ds_uri = create_dataset_uri(
1131
+ name, namespace_name, project_name, version
1132
+ )
1132
1133
 
1133
1134
  self.pull_dataset(
1134
1135
  remote_ds_uri=remote_ds_uri,
1135
1136
  local_ds_name=name,
1136
1137
  local_ds_version=version,
1137
1138
  )
1138
- return self.get_dataset(name)
1139
+ return self.get_dataset(
1140
+ name,
1141
+ namespace_name=namespace_name,
1142
+ project_name=project_name,
1143
+ )
1144
+
1145
+ return self.get_remote_dataset(namespace_name, project_name, name)
1139
1146
 
1140
1147
  def get_dataset_with_version_uuid(self, uuid: str) -> DatasetRecord:
1141
1148
  """Returns dataset that contains version with specific uuid"""
1142
1149
  for dataset in self.ls_datasets():
1143
1150
  if dataset.has_version_with_uuid(uuid):
1144
- return self.get_dataset(dataset.name)
1151
+ return self.get_dataset(
1152
+ dataset.name,
1153
+ namespace_name=dataset.project.namespace.name,
1154
+ project_name=dataset.project.name,
1155
+ )
1145
1156
  raise DatasetNotFoundError(f"Dataset with version uuid {uuid} not found.")
1146
1157
 
1147
- def get_remote_dataset(self, name: str) -> DatasetRecord:
1158
+ def get_remote_dataset(
1159
+ self, namespace: str, project: str, name: str
1160
+ ) -> DatasetRecord:
1148
1161
  from datachain.remote.studio import StudioClient
1149
1162
 
1150
1163
  studio_client = StudioClient()
1151
1164
 
1152
- info_response = studio_client.dataset_info(name)
1165
+ info_response = studio_client.dataset_info(namespace, project, name)
1153
1166
  if not info_response.ok:
1167
+ if info_response.status == 404:
1168
+ raise DatasetNotFoundError(
1169
+ f"Dataset {namespace}.{project}.{name} not found"
1170
+ )
1154
1171
  raise DataChainError(info_response.message)
1155
1172
 
1156
1173
  dataset_info = info_response.data
1157
1174
  assert isinstance(dataset_info, dict)
1158
1175
  return DatasetRecord.from_dict(dataset_info)
1159
1176
 
1160
- def get_dataset_dependencies(
1161
- self, name: str, version: int, indirect=False
1162
- ) -> list[Optional[DatasetDependency]]:
1163
- dataset = self.get_dataset(name)
1177
+ def get_dataset_dependencies_by_ids(
1178
+ self,
1179
+ dataset_id: int,
1180
+ version_id: int,
1181
+ indirect: bool = True,
1182
+ ) -> list[DatasetDependency | None]:
1183
+ dependency_nodes = self.metastore.get_dataset_dependency_nodes(
1184
+ dataset_id=dataset_id,
1185
+ version_id=version_id,
1186
+ )
1187
+
1188
+ if not dependency_nodes:
1189
+ return []
1190
+
1191
+ dependency_map, children_map = build_dependency_hierarchy(dependency_nodes)
1164
1192
 
1165
- direct_dependencies = self.metastore.get_direct_dataset_dependencies(
1166
- dataset, version
1193
+ root_key = (dataset_id, version_id)
1194
+ if root_key not in children_map:
1195
+ return []
1196
+
1197
+ root_dependency_ids = children_map[root_key]
1198
+ root_dependencies = [dependency_map[dep_id] for dep_id in root_dependency_ids]
1199
+
1200
+ if indirect:
1201
+ for dependency in root_dependencies:
1202
+ if dependency is not None:
1203
+ populate_nested_dependencies(
1204
+ dependency, dependency_nodes, dependency_map, children_map
1205
+ )
1206
+
1207
+ return root_dependencies
1208
+
1209
+ def get_dataset_dependencies(
1210
+ self,
1211
+ name: str,
1212
+ version: str,
1213
+ namespace_name: str | None = None,
1214
+ project_name: str | None = None,
1215
+ indirect=False,
1216
+ ) -> list[DatasetDependency | None]:
1217
+ dataset = self.get_dataset(
1218
+ name,
1219
+ namespace_name=namespace_name,
1220
+ project_name=project_name,
1167
1221
  )
1222
+ dataset_version = dataset.get_version(version)
1223
+ dataset_id = dataset.id
1224
+ dataset_version_id = dataset_version.id
1168
1225
 
1169
1226
  if not indirect:
1170
- return direct_dependencies
1171
-
1172
- for d in direct_dependencies:
1173
- if not d:
1174
- # dependency has been removed
1175
- continue
1176
- if d.is_dataset:
1177
- # only datasets can have dependencies
1178
- d.dependencies = self.get_dataset_dependencies(
1179
- d.name, int(d.version), indirect=indirect
1180
- )
1227
+ return self.metastore.get_direct_dataset_dependencies(
1228
+ dataset,
1229
+ version,
1230
+ )
1181
1231
 
1182
- return direct_dependencies
1232
+ return self.get_dataset_dependencies_by_ids(
1233
+ dataset_id,
1234
+ dataset_version_id,
1235
+ indirect,
1236
+ )
1183
1237
 
1184
1238
  def ls_datasets(
1185
- self, include_listing: bool = False, studio: bool = False
1239
+ self,
1240
+ prefix: str | None = None,
1241
+ include_listing: bool = False,
1242
+ studio: bool = False,
1243
+ project: Project | None = None,
1186
1244
  ) -> Iterator[DatasetListRecord]:
1187
1245
  from datachain.remote.studio import StudioClient
1188
1246
 
1247
+ project_id = project.id if project else None
1248
+
1189
1249
  if studio:
1190
1250
  client = StudioClient()
1191
- response = client.ls_datasets()
1251
+ response = client.ls_datasets(prefix=prefix)
1192
1252
  if not response.ok:
1193
1253
  raise DataChainError(response.message)
1194
1254
  if not response.data:
@@ -1199,8 +1259,12 @@ class Catalog:
1199
1259
  for d in response.data
1200
1260
  if not d.get("name", "").startswith(QUERY_DATASET_PREFIX)
1201
1261
  )
1262
+ elif prefix:
1263
+ datasets = self.metastore.list_datasets_by_prefix(
1264
+ prefix, project_id=project_id
1265
+ )
1202
1266
  else:
1203
- datasets = self.metastore.list_datasets()
1267
+ datasets = self.metastore.list_datasets(project_id=project_id)
1204
1268
 
1205
1269
  for d in datasets:
1206
1270
  if not d.is_bucket_listing or include_listing:
@@ -1208,50 +1272,79 @@ class Catalog:
1208
1272
 
1209
1273
  def list_datasets_versions(
1210
1274
  self,
1275
+ prefix: str | None = None,
1211
1276
  include_listing: bool = False,
1277
+ with_job: bool = True,
1212
1278
  studio: bool = False,
1213
- ) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", Optional["Job"]]]:
1279
+ project: Project | None = None,
1280
+ ) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", "Job | None"]]:
1214
1281
  """Iterate over all dataset versions with related jobs."""
1215
1282
  datasets = list(
1216
- self.ls_datasets(include_listing=include_listing, studio=studio)
1283
+ self.ls_datasets(
1284
+ prefix=prefix,
1285
+ include_listing=include_listing,
1286
+ studio=studio,
1287
+ project=project,
1288
+ )
1217
1289
  )
1218
1290
 
1219
1291
  # preselect dataset versions jobs from db to avoid multiple queries
1220
- jobs_ids: set[str] = {
1221
- v.job_id for ds in datasets for v in ds.versions if v.job_id
1222
- }
1223
1292
  jobs: dict[str, Job] = {}
1224
- if jobs_ids:
1225
- jobs = {j.id: j for j in self.metastore.list_jobs_by_ids(list(jobs_ids))}
1293
+ if with_job:
1294
+ jobs_ids: set[str] = {
1295
+ v.job_id for ds in datasets for v in ds.versions if v.job_id
1296
+ }
1297
+ if jobs_ids:
1298
+ jobs = {
1299
+ j.id: j for j in self.metastore.list_jobs_by_ids(list(jobs_ids))
1300
+ }
1226
1301
 
1227
1302
  for d in datasets:
1228
1303
  yield from (
1229
- (d, v, jobs.get(str(v.job_id)) if v.job_id else None)
1304
+ (d, v, jobs.get(str(v.job_id)) if with_job and v.job_id else None)
1230
1305
  for v in d.versions
1231
1306
  )
1232
1307
 
1233
- def listings(self):
1308
+ def listings(self, prefix: str | None = None) -> list["ListingInfo"]:
1234
1309
  """
1235
1310
  Returns list of ListingInfo objects which are representing specific
1236
1311
  storage listing datasets
1237
1312
  """
1238
- from datachain.lib.listing import is_listing_dataset
1313
+ from datachain.lib.listing import LISTING_PREFIX, is_listing_dataset
1239
1314
  from datachain.lib.listing_info import ListingInfo
1240
1315
 
1316
+ if prefix and not prefix.startswith(LISTING_PREFIX):
1317
+ prefix = LISTING_PREFIX + prefix
1318
+
1319
+ listing_datasets_versions = self.list_datasets_versions(
1320
+ prefix=prefix,
1321
+ include_listing=True,
1322
+ with_job=False,
1323
+ project=self.metastore.listing_project,
1324
+ )
1325
+
1241
1326
  return [
1242
1327
  ListingInfo.from_models(d, v, j)
1243
- for d, v, j in self.list_datasets_versions(include_listing=True)
1328
+ for d, v, j in listing_datasets_versions
1244
1329
  if is_listing_dataset(d.name)
1245
1330
  ]
1246
1331
 
1247
1332
  def ls_dataset_rows(
1248
- self, name: str, version: int, offset=None, limit=None
1333
+ self,
1334
+ dataset: DatasetRecord,
1335
+ version: str,
1336
+ offset=None,
1337
+ limit=None,
1249
1338
  ) -> list[dict]:
1250
1339
  from datachain.query.dataset import DatasetQuery
1251
1340
 
1252
- dataset = self.get_dataset(name)
1253
-
1254
- q = DatasetQuery(name=dataset.name, version=version, catalog=self)
1341
+ q = DatasetQuery(
1342
+ name=dataset.name,
1343
+ namespace_name=dataset.project.namespace.name,
1344
+ project_name=dataset.project.name,
1345
+ version=version,
1346
+ catalog=self,
1347
+ )
1255
1348
  if limit:
1256
1349
  q = q.limit(limit)
1257
1350
  if offset:
@@ -1263,9 +1356,9 @@ class Catalog:
1263
1356
  self,
1264
1357
  source: str,
1265
1358
  path: str,
1266
- version_id: Optional[str] = None,
1359
+ version_id: str | None = None,
1267
1360
  client_config=None,
1268
- content_disposition: Optional[str] = None,
1361
+ content_disposition: str | None = None,
1269
1362
  **kwargs,
1270
1363
  ) -> str:
1271
1364
  client_config = client_config or self.client_config
@@ -1283,26 +1376,42 @@ class Catalog:
1283
1376
  self,
1284
1377
  bucket_uri: str,
1285
1378
  name: str,
1286
- version: int,
1379
+ version: str,
1380
+ project: Project | None = None,
1287
1381
  client_config=None,
1288
1382
  ) -> list[str]:
1289
- dataset = self.get_dataset(name)
1383
+ dataset = self.get_dataset(
1384
+ name,
1385
+ namespace_name=project.namespace.name if project else None,
1386
+ project_name=project.name if project else None,
1387
+ )
1290
1388
 
1291
1389
  return self.warehouse.export_dataset_table(
1292
1390
  bucket_uri, dataset, version, client_config
1293
1391
  )
1294
1392
 
1295
- def dataset_table_export_file_names(self, name: str, version: int) -> list[str]:
1296
- dataset = self.get_dataset(name)
1393
+ def dataset_table_export_file_names(
1394
+ self, name: str, version: str, project: Project | None = None
1395
+ ) -> list[str]:
1396
+ dataset = self.get_dataset(
1397
+ name,
1398
+ namespace_name=project.namespace.name if project else None,
1399
+ project_name=project.name if project else None,
1400
+ )
1297
1401
  return self.warehouse.dataset_table_export_file_names(dataset, version)
1298
1402
 
1299
1403
  def remove_dataset(
1300
1404
  self,
1301
1405
  name: str,
1302
- version: Optional[int] = None,
1303
- force: Optional[bool] = False,
1406
+ project: Project | None = None,
1407
+ version: str | None = None,
1408
+ force: bool | None = False,
1304
1409
  ):
1305
- dataset = self.get_dataset(name)
1410
+ dataset = self.get_dataset(
1411
+ name,
1412
+ namespace_name=project.namespace.name if project else None,
1413
+ project_name=project.name if project else None,
1414
+ )
1306
1415
  if not version and not force:
1307
1416
  raise ValueError(f"Missing dataset version from input for dataset {name}")
1308
1417
  if version and not dataset.has_version(version):
@@ -1324,19 +1433,25 @@ class Catalog:
1324
1433
  def edit_dataset(
1325
1434
  self,
1326
1435
  name: str,
1327
- new_name: Optional[str] = None,
1328
- description: Optional[str] = None,
1329
- labels: Optional[list[str]] = None,
1436
+ project: Project | None = None,
1437
+ new_name: str | None = None,
1438
+ description: str | None = None,
1439
+ attrs: list[str] | None = None,
1330
1440
  ) -> DatasetRecord:
1331
1441
  update_data = {}
1332
1442
  if new_name:
1443
+ DatasetRecord.validate_name(new_name)
1333
1444
  update_data["name"] = new_name
1334
1445
  if description is not None:
1335
1446
  update_data["description"] = description
1336
- if labels is not None:
1337
- update_data["labels"] = labels # type: ignore[assignment]
1447
+ if attrs is not None:
1448
+ update_data["attrs"] = attrs # type: ignore[assignment]
1338
1449
 
1339
- dataset = self.get_dataset(name)
1450
+ dataset = self.get_dataset(
1451
+ name,
1452
+ namespace_name=project.namespace.name if project else None,
1453
+ project_name=project.name if project else None,
1454
+ )
1340
1455
  return self.update_dataset(dataset, **update_data)
1341
1456
 
1342
1457
  def ls(
@@ -1348,22 +1463,24 @@ class Catalog:
1348
1463
  *,
1349
1464
  client_config=None,
1350
1465
  ) -> Iterator[tuple[DataSource, Iterable[tuple]]]:
1351
- data_sources = self.enlist_sources(
1466
+ with self.enlist_sources(
1352
1467
  sources,
1353
1468
  update,
1354
1469
  skip_indexing=skip_indexing,
1355
1470
  client_config=client_config or self.client_config,
1356
- )
1471
+ ) as data_sources:
1472
+ if data_sources is None:
1473
+ return
1357
1474
 
1358
- for source in data_sources: # type: ignore [union-attr]
1359
- yield source, source.ls(fields)
1475
+ for source in data_sources:
1476
+ yield source, source.ls(fields)
1360
1477
 
1361
1478
  def pull_dataset( # noqa: C901, PLR0915
1362
1479
  self,
1363
1480
  remote_ds_uri: str,
1364
- output: Optional[str] = None,
1365
- local_ds_name: Optional[str] = None,
1366
- local_ds_version: Optional[int] = None,
1481
+ output: str | None = None,
1482
+ local_ds_name: str | None = None,
1483
+ local_ds_version: str | None = None,
1367
1484
  cp: bool = False,
1368
1485
  force: bool = False,
1369
1486
  *,
@@ -1393,7 +1510,29 @@ class Catalog:
1393
1510
  except Exception as e:
1394
1511
  raise DataChainError("Error when parsing dataset uri") from e
1395
1512
 
1396
- remote_ds = self.get_remote_dataset(remote_ds_name)
1513
+ remote_namespace, remote_project, remote_ds_name = parse_dataset_name(
1514
+ remote_ds_name
1515
+ )
1516
+ if not remote_namespace or not remote_project:
1517
+ raise DataChainError(
1518
+ f"Invalid fully qualified dataset name {remote_ds_name}, namespace"
1519
+ f" or project missing"
1520
+ )
1521
+
1522
+ if local_ds_name:
1523
+ local_namespace, local_project, local_ds_name = parse_dataset_name(
1524
+ local_ds_name
1525
+ )
1526
+ if local_namespace and local_namespace != remote_namespace:
1527
+ raise DataChainError(
1528
+ "Local namespace must be the same to remote namespace"
1529
+ )
1530
+ if local_project and local_project != remote_project:
1531
+ raise DataChainError("Local project must be the same to remote project")
1532
+
1533
+ remote_ds = self.get_remote_dataset(
1534
+ remote_namespace, remote_project, remote_ds_name
1535
+ )
1397
1536
 
1398
1537
  try:
1399
1538
  # if version is not specified in uri, take the latest one
@@ -1401,7 +1540,12 @@ class Catalog:
1401
1540
  version = remote_ds.latest_version
1402
1541
  print(f"Version not specified, pulling the latest one (v{version})")
1403
1542
  # updating dataset uri with latest version
1404
- remote_ds_uri = create_dataset_uri(remote_ds_name, version)
1543
+ remote_ds_uri = create_dataset_uri(
1544
+ remote_ds.name,
1545
+ remote_ds.project.namespace.name,
1546
+ remote_ds.project.name,
1547
+ version,
1548
+ )
1405
1549
  remote_ds_version = remote_ds.get_version(version)
1406
1550
  except (DatasetVersionNotFoundError, StopIteration) as exc:
1407
1551
  raise DataChainError(
@@ -1410,7 +1554,13 @@ class Catalog:
1410
1554
 
1411
1555
  local_ds_name = local_ds_name or remote_ds.name
1412
1556
  local_ds_version = local_ds_version or remote_ds_version.version
1413
- local_ds_uri = create_dataset_uri(local_ds_name, local_ds_version)
1557
+
1558
+ local_ds_uri = create_dataset_uri(
1559
+ local_ds_name,
1560
+ remote_ds.project.namespace.name,
1561
+ remote_ds.project.name,
1562
+ local_ds_version,
1563
+ )
1414
1564
 
1415
1565
  try:
1416
1566
  # try to find existing dataset with the same uuid to avoid pulling again
@@ -1419,7 +1569,10 @@ class Catalog:
1419
1569
  remote_ds_version.uuid
1420
1570
  )
1421
1571
  existing_ds_uri = create_dataset_uri(
1422
- existing_ds.name, existing_ds_version.version
1572
+ existing_ds.name,
1573
+ existing_ds.project.namespace.name,
1574
+ existing_ds.project.name,
1575
+ existing_ds_version.version,
1423
1576
  )
1424
1577
  if existing_ds_uri == remote_ds_uri:
1425
1578
  print(f"Local copy of dataset {remote_ds_uri} already present")
@@ -1433,8 +1586,30 @@ class Catalog:
1433
1586
  except DatasetNotFoundError:
1434
1587
  pass
1435
1588
 
1589
+ # Create namespace and project if doesn't exist
1590
+ print(
1591
+ f"Creating namespace {remote_ds.project.namespace.name} and project"
1592
+ f" {remote_ds.project.name}"
1593
+ )
1594
+
1595
+ namespace = self.metastore.create_namespace(
1596
+ remote_ds.project.namespace.name,
1597
+ description=remote_ds.project.namespace.descr,
1598
+ uuid=remote_ds.project.namespace.uuid,
1599
+ validate=False,
1600
+ )
1601
+ project = self.metastore.create_project(
1602
+ namespace.name,
1603
+ remote_ds.project.name,
1604
+ description=remote_ds.project.descr,
1605
+ uuid=remote_ds.project.uuid,
1606
+ validate=False,
1607
+ )
1608
+
1436
1609
  try:
1437
- local_dataset = self.get_dataset(local_ds_name)
1610
+ local_dataset = self.get_dataset(
1611
+ local_ds_name, namespace_name=namespace.name, project_name=project.name
1612
+ )
1438
1613
  if local_dataset and local_dataset.has_version(local_ds_version):
1439
1614
  raise DataChainError(
1440
1615
  f"Local dataset {local_ds_uri} already exists with different uuid,"
@@ -1452,10 +1627,11 @@ class Catalog:
1452
1627
  leave=False,
1453
1628
  )
1454
1629
 
1455
- schema = DatasetRecord.parse_schema(remote_ds_version.schema)
1630
+ schema = parse_schema(remote_ds_version.schema)
1456
1631
 
1457
1632
  local_ds = self.create_dataset(
1458
1633
  local_ds_name,
1634
+ project,
1459
1635
  local_ds_version,
1460
1636
  query_script=remote_ds_version.query_script,
1461
1637
  create_rows=True,
@@ -1468,7 +1644,7 @@ class Catalog:
1468
1644
  # asking remote to export dataset rows table to s3 and to return signed
1469
1645
  # urls of exported parts, which are in parquet format
1470
1646
  export_response = studio_client.export_dataset_table(
1471
- remote_ds_name, remote_ds_version.version
1647
+ remote_ds, remote_ds_version.version
1472
1648
  )
1473
1649
  if not export_response.ok:
1474
1650
  raise DataChainError(export_response.message)
@@ -1499,9 +1675,9 @@ class Catalog:
1499
1675
  rows_fetcher = DatasetRowsFetcher(
1500
1676
  metastore,
1501
1677
  warehouse,
1502
- remote_ds_name,
1678
+ remote_ds,
1503
1679
  remote_ds_version.version,
1504
- local_ds_name,
1680
+ local_ds,
1505
1681
  local_ds_version,
1506
1682
  schema,
1507
1683
  progress_bar=dataset_save_progress_bar,
@@ -1511,7 +1687,7 @@ class Catalog:
1511
1687
  iter(batch(signed_urls)), dataset_save_progress_bar
1512
1688
  )
1513
1689
  except:
1514
- self.remove_dataset(local_ds_name, local_ds_version)
1690
+ self.remove_dataset(local_ds_name, project, local_ds_version)
1515
1691
  raise
1516
1692
 
1517
1693
  local_ds = self.metastore.update_dataset_status(
@@ -1561,92 +1737,20 @@ class Catalog:
1561
1737
  else:
1562
1738
  # since we don't call cp command, which does listing implicitly,
1563
1739
  # it needs to be done here
1564
- self.enlist_sources(
1740
+ with self.enlist_sources(
1565
1741
  sources,
1566
1742
  update,
1567
1743
  client_config=client_config or self.client_config,
1568
- )
1744
+ ):
1745
+ pass
1569
1746
 
1570
1747
  self.create_dataset_from_sources(
1571
- output, sources, client_config=client_config, recursive=recursive
1572
- )
1573
-
1574
- def query(
1575
- self,
1576
- query_script: str,
1577
- env: Optional[Mapping[str, str]] = None,
1578
- python_executable: str = sys.executable,
1579
- capture_output: bool = False,
1580
- output_hook: Callable[[str], None] = noop,
1581
- params: Optional[dict[str, str]] = None,
1582
- job_id: Optional[str] = None,
1583
- interrupt_timeout: Optional[int] = None,
1584
- terminate_timeout: Optional[int] = None,
1585
- ) -> None:
1586
- cmd = [python_executable, "-c", query_script]
1587
- env = dict(env or os.environ)
1588
- env.update(
1589
- {
1590
- "DATACHAIN_QUERY_PARAMS": json.dumps(params or {}),
1591
- "DATACHAIN_JOB_ID": job_id or "",
1592
- },
1748
+ output,
1749
+ sources,
1750
+ self.metastore.default_project,
1751
+ client_config=client_config,
1752
+ recursive=recursive,
1593
1753
  )
1594
- popen_kwargs: dict[str, Any] = {}
1595
- if capture_output:
1596
- popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT}
1597
-
1598
- def raise_termination_signal(sig: int, _: Any) -> NoReturn:
1599
- raise TerminationSignal(sig)
1600
-
1601
- thread: Optional[Thread] = None
1602
- with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603
1603
- logger.info("Starting process %s", proc.pid)
1604
-
1605
- orig_sigint_handler = signal.getsignal(signal.SIGINT)
1606
- # ignore SIGINT in the main process.
1607
- # In the terminal, SIGINTs are received by all the processes in
1608
- # the foreground process group, so the script will receive the signal too.
1609
- # (If we forward the signal to the child, it will receive it twice.)
1610
- signal.signal(signal.SIGINT, signal.SIG_IGN)
1611
-
1612
- orig_sigterm_handler = signal.getsignal(signal.SIGTERM)
1613
- signal.signal(signal.SIGTERM, raise_termination_signal)
1614
- try:
1615
- if capture_output:
1616
- args = (proc.stdout, output_hook)
1617
- thread = Thread(target=_process_stream, args=args, daemon=True)
1618
- thread.start()
1619
-
1620
- proc.wait()
1621
- except TerminationSignal as exc:
1622
- signal.signal(signal.SIGTERM, orig_sigterm_handler)
1623
- signal.signal(signal.SIGINT, orig_sigint_handler)
1624
- logger.info("Shutting down process %s, received %r", proc.pid, exc)
1625
- # Rather than forwarding the signal to the child, we try to shut it down
1626
- # gracefully. This is because we consider the script to be interactive
1627
- # and special, so we give it time to cleanup before exiting.
1628
- shutdown_process(proc, interrupt_timeout, terminate_timeout)
1629
- if proc.returncode:
1630
- raise QueryScriptCancelError(
1631
- "Query script was canceled by user", return_code=proc.returncode
1632
- ) from exc
1633
- finally:
1634
- signal.signal(signal.SIGTERM, orig_sigterm_handler)
1635
- signal.signal(signal.SIGINT, orig_sigint_handler)
1636
- if thread:
1637
- thread.join() # wait for the reader thread
1638
-
1639
- logger.info("Process %s exited with return code %s", proc.pid, proc.returncode)
1640
- if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE:
1641
- raise QueryScriptCancelError(
1642
- "Query script was canceled by user",
1643
- return_code=proc.returncode,
1644
- )
1645
- if proc.returncode:
1646
- raise QueryScriptRunError(
1647
- f"Query script exited with error code {proc.returncode}",
1648
- return_code=proc.returncode,
1649
- )
1650
1754
 
1651
1755
  def cp(
1652
1756
  self,
@@ -1658,7 +1762,7 @@ class Catalog:
1658
1762
  no_cp: bool = False,
1659
1763
  no_glob: bool = False,
1660
1764
  *,
1661
- client_config: Optional["dict"] = None,
1765
+ client_config: dict | None = None,
1662
1766
  ) -> None:
1663
1767
  """
1664
1768
  This function copies files from cloud sources to local destination directory
@@ -1671,38 +1775,42 @@ class Catalog:
1671
1775
  no_glob,
1672
1776
  client_config=client_config,
1673
1777
  )
1778
+ try:
1779
+ always_copy_dir_contents, copy_to_filename = prepare_output_for_cp(
1780
+ node_groups, output, force, no_cp
1781
+ )
1782
+ total_size, total_files = collect_nodes_for_cp(node_groups, recursive)
1783
+ if not total_files:
1784
+ return
1674
1785
 
1675
- always_copy_dir_contents, copy_to_filename = prepare_output_for_cp(
1676
- node_groups, output, force, no_cp
1677
- )
1678
- total_size, total_files = collect_nodes_for_cp(node_groups, recursive)
1679
- if not total_files:
1680
- return
1681
-
1682
- desc_max_len = max(len(output) + 16, 19)
1683
- bar_format = (
1684
- "{desc:<"
1685
- f"{desc_max_len}"
1686
- "}{percentage:3.0f}%|{bar}| {n_fmt:>5}/{total_fmt:<5} "
1687
- "[{elapsed}<{remaining}, {rate_fmt:>8}]"
1688
- )
1786
+ desc_max_len = max(len(output) + 16, 19)
1787
+ bar_format = (
1788
+ "{desc:<"
1789
+ f"{desc_max_len}"
1790
+ "}{percentage:3.0f}%|{bar}| {n_fmt:>5}/{total_fmt:<5} "
1791
+ "[{elapsed}<{remaining}, {rate_fmt:>8}]"
1792
+ )
1689
1793
 
1690
- if not no_cp:
1691
- with get_download_bar(bar_format, total_size) as pbar:
1692
- for node_group in node_groups:
1693
- node_group.download(recursive=recursive, pbar=pbar)
1794
+ if not no_cp:
1795
+ with get_download_bar(bar_format, total_size) as pbar:
1796
+ for node_group in node_groups:
1797
+ node_group.download(recursive=recursive, pbar=pbar)
1694
1798
 
1695
- instantiate_node_groups(
1696
- node_groups,
1697
- output,
1698
- bar_format,
1699
- total_files,
1700
- force,
1701
- recursive,
1702
- no_cp,
1703
- always_copy_dir_contents,
1704
- copy_to_filename,
1705
- )
1799
+ instantiate_node_groups(
1800
+ node_groups,
1801
+ output,
1802
+ bar_format,
1803
+ total_files,
1804
+ force,
1805
+ recursive,
1806
+ no_cp,
1807
+ always_copy_dir_contents,
1808
+ copy_to_filename,
1809
+ )
1810
+ finally:
1811
+ for node_group in node_groups:
1812
+ with suppress(Exception):
1813
+ node_group.close()
1706
1814
 
1707
1815
  def du(
1708
1816
  self,
@@ -1712,24 +1820,26 @@ class Catalog:
1712
1820
  *,
1713
1821
  client_config=None,
1714
1822
  ) -> Iterable[tuple[str, float]]:
1715
- sources = self.enlist_sources(
1823
+ with self.enlist_sources(
1716
1824
  sources,
1717
1825
  update,
1718
1826
  client_config=client_config or self.client_config,
1719
- )
1827
+ ) as matched_sources:
1828
+ if matched_sources is None:
1829
+ return
1720
1830
 
1721
- def du_dirs(src, node, subdepth):
1722
- if subdepth > 0:
1723
- subdirs = src.listing.get_dirs_by_parent_path(node.path)
1724
- for sd in subdirs:
1725
- yield from du_dirs(src, sd, subdepth - 1)
1726
- yield (
1727
- src.get_node_full_path(node),
1728
- src.listing.du(node)[0],
1729
- )
1831
+ def du_dirs(src, node, subdepth):
1832
+ if subdepth > 0:
1833
+ subdirs = src.listing.get_dirs_by_parent_path(node.path)
1834
+ for sd in subdirs:
1835
+ yield from du_dirs(src, sd, subdepth - 1)
1836
+ yield (
1837
+ src.get_node_full_path(node),
1838
+ src.listing.du(node)[0],
1839
+ )
1730
1840
 
1731
- for src in sources:
1732
- yield from du_dirs(src, src.node, depth)
1841
+ for src in matched_sources:
1842
+ yield from du_dirs(src, src.node, depth)
1733
1843
 
1734
1844
  def find(
1735
1845
  self,
@@ -1745,39 +1855,42 @@ class Catalog:
1745
1855
  *,
1746
1856
  client_config=None,
1747
1857
  ) -> Iterator[str]:
1748
- sources = self.enlist_sources(
1858
+ with self.enlist_sources(
1749
1859
  sources,
1750
1860
  update,
1751
1861
  client_config=client_config or self.client_config,
1752
- )
1753
- if not columns:
1754
- columns = ["path"]
1755
- field_set = set()
1756
- for column in columns:
1757
- if column == "du":
1758
- field_set.add("dir_type")
1759
- field_set.add("size")
1760
- field_set.add("path")
1761
- elif column == "name":
1762
- field_set.add("path")
1763
- elif column == "path":
1764
- field_set.add("dir_type")
1765
- field_set.add("path")
1766
- elif column == "size":
1767
- field_set.add("size")
1768
- elif column == "type":
1769
- field_set.add("dir_type")
1770
- fields = list(field_set)
1771
- field_lookup = {f: i for i, f in enumerate(fields)}
1772
- for src in sources:
1773
- results = src.listing.find(
1774
- src.node, fields, names, inames, paths, ipaths, size, typ
1775
- )
1776
- for row in results:
1777
- yield "\t".join(
1778
- find_column_to_str(row, field_lookup, src, column)
1779
- for column in columns
1862
+ ) as matched_sources:
1863
+ if matched_sources is None:
1864
+ return
1865
+
1866
+ if not columns:
1867
+ columns = ["path"]
1868
+ field_set = set()
1869
+ for column in columns:
1870
+ if column == "du":
1871
+ field_set.add("dir_type")
1872
+ field_set.add("size")
1873
+ field_set.add("path")
1874
+ elif column == "name":
1875
+ field_set.add("path")
1876
+ elif column == "path":
1877
+ field_set.add("dir_type")
1878
+ field_set.add("path")
1879
+ elif column == "size":
1880
+ field_set.add("size")
1881
+ elif column == "type":
1882
+ field_set.add("dir_type")
1883
+ fields = list(field_set)
1884
+ field_lookup = {f: i for i, f in enumerate(fields)}
1885
+ for src in matched_sources:
1886
+ results = src.listing.find(
1887
+ src.node, fields, names, inames, paths, ipaths, size, typ
1780
1888
  )
1889
+ for row in results:
1890
+ yield "\t".join(
1891
+ find_column_to_str(row, field_lookup, src, column)
1892
+ for column in columns
1893
+ )
1781
1894
 
1782
1895
  def index(
1783
1896
  self,
@@ -1786,9 +1899,10 @@ class Catalog:
1786
1899
  *,
1787
1900
  client_config=None,
1788
1901
  ) -> None:
1789
- self.enlist_sources(
1902
+ with self.enlist_sources(
1790
1903
  sources,
1791
1904
  update,
1792
1905
  client_config=client_config or self.client_config,
1793
1906
  only_index=True,
1794
- )
1907
+ ):
1908
+ pass