datachain 0.11.11__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (44) hide show
  1. datachain/catalog/catalog.py +39 -7
  2. datachain/catalog/loader.py +19 -13
  3. datachain/cli/__init__.py +2 -1
  4. datachain/cli/commands/ls.py +8 -6
  5. datachain/cli/commands/show.py +7 -0
  6. datachain/cli/parser/studio.py +13 -1
  7. datachain/client/fsspec.py +12 -16
  8. datachain/client/gcs.py +1 -1
  9. datachain/client/hf.py +36 -14
  10. datachain/client/local.py +1 -4
  11. datachain/client/s3.py +1 -1
  12. datachain/data_storage/metastore.py +6 -0
  13. datachain/data_storage/warehouse.py +3 -8
  14. datachain/dataset.py +8 -0
  15. datachain/error.py +0 -12
  16. datachain/fs/utils.py +30 -0
  17. datachain/func/__init__.py +5 -0
  18. datachain/func/func.py +2 -1
  19. datachain/lib/dc.py +59 -15
  20. datachain/lib/file.py +63 -18
  21. datachain/lib/image.py +30 -6
  22. datachain/lib/listing.py +21 -39
  23. datachain/lib/meta_formats.py +2 -2
  24. datachain/lib/signal_schema.py +65 -18
  25. datachain/lib/udf.py +3 -0
  26. datachain/lib/udf_signature.py +17 -9
  27. datachain/lib/video.py +7 -5
  28. datachain/model/bbox.py +209 -58
  29. datachain/model/pose.py +49 -37
  30. datachain/model/segment.py +22 -18
  31. datachain/model/ultralytics/bbox.py +9 -9
  32. datachain/model/ultralytics/pose.py +7 -7
  33. datachain/model/ultralytics/segment.py +7 -7
  34. datachain/model/utils.py +191 -0
  35. datachain/query/dataset.py +8 -2
  36. datachain/sql/sqlite/base.py +2 -2
  37. datachain/studio.py +8 -6
  38. datachain/utils.py +0 -16
  39. {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/METADATA +4 -2
  40. {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/RECORD +44 -42
  41. {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/WHEEL +1 -1
  42. {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/LICENSE +0 -0
  43. {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/entry_points.txt +0 -0
  44. {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,6 @@ from typing import (
25
25
  )
26
26
  from uuid import uuid4
27
27
 
28
- import requests
29
28
  import sqlalchemy as sa
30
29
  from sqlalchemy import Column
31
30
  from tqdm.auto import tqdm
@@ -54,7 +53,6 @@ from datachain.error import (
54
53
  from datachain.lib.listing import get_listing
55
54
  from datachain.node import DirType, Node, NodeWithPath
56
55
  from datachain.nodes_thread_pool import NodesThreadPool
57
- from datachain.remote.studio import StudioClient
58
56
  from datachain.sql.types import DateTime, SQLType
59
57
  from datachain.utils import DataChainDir
60
58
 
@@ -162,6 +160,8 @@ class DatasetRowsFetcher(NodesThreadPool):
162
160
  max_threads: int = PULL_DATASET_MAX_THREADS,
163
161
  progress_bar=None,
164
162
  ):
163
+ from datachain.remote.studio import StudioClient
164
+
165
165
  super().__init__(max_threads)
166
166
  self._check_dependencies()
167
167
  self.metastore = metastore
@@ -234,6 +234,8 @@ class DatasetRowsFetcher(NodesThreadPool):
234
234
  return df.drop("sys__id", axis=1)
235
235
 
236
236
  def get_parquet_content(self, url: str):
237
+ import requests
238
+
237
239
  while True:
238
240
  if self.should_check_for_status():
239
241
  self.check_for_status()
@@ -775,6 +777,8 @@ class Catalog:
775
777
  validate_version: Optional[bool] = True,
776
778
  listing: Optional[bool] = False,
777
779
  uuid: Optional[str] = None,
780
+ description: Optional[str] = None,
781
+ labels: Optional[list[str]] = None,
778
782
  ) -> "DatasetRecord":
779
783
  """
780
784
  Creates new dataset of a specific version.
@@ -801,6 +805,8 @@ class Catalog:
801
805
  query_script=query_script,
802
806
  schema=schema,
803
807
  ignore_if_exists=True,
808
+ description=description,
809
+ labels=labels,
804
810
  )
805
811
 
806
812
  version = version or default_version
@@ -1130,6 +1136,8 @@ class Catalog:
1130
1136
  raise DatasetNotFoundError(f"Dataset with version uuid {uuid} not found.")
1131
1137
 
1132
1138
  def get_remote_dataset(self, name: str) -> DatasetRecord:
1139
+ from datachain.remote.studio import StudioClient
1140
+
1133
1141
  studio_client = StudioClient()
1134
1142
 
1135
1143
  info_response = studio_client.dataset_info(name)
@@ -1164,8 +1172,27 @@ class Catalog:
1164
1172
 
1165
1173
  return direct_dependencies
1166
1174
 
1167
- def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetListRecord]:
1168
- datasets = self.metastore.list_datasets()
1175
+ def ls_datasets(
1176
+ self, include_listing: bool = False, studio: bool = False
1177
+ ) -> Iterator[DatasetListRecord]:
1178
+ from datachain.remote.studio import StudioClient
1179
+
1180
+ if studio:
1181
+ client = StudioClient()
1182
+ response = client.ls_datasets()
1183
+ if not response.ok:
1184
+ raise DataChainError(response.message)
1185
+ if not response.data:
1186
+ return
1187
+
1188
+ datasets: Iterator[DatasetListRecord] = (
1189
+ DatasetListRecord.from_dict(d)
1190
+ for d in response.data
1191
+ if not d.get("name", "").startswith(QUERY_DATASET_PREFIX)
1192
+ )
1193
+ else:
1194
+ datasets = self.metastore.list_datasets()
1195
+
1169
1196
  for d in datasets:
1170
1197
  if not d.is_bucket_listing or include_listing:
1171
1198
  yield d
@@ -1173,9 +1200,12 @@ class Catalog:
1173
1200
  def list_datasets_versions(
1174
1201
  self,
1175
1202
  include_listing: bool = False,
1203
+ studio: bool = False,
1176
1204
  ) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", Optional["Job"]]]:
1177
1205
  """Iterate over all dataset versions with related jobs."""
1178
- datasets = list(self.ls_datasets(include_listing=include_listing))
1206
+ datasets = list(
1207
+ self.ls_datasets(include_listing=include_listing, studio=studio)
1208
+ )
1179
1209
 
1180
1210
  # preselect dataset versions jobs from db to avoid multiple queries
1181
1211
  jobs_ids: set[str] = {
@@ -1345,6 +1375,8 @@ class Catalog:
1345
1375
  if cp and not output:
1346
1376
  raise ValueError("Please provide output directory for instantiation")
1347
1377
 
1378
+ from datachain.remote.studio import StudioClient
1379
+
1348
1380
  studio_client = StudioClient()
1349
1381
 
1350
1382
  try:
@@ -1580,7 +1612,7 @@ class Catalog:
1580
1612
  except TerminationSignal as exc:
1581
1613
  signal.signal(signal.SIGTERM, orig_sigterm_handler)
1582
1614
  signal.signal(signal.SIGINT, orig_sigint_handler)
1583
- logging.info("Shutting down process %s, received %r", proc.pid, exc)
1615
+ logger.info("Shutting down process %s, received %r", proc.pid, exc)
1584
1616
  # Rather than forwarding the signal to the child, we try to shut it down
1585
1617
  # gracefully. This is because we consider the script to be interactive
1586
1618
  # and special, so we give it time to cleanup before exiting.
@@ -1595,7 +1627,7 @@ class Catalog:
1595
1627
  if thread:
1596
1628
  thread.join() # wait for the reader thread
1597
1629
 
1598
- logging.info("Process %s exited with return code %s", proc.pid, proc.returncode)
1630
+ logger.info("Process %s exited with return code %s", proc.pid, proc.returncode)
1599
1631
  if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE:
1600
1632
  raise QueryScriptCancelError(
1601
1633
  "Query script was canceled by user",
@@ -1,19 +1,13 @@
1
1
  import os
2
2
  from importlib import import_module
3
- from typing import Any, Optional
4
-
5
- from datachain.catalog import Catalog
6
- from datachain.data_storage import (
7
- AbstractMetastore,
8
- AbstractWarehouse,
9
- )
10
- from datachain.data_storage.serializer import deserialize
11
- from datachain.data_storage.sqlite import (
12
- SQLiteMetastore,
13
- SQLiteWarehouse,
14
- )
3
+ from typing import TYPE_CHECKING, Any, Optional
4
+
15
5
  from datachain.utils import get_envs_by_prefix
16
6
 
7
+ if TYPE_CHECKING:
8
+ from datachain.catalog import Catalog
9
+ from datachain.data_storage import AbstractMetastore, AbstractWarehouse
10
+
17
11
  METASTORE_SERIALIZED = "DATACHAIN__METASTORE"
18
12
  METASTORE_IMPORT_PATH = "DATACHAIN_METASTORE"
19
13
  METASTORE_ARG_PREFIX = "DATACHAIN_METASTORE_ARG_"
@@ -27,6 +21,9 @@ IN_MEMORY_ERROR_MESSAGE = "In-memory is only supported on SQLite"
27
21
 
28
22
 
29
23
  def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
24
+ from datachain.data_storage import AbstractMetastore
25
+ from datachain.data_storage.serializer import deserialize
26
+
30
27
  metastore_serialized = os.environ.get(METASTORE_SERIALIZED)
31
28
  if metastore_serialized:
32
29
  metastore_obj = deserialize(metastore_serialized)
@@ -45,6 +42,8 @@ def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
45
42
  }
46
43
 
47
44
  if not metastore_import_path:
45
+ from datachain.data_storage.sqlite import SQLiteMetastore
46
+
48
47
  metastore_args["in_memory"] = in_memory
49
48
  return SQLiteMetastore(**metastore_args)
50
49
  if in_memory:
@@ -62,6 +61,9 @@ def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
62
61
 
63
62
 
64
63
  def get_warehouse(in_memory: bool = False) -> "AbstractWarehouse":
64
+ from datachain.data_storage import AbstractWarehouse
65
+ from datachain.data_storage.serializer import deserialize
66
+
65
67
  warehouse_serialized = os.environ.get(WAREHOUSE_SERIALIZED)
66
68
  if warehouse_serialized:
67
69
  warehouse_obj = deserialize(warehouse_serialized)
@@ -80,6 +82,8 @@ def get_warehouse(in_memory: bool = False) -> "AbstractWarehouse":
80
82
  }
81
83
 
82
84
  if not warehouse_import_path:
85
+ from datachain.data_storage.sqlite import SQLiteWarehouse
86
+
83
87
  warehouse_args["in_memory"] = in_memory
84
88
  return SQLiteWarehouse(**warehouse_args)
85
89
  if in_memory:
@@ -121,7 +125,7 @@ def get_distributed_class(**kwargs):
121
125
 
122
126
  def get_catalog(
123
127
  client_config: Optional[dict[str, Any]] = None, in_memory: bool = False
124
- ) -> Catalog:
128
+ ) -> "Catalog":
125
129
  """
126
130
  Function that creates Catalog instance with appropriate metastore
127
131
  and warehouse classes. Metastore class can be provided with env variable
@@ -133,6 +137,8 @@ def get_catalog(
133
137
  and name of variable after, e.g. if it accepts team_id as kwargs
134
138
  we can provide DATACHAIN_METASTORE_ARG_TEAM_ID=12345 env variable.
135
139
  """
140
+ from datachain.catalog import Catalog
141
+
136
142
  return Catalog(
137
143
  metastore=get_metastore(in_memory=in_memory),
138
144
  warehouse=get_warehouse(in_memory=in_memory),
datachain/cli/__init__.py CHANGED
@@ -6,7 +6,6 @@ from multiprocessing import freeze_support
6
6
  from typing import Optional
7
7
 
8
8
  from datachain.cli.utils import get_logging_level
9
- from datachain.telemetry import telemetry
10
9
 
11
10
  from .commands import (
12
11
  clear_cache,
@@ -70,6 +69,8 @@ def main(argv: Optional[list[str]] = None) -> int:
70
69
  error, return_code = handle_general_exception(exc, args, logging_level)
71
70
  return return_code
72
71
  finally:
72
+ from datachain.telemetry import telemetry
73
+
73
74
  telemetry.send_cli_call(args.command, error=error)
74
75
 
75
76
 
@@ -38,11 +38,12 @@ def ls_local(
38
38
  ):
39
39
  from datachain import DataChain
40
40
 
41
- if catalog is None:
42
- from datachain.catalog import get_catalog
43
-
44
- catalog = get_catalog(client_config=client_config)
45
41
  if sources:
42
+ if catalog is None:
43
+ from datachain.catalog import get_catalog
44
+
45
+ catalog = get_catalog(client_config=client_config)
46
+
46
47
  actual_sources = list(ls_urls(sources, catalog=catalog, long=long, **kwargs))
47
48
  if len(actual_sources) == 1:
48
49
  for _, entries in actual_sources:
@@ -61,8 +62,9 @@ def ls_local(
61
62
  for entry in entries:
62
63
  print(format_ls_entry(entry))
63
64
  else:
64
- chain = DataChain.listings()
65
- for ls in chain.collect("listing"):
65
+ # Collect results in a list here to prevent interference from `tqdm` and `print`
66
+ listing = list(DataChain.listings().collect("listing"))
67
+ for ls in listing:
66
68
  print(format_ls_entry(f"{ls.uri}@v{ls.version}")) # type: ignore[union-attr]
67
69
 
68
70
 
@@ -40,6 +40,13 @@ def show(
40
40
  .offset(offset)
41
41
  )
42
42
  records = query.to_db_records()
43
+ print("Name: ", name)
44
+ if dataset.description:
45
+ print("Description: ", dataset.description)
46
+ if dataset.labels:
47
+ print("Labels: ", ",".join(dataset.labels))
48
+ print("\n")
49
+
43
50
  show_records(records, collapse_columns=not no_collapse, hidden_fields=hidden_fields)
44
51
 
45
52
  if schema and dataset_version.feature_schema:
@@ -63,19 +63,31 @@ def add_auth_parser(subparsers, parent_parser) -> None:
63
63
  default=False,
64
64
  help="Use code-based authentication without browser",
65
65
  )
66
+ login_parser.add_argument(
67
+ "--local",
68
+ action="store_true",
69
+ default=False,
70
+ help="Save the token in the local project config",
71
+ )
66
72
 
67
73
  auth_logout_help = "Log out from Studio"
68
74
  auth_logout_description = (
69
75
  "Remove the Studio authentication token from global config."
70
76
  )
71
77
 
72
- auth_subparser.add_parser(
78
+ logout_parser = auth_subparser.add_parser(
73
79
  "logout",
74
80
  parents=[parent_parser],
75
81
  description=auth_logout_description,
76
82
  help=auth_logout_help,
77
83
  formatter_class=CustomHelpFormatter,
78
84
  )
85
+ logout_parser.add_argument(
86
+ "--local",
87
+ action="store_true",
88
+ default=False,
89
+ help="Remove the token from the local project config",
90
+ )
79
91
 
80
92
  auth_team_help = "Set default team for Studio operations"
81
93
  auth_team_description = "Set the default team for Studio operations."
@@ -17,10 +17,10 @@ from typing import (
17
17
  ClassVar,
18
18
  NamedTuple,
19
19
  Optional,
20
+ Union,
20
21
  )
21
22
  from urllib.parse import urlparse
22
23
 
23
- from botocore.exceptions import ClientError
24
24
  from dvc_objects.fs.system import reflink
25
25
  from fsspec.asyn import get_loop, sync
26
26
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
@@ -28,7 +28,6 @@ from tqdm.auto import tqdm
28
28
 
29
29
  from datachain.cache import Cache
30
30
  from datachain.client.fileslice import FileWrapper
31
- from datachain.error import ClientError as DataChainClientError
32
31
  from datachain.nodes_fetcher import NodesFetcher
33
32
  from datachain.nodes_thread_pool import NodeChunk
34
33
 
@@ -83,19 +82,17 @@ class Client(ABC):
83
82
  self.uri = self.get_uri(self.name)
84
83
 
85
84
  @staticmethod
86
- def get_implementation(url: str) -> type["Client"]:
85
+ def get_implementation(url: Union[str, os.PathLike[str]]) -> type["Client"]:
87
86
  from .azure import AzureClient
88
87
  from .gcs import GCSClient
89
88
  from .hf import HfClient
90
89
  from .local import FileClient
91
90
  from .s3 import ClientS3
92
91
 
93
- protocol = urlparse(url).scheme
92
+ protocol = urlparse(str(url)).scheme
94
93
 
95
- if not protocol or _is_win_local_path(url):
94
+ if not protocol or _is_win_local_path(str(url)):
96
95
  return FileClient
97
-
98
- protocol = protocol.lower()
99
96
  if protocol == ClientS3.protocol:
100
97
  return ClientS3
101
98
  if protocol == GCSClient.protocol:
@@ -121,9 +118,11 @@ class Client(ABC):
121
118
  return cls.get_uri(storage_name), rel_path
122
119
 
123
120
  @staticmethod
124
- def get_client(source: str, cache: Cache, **kwargs) -> "Client":
121
+ def get_client(
122
+ source: Union[str, os.PathLike[str]], cache: Cache, **kwargs
123
+ ) -> "Client":
125
124
  cls = Client.get_implementation(source)
126
- storage_url, _ = cls.split_url(source)
125
+ storage_url, _ = cls.split_url(str(source))
127
126
  if os.name == "nt":
128
127
  storage_url = storage_url.removeprefix("/")
129
128
 
@@ -209,7 +208,7 @@ class Client(ABC):
209
208
 
210
209
  async def get_current_etag(self, file: "File") -> str:
211
210
  kwargs = {}
212
- if self.fs.version_aware:
211
+ if getattr(self.fs, "version_aware", False):
213
212
  kwargs["version_id"] = file.version
214
213
  info = await self.fs._info(
215
214
  self.get_full_path(file.path, file.version), **kwargs
@@ -286,11 +285,6 @@ class Client(ABC):
286
285
  worker.cancel()
287
286
  if excs:
288
287
  raise excs[0]
289
- except ClientError as exc:
290
- raise DataChainClientError(
291
- exc.response.get("Error", {}).get("Message") or exc,
292
- exc.response.get("Error", {}).get("Code"),
293
- ) from exc
294
288
  finally:
295
289
  # This ensures the progress bar is closed before any exceptions are raised
296
290
  progress_bar.close()
@@ -333,7 +327,9 @@ class Client(ABC):
333
327
  return not (key.startswith("/") or key.endswith("/") or "//" in key)
334
328
 
335
329
  async def ls_dir(self, path):
336
- return await self.fs._ls(path, detail=True, versions=True)
330
+ if getattr(self.fs, "version_aware", False):
331
+ kwargs = {"versions": True}
332
+ return await self.fs._ls(path, detail=True, **kwargs)
337
333
 
338
334
  def rel_path(self, path: str) -> str:
339
335
  return self.fs.split_path(path)[1]
datachain/client/gcs.py CHANGED
@@ -30,7 +30,7 @@ class GCSClient(Client):
30
30
  if kwargs.pop("anon", False):
31
31
  kwargs["token"] = "anon" # noqa: S105
32
32
 
33
- return cast(GCSFileSystem, super().create_fs(**kwargs))
33
+ return cast("GCSFileSystem", super().create_fs(**kwargs))
34
34
 
35
35
  def url(self, path: str, expires: int = 3600, **kwargs) -> str:
36
36
  """
datachain/client/hf.py CHANGED
@@ -1,25 +1,50 @@
1
- import os
1
+ import functools
2
2
  import posixpath
3
- from typing import Any, cast
4
-
5
- from huggingface_hub import HfFileSystem
3
+ from typing import Any
6
4
 
7
5
  from datachain.lib.file import File
8
6
 
9
7
  from .fsspec import Client
10
8
 
11
9
 
10
+ class classproperty: # noqa: N801
11
+ def __init__(self, func):
12
+ self.fget = func
13
+
14
+ def __get__(self, instance, owner):
15
+ return self.fget(owner)
16
+
17
+
18
+ @functools.cache
19
+ def get_hf_filesystem_cls():
20
+ import fsspec
21
+ from packaging.version import Version, parse
22
+
23
+ fsspec_version = parse(fsspec.__version__)
24
+ minver = Version("2024.12.0")
25
+
26
+ if fsspec_version < minver:
27
+ raise ImportError(
28
+ f"datachain requires 'fsspec>={minver}' but version "
29
+ f"{fsspec_version} is installed."
30
+ )
31
+
32
+ from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper
33
+ from huggingface_hub import HfFileSystem
34
+
35
+ fs_cls = AsyncFileSystemWrapper.wrap_class(HfFileSystem)
36
+ # AsyncFileSystemWrapper does not set class properties, so we need to set them back.
37
+ fs_cls.protocol = HfFileSystem.protocol
38
+ return fs_cls
39
+
40
+
12
41
  class HfClient(Client):
13
- FS_CLASS = HfFileSystem
14
42
  PREFIX = "hf://"
15
43
  protocol = "hf"
16
44
 
17
- @classmethod
18
- def create_fs(cls, **kwargs) -> HfFileSystem:
19
- if os.environ.get("HF_TOKEN"):
20
- kwargs["token"] = os.environ["HF_TOKEN"]
21
-
22
- return cast(HfFileSystem, super().create_fs(**kwargs))
45
+ @classproperty
46
+ def FS_CLASS(cls): # noqa: N802, N805
47
+ return get_hf_filesystem_cls()
23
48
 
24
49
  def info_to_file(self, v: dict[str, Any], path: str) -> File:
25
50
  return File(
@@ -31,8 +56,5 @@ class HfClient(Client):
31
56
  last_modified=v["last_commit"].date,
32
57
  )
33
58
 
34
- async def ls_dir(self, path):
35
- return self.fs.ls(path, detail=True)
36
-
37
59
  def rel_path(self, path):
38
60
  return posixpath.relpath(path, self.name)
datachain/client/local.py CHANGED
@@ -67,10 +67,7 @@ class FileClient(Client):
67
67
  @classmethod
68
68
  def split_url(cls, url: str) -> tuple[str, str]:
69
69
  parsed = urlparse(url)
70
- if parsed.scheme == "file":
71
- scheme, rest = url.split(":", 1)
72
- url = f"{scheme.lower()}:{rest}"
73
- else:
70
+ if parsed.scheme != "file":
74
71
  url = cls.path_to_uri(url)
75
72
 
76
73
  fill_path = url[len(cls.PREFIX) :]
datachain/client/s3.py CHANGED
@@ -55,7 +55,7 @@ class ClientS3(Client):
55
55
  except NotImplementedError:
56
56
  pass
57
57
 
58
- return cast(S3FileSystem, super().create_fs(**kwargs))
58
+ return cast("S3FileSystem", super().create_fs(**kwargs))
59
59
 
60
60
  def url(self, path: str, expires: int = 3600, **kwargs) -> str:
61
61
  """
@@ -119,6 +119,8 @@ class AbstractMetastore(ABC, Serializable):
119
119
  query_script: str = "",
120
120
  schema: Optional[dict[str, Any]] = None,
121
121
  ignore_if_exists: bool = False,
122
+ description: Optional[str] = None,
123
+ labels: Optional[list[str]] = None,
122
124
  ) -> DatasetRecord:
123
125
  """Creates new dataset."""
124
126
 
@@ -518,6 +520,8 @@ class AbstractDBMetastore(AbstractMetastore):
518
520
  query_script: str = "",
519
521
  schema: Optional[dict[str, Any]] = None,
520
522
  ignore_if_exists: bool = False,
523
+ description: Optional[str] = None,
524
+ labels: Optional[list[str]] = None,
521
525
  **kwargs, # TODO registered = True / False
522
526
  ) -> DatasetRecord:
523
527
  """Creates new dataset."""
@@ -533,6 +537,8 @@ class AbstractDBMetastore(AbstractMetastore):
533
537
  sources="\n".join(sources) if sources else "",
534
538
  query_script=query_script,
535
539
  schema=json.dumps(schema or {}),
540
+ description=description,
541
+ labels=json.dumps(labels or []),
536
542
  )
537
543
  if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
538
544
  # SQLite and PostgreSQL both support 'on_conflict_do_nothing',
@@ -39,13 +39,6 @@ if TYPE_CHECKING:
39
39
  from datachain.data_storage.schema import DataTable
40
40
  from datachain.lib.file import File
41
41
 
42
- try:
43
- import numpy as np
44
-
45
- numpy_imported = True
46
- except ImportError:
47
- numpy_imported = False
48
-
49
42
 
50
43
  logger = logging.getLogger("datachain")
51
44
 
@@ -96,7 +89,9 @@ class AbstractWarehouse(ABC, Serializable):
96
89
  If value is a list or some other iterable, it tries to convert sub elements
97
90
  as well
98
91
  """
99
- if numpy_imported and isinstance(val, (np.ndarray, np.generic)):
92
+ import numpy as np
93
+
94
+ if isinstance(val, (np.ndarray, np.generic)):
100
95
  val = val.tolist()
101
96
 
102
97
  # Optimization: Precompute all the column type variables.
datachain/dataset.py CHANGED
@@ -302,6 +302,7 @@ class DatasetListVersion:
302
302
  size: Optional[int],
303
303
  query_script: str = "",
304
304
  job_id: Optional[str] = None,
305
+ **kwargs,
305
306
  ):
306
307
  return cls(
307
308
  id,
@@ -648,6 +649,13 @@ class DatasetListRecord:
648
649
  def has_version_with_uuid(self, uuid: str) -> bool:
649
650
  return any(v.uuid == uuid for v in self.versions)
650
651
 
652
+ @classmethod
653
+ def from_dict(cls, d: dict[str, Any]) -> "DatasetListRecord":
654
+ versions = [DatasetListVersion.parse(**v) for v in d.get("versions", [])]
655
+ kwargs = {f.name: d[f.name] for f in fields(cls) if f.name in d}
656
+ kwargs["versions"] = versions
657
+ return cls(**kwargs)
658
+
651
659
 
652
660
  class RowDict(dict):
653
661
  pass
datachain/error.py CHANGED
@@ -1,15 +1,3 @@
1
- import botocore.errorfactory
2
- import botocore.exceptions
3
- import gcsfs.retry
4
-
5
- REMOTE_ERRORS = (
6
- gcsfs.retry.HttpError, # GCS
7
- OSError, # GCS
8
- botocore.exceptions.BotoCoreError, # S3
9
- ValueError, # Azure
10
- )
11
-
12
-
13
1
  class DataChainError(RuntimeError):
14
2
  pass
15
3
 
datachain/fs/utils.py ADDED
@@ -0,0 +1,30 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from fsspec.implementations.local import LocalFileSystem
4
+
5
+ if TYPE_CHECKING:
6
+ from fsspec import AbstractFileSystem
7
+
8
+
9
+ def _isdir(fs: "AbstractFileSystem", path: str) -> bool:
10
+ info = fs.info(path)
11
+ return info["type"] == "directory" or (
12
+ info["size"] == 0 and info["type"] == "file" and info["name"].endswith("/")
13
+ )
14
+
15
+
16
+ def isfile(fs: "AbstractFileSystem", path: str) -> bool:
17
+ """
18
+ Returns True if uri points to a file.
19
+
20
+ Supports special directories on object storages, e.g.:
21
+ Google creates a zero byte file with the same name as the directory with a trailing
22
+ slash at the end.
23
+ """
24
+ if isinstance(fs, LocalFileSystem):
25
+ return fs.isfile(path)
26
+
27
+ try:
28
+ return not _isdir(fs, path)
29
+ except FileNotFoundError:
30
+ return False
@@ -18,6 +18,7 @@ from .aggregate import (
18
18
  from .array import contains, cosine_distance, euclidean_distance, length, sip_hash_64
19
19
  from .conditional import and_, case, greatest, ifelse, isnone, least, or_
20
20
  from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
21
+ from .path import file_ext, file_stem, name, parent
21
22
  from .random import rand
22
23
  from .string import byte_hamming_distance
23
24
  from .window import window
@@ -40,6 +41,8 @@ __all__ = [
40
41
  "count",
41
42
  "dense_rank",
42
43
  "euclidean_distance",
44
+ "file_ext",
45
+ "file_stem",
43
46
  "first",
44
47
  "greatest",
45
48
  "ifelse",
@@ -50,7 +53,9 @@ __all__ = [
50
53
  "literal",
51
54
  "max",
52
55
  "min",
56
+ "name",
53
57
  "or_",
58
+ "parent",
54
59
  "path",
55
60
  "rand",
56
61
  "random",
datachain/func/func.py CHANGED
@@ -3,7 +3,6 @@ from collections.abc import Sequence
3
3
  from typing import TYPE_CHECKING, Any, Callable, Optional, Union
4
4
 
5
5
  from sqlalchemy import BindParameter, Case, ColumnElement, Integer, cast, desc
6
- from sqlalchemy.ext.hybrid import Comparator
7
6
  from sqlalchemy.sql import func as sa_func
8
7
 
9
8
  from datachain.lib.convert.python_to_sql import python_to_sql
@@ -75,6 +74,8 @@ class Func(Function):
75
74
 
76
75
  @property
77
76
  def _db_cols(self) -> Sequence[ColT]:
77
+ from sqlalchemy.ext.hybrid import Comparator
78
+
78
79
  return (
79
80
  [
80
81
  col