datachain 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -52,6 +52,7 @@ from datachain.error import (
52
52
  QueryScriptCancelError,
53
53
  QueryScriptRunError,
54
54
  )
55
+ from datachain.lib.listing import get_listing
55
56
  from datachain.node import DirType, Node, NodeWithPath
56
57
  from datachain.nodes_thread_pool import NodesThreadPool
57
58
  from datachain.remote.studio import StudioClient
@@ -599,7 +600,7 @@ class Catalog:
599
600
  source, session=self.session, update=update, object_name=object_name
600
601
  )
601
602
 
602
- list_ds_name, list_uri, list_path, _ = DataChain.parse_uri(
603
+ list_ds_name, list_uri, list_path, _ = get_listing(
603
604
  source, self.session, update=update
604
605
  )
605
606
 
@@ -697,11 +698,9 @@ class Catalog:
697
698
  )
698
699
  indexed_sources = []
699
700
  for source in dataset_sources:
700
- from datachain.lib.dc import DataChain
701
-
702
701
  client = self.get_client(source, **client_config)
703
702
  uri = client.uri
704
- dataset_name, _, _, _ = DataChain.parse_uri(uri, self.session)
703
+ dataset_name, _, _, _ = get_listing(uri, self.session)
705
704
  listing = Listing(
706
705
  self.metastore.clone(),
707
706
  self.warehouse.clone(),
datachain/client/gcs.py CHANGED
@@ -32,6 +32,16 @@ class GCSClient(Client):
32
32
 
33
33
  return cast(GCSFileSystem, super().create_fs(**kwargs))
34
34
 
35
+ def url(self, path: str, expires: int = 3600, **kwargs) -> str:
36
+ """
37
+ Generate a signed URL for the given path.
38
+ If the client is anonymous, a public URL is returned instead
39
+ (see https://cloud.google.com/storage/docs/access-public-data#api-link).
40
+ """
41
+ if self.fs.storage_options.get("token") == "anon":
42
+ return f"https://storage.googleapis.com/{self.name}/{path}"
43
+ return self.fs.sign(self.get_full_path(path), expiration=expires, **kwargs)
44
+
35
45
  @staticmethod
36
46
  def parse_timestamp(timestamp: str) -> datetime:
37
47
  """
@@ -216,7 +216,6 @@ class AbstractWarehouse(ABC, Serializable):
216
216
  limit = query._limit
217
217
  paginated_query = query.limit(page_size)
218
218
 
219
- results = None
220
219
  offset = 0
221
220
  num_yielded = 0
222
221
 
datachain/lib/arrow.py CHANGED
@@ -1,9 +1,11 @@
1
1
  from collections.abc import Sequence
2
- from tempfile import NamedTemporaryFile
2
+ from itertools import islice
3
3
  from typing import TYPE_CHECKING, Any, Optional
4
4
 
5
+ import fsspec.implementations.reference
5
6
  import orjson
6
7
  import pyarrow as pa
8
+ from fsspec.core import split_protocol
7
9
  from pyarrow.dataset import CsvFileFormat, dataset
8
10
  from tqdm import tqdm
9
11
 
@@ -25,7 +27,18 @@ if TYPE_CHECKING:
25
27
  DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
26
28
 
27
29
 
30
+ class ReferenceFileSystem(fsspec.implementations.reference.ReferenceFileSystem):
31
+ def _open(self, path, mode="rb", *args, **kwargs):
32
+ # overriding because `fsspec`'s `ReferenceFileSystem._open`
33
+ # reads the whole file in-memory.
34
+ (uri,) = self.references[path]
35
+ protocol, _ = split_protocol(uri)
36
+ return self.fss[protocol]._open(uri, mode, *args, **kwargs)
37
+
38
+
28
39
  class ArrowGenerator(Generator):
40
+ DEFAULT_BATCH_SIZE = 2**17 # same as `pyarrow._dataset._DEFAULT_BATCH_SIZE`
41
+
29
42
  def __init__(
30
43
  self,
31
44
  input_schema: Optional["pa.Schema"] = None,
@@ -55,57 +68,80 @@ class ArrowGenerator(Generator):
55
68
  def process(self, file: File):
56
69
  if file._caching_enabled:
57
70
  file.ensure_cached()
58
- path = file.get_local_path()
59
- ds = dataset(path, schema=self.input_schema, **self.kwargs)
60
- elif self.nrows:
61
- path = _nrows_file(file, self.nrows)
62
- ds = dataset(path, schema=self.input_schema, **self.kwargs)
71
+ cache_path = file.get_local_path()
72
+ fs_path = file.path
73
+ fs = ReferenceFileSystem({fs_path: [cache_path]})
63
74
  else:
64
- path = file.get_path()
65
- ds = dataset(
66
- path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs
67
- )
75
+ fs, fs_path = file.get_fs(), file.get_path()
76
+
77
+ ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **self.kwargs)
78
+
68
79
  hf_schema = _get_hf_schema(ds.schema)
69
80
  use_datachain_schema = (
70
81
  bool(ds.schema.metadata)
71
82
  and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in ds.schema.metadata
72
83
  )
73
- index = 0
74
- with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
75
- for record_batch in ds.to_batches():
76
- for record in record_batch.to_pylist():
77
- if use_datachain_schema and self.output_schema:
78
- vals = [_nested_model_instantiate(record, self.output_schema)]
79
- else:
80
- vals = list(record.values())
81
- if self.output_schema:
82
- fields = self.output_schema.model_fields
83
- vals_dict = {}
84
- for i, ((field, field_info), val) in enumerate(
85
- zip(fields.items(), vals)
86
- ):
87
- anno = field_info.annotation
88
- if hf_schema:
89
- from datachain.lib.hf import convert_feature
90
-
91
- feat = list(hf_schema[0].values())[i]
92
- vals_dict[field] = convert_feature(val, feat, anno)
93
- elif ModelStore.is_pydantic(anno):
94
- vals_dict[field] = anno(**val) # type: ignore[misc]
95
- else:
96
- vals_dict[field] = val
97
- vals = [self.output_schema(**vals_dict)]
98
- if self.source:
99
- kwargs: dict = self.kwargs
100
- # Can't serialize CsvFileFormat; may lose formatting options.
101
- if isinstance(kwargs.get("format"), CsvFileFormat):
102
- kwargs["format"] = "csv"
103
- arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs)
104
- yield [arrow_file, *vals]
105
- else:
106
- yield vals
107
- index += 1
108
- pbar.update(len(record_batch))
84
+
85
+ kw = {}
86
+ if self.nrows:
87
+ kw = {"batch_size": min(self.DEFAULT_BATCH_SIZE, self.nrows)}
88
+
89
+ def iter_records():
90
+ for record_batch in ds.to_batches(**kw):
91
+ yield from record_batch.to_pylist()
92
+
93
+ it = islice(iter_records(), self.nrows)
94
+ with tqdm(it, desc="Parsed by pyarrow", unit="rows", total=self.nrows) as pbar:
95
+ for index, record in enumerate(pbar):
96
+ yield self._process_record(
97
+ record, file, index, hf_schema, use_datachain_schema
98
+ )
99
+
100
+ def _process_record(
101
+ self,
102
+ record: dict[str, Any],
103
+ file: File,
104
+ index: int,
105
+ hf_schema: Optional[tuple["Features", dict[str, "DataType"]]],
106
+ use_datachain_schema: bool,
107
+ ):
108
+ if use_datachain_schema and self.output_schema:
109
+ vals = [_nested_model_instantiate(record, self.output_schema)]
110
+ else:
111
+ vals = self._process_non_datachain_record(record, hf_schema)
112
+
113
+ if self.source:
114
+ kwargs: dict = self.kwargs
115
+ # Can't serialize CsvFileFormat; may lose formatting options.
116
+ if isinstance(kwargs.get("format"), CsvFileFormat):
117
+ kwargs["format"] = "csv"
118
+ arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs)
119
+ return [arrow_file, *vals]
120
+ return vals
121
+
122
+ def _process_non_datachain_record(
123
+ self,
124
+ record: dict[str, Any],
125
+ hf_schema: Optional[tuple["Features", dict[str, "DataType"]]],
126
+ ):
127
+ vals = list(record.values())
128
+ if not self.output_schema:
129
+ return vals
130
+
131
+ fields = self.output_schema.model_fields
132
+ vals_dict = {}
133
+ for i, ((field, field_info), val) in enumerate(zip(fields.items(), vals)):
134
+ anno = field_info.annotation
135
+ if hf_schema:
136
+ from datachain.lib.hf import convert_feature
137
+
138
+ feat = list(hf_schema[0].values())[i]
139
+ vals_dict[field] = convert_feature(val, feat, anno)
140
+ elif ModelStore.is_pydantic(anno):
141
+ vals_dict[field] = anno(**val) # type: ignore[misc]
142
+ else:
143
+ vals_dict[field] = val
144
+ return [self.output_schema(**vals_dict)]
109
145
 
110
146
 
111
147
  def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
@@ -190,18 +226,6 @@ def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa:
190
226
  raise TypeError(f"{col_type!r} datatypes not supported, column: {column}")
191
227
 
192
228
 
193
- def _nrows_file(file: File, nrows: int) -> str:
194
- tf = NamedTemporaryFile(delete=False) # noqa: SIM115
195
- with file.open(mode="r") as reader:
196
- with open(tf.name, "a") as writer:
197
- for row, line in enumerate(reader):
198
- if row >= nrows:
199
- break
200
- writer.write(line)
201
- writer.write("\n")
202
- return tf.name
203
-
204
-
205
229
  def _get_hf_schema(
206
230
  schema: "pa.Schema",
207
231
  ) -> Optional[tuple["Features", dict[str, "DataType"]]]:
datachain/lib/dc.py CHANGED
@@ -11,7 +11,6 @@ from typing import (
11
11
  BinaryIO,
12
12
  Callable,
13
13
  ClassVar,
14
- Literal,
15
14
  Optional,
16
15
  TypeVar,
17
16
  Union,
@@ -24,8 +23,6 @@ from pydantic import BaseModel
24
23
  from sqlalchemy.sql.functions import GenericFunction
25
24
  from sqlalchemy.sql.sqltypes import NullType
26
25
 
27
- from datachain.client import Client
28
- from datachain.client.local import FileClient
29
26
  from datachain.dataset import DatasetRecord
30
27
  from datachain.func.base import Function
31
28
  from datachain.func.func import Func
@@ -33,13 +30,9 @@ from datachain.lib.convert.python_to_sql import python_to_sql
33
30
  from datachain.lib.convert.values_to_tuples import values_to_tuples
34
31
  from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_data_model
35
32
  from datachain.lib.dataset_info import DatasetInfo
36
- from datachain.lib.file import ArrowRow, File, get_file_type
33
+ from datachain.lib.file import ArrowRow, File, FileType, get_file_type
37
34
  from datachain.lib.file import ExportPlacement as FileExportPlacement
38
- from datachain.lib.listing import (
39
- list_bucket,
40
- ls,
41
- parse_listing_uri,
42
- )
35
+ from datachain.lib.listing import get_listing, list_bucket, ls
43
36
  from datachain.lib.listing_info import ListingInfo
44
37
  from datachain.lib.meta_formats import read_meta
45
38
  from datachain.lib.model_store import ModelStore
@@ -403,53 +396,12 @@ class DataChain:
403
396
  self.signals_schema |= signals_schema
404
397
  return self
405
398
 
406
- @classmethod
407
- def parse_uri(
408
- cls, uri: str, session: Session, update: bool = False
409
- ) -> tuple[str, str, str, bool]:
410
- """Returns correct listing dataset name that must be used for saving listing
411
- operation. It takes into account existing listings and reusability of those.
412
- It also returns boolean saying if returned dataset name is reused / already
413
- exists or not, and it returns correct listing path that should be used to find
414
- rows based on uri.
415
- """
416
- catalog = session.catalog
417
- cache = catalog.cache
418
- client_config = catalog.client_config
419
-
420
- client = Client.get_client(uri, cache, **client_config)
421
- ds_name, list_uri, list_path = parse_listing_uri(uri, cache, client_config)
422
- listing = None
423
-
424
- listings = [
425
- ls
426
- for ls in catalog.listings()
427
- if not ls.is_expired and ls.contains(ds_name)
428
- ]
429
-
430
- if listings:
431
- if update:
432
- # choosing the smallest possible one to minimize update time
433
- listing = sorted(listings, key=lambda ls: len(ls.name))[0]
434
- else:
435
- # no need to update, choosing the most recent one
436
- listing = sorted(listings, key=lambda ls: ls.created_at)[-1]
437
-
438
- if isinstance(client, FileClient) and listing and listing.name != ds_name:
439
- # For local file system we need to fix listing path / prefix
440
- # if we are reusing existing listing
441
- list_path = f'{ds_name.strip("/").removeprefix(listing.name)}/{list_path}'
442
-
443
- ds_name = listing.name if listing else ds_name
444
-
445
- return ds_name, list_uri, list_path, bool(listing)
446
-
447
399
  @classmethod
448
400
  def from_storage(
449
401
  cls,
450
402
  uri,
451
403
  *,
452
- type: Literal["binary", "text", "image"] = "binary",
404
+ type: FileType = "binary",
453
405
  session: Optional[Session] = None,
454
406
  settings: Optional[dict] = None,
455
407
  in_memory: bool = False,
@@ -482,7 +434,7 @@ class DataChain:
482
434
  cache = session.catalog.cache
483
435
  client_config = session.catalog.client_config
484
436
 
485
- list_ds_name, list_uri, list_path, list_ds_exists = cls.parse_uri(
437
+ list_ds_name, list_uri, list_path, list_ds_exists = get_listing(
486
438
  uri, session, update=update
487
439
  )
488
440
 
@@ -548,7 +500,7 @@ class DataChain:
548
500
  def from_json(
549
501
  cls,
550
502
  path,
551
- type: Literal["binary", "text", "image"] = "text",
503
+ type: FileType = "text",
552
504
  spec: Optional[DataType] = None,
553
505
  schema_from: Optional[str] = "auto",
554
506
  jmespath: Optional[str] = None,
@@ -605,7 +557,9 @@ class DataChain:
605
557
  nrows=nrows,
606
558
  )
607
559
  }
608
- return chain.gen(**signal_dict) # type: ignore[misc, arg-type]
560
+ # disable prefetch if nrows is set
561
+ settings = {"prefetch": 0} if nrows else {}
562
+ return chain.settings(**settings).gen(**signal_dict) # type: ignore[misc, arg-type]
609
563
 
610
564
  def explode(
611
565
  self,
@@ -1942,7 +1896,10 @@ class DataChain:
1942
1896
 
1943
1897
  if source:
1944
1898
  output = {"source": ArrowRow} | output # type: ignore[assignment,operator]
1945
- return self.gen(
1899
+
1900
+ # disable prefetch if nrows is set
1901
+ settings = {"prefetch": 0} if nrows else {}
1902
+ return self.settings(**settings).gen( # type: ignore[arg-type]
1946
1903
  ArrowGenerator(schema, model, source, nrows, **kwargs), output=output
1947
1904
  )
1948
1905
 
@@ -2024,8 +1981,6 @@ class DataChain:
2024
1981
  else:
2025
1982
  msg = f"error parsing csv - incompatible output type {type(output)}"
2026
1983
  raise DatasetPrepareError(chain.name, msg)
2027
- elif nrows:
2028
- nrows += 1
2029
1984
 
2030
1985
  parse_options = ParseOptions(delimiter=delimiter)
2031
1986
  read_options = ReadOptions(column_names=column_names)
datachain/lib/file.py CHANGED
@@ -39,6 +39,8 @@ logger = logging.getLogger("datachain")
39
39
  # how to create file path when exporting
40
40
  ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
41
41
 
42
+ FileType = Literal["binary", "text", "image"]
43
+
42
44
 
43
45
  class VFileError(DataChainError):
44
46
  def __init__(self, file: "File", message: str, vtype: str = ""):
@@ -470,7 +472,7 @@ class ArrowRow(DataModel):
470
472
  return record_batch.to_pylist()[0]
471
473
 
472
474
 
473
- def get_file_type(type_: Literal["binary", "text", "image"] = "binary") -> type[File]:
475
+ def get_file_type(type_: FileType = "binary") -> type[File]:
474
476
  file: type[File] = File
475
477
  if type_ == "text":
476
478
  file = TextFile
datachain/lib/listing.py CHANGED
@@ -15,6 +15,7 @@ from datachain.utils import uses_glob
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from datachain.lib.dc import DataChain
18
+ from datachain.query.session import Session
18
19
 
19
20
  LISTING_TTL = 4 * 60 * 60 # cached listing lasts 4 hours
20
21
  LISTING_PREFIX = "lst__" # listing datasets start with this name
@@ -108,3 +109,46 @@ def listing_uri_from_name(dataset_name: str) -> str:
108
109
  if not is_listing_dataset(dataset_name):
109
110
  raise ValueError(f"Dataset {dataset_name} is not a listing")
110
111
  return dataset_name.removeprefix(LISTING_PREFIX)
112
+
113
+
114
+ def get_listing(
115
+ uri: str, session: "Session", update: bool = False
116
+ ) -> tuple[str, str, str, bool]:
117
+ """Returns correct listing dataset name that must be used for saving listing
118
+ operation. It takes into account existing listings and reusability of those.
119
+ It also returns boolean saying if returned dataset name is reused / already
120
+ exists or not (on update it always returns False - just because there was no
121
+ reason to complicate it so far). And it returns correct listing path that should
122
+ be used to find rows based on uri.
123
+ """
124
+ from datachain.client.local import FileClient
125
+
126
+ catalog = session.catalog
127
+ cache = catalog.cache
128
+ client_config = catalog.client_config
129
+
130
+ client = Client.get_client(uri, cache, **client_config)
131
+ ds_name, list_uri, list_path = parse_listing_uri(uri, cache, client_config)
132
+ listing = None
133
+
134
+ listings = [
135
+ ls for ls in catalog.listings() if not ls.is_expired and ls.contains(ds_name)
136
+ ]
137
+
138
+ # if no need to update - choosing the most recent one;
139
+ # otherwise, we'll using the exact original `ds_name`` in this case:
140
+ # - if a "bigger" listing exists, we don't want to update it, it's better
141
+ # to create a new "smaller" one on "update=True"
142
+ # - if an exact listing exists it will have the same name as `ds_name`
143
+ # anyway below
144
+ if listings and not update:
145
+ listing = sorted(listings, key=lambda ls: ls.created_at)[-1]
146
+
147
+ # for local file system we need to fix listing path / prefix
148
+ # if we are reusing existing listing
149
+ if isinstance(client, FileClient) and listing and listing.name != ds_name:
150
+ list_path = f'{ds_name.strip("/").removeprefix(listing.name)}/{list_path}'
151
+
152
+ ds_name = listing.name if listing else ds_name
153
+
154
+ return ds_name, list_uri, list_path, bool(listing)
datachain/lib/udf.py CHANGED
@@ -85,7 +85,6 @@ class UDFAdapter:
85
85
  udf_fields: "Sequence[str]",
86
86
  udf_inputs: "Iterable[RowsOutput]",
87
87
  catalog: "Catalog",
88
- is_generator: bool,
89
88
  cache: bool,
90
89
  download_cb: Callback = DEFAULT_CALLBACK,
91
90
  processed_cb: Callback = DEFAULT_CALLBACK,
datachain/query/batch.py CHANGED
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
7
7
 
8
8
  from datachain.data_storage.schema import PARTITION_COLUMN_ID
9
9
  from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
10
+ from datachain.query.utils import get_query_column, get_query_id_column
10
11
 
11
12
  if TYPE_CHECKING:
12
13
  from sqlalchemy import Select
@@ -23,11 +24,14 @@ RowsOutput = Union[Sequence, RowsOutputBatch]
23
24
  class BatchingStrategy(ABC):
24
25
  """BatchingStrategy provides means of batching UDF executions."""
25
26
 
27
+ is_batching: bool
28
+
26
29
  @abstractmethod
27
30
  def __call__(
28
31
  self,
29
- execute: Callable[..., Generator[Sequence, None, None]],
32
+ execute: Callable,
30
33
  query: "Select",
34
+ ids_only: bool = False,
31
35
  ) -> Generator[RowsOutput, None, None]:
32
36
  """Apply the provided parameters to the UDF."""
33
37
 
@@ -38,11 +42,16 @@ class NoBatching(BatchingStrategy):
38
42
  batch UDF calls.
39
43
  """
40
44
 
45
+ is_batching = False
46
+
41
47
  def __call__(
42
48
  self,
43
- execute: Callable[..., Generator[Sequence, None, None]],
49
+ execute: Callable,
44
50
  query: "Select",
51
+ ids_only: bool = False,
45
52
  ) -> Generator[Sequence, None, None]:
53
+ if ids_only:
54
+ query = query.with_only_columns(get_query_id_column(query))
46
55
  return execute(query)
47
56
 
48
57
 
@@ -52,14 +61,20 @@ class Batch(BatchingStrategy):
52
61
  is passed a sequence of multiple parameter sets.
53
62
  """
54
63
 
64
+ is_batching = True
65
+
55
66
  def __init__(self, count: int):
56
67
  self.count = count
57
68
 
58
69
  def __call__(
59
70
  self,
60
- execute: Callable[..., Generator[Sequence, None, None]],
71
+ execute: Callable,
61
72
  query: "Select",
73
+ ids_only: bool = False,
62
74
  ) -> Generator[RowsOutputBatch, None, None]:
75
+ if ids_only:
76
+ query = query.with_only_columns(get_query_id_column(query))
77
+
63
78
  # choose page size that is a multiple of the batch size
64
79
  page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count
65
80
 
@@ -84,19 +99,30 @@ class Partition(BatchingStrategy):
84
99
  Dataset rows need to be sorted by the grouping column.
85
100
  """
86
101
 
102
+ is_batching = True
103
+
87
104
  def __call__(
88
105
  self,
89
- execute: Callable[..., Generator[Sequence, None, None]],
106
+ execute: Callable,
90
107
  query: "Select",
108
+ ids_only: bool = False,
91
109
  ) -> Generator[RowsOutputBatch, None, None]:
110
+ id_col = get_query_id_column(query)
111
+ if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None:
112
+ raise RuntimeError("partition column not found in query")
113
+
114
+ if ids_only:
115
+ query = query.with_only_columns(id_col, partition_col)
116
+
92
117
  current_partition: Optional[int] = None
93
118
  batch: list[Sequence] = []
94
119
 
95
120
  query_fields = [str(c.name) for c in query.selected_columns]
121
+ id_column_idx = query_fields.index("sys__id")
96
122
  partition_column_idx = query_fields.index(PARTITION_COLUMN_ID)
97
123
 
98
124
  ordered_query = query.order_by(None).order_by(
99
- PARTITION_COLUMN_ID,
125
+ partition_col,
100
126
  *query._order_by_clauses,
101
127
  )
102
128
 
@@ -108,7 +134,7 @@ class Partition(BatchingStrategy):
108
134
  if len(batch) > 0:
109
135
  yield RowsOutputBatch(batch)
110
136
  batch = []
111
- batch.append(row)
137
+ batch.append([row[id_column_idx]] if ids_only else row)
112
138
 
113
139
  if len(batch) > 0:
114
140
  yield RowsOutputBatch(batch)