datachain 0.11.11__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +39 -7
- datachain/catalog/loader.py +19 -13
- datachain/cli/__init__.py +2 -1
- datachain/cli/commands/ls.py +8 -6
- datachain/cli/commands/show.py +7 -0
- datachain/cli/parser/studio.py +13 -1
- datachain/client/fsspec.py +12 -16
- datachain/client/gcs.py +1 -1
- datachain/client/hf.py +36 -14
- datachain/client/local.py +1 -4
- datachain/client/s3.py +1 -1
- datachain/data_storage/metastore.py +6 -0
- datachain/data_storage/warehouse.py +3 -8
- datachain/dataset.py +8 -0
- datachain/error.py +0 -12
- datachain/fs/utils.py +30 -0
- datachain/func/__init__.py +5 -0
- datachain/func/func.py +2 -1
- datachain/lib/dc.py +59 -15
- datachain/lib/file.py +63 -18
- datachain/lib/image.py +30 -6
- datachain/lib/listing.py +21 -39
- datachain/lib/meta_formats.py +2 -2
- datachain/lib/signal_schema.py +65 -18
- datachain/lib/udf.py +3 -0
- datachain/lib/udf_signature.py +17 -9
- datachain/lib/video.py +7 -5
- datachain/model/bbox.py +209 -58
- datachain/model/pose.py +49 -37
- datachain/model/segment.py +22 -18
- datachain/model/ultralytics/bbox.py +9 -9
- datachain/model/ultralytics/pose.py +7 -7
- datachain/model/ultralytics/segment.py +7 -7
- datachain/model/utils.py +191 -0
- datachain/query/dataset.py +8 -2
- datachain/sql/sqlite/base.py +2 -2
- datachain/studio.py +8 -6
- datachain/utils.py +0 -16
- {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/METADATA +4 -2
- {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/RECORD +44 -42
- {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/WHEEL +1 -1
- {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/LICENSE +0 -0
- {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py
CHANGED
|
@@ -6,6 +6,7 @@ import sys
|
|
|
6
6
|
from collections.abc import Iterator, Sequence
|
|
7
7
|
from functools import wraps
|
|
8
8
|
from typing import (
|
|
9
|
+
IO,
|
|
9
10
|
TYPE_CHECKING,
|
|
10
11
|
Any,
|
|
11
12
|
BinaryIO,
|
|
@@ -22,7 +23,6 @@ import orjson
|
|
|
22
23
|
import sqlalchemy
|
|
23
24
|
from pydantic import BaseModel
|
|
24
25
|
from sqlalchemy.sql.functions import GenericFunction
|
|
25
|
-
from sqlalchemy.sql.sqltypes import NullType
|
|
26
26
|
from tqdm import tqdm
|
|
27
27
|
|
|
28
28
|
from datachain.dataset import DatasetRecord
|
|
@@ -55,7 +55,6 @@ from datachain.query import Session
|
|
|
55
55
|
from datachain.query.dataset import DatasetQuery, PartitionByType
|
|
56
56
|
from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
|
|
57
57
|
from datachain.sql.functions import path as pathfunc
|
|
58
|
-
from datachain.telemetry import telemetry
|
|
59
58
|
from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
|
|
60
59
|
|
|
61
60
|
if TYPE_CHECKING:
|
|
@@ -215,7 +214,7 @@ class DataChain:
|
|
|
215
214
|
from mistralai.client import MistralClient
|
|
216
215
|
from mistralai.models.chat_completion import ChatMessage
|
|
217
216
|
|
|
218
|
-
from datachain
|
|
217
|
+
from datachain import DataChain, Column
|
|
219
218
|
|
|
220
219
|
PROMPT = (
|
|
221
220
|
"Was this bot dialog successful? "
|
|
@@ -272,6 +271,18 @@ class DataChain:
|
|
|
272
271
|
self._setup: dict = setup or {}
|
|
273
272
|
self._sys = _sys
|
|
274
273
|
|
|
274
|
+
def __repr__(self) -> str:
|
|
275
|
+
"""Return a string representation of the chain."""
|
|
276
|
+
classname = self.__class__.__name__
|
|
277
|
+
if not self._effective_signals_schema.values:
|
|
278
|
+
return f"Empty {classname}"
|
|
279
|
+
|
|
280
|
+
import io
|
|
281
|
+
|
|
282
|
+
file = io.StringIO()
|
|
283
|
+
self.print_schema(file=file)
|
|
284
|
+
return file.getvalue()
|
|
285
|
+
|
|
275
286
|
@property
|
|
276
287
|
def schema(self) -> dict[str, DataType]:
|
|
277
288
|
"""Get schema of the chain."""
|
|
@@ -325,9 +336,9 @@ class DataChain:
|
|
|
325
336
|
"""Return `self.union(other)`."""
|
|
326
337
|
return self.union(other)
|
|
327
338
|
|
|
328
|
-
def print_schema(self) -> None:
|
|
339
|
+
def print_schema(self, file: Optional[IO] = None) -> None:
|
|
329
340
|
"""Print schema of the chain."""
|
|
330
|
-
self._effective_signals_schema.print_tree()
|
|
341
|
+
self._effective_signals_schema.print_tree(file=file)
|
|
331
342
|
|
|
332
343
|
def clone(self) -> "Self":
|
|
333
344
|
"""Make a copy of the chain in a new table."""
|
|
@@ -408,7 +419,7 @@ class DataChain:
|
|
|
408
419
|
@classmethod
|
|
409
420
|
def from_storage(
|
|
410
421
|
cls,
|
|
411
|
-
uri,
|
|
422
|
+
uri: Union[str, os.PathLike[str]],
|
|
412
423
|
*,
|
|
413
424
|
type: FileType = "binary",
|
|
414
425
|
session: Optional[Session] = None,
|
|
@@ -550,6 +561,8 @@ class DataChain:
|
|
|
550
561
|
)
|
|
551
562
|
```
|
|
552
563
|
"""
|
|
564
|
+
from datachain.telemetry import telemetry
|
|
565
|
+
|
|
553
566
|
query = DatasetQuery(
|
|
554
567
|
name=name,
|
|
555
568
|
version=version,
|
|
@@ -573,7 +586,7 @@ class DataChain:
|
|
|
573
586
|
@classmethod
|
|
574
587
|
def from_json(
|
|
575
588
|
cls,
|
|
576
|
-
path,
|
|
589
|
+
path: Union[str, os.PathLike[str]],
|
|
577
590
|
type: FileType = "text",
|
|
578
591
|
spec: Optional[DataType] = None,
|
|
579
592
|
schema_from: Optional[str] = "auto",
|
|
@@ -610,7 +623,7 @@ class DataChain:
|
|
|
610
623
|
```
|
|
611
624
|
"""
|
|
612
625
|
if schema_from == "auto":
|
|
613
|
-
schema_from = path
|
|
626
|
+
schema_from = str(path)
|
|
614
627
|
|
|
615
628
|
def jmespath_to_name(s: str):
|
|
616
629
|
name_end = re.search(r"\W", s).start() if re.search(r"\W", s) else len(s) # type: ignore[union-attr]
|
|
@@ -629,7 +642,8 @@ class DataChain:
|
|
|
629
642
|
model_name=model_name,
|
|
630
643
|
jmespath=jmespath,
|
|
631
644
|
nrows=nrows,
|
|
632
|
-
)
|
|
645
|
+
),
|
|
646
|
+
"params": {"file": File},
|
|
633
647
|
}
|
|
634
648
|
# disable prefetch if nrows is set
|
|
635
649
|
settings = {"prefetch": 0} if nrows else {}
|
|
@@ -701,9 +715,22 @@ class DataChain:
|
|
|
701
715
|
in_memory: bool = False,
|
|
702
716
|
object_name: str = "dataset",
|
|
703
717
|
include_listing: bool = False,
|
|
718
|
+
studio: bool = False,
|
|
704
719
|
) -> "DataChain":
|
|
705
720
|
"""Generate chain with list of registered datasets.
|
|
706
721
|
|
|
722
|
+
Args:
|
|
723
|
+
session: Optional session instance. If not provided, uses default session.
|
|
724
|
+
settings: Optional dictionary of settings to configure the chain.
|
|
725
|
+
in_memory: If True, creates an in-memory session. Defaults to False.
|
|
726
|
+
object_name: Name of the output object in the chain. Defaults to "dataset".
|
|
727
|
+
include_listing: If True, includes listing datasets. Defaults to False.
|
|
728
|
+
studio: If True, returns datasets from Studio only,
|
|
729
|
+
otherwise returns all local datasets. Defaults to False.
|
|
730
|
+
|
|
731
|
+
Returns:
|
|
732
|
+
DataChain: A new DataChain instance containing dataset information.
|
|
733
|
+
|
|
707
734
|
Example:
|
|
708
735
|
```py
|
|
709
736
|
from datachain import DataChain
|
|
@@ -719,7 +746,7 @@ class DataChain:
|
|
|
719
746
|
datasets = [
|
|
720
747
|
DatasetInfo.from_models(d, v, j)
|
|
721
748
|
for d, v, j in catalog.list_datasets_versions(
|
|
722
|
-
include_listing=include_listing
|
|
749
|
+
include_listing=include_listing, studio=studio
|
|
723
750
|
)
|
|
724
751
|
]
|
|
725
752
|
|
|
@@ -760,7 +787,12 @@ class DataChain:
|
|
|
760
787
|
)
|
|
761
788
|
|
|
762
789
|
def save( # type: ignore[override]
|
|
763
|
-
self,
|
|
790
|
+
self,
|
|
791
|
+
name: Optional[str] = None,
|
|
792
|
+
version: Optional[int] = None,
|
|
793
|
+
description: Optional[str] = None,
|
|
794
|
+
labels: Optional[list[str]] = None,
|
|
795
|
+
**kwargs,
|
|
764
796
|
) -> "Self":
|
|
765
797
|
"""Save to a Dataset. It returns the chain itself.
|
|
766
798
|
|
|
@@ -768,11 +800,18 @@ class DataChain:
|
|
|
768
800
|
name : dataset name. Empty name saves to a temporary dataset that will be
|
|
769
801
|
removed after process ends. Temp dataset are useful for optimization.
|
|
770
802
|
version : version of a dataset. Default - the last version that exist.
|
|
803
|
+
description : description of a dataset.
|
|
804
|
+
labels : labels of a dataset.
|
|
771
805
|
"""
|
|
772
806
|
schema = self.signals_schema.clone_without_sys_signals().serialize()
|
|
773
807
|
return self._evolve(
|
|
774
808
|
query=self._query.save(
|
|
775
|
-
name=name,
|
|
809
|
+
name=name,
|
|
810
|
+
version=version,
|
|
811
|
+
description=description,
|
|
812
|
+
labels=labels,
|
|
813
|
+
feature_schema=schema,
|
|
814
|
+
**kwargs,
|
|
776
815
|
)
|
|
777
816
|
)
|
|
778
817
|
|
|
@@ -990,8 +1029,9 @@ class DataChain:
|
|
|
990
1029
|
func: Optional[Union[Callable, UDFObjT]],
|
|
991
1030
|
params: Union[None, str, Sequence[str]],
|
|
992
1031
|
output: OutputType,
|
|
993
|
-
signal_map,
|
|
1032
|
+
signal_map: dict[str, Callable],
|
|
994
1033
|
) -> UDFObjT:
|
|
1034
|
+
is_batch = target_class.is_input_batched
|
|
995
1035
|
is_generator = target_class.is_output_batched
|
|
996
1036
|
name = self.name or ""
|
|
997
1037
|
|
|
@@ -1002,7 +1042,9 @@ class DataChain:
|
|
|
1002
1042
|
if self._sys:
|
|
1003
1043
|
signals_schema = SignalSchema({"sys": Sys}) | signals_schema
|
|
1004
1044
|
|
|
1005
|
-
params_schema = signals_schema.slice(
|
|
1045
|
+
params_schema = signals_schema.slice(
|
|
1046
|
+
sign.params, self._setup, is_batch=is_batch
|
|
1047
|
+
)
|
|
1006
1048
|
|
|
1007
1049
|
return target_class._create(sign, params_schema)
|
|
1008
1050
|
|
|
@@ -1195,6 +1237,8 @@ class DataChain:
|
|
|
1195
1237
|
)
|
|
1196
1238
|
```
|
|
1197
1239
|
"""
|
|
1240
|
+
from sqlalchemy.sql.sqltypes import NullType
|
|
1241
|
+
|
|
1198
1242
|
primitives = (bool, str, int, float)
|
|
1199
1243
|
|
|
1200
1244
|
for col_name, expr in kwargs.items():
|
|
@@ -2542,7 +2586,7 @@ class DataChain:
|
|
|
2542
2586
|
|
|
2543
2587
|
def to_storage(
|
|
2544
2588
|
self,
|
|
2545
|
-
output: str,
|
|
2589
|
+
output: Union[str, os.PathLike[str]],
|
|
2546
2590
|
signal: str = "file",
|
|
2547
2591
|
placement: FileExportPlacement = "fullpath",
|
|
2548
2592
|
link_type: Literal["copy", "symlink"] = "copy",
|
datachain/lib/file.py
CHANGED
|
@@ -18,7 +18,6 @@ from urllib.request import url2pathname
|
|
|
18
18
|
|
|
19
19
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
20
20
|
from fsspec.utils import stringify_path
|
|
21
|
-
from PIL import Image as PilImage
|
|
22
21
|
from pydantic import Field, field_validator
|
|
23
22
|
|
|
24
23
|
from datachain.client.fileslice import FileSlice
|
|
@@ -52,7 +51,7 @@ class FileExporter(NodesThreadPool):
|
|
|
52
51
|
|
|
53
52
|
def __init__(
|
|
54
53
|
self,
|
|
55
|
-
output: str,
|
|
54
|
+
output: Union[str, os.PathLike[str]],
|
|
56
55
|
placement: ExportPlacement,
|
|
57
56
|
use_cache: bool,
|
|
58
57
|
link_type: Literal["copy", "symlink"],
|
|
@@ -194,7 +193,14 @@ class File(DataModel):
|
|
|
194
193
|
"last_modified": DateTime,
|
|
195
194
|
"location": JSON,
|
|
196
195
|
}
|
|
197
|
-
_hidden_fields: ClassVar[list[str]] = [
|
|
196
|
+
_hidden_fields: ClassVar[list[str]] = [
|
|
197
|
+
"source",
|
|
198
|
+
"version",
|
|
199
|
+
"etag",
|
|
200
|
+
"is_latest",
|
|
201
|
+
"last_modified",
|
|
202
|
+
"location",
|
|
203
|
+
]
|
|
198
204
|
|
|
199
205
|
_unique_id_keys: ClassVar[list[str]] = [
|
|
200
206
|
"source",
|
|
@@ -243,6 +249,30 @@ class File(DataModel):
|
|
|
243
249
|
self._catalog = None
|
|
244
250
|
self._caching_enabled: bool = False
|
|
245
251
|
|
|
252
|
+
def as_text_file(self) -> "TextFile":
|
|
253
|
+
"""Convert the file to a `TextFile` object."""
|
|
254
|
+
if isinstance(self, TextFile):
|
|
255
|
+
return self
|
|
256
|
+
file = TextFile(**self.model_dump())
|
|
257
|
+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
258
|
+
return file
|
|
259
|
+
|
|
260
|
+
def as_image_file(self) -> "ImageFile":
|
|
261
|
+
"""Convert the file to a `ImageFile` object."""
|
|
262
|
+
if isinstance(self, ImageFile):
|
|
263
|
+
return self
|
|
264
|
+
file = ImageFile(**self.model_dump())
|
|
265
|
+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
266
|
+
return file
|
|
267
|
+
|
|
268
|
+
def as_video_file(self) -> "VideoFile":
|
|
269
|
+
"""Convert the file to a `VideoFile` object."""
|
|
270
|
+
if isinstance(self, VideoFile):
|
|
271
|
+
return self
|
|
272
|
+
file = VideoFile(**self.model_dump())
|
|
273
|
+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
274
|
+
return file
|
|
275
|
+
|
|
246
276
|
@classmethod
|
|
247
277
|
def upload(
|
|
248
278
|
cls, data: bytes, path: str, catalog: Optional["Catalog"] = None
|
|
@@ -292,20 +322,20 @@ class File(DataModel):
|
|
|
292
322
|
) as f:
|
|
293
323
|
yield io.TextIOWrapper(f) if mode == "r" else f
|
|
294
324
|
|
|
295
|
-
def
|
|
296
|
-
"""Returns file contents."""
|
|
325
|
+
def read_bytes(self, length: int = -1):
|
|
326
|
+
"""Returns file contents as bytes."""
|
|
297
327
|
with self.open() as stream:
|
|
298
328
|
return stream.read(length)
|
|
299
329
|
|
|
300
|
-
def read_bytes(self):
|
|
301
|
-
"""Returns file contents as bytes."""
|
|
302
|
-
return self.read()
|
|
303
|
-
|
|
304
330
|
def read_text(self):
|
|
305
331
|
"""Returns file contents as text."""
|
|
306
332
|
with self.open(mode="r") as stream:
|
|
307
333
|
return stream.read()
|
|
308
334
|
|
|
335
|
+
def read(self, length: int = -1):
|
|
336
|
+
"""Returns file contents."""
|
|
337
|
+
return self.read_bytes(length)
|
|
338
|
+
|
|
309
339
|
def save(self, destination: str, client_config: Optional[dict] = None):
|
|
310
340
|
"""Writes it's content to destination"""
|
|
311
341
|
destination = stringify_path(destination)
|
|
@@ -333,7 +363,7 @@ class File(DataModel):
|
|
|
333
363
|
|
|
334
364
|
def export(
|
|
335
365
|
self,
|
|
336
|
-
output: str,
|
|
366
|
+
output: Union[str, os.PathLike[str]],
|
|
337
367
|
placement: ExportPlacement = "fullpath",
|
|
338
368
|
use_cache: bool = True,
|
|
339
369
|
link_type: Literal["copy", "symlink"] = "copy",
|
|
@@ -374,15 +404,10 @@ class File(DataModel):
|
|
|
374
404
|
client.download(self, callback=self._download_cb)
|
|
375
405
|
|
|
376
406
|
async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
|
|
377
|
-
from datachain.client.hf import HfClient
|
|
378
|
-
|
|
379
407
|
if self._catalog is None:
|
|
380
408
|
raise RuntimeError("cannot prefetch file because catalog is not setup")
|
|
381
409
|
|
|
382
410
|
client = self._catalog.get_client(self.source)
|
|
383
|
-
if client.protocol == HfClient.protocol:
|
|
384
|
-
return False
|
|
385
|
-
|
|
386
411
|
await client._download(self, callback=download_cb or self._download_cb)
|
|
387
412
|
self._set_stream(
|
|
388
413
|
self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
|
|
@@ -430,7 +455,9 @@ class File(DataModel):
|
|
|
430
455
|
path = url2pathname(path)
|
|
431
456
|
return path
|
|
432
457
|
|
|
433
|
-
def get_destination_path(
|
|
458
|
+
def get_destination_path(
|
|
459
|
+
self, output: Union[str, os.PathLike[str]], placement: ExportPlacement
|
|
460
|
+
) -> str:
|
|
434
461
|
"""
|
|
435
462
|
Returns full destination path of a file for exporting to some output
|
|
436
463
|
based on export placement
|
|
@@ -551,18 +578,36 @@ class TextFile(File):
|
|
|
551
578
|
class ImageFile(File):
|
|
552
579
|
"""`DataModel` for reading image files."""
|
|
553
580
|
|
|
581
|
+
def get_info(self) -> "Image":
|
|
582
|
+
"""
|
|
583
|
+
Retrieves metadata and information about the image file.
|
|
584
|
+
|
|
585
|
+
Returns:
|
|
586
|
+
Image: A Model containing image metadata such as width, height and format.
|
|
587
|
+
"""
|
|
588
|
+
from .image import image_info
|
|
589
|
+
|
|
590
|
+
return image_info(self)
|
|
591
|
+
|
|
554
592
|
def read(self):
|
|
555
593
|
"""Returns `PIL.Image.Image` object."""
|
|
594
|
+
from PIL import Image as PilImage
|
|
595
|
+
|
|
556
596
|
fobj = super().read()
|
|
557
597
|
return PilImage.open(BytesIO(fobj))
|
|
558
598
|
|
|
559
|
-
def save(
|
|
599
|
+
def save( # type: ignore[override]
|
|
600
|
+
self,
|
|
601
|
+
destination: str,
|
|
602
|
+
format: Optional[str] = None,
|
|
603
|
+
client_config: Optional[dict] = None,
|
|
604
|
+
):
|
|
560
605
|
"""Writes it's content to destination"""
|
|
561
606
|
destination = stringify_path(destination)
|
|
562
607
|
|
|
563
608
|
client: Client = self._catalog.get_client(destination, **(client_config or {}))
|
|
564
609
|
with client.fs.open(destination, mode="wb") as f:
|
|
565
|
-
self.read().save(f)
|
|
610
|
+
self.read().save(f, format=format)
|
|
566
611
|
|
|
567
612
|
|
|
568
613
|
class Image(DataModel):
|
datachain/lib/image.py
CHANGED
|
@@ -1,17 +1,41 @@
|
|
|
1
1
|
from typing import Callable, Optional, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from PIL import Image
|
|
4
|
+
from PIL import Image as PILImage
|
|
5
|
+
|
|
6
|
+
from datachain.lib.file import File, FileError, Image, ImageFile
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def image_info(file: Union[File, ImageFile]) -> Image:
|
|
10
|
+
"""
|
|
11
|
+
Returns image file information.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
file (ImageFile): Image file object.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
Image: Image file information.
|
|
18
|
+
"""
|
|
19
|
+
try:
|
|
20
|
+
img = file.as_image_file().read()
|
|
21
|
+
except Exception as exc:
|
|
22
|
+
raise FileError(file, "unable to open image file") from exc
|
|
23
|
+
|
|
24
|
+
return Image(
|
|
25
|
+
width=img.width,
|
|
26
|
+
height=img.height,
|
|
27
|
+
format=img.format or "",
|
|
28
|
+
)
|
|
5
29
|
|
|
6
30
|
|
|
7
31
|
def convert_image(
|
|
8
|
-
img:
|
|
32
|
+
img: PILImage.Image,
|
|
9
33
|
mode: str = "RGB",
|
|
10
34
|
size: Optional[tuple[int, int]] = None,
|
|
11
35
|
transform: Optional[Callable] = None,
|
|
12
36
|
encoder: Optional[Callable] = None,
|
|
13
37
|
device: Optional[Union[str, torch.device]] = None,
|
|
14
|
-
) -> Union[
|
|
38
|
+
) -> Union[PILImage.Image, torch.Tensor]:
|
|
15
39
|
"""
|
|
16
40
|
Resize, transform, and otherwise convert an image.
|
|
17
41
|
|
|
@@ -47,13 +71,13 @@ def convert_image(
|
|
|
47
71
|
|
|
48
72
|
|
|
49
73
|
def convert_images(
|
|
50
|
-
images: Union[
|
|
74
|
+
images: Union[PILImage.Image, list[PILImage.Image]],
|
|
51
75
|
mode: str = "RGB",
|
|
52
76
|
size: Optional[tuple[int, int]] = None,
|
|
53
77
|
transform: Optional[Callable] = None,
|
|
54
78
|
encoder: Optional[Callable] = None,
|
|
55
79
|
device: Optional[Union[str, torch.device]] = None,
|
|
56
|
-
) -> Union[list[
|
|
80
|
+
) -> Union[list[PILImage.Image], torch.Tensor]:
|
|
57
81
|
"""
|
|
58
82
|
Resize, transform, and otherwise convert one or more images.
|
|
59
83
|
|
|
@@ -65,7 +89,7 @@ def convert_images(
|
|
|
65
89
|
encoder (Callable): Encode image using model.
|
|
66
90
|
device (str or torch.device): Device to use.
|
|
67
91
|
"""
|
|
68
|
-
if isinstance(images,
|
|
92
|
+
if isinstance(images, PILImage.Image):
|
|
69
93
|
images = [images]
|
|
70
94
|
|
|
71
95
|
converted = [
|
datachain/lib/listing.py
CHANGED
|
@@ -1,19 +1,21 @@
|
|
|
1
|
+
import glob
|
|
1
2
|
import logging
|
|
2
3
|
import os
|
|
3
4
|
import posixpath
|
|
4
5
|
from collections.abc import Iterator
|
|
5
|
-
from
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
|
|
6
8
|
|
|
7
9
|
from fsspec.asyn import get_loop
|
|
8
10
|
from sqlalchemy.sql.expression import true
|
|
9
11
|
|
|
12
|
+
import datachain.fs.utils as fsutils
|
|
10
13
|
from datachain.asyn import iter_over_async
|
|
11
14
|
from datachain.client import Client
|
|
12
|
-
from datachain.error import
|
|
15
|
+
from datachain.error import ClientError
|
|
13
16
|
from datachain.lib.file import File
|
|
14
17
|
from datachain.query.schema import Column
|
|
15
18
|
from datachain.sql.functions import path as pathfunc
|
|
16
|
-
from datachain.telemetry import telemetry
|
|
17
19
|
from datachain.utils import uses_glob
|
|
18
20
|
|
|
19
21
|
if TYPE_CHECKING:
|
|
@@ -92,38 +94,6 @@ def ls(
|
|
|
92
94
|
return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*"))
|
|
93
95
|
|
|
94
96
|
|
|
95
|
-
def _isfile(client: "Client", path: str) -> bool:
|
|
96
|
-
"""
|
|
97
|
-
Returns True if uri points to a file
|
|
98
|
-
"""
|
|
99
|
-
try:
|
|
100
|
-
if "://" in path:
|
|
101
|
-
# This makes sure that the uppercase scheme is converted to lowercase
|
|
102
|
-
scheme, path = path.split("://", 1)
|
|
103
|
-
path = f"{scheme.lower()}://{path}"
|
|
104
|
-
|
|
105
|
-
if os.name == "nt" and "*" in path:
|
|
106
|
-
# On Windows, the glob pattern "*" is not supported
|
|
107
|
-
return False
|
|
108
|
-
|
|
109
|
-
info = client.fs.info(path)
|
|
110
|
-
name = info.get("name")
|
|
111
|
-
# case for special simulated directories on some clouds
|
|
112
|
-
# e.g. Google creates a zero byte file with the same name as the
|
|
113
|
-
# directory with a trailing slash at the end
|
|
114
|
-
if not name or name.endswith("/"):
|
|
115
|
-
return False
|
|
116
|
-
|
|
117
|
-
return info["type"] == "file"
|
|
118
|
-
except FileNotFoundError:
|
|
119
|
-
return False
|
|
120
|
-
except REMOTE_ERRORS as e:
|
|
121
|
-
raise ClientError(
|
|
122
|
-
message=str(e),
|
|
123
|
-
error_code=getattr(e, "code", None),
|
|
124
|
-
) from e
|
|
125
|
-
|
|
126
|
-
|
|
127
97
|
def parse_listing_uri(uri: str, client_config) -> tuple[str, str, str]:
|
|
128
98
|
"""
|
|
129
99
|
Parsing uri and returns listing dataset name, listing uri and listing path
|
|
@@ -156,8 +126,16 @@ def listing_uri_from_name(dataset_name: str) -> str:
|
|
|
156
126
|
return dataset_name.removeprefix(LISTING_PREFIX)
|
|
157
127
|
|
|
158
128
|
|
|
129
|
+
@contextmanager
|
|
130
|
+
def _reraise_as_client_error() -> Iterator[None]:
|
|
131
|
+
try:
|
|
132
|
+
yield
|
|
133
|
+
except Exception as e:
|
|
134
|
+
raise ClientError(message=str(e), error_code=getattr(e, "code", None)) from e
|
|
135
|
+
|
|
136
|
+
|
|
159
137
|
def get_listing(
|
|
160
|
-
uri: str, session: "Session", update: bool = False
|
|
138
|
+
uri: Union[str, os.PathLike[str]], session: "Session", update: bool = False
|
|
161
139
|
) -> tuple[Optional[str], str, str, bool]:
|
|
162
140
|
"""Returns correct listing dataset name that must be used for saving listing
|
|
163
141
|
operation. It takes into account existing listings and reusability of those.
|
|
@@ -167,6 +145,7 @@ def get_listing(
|
|
|
167
145
|
be used to find rows based on uri.
|
|
168
146
|
"""
|
|
169
147
|
from datachain.client.local import FileClient
|
|
148
|
+
from datachain.telemetry import telemetry
|
|
170
149
|
|
|
171
150
|
catalog = session.catalog
|
|
172
151
|
cache = catalog.cache
|
|
@@ -174,11 +153,14 @@ def get_listing(
|
|
|
174
153
|
|
|
175
154
|
client = Client.get_client(uri, cache, **client_config)
|
|
176
155
|
telemetry.log_param("client", client.PREFIX)
|
|
156
|
+
if not isinstance(uri, str):
|
|
157
|
+
uri = os.fspath(uri)
|
|
177
158
|
|
|
178
159
|
# we don't want to use cached dataset (e.g. for a single file listing)
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
160
|
+
isfile = _reraise_as_client_error()(fsutils.isfile)
|
|
161
|
+
if not glob.has_magic(uri) and not uri.endswith("/") and isfile(client.fs, uri):
|
|
162
|
+
_, path = Client.parse_url(uri)
|
|
163
|
+
return None, uri, path, False
|
|
182
164
|
|
|
183
165
|
ds_name, list_uri, list_path = parse_listing_uri(uri, client_config)
|
|
184
166
|
listing = None
|
datachain/lib/meta_formats.py
CHANGED
|
@@ -10,7 +10,7 @@ import jmespath as jsp
|
|
|
10
10
|
from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401
|
|
11
11
|
|
|
12
12
|
from datachain.lib.data_model import DataModel # noqa: F401
|
|
13
|
-
from datachain.lib.file import
|
|
13
|
+
from datachain.lib.file import TextFile
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class UserModel(BaseModel):
|
|
@@ -130,7 +130,7 @@ def read_meta( # noqa: C901
|
|
|
130
130
|
#
|
|
131
131
|
|
|
132
132
|
def parse_data(
|
|
133
|
-
file:
|
|
133
|
+
file: TextFile,
|
|
134
134
|
data_model=spec,
|
|
135
135
|
format=format,
|
|
136
136
|
jmespath=jmespath,
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|
|
5
5
|
from datetime import datetime
|
|
6
6
|
from inspect import isclass
|
|
7
7
|
from typing import ( # noqa: UP035
|
|
8
|
+
IO,
|
|
8
9
|
TYPE_CHECKING,
|
|
9
10
|
Annotated,
|
|
10
11
|
Any,
|
|
@@ -154,9 +155,9 @@ class SignalSchema:
|
|
|
154
155
|
if not callable(func):
|
|
155
156
|
raise SetupError(key, "value must be function or callable class")
|
|
156
157
|
|
|
157
|
-
def _init_setup_values(self):
|
|
158
|
+
def _init_setup_values(self) -> None:
|
|
158
159
|
if self.setup_values is not None:
|
|
159
|
-
return
|
|
160
|
+
return
|
|
160
161
|
|
|
161
162
|
res = {}
|
|
162
163
|
for key, func in self.setup_func.items():
|
|
@@ -398,7 +399,7 @@ class SignalSchema:
|
|
|
398
399
|
return SignalSchema(signals)
|
|
399
400
|
|
|
400
401
|
@staticmethod
|
|
401
|
-
def get_flatten_hidden_fields(schema):
|
|
402
|
+
def get_flatten_hidden_fields(schema: dict):
|
|
402
403
|
custom_types = schema.get("_custom_types", {})
|
|
403
404
|
if not custom_types:
|
|
404
405
|
return []
|
|
@@ -464,19 +465,61 @@ class SignalSchema:
|
|
|
464
465
|
return False
|
|
465
466
|
|
|
466
467
|
def slice(
|
|
467
|
-
self,
|
|
468
|
+
self,
|
|
469
|
+
params: dict[str, Union[DataType, Any]],
|
|
470
|
+
setup: Optional[dict[str, Callable]] = None,
|
|
471
|
+
is_batch: bool = False,
|
|
468
472
|
) -> "SignalSchema":
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
for
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
473
|
+
"""
|
|
474
|
+
Returns new schema that combines current schema and setup signals.
|
|
475
|
+
"""
|
|
476
|
+
setup_params = setup.keys() if setup else []
|
|
477
|
+
schema: dict[str, DataType] = {}
|
|
478
|
+
|
|
479
|
+
for param, param_type in params.items():
|
|
480
|
+
# This is special case for setup params, they are always treated as strings
|
|
481
|
+
if param in setup_params:
|
|
482
|
+
schema[param] = str
|
|
483
|
+
continue
|
|
484
|
+
|
|
485
|
+
schema_type = self._find_in_tree(param.split("."))
|
|
486
|
+
|
|
487
|
+
if param_type is Any:
|
|
488
|
+
schema[param] = schema_type
|
|
489
|
+
continue
|
|
490
|
+
|
|
491
|
+
schema_origin = get_origin(schema_type)
|
|
492
|
+
param_origin = get_origin(param_type)
|
|
493
|
+
|
|
494
|
+
if schema_origin is Union and type(None) in get_args(schema_type):
|
|
495
|
+
schema_type = get_args(schema_type)[0]
|
|
496
|
+
if param_origin is Union and type(None) in get_args(param_type):
|
|
497
|
+
param_type = get_args(param_type)[0]
|
|
498
|
+
|
|
499
|
+
if is_batch:
|
|
500
|
+
if param_type is list:
|
|
501
|
+
schema[param] = schema_type
|
|
502
|
+
continue
|
|
503
|
+
|
|
504
|
+
if param_origin is not list:
|
|
505
|
+
raise SignalResolvingError(param.split("."), "is not a list")
|
|
506
|
+
|
|
507
|
+
param_type = get_args(param_type)[0]
|
|
508
|
+
|
|
509
|
+
if param_type == schema_type or (
|
|
510
|
+
isclass(param_type)
|
|
511
|
+
and isclass(schema_type)
|
|
512
|
+
and issubclass(param_type, File)
|
|
513
|
+
and issubclass(schema_type, File)
|
|
514
|
+
):
|
|
515
|
+
schema[param] = schema_type
|
|
516
|
+
continue
|
|
517
|
+
|
|
518
|
+
raise SignalResolvingError(
|
|
519
|
+
param.split("."),
|
|
520
|
+
f"types mismatch: {param_type} != {schema_type}",
|
|
521
|
+
)
|
|
522
|
+
|
|
480
523
|
return SignalSchema(schema, setup)
|
|
481
524
|
|
|
482
525
|
def row_to_features(
|
|
@@ -696,16 +739,20 @@ class SignalSchema:
|
|
|
696
739
|
substree, new_prefix, depth + 1, include_hidden
|
|
697
740
|
)
|
|
698
741
|
|
|
699
|
-
def print_tree(self, indent: int =
|
|
742
|
+
def print_tree(self, indent: int = 2, start_at: int = 0, file: Optional[IO] = None):
|
|
700
743
|
for path, type_, _, depth in self.get_flat_tree():
|
|
701
744
|
total_indent = start_at + depth * indent
|
|
702
|
-
|
|
745
|
+
col_name = " " * total_indent + path[-1]
|
|
746
|
+
col_type = SignalSchema._type_to_str(type_)
|
|
747
|
+
print(col_name, col_type, sep=": ", file=file)
|
|
703
748
|
|
|
704
749
|
if get_origin(type_) is list:
|
|
705
750
|
args = get_args(type_)
|
|
706
751
|
if len(args) > 0 and ModelStore.is_pydantic(args[0]):
|
|
707
752
|
sub_schema = SignalSchema({"* list of": args[0]})
|
|
708
|
-
sub_schema.print_tree(
|
|
753
|
+
sub_schema.print_tree(
|
|
754
|
+
indent=indent, start_at=total_indent + indent, file=file
|
|
755
|
+
)
|
|
709
756
|
|
|
710
757
|
def get_headers_with_length(self, include_hidden: bool = True):
|
|
711
758
|
paths = [
|