datachain 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +57 -212
- datachain/cli.py +6 -38
- datachain/client/fsspec.py +3 -0
- datachain/client/hf.py +47 -0
- datachain/data_storage/metastore.py +2 -29
- datachain/data_storage/sqlite.py +3 -12
- datachain/data_storage/warehouse.py +20 -29
- datachain/dataset.py +44 -32
- datachain/job.py +4 -3
- datachain/lib/arrow.py +21 -5
- datachain/lib/dataset_info.py +4 -0
- datachain/lib/dc.py +183 -59
- datachain/lib/file.py +10 -33
- datachain/lib/hf.py +2 -1
- datachain/lib/listing.py +102 -94
- datachain/lib/listing_info.py +32 -0
- datachain/lib/meta_formats.py +39 -56
- datachain/lib/signal_schema.py +5 -2
- datachain/node.py +13 -0
- datachain/query/dataset.py +12 -105
- datachain/query/metrics.py +8 -0
- datachain/utils.py +5 -0
- {datachain-0.3.9.dist-info → datachain-0.3.11.dist-info}/METADATA +7 -3
- {datachain-0.3.9.dist-info → datachain-0.3.11.dist-info}/RECORD +28 -27
- {datachain-0.3.9.dist-info → datachain-0.3.11.dist-info}/WHEEL +1 -1
- datachain/catalog/subclass.py +0 -60
- {datachain-0.3.9.dist-info → datachain-0.3.11.dist-info}/LICENSE +0 -0
- {datachain-0.3.9.dist-info → datachain-0.3.11.dist-info}/entry_points.txt +0 -0
- {datachain-0.3.9.dist-info → datachain-0.3.11.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py
CHANGED
|
@@ -27,7 +27,16 @@ from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
|
27
27
|
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
|
|
28
28
|
from datachain.lib.dataset_info import DatasetInfo
|
|
29
29
|
from datachain.lib.file import ExportPlacement as FileExportPlacement
|
|
30
|
-
from datachain.lib.file import File, IndexedFile,
|
|
30
|
+
from datachain.lib.file import File, IndexedFile, get_file_type
|
|
31
|
+
from datachain.lib.listing import (
|
|
32
|
+
is_listing_dataset,
|
|
33
|
+
is_listing_expired,
|
|
34
|
+
is_listing_subset,
|
|
35
|
+
list_bucket,
|
|
36
|
+
ls,
|
|
37
|
+
parse_listing_uri,
|
|
38
|
+
)
|
|
39
|
+
from datachain.lib.listing_info import ListingInfo
|
|
31
40
|
from datachain.lib.meta_formats import read_meta, read_schema
|
|
32
41
|
from datachain.lib.model_store import ModelStore
|
|
33
42
|
from datachain.lib.settings import Settings
|
|
@@ -47,7 +56,7 @@ from datachain.query.dataset import (
|
|
|
47
56
|
PartitionByType,
|
|
48
57
|
detach,
|
|
49
58
|
)
|
|
50
|
-
from datachain.query.schema import Column, DatasetRow
|
|
59
|
+
from datachain.query.schema import DEFAULT_DELIMITER, Column, DatasetRow
|
|
51
60
|
from datachain.sql.functions import path as pathfunc
|
|
52
61
|
from datachain.utils import inside_notebook
|
|
53
62
|
|
|
@@ -103,11 +112,31 @@ class DatasetFromValuesError(DataChainParamsError): # noqa: D101
|
|
|
103
112
|
super().__init__(f"Dataset{name} from values error: {msg}")
|
|
104
113
|
|
|
105
114
|
|
|
115
|
+
def _get_merge_error_str(col: Union[str, sqlalchemy.ColumnElement]) -> str:
|
|
116
|
+
if isinstance(col, str):
|
|
117
|
+
return col
|
|
118
|
+
if isinstance(col, sqlalchemy.Column):
|
|
119
|
+
return col.name.replace(DEFAULT_DELIMITER, ".")
|
|
120
|
+
if isinstance(col, sqlalchemy.ColumnElement) and hasattr(col, "name"):
|
|
121
|
+
return f"{col.name} expression"
|
|
122
|
+
return str(col)
|
|
123
|
+
|
|
124
|
+
|
|
106
125
|
class DatasetMergeError(DataChainParamsError): # noqa: D101
|
|
107
|
-
def __init__(
|
|
108
|
-
|
|
126
|
+
def __init__( # noqa: D107
|
|
127
|
+
self,
|
|
128
|
+
on: Sequence[Union[str, sqlalchemy.ColumnElement]],
|
|
129
|
+
right_on: Optional[Sequence[Union[str, sqlalchemy.ColumnElement]]],
|
|
130
|
+
msg: str,
|
|
131
|
+
):
|
|
132
|
+
def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str:
|
|
133
|
+
if not isinstance(on, Sequence):
|
|
134
|
+
return str(on) # type: ignore[unreachable]
|
|
135
|
+
return ", ".join([_get_merge_error_str(col) for col in on])
|
|
136
|
+
|
|
137
|
+
on_str = _get_str(on)
|
|
109
138
|
right_on_str = (
|
|
110
|
-
", right_on='" +
|
|
139
|
+
", right_on='" + _get_str(right_on) + "'"
|
|
111
140
|
if right_on and isinstance(right_on, Sequence)
|
|
112
141
|
else ""
|
|
113
142
|
)
|
|
@@ -130,7 +159,7 @@ class Sys(DataModel):
|
|
|
130
159
|
|
|
131
160
|
|
|
132
161
|
class DataChain(DatasetQuery):
|
|
133
|
-
"""
|
|
162
|
+
"""DataChain - a data structure for batch data processing and evaluation.
|
|
134
163
|
|
|
135
164
|
It represents a sequence of data manipulation steps such as reading data from
|
|
136
165
|
storages, running AI or LLM models or calling external services API to validate or
|
|
@@ -243,13 +272,24 @@ class DataChain(DatasetQuery):
|
|
|
243
272
|
"""Returns Column instance with a type if name is found in current schema,
|
|
244
273
|
otherwise raises an exception.
|
|
245
274
|
"""
|
|
246
|
-
|
|
275
|
+
if "." in name:
|
|
276
|
+
name_path = name.split(".")
|
|
277
|
+
elif DEFAULT_DELIMITER in name:
|
|
278
|
+
name_path = name.split(DEFAULT_DELIMITER)
|
|
279
|
+
else:
|
|
280
|
+
name_path = [name]
|
|
247
281
|
for path, type_, _, _ in self.signals_schema.get_flat_tree():
|
|
248
282
|
if path == name_path:
|
|
249
283
|
return Column(name, python_to_sql(type_))
|
|
250
284
|
|
|
251
285
|
raise ValueError(f"Column with name {name} not found in the schema")
|
|
252
286
|
|
|
287
|
+
def c(self, column: Union[str, Column]) -> Column:
|
|
288
|
+
"""Returns Column instance attached to the current chain."""
|
|
289
|
+
c = self.column(column) if isinstance(column, str) else self.column(column.name)
|
|
290
|
+
c.table = self.table
|
|
291
|
+
return c
|
|
292
|
+
|
|
253
293
|
def print_schema(self) -> None:
|
|
254
294
|
"""Print schema of the chain."""
|
|
255
295
|
self._effective_signals_schema.print_tree()
|
|
@@ -311,7 +351,7 @@ class DataChain(DatasetQuery):
|
|
|
311
351
|
@classmethod
|
|
312
352
|
def from_storage(
|
|
313
353
|
cls,
|
|
314
|
-
|
|
354
|
+
uri,
|
|
315
355
|
*,
|
|
316
356
|
type: Literal["binary", "text", "image"] = "binary",
|
|
317
357
|
session: Optional[Session] = None,
|
|
@@ -320,41 +360,73 @@ class DataChain(DatasetQuery):
|
|
|
320
360
|
recursive: Optional[bool] = True,
|
|
321
361
|
object_name: str = "file",
|
|
322
362
|
update: bool = False,
|
|
323
|
-
|
|
363
|
+
anon: bool = False,
|
|
324
364
|
) -> "Self":
|
|
325
365
|
"""Get data from a storage as a list of file with all file attributes.
|
|
326
366
|
It returns the chain itself as usual.
|
|
327
367
|
|
|
328
368
|
Parameters:
|
|
329
|
-
|
|
369
|
+
uri : storage URI with directory. URI must start with storage prefix such
|
|
330
370
|
as `s3://`, `gs://`, `az://` or "file:///"
|
|
331
371
|
type : read file as "binary", "text", or "image" data. Default is "binary".
|
|
332
372
|
recursive : search recursively for the given path.
|
|
333
373
|
object_name : Created object column name.
|
|
334
374
|
update : force storage reindexing. Default is False.
|
|
375
|
+
anon : If True, we will treat cloud bucket as public one
|
|
335
376
|
|
|
336
377
|
Example:
|
|
337
378
|
```py
|
|
338
379
|
chain = DataChain.from_storage("s3://my-bucket/my-dir")
|
|
339
380
|
```
|
|
340
381
|
"""
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
in_memory=in_memory,
|
|
350
|
-
**kwargs,
|
|
351
|
-
)
|
|
352
|
-
.map(**{object_name: func})
|
|
353
|
-
.select(object_name)
|
|
382
|
+
file_type = get_file_type(type)
|
|
383
|
+
|
|
384
|
+
client_config = {"anon": True} if anon else None
|
|
385
|
+
|
|
386
|
+
session = Session.get(session, client_config=client_config, in_memory=in_memory)
|
|
387
|
+
|
|
388
|
+
list_dataset_name, list_uri, list_path = parse_listing_uri(
|
|
389
|
+
uri, session.catalog.cache, session.catalog.client_config
|
|
354
390
|
)
|
|
391
|
+
need_listing = True
|
|
392
|
+
|
|
393
|
+
for ds in cls.listings(session=session, in_memory=in_memory).collect("listing"):
|
|
394
|
+
if (
|
|
395
|
+
not is_listing_expired(ds.created_at) # type: ignore[union-attr]
|
|
396
|
+
and is_listing_subset(ds.name, list_dataset_name) # type: ignore[union-attr]
|
|
397
|
+
and not update
|
|
398
|
+
):
|
|
399
|
+
need_listing = False
|
|
400
|
+
list_dataset_name = ds.name # type: ignore[union-attr]
|
|
401
|
+
|
|
402
|
+
if need_listing:
|
|
403
|
+
# caching new listing to special listing dataset
|
|
404
|
+
(
|
|
405
|
+
cls.from_records(
|
|
406
|
+
DataChain.DEFAULT_FILE_RECORD,
|
|
407
|
+
session=session,
|
|
408
|
+
settings=settings,
|
|
409
|
+
in_memory=in_memory,
|
|
410
|
+
)
|
|
411
|
+
.gen(
|
|
412
|
+
list_bucket(list_uri, client_config=session.catalog.client_config),
|
|
413
|
+
output={f"{object_name}": File},
|
|
414
|
+
)
|
|
415
|
+
.save(list_dataset_name, listing=True)
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
dc = cls.from_dataset(list_dataset_name, session=session)
|
|
419
|
+
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
|
|
420
|
+
|
|
421
|
+
return ls(dc, list_path, recursive=recursive, object_name=object_name)
|
|
355
422
|
|
|
356
423
|
@classmethod
|
|
357
|
-
def from_dataset(
|
|
424
|
+
def from_dataset(
|
|
425
|
+
cls,
|
|
426
|
+
name: str,
|
|
427
|
+
version: Optional[int] = None,
|
|
428
|
+
session: Optional[Session] = None,
|
|
429
|
+
) -> "DataChain":
|
|
358
430
|
"""Get data from a saved Dataset. It returns the chain itself.
|
|
359
431
|
|
|
360
432
|
Parameters:
|
|
@@ -366,7 +438,7 @@ class DataChain(DatasetQuery):
|
|
|
366
438
|
chain = DataChain.from_dataset("my_cats")
|
|
367
439
|
```
|
|
368
440
|
"""
|
|
369
|
-
return DataChain(name=name, version=version)
|
|
441
|
+
return DataChain(name=name, version=version, session=session)
|
|
370
442
|
|
|
371
443
|
@classmethod
|
|
372
444
|
def from_json(
|
|
@@ -419,7 +491,7 @@ class DataChain(DatasetQuery):
|
|
|
419
491
|
object_name = jmespath_to_name(jmespath)
|
|
420
492
|
if not object_name:
|
|
421
493
|
object_name = meta_type
|
|
422
|
-
chain = DataChain.from_storage(
|
|
494
|
+
chain = DataChain.from_storage(uri=path, type=type, **kwargs)
|
|
423
495
|
signal_dict = {
|
|
424
496
|
object_name: read_meta(
|
|
425
497
|
schema_from=schema_from,
|
|
@@ -479,7 +551,7 @@ class DataChain(DatasetQuery):
|
|
|
479
551
|
object_name = jmespath_to_name(jmespath)
|
|
480
552
|
if not object_name:
|
|
481
553
|
object_name = meta_type
|
|
482
|
-
chain = DataChain.from_storage(
|
|
554
|
+
chain = DataChain.from_storage(uri=path, type=type, **kwargs)
|
|
483
555
|
signal_dict = {
|
|
484
556
|
object_name: read_meta(
|
|
485
557
|
schema_from=schema_from,
|
|
@@ -500,6 +572,7 @@ class DataChain(DatasetQuery):
|
|
|
500
572
|
settings: Optional[dict] = None,
|
|
501
573
|
in_memory: bool = False,
|
|
502
574
|
object_name: str = "dataset",
|
|
575
|
+
include_listing: bool = False,
|
|
503
576
|
) -> "DataChain":
|
|
504
577
|
"""Generate chain with list of registered datasets.
|
|
505
578
|
|
|
@@ -517,7 +590,9 @@ class DataChain(DatasetQuery):
|
|
|
517
590
|
|
|
518
591
|
datasets = [
|
|
519
592
|
DatasetInfo.from_models(d, v, j)
|
|
520
|
-
for d, v, j in catalog.list_datasets_versions(
|
|
593
|
+
for d, v, j in catalog.list_datasets_versions(
|
|
594
|
+
include_listing=include_listing
|
|
595
|
+
)
|
|
521
596
|
]
|
|
522
597
|
|
|
523
598
|
return cls.from_values(
|
|
@@ -528,6 +603,42 @@ class DataChain(DatasetQuery):
|
|
|
528
603
|
**{object_name: datasets}, # type: ignore[arg-type]
|
|
529
604
|
)
|
|
530
605
|
|
|
606
|
+
@classmethod
|
|
607
|
+
def listings(
|
|
608
|
+
cls,
|
|
609
|
+
session: Optional[Session] = None,
|
|
610
|
+
in_memory: bool = False,
|
|
611
|
+
object_name: str = "listing",
|
|
612
|
+
**kwargs,
|
|
613
|
+
) -> "DataChain":
|
|
614
|
+
"""Generate chain with list of cached listings.
|
|
615
|
+
Listing is a special kind of dataset which has directory listing data of
|
|
616
|
+
some underlying storage (e.g S3 bucket).
|
|
617
|
+
|
|
618
|
+
Example:
|
|
619
|
+
```py
|
|
620
|
+
from datachain import DataChain
|
|
621
|
+
DataChain.listings().show()
|
|
622
|
+
```
|
|
623
|
+
"""
|
|
624
|
+
session = Session.get(session, in_memory=in_memory)
|
|
625
|
+
catalog = kwargs.get("catalog") or session.catalog
|
|
626
|
+
|
|
627
|
+
listings = [
|
|
628
|
+
ListingInfo.from_models(d, v, j)
|
|
629
|
+
for d, v, j in catalog.list_datasets_versions(
|
|
630
|
+
include_listing=True, **kwargs
|
|
631
|
+
)
|
|
632
|
+
if is_listing_dataset(d.name)
|
|
633
|
+
]
|
|
634
|
+
|
|
635
|
+
return cls.from_values(
|
|
636
|
+
session=session,
|
|
637
|
+
in_memory=in_memory,
|
|
638
|
+
output={object_name: ListingInfo},
|
|
639
|
+
**{object_name: listings}, # type: ignore[arg-type]
|
|
640
|
+
)
|
|
641
|
+
|
|
531
642
|
def print_json_schema( # type: ignore[override]
|
|
532
643
|
self, jmespath: Optional[str] = None, model_name: Optional[str] = None
|
|
533
644
|
) -> "Self":
|
|
@@ -570,7 +681,7 @@ class DataChain(DatasetQuery):
|
|
|
570
681
|
)
|
|
571
682
|
|
|
572
683
|
def save( # type: ignore[override]
|
|
573
|
-
self, name: Optional[str] = None, version: Optional[int] = None
|
|
684
|
+
self, name: Optional[str] = None, version: Optional[int] = None, **kwargs
|
|
574
685
|
) -> "Self":
|
|
575
686
|
"""Save to a Dataset. It returns the chain itself.
|
|
576
687
|
|
|
@@ -580,7 +691,7 @@ class DataChain(DatasetQuery):
|
|
|
580
691
|
version : version of a dataset. Default - the last version that exist.
|
|
581
692
|
"""
|
|
582
693
|
schema = self.signals_schema.clone_without_sys_signals().serialize()
|
|
583
|
-
return super().save(name=name, version=version, feature_schema=schema)
|
|
694
|
+
return super().save(name=name, version=version, feature_schema=schema, **kwargs)
|
|
584
695
|
|
|
585
696
|
def apply(self, func, *args, **kwargs):
|
|
586
697
|
"""Apply any function to the chain.
|
|
@@ -1060,8 +1171,17 @@ class DataChain(DatasetQuery):
|
|
|
1060
1171
|
def merge(
|
|
1061
1172
|
self,
|
|
1062
1173
|
right_ds: "DataChain",
|
|
1063
|
-
on: Union[
|
|
1064
|
-
|
|
1174
|
+
on: Union[
|
|
1175
|
+
str,
|
|
1176
|
+
sqlalchemy.ColumnElement,
|
|
1177
|
+
Sequence[Union[str, sqlalchemy.ColumnElement]],
|
|
1178
|
+
],
|
|
1179
|
+
right_on: Union[
|
|
1180
|
+
str,
|
|
1181
|
+
sqlalchemy.ColumnElement,
|
|
1182
|
+
Sequence[Union[str, sqlalchemy.ColumnElement]],
|
|
1183
|
+
None,
|
|
1184
|
+
] = None,
|
|
1065
1185
|
inner=False,
|
|
1066
1186
|
rname="right_",
|
|
1067
1187
|
) -> "Self":
|
|
@@ -1086,7 +1206,7 @@ class DataChain(DatasetQuery):
|
|
|
1086
1206
|
if on is None:
|
|
1087
1207
|
raise DatasetMergeError(["None"], None, "'on' must be specified")
|
|
1088
1208
|
|
|
1089
|
-
if isinstance(on, str):
|
|
1209
|
+
if isinstance(on, (str, sqlalchemy.ColumnElement)):
|
|
1090
1210
|
on = [on]
|
|
1091
1211
|
elif not isinstance(on, Sequence):
|
|
1092
1212
|
raise DatasetMergeError(
|
|
@@ -1095,19 +1215,15 @@ class DataChain(DatasetQuery):
|
|
|
1095
1215
|
f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'",
|
|
1096
1216
|
)
|
|
1097
1217
|
|
|
1098
|
-
signals_schema = self.signals_schema.clone_without_sys_signals()
|
|
1099
|
-
on_columns: list[str] = signals_schema.resolve(*on).db_signals() # type: ignore[assignment]
|
|
1100
|
-
|
|
1101
|
-
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
|
|
1102
1218
|
if right_on is not None:
|
|
1103
|
-
if isinstance(right_on, str):
|
|
1219
|
+
if isinstance(right_on, (str, sqlalchemy.ColumnElement)):
|
|
1104
1220
|
right_on = [right_on]
|
|
1105
1221
|
elif not isinstance(right_on, Sequence):
|
|
1106
1222
|
raise DatasetMergeError(
|
|
1107
1223
|
on,
|
|
1108
1224
|
right_on,
|
|
1109
1225
|
"'right_on' must be 'str' or 'Sequence' object"
|
|
1110
|
-
f" but got type '{right_on}'",
|
|
1226
|
+
f" but got type '{type(right_on)}'",
|
|
1111
1227
|
)
|
|
1112
1228
|
|
|
1113
1229
|
if len(right_on) != len(on):
|
|
@@ -1115,34 +1231,39 @@ class DataChain(DatasetQuery):
|
|
|
1115
1231
|
on, right_on, "'on' and 'right_on' must have the same length'"
|
|
1116
1232
|
)
|
|
1117
1233
|
|
|
1118
|
-
right_on_columns: list[str] = right_signals_schema.resolve(
|
|
1119
|
-
*right_on
|
|
1120
|
-
).db_signals() # type: ignore[assignment]
|
|
1121
|
-
|
|
1122
|
-
if len(right_on_columns) != len(on_columns):
|
|
1123
|
-
on_str = ", ".join(right_on_columns)
|
|
1124
|
-
right_on_str = ", ".join(right_on_columns)
|
|
1125
|
-
raise DatasetMergeError(
|
|
1126
|
-
on,
|
|
1127
|
-
right_on,
|
|
1128
|
-
"'on' and 'right_on' must have the same number of columns in db'."
|
|
1129
|
-
f" on -> {on_str}, right_on -> {right_on_str}",
|
|
1130
|
-
)
|
|
1131
|
-
else:
|
|
1132
|
-
right_on = on
|
|
1133
|
-
right_on_columns = on_columns
|
|
1134
|
-
|
|
1135
1234
|
if self == right_ds:
|
|
1136
1235
|
right_ds = right_ds.clone(new_table=True)
|
|
1137
1236
|
|
|
1237
|
+
errors = []
|
|
1238
|
+
|
|
1239
|
+
def _resolve(
|
|
1240
|
+
ds: DataChain,
|
|
1241
|
+
col: Union[str, sqlalchemy.ColumnElement],
|
|
1242
|
+
side: Union[str, None],
|
|
1243
|
+
):
|
|
1244
|
+
try:
|
|
1245
|
+
return ds.c(col) if isinstance(col, (str, C)) else col
|
|
1246
|
+
except ValueError:
|
|
1247
|
+
if side:
|
|
1248
|
+
errors.append(f"{_get_merge_error_str(col)} in {side}")
|
|
1249
|
+
|
|
1138
1250
|
ops = [
|
|
1139
|
-
self
|
|
1140
|
-
|
|
1251
|
+
_resolve(self, left, "left")
|
|
1252
|
+
== _resolve(right_ds, right, "right" if right_on else None)
|
|
1253
|
+
for left, right in zip(on, right_on or on)
|
|
1141
1254
|
]
|
|
1142
1255
|
|
|
1256
|
+
if errors:
|
|
1257
|
+
raise DatasetMergeError(
|
|
1258
|
+
on, right_on, f"Could not resolve {', '.join(errors)}"
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1143
1261
|
ds = self.join(right_ds, sqlalchemy.and_(*ops), inner, rname + "{name}")
|
|
1144
1262
|
|
|
1145
1263
|
ds.feature_schema = None
|
|
1264
|
+
|
|
1265
|
+
signals_schema = self.signals_schema.clone_without_sys_signals()
|
|
1266
|
+
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
|
|
1146
1267
|
ds.signals_schema = SignalSchema({"sys": Sys}) | signals_schema.merge(
|
|
1147
1268
|
right_signals_schema, rname
|
|
1148
1269
|
)
|
|
@@ -1665,7 +1786,10 @@ class DataChain(DatasetQuery):
|
|
|
1665
1786
|
|
|
1666
1787
|
if schema:
|
|
1667
1788
|
signal_schema = SignalSchema(schema)
|
|
1668
|
-
columns =
|
|
1789
|
+
columns = [
|
|
1790
|
+
sqlalchemy.Column(c.name, c.type) # type: ignore[union-attr]
|
|
1791
|
+
for c in signal_schema.db_signals(as_columns=True) # type: ignore[assignment]
|
|
1792
|
+
]
|
|
1669
1793
|
else:
|
|
1670
1794
|
columns = [
|
|
1671
1795
|
sqlalchemy.Column(name, typ)
|
datachain/lib/file.py
CHANGED
|
@@ -349,39 +349,6 @@ class ImageFile(File):
|
|
|
349
349
|
self.read().save(destination)
|
|
350
350
|
|
|
351
351
|
|
|
352
|
-
def get_file(type_: Literal["binary", "text", "image"] = "binary"):
|
|
353
|
-
file: type[File] = File
|
|
354
|
-
if type_ == "text":
|
|
355
|
-
file = TextFile
|
|
356
|
-
elif type_ == "image":
|
|
357
|
-
file = ImageFile # type: ignore[assignment]
|
|
358
|
-
|
|
359
|
-
def get_file_type(
|
|
360
|
-
source: str,
|
|
361
|
-
path: str,
|
|
362
|
-
size: int,
|
|
363
|
-
version: str,
|
|
364
|
-
etag: str,
|
|
365
|
-
is_latest: bool,
|
|
366
|
-
last_modified: datetime,
|
|
367
|
-
location: Optional[Union[dict, list[dict]]],
|
|
368
|
-
vtype: str,
|
|
369
|
-
) -> file: # type: ignore[valid-type]
|
|
370
|
-
return file(
|
|
371
|
-
source=source,
|
|
372
|
-
path=path,
|
|
373
|
-
size=size,
|
|
374
|
-
version=version,
|
|
375
|
-
etag=etag,
|
|
376
|
-
is_latest=is_latest,
|
|
377
|
-
last_modified=last_modified,
|
|
378
|
-
location=location,
|
|
379
|
-
vtype=vtype,
|
|
380
|
-
)
|
|
381
|
-
|
|
382
|
-
return get_file_type
|
|
383
|
-
|
|
384
|
-
|
|
385
352
|
class IndexedFile(DataModel):
|
|
386
353
|
"""Metadata indexed from tabular files.
|
|
387
354
|
|
|
@@ -390,3 +357,13 @@ class IndexedFile(DataModel):
|
|
|
390
357
|
|
|
391
358
|
file: File
|
|
392
359
|
index: int
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def get_file_type(type_: Literal["binary", "text", "image"] = "binary") -> type[File]:
|
|
363
|
+
file: type[File] = File
|
|
364
|
+
if type_ == "text":
|
|
365
|
+
file = TextFile
|
|
366
|
+
elif type_ == "image":
|
|
367
|
+
file = ImageFile # type: ignore[assignment]
|
|
368
|
+
|
|
369
|
+
return file
|
datachain/lib/hf.py
CHANGED
|
@@ -99,7 +99,8 @@ class HFGenerator(Generator):
|
|
|
99
99
|
|
|
100
100
|
def stream_splits(ds: Union[str, HFDatasetType], *args, **kwargs):
|
|
101
101
|
if isinstance(ds, str):
|
|
102
|
-
|
|
102
|
+
kwargs["streaming"] = True
|
|
103
|
+
ds = load_dataset(ds, *args, **kwargs)
|
|
103
104
|
if isinstance(ds, (DatasetDict, IterableDatasetDict)):
|
|
104
105
|
return ds
|
|
105
106
|
return {"": ds}
|
datachain/lib/listing.py
CHANGED
|
@@ -1,103 +1,26 @@
|
|
|
1
|
-
import
|
|
2
|
-
from collections.abc import
|
|
3
|
-
from
|
|
1
|
+
import posixpath
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from datetime import datetime, timedelta, timezone
|
|
4
|
+
from typing import TYPE_CHECKING, Callable, Optional
|
|
4
5
|
|
|
5
|
-
from botocore.exceptions import ClientError
|
|
6
6
|
from fsspec.asyn import get_loop
|
|
7
|
+
from sqlalchemy.sql.expression import true
|
|
7
8
|
|
|
8
9
|
from datachain.asyn import iter_over_async
|
|
9
10
|
from datachain.client import Client
|
|
10
|
-
from datachain.error import ClientError as DataChainClientError
|
|
11
11
|
from datachain.lib.file import File
|
|
12
|
+
from datachain.query.schema import Column
|
|
13
|
+
from datachain.sql.functions import path as pathfunc
|
|
14
|
+
from datachain.utils import uses_glob
|
|
12
15
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
infos = await client.ls_dir(path)
|
|
22
|
-
files = []
|
|
23
|
-
subdirs = set()
|
|
24
|
-
for info in infos:
|
|
25
|
-
full_path = info["name"]
|
|
26
|
-
subprefix = client.rel_path(full_path)
|
|
27
|
-
if prefix.strip(DELIMITER) == subprefix.strip(DELIMITER):
|
|
28
|
-
continue
|
|
29
|
-
if info["type"] == "directory":
|
|
30
|
-
subdirs.add(subprefix)
|
|
31
|
-
else:
|
|
32
|
-
files.append(client.info_to_file(info, subprefix))
|
|
33
|
-
if files:
|
|
34
|
-
await result_queue.put(files)
|
|
35
|
-
return subdirs
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
async def _fetch(
|
|
39
|
-
client, start_prefix: str, result_queue: ResultQueue, fetch_workers
|
|
40
|
-
) -> None:
|
|
41
|
-
loop = get_loop()
|
|
42
|
-
|
|
43
|
-
queue: asyncio.Queue[str] = asyncio.Queue()
|
|
44
|
-
queue.put_nowait(start_prefix)
|
|
45
|
-
|
|
46
|
-
async def process(queue) -> None:
|
|
47
|
-
while True:
|
|
48
|
-
prefix = await queue.get()
|
|
49
|
-
try:
|
|
50
|
-
subdirs = await _fetch_dir(client, prefix, result_queue)
|
|
51
|
-
for subdir in subdirs:
|
|
52
|
-
queue.put_nowait(subdir)
|
|
53
|
-
except Exception:
|
|
54
|
-
while not queue.empty():
|
|
55
|
-
queue.get_nowait()
|
|
56
|
-
queue.task_done()
|
|
57
|
-
raise
|
|
58
|
-
|
|
59
|
-
finally:
|
|
60
|
-
queue.task_done()
|
|
61
|
-
|
|
62
|
-
try:
|
|
63
|
-
workers: list[asyncio.Task] = [
|
|
64
|
-
loop.create_task(process(queue)) for _ in range(fetch_workers)
|
|
65
|
-
]
|
|
66
|
-
|
|
67
|
-
# Wait for all fetch tasks to complete
|
|
68
|
-
await queue.join()
|
|
69
|
-
# Stop the workers
|
|
70
|
-
excs = []
|
|
71
|
-
for worker in workers:
|
|
72
|
-
if worker.done() and (exc := worker.exception()):
|
|
73
|
-
excs.append(exc)
|
|
74
|
-
else:
|
|
75
|
-
worker.cancel()
|
|
76
|
-
if excs:
|
|
77
|
-
raise excs[0]
|
|
78
|
-
except ClientError as exc:
|
|
79
|
-
raise DataChainClientError(
|
|
80
|
-
exc.response.get("Error", {}).get("Message") or exc,
|
|
81
|
-
exc.response.get("Error", {}).get("Code"),
|
|
82
|
-
) from exc
|
|
83
|
-
finally:
|
|
84
|
-
# This ensures the progress bar is closed before any exceptions are raised
|
|
85
|
-
result_queue.put_nowait(None)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
async def _scandir(client, prefix, fetch_workers) -> AsyncIterator:
|
|
89
|
-
"""Recursively goes through dir tree and yields files"""
|
|
90
|
-
result_queue: ResultQueue = asyncio.Queue()
|
|
91
|
-
loop = get_loop()
|
|
92
|
-
main_task = loop.create_task(_fetch(client, prefix, result_queue, fetch_workers))
|
|
93
|
-
while (files := await result_queue.get()) is not None:
|
|
94
|
-
for f in files:
|
|
95
|
-
yield f
|
|
96
|
-
|
|
97
|
-
await main_task
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def list_bucket(uri: str, client_config=None, fetch_workers=FETCH_WORKERS) -> Callable:
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from datachain.lib.dc import DataChain
|
|
18
|
+
|
|
19
|
+
LISTING_TTL = 4 * 60 * 60 # cached listing lasts 4 hours
|
|
20
|
+
LISTING_PREFIX = "lst__" # listing datasets start with this name
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def list_bucket(uri: str, client_config=None) -> Callable:
|
|
101
24
|
"""
|
|
102
25
|
Function that returns another generator function that yields File objects
|
|
103
26
|
from bucket where each File represents one bucket entry.
|
|
@@ -106,6 +29,91 @@ def list_bucket(uri: str, client_config=None, fetch_workers=FETCH_WORKERS) -> Ca
|
|
|
106
29
|
def list_func() -> Iterator[File]:
|
|
107
30
|
config = client_config or {}
|
|
108
31
|
client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type]
|
|
109
|
-
|
|
32
|
+
for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()):
|
|
33
|
+
for entry in entries:
|
|
34
|
+
yield entry.to_file(client.uri)
|
|
110
35
|
|
|
111
36
|
return list_func
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def ls(
|
|
40
|
+
dc: "DataChain",
|
|
41
|
+
path: str,
|
|
42
|
+
recursive: Optional[bool] = True,
|
|
43
|
+
object_name="file",
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Return files by some path from DataChain instance which contains bucket listing.
|
|
47
|
+
Path can have globs.
|
|
48
|
+
If recursive is set to False, only first level children will be returned by
|
|
49
|
+
specified path
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def _file_c(name: str) -> Column:
|
|
53
|
+
return Column(f"{object_name}.{name}")
|
|
54
|
+
|
|
55
|
+
dc = dc.filter(_file_c("is_latest") == true())
|
|
56
|
+
|
|
57
|
+
if recursive:
|
|
58
|
+
if not path or path == "/":
|
|
59
|
+
# root of a bucket, returning all latest files from it
|
|
60
|
+
return dc
|
|
61
|
+
|
|
62
|
+
if not uses_glob(path):
|
|
63
|
+
# path is not glob, so it's pointing to some directory or a specific
|
|
64
|
+
# file and we are adding proper filter for it
|
|
65
|
+
return dc.filter(
|
|
66
|
+
(_file_c("path") == path)
|
|
67
|
+
| (_file_c("path").glob(path.rstrip("/") + "/*"))
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# path has glob syntax so we are returning glob filter
|
|
71
|
+
return dc.filter(_file_c("path").glob(path))
|
|
72
|
+
# returning only first level children by path
|
|
73
|
+
return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*"))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]:
|
|
77
|
+
"""
|
|
78
|
+
Parsing uri and returns listing dataset name, listing uri and listing path
|
|
79
|
+
"""
|
|
80
|
+
client, path = Client.parse_url(uri, cache, **client_config)
|
|
81
|
+
|
|
82
|
+
# clean path without globs
|
|
83
|
+
lst_uri_path = (
|
|
84
|
+
posixpath.dirname(path) if uses_glob(path) or client.fs.isfile(uri) else path
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
lst_uri = f"{client.uri}/{lst_uri_path.lstrip('/')}"
|
|
88
|
+
ds_name = (
|
|
89
|
+
f"{LISTING_PREFIX}{client.uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return ds_name, lst_uri, path
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def is_listing_dataset(name: str) -> bool:
|
|
96
|
+
"""Returns True if it's special listing dataset"""
|
|
97
|
+
return name.startswith(LISTING_PREFIX)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def listing_uri_from_name(dataset_name: str) -> str:
|
|
101
|
+
"""Returns clean storage URI from listing dataset name"""
|
|
102
|
+
if not is_listing_dataset(dataset_name):
|
|
103
|
+
raise ValueError(f"Dataset {dataset_name} is not a listing")
|
|
104
|
+
return dataset_name.removeprefix(LISTING_PREFIX)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def is_listing_expired(created_at: datetime) -> bool:
|
|
108
|
+
"""Checks if listing has expired based on it's creation date"""
|
|
109
|
+
return datetime.now(timezone.utc) > created_at + timedelta(seconds=LISTING_TTL)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def is_listing_subset(ds1_name: str, ds2_name: str) -> bool:
|
|
113
|
+
"""
|
|
114
|
+
Checks if one listing contains another one by comparing corresponding dataset names
|
|
115
|
+
"""
|
|
116
|
+
assert ds1_name.endswith("/")
|
|
117
|
+
assert ds2_name.endswith("/")
|
|
118
|
+
|
|
119
|
+
return ds2_name.startswith(ds1_name)
|