datachain 0.13.1__py3-none-any.whl → 0.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +28 -1
- datachain/catalog/catalog.py +6 -10
- datachain/cli/commands/ls.py +2 -2
- datachain/cli/commands/show.py +2 -3
- datachain/client/fsspec.py +3 -3
- datachain/lib/dc/__init__.py +32 -0
- datachain/lib/dc/csv.py +127 -0
- datachain/lib/{dc.py → dc/datachain.py} +144 -733
- datachain/lib/dc/datasets.py +149 -0
- datachain/lib/dc/hf.py +73 -0
- datachain/lib/dc/json.py +91 -0
- datachain/lib/dc/listings.py +43 -0
- datachain/lib/dc/pandas.py +56 -0
- datachain/lib/dc/parquet.py +65 -0
- datachain/lib/dc/records.py +90 -0
- datachain/lib/dc/storage.py +170 -0
- datachain/lib/dc/utils.py +128 -0
- datachain/lib/dc/values.py +53 -0
- datachain/lib/meta_formats.py +2 -4
- datachain/lib/pytorch.py +2 -2
- datachain/lib/udf.py +3 -3
- datachain/query/dataset.py +39 -16
- datachain/toolkit/split.py +2 -2
- {datachain-0.13.1.dist-info → datachain-0.14.1.dist-info}/METADATA +11 -11
- {datachain-0.13.1.dist-info → datachain-0.14.1.dist-info}/RECORD +29 -17
- {datachain-0.13.1.dist-info → datachain-0.14.1.dist-info}/WHEEL +1 -1
- {datachain-0.13.1.dist-info → datachain-0.14.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.13.1.dist-info → datachain-0.14.1.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.13.1.dist-info → datachain-0.14.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import os.path
|
|
2
|
+
from typing import (
|
|
3
|
+
TYPE_CHECKING,
|
|
4
|
+
Optional,
|
|
5
|
+
Union,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from datachain.lib.file import (
|
|
9
|
+
FileType,
|
|
10
|
+
get_file_type,
|
|
11
|
+
)
|
|
12
|
+
from datachain.lib.listing import (
|
|
13
|
+
get_file_info,
|
|
14
|
+
get_listing,
|
|
15
|
+
list_bucket,
|
|
16
|
+
ls,
|
|
17
|
+
)
|
|
18
|
+
from datachain.query import Session
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from .datachain import DataChain
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def from_storage(
|
|
25
|
+
uri: Union[str, os.PathLike[str], list[str], list[os.PathLike[str]]],
|
|
26
|
+
*,
|
|
27
|
+
type: FileType = "binary",
|
|
28
|
+
session: Optional[Session] = None,
|
|
29
|
+
settings: Optional[dict] = None,
|
|
30
|
+
in_memory: bool = False,
|
|
31
|
+
recursive: Optional[bool] = True,
|
|
32
|
+
object_name: str = "file",
|
|
33
|
+
update: bool = False,
|
|
34
|
+
anon: bool = False,
|
|
35
|
+
client_config: Optional[dict] = None,
|
|
36
|
+
) -> "DataChain":
|
|
37
|
+
"""Get data from storage(s) as a list of file with all file attributes.
|
|
38
|
+
It returns the chain itself as usual.
|
|
39
|
+
|
|
40
|
+
Parameters:
|
|
41
|
+
uri : storage URI with directory or list of URIs.
|
|
42
|
+
URIs must start with storage prefix such
|
|
43
|
+
as `s3://`, `gs://`, `az://` or "file:///"
|
|
44
|
+
type : read file as "binary", "text", or "image" data. Default is "binary".
|
|
45
|
+
recursive : search recursively for the given path.
|
|
46
|
+
object_name : Created object column name.
|
|
47
|
+
update : force storage reindexing. Default is False.
|
|
48
|
+
anon : If True, we will treat cloud bucket as public one
|
|
49
|
+
client_config : Optional client configuration for the storage client.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
DataChain: A DataChain object containing the file information.
|
|
53
|
+
|
|
54
|
+
Examples:
|
|
55
|
+
Simple call from s3:
|
|
56
|
+
```python
|
|
57
|
+
import datachain as dc
|
|
58
|
+
chain = dc.from_storage("s3://my-bucket/my-dir")
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
Multiple URIs:
|
|
62
|
+
```python
|
|
63
|
+
chain = dc.from_storage([
|
|
64
|
+
"s3://bucket1/dir1",
|
|
65
|
+
"s3://bucket2/dir2"
|
|
66
|
+
])
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
With AWS S3-compatible storage:
|
|
70
|
+
```python
|
|
71
|
+
chain = dc.from_storage(
|
|
72
|
+
"s3://my-bucket/my-dir",
|
|
73
|
+
client_config = {"aws_endpoint_url": "<minio-endpoint-url>"}
|
|
74
|
+
)
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
Pass existing session
|
|
78
|
+
```py
|
|
79
|
+
session = Session.get()
|
|
80
|
+
chain = dc.from_storage([
|
|
81
|
+
"path/to/dir1",
|
|
82
|
+
"path/to/dir2"
|
|
83
|
+
], session=session, recursive=True)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
Note:
|
|
87
|
+
When using multiple URIs with `update=True`, the function optimizes by
|
|
88
|
+
avoiding redundant updates for URIs pointing to the same storage location.
|
|
89
|
+
"""
|
|
90
|
+
from .datachain import DataChain
|
|
91
|
+
from .datasets import from_dataset
|
|
92
|
+
from .records import from_records
|
|
93
|
+
from .values import from_values
|
|
94
|
+
|
|
95
|
+
file_type = get_file_type(type)
|
|
96
|
+
|
|
97
|
+
if anon:
|
|
98
|
+
client_config = (client_config or {}) | {"anon": True}
|
|
99
|
+
session = Session.get(session, client_config=client_config, in_memory=in_memory)
|
|
100
|
+
cache = session.catalog.cache
|
|
101
|
+
client_config = session.catalog.client_config
|
|
102
|
+
|
|
103
|
+
uris = uri if isinstance(uri, (list, tuple)) else [uri]
|
|
104
|
+
|
|
105
|
+
if not uris:
|
|
106
|
+
raise ValueError("No URIs provided")
|
|
107
|
+
|
|
108
|
+
storage_chain = None
|
|
109
|
+
listed_ds_name = set()
|
|
110
|
+
file_values = []
|
|
111
|
+
|
|
112
|
+
for single_uri in uris:
|
|
113
|
+
list_ds_name, list_uri, list_path, list_ds_exists = get_listing(
|
|
114
|
+
single_uri, session, update=update
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# list_ds_name is None if object is a file, we don't want to use cache
|
|
118
|
+
# or do listing in that case - just read that single object
|
|
119
|
+
if not list_ds_name:
|
|
120
|
+
file_values.append(
|
|
121
|
+
get_file_info(list_uri, cache, client_config=client_config)
|
|
122
|
+
)
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
dc = from_dataset(list_ds_name, session=session, settings=settings)
|
|
126
|
+
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
|
|
127
|
+
|
|
128
|
+
if update or not list_ds_exists:
|
|
129
|
+
|
|
130
|
+
def lst_fn(ds_name, lst_uri):
|
|
131
|
+
# disable prefetch for listing, as it pre-downloads all files
|
|
132
|
+
(
|
|
133
|
+
from_records(
|
|
134
|
+
DataChain.DEFAULT_FILE_RECORD,
|
|
135
|
+
session=session,
|
|
136
|
+
settings=settings,
|
|
137
|
+
in_memory=in_memory,
|
|
138
|
+
)
|
|
139
|
+
.settings(prefetch=0)
|
|
140
|
+
.gen(
|
|
141
|
+
list_bucket(lst_uri, cache, client_config=client_config),
|
|
142
|
+
output={f"{object_name}": file_type},
|
|
143
|
+
)
|
|
144
|
+
.save(ds_name, listing=True)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
dc._query.add_before_steps(
|
|
148
|
+
lambda ds_name=list_ds_name, lst_uri=list_uri: lst_fn(ds_name, lst_uri)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
chain = ls(dc, list_path, recursive=recursive, object_name=object_name)
|
|
152
|
+
|
|
153
|
+
storage_chain = storage_chain.union(chain) if storage_chain else chain
|
|
154
|
+
listed_ds_name.add(list_ds_name)
|
|
155
|
+
|
|
156
|
+
if file_values:
|
|
157
|
+
file_chain = from_values(
|
|
158
|
+
session=session,
|
|
159
|
+
settings=settings,
|
|
160
|
+
in_memory=in_memory,
|
|
161
|
+
file=file_values,
|
|
162
|
+
)
|
|
163
|
+
file_chain.signals_schema = file_chain.signals_schema.mutate(
|
|
164
|
+
{f"{object_name}": file_type}
|
|
165
|
+
)
|
|
166
|
+
storage_chain = storage_chain.union(file_chain) if storage_chain else file_chain
|
|
167
|
+
|
|
168
|
+
assert storage_chain is not None
|
|
169
|
+
|
|
170
|
+
return storage_chain
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from functools import wraps
|
|
3
|
+
from typing import (
|
|
4
|
+
TYPE_CHECKING,
|
|
5
|
+
Callable,
|
|
6
|
+
Optional,
|
|
7
|
+
TypeVar,
|
|
8
|
+
Union,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
import sqlalchemy
|
|
12
|
+
from sqlalchemy.sql.functions import GenericFunction
|
|
13
|
+
|
|
14
|
+
from datachain.func.base import Function
|
|
15
|
+
from datachain.lib.data_model import DataModel, DataType
|
|
16
|
+
from datachain.lib.utils import DataChainParamsError
|
|
17
|
+
from datachain.query.schema import DEFAULT_DELIMITER
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from typing_extensions import Concatenate, ParamSpec
|
|
21
|
+
|
|
22
|
+
from .datachain import DataChain
|
|
23
|
+
|
|
24
|
+
P = ParamSpec("P")
|
|
25
|
+
|
|
26
|
+
D = TypeVar("D", bound="DataChain")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def resolve_columns(
|
|
30
|
+
method: "Callable[Concatenate[D, P], D]",
|
|
31
|
+
) -> "Callable[Concatenate[D, P], D]":
|
|
32
|
+
"""Decorator that resolvs input column names to their actual DB names. This is
|
|
33
|
+
specially important for nested columns as user works with them by using dot
|
|
34
|
+
notation e.g (file.name) but are actually defined with default delimiter
|
|
35
|
+
in DB, e.g file__name.
|
|
36
|
+
If there are any sql functions in arguments, they will just be transferred as is
|
|
37
|
+
to a method.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@wraps(method)
|
|
41
|
+
def _inner(self: D, *args: "P.args", **kwargs: "P.kwargs") -> D:
|
|
42
|
+
resolved_args = self.signals_schema.resolve(
|
|
43
|
+
*[arg for arg in args if not isinstance(arg, GenericFunction)] # type: ignore[arg-type]
|
|
44
|
+
).db_signals()
|
|
45
|
+
|
|
46
|
+
for idx, arg in enumerate(args):
|
|
47
|
+
if isinstance(arg, GenericFunction):
|
|
48
|
+
resolved_args.insert(idx, arg) # type: ignore[arg-type]
|
|
49
|
+
|
|
50
|
+
return method(self, *resolved_args, **kwargs)
|
|
51
|
+
|
|
52
|
+
return _inner
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class DatasetPrepareError(DataChainParamsError):
|
|
56
|
+
def __init__(self, name, msg, output=None):
|
|
57
|
+
name = f" '{name}'" if name else ""
|
|
58
|
+
output = f" output '{output}'" if output else ""
|
|
59
|
+
super().__init__(f"Dataset{name}{output} processing prepare error: {msg}")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class DatasetFromValuesError(DataChainParamsError):
|
|
63
|
+
def __init__(self, name, msg):
|
|
64
|
+
name = f" '{name}'" if name else ""
|
|
65
|
+
super().__init__(f"Dataset{name} from values error: {msg}")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
MergeColType = Union[str, Function, sqlalchemy.ColumnElement]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _validate_merge_on(
|
|
72
|
+
on: Union[MergeColType, Sequence[MergeColType]],
|
|
73
|
+
ds: "DataChain",
|
|
74
|
+
) -> Sequence[MergeColType]:
|
|
75
|
+
if isinstance(on, (str, sqlalchemy.ColumnElement)):
|
|
76
|
+
return [on]
|
|
77
|
+
if isinstance(on, Function):
|
|
78
|
+
return [on.get_column(table=ds._query.table)]
|
|
79
|
+
if isinstance(on, Sequence):
|
|
80
|
+
return [
|
|
81
|
+
c.get_column(table=ds._query.table) if isinstance(c, Function) else c
|
|
82
|
+
for c in on
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _get_merge_error_str(col: MergeColType) -> str:
|
|
87
|
+
if isinstance(col, str):
|
|
88
|
+
return col
|
|
89
|
+
if isinstance(col, Function):
|
|
90
|
+
return f"{col.name}()"
|
|
91
|
+
if isinstance(col, sqlalchemy.Column):
|
|
92
|
+
return col.name.replace(DEFAULT_DELIMITER, ".")
|
|
93
|
+
if isinstance(col, sqlalchemy.ColumnElement) and hasattr(col, "name"):
|
|
94
|
+
return f"{col.name} expression"
|
|
95
|
+
return str(col)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class DatasetMergeError(DataChainParamsError):
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
on: Union[MergeColType, Sequence[MergeColType]],
|
|
102
|
+
right_on: Optional[Union[MergeColType, Sequence[MergeColType]]],
|
|
103
|
+
msg: str,
|
|
104
|
+
):
|
|
105
|
+
def _get_str(
|
|
106
|
+
on: Union[MergeColType, Sequence[MergeColType]],
|
|
107
|
+
) -> str:
|
|
108
|
+
if not isinstance(on, Sequence):
|
|
109
|
+
return str(on) # type: ignore[unreachable]
|
|
110
|
+
return ", ".join([_get_merge_error_str(col) for col in on])
|
|
111
|
+
|
|
112
|
+
on_str = _get_str(on)
|
|
113
|
+
right_on_str = (
|
|
114
|
+
", right_on='" + _get_str(right_on) + "'"
|
|
115
|
+
if right_on and isinstance(right_on, Sequence)
|
|
116
|
+
else ""
|
|
117
|
+
)
|
|
118
|
+
super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class Sys(DataModel):
|
|
125
|
+
"""Model for internal DataChain signals `id` and `rand`."""
|
|
126
|
+
|
|
127
|
+
id: int
|
|
128
|
+
rand: int
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from collections.abc import Iterator
|
|
2
|
+
from typing import (
|
|
3
|
+
TYPE_CHECKING,
|
|
4
|
+
Optional,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
8
|
+
from datachain.lib.data_model import dict_to_data_model
|
|
9
|
+
from datachain.lib.dc.records import from_records
|
|
10
|
+
from datachain.lib.dc.utils import OutputType
|
|
11
|
+
from datachain.query import Session
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from typing_extensions import ParamSpec
|
|
15
|
+
|
|
16
|
+
from .datachain import DataChain
|
|
17
|
+
|
|
18
|
+
P = ParamSpec("P")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def from_values(
|
|
22
|
+
ds_name: str = "",
|
|
23
|
+
session: Optional[Session] = None,
|
|
24
|
+
settings: Optional[dict] = None,
|
|
25
|
+
in_memory: bool = False,
|
|
26
|
+
output: OutputType = None,
|
|
27
|
+
object_name: str = "",
|
|
28
|
+
**fr_map,
|
|
29
|
+
) -> "DataChain":
|
|
30
|
+
"""Generate chain from list of values.
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
```py
|
|
34
|
+
import datachain as dc
|
|
35
|
+
dc.from_values(fib=[1, 2, 3, 5, 8])
|
|
36
|
+
```
|
|
37
|
+
"""
|
|
38
|
+
from .datachain import DataChain
|
|
39
|
+
|
|
40
|
+
tuple_type, output, tuples = values_to_tuples(ds_name, output, **fr_map)
|
|
41
|
+
|
|
42
|
+
def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
|
|
43
|
+
yield from tuples
|
|
44
|
+
|
|
45
|
+
chain = from_records(
|
|
46
|
+
DataChain.DEFAULT_FILE_RECORD,
|
|
47
|
+
session=session,
|
|
48
|
+
settings=settings,
|
|
49
|
+
in_memory=in_memory,
|
|
50
|
+
)
|
|
51
|
+
if object_name:
|
|
52
|
+
output = {object_name: dict_to_data_model(object_name, output)} # type: ignore[arg-type]
|
|
53
|
+
return chain.gen(_func_fr, output=output)
|
datachain/lib/meta_formats.py
CHANGED
|
@@ -103,12 +103,10 @@ def read_meta( # noqa: C901
|
|
|
103
103
|
model_name=None,
|
|
104
104
|
nrows=None,
|
|
105
105
|
) -> Callable:
|
|
106
|
-
from datachain
|
|
106
|
+
from datachain import from_storage
|
|
107
107
|
|
|
108
108
|
if schema_from:
|
|
109
|
-
file = next(
|
|
110
|
-
DataChain.from_storage(schema_from, type="text").limit(1).collect("file")
|
|
111
|
-
)
|
|
109
|
+
file = next(from_storage(schema_from, type="text").limit(1).collect("file"))
|
|
112
110
|
model_code = gen_datamodel_code(
|
|
113
111
|
file, format=format, jmespath=jmespath, model_name=model_name
|
|
114
112
|
)
|
datachain/lib/pytorch.py
CHANGED
|
@@ -14,7 +14,7 @@ from torchvision.transforms import v2
|
|
|
14
14
|
from datachain import Session
|
|
15
15
|
from datachain.cache import get_temp_cache
|
|
16
16
|
from datachain.catalog import Catalog, get_catalog
|
|
17
|
-
from datachain.lib.dc import
|
|
17
|
+
from datachain.lib.dc.datasets import from_dataset
|
|
18
18
|
from datachain.lib.settings import Settings
|
|
19
19
|
from datachain.lib.text import convert_text
|
|
20
20
|
from datachain.progress import CombinedDownloadCallback
|
|
@@ -122,7 +122,7 @@ class PytorchDataset(IterableDataset):
|
|
|
122
122
|
) -> Generator[tuple[Any, ...], None, None]:
|
|
123
123
|
catalog = self._get_catalog()
|
|
124
124
|
session = Session("PyTorch", catalog=catalog)
|
|
125
|
-
ds =
|
|
125
|
+
ds = from_dataset(
|
|
126
126
|
name=self.name, version=self.version, session=session
|
|
127
127
|
).settings(cache=self.cache, prefetch=self.prefetch)
|
|
128
128
|
ds = ds.remove_file_signals()
|
datachain/lib/udf.py
CHANGED
|
@@ -123,10 +123,10 @@ class UDFBase(AbstractUDF):
|
|
|
123
123
|
|
|
124
124
|
Example:
|
|
125
125
|
```py
|
|
126
|
-
|
|
126
|
+
import datachain as dc
|
|
127
127
|
import open_clip
|
|
128
128
|
|
|
129
|
-
class ImageEncoder(Mapper):
|
|
129
|
+
class ImageEncoder(dc.Mapper):
|
|
130
130
|
def __init__(self, model_name: str, pretrained: str):
|
|
131
131
|
self.model_name = model_name
|
|
132
132
|
self.pretrained = pretrained
|
|
@@ -145,7 +145,7 @@ class UDFBase(AbstractUDF):
|
|
|
145
145
|
return emb[0].tolist()
|
|
146
146
|
|
|
147
147
|
(
|
|
148
|
-
|
|
148
|
+
dc.from_storage(
|
|
149
149
|
"gs://datachain-demo/fashion-product-images/images", type="image"
|
|
150
150
|
)
|
|
151
151
|
.limit(5)
|
datachain/query/dataset.py
CHANGED
|
@@ -47,6 +47,7 @@ from datachain.error import (
|
|
|
47
47
|
QueryScriptCancelError,
|
|
48
48
|
)
|
|
49
49
|
from datachain.func.base import Function
|
|
50
|
+
from datachain.lib.listing import is_listing_dataset
|
|
50
51
|
from datachain.lib.udf import UDFAdapter, _get_cache
|
|
51
52
|
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
|
|
52
53
|
from datachain.query.schema import C, UDFParamSpec, normalize_param
|
|
@@ -151,13 +152,6 @@ def step_result(
|
|
|
151
152
|
)
|
|
152
153
|
|
|
153
154
|
|
|
154
|
-
class StartingStep(ABC):
|
|
155
|
-
"""An initial query processing step, referencing a data source."""
|
|
156
|
-
|
|
157
|
-
@abstractmethod
|
|
158
|
-
def apply(self) -> "StepResult": ...
|
|
159
|
-
|
|
160
|
-
|
|
161
155
|
@frozen
|
|
162
156
|
class Step(ABC):
|
|
163
157
|
"""A query processing step (filtering, mutation, etc.)"""
|
|
@@ -170,7 +164,7 @@ class Step(ABC):
|
|
|
170
164
|
|
|
171
165
|
|
|
172
166
|
@frozen
|
|
173
|
-
class QueryStep
|
|
167
|
+
class QueryStep:
|
|
174
168
|
catalog: "Catalog"
|
|
175
169
|
dataset_name: str
|
|
176
170
|
dataset_version: int
|
|
@@ -1097,26 +1091,42 @@ class DatasetQuery:
|
|
|
1097
1091
|
self.temp_table_names: list[str] = []
|
|
1098
1092
|
self.dependencies: set[DatasetDependencyType] = set()
|
|
1099
1093
|
self.table = self.get_table()
|
|
1100
|
-
self.starting_step:
|
|
1094
|
+
self.starting_step: Optional[QueryStep] = None
|
|
1101
1095
|
self.name: Optional[str] = None
|
|
1102
1096
|
self.version: Optional[int] = None
|
|
1103
1097
|
self.feature_schema: Optional[dict] = None
|
|
1104
1098
|
self.column_types: Optional[dict[str, Any]] = None
|
|
1099
|
+
self.before_steps: list[Callable] = []
|
|
1105
1100
|
|
|
1106
|
-
self.
|
|
1101
|
+
self.list_ds_name: Optional[str] = None
|
|
1107
1102
|
|
|
1108
|
-
|
|
1109
|
-
|
|
1103
|
+
self.name = name
|
|
1104
|
+
self.dialect = self.catalog.warehouse.db.dialect
|
|
1105
|
+
if version:
|
|
1106
|
+
self.version = version
|
|
1107
|
+
|
|
1108
|
+
if is_listing_dataset(name):
|
|
1109
|
+
# not setting query step yet as listing dataset might not exist at
|
|
1110
|
+
# this point
|
|
1111
|
+
self.list_ds_name = name
|
|
1112
|
+
elif fallback_to_studio and is_token_set():
|
|
1113
|
+
self._set_starting_step(
|
|
1114
|
+
self.catalog.get_dataset_with_remote_fallback(name, version)
|
|
1115
|
+
)
|
|
1110
1116
|
else:
|
|
1111
|
-
|
|
1117
|
+
self._set_starting_step(self.catalog.get_dataset(name))
|
|
1118
|
+
|
|
1119
|
+
def _set_starting_step(self, ds: "DatasetRecord") -> None:
|
|
1120
|
+
if not self.version:
|
|
1121
|
+
self.version = ds.latest_version
|
|
1112
1122
|
|
|
1113
|
-
self.
|
|
1123
|
+
self.starting_step = QueryStep(self.catalog, ds.name, self.version)
|
|
1124
|
+
|
|
1125
|
+
# at this point we know our starting dataset so setting up schemas
|
|
1114
1126
|
self.feature_schema = ds.get_version(self.version).feature_schema
|
|
1115
1127
|
self.column_types = copy(ds.schema)
|
|
1116
1128
|
if "sys__id" in self.column_types:
|
|
1117
1129
|
self.column_types.pop("sys__id")
|
|
1118
|
-
self.starting_step = QueryStep(self.catalog, name, self.version)
|
|
1119
|
-
self.dialect = self.catalog.warehouse.db.dialect
|
|
1120
1130
|
|
|
1121
1131
|
def __iter__(self):
|
|
1122
1132
|
return iter(self.db_results())
|
|
@@ -1180,11 +1190,23 @@ class DatasetQuery:
|
|
|
1180
1190
|
col.table = self.table
|
|
1181
1191
|
return col
|
|
1182
1192
|
|
|
1193
|
+
def add_before_steps(self, fn: Callable) -> None:
|
|
1194
|
+
"""
|
|
1195
|
+
Setting custom function to be run before applying steps
|
|
1196
|
+
"""
|
|
1197
|
+
self.before_steps.append(fn)
|
|
1198
|
+
|
|
1183
1199
|
def apply_steps(self) -> QueryGenerator:
|
|
1184
1200
|
"""
|
|
1185
1201
|
Apply the steps in the query and return the resulting
|
|
1186
1202
|
sqlalchemy.SelectBase.
|
|
1187
1203
|
"""
|
|
1204
|
+
for fn in self.before_steps:
|
|
1205
|
+
fn()
|
|
1206
|
+
|
|
1207
|
+
if self.list_ds_name:
|
|
1208
|
+
# at this point we know what is our starting listing dataset name
|
|
1209
|
+
self._set_starting_step(self.catalog.get_dataset(self.list_ds_name)) # type: ignore [arg-type]
|
|
1188
1210
|
query = self.clone()
|
|
1189
1211
|
|
|
1190
1212
|
index = os.getenv("DATACHAIN_QUERY_CHUNK_INDEX", self._chunk_index)
|
|
@@ -1203,6 +1225,7 @@ class DatasetQuery:
|
|
|
1203
1225
|
query = query.filter(C.sys__rand % total == index)
|
|
1204
1226
|
query.steps = query.steps[-1:] + query.steps[:-1]
|
|
1205
1227
|
|
|
1228
|
+
assert query.starting_step
|
|
1206
1229
|
result = query.starting_step.apply()
|
|
1207
1230
|
self.dependencies.update(result.dependencies)
|
|
1208
1231
|
|
datachain/toolkit/split.py
CHANGED
|
@@ -37,11 +37,11 @@ def train_test_split(
|
|
|
37
37
|
Examples:
|
|
38
38
|
Train-test split:
|
|
39
39
|
```python
|
|
40
|
-
|
|
40
|
+
import datachain as dc
|
|
41
41
|
from datachain.toolkit import train_test_split
|
|
42
42
|
|
|
43
43
|
# Load a DataChain from a storage source (e.g., S3 bucket)
|
|
44
|
-
dc =
|
|
44
|
+
dc = dc.from_storage("s3://bucket/dir/")
|
|
45
45
|
|
|
46
46
|
# Perform a 70/30 train-test split
|
|
47
47
|
train, test = train_test_split(dc, [0.7, 0.3])
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: datachain
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.14.1
|
|
4
4
|
Summary: Wrangle unstructured AI data at scale
|
|
5
5
|
Author-email: Dmitry Petrov <support@dvc.org>
|
|
6
|
-
License: Apache-2.0
|
|
6
|
+
License-Expression: Apache-2.0
|
|
7
7
|
Project-URL: Documentation, https://datachain.dvc.ai
|
|
8
8
|
Project-URL: Issues, https://github.com/iterative/datachain/issues
|
|
9
9
|
Project-URL: Source, https://github.com/iterative/datachain
|
|
@@ -169,16 +169,16 @@ high confidence scores.
|
|
|
169
169
|
|
|
170
170
|
.. code:: py
|
|
171
171
|
|
|
172
|
-
|
|
172
|
+
import datachain as dc
|
|
173
173
|
|
|
174
|
-
meta =
|
|
175
|
-
images =
|
|
174
|
+
meta = dc.from_json("gs://datachain-demo/dogs-and-cats/*json", object_name="meta", anon=True)
|
|
175
|
+
images = dc.from_storage("gs://datachain-demo/dogs-and-cats/*jpg", anon=True)
|
|
176
176
|
|
|
177
177
|
images_id = images.map(id=lambda file: file.path.split('.')[-2])
|
|
178
178
|
annotated = images_id.merge(meta, on="id", right_on="meta.id")
|
|
179
179
|
|
|
180
|
-
likely_cats = annotated.filter((Column("meta.inference.confidence") > 0.93) \
|
|
181
|
-
& (Column("meta.inference.class_") == "cat"))
|
|
180
|
+
likely_cats = annotated.filter((dc.Column("meta.inference.confidence") > 0.93) \
|
|
181
|
+
& (dc.Column("meta.inference.class_") == "cat"))
|
|
182
182
|
likely_cats.to_storage("high-confidence-cats/", signal="file")
|
|
183
183
|
|
|
184
184
|
|
|
@@ -199,11 +199,11 @@ Python code:
|
|
|
199
199
|
|
|
200
200
|
import os
|
|
201
201
|
from mistralai import Mistral
|
|
202
|
-
|
|
202
|
+
import datachain as dc
|
|
203
203
|
|
|
204
204
|
PROMPT = "Was this dialog successful? Answer in a single word: Success or Failure."
|
|
205
205
|
|
|
206
|
-
def eval_dialogue(file: File) -> bool:
|
|
206
|
+
def eval_dialogue(file: dc.File) -> bool:
|
|
207
207
|
client = Mistral(api_key = os.environ["MISTRAL_API_KEY"])
|
|
208
208
|
response = client.chat.complete(
|
|
209
209
|
model="open-mixtral-8x22b",
|
|
@@ -213,13 +213,13 @@ Python code:
|
|
|
213
213
|
return result.lower().startswith("success")
|
|
214
214
|
|
|
215
215
|
chain = (
|
|
216
|
-
|
|
216
|
+
dc.from_storage("gs://datachain-demo/chatbot-KiT/", object_name="file", anon=True)
|
|
217
217
|
.settings(parallel=4, cache=True)
|
|
218
218
|
.map(is_success=eval_dialogue)
|
|
219
219
|
.save("mistral_files")
|
|
220
220
|
)
|
|
221
221
|
|
|
222
|
-
successful_chain = chain.filter(Column("is_success") == True)
|
|
222
|
+
successful_chain = chain.filter(dc.Column("is_success") == True)
|
|
223
223
|
successful_chain.to_storage("./output_mistral")
|
|
224
224
|
|
|
225
225
|
print(f"{successful_chain.count()} files were exported")
|