datachain 0.13.1__py3-none-any.whl → 0.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -0,0 +1,170 @@
1
+ import os.path
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ Optional,
5
+ Union,
6
+ )
7
+
8
+ from datachain.lib.file import (
9
+ FileType,
10
+ get_file_type,
11
+ )
12
+ from datachain.lib.listing import (
13
+ get_file_info,
14
+ get_listing,
15
+ list_bucket,
16
+ ls,
17
+ )
18
+ from datachain.query import Session
19
+
20
+ if TYPE_CHECKING:
21
+ from .datachain import DataChain
22
+
23
+
24
+ def from_storage(
25
+ uri: Union[str, os.PathLike[str], list[str], list[os.PathLike[str]]],
26
+ *,
27
+ type: FileType = "binary",
28
+ session: Optional[Session] = None,
29
+ settings: Optional[dict] = None,
30
+ in_memory: bool = False,
31
+ recursive: Optional[bool] = True,
32
+ object_name: str = "file",
33
+ update: bool = False,
34
+ anon: bool = False,
35
+ client_config: Optional[dict] = None,
36
+ ) -> "DataChain":
37
+ """Get data from storage(s) as a list of file with all file attributes.
38
+ It returns the chain itself as usual.
39
+
40
+ Parameters:
41
+ uri : storage URI with directory or list of URIs.
42
+ URIs must start with storage prefix such
43
+ as `s3://`, `gs://`, `az://` or "file:///"
44
+ type : read file as "binary", "text", or "image" data. Default is "binary".
45
+ recursive : search recursively for the given path.
46
+ object_name : Created object column name.
47
+ update : force storage reindexing. Default is False.
48
+ anon : If True, we will treat cloud bucket as public one
49
+ client_config : Optional client configuration for the storage client.
50
+
51
+ Returns:
52
+ DataChain: A DataChain object containing the file information.
53
+
54
+ Examples:
55
+ Simple call from s3:
56
+ ```python
57
+ import datachain as dc
58
+ chain = dc.from_storage("s3://my-bucket/my-dir")
59
+ ```
60
+
61
+ Multiple URIs:
62
+ ```python
63
+ chain = dc.from_storage([
64
+ "s3://bucket1/dir1",
65
+ "s3://bucket2/dir2"
66
+ ])
67
+ ```
68
+
69
+ With AWS S3-compatible storage:
70
+ ```python
71
+ chain = dc.from_storage(
72
+ "s3://my-bucket/my-dir",
73
+ client_config = {"aws_endpoint_url": "<minio-endpoint-url>"}
74
+ )
75
+ ```
76
+
77
+ Pass existing session
78
+ ```py
79
+ session = Session.get()
80
+ chain = dc.from_storage([
81
+ "path/to/dir1",
82
+ "path/to/dir2"
83
+ ], session=session, recursive=True)
84
+ ```
85
+
86
+ Note:
87
+ When using multiple URIs with `update=True`, the function optimizes by
88
+ avoiding redundant updates for URIs pointing to the same storage location.
89
+ """
90
+ from .datachain import DataChain
91
+ from .datasets import from_dataset
92
+ from .records import from_records
93
+ from .values import from_values
94
+
95
+ file_type = get_file_type(type)
96
+
97
+ if anon:
98
+ client_config = (client_config or {}) | {"anon": True}
99
+ session = Session.get(session, client_config=client_config, in_memory=in_memory)
100
+ cache = session.catalog.cache
101
+ client_config = session.catalog.client_config
102
+
103
+ uris = uri if isinstance(uri, (list, tuple)) else [uri]
104
+
105
+ if not uris:
106
+ raise ValueError("No URIs provided")
107
+
108
+ storage_chain = None
109
+ listed_ds_name = set()
110
+ file_values = []
111
+
112
+ for single_uri in uris:
113
+ list_ds_name, list_uri, list_path, list_ds_exists = get_listing(
114
+ single_uri, session, update=update
115
+ )
116
+
117
+ # list_ds_name is None if object is a file, we don't want to use cache
118
+ # or do listing in that case - just read that single object
119
+ if not list_ds_name:
120
+ file_values.append(
121
+ get_file_info(list_uri, cache, client_config=client_config)
122
+ )
123
+ continue
124
+
125
+ dc = from_dataset(list_ds_name, session=session, settings=settings)
126
+ dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
127
+
128
+ if update or not list_ds_exists:
129
+
130
+ def lst_fn(ds_name, lst_uri):
131
+ # disable prefetch for listing, as it pre-downloads all files
132
+ (
133
+ from_records(
134
+ DataChain.DEFAULT_FILE_RECORD,
135
+ session=session,
136
+ settings=settings,
137
+ in_memory=in_memory,
138
+ )
139
+ .settings(prefetch=0)
140
+ .gen(
141
+ list_bucket(lst_uri, cache, client_config=client_config),
142
+ output={f"{object_name}": file_type},
143
+ )
144
+ .save(ds_name, listing=True)
145
+ )
146
+
147
+ dc._query.add_before_steps(
148
+ lambda ds_name=list_ds_name, lst_uri=list_uri: lst_fn(ds_name, lst_uri)
149
+ )
150
+
151
+ chain = ls(dc, list_path, recursive=recursive, object_name=object_name)
152
+
153
+ storage_chain = storage_chain.union(chain) if storage_chain else chain
154
+ listed_ds_name.add(list_ds_name)
155
+
156
+ if file_values:
157
+ file_chain = from_values(
158
+ session=session,
159
+ settings=settings,
160
+ in_memory=in_memory,
161
+ file=file_values,
162
+ )
163
+ file_chain.signals_schema = file_chain.signals_schema.mutate(
164
+ {f"{object_name}": file_type}
165
+ )
166
+ storage_chain = storage_chain.union(file_chain) if storage_chain else file_chain
167
+
168
+ assert storage_chain is not None
169
+
170
+ return storage_chain
@@ -0,0 +1,128 @@
1
+ from collections.abc import Sequence
2
+ from functools import wraps
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Callable,
6
+ Optional,
7
+ TypeVar,
8
+ Union,
9
+ )
10
+
11
+ import sqlalchemy
12
+ from sqlalchemy.sql.functions import GenericFunction
13
+
14
+ from datachain.func.base import Function
15
+ from datachain.lib.data_model import DataModel, DataType
16
+ from datachain.lib.utils import DataChainParamsError
17
+ from datachain.query.schema import DEFAULT_DELIMITER
18
+
19
+ if TYPE_CHECKING:
20
+ from typing_extensions import Concatenate, ParamSpec
21
+
22
+ from .datachain import DataChain
23
+
24
+ P = ParamSpec("P")
25
+
26
+ D = TypeVar("D", bound="DataChain")
27
+
28
+
29
+ def resolve_columns(
30
+ method: "Callable[Concatenate[D, P], D]",
31
+ ) -> "Callable[Concatenate[D, P], D]":
32
+ """Decorator that resolvs input column names to their actual DB names. This is
33
+ specially important for nested columns as user works with them by using dot
34
+ notation e.g (file.name) but are actually defined with default delimiter
35
+ in DB, e.g file__name.
36
+ If there are any sql functions in arguments, they will just be transferred as is
37
+ to a method.
38
+ """
39
+
40
+ @wraps(method)
41
+ def _inner(self: D, *args: "P.args", **kwargs: "P.kwargs") -> D:
42
+ resolved_args = self.signals_schema.resolve(
43
+ *[arg for arg in args if not isinstance(arg, GenericFunction)] # type: ignore[arg-type]
44
+ ).db_signals()
45
+
46
+ for idx, arg in enumerate(args):
47
+ if isinstance(arg, GenericFunction):
48
+ resolved_args.insert(idx, arg) # type: ignore[arg-type]
49
+
50
+ return method(self, *resolved_args, **kwargs)
51
+
52
+ return _inner
53
+
54
+
55
+ class DatasetPrepareError(DataChainParamsError):
56
+ def __init__(self, name, msg, output=None):
57
+ name = f" '{name}'" if name else ""
58
+ output = f" output '{output}'" if output else ""
59
+ super().__init__(f"Dataset{name}{output} processing prepare error: {msg}")
60
+
61
+
62
+ class DatasetFromValuesError(DataChainParamsError):
63
+ def __init__(self, name, msg):
64
+ name = f" '{name}'" if name else ""
65
+ super().__init__(f"Dataset{name} from values error: {msg}")
66
+
67
+
68
+ MergeColType = Union[str, Function, sqlalchemy.ColumnElement]
69
+
70
+
71
+ def _validate_merge_on(
72
+ on: Union[MergeColType, Sequence[MergeColType]],
73
+ ds: "DataChain",
74
+ ) -> Sequence[MergeColType]:
75
+ if isinstance(on, (str, sqlalchemy.ColumnElement)):
76
+ return [on]
77
+ if isinstance(on, Function):
78
+ return [on.get_column(table=ds._query.table)]
79
+ if isinstance(on, Sequence):
80
+ return [
81
+ c.get_column(table=ds._query.table) if isinstance(c, Function) else c
82
+ for c in on
83
+ ]
84
+
85
+
86
+ def _get_merge_error_str(col: MergeColType) -> str:
87
+ if isinstance(col, str):
88
+ return col
89
+ if isinstance(col, Function):
90
+ return f"{col.name}()"
91
+ if isinstance(col, sqlalchemy.Column):
92
+ return col.name.replace(DEFAULT_DELIMITER, ".")
93
+ if isinstance(col, sqlalchemy.ColumnElement) and hasattr(col, "name"):
94
+ return f"{col.name} expression"
95
+ return str(col)
96
+
97
+
98
+ class DatasetMergeError(DataChainParamsError):
99
+ def __init__(
100
+ self,
101
+ on: Union[MergeColType, Sequence[MergeColType]],
102
+ right_on: Optional[Union[MergeColType, Sequence[MergeColType]]],
103
+ msg: str,
104
+ ):
105
+ def _get_str(
106
+ on: Union[MergeColType, Sequence[MergeColType]],
107
+ ) -> str:
108
+ if not isinstance(on, Sequence):
109
+ return str(on) # type: ignore[unreachable]
110
+ return ", ".join([_get_merge_error_str(col) for col in on])
111
+
112
+ on_str = _get_str(on)
113
+ right_on_str = (
114
+ ", right_on='" + _get_str(right_on) + "'"
115
+ if right_on and isinstance(right_on, Sequence)
116
+ else ""
117
+ )
118
+ super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")
119
+
120
+
121
+ OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]
122
+
123
+
124
+ class Sys(DataModel):
125
+ """Model for internal DataChain signals `id` and `rand`."""
126
+
127
+ id: int
128
+ rand: int
@@ -0,0 +1,53 @@
1
+ from collections.abc import Iterator
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ Optional,
5
+ )
6
+
7
+ from datachain.lib.convert.values_to_tuples import values_to_tuples
8
+ from datachain.lib.data_model import dict_to_data_model
9
+ from datachain.lib.dc.records import from_records
10
+ from datachain.lib.dc.utils import OutputType
11
+ from datachain.query import Session
12
+
13
+ if TYPE_CHECKING:
14
+ from typing_extensions import ParamSpec
15
+
16
+ from .datachain import DataChain
17
+
18
+ P = ParamSpec("P")
19
+
20
+
21
+ def from_values(
22
+ ds_name: str = "",
23
+ session: Optional[Session] = None,
24
+ settings: Optional[dict] = None,
25
+ in_memory: bool = False,
26
+ output: OutputType = None,
27
+ object_name: str = "",
28
+ **fr_map,
29
+ ) -> "DataChain":
30
+ """Generate chain from list of values.
31
+
32
+ Example:
33
+ ```py
34
+ import datachain as dc
35
+ dc.from_values(fib=[1, 2, 3, 5, 8])
36
+ ```
37
+ """
38
+ from .datachain import DataChain
39
+
40
+ tuple_type, output, tuples = values_to_tuples(ds_name, output, **fr_map)
41
+
42
+ def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
43
+ yield from tuples
44
+
45
+ chain = from_records(
46
+ DataChain.DEFAULT_FILE_RECORD,
47
+ session=session,
48
+ settings=settings,
49
+ in_memory=in_memory,
50
+ )
51
+ if object_name:
52
+ output = {object_name: dict_to_data_model(object_name, output)} # type: ignore[arg-type]
53
+ return chain.gen(_func_fr, output=output)
@@ -103,12 +103,10 @@ def read_meta( # noqa: C901
103
103
  model_name=None,
104
104
  nrows=None,
105
105
  ) -> Callable:
106
- from datachain.lib.dc import DataChain
106
+ from datachain import from_storage
107
107
 
108
108
  if schema_from:
109
- file = next(
110
- DataChain.from_storage(schema_from, type="text").limit(1).collect("file")
111
- )
109
+ file = next(from_storage(schema_from, type="text").limit(1).collect("file"))
112
110
  model_code = gen_datamodel_code(
113
111
  file, format=format, jmespath=jmespath, model_name=model_name
114
112
  )
datachain/lib/pytorch.py CHANGED
@@ -14,7 +14,7 @@ from torchvision.transforms import v2
14
14
  from datachain import Session
15
15
  from datachain.cache import get_temp_cache
16
16
  from datachain.catalog import Catalog, get_catalog
17
- from datachain.lib.dc import DataChain
17
+ from datachain.lib.dc.datasets import from_dataset
18
18
  from datachain.lib.settings import Settings
19
19
  from datachain.lib.text import convert_text
20
20
  from datachain.progress import CombinedDownloadCallback
@@ -122,7 +122,7 @@ class PytorchDataset(IterableDataset):
122
122
  ) -> Generator[tuple[Any, ...], None, None]:
123
123
  catalog = self._get_catalog()
124
124
  session = Session("PyTorch", catalog=catalog)
125
- ds = DataChain.from_dataset(
125
+ ds = from_dataset(
126
126
  name=self.name, version=self.version, session=session
127
127
  ).settings(cache=self.cache, prefetch=self.prefetch)
128
128
  ds = ds.remove_file_signals()
datachain/lib/udf.py CHANGED
@@ -123,10 +123,10 @@ class UDFBase(AbstractUDF):
123
123
 
124
124
  Example:
125
125
  ```py
126
- from datachain import C, DataChain, Mapper
126
+ import datachain as dc
127
127
  import open_clip
128
128
 
129
- class ImageEncoder(Mapper):
129
+ class ImageEncoder(dc.Mapper):
130
130
  def __init__(self, model_name: str, pretrained: str):
131
131
  self.model_name = model_name
132
132
  self.pretrained = pretrained
@@ -145,7 +145,7 @@ class UDFBase(AbstractUDF):
145
145
  return emb[0].tolist()
146
146
 
147
147
  (
148
- DataChain.from_storage(
148
+ dc.from_storage(
149
149
  "gs://datachain-demo/fashion-product-images/images", type="image"
150
150
  )
151
151
  .limit(5)
@@ -47,6 +47,7 @@ from datachain.error import (
47
47
  QueryScriptCancelError,
48
48
  )
49
49
  from datachain.func.base import Function
50
+ from datachain.lib.listing import is_listing_dataset
50
51
  from datachain.lib.udf import UDFAdapter, _get_cache
51
52
  from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
52
53
  from datachain.query.schema import C, UDFParamSpec, normalize_param
@@ -151,13 +152,6 @@ def step_result(
151
152
  )
152
153
 
153
154
 
154
- class StartingStep(ABC):
155
- """An initial query processing step, referencing a data source."""
156
-
157
- @abstractmethod
158
- def apply(self) -> "StepResult": ...
159
-
160
-
161
155
  @frozen
162
156
  class Step(ABC):
163
157
  """A query processing step (filtering, mutation, etc.)"""
@@ -170,7 +164,7 @@ class Step(ABC):
170
164
 
171
165
 
172
166
  @frozen
173
- class QueryStep(StartingStep):
167
+ class QueryStep:
174
168
  catalog: "Catalog"
175
169
  dataset_name: str
176
170
  dataset_version: int
@@ -1097,26 +1091,42 @@ class DatasetQuery:
1097
1091
  self.temp_table_names: list[str] = []
1098
1092
  self.dependencies: set[DatasetDependencyType] = set()
1099
1093
  self.table = self.get_table()
1100
- self.starting_step: StartingStep
1094
+ self.starting_step: Optional[QueryStep] = None
1101
1095
  self.name: Optional[str] = None
1102
1096
  self.version: Optional[int] = None
1103
1097
  self.feature_schema: Optional[dict] = None
1104
1098
  self.column_types: Optional[dict[str, Any]] = None
1099
+ self.before_steps: list[Callable] = []
1105
1100
 
1106
- self.name = name
1101
+ self.list_ds_name: Optional[str] = None
1107
1102
 
1108
- if fallback_to_studio and is_token_set():
1109
- ds = self.catalog.get_dataset_with_remote_fallback(name, version)
1103
+ self.name = name
1104
+ self.dialect = self.catalog.warehouse.db.dialect
1105
+ if version:
1106
+ self.version = version
1107
+
1108
+ if is_listing_dataset(name):
1109
+ # not setting query step yet as listing dataset might not exist at
1110
+ # this point
1111
+ self.list_ds_name = name
1112
+ elif fallback_to_studio and is_token_set():
1113
+ self._set_starting_step(
1114
+ self.catalog.get_dataset_with_remote_fallback(name, version)
1115
+ )
1110
1116
  else:
1111
- ds = self.catalog.get_dataset(name)
1117
+ self._set_starting_step(self.catalog.get_dataset(name))
1118
+
1119
+ def _set_starting_step(self, ds: "DatasetRecord") -> None:
1120
+ if not self.version:
1121
+ self.version = ds.latest_version
1112
1122
 
1113
- self.version = version or ds.latest_version
1123
+ self.starting_step = QueryStep(self.catalog, ds.name, self.version)
1124
+
1125
+ # at this point we know our starting dataset so setting up schemas
1114
1126
  self.feature_schema = ds.get_version(self.version).feature_schema
1115
1127
  self.column_types = copy(ds.schema)
1116
1128
  if "sys__id" in self.column_types:
1117
1129
  self.column_types.pop("sys__id")
1118
- self.starting_step = QueryStep(self.catalog, name, self.version)
1119
- self.dialect = self.catalog.warehouse.db.dialect
1120
1130
 
1121
1131
  def __iter__(self):
1122
1132
  return iter(self.db_results())
@@ -1180,11 +1190,23 @@ class DatasetQuery:
1180
1190
  col.table = self.table
1181
1191
  return col
1182
1192
 
1193
+ def add_before_steps(self, fn: Callable) -> None:
1194
+ """
1195
+ Setting custom function to be run before applying steps
1196
+ """
1197
+ self.before_steps.append(fn)
1198
+
1183
1199
  def apply_steps(self) -> QueryGenerator:
1184
1200
  """
1185
1201
  Apply the steps in the query and return the resulting
1186
1202
  sqlalchemy.SelectBase.
1187
1203
  """
1204
+ for fn in self.before_steps:
1205
+ fn()
1206
+
1207
+ if self.list_ds_name:
1208
+ # at this point we know what is our starting listing dataset name
1209
+ self._set_starting_step(self.catalog.get_dataset(self.list_ds_name)) # type: ignore [arg-type]
1188
1210
  query = self.clone()
1189
1211
 
1190
1212
  index = os.getenv("DATACHAIN_QUERY_CHUNK_INDEX", self._chunk_index)
@@ -1203,6 +1225,7 @@ class DatasetQuery:
1203
1225
  query = query.filter(C.sys__rand % total == index)
1204
1226
  query.steps = query.steps[-1:] + query.steps[:-1]
1205
1227
 
1228
+ assert query.starting_step
1206
1229
  result = query.starting_step.apply()
1207
1230
  self.dependencies.update(result.dependencies)
1208
1231
 
@@ -37,11 +37,11 @@ def train_test_split(
37
37
  Examples:
38
38
  Train-test split:
39
39
  ```python
40
- from datachain import DataChain
40
+ import datachain as dc
41
41
  from datachain.toolkit import train_test_split
42
42
 
43
43
  # Load a DataChain from a storage source (e.g., S3 bucket)
44
- dc = DataChain.from_storage("s3://bucket/dir/")
44
+ dc = dc.from_storage("s3://bucket/dir/")
45
45
 
46
46
  # Perform a 70/30 train-test split
47
47
  train, test = train_test_split(dc, [0.7, 0.3])
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datachain
3
- Version: 0.13.1
3
+ Version: 0.14.1
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
- License: Apache-2.0
6
+ License-Expression: Apache-2.0
7
7
  Project-URL: Documentation, https://datachain.dvc.ai
8
8
  Project-URL: Issues, https://github.com/iterative/datachain/issues
9
9
  Project-URL: Source, https://github.com/iterative/datachain
@@ -169,16 +169,16 @@ high confidence scores.
169
169
 
170
170
  .. code:: py
171
171
 
172
- from datachain import Column, DataChain
172
+ import datachain as dc
173
173
 
174
- meta = DataChain.from_json("gs://datachain-demo/dogs-and-cats/*json", object_name="meta", anon=True)
175
- images = DataChain.from_storage("gs://datachain-demo/dogs-and-cats/*jpg", anon=True)
174
+ meta = dc.from_json("gs://datachain-demo/dogs-and-cats/*json", object_name="meta", anon=True)
175
+ images = dc.from_storage("gs://datachain-demo/dogs-and-cats/*jpg", anon=True)
176
176
 
177
177
  images_id = images.map(id=lambda file: file.path.split('.')[-2])
178
178
  annotated = images_id.merge(meta, on="id", right_on="meta.id")
179
179
 
180
- likely_cats = annotated.filter((Column("meta.inference.confidence") > 0.93) \
181
- & (Column("meta.inference.class_") == "cat"))
180
+ likely_cats = annotated.filter((dc.Column("meta.inference.confidence") > 0.93) \
181
+ & (dc.Column("meta.inference.class_") == "cat"))
182
182
  likely_cats.to_storage("high-confidence-cats/", signal="file")
183
183
 
184
184
 
@@ -199,11 +199,11 @@ Python code:
199
199
 
200
200
  import os
201
201
  from mistralai import Mistral
202
- from datachain import File, DataChain, Column
202
+ import datachain as dc
203
203
 
204
204
  PROMPT = "Was this dialog successful? Answer in a single word: Success or Failure."
205
205
 
206
- def eval_dialogue(file: File) -> bool:
206
+ def eval_dialogue(file: dc.File) -> bool:
207
207
  client = Mistral(api_key = os.environ["MISTRAL_API_KEY"])
208
208
  response = client.chat.complete(
209
209
  model="open-mixtral-8x22b",
@@ -213,13 +213,13 @@ Python code:
213
213
  return result.lower().startswith("success")
214
214
 
215
215
  chain = (
216
- DataChain.from_storage("gs://datachain-demo/chatbot-KiT/", object_name="file", anon=True)
216
+ dc.from_storage("gs://datachain-demo/chatbot-KiT/", object_name="file", anon=True)
217
217
  .settings(parallel=4, cache=True)
218
218
  .map(is_success=eval_dialogue)
219
219
  .save("mistral_files")
220
220
  )
221
221
 
222
- successful_chain = chain.filter(Column("is_success") == True)
222
+ successful_chain = chain.filter(dc.Column("is_success") == True)
223
223
  successful_chain.to_storage("./output_mistral")
224
224
 
225
225
  print(f"{successful_chain.count()} files were exported")