datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/dc/storage.py
CHANGED
|
@@ -1,20 +1,17 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
Union,
|
|
6
|
-
)
|
|
1
|
+
import os
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from functools import reduce
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
7
5
|
|
|
8
|
-
from datachain.lib.
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
get_listing,
|
|
15
|
-
list_bucket,
|
|
16
|
-
ls,
|
|
6
|
+
from datachain.lib.dc.storage_pattern import (
|
|
7
|
+
apply_glob_filter,
|
|
8
|
+
expand_brace_pattern,
|
|
9
|
+
should_use_recursion,
|
|
10
|
+
split_uri_pattern,
|
|
11
|
+
validate_cloud_bucket_name,
|
|
17
12
|
)
|
|
13
|
+
from datachain.lib.file import FileType, get_file_type
|
|
14
|
+
from datachain.lib.listing import get_file_info, get_listing, list_bucket, ls
|
|
18
15
|
from datachain.query import Session
|
|
19
16
|
|
|
20
17
|
if TYPE_CHECKING:
|
|
@@ -22,31 +19,68 @@ if TYPE_CHECKING:
|
|
|
22
19
|
|
|
23
20
|
|
|
24
21
|
def read_storage(
|
|
25
|
-
uri:
|
|
22
|
+
uri: str | os.PathLike[str] | list[str] | list[os.PathLike[str]],
|
|
26
23
|
*,
|
|
27
24
|
type: FileType = "binary",
|
|
28
|
-
session:
|
|
29
|
-
settings:
|
|
25
|
+
session: Session | None = None,
|
|
26
|
+
settings: dict | None = None,
|
|
30
27
|
in_memory: bool = False,
|
|
31
|
-
recursive:
|
|
32
|
-
|
|
28
|
+
recursive: bool | None = True,
|
|
29
|
+
column: str = "file",
|
|
33
30
|
update: bool = False,
|
|
34
|
-
anon: bool =
|
|
35
|
-
|
|
31
|
+
anon: bool | None = None,
|
|
32
|
+
delta: bool | None = False,
|
|
33
|
+
delta_on: str | Sequence[str] | None = (
|
|
34
|
+
"file.path",
|
|
35
|
+
"file.etag",
|
|
36
|
+
"file.version",
|
|
37
|
+
),
|
|
38
|
+
delta_result_on: str | Sequence[str] | None = None,
|
|
39
|
+
delta_compare: str | Sequence[str] | None = None,
|
|
40
|
+
delta_retry: bool | str | None = None,
|
|
41
|
+
delta_unsafe: bool = False,
|
|
42
|
+
client_config: dict | None = None,
|
|
36
43
|
) -> "DataChain":
|
|
37
44
|
"""Get data from storage(s) as a list of file with all file attributes.
|
|
38
45
|
It returns the chain itself as usual.
|
|
39
46
|
|
|
40
47
|
Parameters:
|
|
41
|
-
uri
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
48
|
+
uri: Storage path(s) or URI(s). Can be a local path or start with a
|
|
49
|
+
storage prefix like `s3://`, `gs://`, `az://`, `hf://` or "file:///".
|
|
50
|
+
Supports glob patterns:
|
|
51
|
+
- `*` : wildcard
|
|
52
|
+
- `**` : recursive wildcard
|
|
53
|
+
- `?` : single character
|
|
54
|
+
- `{a,b}` : brace expansion list
|
|
55
|
+
- `{1..9}` : brace numeric or alphabetic range
|
|
56
|
+
type: read file as "binary", "text", or "image" data. Default is "binary".
|
|
57
|
+
recursive: search recursively for the given path.
|
|
58
|
+
column: Column name that will contain File objects. Default is "file".
|
|
59
|
+
update: force storage reindexing. Default is False.
|
|
60
|
+
anon: If True, we will treat cloud bucket as public one.
|
|
61
|
+
client_config: Optional client configuration for the storage client.
|
|
62
|
+
delta: If True, only process new or changed files instead of reprocessing
|
|
63
|
+
everything. This saves time by skipping files that were already processed in
|
|
64
|
+
previous versions. The optimization is working when a new version of the
|
|
65
|
+
dataset is created.
|
|
66
|
+
Default is False.
|
|
67
|
+
delta_on: Field(s) that uniquely identify each record in the source data.
|
|
68
|
+
Used to detect which records are new or changed.
|
|
69
|
+
Default is ("file.path", "file.etag", "file.version").
|
|
70
|
+
delta_result_on: Field(s) in the result dataset that match `delta_on` fields.
|
|
71
|
+
Only needed if you rename the identifying fields during processing.
|
|
72
|
+
Default is None.
|
|
73
|
+
delta_compare: Field(s) used to detect if a record has changed.
|
|
74
|
+
If not specified, all fields except `delta_on` fields are used.
|
|
75
|
+
Default is None.
|
|
76
|
+
delta_retry: Controls retry behavior for failed records:
|
|
77
|
+
- String (field name): Reprocess records where this field is not empty
|
|
78
|
+
(error mode)
|
|
79
|
+
- True: Reprocess records missing from the result dataset (missing mode)
|
|
80
|
+
- None: No retry processing (default)
|
|
81
|
+
delta_unsafe: Allow restricted ops in delta: merge, agg, union, group_by,
|
|
82
|
+
distinct. Caller must ensure datasets are consistent and not partially
|
|
83
|
+
updated.
|
|
50
84
|
|
|
51
85
|
Returns:
|
|
52
86
|
DataChain: A DataChain object containing the file information.
|
|
@@ -55,37 +89,36 @@ def read_storage(
|
|
|
55
89
|
Simple call from s3:
|
|
56
90
|
```python
|
|
57
91
|
import datachain as dc
|
|
58
|
-
|
|
92
|
+
dc.read_storage("s3://my-bucket/my-dir")
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
Match all .json files recursively using glob pattern
|
|
96
|
+
```py
|
|
97
|
+
dc.read_storage("gs://bucket/meta/**/*.json")
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
Match image file extensions for directories with pattern
|
|
101
|
+
```py
|
|
102
|
+
dc.read_storage("s3://bucket/202?/**/*.{jpg,jpeg,png}")
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
By ranges in filenames:
|
|
106
|
+
```py
|
|
107
|
+
dc.read_storage("s3://bucket/202{1..4}/**/*.{jpg,jpeg,png}")
|
|
59
108
|
```
|
|
60
109
|
|
|
61
110
|
Multiple URIs:
|
|
62
111
|
```python
|
|
63
|
-
|
|
64
|
-
"s3://bucket1/dir1",
|
|
65
|
-
"s3://bucket2/dir2"
|
|
66
|
-
])
|
|
112
|
+
dc.read_storage(["s3://my-bkt/dir1", "s3://bucket2/dir2/dir3"])
|
|
67
113
|
```
|
|
68
114
|
|
|
69
115
|
With AWS S3-compatible storage:
|
|
70
116
|
```python
|
|
71
|
-
|
|
117
|
+
dc.read_storage(
|
|
72
118
|
"s3://my-bucket/my-dir",
|
|
73
119
|
client_config = {"aws_endpoint_url": "<minio-endpoint-url>"}
|
|
74
120
|
)
|
|
75
121
|
```
|
|
76
|
-
|
|
77
|
-
Pass existing session
|
|
78
|
-
```py
|
|
79
|
-
session = Session.get()
|
|
80
|
-
chain = dc.read_storage([
|
|
81
|
-
"path/to/dir1",
|
|
82
|
-
"path/to/dir2"
|
|
83
|
-
], session=session, recursive=True)
|
|
84
|
-
```
|
|
85
|
-
|
|
86
|
-
Note:
|
|
87
|
-
When using multiple URIs with `update=True`, the function optimizes by
|
|
88
|
-
avoiding redundant updates for URIs pointing to the same storage location.
|
|
89
122
|
"""
|
|
90
123
|
from .datachain import DataChain
|
|
91
124
|
from .datasets import read_dataset
|
|
@@ -94,24 +127,50 @@ def read_storage(
|
|
|
94
127
|
|
|
95
128
|
file_type = get_file_type(type)
|
|
96
129
|
|
|
97
|
-
if anon:
|
|
98
|
-
client_config = (client_config or {}) | {"anon":
|
|
130
|
+
if anon is not None:
|
|
131
|
+
client_config = (client_config or {}) | {"anon": anon}
|
|
99
132
|
session = Session.get(session, client_config=client_config, in_memory=in_memory)
|
|
100
|
-
|
|
133
|
+
catalog = session.catalog
|
|
134
|
+
cache = catalog.cache
|
|
101
135
|
client_config = session.catalog.client_config
|
|
136
|
+
listing_namespace_name = catalog.metastore.system_namespace_name
|
|
137
|
+
listing_project_name = catalog.metastore.listing_project_name
|
|
102
138
|
|
|
103
139
|
uris = uri if isinstance(uri, (list, tuple)) else [uri]
|
|
104
140
|
|
|
105
141
|
if not uris:
|
|
106
142
|
raise ValueError("No URIs provided")
|
|
107
143
|
|
|
108
|
-
|
|
144
|
+
# Then expand all URIs that contain brace patterns
|
|
145
|
+
expanded_uris = []
|
|
146
|
+
for single_uri in uris:
|
|
147
|
+
uri_str = str(single_uri)
|
|
148
|
+
validate_cloud_bucket_name(uri_str)
|
|
149
|
+
expanded_uris.extend(expand_brace_pattern(uri_str))
|
|
150
|
+
|
|
151
|
+
# Now process each expanded URI
|
|
152
|
+
chains = []
|
|
109
153
|
listed_ds_name = set()
|
|
110
154
|
file_values = []
|
|
111
155
|
|
|
112
|
-
|
|
156
|
+
updated_uris = set()
|
|
157
|
+
|
|
158
|
+
for single_uri in expanded_uris:
|
|
159
|
+
# Check if URI contains glob patterns and split them
|
|
160
|
+
base_uri, glob_pattern = split_uri_pattern(single_uri)
|
|
161
|
+
|
|
162
|
+
# If a pattern is found, use the base_uri for listing
|
|
163
|
+
# The pattern will be used for filtering later
|
|
164
|
+
list_uri_to_use = base_uri if glob_pattern else single_uri
|
|
165
|
+
|
|
166
|
+
# Avoid double updates for the same URI
|
|
167
|
+
update_single_uri = False
|
|
168
|
+
if update and (list_uri_to_use not in updated_uris):
|
|
169
|
+
updated_uris.add(list_uri_to_use)
|
|
170
|
+
update_single_uri = True
|
|
171
|
+
|
|
113
172
|
list_ds_name, list_uri, list_path, list_ds_exists = get_listing(
|
|
114
|
-
|
|
173
|
+
list_uri_to_use, session, update=update_single_uri
|
|
115
174
|
)
|
|
116
175
|
|
|
117
176
|
# list_ds_name is None if object is a file, we don't want to use cache
|
|
@@ -122,9 +181,21 @@ def read_storage(
|
|
|
122
181
|
)
|
|
123
182
|
continue
|
|
124
183
|
|
|
125
|
-
dc = read_dataset(
|
|
184
|
+
dc = read_dataset(
|
|
185
|
+
list_ds_name,
|
|
186
|
+
namespace=listing_namespace_name,
|
|
187
|
+
project=listing_project_name,
|
|
188
|
+
session=session,
|
|
189
|
+
settings=settings,
|
|
190
|
+
delta=delta,
|
|
191
|
+
delta_on=delta_on,
|
|
192
|
+
delta_result_on=delta_result_on,
|
|
193
|
+
delta_compare=delta_compare,
|
|
194
|
+
delta_retry=delta_retry,
|
|
195
|
+
delta_unsafe=delta_unsafe,
|
|
196
|
+
)
|
|
126
197
|
dc._query.update = update
|
|
127
|
-
dc.signals_schema = dc.signals_schema.mutate({f"{
|
|
198
|
+
dc.signals_schema = dc.signals_schema.mutate({f"{column}": file_type})
|
|
128
199
|
|
|
129
200
|
if update or not list_ds_exists:
|
|
130
201
|
|
|
@@ -137,23 +208,42 @@ def read_storage(
|
|
|
137
208
|
settings=settings,
|
|
138
209
|
in_memory=in_memory,
|
|
139
210
|
)
|
|
140
|
-
.settings(
|
|
211
|
+
.settings(
|
|
212
|
+
prefetch=0,
|
|
213
|
+
namespace=listing_namespace_name,
|
|
214
|
+
project=listing_project_name,
|
|
215
|
+
)
|
|
141
216
|
.gen(
|
|
142
217
|
list_bucket(lst_uri, cache, client_config=client_config),
|
|
143
|
-
output={f"{
|
|
218
|
+
output={f"{column}": file_type},
|
|
144
219
|
)
|
|
145
|
-
|
|
220
|
+
# for internal listing datasets, we always bump major version
|
|
221
|
+
.save(ds_name, listing=True, update_version="major")
|
|
146
222
|
)
|
|
147
223
|
|
|
148
224
|
dc._query.set_listing_fn(
|
|
149
225
|
lambda ds_name=list_ds_name, lst_uri=list_uri: lst_fn(ds_name, lst_uri)
|
|
150
226
|
)
|
|
151
227
|
|
|
152
|
-
|
|
228
|
+
# If a glob pattern was detected, use it for filtering
|
|
229
|
+
# Otherwise, use the original list_path from get_listing
|
|
230
|
+
if glob_pattern:
|
|
231
|
+
# Determine if we should use recursive listing based on the pattern
|
|
232
|
+
use_recursive = should_use_recursion(glob_pattern, recursive or False)
|
|
233
|
+
|
|
234
|
+
# Apply glob filter - no need for brace expansion here as it's done above
|
|
235
|
+
chain = apply_glob_filter(
|
|
236
|
+
dc, glob_pattern, list_path, use_recursive, column
|
|
237
|
+
)
|
|
238
|
+
chains.append(chain)
|
|
239
|
+
else:
|
|
240
|
+
# No glob pattern detected, use normal ls behavior
|
|
241
|
+
chains.append(ls(dc, list_path, recursive=recursive, column=column))
|
|
153
242
|
|
|
154
|
-
storage_chain = storage_chain.union(chain) if storage_chain else chain
|
|
155
243
|
listed_ds_name.add(list_ds_name)
|
|
156
244
|
|
|
245
|
+
storage_chain = None if not chains else reduce(lambda x, y: x.union(y), chains)
|
|
246
|
+
|
|
157
247
|
if file_values:
|
|
158
248
|
file_chain = read_values(
|
|
159
249
|
session=session,
|
|
@@ -162,7 +252,7 @@ def read_storage(
|
|
|
162
252
|
file=file_values,
|
|
163
253
|
)
|
|
164
254
|
file_chain.signals_schema = file_chain.signals_schema.mutate(
|
|
165
|
-
{f"{
|
|
255
|
+
{f"{column}": file_type}
|
|
166
256
|
)
|
|
167
257
|
storage_chain = storage_chain.union(file_chain) if storage_chain else file_chain
|
|
168
258
|
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from datachain.client.fsspec import is_cloud_uri
|
|
5
|
+
from datachain.lib.listing import ls
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from .datachain import DataChain
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def validate_cloud_bucket_name(uri: str) -> None:
|
|
12
|
+
"""
|
|
13
|
+
Validate that cloud storage bucket names don't contain glob patterns.
|
|
14
|
+
|
|
15
|
+
Raises:
|
|
16
|
+
ValueError: If a cloud storage bucket name contains glob patterns
|
|
17
|
+
"""
|
|
18
|
+
if not is_cloud_uri(uri):
|
|
19
|
+
return
|
|
20
|
+
|
|
21
|
+
if "://" in uri:
|
|
22
|
+
scheme_end = uri.index("://") + 3
|
|
23
|
+
path_part = uri[scheme_end:]
|
|
24
|
+
|
|
25
|
+
if "/" in path_part:
|
|
26
|
+
bucket_name = path_part.split("/")[0]
|
|
27
|
+
else:
|
|
28
|
+
bucket_name = path_part
|
|
29
|
+
|
|
30
|
+
glob_chars = ["*", "?", "[", "]", "{", "}"]
|
|
31
|
+
if any(char in bucket_name for char in glob_chars):
|
|
32
|
+
raise ValueError(f"Glob patterns in bucket names are not supported: {uri}")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def split_uri_pattern(uri: str) -> tuple[str, str | None]:
|
|
36
|
+
"""Split a URI into base path and glob pattern."""
|
|
37
|
+
if not any(char in uri for char in ["*", "?", "[", "{", "}"]):
|
|
38
|
+
return uri, None
|
|
39
|
+
|
|
40
|
+
if "://" in uri:
|
|
41
|
+
scheme_end = uri.index("://") + 3
|
|
42
|
+
scheme_part = uri[:scheme_end]
|
|
43
|
+
path_part = uri[scheme_end:]
|
|
44
|
+
path_segments = path_part.split("/")
|
|
45
|
+
|
|
46
|
+
pattern_start_idx = None
|
|
47
|
+
for i, segment in enumerate(path_segments):
|
|
48
|
+
# Check for glob patterns including brace expansion
|
|
49
|
+
if glob.has_magic(segment) or "{" in segment:
|
|
50
|
+
pattern_start_idx = i
|
|
51
|
+
break
|
|
52
|
+
|
|
53
|
+
if pattern_start_idx is None:
|
|
54
|
+
return uri, None
|
|
55
|
+
|
|
56
|
+
if pattern_start_idx == 0:
|
|
57
|
+
base = scheme_part + path_segments[0]
|
|
58
|
+
pattern = "/".join(path_segments[1:]) if len(path_segments) > 1 else "*"
|
|
59
|
+
else:
|
|
60
|
+
base = scheme_part + "/".join(path_segments[:pattern_start_idx])
|
|
61
|
+
pattern = "/".join(path_segments[pattern_start_idx:])
|
|
62
|
+
|
|
63
|
+
return base, pattern
|
|
64
|
+
|
|
65
|
+
path_segments = uri.split("/")
|
|
66
|
+
|
|
67
|
+
pattern_start_idx = None
|
|
68
|
+
for i, segment in enumerate(path_segments):
|
|
69
|
+
if glob.has_magic(segment) or "{" in segment:
|
|
70
|
+
pattern_start_idx = i
|
|
71
|
+
break
|
|
72
|
+
|
|
73
|
+
if pattern_start_idx is None:
|
|
74
|
+
return uri, None
|
|
75
|
+
|
|
76
|
+
base = "/".join(path_segments[:pattern_start_idx]) if pattern_start_idx > 0 else "/"
|
|
77
|
+
pattern = "/".join(path_segments[pattern_start_idx:])
|
|
78
|
+
|
|
79
|
+
return base, pattern
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def should_use_recursion(pattern: str, user_recursive: bool) -> bool:
|
|
83
|
+
if not user_recursive:
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
if "**" in pattern:
|
|
87
|
+
return True
|
|
88
|
+
|
|
89
|
+
return "/" in pattern
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def expand_brace_pattern(pattern: str) -> list[str]:
|
|
93
|
+
"""
|
|
94
|
+
Recursively expand brace patterns into multiple glob patterns.
|
|
95
|
+
Supports:
|
|
96
|
+
- Comma-separated lists: *.{mp3,wav}
|
|
97
|
+
- Numeric ranges: file{1..10}
|
|
98
|
+
- Zero-padded numeric ranges: file{01..10}
|
|
99
|
+
- Character ranges: file{a..z}
|
|
100
|
+
|
|
101
|
+
Examples:
|
|
102
|
+
"*.{mp3,wav}" -> ["*.mp3", "*.wav"]
|
|
103
|
+
"file{1..3}" -> ["file1", "file2", "file3"]
|
|
104
|
+
"file{01..03}" -> ["file01", "file02", "file03"]
|
|
105
|
+
"file{a..c}" -> ["filea", "fileb", "filec"]
|
|
106
|
+
"{a,b}/{c,d}" -> ["a/c", "a/d", "b/c", "b/d"]
|
|
107
|
+
"""
|
|
108
|
+
if "{" not in pattern or "}" not in pattern:
|
|
109
|
+
return [pattern]
|
|
110
|
+
|
|
111
|
+
return _expand_single_braces(pattern)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _expand_single_braces(pattern: str) -> list[str]:
|
|
115
|
+
if "{" not in pattern or "}" not in pattern:
|
|
116
|
+
return [pattern]
|
|
117
|
+
|
|
118
|
+
start = pattern.index("{")
|
|
119
|
+
end = start
|
|
120
|
+
depth = 0
|
|
121
|
+
for i in range(start, len(pattern)):
|
|
122
|
+
if pattern[i] == "{":
|
|
123
|
+
depth += 1
|
|
124
|
+
elif pattern[i] == "}":
|
|
125
|
+
depth -= 1
|
|
126
|
+
if depth == 0:
|
|
127
|
+
end = i
|
|
128
|
+
break
|
|
129
|
+
|
|
130
|
+
if start >= end:
|
|
131
|
+
return [pattern]
|
|
132
|
+
|
|
133
|
+
prefix = pattern[:start]
|
|
134
|
+
suffix = pattern[end + 1 :]
|
|
135
|
+
brace_content = pattern[start + 1 : end]
|
|
136
|
+
|
|
137
|
+
if ".." in brace_content:
|
|
138
|
+
options = _expand_range(brace_content)
|
|
139
|
+
else:
|
|
140
|
+
options = [opt.strip() for opt in brace_content.split(",")]
|
|
141
|
+
|
|
142
|
+
expanded = []
|
|
143
|
+
for option in options:
|
|
144
|
+
combined = prefix + option + suffix
|
|
145
|
+
expanded.extend(_expand_single_braces(combined))
|
|
146
|
+
|
|
147
|
+
return expanded
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _expand_range(range_spec: str) -> list[str]: # noqa: PLR0911
|
|
151
|
+
if ".." not in range_spec:
|
|
152
|
+
return [range_spec]
|
|
153
|
+
|
|
154
|
+
parts = range_spec.split("..")
|
|
155
|
+
if len(parts) != 2:
|
|
156
|
+
return [range_spec]
|
|
157
|
+
|
|
158
|
+
start, end = parts[0], parts[1]
|
|
159
|
+
|
|
160
|
+
if start.isdigit() and end.isdigit():
|
|
161
|
+
pad_width = max(len(start), len(end)) if start[0] == "0" or end[0] == "0" else 0
|
|
162
|
+
start_num = int(start)
|
|
163
|
+
end_num = int(end)
|
|
164
|
+
|
|
165
|
+
if start_num <= end_num:
|
|
166
|
+
if pad_width > 0:
|
|
167
|
+
return [str(i).zfill(pad_width) for i in range(start_num, end_num + 1)]
|
|
168
|
+
return [str(i) for i in range(start_num, end_num + 1)]
|
|
169
|
+
if pad_width > 0:
|
|
170
|
+
return [str(i).zfill(pad_width) for i in range(start_num, end_num - 1, -1)]
|
|
171
|
+
return [str(i) for i in range(start_num, end_num - 1, -1)]
|
|
172
|
+
|
|
173
|
+
if len(start) == 1 and len(end) == 1 and start.isalpha() and end.isalpha():
|
|
174
|
+
start_ord = ord(start)
|
|
175
|
+
end_ord = ord(end)
|
|
176
|
+
|
|
177
|
+
if start_ord <= end_ord:
|
|
178
|
+
return [chr(i) for i in range(start_ord, end_ord + 1)]
|
|
179
|
+
return [chr(i) for i in range(start_ord, end_ord - 1, -1)]
|
|
180
|
+
|
|
181
|
+
return [range_spec]
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def convert_globstar_to_glob(filter_pattern: str) -> str:
|
|
185
|
+
if "**" not in filter_pattern:
|
|
186
|
+
return filter_pattern
|
|
187
|
+
|
|
188
|
+
parts = filter_pattern.split("/")
|
|
189
|
+
globstar_positions = [i for i, p in enumerate(parts) if p == "**"]
|
|
190
|
+
|
|
191
|
+
num_globstars = len(globstar_positions)
|
|
192
|
+
|
|
193
|
+
if num_globstars <= 1:
|
|
194
|
+
if filter_pattern == "**/*":
|
|
195
|
+
return "*"
|
|
196
|
+
if filter_pattern.startswith("**/"):
|
|
197
|
+
remaining = filter_pattern[3:]
|
|
198
|
+
if "/" not in remaining:
|
|
199
|
+
# Pattern like **/*.ext or **/temp?.*
|
|
200
|
+
# The ** means zero or more directories
|
|
201
|
+
# For zero directories: pattern should be just the filename pattern
|
|
202
|
+
# For one or more: pattern should be */filename
|
|
203
|
+
# Since we can't OR in GLOB, we choose the more permissive option
|
|
204
|
+
# that works with recursive listing
|
|
205
|
+
# Special handling: if it's a simple extension pattern, match broadly
|
|
206
|
+
if remaining.startswith("*."):
|
|
207
|
+
return remaining
|
|
208
|
+
return f"*/{remaining}"
|
|
209
|
+
|
|
210
|
+
return filter_pattern.replace("**", "*")
|
|
211
|
+
|
|
212
|
+
middle_parts = []
|
|
213
|
+
start_idx = globstar_positions[0] + 1
|
|
214
|
+
end_idx = globstar_positions[-1]
|
|
215
|
+
for i in range(start_idx, end_idx):
|
|
216
|
+
if parts[i] != "**":
|
|
217
|
+
middle_parts.append(parts[i])
|
|
218
|
+
|
|
219
|
+
if not middle_parts:
|
|
220
|
+
result = filter_pattern.replace("**", "*")
|
|
221
|
+
else:
|
|
222
|
+
middle_pattern = "/".join(middle_parts)
|
|
223
|
+
last_part = parts[-1] if parts[-1] != "**" else "*"
|
|
224
|
+
|
|
225
|
+
if last_part != "*":
|
|
226
|
+
result = f"*{middle_pattern}*{last_part}"
|
|
227
|
+
else:
|
|
228
|
+
result = f"*{middle_pattern}*"
|
|
229
|
+
|
|
230
|
+
return result
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def apply_glob_filter(
|
|
234
|
+
dc: "DataChain",
|
|
235
|
+
pattern: str,
|
|
236
|
+
list_path: str,
|
|
237
|
+
use_recursive: bool,
|
|
238
|
+
column: str,
|
|
239
|
+
) -> "DataChain":
|
|
240
|
+
from datachain.query.schema import Column
|
|
241
|
+
|
|
242
|
+
chain = ls(dc, list_path, recursive=use_recursive, column=column)
|
|
243
|
+
|
|
244
|
+
if list_path and "/" not in pattern:
|
|
245
|
+
filter_pattern = f"{list_path.rstrip('/')}/{pattern}"
|
|
246
|
+
else:
|
|
247
|
+
filter_pattern = pattern
|
|
248
|
+
|
|
249
|
+
glob_pattern = convert_globstar_to_glob(filter_pattern)
|
|
250
|
+
|
|
251
|
+
return chain.filter(Column(f"{column}.path").glob(glob_pattern))
|
datachain/lib/dc/utils.py
CHANGED
|
@@ -1,12 +1,6 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from functools import wraps
|
|
3
|
-
from typing import
|
|
4
|
-
TYPE_CHECKING,
|
|
5
|
-
Callable,
|
|
6
|
-
Optional,
|
|
7
|
-
TypeVar,
|
|
8
|
-
Union,
|
|
9
|
-
)
|
|
3
|
+
from typing import TYPE_CHECKING, TypeVar
|
|
10
4
|
|
|
11
5
|
import sqlalchemy
|
|
12
6
|
from sqlalchemy.sql.functions import GenericFunction
|
|
@@ -15,9 +9,13 @@ from datachain.func.base import Function
|
|
|
15
9
|
from datachain.lib.data_model import DataModel, DataType
|
|
16
10
|
from datachain.lib.utils import DataChainParamsError
|
|
17
11
|
from datachain.query.schema import DEFAULT_DELIMITER
|
|
12
|
+
from datachain.utils import getenv_bool
|
|
18
13
|
|
|
19
14
|
if TYPE_CHECKING:
|
|
20
|
-
from
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
from typing import Concatenate
|
|
17
|
+
|
|
18
|
+
from typing_extensions import ParamSpec
|
|
21
19
|
|
|
22
20
|
from .datachain import DataChain
|
|
23
21
|
|
|
@@ -26,13 +24,23 @@ if TYPE_CHECKING:
|
|
|
26
24
|
D = TypeVar("D", bound="DataChain")
|
|
27
25
|
|
|
28
26
|
|
|
27
|
+
def is_studio() -> bool:
|
|
28
|
+
"""Check if the runtime environment is Studio (not local)."""
|
|
29
|
+
return getenv_bool("DATACHAIN_IS_STUDIO", default=False)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def is_local() -> bool:
|
|
33
|
+
"""Check if the runtime environment is local (not Studio)."""
|
|
34
|
+
return not is_studio()
|
|
35
|
+
|
|
36
|
+
|
|
29
37
|
def resolve_columns(
|
|
30
38
|
method: "Callable[Concatenate[D, P], D]",
|
|
31
39
|
) -> "Callable[Concatenate[D, P], D]":
|
|
32
40
|
"""Decorator that resolvs input column names to their actual DB names. This is
|
|
33
41
|
specially important for nested columns as user works with them by using dot
|
|
34
|
-
notation e.g (file.
|
|
35
|
-
in DB, e.g
|
|
42
|
+
notation e.g (file.path) but are actually defined with default delimiter
|
|
43
|
+
in DB, e.g file__path.
|
|
36
44
|
If there are any sql functions in arguments, they will just be transferred as is
|
|
37
45
|
to a method.
|
|
38
46
|
"""
|
|
@@ -65,11 +73,11 @@ class DatasetFromValuesError(DataChainParamsError):
|
|
|
65
73
|
super().__init__(f"Dataset{name} from values error: {msg}")
|
|
66
74
|
|
|
67
75
|
|
|
68
|
-
MergeColType =
|
|
76
|
+
MergeColType = str | Function | sqlalchemy.ColumnElement
|
|
69
77
|
|
|
70
78
|
|
|
71
79
|
def _validate_merge_on(
|
|
72
|
-
on:
|
|
80
|
+
on: MergeColType | Sequence[MergeColType],
|
|
73
81
|
ds: "DataChain",
|
|
74
82
|
) -> Sequence[MergeColType]:
|
|
75
83
|
if isinstance(on, (str, sqlalchemy.ColumnElement)):
|
|
@@ -98,12 +106,12 @@ def _get_merge_error_str(col: MergeColType) -> str:
|
|
|
98
106
|
class DatasetMergeError(DataChainParamsError):
|
|
99
107
|
def __init__(
|
|
100
108
|
self,
|
|
101
|
-
on:
|
|
102
|
-
right_on:
|
|
109
|
+
on: MergeColType | Sequence[MergeColType],
|
|
110
|
+
right_on: MergeColType | Sequence[MergeColType] | None,
|
|
103
111
|
msg: str,
|
|
104
112
|
):
|
|
105
113
|
def _get_str(
|
|
106
|
-
on:
|
|
114
|
+
on: MergeColType | Sequence[MergeColType],
|
|
107
115
|
) -> str:
|
|
108
116
|
if not isinstance(on, Sequence):
|
|
109
117
|
return str(on) # type: ignore[unreachable]
|
|
@@ -118,7 +126,7 @@ class DatasetMergeError(DataChainParamsError):
|
|
|
118
126
|
super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")
|
|
119
127
|
|
|
120
128
|
|
|
121
|
-
OutputType =
|
|
129
|
+
OutputType = DataType | Sequence[str] | dict[str, DataType] | None
|
|
122
130
|
|
|
123
131
|
|
|
124
132
|
class Sys(DataModel):
|