datachain 0.3.16__py3-none-any.whl → 0.3.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/cache.py CHANGED
@@ -1,56 +1,15 @@
1
- import hashlib
2
- import json
3
1
  import os
4
- from datetime import datetime
5
- from functools import partial
6
2
  from typing import TYPE_CHECKING, Optional
7
3
 
8
- import attrs
9
4
  from dvc_data.hashfile.db.local import LocalHashFileDB
10
5
  from dvc_objects.fs.local import LocalFileSystem
11
6
  from fsspec.callbacks import Callback, TqdmCallback
12
7
 
13
- from datachain.utils import TIME_ZERO
14
-
15
8
  from .progress import Tqdm
16
9
 
17
10
  if TYPE_CHECKING:
18
11
  from datachain.client import Client
19
- from datachain.storage import StorageURI
20
-
21
- sha256 = partial(hashlib.sha256, usedforsecurity=False)
22
-
23
-
24
- @attrs.frozen
25
- class UniqueId:
26
- storage: "StorageURI"
27
- path: str
28
- size: int
29
- etag: str
30
- version: str = ""
31
- is_latest: bool = True
32
- location: Optional[str] = None
33
- last_modified: datetime = TIME_ZERO
34
-
35
- def get_parsed_location(self) -> Optional[dict]:
36
- if not self.location:
37
- return None
38
-
39
- loc_stack = (
40
- json.loads(self.location)
41
- if isinstance(self.location, str)
42
- else self.location
43
- )
44
- if len(loc_stack) > 1:
45
- raise NotImplementedError("Nested v-objects are not supported yet.")
46
-
47
- return loc_stack[0]
48
-
49
- def get_hash(self) -> str:
50
- fingerprint = f"{self.storage}/{self.path}/{self.version}/{self.etag}"
51
- if self.location:
52
- fingerprint += f"/{self.location}"
53
- return sha256(fingerprint.encode()).hexdigest()
12
+ from datachain.lib.file import File
54
13
 
55
14
 
56
15
  def try_scandir(path):
@@ -77,30 +36,30 @@ class DataChainCache:
77
36
  def tmp_dir(self):
78
37
  return self.odb.tmp_dir
79
38
 
80
- def get_path(self, uid: UniqueId) -> Optional[str]:
81
- if self.contains(uid):
82
- return self.path_from_checksum(uid.get_hash())
39
+ def get_path(self, file: "File") -> Optional[str]:
40
+ if self.contains(file):
41
+ return self.path_from_checksum(file.get_hash())
83
42
  return None
84
43
 
85
- def contains(self, uid: UniqueId) -> bool:
86
- return self.odb.exists(uid.get_hash())
44
+ def contains(self, file: "File") -> bool:
45
+ return self.odb.exists(file.get_hash())
87
46
 
88
47
  def path_from_checksum(self, checksum: str) -> str:
89
48
  assert checksum
90
49
  return self.odb.oid_to_path(checksum)
91
50
 
92
- def remove(self, uid: UniqueId) -> None:
93
- self.odb.delete(uid.get_hash())
51
+ def remove(self, file: "File") -> None:
52
+ self.odb.delete(file.get_hash())
94
53
 
95
54
  async def download(
96
- self, uid: UniqueId, client: "Client", callback: Optional[Callback] = None
55
+ self, file: "File", client: "Client", callback: Optional[Callback] = None
97
56
  ) -> None:
98
- from_path = f"{uid.storage}/{uid.path}"
57
+ from_path = f"{file.source}/{file.path}"
99
58
  from dvc_objects.fs.utils import tmp_fname
100
59
 
101
60
  odb_fs = self.odb.fs
102
61
  tmp_info = odb_fs.join(self.odb.tmp_dir, tmp_fname()) # type: ignore[arg-type]
103
- size = uid.size
62
+ size = file.size
104
63
  if size < 0:
105
64
  size = await client.get_size(from_path)
106
65
  cb = callback or TqdmCallback(
@@ -115,13 +74,13 @@ class DataChainCache:
115
74
  cb.close()
116
75
 
117
76
  try:
118
- oid = uid.get_hash()
77
+ oid = file.get_hash()
119
78
  self.odb.add(tmp_info, self.odb.fs, oid)
120
79
  finally:
121
80
  os.unlink(tmp_info)
122
81
 
123
- def store_data(self, uid: UniqueId, contents: bytes) -> None:
124
- checksum = uid.get_hash()
82
+ def store_data(self, file: "File", contents: bytes) -> None:
83
+ checksum = file.get_hash()
125
84
  dst = self.path_from_checksum(checksum)
126
85
  if not os.path.exists(dst):
127
86
  # Create the file only if it's not already in cache
@@ -34,7 +34,7 @@ import yaml
34
34
  from sqlalchemy import Column
35
35
  from tqdm import tqdm
36
36
 
37
- from datachain.cache import DataChainCache, UniqueId
37
+ from datachain.cache import DataChainCache
38
38
  from datachain.client import Client
39
39
  from datachain.config import get_remote_config, read_config
40
40
  from datachain.dataset import (
@@ -68,8 +68,6 @@ from datachain.utils import (
68
68
  DataChainDir,
69
69
  batched,
70
70
  datachain_paths_join,
71
- import_object,
72
- parse_params_string,
73
71
  )
74
72
 
75
73
  from .datasource import DataSource
@@ -621,13 +619,13 @@ class Catalog:
621
619
  code_ast.body[-1:] = new_expressions
622
620
  return code_ast
623
621
 
624
- def get_client(self, uri: StorageURI, **config: Any) -> Client:
622
+ def get_client(self, uri: str, **config: Any) -> Client:
625
623
  """
626
624
  Return the client corresponding to the given source `uri`.
627
625
  """
628
626
  config = config or self.client_config
629
627
  cls = Client.get_implementation(uri)
630
- return cls.from_source(uri, self.cache, **config)
628
+ return cls.from_source(StorageURI(uri), self.cache, **config)
631
629
 
632
630
  def enlist_source(
633
631
  self,
@@ -843,7 +841,7 @@ class Catalog:
843
841
  from datachain.query import DatasetQuery
844
842
 
845
843
  def _row_to_node(d: dict[str, Any]) -> Node:
846
- del d["source"]
844
+ del d["file__source"]
847
845
  return Node.from_dict(d)
848
846
 
849
847
  enlisted_sources: list[tuple[bool, bool, Any]] = []
@@ -1148,30 +1146,28 @@ class Catalog:
1148
1146
  if not sources:
1149
1147
  raise ValueError("Sources needs to be non empty list")
1150
1148
 
1151
- from datachain.query import DatasetQuery
1149
+ from datachain.lib.dc import DataChain
1150
+ from datachain.query.session import Session
1151
+
1152
+ session = Session.get(catalog=self, client_config=client_config)
1152
1153
 
1153
- dataset_queries = []
1154
+ chains = []
1154
1155
  for source in sources:
1155
1156
  if source.startswith(DATASET_PREFIX):
1156
- dq = DatasetQuery(
1157
- name=source[len(DATASET_PREFIX) :],
1158
- catalog=self,
1159
- client_config=client_config,
1157
+ dc = DataChain.from_dataset(
1158
+ source[len(DATASET_PREFIX) :], session=session
1160
1159
  )
1161
1160
  else:
1162
- dq = DatasetQuery(
1163
- path=source,
1164
- catalog=self,
1165
- client_config=client_config,
1166
- recursive=recursive,
1161
+ dc = DataChain.from_storage(
1162
+ source, session=session, recursive=recursive
1167
1163
  )
1168
1164
 
1169
- dataset_queries.append(dq)
1165
+ chains.append(dc)
1170
1166
 
1171
1167
  # create union of all dataset queries created from sources
1172
- dq = reduce(lambda ds1, ds2: ds1.union(ds2), dataset_queries)
1168
+ dc = reduce(lambda dc1, dc2: dc1.union(dc2), chains)
1173
1169
  try:
1174
- dq.save(name)
1170
+ dc.save(name)
1175
1171
  except Exception as e: # noqa: BLE001
1176
1172
  try:
1177
1173
  ds = self.get_dataset(name)
@@ -1435,7 +1431,7 @@ class Catalog:
1435
1431
 
1436
1432
  def get_file_signals(
1437
1433
  self, dataset_name: str, dataset_version: int, row: RowDict
1438
- ) -> Optional[dict]:
1434
+ ) -> Optional[RowDict]:
1439
1435
  """
1440
1436
  Function that returns file signals from dataset row.
1441
1437
  Note that signal names are without prefix, so if there was 'laion__file__source'
@@ -1452,7 +1448,7 @@ class Catalog:
1452
1448
 
1453
1449
  version = self.get_dataset(dataset_name).get_version(dataset_version)
1454
1450
 
1455
- file_signals_values = {}
1451
+ file_signals_values = RowDict()
1456
1452
 
1457
1453
  schema = SignalSchema.deserialize(version.feature_schema)
1458
1454
  for file_signals in schema.get_signals(File):
@@ -1480,6 +1476,8 @@ class Catalog:
1480
1476
  use_cache: bool = True,
1481
1477
  **config: Any,
1482
1478
  ):
1479
+ from datachain.lib.file import File
1480
+
1483
1481
  file_signals = self.get_file_signals(dataset_name, dataset_version, row)
1484
1482
  if not file_signals:
1485
1483
  raise RuntimeError("Cannot open object without file signals")
@@ -1487,22 +1485,10 @@ class Catalog:
1487
1485
  config = config or self.client_config
1488
1486
  client = self.get_client(file_signals["source"], **config)
1489
1487
  return client.open_object(
1490
- self._get_row_uid(file_signals), # type: ignore [arg-type]
1488
+ File._from_row(file_signals),
1491
1489
  use_cache=use_cache,
1492
1490
  )
1493
1491
 
1494
- def _get_row_uid(self, row: RowDict) -> UniqueId:
1495
- return UniqueId(
1496
- row["source"],
1497
- row["path"],
1498
- row["size"],
1499
- row["etag"],
1500
- row["version"],
1501
- row["is_latest"],
1502
- row["location"],
1503
- row["last_modified"],
1504
- )
1505
-
1506
1492
  def ls(
1507
1493
  self,
1508
1494
  sources: list[str],
@@ -1731,26 +1717,6 @@ class Catalog:
1731
1717
  output, sources, client_config=client_config, recursive=recursive
1732
1718
  )
1733
1719
 
1734
- def apply_udf(
1735
- self,
1736
- udf_location: str,
1737
- source: str,
1738
- target_name: str,
1739
- parallel: Optional[int] = None,
1740
- params: Optional[str] = None,
1741
- ):
1742
- from datachain.query import DatasetQuery
1743
-
1744
- if source.startswith(DATASET_PREFIX):
1745
- ds = DatasetQuery(name=source[len(DATASET_PREFIX) :], catalog=self)
1746
- else:
1747
- ds = DatasetQuery(path=source, catalog=self)
1748
- udf = import_object(udf_location)
1749
- if params:
1750
- args, kwargs = parse_params_string(params)
1751
- udf = udf(*args, **kwargs)
1752
- ds.add_signals(udf, parallel=parallel).save(target_name)
1753
-
1754
1720
  def query(
1755
1721
  self,
1756
1722
  query_script: str,
datachain/cli.py CHANGED
@@ -15,6 +15,7 @@ import shtab
15
15
  from datachain import utils
16
16
  from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
17
17
  from datachain.lib.dc import DataChain
18
+ from datachain.telemetry import telemetry
18
19
  from datachain.utils import DataChainDir
19
20
 
20
21
  if TYPE_CHECKING:
@@ -494,27 +495,6 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
494
495
  help="Query parameters",
495
496
  )
496
497
 
497
- apply_udf_parser = subp.add_parser(
498
- "apply-udf", parents=[parent_parser], description="Apply UDF"
499
- )
500
- apply_udf_parser.add_argument("udf", type=str, help="UDF location")
501
- apply_udf_parser.add_argument("source", type=str, help="Source storage or dataset")
502
- apply_udf_parser.add_argument("target", type=str, help="Target dataset name")
503
- apply_udf_parser.add_argument(
504
- "--parallel",
505
- nargs="?",
506
- type=int,
507
- const=-1,
508
- default=None,
509
- metavar="N",
510
- help=(
511
- "Use multiprocessing to run the UDF with N worker processes. "
512
- "N defaults to the CPU count."
513
- ),
514
- )
515
- apply_udf_parser.add_argument(
516
- "--udf-params", type=str, default=None, help="UDF class parameters"
517
- )
518
498
  subp.add_parser(
519
499
  "clear-cache", parents=[parent_parser], description="Clear the local file cache"
520
500
  )
@@ -893,6 +873,7 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
893
873
  # This also sets this environment variable for any subprocesses
894
874
  os.environ["DEBUG_SHOW_SQL_QUERIES"] = "True"
895
875
 
876
+ error = None
896
877
  try:
897
878
  catalog = get_catalog(client_config=client_config)
898
879
  if args.command == "cp":
@@ -1016,10 +997,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
1016
997
  parallel=args.parallel,
1017
998
  params=args.param,
1018
999
  )
1019
- elif args.command == "apply-udf":
1020
- catalog.apply_udf(
1021
- args.udf, args.source, args.target, args.parallel, args.udf_params
1022
- )
1023
1000
  elif args.command == "clear-cache":
1024
1001
  clear_cache(catalog)
1025
1002
  elif args.command == "gc":
@@ -1028,14 +1005,16 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
1028
1005
  print(f"invalid command: {args.command}", file=sys.stderr)
1029
1006
  return 1
1030
1007
  return 0
1031
- except BrokenPipeError:
1008
+ except BrokenPipeError as exc:
1032
1009
  # Python flushes standard streams on exit; redirect remaining output
1033
1010
  # to devnull to avoid another BrokenPipeError at shutdown
1034
1011
  # See: https://docs.python.org/3/library/signal.html#note-on-sigpipe
1012
+ error = str(exc)
1035
1013
  devnull = os.open(os.devnull, os.O_WRONLY)
1036
1014
  os.dup2(devnull, sys.stdout.fileno())
1037
1015
  return 141 # 128 + 13 (SIGPIPE)
1038
1016
  except (KeyboardInterrupt, Exception) as exc:
1017
+ error = str(exc)
1039
1018
  if isinstance(exc, KeyboardInterrupt):
1040
1019
  msg = "Operation cancelled by the user"
1041
1020
  else:
@@ -1053,3 +1032,5 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
1053
1032
 
1054
1033
  pdb.post_mortem()
1055
1034
  return 1
1035
+ finally:
1036
+ telemetry.send_cli_call(args.command, error=error)
@@ -3,7 +3,6 @@ import functools
3
3
  import logging
4
4
  import multiprocessing
5
5
  import os
6
- import posixpath
7
6
  import re
8
7
  import sys
9
8
  from abc import ABC, abstractmethod
@@ -26,8 +25,8 @@ from fsspec.asyn import get_loop, sync
26
25
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
27
26
  from tqdm import tqdm
28
27
 
29
- from datachain.cache import DataChainCache, UniqueId
30
- from datachain.client.fileslice import FileSlice, FileWrapper
28
+ from datachain.cache import DataChainCache
29
+ from datachain.client.fileslice import FileWrapper
31
30
  from datachain.error import ClientError as DataChainClientError
32
31
  from datachain.lib.file import File
33
32
  from datachain.nodes_fetcher import NodesFetcher
@@ -187,8 +186,8 @@ class Client(ABC):
187
186
  def url(self, path: str, expires: int = 3600, **kwargs) -> str:
188
187
  return self.fs.sign(self.get_full_path(path), expiration=expires, **kwargs)
189
188
 
190
- async def get_current_etag(self, uid: UniqueId) -> str:
191
- info = await self.fs._info(self.get_full_path(uid.path))
189
+ async def get_current_etag(self, file: "File") -> str:
190
+ info = await self.fs._info(self.get_full_path(file.path))
192
191
  return self.info_to_file(info, "").etag
193
192
 
194
193
  async def get_size(self, path: str) -> int:
@@ -317,7 +316,7 @@ class Client(ABC):
317
316
 
318
317
  def instantiate_object(
319
318
  self,
320
- uid: UniqueId,
319
+ file: "File",
321
320
  dst: str,
322
321
  progress_bar: tqdm,
323
322
  force: bool = False,
@@ -328,10 +327,10 @@ class Client(ABC):
328
327
  else:
329
328
  progress_bar.close()
330
329
  raise FileExistsError(f"Path {dst} already exists")
331
- self.do_instantiate_object(uid, dst)
330
+ self.do_instantiate_object(file, dst)
332
331
 
333
- def do_instantiate_object(self, uid: "UniqueId", dst: str) -> None:
334
- src = self.cache.get_path(uid)
332
+ def do_instantiate_object(self, file: "File", dst: str) -> None:
333
+ src = self.cache.get_path(file)
335
334
  assert src is not None
336
335
 
337
336
  try:
@@ -341,66 +340,33 @@ class Client(ABC):
341
340
  copy2(src, dst)
342
341
 
343
342
  def open_object(
344
- self, uid: UniqueId, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
343
+ self, file: File, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
345
344
  ) -> BinaryIO:
346
345
  """Open a file, including files in tar archives."""
347
- location = uid.get_parsed_location()
348
- if use_cache and (cache_path := self.cache.get_path(uid)):
346
+ if use_cache and (cache_path := self.cache.get_path(file)):
349
347
  return open(cache_path, mode="rb") # noqa: SIM115
350
- if location and location["vtype"] == "tar":
351
- return self._open_tar(uid, use_cache=True)
352
- return FileWrapper(self.fs.open(self.get_full_path(uid.path)), cb) # type: ignore[return-value]
353
-
354
- def _open_tar(self, uid: UniqueId, use_cache: bool = True):
355
- location = uid.get_parsed_location()
356
- assert location
357
-
358
- offset = location["offset"]
359
- size = location["size"]
360
- parent = location["parent"]
361
-
362
- parent_uid = UniqueId(
363
- parent["source"],
364
- parent["path"],
365
- parent["size"],
366
- parent["etag"],
367
- location=parent["location"],
368
- )
369
- f = self.open_object(parent_uid, use_cache=use_cache)
370
- return FileSlice(f, offset, size, posixpath.basename(uid.path))
371
-
372
- def download(self, uid: UniqueId, *, callback: Callback = DEFAULT_CALLBACK) -> None:
373
- sync(get_loop(), functools.partial(self._download, uid, callback=callback))
374
-
375
- async def _download(self, uid: UniqueId, *, callback: "Callback" = None) -> None:
376
- if self.cache.contains(uid):
348
+ assert not file.location
349
+ return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value]
350
+
351
+ def download(self, file: File, *, callback: Callback = DEFAULT_CALLBACK) -> None:
352
+ sync(get_loop(), functools.partial(self._download, file, callback=callback))
353
+
354
+ async def _download(self, file: File, *, callback: "Callback" = None) -> None:
355
+ if self.cache.contains(file):
377
356
  # Already in cache, so there's nothing to do.
378
357
  return
379
- await self._put_in_cache(uid, callback=callback)
358
+ await self._put_in_cache(file, callback=callback)
380
359
 
381
- def put_in_cache(self, uid: UniqueId, *, callback: "Callback" = None) -> None:
382
- sync(get_loop(), functools.partial(self._put_in_cache, uid, callback=callback))
360
+ def put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
361
+ sync(get_loop(), functools.partial(self._put_in_cache, file, callback=callback))
383
362
 
384
- async def _put_in_cache(
385
- self, uid: UniqueId, *, callback: "Callback" = None
386
- ) -> None:
387
- location = uid.get_parsed_location()
388
- if location and location["vtype"] == "tar":
389
- loop = asyncio.get_running_loop()
390
- await loop.run_in_executor(
391
- None, functools.partial(self._download_from_tar, uid, callback=callback)
392
- )
393
- return
394
- if uid.etag:
395
- etag = await self.get_current_etag(uid)
396
- if uid.etag != etag:
363
+ async def _put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
364
+ assert not file.location
365
+ if file.etag:
366
+ etag = await self.get_current_etag(file)
367
+ if file.etag != etag:
397
368
  raise FileNotFoundError(
398
- f"Invalid etag for {uid.storage}/{uid.path}: "
399
- f"expected {uid.etag}, got {etag}"
369
+ f"Invalid etag for {file.source}/{file.path}: "
370
+ f"expected {file.etag}, got {etag}"
400
371
  )
401
- await self.cache.download(uid, self, callback=callback)
402
-
403
- def _download_from_tar(self, uid, *, callback: "Callback" = None):
404
- with self._open_tar(uid, use_cache=False) as f:
405
- contents = f.read()
406
- self.cache.store_data(uid, contents)
372
+ await self.cache.download(file, self, callback=callback)
datachain/client/local.py CHANGED
@@ -7,7 +7,6 @@ from urllib.parse import urlparse
7
7
 
8
8
  from fsspec.implementations.local import LocalFileSystem
9
9
 
10
- from datachain.cache import UniqueId
11
10
  from datachain.lib.file import File
12
11
  from datachain.storage import StorageURI
13
12
 
@@ -114,8 +113,8 @@ class FileClient(Client):
114
113
  use_symlinks=use_symlinks,
115
114
  )
116
115
 
117
- async def get_current_etag(self, uid: UniqueId) -> str:
118
- info = self.fs.info(self.get_full_path(uid.path))
116
+ async def get_current_etag(self, file: "File") -> str:
117
+ info = self.fs.info(self.get_full_path(file.path))
119
118
  return self.info_to_file(info, "").etag
120
119
 
121
120
  async def get_size(self, path: str) -> int:
@@ -297,39 +297,6 @@ class AbstractMetastore(ABC, Serializable):
297
297
  #
298
298
  # Dataset dependencies
299
299
  #
300
-
301
- def add_dependency(
302
- self,
303
- dependency: DatasetDependency,
304
- source_dataset_name: str,
305
- source_dataset_version: int,
306
- ) -> None:
307
- """Add dependency to dataset or storage."""
308
- if dependency.is_dataset:
309
- self.add_dataset_dependency(
310
- source_dataset_name,
311
- source_dataset_version,
312
- dependency.dataset_name,
313
- int(dependency.version),
314
- )
315
- else:
316
- self.add_storage_dependency(
317
- source_dataset_name,
318
- source_dataset_version,
319
- StorageURI(dependency.name),
320
- dependency.version,
321
- )
322
-
323
- @abstractmethod
324
- def add_storage_dependency(
325
- self,
326
- source_dataset_name: str,
327
- source_dataset_version: int,
328
- storage_uri: StorageURI,
329
- storage_timestamp_str: Optional[str] = None,
330
- ) -> None:
331
- """Adds storage dependency to dataset."""
332
-
333
300
  @abstractmethod
334
301
  def add_dataset_dependency(
335
302
  self,
@@ -1268,32 +1235,6 @@ class AbstractDBMetastore(AbstractMetastore):
1268
1235
  #
1269
1236
  # Dataset dependencies
1270
1237
  #
1271
-
1272
- def _insert_dataset_dependency(self, data: dict[str, Any]) -> None:
1273
- """Method for inserting dependencies."""
1274
- self.db.execute(self._datasets_dependencies_insert().values(**data))
1275
-
1276
- def add_storage_dependency(
1277
- self,
1278
- source_dataset_name: str,
1279
- source_dataset_version: int,
1280
- storage_uri: StorageURI,
1281
- storage_timestamp_str: Optional[str] = None,
1282
- ) -> None:
1283
- source_dataset = self.get_dataset(source_dataset_name)
1284
- storage = self.get_storage(storage_uri)
1285
-
1286
- self._insert_dataset_dependency(
1287
- {
1288
- "source_dataset_id": source_dataset.id,
1289
- "source_dataset_version_id": (
1290
- source_dataset.get_version(source_dataset_version).id
1291
- ),
1292
- "bucket_id": storage.id,
1293
- "bucket_version": storage_timestamp_str,
1294
- }
1295
- )
1296
-
1297
1238
  def add_dataset_dependency(
1298
1239
  self,
1299
1240
  source_dataset_name: str,
@@ -1305,15 +1246,15 @@ class AbstractDBMetastore(AbstractMetastore):
1305
1246
  source_dataset = self.get_dataset(source_dataset_name)
1306
1247
  dataset = self.get_dataset(dataset_name)
1307
1248
 
1308
- self._insert_dataset_dependency(
1309
- {
1310
- "source_dataset_id": source_dataset.id,
1311
- "source_dataset_version_id": (
1249
+ self.db.execute(
1250
+ self._datasets_dependencies_insert().values(
1251
+ source_dataset_id=source_dataset.id,
1252
+ source_dataset_version_id=(
1312
1253
  source_dataset.get_version(source_dataset_version).id
1313
1254
  ),
1314
- "dataset_id": dataset.id,
1315
- "dataset_version_id": dataset.get_version(dataset_version).id,
1316
- }
1255
+ dataset_id=dataset.id,
1256
+ dataset_version_id=dataset.get_version(dataset_version).id,
1257
+ )
1317
1258
  )
1318
1259
 
1319
1260
  def update_dataset_dependency_source(
@@ -651,11 +651,14 @@ class SQLiteWarehouse(AbstractWarehouse):
651
651
  self, dataset: DatasetRecord, version: int
652
652
  ) -> list[StorageURI]:
653
653
  dr = self.dataset_rows(dataset, version)
654
- query = dr.select(dr.c.source).distinct()
654
+ query = dr.select(dr.c.file__source).distinct()
655
655
  cur = self.db.cursor()
656
656
  cur.row_factory = sqlite3.Row # type: ignore[assignment]
657
657
 
658
- return [StorageURI(row["source"]) for row in self.db.execute(query, cursor=cur)]
658
+ return [
659
+ StorageURI(row["file__source"])
660
+ for row in self.db.execute(query, cursor=cur)
661
+ ]
659
662
 
660
663
  def merge_dataset_rows(
661
664
  self,
@@ -942,28 +942,6 @@ class AbstractWarehouse(ABC, Serializable):
942
942
  self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
943
943
  pbar.update(1)
944
944
 
945
- def changed_query(
946
- self,
947
- source_query: sa.sql.selectable.Select,
948
- target_query: sa.sql.selectable.Select,
949
- ) -> sa.sql.selectable.Select:
950
- sq = source_query.alias("source_query")
951
- tq = target_query.alias("target_query")
952
-
953
- source_target_join = sa.join(
954
- sq, tq, (sq.c.source == tq.c.source) & (sq.c.path == tq.c.path)
955
- )
956
-
957
- return (
958
- select(*sq.c)
959
- .select_from(source_target_join)
960
- .where(
961
- (sq.c.last_modified > tq.c.last_modified)
962
- & (sq.c.is_latest == true())
963
- & (tq.c.is_latest == true())
964
- )
965
- )
966
-
967
945
 
968
946
  def _random_string(length: int) -> str:
969
947
  return "".join(
datachain/lib/arrow.py CHANGED
@@ -49,7 +49,8 @@ class ArrowGenerator(Generator):
49
49
 
50
50
  def process(self, file: File):
51
51
  if file._caching_enabled:
52
- path = file.get_local_path(download=True)
52
+ file.ensure_cached()
53
+ path = file.get_local_path()
53
54
  ds = dataset(path, schema=self.input_schema, **self.kwargs)
54
55
  elif self.nrows:
55
56
  path = _nrows_file(file, self.nrows)