datachain 0.3.17__py3-none-any.whl → 0.3.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/__init__.py CHANGED
@@ -1,21 +1,23 @@
1
1
  from datachain.lib.data_model import DataModel, DataType, is_chain_type
2
2
  from datachain.lib.dc import C, Column, DataChain, Sys
3
3
  from datachain.lib.file import (
4
+ ArrowRow,
4
5
  File,
5
6
  FileError,
6
7
  ImageFile,
7
- IndexedFile,
8
8
  TarVFile,
9
9
  TextFile,
10
10
  )
11
11
  from datachain.lib.model_store import ModelStore
12
12
  from datachain.lib.udf import Aggregator, Generator, Mapper
13
13
  from datachain.lib.utils import AbstractUDF, DataChainError
14
+ from datachain.query import metrics, param
14
15
  from datachain.query.session import Session
15
16
 
16
17
  __all__ = [
17
18
  "AbstractUDF",
18
19
  "Aggregator",
20
+ "ArrowRow",
19
21
  "C",
20
22
  "Column",
21
23
  "DataChain",
@@ -26,7 +28,6 @@ __all__ = [
26
28
  "FileError",
27
29
  "Generator",
28
30
  "ImageFile",
29
- "IndexedFile",
30
31
  "Mapper",
31
32
  "ModelStore",
32
33
  "Session",
@@ -34,4 +35,6 @@ __all__ = [
34
35
  "TarVFile",
35
36
  "TextFile",
36
37
  "is_chain_type",
38
+ "metrics",
39
+ "param",
37
40
  ]
datachain/cache.py CHANGED
@@ -1,56 +1,15 @@
1
- import hashlib
2
- import json
3
1
  import os
4
- from datetime import datetime
5
- from functools import partial
6
2
  from typing import TYPE_CHECKING, Optional
7
3
 
8
- import attrs
9
4
  from dvc_data.hashfile.db.local import LocalHashFileDB
10
5
  from dvc_objects.fs.local import LocalFileSystem
11
6
  from fsspec.callbacks import Callback, TqdmCallback
12
7
 
13
- from datachain.utils import TIME_ZERO
14
-
15
8
  from .progress import Tqdm
16
9
 
17
10
  if TYPE_CHECKING:
18
11
  from datachain.client import Client
19
- from datachain.storage import StorageURI
20
-
21
- sha256 = partial(hashlib.sha256, usedforsecurity=False)
22
-
23
-
24
- @attrs.frozen
25
- class UniqueId:
26
- storage: "StorageURI"
27
- path: str
28
- size: int
29
- etag: str
30
- version: str = ""
31
- is_latest: bool = True
32
- location: Optional[str] = None
33
- last_modified: datetime = TIME_ZERO
34
-
35
- def get_parsed_location(self) -> Optional[dict]:
36
- if not self.location:
37
- return None
38
-
39
- loc_stack = (
40
- json.loads(self.location)
41
- if isinstance(self.location, str)
42
- else self.location
43
- )
44
- if len(loc_stack) > 1:
45
- raise NotImplementedError("Nested v-objects are not supported yet.")
46
-
47
- return loc_stack[0]
48
-
49
- def get_hash(self) -> str:
50
- fingerprint = f"{self.storage}/{self.path}/{self.version}/{self.etag}"
51
- if self.location:
52
- fingerprint += f"/{self.location}"
53
- return sha256(fingerprint.encode()).hexdigest()
12
+ from datachain.lib.file import File
54
13
 
55
14
 
56
15
  def try_scandir(path):
@@ -77,30 +36,30 @@ class DataChainCache:
77
36
  def tmp_dir(self):
78
37
  return self.odb.tmp_dir
79
38
 
80
- def get_path(self, uid: UniqueId) -> Optional[str]:
81
- if self.contains(uid):
82
- return self.path_from_checksum(uid.get_hash())
39
+ def get_path(self, file: "File") -> Optional[str]:
40
+ if self.contains(file):
41
+ return self.path_from_checksum(file.get_hash())
83
42
  return None
84
43
 
85
- def contains(self, uid: UniqueId) -> bool:
86
- return self.odb.exists(uid.get_hash())
44
+ def contains(self, file: "File") -> bool:
45
+ return self.odb.exists(file.get_hash())
87
46
 
88
47
  def path_from_checksum(self, checksum: str) -> str:
89
48
  assert checksum
90
49
  return self.odb.oid_to_path(checksum)
91
50
 
92
- def remove(self, uid: UniqueId) -> None:
93
- self.odb.delete(uid.get_hash())
51
+ def remove(self, file: "File") -> None:
52
+ self.odb.delete(file.get_hash())
94
53
 
95
54
  async def download(
96
- self, uid: UniqueId, client: "Client", callback: Optional[Callback] = None
55
+ self, file: "File", client: "Client", callback: Optional[Callback] = None
97
56
  ) -> None:
98
- from_path = f"{uid.storage}/{uid.path}"
57
+ from_path = f"{file.source}/{file.path}"
99
58
  from dvc_objects.fs.utils import tmp_fname
100
59
 
101
60
  odb_fs = self.odb.fs
102
61
  tmp_info = odb_fs.join(self.odb.tmp_dir, tmp_fname()) # type: ignore[arg-type]
103
- size = uid.size
62
+ size = file.size
104
63
  if size < 0:
105
64
  size = await client.get_size(from_path)
106
65
  cb = callback or TqdmCallback(
@@ -115,13 +74,13 @@ class DataChainCache:
115
74
  cb.close()
116
75
 
117
76
  try:
118
- oid = uid.get_hash()
77
+ oid = file.get_hash()
119
78
  self.odb.add(tmp_info, self.odb.fs, oid)
120
79
  finally:
121
80
  os.unlink(tmp_info)
122
81
 
123
- def store_data(self, uid: UniqueId, contents: bytes) -> None:
124
- checksum = uid.get_hash()
82
+ def store_data(self, file: "File", contents: bytes) -> None:
83
+ checksum = file.get_hash()
125
84
  dst = self.path_from_checksum(checksum)
126
85
  if not os.path.exists(dst):
127
86
  # Create the file only if it's not already in cache
@@ -1,4 +1,3 @@
1
- import ast
2
1
  import glob
3
2
  import io
4
3
  import json
@@ -34,7 +33,7 @@ import yaml
34
33
  from sqlalchemy import Column
35
34
  from tqdm import tqdm
36
35
 
37
- from datachain.cache import DataChainCache, UniqueId
36
+ from datachain.cache import DataChainCache
38
37
  from datachain.client import Client
39
38
  from datachain.config import get_remote_config, read_config
40
39
  from datachain.dataset import (
@@ -53,9 +52,9 @@ from datachain.error import (
53
52
  DataChainError,
54
53
  DatasetInvalidVersionError,
55
54
  DatasetNotFoundError,
55
+ DatasetVersionNotFoundError,
56
56
  PendingIndexingError,
57
57
  QueryScriptCancelError,
58
- QueryScriptCompileError,
59
58
  QueryScriptRunError,
60
59
  )
61
60
  from datachain.listing import Listing
@@ -588,44 +587,13 @@ class Catalog:
588
587
  def generate_query_dataset_name(cls) -> str:
589
588
  return f"{QUERY_DATASET_PREFIX}_{uuid4().hex}"
590
589
 
591
- def attach_query_wrapper(self, code_ast):
592
- if code_ast.body:
593
- last_expr = code_ast.body[-1]
594
- if isinstance(last_expr, ast.Expr):
595
- new_expressions = [
596
- ast.Import(
597
- names=[ast.alias(name="datachain.query.dataset", asname=None)]
598
- ),
599
- ast.Expr(
600
- value=ast.Call(
601
- func=ast.Attribute(
602
- value=ast.Attribute(
603
- value=ast.Attribute(
604
- value=ast.Name(id="datachain", ctx=ast.Load()),
605
- attr="query",
606
- ctx=ast.Load(),
607
- ),
608
- attr="dataset",
609
- ctx=ast.Load(),
610
- ),
611
- attr="query_wrapper",
612
- ctx=ast.Load(),
613
- ),
614
- args=[last_expr],
615
- keywords=[],
616
- )
617
- ),
618
- ]
619
- code_ast.body[-1:] = new_expressions
620
- return code_ast
621
-
622
- def get_client(self, uri: StorageURI, **config: Any) -> Client:
590
+ def get_client(self, uri: str, **config: Any) -> Client:
623
591
  """
624
592
  Return the client corresponding to the given source `uri`.
625
593
  """
626
594
  config = config or self.client_config
627
595
  cls = Client.get_implementation(uri)
628
- return cls.from_source(uri, self.cache, **config)
596
+ return cls.from_source(StorageURI(uri), self.cache, **config)
629
597
 
630
598
  def enlist_source(
631
599
  self,
@@ -1218,7 +1186,9 @@ class Catalog:
1218
1186
 
1219
1187
  dataset_version = dataset.get_version(version)
1220
1188
  if not dataset_version:
1221
- raise ValueError(f"Dataset {dataset.name} does not have version {version}")
1189
+ raise DatasetVersionNotFoundError(
1190
+ f"Dataset {dataset.name} does not have version {version}"
1191
+ )
1222
1192
 
1223
1193
  if not dataset_version.is_final_status():
1224
1194
  raise ValueError("Cannot register dataset version in non final status")
@@ -1431,7 +1401,7 @@ class Catalog:
1431
1401
 
1432
1402
  def get_file_signals(
1433
1403
  self, dataset_name: str, dataset_version: int, row: RowDict
1434
- ) -> Optional[dict]:
1404
+ ) -> Optional[RowDict]:
1435
1405
  """
1436
1406
  Function that returns file signals from dataset row.
1437
1407
  Note that signal names are without prefix, so if there was 'laion__file__source'
@@ -1448,7 +1418,7 @@ class Catalog:
1448
1418
 
1449
1419
  version = self.get_dataset(dataset_name).get_version(dataset_version)
1450
1420
 
1451
- file_signals_values = {}
1421
+ file_signals_values = RowDict()
1452
1422
 
1453
1423
  schema = SignalSchema.deserialize(version.feature_schema)
1454
1424
  for file_signals in schema.get_signals(File):
@@ -1476,6 +1446,8 @@ class Catalog:
1476
1446
  use_cache: bool = True,
1477
1447
  **config: Any,
1478
1448
  ):
1449
+ from datachain.lib.file import File
1450
+
1479
1451
  file_signals = self.get_file_signals(dataset_name, dataset_version, row)
1480
1452
  if not file_signals:
1481
1453
  raise RuntimeError("Cannot open object without file signals")
@@ -1483,22 +1455,10 @@ class Catalog:
1483
1455
  config = config or self.client_config
1484
1456
  client = self.get_client(file_signals["source"], **config)
1485
1457
  return client.open_object(
1486
- self._get_row_uid(file_signals), # type: ignore [arg-type]
1458
+ File._from_row(file_signals),
1487
1459
  use_cache=use_cache,
1488
1460
  )
1489
1461
 
1490
- def _get_row_uid(self, row: RowDict) -> UniqueId:
1491
- return UniqueId(
1492
- row["source"],
1493
- row["path"],
1494
- row["size"],
1495
- row["etag"],
1496
- row["version"],
1497
- row["is_latest"],
1498
- row["location"],
1499
- row["last_modified"],
1500
- )
1501
-
1502
1462
  def ls(
1503
1463
  self,
1504
1464
  sources: list[str],
@@ -1591,7 +1551,7 @@ class Catalog:
1591
1551
 
1592
1552
  try:
1593
1553
  remote_dataset_version = remote_dataset.get_version(version)
1594
- except (ValueError, StopIteration) as exc:
1554
+ except (DatasetVersionNotFoundError, StopIteration) as exc:
1595
1555
  raise DataChainError(
1596
1556
  f"Dataset {remote_dataset_name} doesn't have version {version}"
1597
1557
  " on server"
@@ -1732,64 +1692,24 @@ class Catalog:
1732
1692
  query_script: str,
1733
1693
  env: Optional[Mapping[str, str]] = None,
1734
1694
  python_executable: str = sys.executable,
1735
- save: bool = False,
1736
- capture_output: bool = True,
1695
+ capture_output: bool = False,
1737
1696
  output_hook: Callable[[str], None] = noop,
1738
1697
  params: Optional[dict[str, str]] = None,
1739
1698
  job_id: Optional[str] = None,
1740
- _execute_last_expression: bool = False,
1741
1699
  ) -> None:
1742
- """
1743
- Method to run custom user Python script to run a query and, as result,
1744
- creates new dataset from the results of a query.
1745
- Returns tuple of result dataset and script output.
1746
-
1747
- Constraints on query script:
1748
- 1. datachain.query.DatasetQuery should be used in order to create query
1749
- for a dataset
1750
- 2. There should not be any .save() call on DatasetQuery since the idea
1751
- is to create only one dataset as the outcome of the script
1752
- 3. Last statement must be an instance of DatasetQuery
1753
-
1754
- If save is set to True, we are creating new dataset with results
1755
- from dataset query. If it's set to False, we will just print results
1756
- without saving anything
1757
-
1758
- Example of query script:
1759
- from datachain.query import DatasetQuery, C
1760
- DatasetQuery('s3://ldb-public/remote/datasets/mnist-tiny/').filter(
1761
- C.size > 1000
1762
- )
1763
- """
1764
- if _execute_last_expression:
1765
- try:
1766
- code_ast = ast.parse(query_script)
1767
- code_ast = self.attach_query_wrapper(code_ast)
1768
- query_script_compiled = ast.unparse(code_ast)
1769
- except Exception as exc:
1770
- raise QueryScriptCompileError(
1771
- f"Query script failed to compile, reason: {exc}"
1772
- ) from exc
1773
- else:
1774
- query_script_compiled = query_script
1775
- assert not save
1776
-
1700
+ cmd = [python_executable, "-c", query_script]
1777
1701
  env = dict(env or os.environ)
1778
1702
  env.update(
1779
1703
  {
1780
1704
  "DATACHAIN_QUERY_PARAMS": json.dumps(params or {}),
1781
- "PYTHONPATH": os.getcwd(), # For local imports
1782
- "DATACHAIN_QUERY_SAVE": "1" if save else "",
1783
- "PYTHONUNBUFFERED": "1",
1784
1705
  "DATACHAIN_JOB_ID": job_id or "",
1785
1706
  },
1786
1707
  )
1787
- popen_kwargs = {}
1708
+ popen_kwargs: dict[str, Any] = {}
1788
1709
  if capture_output:
1789
1710
  popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT}
1790
1711
 
1791
- cmd = [python_executable, "-c", query_script_compiled]
1792
- with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # type: ignore[call-overload] # noqa: S603
1712
+ with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603
1793
1713
  if capture_output:
1794
1714
  args = (proc.stdout, output_hook)
1795
1715
  thread = Thread(target=_process_stream, args=args, daemon=True)
datachain/cli.py CHANGED
@@ -15,6 +15,7 @@ import shtab
15
15
  from datachain import utils
16
16
  from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
17
17
  from datachain.lib.dc import DataChain
18
+ from datachain.telemetry import telemetry
18
19
  from datachain.utils import DataChainDir
19
20
 
20
21
  if TYPE_CHECKING:
@@ -803,7 +804,6 @@ def query(
803
804
  catalog.query(
804
805
  script_content,
805
806
  python_executable=python_executable,
806
- capture_output=False,
807
807
  params=params,
808
808
  job_id=job_id,
809
809
  )
@@ -872,6 +872,7 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
872
872
  # This also sets this environment variable for any subprocesses
873
873
  os.environ["DEBUG_SHOW_SQL_QUERIES"] = "True"
874
874
 
875
+ error = None
875
876
  try:
876
877
  catalog = get_catalog(client_config=client_config)
877
878
  if args.command == "cp":
@@ -1003,14 +1004,16 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
1003
1004
  print(f"invalid command: {args.command}", file=sys.stderr)
1004
1005
  return 1
1005
1006
  return 0
1006
- except BrokenPipeError:
1007
+ except BrokenPipeError as exc:
1007
1008
  # Python flushes standard streams on exit; redirect remaining output
1008
1009
  # to devnull to avoid another BrokenPipeError at shutdown
1009
1010
  # See: https://docs.python.org/3/library/signal.html#note-on-sigpipe
1011
+ error = str(exc)
1010
1012
  devnull = os.open(os.devnull, os.O_WRONLY)
1011
1013
  os.dup2(devnull, sys.stdout.fileno())
1012
1014
  return 141 # 128 + 13 (SIGPIPE)
1013
1015
  except (KeyboardInterrupt, Exception) as exc:
1016
+ error = str(exc)
1014
1017
  if isinstance(exc, KeyboardInterrupt):
1015
1018
  msg = "Operation cancelled by the user"
1016
1019
  else:
@@ -1028,3 +1031,5 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
1028
1031
 
1029
1032
  pdb.post_mortem()
1030
1033
  return 1
1034
+ finally:
1035
+ telemetry.send_cli_call(args.command, error=error)
@@ -3,7 +3,6 @@ import functools
3
3
  import logging
4
4
  import multiprocessing
5
5
  import os
6
- import posixpath
7
6
  import re
8
7
  import sys
9
8
  from abc import ABC, abstractmethod
@@ -26,8 +25,8 @@ from fsspec.asyn import get_loop, sync
26
25
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
27
26
  from tqdm import tqdm
28
27
 
29
- from datachain.cache import DataChainCache, UniqueId
30
- from datachain.client.fileslice import FileSlice, FileWrapper
28
+ from datachain.cache import DataChainCache
29
+ from datachain.client.fileslice import FileWrapper
31
30
  from datachain.error import ClientError as DataChainClientError
32
31
  from datachain.lib.file import File
33
32
  from datachain.nodes_fetcher import NodesFetcher
@@ -187,8 +186,8 @@ class Client(ABC):
187
186
  def url(self, path: str, expires: int = 3600, **kwargs) -> str:
188
187
  return self.fs.sign(self.get_full_path(path), expiration=expires, **kwargs)
189
188
 
190
- async def get_current_etag(self, uid: UniqueId) -> str:
191
- info = await self.fs._info(self.get_full_path(uid.path))
189
+ async def get_current_etag(self, file: "File") -> str:
190
+ info = await self.fs._info(self.get_full_path(file.path))
192
191
  return self.info_to_file(info, "").etag
193
192
 
194
193
  async def get_size(self, path: str) -> int:
@@ -317,7 +316,7 @@ class Client(ABC):
317
316
 
318
317
  def instantiate_object(
319
318
  self,
320
- uid: UniqueId,
319
+ file: "File",
321
320
  dst: str,
322
321
  progress_bar: tqdm,
323
322
  force: bool = False,
@@ -328,10 +327,10 @@ class Client(ABC):
328
327
  else:
329
328
  progress_bar.close()
330
329
  raise FileExistsError(f"Path {dst} already exists")
331
- self.do_instantiate_object(uid, dst)
330
+ self.do_instantiate_object(file, dst)
332
331
 
333
- def do_instantiate_object(self, uid: "UniqueId", dst: str) -> None:
334
- src = self.cache.get_path(uid)
332
+ def do_instantiate_object(self, file: "File", dst: str) -> None:
333
+ src = self.cache.get_path(file)
335
334
  assert src is not None
336
335
 
337
336
  try:
@@ -341,66 +340,33 @@ class Client(ABC):
341
340
  copy2(src, dst)
342
341
 
343
342
  def open_object(
344
- self, uid: UniqueId, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
343
+ self, file: File, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
345
344
  ) -> BinaryIO:
346
345
  """Open a file, including files in tar archives."""
347
- location = uid.get_parsed_location()
348
- if use_cache and (cache_path := self.cache.get_path(uid)):
346
+ if use_cache and (cache_path := self.cache.get_path(file)):
349
347
  return open(cache_path, mode="rb") # noqa: SIM115
350
- if location and location["vtype"] == "tar":
351
- return self._open_tar(uid, use_cache=True)
352
- return FileWrapper(self.fs.open(self.get_full_path(uid.path)), cb) # type: ignore[return-value]
353
-
354
- def _open_tar(self, uid: UniqueId, use_cache: bool = True):
355
- location = uid.get_parsed_location()
356
- assert location
357
-
358
- offset = location["offset"]
359
- size = location["size"]
360
- parent = location["parent"]
361
-
362
- parent_uid = UniqueId(
363
- parent["source"],
364
- parent["path"],
365
- parent["size"],
366
- parent["etag"],
367
- location=parent["location"],
368
- )
369
- f = self.open_object(parent_uid, use_cache=use_cache)
370
- return FileSlice(f, offset, size, posixpath.basename(uid.path))
371
-
372
- def download(self, uid: UniqueId, *, callback: Callback = DEFAULT_CALLBACK) -> None:
373
- sync(get_loop(), functools.partial(self._download, uid, callback=callback))
374
-
375
- async def _download(self, uid: UniqueId, *, callback: "Callback" = None) -> None:
376
- if self.cache.contains(uid):
348
+ assert not file.location
349
+ return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value]
350
+
351
+ def download(self, file: File, *, callback: Callback = DEFAULT_CALLBACK) -> None:
352
+ sync(get_loop(), functools.partial(self._download, file, callback=callback))
353
+
354
+ async def _download(self, file: File, *, callback: "Callback" = None) -> None:
355
+ if self.cache.contains(file):
377
356
  # Already in cache, so there's nothing to do.
378
357
  return
379
- await self._put_in_cache(uid, callback=callback)
358
+ await self._put_in_cache(file, callback=callback)
380
359
 
381
- def put_in_cache(self, uid: UniqueId, *, callback: "Callback" = None) -> None:
382
- sync(get_loop(), functools.partial(self._put_in_cache, uid, callback=callback))
360
+ def put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
361
+ sync(get_loop(), functools.partial(self._put_in_cache, file, callback=callback))
383
362
 
384
- async def _put_in_cache(
385
- self, uid: UniqueId, *, callback: "Callback" = None
386
- ) -> None:
387
- location = uid.get_parsed_location()
388
- if location and location["vtype"] == "tar":
389
- loop = asyncio.get_running_loop()
390
- await loop.run_in_executor(
391
- None, functools.partial(self._download_from_tar, uid, callback=callback)
392
- )
393
- return
394
- if uid.etag:
395
- etag = await self.get_current_etag(uid)
396
- if uid.etag != etag:
363
+ async def _put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
364
+ assert not file.location
365
+ if file.etag:
366
+ etag = await self.get_current_etag(file)
367
+ if file.etag != etag:
397
368
  raise FileNotFoundError(
398
- f"Invalid etag for {uid.storage}/{uid.path}: "
399
- f"expected {uid.etag}, got {etag}"
369
+ f"Invalid etag for {file.source}/{file.path}: "
370
+ f"expected {file.etag}, got {etag}"
400
371
  )
401
- await self.cache.download(uid, self, callback=callback)
402
-
403
- def _download_from_tar(self, uid, *, callback: "Callback" = None):
404
- with self._open_tar(uid, use_cache=False) as f:
405
- contents = f.read()
406
- self.cache.store_data(uid, contents)
372
+ await self.cache.download(file, self, callback=callback)
datachain/client/local.py CHANGED
@@ -7,7 +7,6 @@ from urllib.parse import urlparse
7
7
 
8
8
  from fsspec.implementations.local import LocalFileSystem
9
9
 
10
- from datachain.cache import UniqueId
11
10
  from datachain.lib.file import File
12
11
  from datachain.storage import StorageURI
13
12
 
@@ -114,8 +113,8 @@ class FileClient(Client):
114
113
  use_symlinks=use_symlinks,
115
114
  )
116
115
 
117
- async def get_current_etag(self, uid: UniqueId) -> str:
118
- info = self.fs.info(self.get_full_path(uid.path))
116
+ async def get_current_etag(self, file: "File") -> str:
117
+ info = self.fs.info(self.get_full_path(file.path))
119
118
  return self.info_to_file(info, "").etag
120
119
 
121
120
  async def get_size(self, path: str) -> int:
datachain/dataset.py CHANGED
@@ -12,6 +12,7 @@ from typing import (
12
12
  from urllib.parse import urlparse
13
13
 
14
14
  from datachain.client import Client
15
+ from datachain.error import DatasetVersionNotFoundError
15
16
  from datachain.sql.types import NAME_TYPES_MAPPING, SQLType
16
17
 
17
18
  if TYPE_CHECKING:
@@ -417,7 +418,9 @@ class DatasetRecord:
417
418
 
418
419
  def get_version(self, version: int) -> DatasetVersion:
419
420
  if not self.has_version(version):
420
- raise ValueError(f"Dataset {self.name} does not have version {version}")
421
+ raise DatasetVersionNotFoundError(
422
+ f"Dataset {self.name} does not have version {version}"
423
+ )
421
424
  return next(
422
425
  v
423
426
  for v in self.versions # type: ignore [union-attr]
@@ -435,7 +438,9 @@ class DatasetRecord:
435
438
  Get identifier in the form my-dataset@v3
436
439
  """
437
440
  if not self.has_version(version):
438
- raise ValueError(f"Dataset {self.name} doesn't have a version {version}")
441
+ raise DatasetVersionNotFoundError(
442
+ f"Dataset {self.name} doesn't have a version {version}"
443
+ )
439
444
  return f"{self.name}@v{version}"
440
445
 
441
446
  def uri(self, version: int) -> str:
datachain/error.py CHANGED
@@ -10,6 +10,10 @@ class DatasetNotFoundError(NotFoundError):
10
10
  pass
11
11
 
12
12
 
13
+ class DatasetVersionNotFoundError(NotFoundError):
14
+ pass
15
+
16
+
13
17
  class DatasetInvalidVersionError(Exception):
14
18
  pass
15
19
 
@@ -32,14 +36,12 @@ class QueryScriptRunError(Exception):
32
36
  Attributes:
33
37
  message Explanation of the error
34
38
  return_code Code returned by the subprocess
35
- output STDOUT + STDERR output of the subprocess
36
39
  """
37
40
 
38
- def __init__(self, message: str, return_code: int = 0, output: str = ""):
41
+ def __init__(self, message: str, return_code: int = 0):
39
42
  self.message = message
40
43
  self.return_code = return_code
41
- self.output = output
42
- super().__init__(self.message)
44
+ super().__init__(message)
43
45
 
44
46
 
45
47
  class QueryScriptCancelError(QueryScriptRunError):
datachain/lib/arrow.py CHANGED
@@ -4,11 +4,11 @@ from tempfile import NamedTemporaryFile
4
4
  from typing import TYPE_CHECKING, Optional
5
5
 
6
6
  import pyarrow as pa
7
- from pyarrow.dataset import dataset
7
+ from pyarrow.dataset import CsvFileFormat, dataset
8
8
  from tqdm import tqdm
9
9
 
10
10
  from datachain.lib.data_model import dict_to_data_model
11
- from datachain.lib.file import File, IndexedFile
11
+ from datachain.lib.file import ArrowRow, File
12
12
  from datachain.lib.model_store import ModelStore
13
13
  from datachain.lib.udf import Generator
14
14
 
@@ -49,7 +49,8 @@ class ArrowGenerator(Generator):
49
49
 
50
50
  def process(self, file: File):
51
51
  if file._caching_enabled:
52
- path = file.get_local_path(download=True)
52
+ file.ensure_cached()
53
+ path = file.get_local_path()
53
54
  ds = dataset(path, schema=self.input_schema, **self.kwargs)
54
55
  elif self.nrows:
55
56
  path = _nrows_file(file, self.nrows)
@@ -83,7 +84,12 @@ class ArrowGenerator(Generator):
83
84
  vals_dict[field] = val
84
85
  vals = [self.output_schema(**vals_dict)]
85
86
  if self.source:
86
- yield [IndexedFile(file=file, index=index), *vals]
87
+ kwargs: dict = self.kwargs
88
+ # Can't serialize CsvFileFormat; may lose formatting options.
89
+ if isinstance(kwargs.get("format"), CsvFileFormat):
90
+ kwargs["format"] = "csv"
91
+ arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs)
92
+ yield [arrow_file, *vals]
87
93
  else:
88
94
  yield vals
89
95
  index += 1