datachain 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/dc.py CHANGED
@@ -27,7 +27,16 @@ from datachain.lib.convert.values_to_tuples import values_to_tuples
27
27
  from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
28
28
  from datachain.lib.dataset_info import DatasetInfo
29
29
  from datachain.lib.file import ExportPlacement as FileExportPlacement
30
- from datachain.lib.file import File, IndexedFile, get_file
30
+ from datachain.lib.file import File, IndexedFile, get_file_type
31
+ from datachain.lib.listing import (
32
+ is_listing_dataset,
33
+ is_listing_expired,
34
+ is_listing_subset,
35
+ list_bucket,
36
+ ls,
37
+ parse_listing_uri,
38
+ )
39
+ from datachain.lib.listing_info import ListingInfo
31
40
  from datachain.lib.meta_formats import read_meta, read_schema
32
41
  from datachain.lib.model_store import ModelStore
33
42
  from datachain.lib.settings import Settings
@@ -47,7 +56,7 @@ from datachain.query.dataset import (
47
56
  PartitionByType,
48
57
  detach,
49
58
  )
50
- from datachain.query.schema import Column, DatasetRow
59
+ from datachain.query.schema import DEFAULT_DELIMITER, Column, DatasetRow
51
60
  from datachain.sql.functions import path as pathfunc
52
61
  from datachain.utils import inside_notebook
53
62
 
@@ -103,11 +112,31 @@ class DatasetFromValuesError(DataChainParamsError): # noqa: D101
103
112
  super().__init__(f"Dataset{name} from values error: {msg}")
104
113
 
105
114
 
115
+ def _get_merge_error_str(col: Union[str, sqlalchemy.ColumnElement]) -> str:
116
+ if isinstance(col, str):
117
+ return col
118
+ if isinstance(col, sqlalchemy.Column):
119
+ return col.name.replace(DEFAULT_DELIMITER, ".")
120
+ if isinstance(col, sqlalchemy.ColumnElement) and hasattr(col, "name"):
121
+ return f"{col.name} expression"
122
+ return str(col)
123
+
124
+
106
125
  class DatasetMergeError(DataChainParamsError): # noqa: D101
107
- def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: str): # noqa: D107
108
- on_str = ", ".join(on) if isinstance(on, Sequence) else ""
126
+ def __init__( # noqa: D107
127
+ self,
128
+ on: Sequence[Union[str, sqlalchemy.ColumnElement]],
129
+ right_on: Optional[Sequence[Union[str, sqlalchemy.ColumnElement]]],
130
+ msg: str,
131
+ ):
132
+ def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str:
133
+ if not isinstance(on, Sequence):
134
+ return str(on) # type: ignore[unreachable]
135
+ return ", ".join([_get_merge_error_str(col) for col in on])
136
+
137
+ on_str = _get_str(on)
109
138
  right_on_str = (
110
- ", right_on='" + ", ".join(right_on) + "'"
139
+ ", right_on='" + _get_str(right_on) + "'"
111
140
  if right_on and isinstance(right_on, Sequence)
112
141
  else ""
113
142
  )
@@ -130,7 +159,7 @@ class Sys(DataModel):
130
159
 
131
160
 
132
161
  class DataChain(DatasetQuery):
133
- """AI 🔗 DataChain - a data structure for batch data processing and evaluation.
162
+ """DataChain - a data structure for batch data processing and evaluation.
134
163
 
135
164
  It represents a sequence of data manipulation steps such as reading data from
136
165
  storages, running AI or LLM models or calling external services API to validate or
@@ -243,13 +272,24 @@ class DataChain(DatasetQuery):
243
272
  """Returns Column instance with a type if name is found in current schema,
244
273
  otherwise raises an exception.
245
274
  """
246
- name_path = name.split(".")
275
+ if "." in name:
276
+ name_path = name.split(".")
277
+ elif DEFAULT_DELIMITER in name:
278
+ name_path = name.split(DEFAULT_DELIMITER)
279
+ else:
280
+ name_path = [name]
247
281
  for path, type_, _, _ in self.signals_schema.get_flat_tree():
248
282
  if path == name_path:
249
283
  return Column(name, python_to_sql(type_))
250
284
 
251
285
  raise ValueError(f"Column with name {name} not found in the schema")
252
286
 
287
+ def c(self, column: Union[str, Column]) -> Column:
288
+ """Returns Column instance attached to the current chain."""
289
+ c = self.column(column) if isinstance(column, str) else self.column(column.name)
290
+ c.table = self.table
291
+ return c
292
+
253
293
  def print_schema(self) -> None:
254
294
  """Print schema of the chain."""
255
295
  self._effective_signals_schema.print_tree()
@@ -311,7 +351,7 @@ class DataChain(DatasetQuery):
311
351
  @classmethod
312
352
  def from_storage(
313
353
  cls,
314
- path,
354
+ uri,
315
355
  *,
316
356
  type: Literal["binary", "text", "image"] = "binary",
317
357
  session: Optional[Session] = None,
@@ -320,41 +360,73 @@ class DataChain(DatasetQuery):
320
360
  recursive: Optional[bool] = True,
321
361
  object_name: str = "file",
322
362
  update: bool = False,
323
- **kwargs,
363
+ anon: bool = False,
324
364
  ) -> "Self":
325
365
  """Get data from a storage as a list of file with all file attributes.
326
366
  It returns the chain itself as usual.
327
367
 
328
368
  Parameters:
329
- path : storage URI with directory. URI must start with storage prefix such
369
+ uri : storage URI with directory. URI must start with storage prefix such
330
370
  as `s3://`, `gs://`, `az://` or "file:///"
331
371
  type : read file as "binary", "text", or "image" data. Default is "binary".
332
372
  recursive : search recursively for the given path.
333
373
  object_name : Created object column name.
334
374
  update : force storage reindexing. Default is False.
375
+ anon : If True, we will treat cloud bucket as public one
335
376
 
336
377
  Example:
337
378
  ```py
338
379
  chain = DataChain.from_storage("s3://my-bucket/my-dir")
339
380
  ```
340
381
  """
341
- func = get_file(type)
342
- return (
343
- cls(
344
- path,
345
- session=session,
346
- settings=settings,
347
- recursive=recursive,
348
- update=update,
349
- in_memory=in_memory,
350
- **kwargs,
351
- )
352
- .map(**{object_name: func})
353
- .select(object_name)
382
+ file_type = get_file_type(type)
383
+
384
+ client_config = {"anon": True} if anon else None
385
+
386
+ session = Session.get(session, client_config=client_config, in_memory=in_memory)
387
+
388
+ list_dataset_name, list_uri, list_path = parse_listing_uri(
389
+ uri, session.catalog.cache, session.catalog.client_config
354
390
  )
391
+ need_listing = True
392
+
393
+ for ds in cls.listings(session=session, in_memory=in_memory).collect("listing"):
394
+ if (
395
+ not is_listing_expired(ds.created_at) # type: ignore[union-attr]
396
+ and is_listing_subset(ds.name, list_dataset_name) # type: ignore[union-attr]
397
+ and not update
398
+ ):
399
+ need_listing = False
400
+ list_dataset_name = ds.name # type: ignore[union-attr]
401
+
402
+ if need_listing:
403
+ # caching new listing to special listing dataset
404
+ (
405
+ cls.from_records(
406
+ DataChain.DEFAULT_FILE_RECORD,
407
+ session=session,
408
+ settings=settings,
409
+ in_memory=in_memory,
410
+ )
411
+ .gen(
412
+ list_bucket(list_uri, client_config=session.catalog.client_config),
413
+ output={f"{object_name}": File},
414
+ )
415
+ .save(list_dataset_name, listing=True)
416
+ )
417
+
418
+ dc = cls.from_dataset(list_dataset_name, session=session)
419
+ dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
420
+
421
+ return ls(dc, list_path, recursive=recursive, object_name=object_name)
355
422
 
356
423
  @classmethod
357
- def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
424
+ def from_dataset(
425
+ cls,
426
+ name: str,
427
+ version: Optional[int] = None,
428
+ session: Optional[Session] = None,
429
+ ) -> "DataChain":
358
430
  """Get data from a saved Dataset. It returns the chain itself.
359
431
 
360
432
  Parameters:
@@ -366,7 +438,7 @@ class DataChain(DatasetQuery):
366
438
  chain = DataChain.from_dataset("my_cats")
367
439
  ```
368
440
  """
369
- return DataChain(name=name, version=version)
441
+ return DataChain(name=name, version=version, session=session)
370
442
 
371
443
  @classmethod
372
444
  def from_json(
@@ -419,7 +491,7 @@ class DataChain(DatasetQuery):
419
491
  object_name = jmespath_to_name(jmespath)
420
492
  if not object_name:
421
493
  object_name = meta_type
422
- chain = DataChain.from_storage(path=path, type=type, **kwargs)
494
+ chain = DataChain.from_storage(uri=path, type=type, **kwargs)
423
495
  signal_dict = {
424
496
  object_name: read_meta(
425
497
  schema_from=schema_from,
@@ -479,7 +551,7 @@ class DataChain(DatasetQuery):
479
551
  object_name = jmespath_to_name(jmespath)
480
552
  if not object_name:
481
553
  object_name = meta_type
482
- chain = DataChain.from_storage(path=path, type=type, **kwargs)
554
+ chain = DataChain.from_storage(uri=path, type=type, **kwargs)
483
555
  signal_dict = {
484
556
  object_name: read_meta(
485
557
  schema_from=schema_from,
@@ -500,6 +572,7 @@ class DataChain(DatasetQuery):
500
572
  settings: Optional[dict] = None,
501
573
  in_memory: bool = False,
502
574
  object_name: str = "dataset",
575
+ include_listing: bool = False,
503
576
  ) -> "DataChain":
504
577
  """Generate chain with list of registered datasets.
505
578
 
@@ -517,7 +590,9 @@ class DataChain(DatasetQuery):
517
590
 
518
591
  datasets = [
519
592
  DatasetInfo.from_models(d, v, j)
520
- for d, v, j in catalog.list_datasets_versions()
593
+ for d, v, j in catalog.list_datasets_versions(
594
+ include_listing=include_listing
595
+ )
521
596
  ]
522
597
 
523
598
  return cls.from_values(
@@ -528,6 +603,42 @@ class DataChain(DatasetQuery):
528
603
  **{object_name: datasets}, # type: ignore[arg-type]
529
604
  )
530
605
 
606
+ @classmethod
607
+ def listings(
608
+ cls,
609
+ session: Optional[Session] = None,
610
+ in_memory: bool = False,
611
+ object_name: str = "listing",
612
+ **kwargs,
613
+ ) -> "DataChain":
614
+ """Generate chain with list of cached listings.
615
+ Listing is a special kind of dataset which has directory listing data of
616
+ some underlying storage (e.g S3 bucket).
617
+
618
+ Example:
619
+ ```py
620
+ from datachain import DataChain
621
+ DataChain.listings().show()
622
+ ```
623
+ """
624
+ session = Session.get(session, in_memory=in_memory)
625
+ catalog = kwargs.get("catalog") or session.catalog
626
+
627
+ listings = [
628
+ ListingInfo.from_models(d, v, j)
629
+ for d, v, j in catalog.list_datasets_versions(
630
+ include_listing=True, **kwargs
631
+ )
632
+ if is_listing_dataset(d.name)
633
+ ]
634
+
635
+ return cls.from_values(
636
+ session=session,
637
+ in_memory=in_memory,
638
+ output={object_name: ListingInfo},
639
+ **{object_name: listings}, # type: ignore[arg-type]
640
+ )
641
+
531
642
  def print_json_schema( # type: ignore[override]
532
643
  self, jmespath: Optional[str] = None, model_name: Optional[str] = None
533
644
  ) -> "Self":
@@ -570,7 +681,7 @@ class DataChain(DatasetQuery):
570
681
  )
571
682
 
572
683
  def save( # type: ignore[override]
573
- self, name: Optional[str] = None, version: Optional[int] = None
684
+ self, name: Optional[str] = None, version: Optional[int] = None, **kwargs
574
685
  ) -> "Self":
575
686
  """Save to a Dataset. It returns the chain itself.
576
687
 
@@ -580,7 +691,7 @@ class DataChain(DatasetQuery):
580
691
  version : version of a dataset. Default - the last version that exist.
581
692
  """
582
693
  schema = self.signals_schema.clone_without_sys_signals().serialize()
583
- return super().save(name=name, version=version, feature_schema=schema)
694
+ return super().save(name=name, version=version, feature_schema=schema, **kwargs)
584
695
 
585
696
  def apply(self, func, *args, **kwargs):
586
697
  """Apply any function to the chain.
@@ -1060,8 +1171,17 @@ class DataChain(DatasetQuery):
1060
1171
  def merge(
1061
1172
  self,
1062
1173
  right_ds: "DataChain",
1063
- on: Union[str, Sequence[str]],
1064
- right_on: Union[str, Sequence[str], None] = None,
1174
+ on: Union[
1175
+ str,
1176
+ sqlalchemy.ColumnElement,
1177
+ Sequence[Union[str, sqlalchemy.ColumnElement]],
1178
+ ],
1179
+ right_on: Union[
1180
+ str,
1181
+ sqlalchemy.ColumnElement,
1182
+ Sequence[Union[str, sqlalchemy.ColumnElement]],
1183
+ None,
1184
+ ] = None,
1065
1185
  inner=False,
1066
1186
  rname="right_",
1067
1187
  ) -> "Self":
@@ -1086,7 +1206,7 @@ class DataChain(DatasetQuery):
1086
1206
  if on is None:
1087
1207
  raise DatasetMergeError(["None"], None, "'on' must be specified")
1088
1208
 
1089
- if isinstance(on, str):
1209
+ if isinstance(on, (str, sqlalchemy.ColumnElement)):
1090
1210
  on = [on]
1091
1211
  elif not isinstance(on, Sequence):
1092
1212
  raise DatasetMergeError(
@@ -1095,19 +1215,15 @@ class DataChain(DatasetQuery):
1095
1215
  f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'",
1096
1216
  )
1097
1217
 
1098
- signals_schema = self.signals_schema.clone_without_sys_signals()
1099
- on_columns: list[str] = signals_schema.resolve(*on).db_signals() # type: ignore[assignment]
1100
-
1101
- right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
1102
1218
  if right_on is not None:
1103
- if isinstance(right_on, str):
1219
+ if isinstance(right_on, (str, sqlalchemy.ColumnElement)):
1104
1220
  right_on = [right_on]
1105
1221
  elif not isinstance(right_on, Sequence):
1106
1222
  raise DatasetMergeError(
1107
1223
  on,
1108
1224
  right_on,
1109
1225
  "'right_on' must be 'str' or 'Sequence' object"
1110
- f" but got type '{right_on}'",
1226
+ f" but got type '{type(right_on)}'",
1111
1227
  )
1112
1228
 
1113
1229
  if len(right_on) != len(on):
@@ -1115,34 +1231,39 @@ class DataChain(DatasetQuery):
1115
1231
  on, right_on, "'on' and 'right_on' must have the same length'"
1116
1232
  )
1117
1233
 
1118
- right_on_columns: list[str] = right_signals_schema.resolve(
1119
- *right_on
1120
- ).db_signals() # type: ignore[assignment]
1121
-
1122
- if len(right_on_columns) != len(on_columns):
1123
- on_str = ", ".join(right_on_columns)
1124
- right_on_str = ", ".join(right_on_columns)
1125
- raise DatasetMergeError(
1126
- on,
1127
- right_on,
1128
- "'on' and 'right_on' must have the same number of columns in db'."
1129
- f" on -> {on_str}, right_on -> {right_on_str}",
1130
- )
1131
- else:
1132
- right_on = on
1133
- right_on_columns = on_columns
1134
-
1135
1234
  if self == right_ds:
1136
1235
  right_ds = right_ds.clone(new_table=True)
1137
1236
 
1237
+ errors = []
1238
+
1239
+ def _resolve(
1240
+ ds: DataChain,
1241
+ col: Union[str, sqlalchemy.ColumnElement],
1242
+ side: Union[str, None],
1243
+ ):
1244
+ try:
1245
+ return ds.c(col) if isinstance(col, (str, C)) else col
1246
+ except ValueError:
1247
+ if side:
1248
+ errors.append(f"{_get_merge_error_str(col)} in {side}")
1249
+
1138
1250
  ops = [
1139
- self.c(left) == right_ds.c(right)
1140
- for left, right in zip(on_columns, right_on_columns)
1251
+ _resolve(self, left, "left")
1252
+ == _resolve(right_ds, right, "right" if right_on else None)
1253
+ for left, right in zip(on, right_on or on)
1141
1254
  ]
1142
1255
 
1256
+ if errors:
1257
+ raise DatasetMergeError(
1258
+ on, right_on, f"Could not resolve {', '.join(errors)}"
1259
+ )
1260
+
1143
1261
  ds = self.join(right_ds, sqlalchemy.and_(*ops), inner, rname + "{name}")
1144
1262
 
1145
1263
  ds.feature_schema = None
1264
+
1265
+ signals_schema = self.signals_schema.clone_without_sys_signals()
1266
+ right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
1146
1267
  ds.signals_schema = SignalSchema({"sys": Sys}) | signals_schema.merge(
1147
1268
  right_signals_schema, rname
1148
1269
  )
@@ -1665,7 +1786,10 @@ class DataChain(DatasetQuery):
1665
1786
 
1666
1787
  if schema:
1667
1788
  signal_schema = SignalSchema(schema)
1668
- columns = signal_schema.db_signals(as_columns=True) # type: ignore[assignment]
1789
+ columns = [
1790
+ sqlalchemy.Column(c.name, c.type) # type: ignore[union-attr]
1791
+ for c in signal_schema.db_signals(as_columns=True) # type: ignore[assignment]
1792
+ ]
1669
1793
  else:
1670
1794
  columns = [
1671
1795
  sqlalchemy.Column(name, typ)
datachain/lib/file.py CHANGED
@@ -349,39 +349,6 @@ class ImageFile(File):
349
349
  self.read().save(destination)
350
350
 
351
351
 
352
- def get_file(type_: Literal["binary", "text", "image"] = "binary"):
353
- file: type[File] = File
354
- if type_ == "text":
355
- file = TextFile
356
- elif type_ == "image":
357
- file = ImageFile # type: ignore[assignment]
358
-
359
- def get_file_type(
360
- source: str,
361
- path: str,
362
- size: int,
363
- version: str,
364
- etag: str,
365
- is_latest: bool,
366
- last_modified: datetime,
367
- location: Optional[Union[dict, list[dict]]],
368
- vtype: str,
369
- ) -> file: # type: ignore[valid-type]
370
- return file(
371
- source=source,
372
- path=path,
373
- size=size,
374
- version=version,
375
- etag=etag,
376
- is_latest=is_latest,
377
- last_modified=last_modified,
378
- location=location,
379
- vtype=vtype,
380
- )
381
-
382
- return get_file_type
383
-
384
-
385
352
  class IndexedFile(DataModel):
386
353
  """Metadata indexed from tabular files.
387
354
 
@@ -390,3 +357,13 @@ class IndexedFile(DataModel):
390
357
 
391
358
  file: File
392
359
  index: int
360
+
361
+
362
+ def get_file_type(type_: Literal["binary", "text", "image"] = "binary") -> type[File]:
363
+ file: type[File] = File
364
+ if type_ == "text":
365
+ file = TextFile
366
+ elif type_ == "image":
367
+ file = ImageFile # type: ignore[assignment]
368
+
369
+ return file
datachain/lib/hf.py CHANGED
@@ -99,7 +99,8 @@ class HFGenerator(Generator):
99
99
 
100
100
  def stream_splits(ds: Union[str, HFDatasetType], *args, **kwargs):
101
101
  if isinstance(ds, str):
102
- ds = load_dataset(ds, *args, streaming=True, **kwargs)
102
+ kwargs["streaming"] = True
103
+ ds = load_dataset(ds, *args, **kwargs)
103
104
  if isinstance(ds, (DatasetDict, IterableDatasetDict)):
104
105
  return ds
105
106
  return {"": ds}
datachain/lib/listing.py CHANGED
@@ -1,103 +1,26 @@
1
- import asyncio
2
- from collections.abc import AsyncIterator, Iterator, Sequence
3
- from typing import Callable, Optional
1
+ import posixpath
2
+ from collections.abc import Iterator
3
+ from datetime import datetime, timedelta, timezone
4
+ from typing import TYPE_CHECKING, Callable, Optional
4
5
 
5
- from botocore.exceptions import ClientError
6
6
  from fsspec.asyn import get_loop
7
+ from sqlalchemy.sql.expression import true
7
8
 
8
9
  from datachain.asyn import iter_over_async
9
10
  from datachain.client import Client
10
- from datachain.error import ClientError as DataChainClientError
11
11
  from datachain.lib.file import File
12
+ from datachain.query.schema import Column
13
+ from datachain.sql.functions import path as pathfunc
14
+ from datachain.utils import uses_glob
12
15
 
13
- ResultQueue = asyncio.Queue[Optional[Sequence[File]]]
14
-
15
- DELIMITER = "/" # Path delimiter
16
- FETCH_WORKERS = 100
17
-
18
-
19
- async def _fetch_dir(client, prefix, result_queue) -> set[str]:
20
- path = f"{client.name}/{prefix}"
21
- infos = await client.ls_dir(path)
22
- files = []
23
- subdirs = set()
24
- for info in infos:
25
- full_path = info["name"]
26
- subprefix = client.rel_path(full_path)
27
- if prefix.strip(DELIMITER) == subprefix.strip(DELIMITER):
28
- continue
29
- if info["type"] == "directory":
30
- subdirs.add(subprefix)
31
- else:
32
- files.append(client.info_to_file(info, subprefix))
33
- if files:
34
- await result_queue.put(files)
35
- return subdirs
36
-
37
-
38
- async def _fetch(
39
- client, start_prefix: str, result_queue: ResultQueue, fetch_workers
40
- ) -> None:
41
- loop = get_loop()
42
-
43
- queue: asyncio.Queue[str] = asyncio.Queue()
44
- queue.put_nowait(start_prefix)
45
-
46
- async def process(queue) -> None:
47
- while True:
48
- prefix = await queue.get()
49
- try:
50
- subdirs = await _fetch_dir(client, prefix, result_queue)
51
- for subdir in subdirs:
52
- queue.put_nowait(subdir)
53
- except Exception:
54
- while not queue.empty():
55
- queue.get_nowait()
56
- queue.task_done()
57
- raise
58
-
59
- finally:
60
- queue.task_done()
61
-
62
- try:
63
- workers: list[asyncio.Task] = [
64
- loop.create_task(process(queue)) for _ in range(fetch_workers)
65
- ]
66
-
67
- # Wait for all fetch tasks to complete
68
- await queue.join()
69
- # Stop the workers
70
- excs = []
71
- for worker in workers:
72
- if worker.done() and (exc := worker.exception()):
73
- excs.append(exc)
74
- else:
75
- worker.cancel()
76
- if excs:
77
- raise excs[0]
78
- except ClientError as exc:
79
- raise DataChainClientError(
80
- exc.response.get("Error", {}).get("Message") or exc,
81
- exc.response.get("Error", {}).get("Code"),
82
- ) from exc
83
- finally:
84
- # This ensures the progress bar is closed before any exceptions are raised
85
- result_queue.put_nowait(None)
86
-
87
-
88
- async def _scandir(client, prefix, fetch_workers) -> AsyncIterator:
89
- """Recursively goes through dir tree and yields files"""
90
- result_queue: ResultQueue = asyncio.Queue()
91
- loop = get_loop()
92
- main_task = loop.create_task(_fetch(client, prefix, result_queue, fetch_workers))
93
- while (files := await result_queue.get()) is not None:
94
- for f in files:
95
- yield f
96
-
97
- await main_task
98
-
99
-
100
- def list_bucket(uri: str, client_config=None, fetch_workers=FETCH_WORKERS) -> Callable:
16
+ if TYPE_CHECKING:
17
+ from datachain.lib.dc import DataChain
18
+
19
+ LISTING_TTL = 4 * 60 * 60 # cached listing lasts 4 hours
20
+ LISTING_PREFIX = "lst__" # listing datasets start with this name
21
+
22
+
23
+ def list_bucket(uri: str, client_config=None) -> Callable:
101
24
  """
102
25
  Function that returns another generator function that yields File objects
103
26
  from bucket where each File represents one bucket entry.
@@ -106,6 +29,91 @@ def list_bucket(uri: str, client_config=None, fetch_workers=FETCH_WORKERS) -> Ca
106
29
  def list_func() -> Iterator[File]:
107
30
  config = client_config or {}
108
31
  client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type]
109
- yield from iter_over_async(_scandir(client, path, fetch_workers), get_loop())
32
+ for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()):
33
+ for entry in entries:
34
+ yield entry.to_file(client.uri)
110
35
 
111
36
  return list_func
37
+
38
+
39
+ def ls(
40
+ dc: "DataChain",
41
+ path: str,
42
+ recursive: Optional[bool] = True,
43
+ object_name="file",
44
+ ):
45
+ """
46
+ Return files by some path from DataChain instance which contains bucket listing.
47
+ Path can have globs.
48
+ If recursive is set to False, only first level children will be returned by
49
+ specified path
50
+ """
51
+
52
+ def _file_c(name: str) -> Column:
53
+ return Column(f"{object_name}.{name}")
54
+
55
+ dc = dc.filter(_file_c("is_latest") == true())
56
+
57
+ if recursive:
58
+ if not path or path == "/":
59
+ # root of a bucket, returning all latest files from it
60
+ return dc
61
+
62
+ if not uses_glob(path):
63
+ # path is not glob, so it's pointing to some directory or a specific
64
+ # file and we are adding proper filter for it
65
+ return dc.filter(
66
+ (_file_c("path") == path)
67
+ | (_file_c("path").glob(path.rstrip("/") + "/*"))
68
+ )
69
+
70
+ # path has glob syntax so we are returning glob filter
71
+ return dc.filter(_file_c("path").glob(path))
72
+ # returning only first level children by path
73
+ return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*"))
74
+
75
+
76
+ def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]:
77
+ """
78
+ Parsing uri and returns listing dataset name, listing uri and listing path
79
+ """
80
+ client, path = Client.parse_url(uri, cache, **client_config)
81
+
82
+ # clean path without globs
83
+ lst_uri_path = (
84
+ posixpath.dirname(path) if uses_glob(path) or client.fs.isfile(uri) else path
85
+ )
86
+
87
+ lst_uri = f"{client.uri}/{lst_uri_path.lstrip('/')}"
88
+ ds_name = (
89
+ f"{LISTING_PREFIX}{client.uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
90
+ )
91
+
92
+ return ds_name, lst_uri, path
93
+
94
+
95
+ def is_listing_dataset(name: str) -> bool:
96
+ """Returns True if it's special listing dataset"""
97
+ return name.startswith(LISTING_PREFIX)
98
+
99
+
100
+ def listing_uri_from_name(dataset_name: str) -> str:
101
+ """Returns clean storage URI from listing dataset name"""
102
+ if not is_listing_dataset(dataset_name):
103
+ raise ValueError(f"Dataset {dataset_name} is not a listing")
104
+ return dataset_name.removeprefix(LISTING_PREFIX)
105
+
106
+
107
+ def is_listing_expired(created_at: datetime) -> bool:
108
+ """Checks if listing has expired based on it's creation date"""
109
+ return datetime.now(timezone.utc) > created_at + timedelta(seconds=LISTING_TTL)
110
+
111
+
112
+ def is_listing_subset(ds1_name: str, ds2_name: str) -> bool:
113
+ """
114
+ Checks if one listing contains another one by comparing corresponding dataset names
115
+ """
116
+ assert ds1_name.endswith("/")
117
+ assert ds2_name.endswith("/")
118
+
119
+ return ds2_name.startswith(ds1_name)