datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,28 +1,16 @@
1
1
  import io
2
- import json
3
2
  import logging
4
3
  import os
5
4
  import os.path
6
5
  import posixpath
7
- import signal
8
- import subprocess
9
- import sys
10
6
  import time
11
7
  import traceback
12
- from collections.abc import Iterable, Iterator, Mapping, Sequence
8
+ from collections.abc import Callable, Iterable, Iterator, Sequence
9
+ from contextlib import contextmanager, suppress
13
10
  from copy import copy
14
11
  from dataclasses import dataclass
15
12
  from functools import cached_property, reduce
16
- from threading import Thread
17
- from typing import (
18
- IO,
19
- TYPE_CHECKING,
20
- Any,
21
- Callable,
22
- NoReturn,
23
- Optional,
24
- Union,
25
- )
13
+ from typing import TYPE_CHECKING, Any
26
14
  from uuid import uuid4
27
15
 
28
16
  import sqlalchemy as sa
@@ -43,6 +31,7 @@ from datachain.dataset import (
43
31
  create_dataset_uri,
44
32
  parse_dataset_name,
45
33
  parse_dataset_uri,
34
+ parse_schema,
46
35
  )
47
36
  from datachain.error import (
48
37
  DataChainError,
@@ -51,8 +40,6 @@ from datachain.error import (
51
40
  DatasetVersionNotFoundError,
52
41
  NamespaceNotFoundError,
53
42
  ProjectNotFoundError,
54
- QueryScriptCancelError,
55
- QueryScriptRunError,
56
43
  )
57
44
  from datachain.lib.listing import get_listing
58
45
  from datachain.node import DirType, Node, NodeWithPath
@@ -62,12 +49,10 @@ from datachain.sql.types import DateTime, SQLType
62
49
  from datachain.utils import DataChainDir
63
50
 
64
51
  from .datasource import DataSource
52
+ from .dependency import build_dependency_hierarchy, populate_nested_dependencies
65
53
 
66
54
  if TYPE_CHECKING:
67
- from datachain.data_storage import (
68
- AbstractMetastore,
69
- AbstractWarehouse,
70
- )
55
+ from datachain.data_storage import AbstractMetastore, AbstractWarehouse
71
56
  from datachain.dataset import DatasetListVersion
72
57
  from datachain.job import Job
73
58
  from datachain.lib.listing_info import ListingInfo
@@ -81,8 +66,6 @@ TTL_INT = 4 * 60 * 60
81
66
 
82
67
  INDEX_INTERNAL_ERROR_MESSAGE = "Internal error on indexing"
83
68
  DATASET_INTERNAL_ERROR_MESSAGE = "Internal error on creating dataset"
84
- # exit code we use if last statement in query script is not instance of DatasetQuery
85
- QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE = 10
86
69
  # exit code we use if query script was canceled
87
70
  QUERY_SCRIPT_CANCELED_EXIT_CODE = 11
88
71
  QUERY_SCRIPT_SIGTERM_EXIT_CODE = -15 # if query script was terminated by SIGTERM
@@ -94,71 +77,11 @@ PULL_DATASET_SLEEP_INTERVAL = 0.1 # sleep time while waiting for chunk to be av
94
77
  PULL_DATASET_CHECK_STATUS_INTERVAL = 20 # interval to check export status in Studio
95
78
 
96
79
 
97
- def noop(_: str):
98
- pass
99
-
100
-
101
- class TerminationSignal(RuntimeError): # noqa: N818
102
- def __init__(self, signal):
103
- self.signal = signal
104
- super().__init__("Received termination signal", signal)
105
-
106
- def __repr__(self):
107
- return f"{self.__class__.__name__}({self.signal})"
108
-
109
-
110
- if sys.platform == "win32":
111
- SIGINT = signal.CTRL_C_EVENT
112
- else:
113
- SIGINT = signal.SIGINT
114
-
115
-
116
80
  def is_namespace_local(namespace_name) -> bool:
117
81
  """Checks if namespace is from local environment, i.e. is `local`"""
118
82
  return namespace_name == "local"
119
83
 
120
84
 
121
- def shutdown_process(
122
- proc: subprocess.Popen,
123
- interrupt_timeout: Optional[int] = None,
124
- terminate_timeout: Optional[int] = None,
125
- ) -> int:
126
- """Shut down the process gracefully with SIGINT -> SIGTERM -> SIGKILL."""
127
-
128
- logger.info("sending interrupt signal to the process %s", proc.pid)
129
- proc.send_signal(SIGINT)
130
-
131
- logger.info("waiting for the process %s to finish", proc.pid)
132
- try:
133
- return proc.wait(interrupt_timeout)
134
- except subprocess.TimeoutExpired:
135
- logger.info(
136
- "timed out waiting, sending terminate signal to the process %s", proc.pid
137
- )
138
- proc.terminate()
139
- try:
140
- return proc.wait(terminate_timeout)
141
- except subprocess.TimeoutExpired:
142
- logger.info("timed out waiting, killing the process %s", proc.pid)
143
- proc.kill()
144
- return proc.wait()
145
-
146
-
147
- def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> None:
148
- buffer = b""
149
- while byt := stream.read(1): # Read one byte at a time
150
- buffer += byt
151
-
152
- if byt in (b"\n", b"\r"): # Check for newline or carriage return
153
- line = buffer.decode("utf-8")
154
- callback(line)
155
- buffer = b"" # Clear buffer for next line
156
-
157
- if buffer: # Handle any remaining data in the buffer
158
- line = buffer.decode("utf-8")
159
- callback(line)
160
-
161
-
162
85
  class DatasetRowsFetcher(NodesThreadPool):
163
86
  def __init__(
164
87
  self,
@@ -168,7 +91,7 @@ class DatasetRowsFetcher(NodesThreadPool):
168
91
  remote_ds_version: str,
169
92
  local_ds: DatasetRecord,
170
93
  local_ds_version: str,
171
- schema: dict[str, Union[SQLType, type[SQLType]]],
94
+ schema: dict[str, SQLType | type[SQLType]],
172
95
  max_threads: int = PULL_DATASET_MAX_THREADS,
173
96
  progress_bar=None,
174
97
  ):
@@ -183,7 +106,7 @@ class DatasetRowsFetcher(NodesThreadPool):
183
106
  self.local_ds = local_ds
184
107
  self.local_ds_version = local_ds_version
185
108
  self.schema = schema
186
- self.last_status_check: Optional[float] = None
109
+ self.last_status_check: float | None = None
187
110
  self.studio_client = StudioClient()
188
111
  self.progress_bar = progress_bar
189
112
 
@@ -287,16 +210,16 @@ class DatasetRowsFetcher(NodesThreadPool):
287
210
  class NodeGroup:
288
211
  """Class for a group of nodes from the same source"""
289
212
 
290
- listing: Optional["Listing"]
291
- client: "Client"
213
+ listing: "Listing | None"
214
+ client: Client
292
215
  sources: list[DataSource]
293
216
 
294
217
  # The source path within the bucket
295
218
  # (not including the bucket name or s3:// prefix)
296
219
  source_path: str = ""
297
- dataset_name: Optional[str] = None
298
- dataset_version: Optional[str] = None
299
- instantiated_nodes: Optional[list[NodeWithPath]] = None
220
+ dataset_name: str | None = None
221
+ dataset_version: str | None = None
222
+ instantiated_nodes: list[NodeWithPath] | None = None
300
223
 
301
224
  @property
302
225
  def is_dataset(self) -> bool:
@@ -317,13 +240,23 @@ class NodeGroup:
317
240
  if self.sources:
318
241
  self.client.fetch_nodes(self.iternodes(recursive), shared_progress_bar=pbar)
319
242
 
243
+ def close(self) -> None:
244
+ if self.listing:
245
+ self.listing.close()
246
+
247
+ def __enter__(self) -> "NodeGroup":
248
+ return self
249
+
250
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
251
+ self.close()
252
+
320
253
 
321
254
  def prepare_output_for_cp(
322
255
  node_groups: list[NodeGroup],
323
256
  output: str,
324
257
  force: bool = False,
325
258
  no_cp: bool = False,
326
- ) -> tuple[bool, Optional[str]]:
259
+ ) -> tuple[bool, str | None]:
327
260
  total_node_count = 0
328
261
  for node_group in node_groups:
329
262
  if not node_group.sources:
@@ -372,7 +305,7 @@ def collect_nodes_for_cp(
372
305
 
373
306
  # Collect all sources to process
374
307
  for node_group in node_groups:
375
- listing: Optional[Listing] = node_group.listing
308
+ listing: Listing | None = node_group.listing
376
309
  valid_sources: list[DataSource] = []
377
310
  for dsrc in node_group.sources:
378
311
  if dsrc.is_single_object():
@@ -416,7 +349,7 @@ def instantiate_node_groups(
416
349
  recursive: bool = False,
417
350
  virtual_only: bool = False,
418
351
  always_copy_dir_contents: bool = False,
419
- copy_to_filename: Optional[str] = None,
352
+ copy_to_filename: str | None = None,
420
353
  ) -> None:
421
354
  instantiate_progress_bar = (
422
355
  None
@@ -444,7 +377,7 @@ def instantiate_node_groups(
444
377
  for node_group in node_groups:
445
378
  if not node_group.sources:
446
379
  continue
447
- listing: Optional[Listing] = node_group.listing
380
+ listing: Listing | None = node_group.listing
448
381
  source_path: str = node_group.source_path
449
382
 
450
383
  copy_dir_contents = always_copy_dir_contents or source_path.endswith("/")
@@ -527,10 +460,8 @@ class Catalog:
527
460
  warehouse: "AbstractWarehouse",
528
461
  cache_dir=None,
529
462
  tmp_dir=None,
530
- client_config: Optional[dict[str, Any]] = None,
531
- warehouse_ready_callback: Optional[
532
- Callable[["AbstractWarehouse"], None]
533
- ] = None,
463
+ client_config: dict[str, Any] | None = None,
464
+ warehouse_ready_callback: Callable[["AbstractWarehouse"], None] | None = None,
534
465
  in_memory: bool = False,
535
466
  ):
536
467
  datachain_dir = DataChainDir(cache=cache_dir, tmp=tmp_dir)
@@ -545,6 +476,7 @@ class Catalog:
545
476
  }
546
477
  self._warehouse_ready_callback = warehouse_ready_callback
547
478
  self.in_memory = in_memory
479
+ self._owns_connections = True # False for copies, prevents double-close
548
480
 
549
481
  @cached_property
550
482
  def warehouse(self) -> "AbstractWarehouse":
@@ -566,13 +498,36 @@ class Catalog:
566
498
  }
567
499
 
568
500
  def copy(self, cache=True, db=True):
501
+ """
502
+ Create a shallow copy of this catalog.
503
+
504
+ The copy shares metastore and warehouse with the original but will not
505
+ close them - only the original catalog owns the connections.
506
+ """
569
507
  result = copy(self)
508
+ result._owns_connections = False
570
509
  if not db:
571
510
  result.metastore = None
572
511
  result._warehouse = None
573
512
  result.warehouse = None
574
513
  return result
575
514
 
515
+ def close(self) -> None:
516
+ if not self._owns_connections:
517
+ return
518
+ if self.metastore is not None:
519
+ with suppress(Exception):
520
+ self.metastore.close_on_exit()
521
+ if self._warehouse is not None:
522
+ with suppress(Exception):
523
+ self._warehouse.close_on_exit()
524
+
525
+ def __enter__(self) -> "Catalog":
526
+ return self
527
+
528
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
529
+ self.close()
530
+
576
531
  @classmethod
577
532
  def generate_query_dataset_name(cls) -> str:
578
533
  return f"{QUERY_DATASET_PREFIX}_{uuid4().hex}"
@@ -592,7 +547,7 @@ class Catalog:
592
547
  client_config=None,
593
548
  column="file",
594
549
  skip_indexing=False,
595
- ) -> tuple[Optional["Listing"], "Client", str]:
550
+ ) -> tuple["Listing | None", Client, str]:
596
551
  from datachain import read_storage
597
552
  from datachain.listing import Listing
598
553
 
@@ -626,6 +581,7 @@ class Catalog:
626
581
  **kwargs,
627
582
  )
628
583
 
584
+ @contextmanager
629
585
  def enlist_sources(
630
586
  self,
631
587
  sources: list[str],
@@ -633,34 +589,41 @@ class Catalog:
633
589
  skip_indexing=False,
634
590
  client_config=None,
635
591
  only_index=False,
636
- ) -> Optional[list["DataSource"]]:
637
- enlisted_sources = []
638
- for src in sources: # Opt: parallel
639
- listing, client, file_path = self.enlist_source(
640
- src,
641
- update,
642
- client_config=client_config or self.client_config,
643
- skip_indexing=skip_indexing,
644
- )
645
- enlisted_sources.append((listing, client, file_path))
646
-
647
- if only_index:
648
- # sometimes we don't really need listing result (e.g on indexing process)
649
- # so this is to improve performance
650
- return None
651
-
652
- dsrc_all: list[DataSource] = []
653
- for listing, client, file_path in enlisted_sources:
654
- if not listing:
655
- nodes = [Node.from_file(client.get_file_info(file_path))]
656
- dir_only = False
657
- else:
658
- nodes = listing.expand_path(file_path)
659
- dir_only = file_path.endswith("/")
660
- dsrc_all.extend(
661
- DataSource(listing, client, node, dir_only) for node in nodes
662
- )
663
- return dsrc_all
592
+ ) -> Iterator[list["DataSource"] | None]:
593
+ enlisted_sources: list[tuple[Listing | None, Client, str]] = []
594
+ try:
595
+ for src in sources: # Opt: parallel
596
+ listing, client, file_path = self.enlist_source(
597
+ src,
598
+ update,
599
+ client_config=client_config or self.client_config,
600
+ skip_indexing=skip_indexing,
601
+ )
602
+ enlisted_sources.append((listing, client, file_path))
603
+
604
+ if only_index:
605
+ # sometimes we don't really need listing result (e.g. on indexing
606
+ # process) so this is to improve performance
607
+ yield None
608
+ return
609
+
610
+ dsrc_all: list[DataSource] = []
611
+ for listing, client, file_path in enlisted_sources:
612
+ if not listing:
613
+ nodes = [Node.from_file(client.get_file_info(file_path))]
614
+ dir_only = False
615
+ else:
616
+ nodes = listing.expand_path(file_path)
617
+ dir_only = file_path.endswith("/")
618
+ dsrc_all.extend(
619
+ DataSource(listing, client, node, dir_only) for node in nodes
620
+ )
621
+ yield dsrc_all
622
+ finally:
623
+ for listing, _, _ in enlisted_sources:
624
+ if listing:
625
+ with suppress(Exception):
626
+ listing.close()
664
627
 
665
628
  def enlist_sources_grouped(
666
629
  self,
@@ -679,7 +642,7 @@ class Catalog:
679
642
  enlisted_sources: list[tuple[bool, bool, Any]] = []
680
643
  client_config = client_config or self.client_config
681
644
  for src in sources: # Opt: parallel
682
- listing: Optional[Listing]
645
+ listing: Listing | None
683
646
  if src.startswith("ds://"):
684
647
  ds_name, ds_version = parse_dataset_uri(src)
685
648
  ds_namespace, ds_project, ds_name = parse_dataset_name(ds_name)
@@ -785,19 +748,20 @@ class Catalog:
785
748
  def create_dataset(
786
749
  self,
787
750
  name: str,
788
- project: Optional[Project] = None,
789
- version: Optional[str] = None,
751
+ project: Project | None = None,
752
+ version: str | None = None,
790
753
  *,
791
754
  columns: Sequence[Column],
792
- feature_schema: Optional[dict] = None,
755
+ feature_schema: dict | None = None,
793
756
  query_script: str = "",
794
- create_rows: Optional[bool] = True,
795
- validate_version: Optional[bool] = True,
796
- listing: Optional[bool] = False,
797
- uuid: Optional[str] = None,
798
- description: Optional[str] = None,
799
- attrs: Optional[list[str]] = None,
800
- update_version: Optional[str] = "patch",
757
+ create_rows: bool | None = True,
758
+ validate_version: bool | None = True,
759
+ listing: bool | None = False,
760
+ uuid: str | None = None,
761
+ description: str | None = None,
762
+ attrs: list[str] | None = None,
763
+ update_version: str | None = "patch",
764
+ job_id: str | None = None,
801
765
  ) -> "DatasetRecord":
802
766
  """
803
767
  Creates new dataset of a specific version.
@@ -863,7 +827,7 @@ class Catalog:
863
827
  f"Version {version} must be higher than the current latest one"
864
828
  )
865
829
 
866
- return self.create_new_dataset_version(
830
+ return self.create_dataset_version(
867
831
  dataset,
868
832
  version,
869
833
  feature_schema=feature_schema,
@@ -871,9 +835,10 @@ class Catalog:
871
835
  create_rows_table=create_rows,
872
836
  columns=columns,
873
837
  uuid=uuid,
838
+ job_id=job_id,
874
839
  )
875
840
 
876
- def create_new_dataset_version(
841
+ def create_dataset_version(
877
842
  self,
878
843
  dataset: DatasetRecord,
879
844
  version: str,
@@ -886,8 +851,8 @@ class Catalog:
886
851
  error_stack="",
887
852
  script_output="",
888
853
  create_rows_table=True,
889
- job_id: Optional[str] = None,
890
- uuid: Optional[str] = None,
854
+ job_id: str | None = None,
855
+ uuid: str | None = None,
891
856
  ) -> DatasetRecord:
892
857
  """
893
858
  Creates dataset version if it doesn't exist.
@@ -901,7 +866,7 @@ class Catalog:
901
866
  dataset = self.metastore.create_dataset_version(
902
867
  dataset,
903
868
  version,
904
- status=DatasetStatus.PENDING,
869
+ status=DatasetStatus.CREATED,
905
870
  sources=sources,
906
871
  feature_schema=feature_schema,
907
872
  query_script=query_script,
@@ -971,7 +936,7 @@ class Catalog:
971
936
  return dataset_updated
972
937
 
973
938
  def remove_dataset_version(
974
- self, dataset: DatasetRecord, version: str, drop_rows: Optional[bool] = True
939
+ self, dataset: DatasetRecord, version: str, drop_rows: bool | None = True
975
940
  ) -> None:
976
941
  """
977
942
  Deletes one single dataset version.
@@ -999,7 +964,7 @@ class Catalog:
999
964
  self,
1000
965
  name: str,
1001
966
  sources: list[str],
1002
- project: Optional[Project] = None,
967
+ project: Project | None = None,
1003
968
  client_config=None,
1004
969
  recursive=False,
1005
970
  ) -> DatasetRecord:
@@ -1068,8 +1033,8 @@ class Catalog:
1068
1033
  def get_full_dataset_name(
1069
1034
  self,
1070
1035
  name: str,
1071
- project_name: Optional[str] = None,
1072
- namespace_name: Optional[str] = None,
1036
+ project_name: str | None = None,
1037
+ namespace_name: str | None = None,
1073
1038
  ) -> tuple[str, str, str]:
1074
1039
  """
1075
1040
  Returns dataset name together with separated namespace and project name.
@@ -1101,8 +1066,8 @@ class Catalog:
1101
1066
  def get_dataset(
1102
1067
  self,
1103
1068
  name: str,
1104
- namespace_name: Optional[str] = None,
1105
- project_name: Optional[str] = None,
1069
+ namespace_name: str | None = None,
1070
+ project_name: str | None = None,
1106
1071
  ) -> DatasetRecord:
1107
1072
  from datachain.lib.listing import is_listing_dataset
1108
1073
 
@@ -1122,7 +1087,7 @@ class Catalog:
1122
1087
  name: str,
1123
1088
  namespace_name: str,
1124
1089
  project_name: str,
1125
- version: Optional[str] = None,
1090
+ version: str | None = None,
1126
1091
  pull_dataset: bool = False,
1127
1092
  update: bool = False,
1128
1093
  ) -> DatasetRecord:
@@ -1209,49 +1174,73 @@ class Catalog:
1209
1174
  assert isinstance(dataset_info, dict)
1210
1175
  return DatasetRecord.from_dict(dataset_info)
1211
1176
 
1177
+ def get_dataset_dependencies_by_ids(
1178
+ self,
1179
+ dataset_id: int,
1180
+ version_id: int,
1181
+ indirect: bool = True,
1182
+ ) -> list[DatasetDependency | None]:
1183
+ dependency_nodes = self.metastore.get_dataset_dependency_nodes(
1184
+ dataset_id=dataset_id,
1185
+ version_id=version_id,
1186
+ )
1187
+
1188
+ if not dependency_nodes:
1189
+ return []
1190
+
1191
+ dependency_map, children_map = build_dependency_hierarchy(dependency_nodes)
1192
+
1193
+ root_key = (dataset_id, version_id)
1194
+ if root_key not in children_map:
1195
+ return []
1196
+
1197
+ root_dependency_ids = children_map[root_key]
1198
+ root_dependencies = [dependency_map[dep_id] for dep_id in root_dependency_ids]
1199
+
1200
+ if indirect:
1201
+ for dependency in root_dependencies:
1202
+ if dependency is not None:
1203
+ populate_nested_dependencies(
1204
+ dependency, dependency_nodes, dependency_map, children_map
1205
+ )
1206
+
1207
+ return root_dependencies
1208
+
1212
1209
  def get_dataset_dependencies(
1213
1210
  self,
1214
1211
  name: str,
1215
1212
  version: str,
1216
- namespace_name: Optional[str] = None,
1217
- project_name: Optional[str] = None,
1213
+ namespace_name: str | None = None,
1214
+ project_name: str | None = None,
1218
1215
  indirect=False,
1219
- ) -> list[Optional[DatasetDependency]]:
1216
+ ) -> list[DatasetDependency | None]:
1220
1217
  dataset = self.get_dataset(
1221
1218
  name,
1222
1219
  namespace_name=namespace_name,
1223
1220
  project_name=project_name,
1224
1221
  )
1225
-
1226
- direct_dependencies = self.metastore.get_direct_dataset_dependencies(
1227
- dataset, version
1228
- )
1222
+ dataset_version = dataset.get_version(version)
1223
+ dataset_id = dataset.id
1224
+ dataset_version_id = dataset_version.id
1229
1225
 
1230
1226
  if not indirect:
1231
- return direct_dependencies
1232
-
1233
- for d in direct_dependencies:
1234
- if not d:
1235
- # dependency has been removed
1236
- continue
1237
- if d.is_dataset:
1238
- # only datasets can have dependencies
1239
- d.dependencies = self.get_dataset_dependencies(
1240
- d.name,
1241
- d.version,
1242
- namespace_name=d.namespace,
1243
- project_name=d.project,
1244
- indirect=indirect,
1245
- )
1227
+ return self.metastore.get_direct_dataset_dependencies(
1228
+ dataset,
1229
+ version,
1230
+ )
1246
1231
 
1247
- return direct_dependencies
1232
+ return self.get_dataset_dependencies_by_ids(
1233
+ dataset_id,
1234
+ dataset_version_id,
1235
+ indirect,
1236
+ )
1248
1237
 
1249
1238
  def ls_datasets(
1250
1239
  self,
1251
- prefix: Optional[str] = None,
1240
+ prefix: str | None = None,
1252
1241
  include_listing: bool = False,
1253
1242
  studio: bool = False,
1254
- project: Optional[Project] = None,
1243
+ project: Project | None = None,
1255
1244
  ) -> Iterator[DatasetListRecord]:
1256
1245
  from datachain.remote.studio import StudioClient
1257
1246
 
@@ -1283,12 +1272,12 @@ class Catalog:
1283
1272
 
1284
1273
  def list_datasets_versions(
1285
1274
  self,
1286
- prefix: Optional[str] = None,
1275
+ prefix: str | None = None,
1287
1276
  include_listing: bool = False,
1288
1277
  with_job: bool = True,
1289
1278
  studio: bool = False,
1290
- project: Optional[Project] = None,
1291
- ) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", Optional["Job"]]]:
1279
+ project: Project | None = None,
1280
+ ) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", "Job | None"]]:
1292
1281
  """Iterate over all dataset versions with related jobs."""
1293
1282
  datasets = list(
1294
1283
  self.ls_datasets(
@@ -1316,7 +1305,7 @@ class Catalog:
1316
1305
  for v in d.versions
1317
1306
  )
1318
1307
 
1319
- def listings(self, prefix: Optional[str] = None) -> list["ListingInfo"]:
1308
+ def listings(self, prefix: str | None = None) -> list["ListingInfo"]:
1320
1309
  """
1321
1310
  Returns list of ListingInfo objects which are representing specific
1322
1311
  storage listing datasets
@@ -1367,9 +1356,9 @@ class Catalog:
1367
1356
  self,
1368
1357
  source: str,
1369
1358
  path: str,
1370
- version_id: Optional[str] = None,
1359
+ version_id: str | None = None,
1371
1360
  client_config=None,
1372
- content_disposition: Optional[str] = None,
1361
+ content_disposition: str | None = None,
1373
1362
  **kwargs,
1374
1363
  ) -> str:
1375
1364
  client_config = client_config or self.client_config
@@ -1388,7 +1377,7 @@ class Catalog:
1388
1377
  bucket_uri: str,
1389
1378
  name: str,
1390
1379
  version: str,
1391
- project: Optional[Project] = None,
1380
+ project: Project | None = None,
1392
1381
  client_config=None,
1393
1382
  ) -> list[str]:
1394
1383
  dataset = self.get_dataset(
@@ -1402,7 +1391,7 @@ class Catalog:
1402
1391
  )
1403
1392
 
1404
1393
  def dataset_table_export_file_names(
1405
- self, name: str, version: str, project: Optional[Project] = None
1394
+ self, name: str, version: str, project: Project | None = None
1406
1395
  ) -> list[str]:
1407
1396
  dataset = self.get_dataset(
1408
1397
  name,
@@ -1414,9 +1403,9 @@ class Catalog:
1414
1403
  def remove_dataset(
1415
1404
  self,
1416
1405
  name: str,
1417
- project: Optional[Project] = None,
1418
- version: Optional[str] = None,
1419
- force: Optional[bool] = False,
1406
+ project: Project | None = None,
1407
+ version: str | None = None,
1408
+ force: bool | None = False,
1420
1409
  ):
1421
1410
  dataset = self.get_dataset(
1422
1411
  name,
@@ -1444,10 +1433,10 @@ class Catalog:
1444
1433
  def edit_dataset(
1445
1434
  self,
1446
1435
  name: str,
1447
- project: Optional[Project] = None,
1448
- new_name: Optional[str] = None,
1449
- description: Optional[str] = None,
1450
- attrs: Optional[list[str]] = None,
1436
+ project: Project | None = None,
1437
+ new_name: str | None = None,
1438
+ description: str | None = None,
1439
+ attrs: list[str] | None = None,
1451
1440
  ) -> DatasetRecord:
1452
1441
  update_data = {}
1453
1442
  if new_name:
@@ -1474,22 +1463,24 @@ class Catalog:
1474
1463
  *,
1475
1464
  client_config=None,
1476
1465
  ) -> Iterator[tuple[DataSource, Iterable[tuple]]]:
1477
- data_sources = self.enlist_sources(
1466
+ with self.enlist_sources(
1478
1467
  sources,
1479
1468
  update,
1480
1469
  skip_indexing=skip_indexing,
1481
1470
  client_config=client_config or self.client_config,
1482
- )
1471
+ ) as data_sources:
1472
+ if data_sources is None:
1473
+ return
1483
1474
 
1484
- for source in data_sources: # type: ignore [union-attr]
1485
- yield source, source.ls(fields)
1475
+ for source in data_sources:
1476
+ yield source, source.ls(fields)
1486
1477
 
1487
1478
  def pull_dataset( # noqa: C901, PLR0915
1488
1479
  self,
1489
1480
  remote_ds_uri: str,
1490
- output: Optional[str] = None,
1491
- local_ds_name: Optional[str] = None,
1492
- local_ds_version: Optional[str] = None,
1481
+ output: str | None = None,
1482
+ local_ds_name: str | None = None,
1483
+ local_ds_version: str | None = None,
1493
1484
  cp: bool = False,
1494
1485
  force: bool = False,
1495
1486
  *,
@@ -1636,7 +1627,7 @@ class Catalog:
1636
1627
  leave=False,
1637
1628
  )
1638
1629
 
1639
- schema = DatasetRecord.parse_schema(remote_ds_version.schema)
1630
+ schema = parse_schema(remote_ds_version.schema)
1640
1631
 
1641
1632
  local_ds = self.create_dataset(
1642
1633
  local_ds_name,
@@ -1746,11 +1737,12 @@ class Catalog:
1746
1737
  else:
1747
1738
  # since we don't call cp command, which does listing implicitly,
1748
1739
  # it needs to be done here
1749
- self.enlist_sources(
1740
+ with self.enlist_sources(
1750
1741
  sources,
1751
1742
  update,
1752
1743
  client_config=client_config or self.client_config,
1753
- )
1744
+ ):
1745
+ pass
1754
1746
 
1755
1747
  self.create_dataset_from_sources(
1756
1748
  output,
@@ -1760,86 +1752,6 @@ class Catalog:
1760
1752
  recursive=recursive,
1761
1753
  )
1762
1754
 
1763
- def query(
1764
- self,
1765
- query_script: str,
1766
- env: Optional[Mapping[str, str]] = None,
1767
- python_executable: str = sys.executable,
1768
- capture_output: bool = False,
1769
- output_hook: Callable[[str], None] = noop,
1770
- params: Optional[dict[str, str]] = None,
1771
- job_id: Optional[str] = None,
1772
- interrupt_timeout: Optional[int] = None,
1773
- terminate_timeout: Optional[int] = None,
1774
- ) -> None:
1775
- cmd = [python_executable, "-c", query_script]
1776
- env = dict(env or os.environ)
1777
- env.update(
1778
- {
1779
- "DATACHAIN_QUERY_PARAMS": json.dumps(params or {}),
1780
- "DATACHAIN_JOB_ID": job_id or "",
1781
- },
1782
- )
1783
- popen_kwargs: dict[str, Any] = {}
1784
- if capture_output:
1785
- popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT}
1786
-
1787
- def raise_termination_signal(sig: int, _: Any) -> NoReturn:
1788
- raise TerminationSignal(sig)
1789
-
1790
- thread: Optional[Thread] = None
1791
- with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603
1792
- logger.info("Starting process %s", proc.pid)
1793
-
1794
- orig_sigint_handler = signal.getsignal(signal.SIGINT)
1795
- # ignore SIGINT in the main process.
1796
- # In the terminal, SIGINTs are received by all the processes in
1797
- # the foreground process group, so the script will receive the signal too.
1798
- # (If we forward the signal to the child, it will receive it twice.)
1799
- signal.signal(signal.SIGINT, signal.SIG_IGN)
1800
-
1801
- orig_sigterm_handler = signal.getsignal(signal.SIGTERM)
1802
- signal.signal(signal.SIGTERM, raise_termination_signal)
1803
- try:
1804
- if capture_output:
1805
- args = (proc.stdout, output_hook)
1806
- thread = Thread(target=_process_stream, args=args, daemon=True)
1807
- thread.start()
1808
-
1809
- proc.wait()
1810
- except TerminationSignal as exc:
1811
- signal.signal(signal.SIGTERM, orig_sigterm_handler)
1812
- signal.signal(signal.SIGINT, orig_sigint_handler)
1813
- logger.info("Shutting down process %s, received %r", proc.pid, exc)
1814
- # Rather than forwarding the signal to the child, we try to shut it down
1815
- # gracefully. This is because we consider the script to be interactive
1816
- # and special, so we give it time to cleanup before exiting.
1817
- shutdown_process(proc, interrupt_timeout, terminate_timeout)
1818
- if proc.returncode:
1819
- raise QueryScriptCancelError(
1820
- "Query script was canceled by user", return_code=proc.returncode
1821
- ) from exc
1822
- finally:
1823
- signal.signal(signal.SIGTERM, orig_sigterm_handler)
1824
- signal.signal(signal.SIGINT, orig_sigint_handler)
1825
- if thread:
1826
- thread.join() # wait for the reader thread
1827
-
1828
- logger.info("Process %s exited with return code %s", proc.pid, proc.returncode)
1829
- if proc.returncode in (
1830
- QUERY_SCRIPT_CANCELED_EXIT_CODE,
1831
- QUERY_SCRIPT_SIGTERM_EXIT_CODE,
1832
- ):
1833
- raise QueryScriptCancelError(
1834
- "Query script was canceled by user",
1835
- return_code=proc.returncode,
1836
- )
1837
- if proc.returncode:
1838
- raise QueryScriptRunError(
1839
- f"Query script exited with error code {proc.returncode}",
1840
- return_code=proc.returncode,
1841
- )
1842
-
1843
1755
  def cp(
1844
1756
  self,
1845
1757
  sources: list[str],
@@ -1850,7 +1762,7 @@ class Catalog:
1850
1762
  no_cp: bool = False,
1851
1763
  no_glob: bool = False,
1852
1764
  *,
1853
- client_config: Optional["dict"] = None,
1765
+ client_config: dict | None = None,
1854
1766
  ) -> None:
1855
1767
  """
1856
1768
  This function copies files from cloud sources to local destination directory
@@ -1863,38 +1775,42 @@ class Catalog:
1863
1775
  no_glob,
1864
1776
  client_config=client_config,
1865
1777
  )
1778
+ try:
1779
+ always_copy_dir_contents, copy_to_filename = prepare_output_for_cp(
1780
+ node_groups, output, force, no_cp
1781
+ )
1782
+ total_size, total_files = collect_nodes_for_cp(node_groups, recursive)
1783
+ if not total_files:
1784
+ return
1866
1785
 
1867
- always_copy_dir_contents, copy_to_filename = prepare_output_for_cp(
1868
- node_groups, output, force, no_cp
1869
- )
1870
- total_size, total_files = collect_nodes_for_cp(node_groups, recursive)
1871
- if not total_files:
1872
- return
1873
-
1874
- desc_max_len = max(len(output) + 16, 19)
1875
- bar_format = (
1876
- "{desc:<"
1877
- f"{desc_max_len}"
1878
- "}{percentage:3.0f}%|{bar}| {n_fmt:>5}/{total_fmt:<5} "
1879
- "[{elapsed}<{remaining}, {rate_fmt:>8}]"
1880
- )
1786
+ desc_max_len = max(len(output) + 16, 19)
1787
+ bar_format = (
1788
+ "{desc:<"
1789
+ f"{desc_max_len}"
1790
+ "}{percentage:3.0f}%|{bar}| {n_fmt:>5}/{total_fmt:<5} "
1791
+ "[{elapsed}<{remaining}, {rate_fmt:>8}]"
1792
+ )
1881
1793
 
1882
- if not no_cp:
1883
- with get_download_bar(bar_format, total_size) as pbar:
1884
- for node_group in node_groups:
1885
- node_group.download(recursive=recursive, pbar=pbar)
1794
+ if not no_cp:
1795
+ with get_download_bar(bar_format, total_size) as pbar:
1796
+ for node_group in node_groups:
1797
+ node_group.download(recursive=recursive, pbar=pbar)
1886
1798
 
1887
- instantiate_node_groups(
1888
- node_groups,
1889
- output,
1890
- bar_format,
1891
- total_files,
1892
- force,
1893
- recursive,
1894
- no_cp,
1895
- always_copy_dir_contents,
1896
- copy_to_filename,
1897
- )
1799
+ instantiate_node_groups(
1800
+ node_groups,
1801
+ output,
1802
+ bar_format,
1803
+ total_files,
1804
+ force,
1805
+ recursive,
1806
+ no_cp,
1807
+ always_copy_dir_contents,
1808
+ copy_to_filename,
1809
+ )
1810
+ finally:
1811
+ for node_group in node_groups:
1812
+ with suppress(Exception):
1813
+ node_group.close()
1898
1814
 
1899
1815
  def du(
1900
1816
  self,
@@ -1904,24 +1820,26 @@ class Catalog:
1904
1820
  *,
1905
1821
  client_config=None,
1906
1822
  ) -> Iterable[tuple[str, float]]:
1907
- sources = self.enlist_sources(
1823
+ with self.enlist_sources(
1908
1824
  sources,
1909
1825
  update,
1910
1826
  client_config=client_config or self.client_config,
1911
- )
1827
+ ) as matched_sources:
1828
+ if matched_sources is None:
1829
+ return
1912
1830
 
1913
- def du_dirs(src, node, subdepth):
1914
- if subdepth > 0:
1915
- subdirs = src.listing.get_dirs_by_parent_path(node.path)
1916
- for sd in subdirs:
1917
- yield from du_dirs(src, sd, subdepth - 1)
1918
- yield (
1919
- src.get_node_full_path(node),
1920
- src.listing.du(node)[0],
1921
- )
1831
+ def du_dirs(src, node, subdepth):
1832
+ if subdepth > 0:
1833
+ subdirs = src.listing.get_dirs_by_parent_path(node.path)
1834
+ for sd in subdirs:
1835
+ yield from du_dirs(src, sd, subdepth - 1)
1836
+ yield (
1837
+ src.get_node_full_path(node),
1838
+ src.listing.du(node)[0],
1839
+ )
1922
1840
 
1923
- for src in sources:
1924
- yield from du_dirs(src, src.node, depth)
1841
+ for src in matched_sources:
1842
+ yield from du_dirs(src, src.node, depth)
1925
1843
 
1926
1844
  def find(
1927
1845
  self,
@@ -1937,39 +1855,42 @@ class Catalog:
1937
1855
  *,
1938
1856
  client_config=None,
1939
1857
  ) -> Iterator[str]:
1940
- sources = self.enlist_sources(
1858
+ with self.enlist_sources(
1941
1859
  sources,
1942
1860
  update,
1943
1861
  client_config=client_config or self.client_config,
1944
- )
1945
- if not columns:
1946
- columns = ["path"]
1947
- field_set = set()
1948
- for column in columns:
1949
- if column == "du":
1950
- field_set.add("dir_type")
1951
- field_set.add("size")
1952
- field_set.add("path")
1953
- elif column == "name":
1954
- field_set.add("path")
1955
- elif column == "path":
1956
- field_set.add("dir_type")
1957
- field_set.add("path")
1958
- elif column == "size":
1959
- field_set.add("size")
1960
- elif column == "type":
1961
- field_set.add("dir_type")
1962
- fields = list(field_set)
1963
- field_lookup = {f: i for i, f in enumerate(fields)}
1964
- for src in sources:
1965
- results = src.listing.find(
1966
- src.node, fields, names, inames, paths, ipaths, size, typ
1967
- )
1968
- for row in results:
1969
- yield "\t".join(
1970
- find_column_to_str(row, field_lookup, src, column)
1971
- for column in columns
1862
+ ) as matched_sources:
1863
+ if matched_sources is None:
1864
+ return
1865
+
1866
+ if not columns:
1867
+ columns = ["path"]
1868
+ field_set = set()
1869
+ for column in columns:
1870
+ if column == "du":
1871
+ field_set.add("dir_type")
1872
+ field_set.add("size")
1873
+ field_set.add("path")
1874
+ elif column == "name":
1875
+ field_set.add("path")
1876
+ elif column == "path":
1877
+ field_set.add("dir_type")
1878
+ field_set.add("path")
1879
+ elif column == "size":
1880
+ field_set.add("size")
1881
+ elif column == "type":
1882
+ field_set.add("dir_type")
1883
+ fields = list(field_set)
1884
+ field_lookup = {f: i for i, f in enumerate(fields)}
1885
+ for src in matched_sources:
1886
+ results = src.listing.find(
1887
+ src.node, fields, names, inames, paths, ipaths, size, typ
1972
1888
  )
1889
+ for row in results:
1890
+ yield "\t".join(
1891
+ find_column_to_str(row, field_lookup, src, column)
1892
+ for column in columns
1893
+ )
1973
1894
 
1974
1895
  def index(
1975
1896
  self,
@@ -1978,9 +1899,10 @@ class Catalog:
1978
1899
  *,
1979
1900
  client_config=None,
1980
1901
  ) -> None:
1981
- self.enlist_sources(
1902
+ with self.enlist_sources(
1982
1903
  sources,
1983
1904
  update,
1984
1905
  client_config=client_config or self.client_config,
1985
1906
  only_index=True,
1986
- )
1907
+ ):
1908
+ pass