datachain 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +13 -91
- datachain/cli.py +6 -38
- datachain/client/fsspec.py +3 -0
- datachain/client/hf.py +47 -0
- datachain/data_storage/metastore.py +2 -29
- datachain/data_storage/sqlite.py +3 -12
- datachain/data_storage/warehouse.py +20 -29
- datachain/dataset.py +44 -32
- datachain/lib/arrow.py +22 -6
- datachain/lib/dataset_info.py +4 -0
- datachain/lib/dc.py +149 -35
- datachain/lib/file.py +10 -33
- datachain/lib/hf.py +2 -1
- datachain/lib/listing.py +102 -94
- datachain/lib/listing_info.py +32 -0
- datachain/lib/meta_formats.py +4 -4
- datachain/lib/signal_schema.py +5 -2
- datachain/lib/webdataset.py +1 -1
- datachain/node.py +13 -0
- datachain/query/dataset.py +25 -87
- datachain/query/metrics.py +8 -0
- datachain/utils.py +5 -0
- {datachain-0.3.8.dist-info → datachain-0.3.10.dist-info}/METADATA +14 -14
- {datachain-0.3.8.dist-info → datachain-0.3.10.dist-info}/RECORD +28 -26
- {datachain-0.3.8.dist-info → datachain-0.3.10.dist-info}/WHEEL +1 -1
- {datachain-0.3.8.dist-info → datachain-0.3.10.dist-info}/LICENSE +0 -0
- {datachain-0.3.8.dist-info → datachain-0.3.10.dist-info}/entry_points.txt +0 -0
- {datachain-0.3.8.dist-info → datachain-0.3.10.dist-info}/top_level.txt +0 -0
datachain/lib/arrow.py
CHANGED
|
@@ -7,7 +7,9 @@ import pyarrow as pa
|
|
|
7
7
|
from pyarrow.dataset import dataset
|
|
8
8
|
from tqdm import tqdm
|
|
9
9
|
|
|
10
|
+
from datachain.lib.data_model import dict_to_data_model
|
|
10
11
|
from datachain.lib.file import File, IndexedFile
|
|
12
|
+
from datachain.lib.model_store import ModelStore
|
|
11
13
|
from datachain.lib.udf import Generator
|
|
12
14
|
|
|
13
15
|
if TYPE_CHECKING:
|
|
@@ -59,7 +61,13 @@ class ArrowGenerator(Generator):
|
|
|
59
61
|
vals = list(record.values())
|
|
60
62
|
if self.output_schema:
|
|
61
63
|
fields = self.output_schema.model_fields
|
|
62
|
-
|
|
64
|
+
vals_dict = {}
|
|
65
|
+
for (field, field_info), val in zip(fields.items(), vals):
|
|
66
|
+
if ModelStore.is_pydantic(field_info.annotation):
|
|
67
|
+
vals_dict[field] = field_info.annotation(**val) # type: ignore[misc]
|
|
68
|
+
else:
|
|
69
|
+
vals_dict[field] = val
|
|
70
|
+
vals = [self.output_schema(**vals_dict)]
|
|
63
71
|
if self.source:
|
|
64
72
|
yield [IndexedFile(file=file, index=index), *vals]
|
|
65
73
|
else:
|
|
@@ -95,15 +103,15 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
|
|
|
95
103
|
if not column:
|
|
96
104
|
column = f"c{default_column}"
|
|
97
105
|
default_column += 1
|
|
98
|
-
dtype = arrow_type_mapper(field.type) # type: ignore[assignment]
|
|
99
|
-
if field.nullable:
|
|
106
|
+
dtype = arrow_type_mapper(field.type, column) # type: ignore[assignment]
|
|
107
|
+
if field.nullable and not ModelStore.is_pydantic(dtype):
|
|
100
108
|
dtype = Optional[dtype] # type: ignore[assignment]
|
|
101
109
|
output[column] = dtype
|
|
102
110
|
|
|
103
111
|
return output
|
|
104
112
|
|
|
105
113
|
|
|
106
|
-
def arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
|
|
114
|
+
def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa: PLR0911
|
|
107
115
|
"""Convert pyarrow types to basic types."""
|
|
108
116
|
from datetime import datetime
|
|
109
117
|
|
|
@@ -123,7 +131,15 @@ def arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
|
|
|
123
131
|
return str
|
|
124
132
|
if pa.types.is_list(col_type):
|
|
125
133
|
return list[arrow_type_mapper(col_type.value_type)] # type: ignore[return-value, misc]
|
|
126
|
-
if pa.types.is_struct(col_type)
|
|
134
|
+
if pa.types.is_struct(col_type):
|
|
135
|
+
type_dict = {}
|
|
136
|
+
for field in col_type:
|
|
137
|
+
dtype = arrow_type_mapper(field.type, field.name)
|
|
138
|
+
if field.nullable and not ModelStore.is_pydantic(dtype):
|
|
139
|
+
dtype = Optional[dtype] # type: ignore[assignment]
|
|
140
|
+
type_dict[field.name] = dtype
|
|
141
|
+
return dict_to_data_model(column, type_dict)
|
|
142
|
+
if pa.types.is_map(col_type):
|
|
127
143
|
return dict
|
|
128
144
|
if isinstance(col_type, pa.lib.DictionaryType):
|
|
129
145
|
return arrow_type_mapper(col_type.value_type) # type: ignore[return-value]
|
|
@@ -131,7 +147,7 @@ def arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
|
|
|
131
147
|
|
|
132
148
|
|
|
133
149
|
def _nrows_file(file: File, nrows: int) -> str:
|
|
134
|
-
tf = NamedTemporaryFile(delete=False)
|
|
150
|
+
tf = NamedTemporaryFile(delete=False) # noqa: SIM115
|
|
135
151
|
with file.open(mode="r") as reader:
|
|
136
152
|
with open(tf.name, "a") as writer:
|
|
137
153
|
for row, line in enumerate(reader):
|
datachain/lib/dataset_info.py
CHANGED
|
@@ -23,6 +23,8 @@ class DatasetInfo(DataModel):
|
|
|
23
23
|
size: Optional[int] = Field(default=None)
|
|
24
24
|
params: dict[str, str] = Field(default=dict)
|
|
25
25
|
metrics: dict[str, Any] = Field(default=dict)
|
|
26
|
+
error_message: str = Field(default="")
|
|
27
|
+
error_stack: str = Field(default="")
|
|
26
28
|
|
|
27
29
|
@staticmethod
|
|
28
30
|
def _validate_dict(
|
|
@@ -67,4 +69,6 @@ class DatasetInfo(DataModel):
|
|
|
67
69
|
size=version.size,
|
|
68
70
|
params=job.params if job else {},
|
|
69
71
|
metrics=job.metrics if job else {},
|
|
72
|
+
error_message=version.error_message,
|
|
73
|
+
error_stack=version.error_stack,
|
|
70
74
|
)
|
datachain/lib/dc.py
CHANGED
|
@@ -27,7 +27,16 @@ from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
|
27
27
|
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
|
|
28
28
|
from datachain.lib.dataset_info import DatasetInfo
|
|
29
29
|
from datachain.lib.file import ExportPlacement as FileExportPlacement
|
|
30
|
-
from datachain.lib.file import File, IndexedFile,
|
|
30
|
+
from datachain.lib.file import File, IndexedFile, get_file_type
|
|
31
|
+
from datachain.lib.listing import (
|
|
32
|
+
is_listing_dataset,
|
|
33
|
+
is_listing_expired,
|
|
34
|
+
is_listing_subset,
|
|
35
|
+
list_bucket,
|
|
36
|
+
ls,
|
|
37
|
+
parse_listing_uri,
|
|
38
|
+
)
|
|
39
|
+
from datachain.lib.listing_info import ListingInfo
|
|
31
40
|
from datachain.lib.meta_formats import read_meta, read_schema
|
|
32
41
|
from datachain.lib.model_store import ModelStore
|
|
33
42
|
from datachain.lib.settings import Settings
|
|
@@ -311,7 +320,7 @@ class DataChain(DatasetQuery):
|
|
|
311
320
|
@classmethod
|
|
312
321
|
def from_storage(
|
|
313
322
|
cls,
|
|
314
|
-
|
|
323
|
+
uri,
|
|
315
324
|
*,
|
|
316
325
|
type: Literal["binary", "text", "image"] = "binary",
|
|
317
326
|
session: Optional[Session] = None,
|
|
@@ -320,41 +329,73 @@ class DataChain(DatasetQuery):
|
|
|
320
329
|
recursive: Optional[bool] = True,
|
|
321
330
|
object_name: str = "file",
|
|
322
331
|
update: bool = False,
|
|
323
|
-
|
|
332
|
+
anon: bool = False,
|
|
324
333
|
) -> "Self":
|
|
325
334
|
"""Get data from a storage as a list of file with all file attributes.
|
|
326
335
|
It returns the chain itself as usual.
|
|
327
336
|
|
|
328
337
|
Parameters:
|
|
329
|
-
|
|
338
|
+
uri : storage URI with directory. URI must start with storage prefix such
|
|
330
339
|
as `s3://`, `gs://`, `az://` or "file:///"
|
|
331
340
|
type : read file as "binary", "text", or "image" data. Default is "binary".
|
|
332
341
|
recursive : search recursively for the given path.
|
|
333
342
|
object_name : Created object column name.
|
|
334
343
|
update : force storage reindexing. Default is False.
|
|
344
|
+
anon : If True, we will treat cloud bucket as public one
|
|
335
345
|
|
|
336
346
|
Example:
|
|
337
347
|
```py
|
|
338
348
|
chain = DataChain.from_storage("s3://my-bucket/my-dir")
|
|
339
349
|
```
|
|
340
350
|
"""
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
in_memory=in_memory,
|
|
350
|
-
**kwargs,
|
|
351
|
-
)
|
|
352
|
-
.map(**{object_name: func})
|
|
353
|
-
.select(object_name)
|
|
351
|
+
file_type = get_file_type(type)
|
|
352
|
+
|
|
353
|
+
client_config = {"anon": True} if anon else None
|
|
354
|
+
|
|
355
|
+
session = Session.get(session, client_config=client_config, in_memory=in_memory)
|
|
356
|
+
|
|
357
|
+
list_dataset_name, list_uri, list_path = parse_listing_uri(
|
|
358
|
+
uri, session.catalog.cache, session.catalog.client_config
|
|
354
359
|
)
|
|
360
|
+
need_listing = True
|
|
361
|
+
|
|
362
|
+
for ds in cls.listings(session=session, in_memory=in_memory).collect("listing"):
|
|
363
|
+
if (
|
|
364
|
+
not is_listing_expired(ds.created_at) # type: ignore[union-attr]
|
|
365
|
+
and is_listing_subset(ds.name, list_dataset_name) # type: ignore[union-attr]
|
|
366
|
+
and not update
|
|
367
|
+
):
|
|
368
|
+
need_listing = False
|
|
369
|
+
list_dataset_name = ds.name # type: ignore[union-attr]
|
|
370
|
+
|
|
371
|
+
if need_listing:
|
|
372
|
+
# caching new listing to special listing dataset
|
|
373
|
+
(
|
|
374
|
+
cls.from_records(
|
|
375
|
+
DataChain.DEFAULT_FILE_RECORD,
|
|
376
|
+
session=session,
|
|
377
|
+
settings=settings,
|
|
378
|
+
in_memory=in_memory,
|
|
379
|
+
)
|
|
380
|
+
.gen(
|
|
381
|
+
list_bucket(list_uri, client_config=session.catalog.client_config),
|
|
382
|
+
output={f"{object_name}": File},
|
|
383
|
+
)
|
|
384
|
+
.save(list_dataset_name, listing=True)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
dc = cls.from_dataset(list_dataset_name, session=session)
|
|
388
|
+
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
|
|
389
|
+
|
|
390
|
+
return ls(dc, list_path, recursive=recursive, object_name=object_name)
|
|
355
391
|
|
|
356
392
|
@classmethod
|
|
357
|
-
def from_dataset(
|
|
393
|
+
def from_dataset(
|
|
394
|
+
cls,
|
|
395
|
+
name: str,
|
|
396
|
+
version: Optional[int] = None,
|
|
397
|
+
session: Optional[Session] = None,
|
|
398
|
+
) -> "DataChain":
|
|
358
399
|
"""Get data from a saved Dataset. It returns the chain itself.
|
|
359
400
|
|
|
360
401
|
Parameters:
|
|
@@ -366,7 +407,7 @@ class DataChain(DatasetQuery):
|
|
|
366
407
|
chain = DataChain.from_dataset("my_cats")
|
|
367
408
|
```
|
|
368
409
|
"""
|
|
369
|
-
return DataChain(name=name, version=version)
|
|
410
|
+
return DataChain(name=name, version=version, session=session)
|
|
370
411
|
|
|
371
412
|
@classmethod
|
|
372
413
|
def from_json(
|
|
@@ -419,7 +460,7 @@ class DataChain(DatasetQuery):
|
|
|
419
460
|
object_name = jmespath_to_name(jmespath)
|
|
420
461
|
if not object_name:
|
|
421
462
|
object_name = meta_type
|
|
422
|
-
chain = DataChain.from_storage(
|
|
463
|
+
chain = DataChain.from_storage(uri=path, type=type, **kwargs)
|
|
423
464
|
signal_dict = {
|
|
424
465
|
object_name: read_meta(
|
|
425
466
|
schema_from=schema_from,
|
|
@@ -479,7 +520,7 @@ class DataChain(DatasetQuery):
|
|
|
479
520
|
object_name = jmespath_to_name(jmespath)
|
|
480
521
|
if not object_name:
|
|
481
522
|
object_name = meta_type
|
|
482
|
-
chain = DataChain.from_storage(
|
|
523
|
+
chain = DataChain.from_storage(uri=path, type=type, **kwargs)
|
|
483
524
|
signal_dict = {
|
|
484
525
|
object_name: read_meta(
|
|
485
526
|
schema_from=schema_from,
|
|
@@ -500,6 +541,7 @@ class DataChain(DatasetQuery):
|
|
|
500
541
|
settings: Optional[dict] = None,
|
|
501
542
|
in_memory: bool = False,
|
|
502
543
|
object_name: str = "dataset",
|
|
544
|
+
include_listing: bool = False,
|
|
503
545
|
) -> "DataChain":
|
|
504
546
|
"""Generate chain with list of registered datasets.
|
|
505
547
|
|
|
@@ -517,7 +559,9 @@ class DataChain(DatasetQuery):
|
|
|
517
559
|
|
|
518
560
|
datasets = [
|
|
519
561
|
DatasetInfo.from_models(d, v, j)
|
|
520
|
-
for d, v, j in catalog.list_datasets_versions(
|
|
562
|
+
for d, v, j in catalog.list_datasets_versions(
|
|
563
|
+
include_listing=include_listing
|
|
564
|
+
)
|
|
521
565
|
]
|
|
522
566
|
|
|
523
567
|
return cls.from_values(
|
|
@@ -528,6 +572,42 @@ class DataChain(DatasetQuery):
|
|
|
528
572
|
**{object_name: datasets}, # type: ignore[arg-type]
|
|
529
573
|
)
|
|
530
574
|
|
|
575
|
+
@classmethod
|
|
576
|
+
def listings(
|
|
577
|
+
cls,
|
|
578
|
+
session: Optional[Session] = None,
|
|
579
|
+
in_memory: bool = False,
|
|
580
|
+
object_name: str = "listing",
|
|
581
|
+
**kwargs,
|
|
582
|
+
) -> "DataChain":
|
|
583
|
+
"""Generate chain with list of cached listings.
|
|
584
|
+
Listing is a special kind of dataset which has directory listing data of
|
|
585
|
+
some underlying storage (e.g S3 bucket).
|
|
586
|
+
|
|
587
|
+
Example:
|
|
588
|
+
```py
|
|
589
|
+
from datachain import DataChain
|
|
590
|
+
DataChain.listings().show()
|
|
591
|
+
```
|
|
592
|
+
"""
|
|
593
|
+
session = Session.get(session, in_memory=in_memory)
|
|
594
|
+
catalog = kwargs.get("catalog") or session.catalog
|
|
595
|
+
|
|
596
|
+
listings = [
|
|
597
|
+
ListingInfo.from_models(d, v, j)
|
|
598
|
+
for d, v, j in catalog.list_datasets_versions(
|
|
599
|
+
include_listing=True, **kwargs
|
|
600
|
+
)
|
|
601
|
+
if is_listing_dataset(d.name)
|
|
602
|
+
]
|
|
603
|
+
|
|
604
|
+
return cls.from_values(
|
|
605
|
+
session=session,
|
|
606
|
+
in_memory=in_memory,
|
|
607
|
+
output={object_name: ListingInfo},
|
|
608
|
+
**{object_name: listings}, # type: ignore[arg-type]
|
|
609
|
+
)
|
|
610
|
+
|
|
531
611
|
def print_json_schema( # type: ignore[override]
|
|
532
612
|
self, jmespath: Optional[str] = None, model_name: Optional[str] = None
|
|
533
613
|
) -> "Self":
|
|
@@ -570,7 +650,7 @@ class DataChain(DatasetQuery):
|
|
|
570
650
|
)
|
|
571
651
|
|
|
572
652
|
def save( # type: ignore[override]
|
|
573
|
-
self, name: Optional[str] = None, version: Optional[int] = None
|
|
653
|
+
self, name: Optional[str] = None, version: Optional[int] = None, **kwargs
|
|
574
654
|
) -> "Self":
|
|
575
655
|
"""Save to a Dataset. It returns the chain itself.
|
|
576
656
|
|
|
@@ -580,7 +660,7 @@ class DataChain(DatasetQuery):
|
|
|
580
660
|
version : version of a dataset. Default - the last version that exist.
|
|
581
661
|
"""
|
|
582
662
|
schema = self.signals_schema.clone_without_sys_signals().serialize()
|
|
583
|
-
return super().save(name=name, version=version, feature_schema=schema)
|
|
663
|
+
return super().save(name=name, version=version, feature_schema=schema, **kwargs)
|
|
584
664
|
|
|
585
665
|
def apply(self, func, *args, **kwargs):
|
|
586
666
|
"""Apply any function to the chain.
|
|
@@ -1153,17 +1233,35 @@ class DataChain(DatasetQuery):
|
|
|
1153
1233
|
self,
|
|
1154
1234
|
other: "DataChain",
|
|
1155
1235
|
on: Optional[Union[str, Sequence[str]]] = None,
|
|
1236
|
+
right_on: Optional[Union[str, Sequence[str]]] = None,
|
|
1156
1237
|
) -> "Self":
|
|
1157
1238
|
"""Remove rows that appear in another chain.
|
|
1158
1239
|
|
|
1159
1240
|
Parameters:
|
|
1160
1241
|
other: chain whose rows will be removed from `self`
|
|
1161
|
-
on: columns to consider for determining row equality
|
|
1162
|
-
defaults to all common columns
|
|
1242
|
+
on: columns to consider for determining row equality in `self`.
|
|
1243
|
+
If unspecified, defaults to all common columns
|
|
1244
|
+
between `self` and `other`.
|
|
1245
|
+
right_on: columns to consider for determining row equality in `other`.
|
|
1246
|
+
If unspecified, defaults to the same values as `on`.
|
|
1163
1247
|
"""
|
|
1164
1248
|
if isinstance(on, str):
|
|
1249
|
+
if not on:
|
|
1250
|
+
raise DataChainParamsError("'on' cannot be an empty string")
|
|
1165
1251
|
on = [on]
|
|
1166
|
-
|
|
1252
|
+
elif isinstance(on, Sequence):
|
|
1253
|
+
if not on or any(not col for col in on):
|
|
1254
|
+
raise DataChainParamsError("'on' cannot contain empty strings")
|
|
1255
|
+
|
|
1256
|
+
if isinstance(right_on, str):
|
|
1257
|
+
if not right_on:
|
|
1258
|
+
raise DataChainParamsError("'right_on' cannot be an empty string")
|
|
1259
|
+
right_on = [right_on]
|
|
1260
|
+
elif isinstance(right_on, Sequence):
|
|
1261
|
+
if not right_on or any(not col for col in right_on):
|
|
1262
|
+
raise DataChainParamsError("'right_on' cannot contain empty strings")
|
|
1263
|
+
|
|
1264
|
+
if on is None and right_on is None:
|
|
1167
1265
|
other_columns = set(other._effective_signals_schema.db_signals())
|
|
1168
1266
|
signals = [
|
|
1169
1267
|
c
|
|
@@ -1172,16 +1270,29 @@ class DataChain(DatasetQuery):
|
|
|
1172
1270
|
]
|
|
1173
1271
|
if not signals:
|
|
1174
1272
|
raise DataChainParamsError("subtract(): no common columns")
|
|
1175
|
-
elif not
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
elif not on:
|
|
1273
|
+
elif on is not None and right_on is None:
|
|
1274
|
+
right_on = on
|
|
1275
|
+
signals = list(self.signals_schema.resolve(*on).db_signals())
|
|
1276
|
+
elif on is None and right_on is not None:
|
|
1180
1277
|
raise DataChainParamsError(
|
|
1181
|
-
"'on'
|
|
1278
|
+
"'on' must be specified when 'right_on' is provided"
|
|
1182
1279
|
)
|
|
1183
1280
|
else:
|
|
1184
|
-
|
|
1281
|
+
if not isinstance(on, Sequence) or not isinstance(right_on, Sequence):
|
|
1282
|
+
raise TypeError(
|
|
1283
|
+
"'on' and 'right_on' must be 'str' or 'Sequence' object"
|
|
1284
|
+
)
|
|
1285
|
+
if len(on) != len(right_on):
|
|
1286
|
+
raise DataChainParamsError(
|
|
1287
|
+
"'on' and 'right_on' must have the same length"
|
|
1288
|
+
)
|
|
1289
|
+
signals = list(
|
|
1290
|
+
zip(
|
|
1291
|
+
self.signals_schema.resolve(*on).db_signals(),
|
|
1292
|
+
other.signals_schema.resolve(*right_on).db_signals(),
|
|
1293
|
+
) # type: ignore[arg-type]
|
|
1294
|
+
)
|
|
1295
|
+
|
|
1185
1296
|
return super()._subtract(other, signals) # type: ignore[arg-type]
|
|
1186
1297
|
|
|
1187
1298
|
@classmethod
|
|
@@ -1634,7 +1745,10 @@ class DataChain(DatasetQuery):
|
|
|
1634
1745
|
|
|
1635
1746
|
if schema:
|
|
1636
1747
|
signal_schema = SignalSchema(schema)
|
|
1637
|
-
columns =
|
|
1748
|
+
columns = [
|
|
1749
|
+
sqlalchemy.Column(c.name, c.type) # type: ignore[union-attr]
|
|
1750
|
+
for c in signal_schema.db_signals(as_columns=True) # type: ignore[assignment]
|
|
1751
|
+
]
|
|
1638
1752
|
else:
|
|
1639
1753
|
columns = [
|
|
1640
1754
|
sqlalchemy.Column(name, typ)
|
datachain/lib/file.py
CHANGED
|
@@ -349,39 +349,6 @@ class ImageFile(File):
|
|
|
349
349
|
self.read().save(destination)
|
|
350
350
|
|
|
351
351
|
|
|
352
|
-
def get_file(type_: Literal["binary", "text", "image"] = "binary"):
|
|
353
|
-
file: type[File] = File
|
|
354
|
-
if type_ == "text":
|
|
355
|
-
file = TextFile
|
|
356
|
-
elif type_ == "image":
|
|
357
|
-
file = ImageFile # type: ignore[assignment]
|
|
358
|
-
|
|
359
|
-
def get_file_type(
|
|
360
|
-
source: str,
|
|
361
|
-
path: str,
|
|
362
|
-
size: int,
|
|
363
|
-
version: str,
|
|
364
|
-
etag: str,
|
|
365
|
-
is_latest: bool,
|
|
366
|
-
last_modified: datetime,
|
|
367
|
-
location: Optional[Union[dict, list[dict]]],
|
|
368
|
-
vtype: str,
|
|
369
|
-
) -> file: # type: ignore[valid-type]
|
|
370
|
-
return file(
|
|
371
|
-
source=source,
|
|
372
|
-
path=path,
|
|
373
|
-
size=size,
|
|
374
|
-
version=version,
|
|
375
|
-
etag=etag,
|
|
376
|
-
is_latest=is_latest,
|
|
377
|
-
last_modified=last_modified,
|
|
378
|
-
location=location,
|
|
379
|
-
vtype=vtype,
|
|
380
|
-
)
|
|
381
|
-
|
|
382
|
-
return get_file_type
|
|
383
|
-
|
|
384
|
-
|
|
385
352
|
class IndexedFile(DataModel):
|
|
386
353
|
"""Metadata indexed from tabular files.
|
|
387
354
|
|
|
@@ -390,3 +357,13 @@ class IndexedFile(DataModel):
|
|
|
390
357
|
|
|
391
358
|
file: File
|
|
392
359
|
index: int
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def get_file_type(type_: Literal["binary", "text", "image"] = "binary") -> type[File]:
|
|
363
|
+
file: type[File] = File
|
|
364
|
+
if type_ == "text":
|
|
365
|
+
file = TextFile
|
|
366
|
+
elif type_ == "image":
|
|
367
|
+
file = ImageFile # type: ignore[assignment]
|
|
368
|
+
|
|
369
|
+
return file
|
datachain/lib/hf.py
CHANGED
|
@@ -99,7 +99,8 @@ class HFGenerator(Generator):
|
|
|
99
99
|
|
|
100
100
|
def stream_splits(ds: Union[str, HFDatasetType], *args, **kwargs):
|
|
101
101
|
if isinstance(ds, str):
|
|
102
|
-
|
|
102
|
+
kwargs["streaming"] = True
|
|
103
|
+
ds = load_dataset(ds, *args, **kwargs)
|
|
103
104
|
if isinstance(ds, (DatasetDict, IterableDatasetDict)):
|
|
104
105
|
return ds
|
|
105
106
|
return {"": ds}
|
datachain/lib/listing.py
CHANGED
|
@@ -1,103 +1,26 @@
|
|
|
1
|
-
import
|
|
2
|
-
from collections.abc import
|
|
3
|
-
from
|
|
1
|
+
import posixpath
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from datetime import datetime, timedelta, timezone
|
|
4
|
+
from typing import TYPE_CHECKING, Callable, Optional
|
|
4
5
|
|
|
5
|
-
from botocore.exceptions import ClientError
|
|
6
6
|
from fsspec.asyn import get_loop
|
|
7
|
+
from sqlalchemy.sql.expression import true
|
|
7
8
|
|
|
8
9
|
from datachain.asyn import iter_over_async
|
|
9
10
|
from datachain.client import Client
|
|
10
|
-
from datachain.error import ClientError as DataChainClientError
|
|
11
11
|
from datachain.lib.file import File
|
|
12
|
+
from datachain.query.schema import Column
|
|
13
|
+
from datachain.sql.functions import path as pathfunc
|
|
14
|
+
from datachain.utils import uses_glob
|
|
12
15
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
infos = await client.ls_dir(path)
|
|
22
|
-
files = []
|
|
23
|
-
subdirs = set()
|
|
24
|
-
for info in infos:
|
|
25
|
-
full_path = info["name"]
|
|
26
|
-
subprefix = client.rel_path(full_path)
|
|
27
|
-
if prefix.strip(DELIMITER) == subprefix.strip(DELIMITER):
|
|
28
|
-
continue
|
|
29
|
-
if info["type"] == "directory":
|
|
30
|
-
subdirs.add(subprefix)
|
|
31
|
-
else:
|
|
32
|
-
files.append(client.info_to_file(info, subprefix))
|
|
33
|
-
if files:
|
|
34
|
-
await result_queue.put(files)
|
|
35
|
-
return subdirs
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
async def _fetch(
|
|
39
|
-
client, start_prefix: str, result_queue: ResultQueue, fetch_workers
|
|
40
|
-
) -> None:
|
|
41
|
-
loop = get_loop()
|
|
42
|
-
|
|
43
|
-
queue: asyncio.Queue[str] = asyncio.Queue()
|
|
44
|
-
queue.put_nowait(start_prefix)
|
|
45
|
-
|
|
46
|
-
async def process(queue) -> None:
|
|
47
|
-
while True:
|
|
48
|
-
prefix = await queue.get()
|
|
49
|
-
try:
|
|
50
|
-
subdirs = await _fetch_dir(client, prefix, result_queue)
|
|
51
|
-
for subdir in subdirs:
|
|
52
|
-
queue.put_nowait(subdir)
|
|
53
|
-
except Exception:
|
|
54
|
-
while not queue.empty():
|
|
55
|
-
queue.get_nowait()
|
|
56
|
-
queue.task_done()
|
|
57
|
-
raise
|
|
58
|
-
|
|
59
|
-
finally:
|
|
60
|
-
queue.task_done()
|
|
61
|
-
|
|
62
|
-
try:
|
|
63
|
-
workers: list[asyncio.Task] = [
|
|
64
|
-
loop.create_task(process(queue)) for _ in range(fetch_workers)
|
|
65
|
-
]
|
|
66
|
-
|
|
67
|
-
# Wait for all fetch tasks to complete
|
|
68
|
-
await queue.join()
|
|
69
|
-
# Stop the workers
|
|
70
|
-
excs = []
|
|
71
|
-
for worker in workers:
|
|
72
|
-
if worker.done() and (exc := worker.exception()):
|
|
73
|
-
excs.append(exc)
|
|
74
|
-
else:
|
|
75
|
-
worker.cancel()
|
|
76
|
-
if excs:
|
|
77
|
-
raise excs[0]
|
|
78
|
-
except ClientError as exc:
|
|
79
|
-
raise DataChainClientError(
|
|
80
|
-
exc.response.get("Error", {}).get("Message") or exc,
|
|
81
|
-
exc.response.get("Error", {}).get("Code"),
|
|
82
|
-
) from exc
|
|
83
|
-
finally:
|
|
84
|
-
# This ensures the progress bar is closed before any exceptions are raised
|
|
85
|
-
result_queue.put_nowait(None)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
async def _scandir(client, prefix, fetch_workers) -> AsyncIterator:
|
|
89
|
-
"""Recursively goes through dir tree and yields files"""
|
|
90
|
-
result_queue: ResultQueue = asyncio.Queue()
|
|
91
|
-
loop = get_loop()
|
|
92
|
-
main_task = loop.create_task(_fetch(client, prefix, result_queue, fetch_workers))
|
|
93
|
-
while (files := await result_queue.get()) is not None:
|
|
94
|
-
for f in files:
|
|
95
|
-
yield f
|
|
96
|
-
|
|
97
|
-
await main_task
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def list_bucket(uri: str, client_config=None, fetch_workers=FETCH_WORKERS) -> Callable:
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from datachain.lib.dc import DataChain
|
|
18
|
+
|
|
19
|
+
LISTING_TTL = 4 * 60 * 60 # cached listing lasts 4 hours
|
|
20
|
+
LISTING_PREFIX = "lst__" # listing datasets start with this name
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def list_bucket(uri: str, client_config=None) -> Callable:
|
|
101
24
|
"""
|
|
102
25
|
Function that returns another generator function that yields File objects
|
|
103
26
|
from bucket where each File represents one bucket entry.
|
|
@@ -106,6 +29,91 @@ def list_bucket(uri: str, client_config=None, fetch_workers=FETCH_WORKERS) -> Ca
|
|
|
106
29
|
def list_func() -> Iterator[File]:
|
|
107
30
|
config = client_config or {}
|
|
108
31
|
client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type]
|
|
109
|
-
|
|
32
|
+
for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()):
|
|
33
|
+
for entry in entries:
|
|
34
|
+
yield entry.to_file(client.uri)
|
|
110
35
|
|
|
111
36
|
return list_func
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def ls(
|
|
40
|
+
dc: "DataChain",
|
|
41
|
+
path: str,
|
|
42
|
+
recursive: Optional[bool] = True,
|
|
43
|
+
object_name="file",
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Return files by some path from DataChain instance which contains bucket listing.
|
|
47
|
+
Path can have globs.
|
|
48
|
+
If recursive is set to False, only first level children will be returned by
|
|
49
|
+
specified path
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def _file_c(name: str) -> Column:
|
|
53
|
+
return Column(f"{object_name}.{name}")
|
|
54
|
+
|
|
55
|
+
dc = dc.filter(_file_c("is_latest") == true())
|
|
56
|
+
|
|
57
|
+
if recursive:
|
|
58
|
+
if not path or path == "/":
|
|
59
|
+
# root of a bucket, returning all latest files from it
|
|
60
|
+
return dc
|
|
61
|
+
|
|
62
|
+
if not uses_glob(path):
|
|
63
|
+
# path is not glob, so it's pointing to some directory or a specific
|
|
64
|
+
# file and we are adding proper filter for it
|
|
65
|
+
return dc.filter(
|
|
66
|
+
(_file_c("path") == path)
|
|
67
|
+
| (_file_c("path").glob(path.rstrip("/") + "/*"))
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# path has glob syntax so we are returning glob filter
|
|
71
|
+
return dc.filter(_file_c("path").glob(path))
|
|
72
|
+
# returning only first level children by path
|
|
73
|
+
return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*"))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]:
|
|
77
|
+
"""
|
|
78
|
+
Parsing uri and returns listing dataset name, listing uri and listing path
|
|
79
|
+
"""
|
|
80
|
+
client, path = Client.parse_url(uri, cache, **client_config)
|
|
81
|
+
|
|
82
|
+
# clean path without globs
|
|
83
|
+
lst_uri_path = (
|
|
84
|
+
posixpath.dirname(path) if uses_glob(path) or client.fs.isfile(uri) else path
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
lst_uri = f"{client.uri}/{lst_uri_path.lstrip('/')}"
|
|
88
|
+
ds_name = (
|
|
89
|
+
f"{LISTING_PREFIX}{client.uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return ds_name, lst_uri, path
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def is_listing_dataset(name: str) -> bool:
|
|
96
|
+
"""Returns True if it's special listing dataset"""
|
|
97
|
+
return name.startswith(LISTING_PREFIX)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def listing_uri_from_name(dataset_name: str) -> str:
|
|
101
|
+
"""Returns clean storage URI from listing dataset name"""
|
|
102
|
+
if not is_listing_dataset(dataset_name):
|
|
103
|
+
raise ValueError(f"Dataset {dataset_name} is not a listing")
|
|
104
|
+
return dataset_name.removeprefix(LISTING_PREFIX)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def is_listing_expired(created_at: datetime) -> bool:
|
|
108
|
+
"""Checks if listing has expired based on it's creation date"""
|
|
109
|
+
return datetime.now(timezone.utc) > created_at + timedelta(seconds=LISTING_TTL)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def is_listing_subset(ds1_name: str, ds2_name: str) -> bool:
|
|
113
|
+
"""
|
|
114
|
+
Checks if one listing contains another one by comparing corresponding dataset names
|
|
115
|
+
"""
|
|
116
|
+
assert ds1_name.endswith("/")
|
|
117
|
+
assert ds2_name.endswith("/")
|
|
118
|
+
|
|
119
|
+
return ds2_name.startswith(ds1_name)
|