datachain 0.6.1__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/dc.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import copy
2
2
  import os
3
+ import os.path
3
4
  import re
5
+ import sys
4
6
  from collections.abc import Iterator, Sequence
5
7
  from functools import wraps
6
8
  from typing import (
@@ -23,6 +25,9 @@ from pydantic import BaseModel
23
25
  from sqlalchemy.sql.functions import GenericFunction
24
26
  from sqlalchemy.sql.sqltypes import NullType
25
27
 
28
+ from datachain.client import Client
29
+ from datachain.client.local import FileClient
30
+ from datachain.dataset import DatasetRecord
26
31
  from datachain.lib.convert.python_to_sql import python_to_sql
27
32
  from datachain.lib.convert.values_to_tuples import values_to_tuples
28
33
  from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
@@ -31,9 +36,6 @@ from datachain.lib.file import ArrowRow, File, get_file_type
31
36
  from datachain.lib.file import ExportPlacement as FileExportPlacement
32
37
  from datachain.lib.func import Func
33
38
  from datachain.lib.listing import (
34
- is_listing_dataset,
35
- is_listing_expired,
36
- is_listing_subset,
37
39
  list_bucket,
38
40
  ls,
39
41
  parse_listing_uri,
@@ -51,7 +53,7 @@ from datachain.query.dataset import DatasetQuery, PartitionByType
51
53
  from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
52
54
  from datachain.sql.functions import path as pathfunc
53
55
  from datachain.telemetry import telemetry
54
- from datachain.utils import batched_it, inside_notebook
56
+ from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
55
57
 
56
58
  if TYPE_CHECKING:
57
59
  from pyarrow import DataType as ArrowDataType
@@ -287,6 +289,13 @@ class DataChain:
287
289
  """Version of the underlying dataset, if there is one."""
288
290
  return self._query.version
289
291
 
292
+ @property
293
+ def dataset(self) -> Optional[DatasetRecord]:
294
+ """Underlying dataset, if there is one."""
295
+ if not self.name:
296
+ return None
297
+ return self.session.catalog.get_dataset(self.name)
298
+
290
299
  def __or__(self, other: "Self") -> "Self":
291
300
  """Return `self.union(other)`."""
292
301
  return self.union(other)
@@ -367,6 +376,47 @@ class DataChain:
367
376
  self.signals_schema |= signals_schema
368
377
  return self
369
378
 
379
+ @classmethod
380
+ def parse_uri(
381
+ cls, uri: str, session: Session, update: bool = False
382
+ ) -> tuple[str, str, str, bool]:
383
+ """Returns correct listing dataset name that must be used for saving listing
384
+ operation. It takes into account existing listings and reusability of those.
385
+ It also returns boolean saying if returned dataset name is reused / already
386
+ exists or not, and it returns correct listing path that should be used to find
387
+ rows based on uri.
388
+ """
389
+ catalog = session.catalog
390
+ cache = catalog.cache
391
+ client_config = catalog.client_config
392
+
393
+ client = Client.get_client(uri, cache, **client_config)
394
+ ds_name, list_uri, list_path = parse_listing_uri(uri, cache, client_config)
395
+ listing = None
396
+
397
+ listings = [
398
+ ls
399
+ for ls in catalog.listings()
400
+ if not ls.is_expired and ls.contains(ds_name)
401
+ ]
402
+
403
+ if listings:
404
+ if update:
405
+ # choosing the smallest possible one to minimize update time
406
+ listing = sorted(listings, key=lambda ls: len(ls.name))[0]
407
+ else:
408
+ # no need to update, choosing the most recent one
409
+ listing = sorted(listings, key=lambda ls: ls.created_at)[-1]
410
+
411
+ if isinstance(client, FileClient) and listing and listing.name != ds_name:
412
+ # For local file system we need to fix listing path / prefix
413
+ # if we are reusing existing listing
414
+ list_path = f'{ds_name.strip("/").removeprefix(listing.name)}/{list_path}'
415
+
416
+ ds_name = listing.name if listing else ds_name
417
+
418
+ return ds_name, list_uri, list_path, bool(listing)
419
+
370
420
  @classmethod
371
421
  def from_storage(
372
422
  cls,
@@ -401,25 +451,15 @@ class DataChain:
401
451
  file_type = get_file_type(type)
402
452
 
403
453
  client_config = {"anon": True} if anon else None
404
-
405
454
  session = Session.get(session, client_config=client_config, in_memory=in_memory)
455
+ cache = session.catalog.cache
456
+ client_config = session.catalog.client_config
406
457
 
407
- list_dataset_name, list_uri, list_path = parse_listing_uri(
408
- uri, session.catalog.cache, session.catalog.client_config
458
+ list_ds_name, list_uri, list_path, list_ds_exists = cls.parse_uri(
459
+ uri, session, update=update
409
460
  )
410
- need_listing = True
411
-
412
- for ds in cls.listings(session=session, in_memory=in_memory).collect("listing"):
413
- if (
414
- not is_listing_expired(ds.created_at) # type: ignore[union-attr]
415
- and is_listing_subset(ds.name, list_dataset_name) # type: ignore[union-attr]
416
- and not update
417
- ):
418
- need_listing = False
419
- list_dataset_name = ds.name # type: ignore[union-attr]
420
461
 
421
- if need_listing:
422
- # caching new listing to special listing dataset
462
+ if update or not list_ds_exists:
423
463
  (
424
464
  cls.from_records(
425
465
  DataChain.DEFAULT_FILE_RECORD,
@@ -428,17 +468,13 @@ class DataChain:
428
468
  in_memory=in_memory,
429
469
  )
430
470
  .gen(
431
- list_bucket(
432
- list_uri,
433
- session.catalog.cache,
434
- client_config=session.catalog.client_config,
435
- ),
471
+ list_bucket(list_uri, cache, client_config=client_config),
436
472
  output={f"{object_name}": File},
437
473
  )
438
- .save(list_dataset_name, listing=True)
474
+ .save(list_ds_name, listing=True)
439
475
  )
440
476
 
441
- dc = cls.from_dataset(list_dataset_name, session=session, settings=settings)
477
+ dc = cls.from_dataset(list_ds_name, session=session, settings=settings)
442
478
  dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
443
479
 
444
480
  return ls(dc, list_path, recursive=recursive, object_name=object_name)
@@ -665,19 +701,11 @@ class DataChain:
665
701
  session = Session.get(session, in_memory=in_memory)
666
702
  catalog = kwargs.get("catalog") or session.catalog
667
703
 
668
- listings = [
669
- ListingInfo.from_models(d, v, j)
670
- for d, v, j in catalog.list_datasets_versions(
671
- include_listing=True, **kwargs
672
- )
673
- if is_listing_dataset(d.name)
674
- ]
675
-
676
704
  return cls.from_values(
677
705
  session=session,
678
706
  in_memory=in_memory,
679
707
  output={object_name: ListingInfo},
680
- **{object_name: listings}, # type: ignore[arg-type]
708
+ **{object_name: catalog.listings()}, # type: ignore[arg-type]
681
709
  )
682
710
 
683
711
  def print_json_schema( # type: ignore[override]
@@ -1006,6 +1034,9 @@ class DataChain:
1006
1034
  """Group rows by specified set of signals and return new signals
1007
1035
  with aggregated values.
1008
1036
 
1037
+ The supported functions:
1038
+ count(), sum(), avg(), min(), max(), any_value(), collect(), concat()
1039
+
1009
1040
  Example:
1010
1041
  ```py
1011
1042
  chain = chain.group_by(
@@ -1069,13 +1100,22 @@ class DataChain:
1069
1100
  Filename: name(), parent(), file_stem(), file_ext()
1070
1101
  Array: length(), sip_hash_64(), euclidean_distance(),
1071
1102
  cosine_distance()
1103
+ Window: row_number(), rank(), dense_rank(), first()
1072
1104
 
1073
1105
  Example:
1074
1106
  ```py
1075
1107
  dc.mutate(
1076
- area=Column("image.height") * Column("image.width"),
1077
- extension=file_ext(Column("file.name")),
1078
- dist=cosine_distance(embedding_text, embedding_image)
1108
+ area=Column("image.height") * Column("image.width"),
1109
+ extension=file_ext(Column("file.name")),
1110
+ dist=cosine_distance(embedding_text, embedding_image)
1111
+ )
1112
+ ```
1113
+
1114
+ Window function example:
1115
+ ```py
1116
+ window = func.window(partition_by="file.parent", order_by="file.size")
1117
+ dc.mutate(
1118
+ row_number=func.row_number().over(window),
1079
1119
  )
1080
1120
  ```
1081
1121
 
@@ -1086,20 +1126,12 @@ class DataChain:
1086
1126
  Example:
1087
1127
  ```py
1088
1128
  dc.mutate(
1089
- newkey=Column("oldkey")
1129
+ newkey=Column("oldkey")
1090
1130
  )
1091
1131
  ```
1092
1132
  """
1093
- existing_columns = set(self.signals_schema.values.keys())
1094
- for col_name in kwargs:
1095
- if col_name in existing_columns:
1096
- raise DataChainColumnError(
1097
- col_name,
1098
- "Cannot modify existing column with mutate(). "
1099
- "Use a different name for the new column.",
1100
- )
1101
1133
  for col_name, expr in kwargs.items():
1102
- if not isinstance(expr, Column) and isinstance(expr.type, NullType):
1134
+ if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType):
1103
1135
  raise DataChainColumnError(
1104
1136
  col_name, f"Cannot infer type with expression {expr}"
1105
1137
  )
@@ -1111,6 +1143,9 @@ class DataChain:
1111
1143
  # renaming existing column
1112
1144
  for signal in schema.db_signals(name=value.name, as_columns=True):
1113
1145
  mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr]
1146
+ elif isinstance(value, Func):
1147
+ # adding new signal
1148
+ mutated[name] = value.get_column(schema)
1114
1149
  else:
1115
1150
  # adding new signal
1116
1151
  mutated[name] = value
@@ -1887,21 +1922,48 @@ class DataChain:
1887
1922
  path: Union[str, os.PathLike[str], BinaryIO],
1888
1923
  partition_cols: Optional[Sequence[str]] = None,
1889
1924
  chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE,
1925
+ fs_kwargs: Optional[dict[str, Any]] = None,
1890
1926
  **kwargs,
1891
1927
  ) -> None:
1892
1928
  """Save chain to parquet file with SignalSchema metadata.
1893
1929
 
1894
1930
  Parameters:
1895
- path : Path or a file-like binary object to save the file.
1931
+ path : Path or a file-like binary object to save the file. This supports
1932
+ local paths as well as remote paths, such as s3:// or hf:// with fsspec.
1896
1933
  partition_cols : Column names by which to partition the dataset.
1897
1934
  chunk_size : The chunk size of results to read and convert to columnar
1898
1935
  data, to avoid running out of memory.
1936
+ fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
1937
+ write, for fsspec-type URLs, such as s3:// or hf:// when
1938
+ provided as the destination path.
1899
1939
  """
1900
1940
  import pyarrow as pa
1901
1941
  import pyarrow.parquet as pq
1902
1942
 
1903
1943
  from datachain.lib.arrow import DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY
1904
1944
 
1945
+ fsspec_fs = None
1946
+
1947
+ if isinstance(path, str) and "://" in path:
1948
+ from datachain.client.fsspec import Client
1949
+
1950
+ fs_kwargs = {
1951
+ **self._query.catalog.client_config,
1952
+ **(fs_kwargs or {}),
1953
+ }
1954
+
1955
+ client = Client.get_implementation(path)
1956
+
1957
+ if path.startswith("file://"):
1958
+ # pyarrow does not handle file:// uris, and needs a direct path instead.
1959
+ from urllib.parse import urlparse
1960
+
1961
+ path = urlparse(path).path
1962
+ if sys.platform == "win32":
1963
+ path = os.path.normpath(path.lstrip("/"))
1964
+
1965
+ fsspec_fs = client.create_fs(**fs_kwargs)
1966
+
1905
1967
  _partition_cols = list(partition_cols) if partition_cols else None
1906
1968
  signal_schema_metadata = orjson.dumps(
1907
1969
  self._effective_signals_schema.serialize()
@@ -1936,12 +1998,15 @@ class DataChain:
1936
1998
  table,
1937
1999
  root_path=path,
1938
2000
  partition_cols=_partition_cols,
2001
+ filesystem=fsspec_fs,
1939
2002
  **kwargs,
1940
2003
  )
1941
2004
  else:
1942
2005
  if first_chunk:
1943
2006
  # Write to a single parquet file.
1944
- parquet_writer = pq.ParquetWriter(path, parquet_schema, **kwargs)
2007
+ parquet_writer = pq.ParquetWriter(
2008
+ path, parquet_schema, filesystem=fsspec_fs, **kwargs
2009
+ )
1945
2010
  first_chunk = False
1946
2011
 
1947
2012
  assert parquet_writer
@@ -1954,28 +2019,122 @@ class DataChain:
1954
2019
  self,
1955
2020
  path: Union[str, os.PathLike[str]],
1956
2021
  delimiter: str = ",",
2022
+ fs_kwargs: Optional[dict[str, Any]] = None,
1957
2023
  **kwargs,
1958
2024
  ) -> None:
1959
2025
  """Save chain to a csv (comma-separated values) file.
1960
2026
 
1961
2027
  Parameters:
1962
- path : Path to save the file.
2028
+ path : Path to save the file. This supports local paths as well as
2029
+ remote paths, such as s3:// or hf:// with fsspec.
1963
2030
  delimiter : Delimiter to use for the resulting file.
2031
+ fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
2032
+ write, for fsspec-type URLs, such as s3:// or hf:// when
2033
+ provided as the destination path.
1964
2034
  """
1965
2035
  import csv
1966
2036
 
2037
+ opener = open
2038
+
2039
+ if isinstance(path, str) and "://" in path:
2040
+ from datachain.client.fsspec import Client
2041
+
2042
+ fs_kwargs = {
2043
+ **self._query.catalog.client_config,
2044
+ **(fs_kwargs or {}),
2045
+ }
2046
+
2047
+ client = Client.get_implementation(path)
2048
+
2049
+ fsspec_fs = client.create_fs(**fs_kwargs)
2050
+
2051
+ opener = fsspec_fs.open
2052
+
1967
2053
  headers, _ = self._effective_signals_schema.get_headers_with_length()
1968
2054
  column_names = [".".join(filter(None, header)) for header in headers]
1969
2055
 
1970
2056
  results_iter = self.collect_flatten()
1971
2057
 
1972
- with open(path, "w", newline="") as f:
2058
+ with opener(path, "w", newline="") as f:
1973
2059
  writer = csv.writer(f, delimiter=delimiter, **kwargs)
1974
2060
  writer.writerow(column_names)
1975
2061
 
1976
2062
  for row in results_iter:
1977
2063
  writer.writerow(row)
1978
2064
 
2065
+ def to_json(
2066
+ self,
2067
+ path: Union[str, os.PathLike[str]],
2068
+ fs_kwargs: Optional[dict[str, Any]] = None,
2069
+ include_outer_list: bool = True,
2070
+ ) -> None:
2071
+ """Save chain to a JSON file.
2072
+
2073
+ Parameters:
2074
+ path : Path to save the file. This supports local paths as well as
2075
+ remote paths, such as s3:// or hf:// with fsspec.
2076
+ fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
2077
+ write, for fsspec-type URLs, such as s3:// or hf:// when
2078
+ provided as the destination path.
2079
+ include_outer_list : Sets whether to include an outer list for all rows.
2080
+ Setting this to True makes the file valid JSON, while False instead
2081
+ writes in the JSON lines format.
2082
+ """
2083
+ opener = open
2084
+
2085
+ if isinstance(path, str) and "://" in path:
2086
+ from datachain.client.fsspec import Client
2087
+
2088
+ fs_kwargs = {
2089
+ **self._query.catalog.client_config,
2090
+ **(fs_kwargs or {}),
2091
+ }
2092
+
2093
+ client = Client.get_implementation(path)
2094
+
2095
+ fsspec_fs = client.create_fs(**fs_kwargs)
2096
+
2097
+ opener = fsspec_fs.open
2098
+
2099
+ headers, _ = self._effective_signals_schema.get_headers_with_length()
2100
+ headers = [list(filter(None, header)) for header in headers]
2101
+
2102
+ is_first = True
2103
+
2104
+ with opener(path, "wb") as f:
2105
+ if include_outer_list:
2106
+ # This makes the file JSON instead of JSON lines.
2107
+ f.write(b"[\n")
2108
+ for row in self.collect_flatten():
2109
+ if not is_first:
2110
+ if include_outer_list:
2111
+ # This makes the file JSON instead of JSON lines.
2112
+ f.write(b",\n")
2113
+ else:
2114
+ f.write(b"\n")
2115
+ else:
2116
+ is_first = False
2117
+ f.write(orjson.dumps(row_to_nested_dict(headers, row)))
2118
+ if include_outer_list:
2119
+ # This makes the file JSON instead of JSON lines.
2120
+ f.write(b"\n]\n")
2121
+
2122
+ def to_jsonl(
2123
+ self,
2124
+ path: Union[str, os.PathLike[str]],
2125
+ fs_kwargs: Optional[dict[str, Any]] = None,
2126
+ ) -> None:
2127
+ """Save chain to a JSON lines file.
2128
+
2129
+ Parameters:
2130
+ path : Path to save the file. This supports local paths as well as
2131
+ remote paths, such as s3:// or hf:// with fsspec.
2132
+ fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
2133
+ write, for fsspec-type URLs, such as s3:// or hf:// when
2134
+ provided as the destination path.
2135
+ """
2136
+ self.to_json(path, fs_kwargs, include_outer_list=False)
2137
+
1979
2138
  @classmethod
1980
2139
  def from_records(
1981
2140
  cls,
@@ -1,5 +1,18 @@
1
- from .aggregate import any_value, avg, collect, concat, count, max, min, sum
2
- from .func import Func
1
+ from .aggregate import (
2
+ any_value,
3
+ avg,
4
+ collect,
5
+ concat,
6
+ count,
7
+ dense_rank,
8
+ first,
9
+ max,
10
+ min,
11
+ rank,
12
+ row_number,
13
+ sum,
14
+ )
15
+ from .func import Func, window
3
16
 
4
17
  __all__ = [
5
18
  "Func",
@@ -8,7 +21,12 @@ __all__ = [
8
21
  "collect",
9
22
  "concat",
10
23
  "count",
24
+ "dense_rank",
25
+ "first",
11
26
  "max",
12
27
  "min",
28
+ "rank",
29
+ "row_number",
13
30
  "sum",
31
+ "window",
14
32
  ]