datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
1
+ import builtins
2
+ from dataclasses import dataclass
3
+ from datetime import datetime
4
+ from typing import TypeVar
5
+
6
+ from datachain.dataset import DatasetDependency
7
+
8
+ DDN = TypeVar("DDN", bound="DatasetDependencyNode")
9
+
10
+
11
+ @dataclass
12
+ class DatasetDependencyNode:
13
+ namespace: str
14
+ project: str
15
+ id: int
16
+ dataset_id: int | None
17
+ dataset_version_id: int | None
18
+ dataset_name: str | None
19
+ dataset_version: str | None
20
+ created_at: datetime
21
+ source_dataset_id: int
22
+ source_dataset_version_id: int | None
23
+ depth: int
24
+
25
+ @classmethod
26
+ def parse(
27
+ cls: builtins.type[DDN],
28
+ namespace: str,
29
+ project: str,
30
+ id: int,
31
+ dataset_id: int | None,
32
+ dataset_version_id: int | None,
33
+ dataset_name: str | None,
34
+ dataset_version: str | None,
35
+ created_at: datetime,
36
+ source_dataset_id: int,
37
+ source_dataset_version_id: int | None,
38
+ depth: int,
39
+ ) -> "DatasetDependencyNode | None":
40
+ return cls(
41
+ namespace,
42
+ project,
43
+ id,
44
+ dataset_id,
45
+ dataset_version_id,
46
+ dataset_name,
47
+ dataset_version,
48
+ created_at,
49
+ source_dataset_id,
50
+ source_dataset_version_id,
51
+ depth,
52
+ )
53
+
54
+ def to_dependency(self) -> "DatasetDependency | None":
55
+ return DatasetDependency.parse(
56
+ namespace_name=self.namespace,
57
+ project_name=self.project,
58
+ id=self.id,
59
+ dataset_id=self.dataset_id,
60
+ dataset_version_id=self.dataset_version_id,
61
+ dataset_name=self.dataset_name,
62
+ dataset_version=self.dataset_version,
63
+ dataset_version_created_at=self.created_at,
64
+ )
65
+
66
+
67
+ def build_dependency_hierarchy(
68
+ dependency_nodes: list[DatasetDependencyNode | None],
69
+ ) -> tuple[
70
+ dict[int, DatasetDependency | None], dict[tuple[int, int | None], list[int]]
71
+ ]:
72
+ """
73
+ Build dependency hierarchy from dependency nodes.
74
+
75
+ Args:
76
+ dependency_nodes: List of DatasetDependencyNode objects from the database
77
+
78
+ Returns:
79
+ Tuple of (dependency_map, children_map) where:
80
+ - dependency_map: Maps dependency_id -> DatasetDependency
81
+ - children_map: Maps (source_dataset_id, source_version_id) ->
82
+ list of dependency_ids
83
+ """
84
+ dependency_map: dict[int, DatasetDependency | None] = {}
85
+ children_map: dict[tuple[int, int | None], list[int]] = {}
86
+
87
+ for node in dependency_nodes:
88
+ if node is None:
89
+ continue
90
+ dependency = node.to_dependency()
91
+ parent_key = (node.source_dataset_id, node.source_dataset_version_id)
92
+
93
+ if dependency is not None:
94
+ dependency_map[dependency.id] = dependency
95
+ children_map.setdefault(parent_key, []).append(dependency.id)
96
+ else:
97
+ # Handle case where dependency creation failed (e.g., deleted dependency)
98
+ dependency_map[node.id] = None
99
+ children_map.setdefault(parent_key, []).append(node.id)
100
+
101
+ return dependency_map, children_map
102
+
103
+
104
+ def populate_nested_dependencies(
105
+ dependency: DatasetDependency,
106
+ dependency_nodes: list[DatasetDependencyNode | None],
107
+ dependency_map: dict[int, DatasetDependency | None],
108
+ children_map: dict[tuple[int, int | None], list[int]],
109
+ ) -> None:
110
+ """
111
+ Recursively populate nested dependencies for a given dependency.
112
+
113
+ Args:
114
+ dependency: The dependency to populate nested dependencies for
115
+ dependency_nodes: All dependency nodes from the database
116
+ dependency_map: Maps dependency_id -> DatasetDependency
117
+ children_map: Maps (source_dataset_id, source_version_id) ->
118
+ list of dependency_ids
119
+ """
120
+ # Find the target dataset and version for this dependency
121
+ target_dataset_id, target_version_id = find_target_dataset_version(
122
+ dependency, dependency_nodes
123
+ )
124
+
125
+ if target_dataset_id is None or target_version_id is None:
126
+ return
127
+
128
+ # Get children for this target
129
+ target_key = (target_dataset_id, target_version_id)
130
+ if target_key not in children_map:
131
+ dependency.dependencies = []
132
+ return
133
+
134
+ child_dependency_ids = children_map[target_key]
135
+ child_dependencies = [dependency_map[child_id] for child_id in child_dependency_ids]
136
+
137
+ dependency.dependencies = child_dependencies
138
+
139
+ # Recursively populate children
140
+ for child_dependency in child_dependencies:
141
+ if child_dependency is not None:
142
+ populate_nested_dependencies(
143
+ child_dependency, dependency_nodes, dependency_map, children_map
144
+ )
145
+
146
+
147
+ def find_target_dataset_version(
148
+ dependency: DatasetDependency,
149
+ dependency_nodes: list[DatasetDependencyNode | None],
150
+ ) -> tuple[int | None, int | None]:
151
+ """
152
+ Find the target dataset ID and version ID for a given dependency.
153
+
154
+ Args:
155
+ dependency: The dependency to find target for
156
+ dependency_nodes: All dependency nodes from the database
157
+
158
+ Returns:
159
+ Tuple of (target_dataset_id, target_version_id) or (None, None) if not found
160
+ """
161
+ for node in dependency_nodes:
162
+ if node is not None and node.id == dependency.id:
163
+ return node.dataset_id, node.dataset_version_id
164
+ return None, None
@@ -1,12 +1,15 @@
1
1
  import os
2
+ import sys
2
3
  from importlib import import_module
3
- from typing import TYPE_CHECKING, Any, Optional
4
+ from typing import TYPE_CHECKING, Any
4
5
 
6
+ from datachain.plugins import ensure_plugins_loaded
5
7
  from datachain.utils import get_envs_by_prefix
6
8
 
7
9
  if TYPE_CHECKING:
8
10
  from datachain.catalog import Catalog
9
11
  from datachain.data_storage import AbstractMetastore, AbstractWarehouse
12
+ from datachain.query.udf import AbstractUDFDistributor
10
13
 
11
14
  METASTORE_SERIALIZED = "DATACHAIN__METASTORE"
12
15
  METASTORE_IMPORT_PATH = "DATACHAIN_METASTORE"
@@ -14,13 +17,16 @@ METASTORE_ARG_PREFIX = "DATACHAIN_METASTORE_ARG_"
14
17
  WAREHOUSE_SERIALIZED = "DATACHAIN__WAREHOUSE"
15
18
  WAREHOUSE_IMPORT_PATH = "DATACHAIN_WAREHOUSE"
16
19
  WAREHOUSE_ARG_PREFIX = "DATACHAIN_WAREHOUSE_ARG_"
20
+ DISTRIBUTED_IMPORT_PYTHONPATH = "DATACHAIN_DISTRIBUTED_PYTHONPATH"
17
21
  DISTRIBUTED_IMPORT_PATH = "DATACHAIN_DISTRIBUTED"
18
- DISTRIBUTED_ARG_PREFIX = "DATACHAIN_DISTRIBUTED_ARG_"
22
+ DISTRIBUTED_DISABLED = "DATACHAIN_DISTRIBUTED_DISABLED"
19
23
 
20
24
  IN_MEMORY_ERROR_MESSAGE = "In-memory is only supported on SQLite"
21
25
 
22
26
 
23
27
  def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
28
+ ensure_plugins_loaded()
29
+
24
30
  from datachain.data_storage import AbstractMetastore
25
31
  from datachain.data_storage.serializer import deserialize
26
32
 
@@ -61,6 +67,8 @@ def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
61
67
 
62
68
 
63
69
  def get_warehouse(in_memory: bool = False) -> "AbstractWarehouse":
70
+ ensure_plugins_loaded()
71
+
64
72
  from datachain.data_storage import AbstractWarehouse
65
73
  from datachain.data_storage.serializer import deserialize
66
74
 
@@ -100,31 +108,32 @@ def get_warehouse(in_memory: bool = False) -> "AbstractWarehouse":
100
108
  return warehouse_class(**warehouse_args)
101
109
 
102
110
 
103
- def get_distributed_class(**kwargs):
104
- distributed_import_path = os.environ.get(DISTRIBUTED_IMPORT_PATH)
105
- distributed_arg_envs = get_envs_by_prefix(DISTRIBUTED_ARG_PREFIX)
106
- # Convert env variable names to keyword argument names by lowercasing them
107
- distributed_args = {k.lower(): v for k, v in distributed_arg_envs.items()}
111
+ def get_udf_distributor_class() -> type["AbstractUDFDistributor"] | None:
112
+ if os.environ.get(DISTRIBUTED_DISABLED) == "True":
113
+ return None
108
114
 
109
- if not distributed_import_path:
110
- raise RuntimeError(
111
- f"{DISTRIBUTED_IMPORT_PATH} import path is required "
112
- "for distributed UDF processing."
113
- )
114
- # Distributed class paths are specified as (for example):
115
- # module.classname
115
+ if not (distributed_import_path := os.environ.get(DISTRIBUTED_IMPORT_PATH)):
116
+ return None
117
+
118
+ # Distributed class paths are specified as (for example): module.classname
116
119
  if "." not in distributed_import_path:
117
120
  raise RuntimeError(
118
121
  f"Invalid {DISTRIBUTED_IMPORT_PATH} import path: {distributed_import_path}"
119
122
  )
123
+
124
+ # Optional: set the Python path to look for the module
125
+ distributed_import_pythonpath = os.environ.get(DISTRIBUTED_IMPORT_PYTHONPATH)
126
+ if distributed_import_pythonpath and distributed_import_pythonpath not in sys.path:
127
+ sys.path.insert(0, distributed_import_pythonpath)
128
+
120
129
  module_name, _, class_name = distributed_import_path.rpartition(".")
121
130
  distributed = import_module(module_name)
122
- distributed_class = getattr(distributed, class_name)
123
- return distributed_class(**distributed_args | kwargs)
131
+ return getattr(distributed, class_name)
124
132
 
125
133
 
126
134
  def get_catalog(
127
- client_config: Optional[dict[str, Any]] = None, in_memory: bool = False
135
+ client_config: dict[str, Any] | None = None,
136
+ in_memory: bool = False,
128
137
  ) -> "Catalog":
129
138
  """
130
139
  Function that creates Catalog instance with appropriate metastore
@@ -139,8 +148,9 @@ def get_catalog(
139
148
  """
140
149
  from datachain.catalog import Catalog
141
150
 
151
+ metastore = get_metastore(in_memory=in_memory)
142
152
  return Catalog(
143
- metastore=get_metastore(in_memory=in_memory),
153
+ metastore=metastore,
144
154
  warehouse=get_warehouse(in_memory=in_memory),
145
155
  client_config=client_config,
146
156
  in_memory=in_memory,
@@ -0,0 +1,43 @@
1
+ import uuid
2
+ from dataclasses import dataclass
3
+ from datetime import datetime
4
+
5
+
6
+ @dataclass
7
+ class Checkpoint:
8
+ """
9
+ Represents a checkpoint within a job run.
10
+
11
+ A checkpoint marks a successfully completed stage of execution. In the event
12
+ of a failure, the job can resume from the most recent checkpoint rather than
13
+ starting over from the beginning.
14
+
15
+ Checkpoints can also be created in a "partial" mode, which indicates that the
16
+ work at this stage was only partially completed. For example, if a failure
17
+ occurs halfway through running a UDF, already computed results can still be
18
+ saved, allowing the job to resume from that partially completed state on
19
+ restart.
20
+ """
21
+
22
+ id: str
23
+ job_id: str
24
+ hash: str
25
+ partial: bool
26
+ created_at: datetime
27
+
28
+ @classmethod
29
+ def parse(
30
+ cls,
31
+ id: str | uuid.UUID,
32
+ job_id: str,
33
+ _hash: str,
34
+ partial: bool,
35
+ created_at: datetime,
36
+ ) -> "Checkpoint":
37
+ return cls(
38
+ str(id),
39
+ job_id,
40
+ _hash,
41
+ bool(partial),
42
+ created_at,
43
+ )
datachain/cli/__init__.py CHANGED
@@ -3,7 +3,6 @@ import os
3
3
  import sys
4
4
  import traceback
5
5
  from multiprocessing import freeze_support
6
- from typing import Optional
7
6
 
8
7
  from datachain.cli.utils import get_logging_level
9
8
 
@@ -16,7 +15,6 @@ from .commands import (
16
15
  index,
17
16
  list_datasets,
18
17
  ls,
19
- query,
20
18
  rm_dataset,
21
19
  show,
22
20
  )
@@ -25,7 +23,7 @@ from .parser import get_parser
25
23
  logger = logging.getLogger("datachain")
26
24
 
27
25
 
28
- def main(argv: Optional[list[str]] = None) -> int:
26
+ def main(argv: list[str] | None = None) -> int:
29
27
  from datachain.catalog import get_catalog
30
28
 
31
29
  # Required for Windows multiprocessing support
@@ -34,8 +32,10 @@ def main(argv: Optional[list[str]] = None) -> int:
34
32
  datachain_parser = get_parser()
35
33
  args = datachain_parser.parse_args(argv)
36
34
 
37
- if args.command in ("internal-run-udf", "internal-run-udf-worker"):
38
- return handle_udf(args.command)
35
+ if args.command == "internal-run-udf":
36
+ return handle_udf()
37
+ if args.command == "internal-run-udf-worker":
38
+ return handle_udf_runner()
39
39
 
40
40
  if args.command is None:
41
41
  datachain_parser.print_help(sys.stderr)
@@ -59,16 +59,22 @@ def main(argv: Optional[list[str]] = None) -> int:
59
59
 
60
60
  error = None
61
61
 
62
+ catalog = None
62
63
  try:
63
64
  catalog = get_catalog(client_config=client_config)
64
65
  return handle_command(args, catalog, client_config)
65
66
  except BrokenPipeError as exc:
66
67
  error, return_code = handle_broken_pipe_error(exc)
67
68
  return return_code
68
- except (KeyboardInterrupt, Exception) as exc:
69
+ except (KeyboardInterrupt, Exception) as exc: # noqa: BLE001
69
70
  error, return_code = handle_general_exception(exc, args, logging_level)
70
71
  return return_code
71
72
  finally:
73
+ if catalog is not None:
74
+ try:
75
+ catalog.close()
76
+ except Exception:
77
+ logger.exception("Failed to close catalog")
72
78
  from datachain.telemetry import telemetry
73
79
 
74
80
  telemetry.send_cli_call(args.command, error=error)
@@ -89,7 +95,6 @@ def handle_command(args, catalog, client_config) -> int:
89
95
  "find": lambda: handle_find_command(args, catalog),
90
96
  "index": lambda: handle_index_command(args, catalog),
91
97
  "completion": lambda: handle_completion_command(args),
92
- "query": lambda: handle_query_command(args, catalog),
93
98
  "clear-cache": lambda: clear_cache(catalog),
94
99
  "gc": lambda: garbage_collect(catalog),
95
100
  "auth": lambda: process_auth_cli_args(args),
@@ -98,8 +103,10 @@ def handle_command(args, catalog, client_config) -> int:
98
103
 
99
104
  handler = command_handlers.get(args.command)
100
105
  if handler:
101
- handler()
102
- return 0
106
+ return_code = handler()
107
+ if return_code is None:
108
+ return 0
109
+ return return_code
103
110
  print(f"invalid command: {args.command}", file=sys.stderr)
104
111
  return 1
105
112
 
@@ -149,10 +156,7 @@ def handle_dataset_command(args, catalog):
149
156
  args.name,
150
157
  new_name=args.new_name,
151
158
  description=args.description,
152
- labels=args.labels,
153
- studio=args.studio,
154
- local=args.local,
155
- all=args.all,
159
+ attrs=args.attrs,
156
160
  team=args.team,
157
161
  ),
158
162
  "ls": lambda: list_datasets(
@@ -170,8 +174,6 @@ def handle_dataset_command(args, catalog):
170
174
  version=args.version,
171
175
  force=args.force,
172
176
  studio=args.studio,
173
- local=args.local,
174
- all=args.all,
175
177
  team=args.team,
176
178
  ),
177
179
  "remove": lambda: rm_dataset(
@@ -180,8 +182,6 @@ def handle_dataset_command(args, catalog):
180
182
  version=args.version,
181
183
  force=args.force,
182
184
  studio=args.studio,
183
- local=args.local,
184
- all=args.all,
185
185
  team=args.team,
186
186
  ),
187
187
  }
@@ -263,15 +263,6 @@ def handle_completion_command(args):
263
263
  print(completion(args.shell))
264
264
 
265
265
 
266
- def handle_query_command(args, catalog):
267
- query(
268
- catalog,
269
- args.script,
270
- parallel=args.parallel,
271
- params=args.param,
272
- )
273
-
274
-
275
266
  def handle_broken_pipe_error(exc):
276
267
  # Python flushes standard streams on exit; redirect remaining output
277
268
  # to devnull to avoid another BrokenPipeError at shutdown
@@ -303,13 +294,13 @@ def handle_general_exception(exc, args, logging_level):
303
294
  return error, 1
304
295
 
305
296
 
306
- def handle_udf(command):
307
- if command == "internal-run-udf":
308
- from datachain.query.dispatch import udf_entrypoint
297
+ def handle_udf() -> int:
298
+ from datachain.query.dispatch import udf_entrypoint
299
+
300
+ return udf_entrypoint()
309
301
 
310
- return udf_entrypoint()
311
302
 
312
- if command == "internal-run-udf-worker":
313
- from datachain.query.dispatch import udf_worker_entrypoint
303
+ def handle_udf_runner() -> int:
304
+ from datachain.query.dispatch import udf_worker_entrypoint
314
305
 
315
- return udf_worker_entrypoint()
306
+ return udf_worker_entrypoint()
@@ -1,14 +1,8 @@
1
- from .datasets import (
2
- edit_dataset,
3
- list_datasets,
4
- list_datasets_local,
5
- rm_dataset,
6
- )
1
+ from .datasets import edit_dataset, list_datasets, list_datasets_local, rm_dataset
7
2
  from .du import du
8
3
  from .index import index
9
4
  from .ls import ls
10
5
  from .misc import clear_cache, completion, garbage_collect
11
- from .query import query
12
6
  from .show import show
13
7
 
14
8
  __all__ = [
@@ -21,7 +15,6 @@ __all__ = [
21
15
  "list_datasets",
22
16
  "list_datasets_local",
23
17
  "ls",
24
- "query",
25
18
  "rm_dataset",
26
19
  "show",
27
20
  ]
@@ -1,29 +1,41 @@
1
1
  import sys
2
- from typing import TYPE_CHECKING, Optional
2
+ from collections.abc import Iterable, Iterator
3
+ from typing import TYPE_CHECKING
3
4
 
4
5
  from tabulate import tabulate
5
6
 
6
- if TYPE_CHECKING:
7
- from datachain.catalog import Catalog
8
-
7
+ from datachain import semver
8
+ from datachain.catalog import is_namespace_local
9
9
  from datachain.cli.utils import determine_flavors
10
10
  from datachain.config import Config
11
- from datachain.error import DatasetNotFoundError
11
+ from datachain.error import DataChainError, DatasetNotFoundError
12
12
  from datachain.studio import list_datasets as list_datasets_studio
13
13
 
14
+ if TYPE_CHECKING:
15
+ from datachain.catalog import Catalog
16
+
17
+
18
+ def group_dataset_versions(
19
+ datasets: Iterable[tuple[str, str]], latest_only=True
20
+ ) -> dict[str, str | list[str]]:
21
+ grouped: dict[str, list[tuple[int, int, int]]] = {}
14
22
 
15
- def group_dataset_versions(datasets, latest_only=True):
16
- grouped = {}
17
23
  # Sort to ensure groupby works as expected
18
24
  # (groupby expects consecutive items with the same key)
19
25
  for name, version in sorted(datasets):
20
- grouped.setdefault(name, []).append(version)
26
+ grouped.setdefault(name, []).append(semver.parse(version))
21
27
 
22
28
  if latest_only:
23
29
  # For each dataset name, pick the highest version.
24
- return {name: max(versions) for name, versions in grouped.items()}
30
+ return {
31
+ name: semver.create(*(max(versions))) for name, versions in grouped.items()
32
+ }
33
+
25
34
  # For each dataset name, return a sorted list of unique versions.
26
- return {name: sorted(set(versions)) for name, versions in grouped.items()}
35
+ return {
36
+ name: [semver.create(*v) for v in sorted(set(versions))]
37
+ for name, versions in grouped.items()
38
+ }
27
39
 
28
40
 
29
41
  def list_datasets(
@@ -31,10 +43,10 @@ def list_datasets(
31
43
  studio: bool = False,
32
44
  local: bool = False,
33
45
  all: bool = True,
34
- team: Optional[str] = None,
46
+ team: str | None = None,
35
47
  latest_only: bool = True,
36
- name: Optional[str] = None,
37
- ):
48
+ name: str | None = None,
49
+ ) -> None:
38
50
  token = Config().read().get("studio", {}).get("token")
39
51
  all, local, studio = determine_flavors(studio, local, all, token)
40
52
  if name:
@@ -94,23 +106,31 @@ def list_datasets(
94
106
  print(tabulate(rows, headers="keys"))
95
107
 
96
108
 
97
- def list_datasets_local(catalog: "Catalog", name: Optional[str] = None):
109
+ def list_datasets_local(
110
+ catalog: "Catalog", name: str | None = None
111
+ ) -> Iterator[tuple[str, str]]:
98
112
  if name:
99
113
  yield from list_datasets_local_versions(catalog, name)
100
114
  return
101
115
 
102
116
  for d in catalog.ls_datasets():
103
117
  for v in d.versions:
104
- yield (d.name, v.version)
118
+ yield d.full_name, v.version
105
119
 
106
120
 
107
- def list_datasets_local_versions(catalog: "Catalog", name: str):
108
- ds = catalog.get_dataset(name)
121
+ def list_datasets_local_versions(
122
+ catalog: "Catalog", name: str
123
+ ) -> Iterator[tuple[str, str]]:
124
+ namespace_name, project_name, name = catalog.get_full_dataset_name(name)
125
+
126
+ ds = catalog.get_dataset(
127
+ name, namespace_name=namespace_name, project_name=project_name
128
+ )
109
129
  for v in ds.versions:
110
- yield (name, v.version)
130
+ yield name, v.version
111
131
 
112
132
 
113
- def _datasets_tabulate_row(name, both, local_version, studio_version):
133
+ def _datasets_tabulate_row(name, both, local_version, studio_version) -> dict[str, str]:
114
134
  row = {
115
135
  "Name": name,
116
136
  }
@@ -127,49 +147,60 @@ def _datasets_tabulate_row(name, both, local_version, studio_version):
127
147
  def rm_dataset(
128
148
  catalog: "Catalog",
129
149
  name: str,
130
- version: Optional[int] = None,
131
- force: Optional[bool] = False,
132
- studio: bool = False,
133
- local: bool = False,
134
- all: bool = True,
135
- team: Optional[str] = None,
136
- ):
137
- from datachain.studio import remove_studio_dataset
138
-
139
- token = Config().read().get("studio", {}).get("token")
140
- all, local, studio = determine_flavors(studio, local, all, token)
141
-
142
- if all or local:
150
+ version: str | None = None,
151
+ force: bool | None = False,
152
+ studio: bool | None = False,
153
+ team: str | None = None,
154
+ ) -> None:
155
+ namespace_name, project_name, name = catalog.get_full_dataset_name(name)
156
+
157
+ if studio:
158
+ # removing Studio dataset from CLI
159
+ from datachain.studio import remove_studio_dataset
160
+
161
+ if Config().read().get("studio", {}).get("token"):
162
+ remove_studio_dataset(
163
+ team, name, namespace_name, project_name, version, force
164
+ )
165
+ else:
166
+ raise DataChainError(
167
+ "Not logged in to Studio. Log in with 'datachain auth login'."
168
+ )
169
+ else:
143
170
  try:
144
- catalog.remove_dataset(name, version=version, force=force)
171
+ project = catalog.metastore.get_project(project_name, namespace_name)
172
+ catalog.remove_dataset(name, project, version=version, force=force)
145
173
  except DatasetNotFoundError:
146
174
  print("Dataset not found in local", file=sys.stderr)
147
175
 
148
- if (all or studio) and token:
149
- remove_studio_dataset(team, name, version, force)
150
-
151
176
 
152
177
  def edit_dataset(
153
178
  catalog: "Catalog",
154
179
  name: str,
155
- new_name: Optional[str] = None,
156
- description: Optional[str] = None,
157
- labels: Optional[list[str]] = None,
158
- studio: bool = False,
159
- local: bool = False,
160
- all: bool = True,
161
- team: Optional[str] = None,
162
- ):
163
- from datachain.studio import edit_studio_dataset
180
+ new_name: str | None = None,
181
+ description: str | None = None,
182
+ attrs: list[str] | None = None,
183
+ team: str | None = None,
184
+ ) -> None:
185
+ from datachain.lib.dc.utils import is_studio
164
186
 
165
- token = Config().read().get("studio", {}).get("token")
166
- all, local, studio = determine_flavors(studio, local, all, token)
187
+ namespace_name, project_name, name = catalog.get_full_dataset_name(name)
167
188
 
168
- if all or local:
189
+ if is_studio() or is_namespace_local(namespace_name):
169
190
  try:
170
- catalog.edit_dataset(name, new_name, description, labels)
191
+ catalog.edit_dataset(
192
+ name, catalog.metastore.default_project, new_name, description, attrs
193
+ )
171
194
  except DatasetNotFoundError:
172
195
  print("Dataset not found in local", file=sys.stderr)
173
-
174
- if (all or studio) and token:
175
- edit_studio_dataset(team, name, new_name, description, labels)
196
+ else:
197
+ from datachain.studio import edit_studio_dataset
198
+
199
+ if Config().read().get("studio", {}).get("token"):
200
+ edit_studio_dataset(
201
+ team, name, namespace_name, project_name, new_name, description, attrs
202
+ )
203
+ else:
204
+ raise DataChainError(
205
+ "Not logged in to Studio. Log in with 'datachain auth login'."
206
+ )