datachain 0.11.11__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (44) hide show
  1. datachain/catalog/catalog.py +39 -7
  2. datachain/catalog/loader.py +19 -13
  3. datachain/cli/__init__.py +2 -1
  4. datachain/cli/commands/ls.py +8 -6
  5. datachain/cli/commands/show.py +7 -0
  6. datachain/cli/parser/studio.py +13 -1
  7. datachain/client/fsspec.py +12 -16
  8. datachain/client/gcs.py +1 -1
  9. datachain/client/hf.py +36 -14
  10. datachain/client/local.py +1 -4
  11. datachain/client/s3.py +1 -1
  12. datachain/data_storage/metastore.py +6 -0
  13. datachain/data_storage/warehouse.py +3 -8
  14. datachain/dataset.py +8 -0
  15. datachain/error.py +0 -12
  16. datachain/fs/utils.py +30 -0
  17. datachain/func/__init__.py +5 -0
  18. datachain/func/func.py +2 -1
  19. datachain/lib/dc.py +59 -15
  20. datachain/lib/file.py +63 -18
  21. datachain/lib/image.py +30 -6
  22. datachain/lib/listing.py +21 -39
  23. datachain/lib/meta_formats.py +2 -2
  24. datachain/lib/signal_schema.py +65 -18
  25. datachain/lib/udf.py +3 -0
  26. datachain/lib/udf_signature.py +17 -9
  27. datachain/lib/video.py +7 -5
  28. datachain/model/bbox.py +209 -58
  29. datachain/model/pose.py +49 -37
  30. datachain/model/segment.py +22 -18
  31. datachain/model/ultralytics/bbox.py +9 -9
  32. datachain/model/ultralytics/pose.py +7 -7
  33. datachain/model/ultralytics/segment.py +7 -7
  34. datachain/model/utils.py +191 -0
  35. datachain/query/dataset.py +8 -2
  36. datachain/sql/sqlite/base.py +2 -2
  37. datachain/studio.py +8 -6
  38. datachain/utils.py +0 -16
  39. {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/METADATA +4 -2
  40. {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/RECORD +44 -42
  41. {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/WHEEL +1 -1
  42. {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/LICENSE +0 -0
  43. {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/entry_points.txt +0 -0
  44. {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py CHANGED
@@ -6,6 +6,7 @@ import sys
6
6
  from collections.abc import Iterator, Sequence
7
7
  from functools import wraps
8
8
  from typing import (
9
+ IO,
9
10
  TYPE_CHECKING,
10
11
  Any,
11
12
  BinaryIO,
@@ -22,7 +23,6 @@ import orjson
22
23
  import sqlalchemy
23
24
  from pydantic import BaseModel
24
25
  from sqlalchemy.sql.functions import GenericFunction
25
- from sqlalchemy.sql.sqltypes import NullType
26
26
  from tqdm import tqdm
27
27
 
28
28
  from datachain.dataset import DatasetRecord
@@ -55,7 +55,6 @@ from datachain.query import Session
55
55
  from datachain.query.dataset import DatasetQuery, PartitionByType
56
56
  from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
57
57
  from datachain.sql.functions import path as pathfunc
58
- from datachain.telemetry import telemetry
59
58
  from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
60
59
 
61
60
  if TYPE_CHECKING:
@@ -215,7 +214,7 @@ class DataChain:
215
214
  from mistralai.client import MistralClient
216
215
  from mistralai.models.chat_completion import ChatMessage
217
216
 
218
- from datachain.dc import DataChain, Column
217
+ from datachain import DataChain, Column
219
218
 
220
219
  PROMPT = (
221
220
  "Was this bot dialog successful? "
@@ -272,6 +271,18 @@ class DataChain:
272
271
  self._setup: dict = setup or {}
273
272
  self._sys = _sys
274
273
 
274
+ def __repr__(self) -> str:
275
+ """Return a string representation of the chain."""
276
+ classname = self.__class__.__name__
277
+ if not self._effective_signals_schema.values:
278
+ return f"Empty {classname}"
279
+
280
+ import io
281
+
282
+ file = io.StringIO()
283
+ self.print_schema(file=file)
284
+ return file.getvalue()
285
+
275
286
  @property
276
287
  def schema(self) -> dict[str, DataType]:
277
288
  """Get schema of the chain."""
@@ -325,9 +336,9 @@ class DataChain:
325
336
  """Return `self.union(other)`."""
326
337
  return self.union(other)
327
338
 
328
- def print_schema(self) -> None:
339
+ def print_schema(self, file: Optional[IO] = None) -> None:
329
340
  """Print schema of the chain."""
330
- self._effective_signals_schema.print_tree()
341
+ self._effective_signals_schema.print_tree(file=file)
331
342
 
332
343
  def clone(self) -> "Self":
333
344
  """Make a copy of the chain in a new table."""
@@ -408,7 +419,7 @@ class DataChain:
408
419
  @classmethod
409
420
  def from_storage(
410
421
  cls,
411
- uri,
422
+ uri: Union[str, os.PathLike[str]],
412
423
  *,
413
424
  type: FileType = "binary",
414
425
  session: Optional[Session] = None,
@@ -550,6 +561,8 @@ class DataChain:
550
561
  )
551
562
  ```
552
563
  """
564
+ from datachain.telemetry import telemetry
565
+
553
566
  query = DatasetQuery(
554
567
  name=name,
555
568
  version=version,
@@ -573,7 +586,7 @@ class DataChain:
573
586
  @classmethod
574
587
  def from_json(
575
588
  cls,
576
- path,
589
+ path: Union[str, os.PathLike[str]],
577
590
  type: FileType = "text",
578
591
  spec: Optional[DataType] = None,
579
592
  schema_from: Optional[str] = "auto",
@@ -610,7 +623,7 @@ class DataChain:
610
623
  ```
611
624
  """
612
625
  if schema_from == "auto":
613
- schema_from = path
626
+ schema_from = str(path)
614
627
 
615
628
  def jmespath_to_name(s: str):
616
629
  name_end = re.search(r"\W", s).start() if re.search(r"\W", s) else len(s) # type: ignore[union-attr]
@@ -629,7 +642,8 @@ class DataChain:
629
642
  model_name=model_name,
630
643
  jmespath=jmespath,
631
644
  nrows=nrows,
632
- )
645
+ ),
646
+ "params": {"file": File},
633
647
  }
634
648
  # disable prefetch if nrows is set
635
649
  settings = {"prefetch": 0} if nrows else {}
@@ -701,9 +715,22 @@ class DataChain:
701
715
  in_memory: bool = False,
702
716
  object_name: str = "dataset",
703
717
  include_listing: bool = False,
718
+ studio: bool = False,
704
719
  ) -> "DataChain":
705
720
  """Generate chain with list of registered datasets.
706
721
 
722
+ Args:
723
+ session: Optional session instance. If not provided, uses default session.
724
+ settings: Optional dictionary of settings to configure the chain.
725
+ in_memory: If True, creates an in-memory session. Defaults to False.
726
+ object_name: Name of the output object in the chain. Defaults to "dataset".
727
+ include_listing: If True, includes listing datasets. Defaults to False.
728
+ studio: If True, returns datasets from Studio only,
729
+ otherwise returns all local datasets. Defaults to False.
730
+
731
+ Returns:
732
+ DataChain: A new DataChain instance containing dataset information.
733
+
707
734
  Example:
708
735
  ```py
709
736
  from datachain import DataChain
@@ -719,7 +746,7 @@ class DataChain:
719
746
  datasets = [
720
747
  DatasetInfo.from_models(d, v, j)
721
748
  for d, v, j in catalog.list_datasets_versions(
722
- include_listing=include_listing
749
+ include_listing=include_listing, studio=studio
723
750
  )
724
751
  ]
725
752
 
@@ -760,7 +787,12 @@ class DataChain:
760
787
  )
761
788
 
762
789
  def save( # type: ignore[override]
763
- self, name: Optional[str] = None, version: Optional[int] = None, **kwargs
790
+ self,
791
+ name: Optional[str] = None,
792
+ version: Optional[int] = None,
793
+ description: Optional[str] = None,
794
+ labels: Optional[list[str]] = None,
795
+ **kwargs,
764
796
  ) -> "Self":
765
797
  """Save to a Dataset. It returns the chain itself.
766
798
 
@@ -768,11 +800,18 @@ class DataChain:
768
800
  name : dataset name. Empty name saves to a temporary dataset that will be
769
801
  removed after process ends. Temp dataset are useful for optimization.
770
802
  version : version of a dataset. Default - the last version that exist.
803
+ description : description of a dataset.
804
+ labels : labels of a dataset.
771
805
  """
772
806
  schema = self.signals_schema.clone_without_sys_signals().serialize()
773
807
  return self._evolve(
774
808
  query=self._query.save(
775
- name=name, version=version, feature_schema=schema, **kwargs
809
+ name=name,
810
+ version=version,
811
+ description=description,
812
+ labels=labels,
813
+ feature_schema=schema,
814
+ **kwargs,
776
815
  )
777
816
  )
778
817
 
@@ -990,8 +1029,9 @@ class DataChain:
990
1029
  func: Optional[Union[Callable, UDFObjT]],
991
1030
  params: Union[None, str, Sequence[str]],
992
1031
  output: OutputType,
993
- signal_map,
1032
+ signal_map: dict[str, Callable],
994
1033
  ) -> UDFObjT:
1034
+ is_batch = target_class.is_input_batched
995
1035
  is_generator = target_class.is_output_batched
996
1036
  name = self.name or ""
997
1037
 
@@ -1002,7 +1042,9 @@ class DataChain:
1002
1042
  if self._sys:
1003
1043
  signals_schema = SignalSchema({"sys": Sys}) | signals_schema
1004
1044
 
1005
- params_schema = signals_schema.slice(sign.params, self._setup)
1045
+ params_schema = signals_schema.slice(
1046
+ sign.params, self._setup, is_batch=is_batch
1047
+ )
1006
1048
 
1007
1049
  return target_class._create(sign, params_schema)
1008
1050
 
@@ -1195,6 +1237,8 @@ class DataChain:
1195
1237
  )
1196
1238
  ```
1197
1239
  """
1240
+ from sqlalchemy.sql.sqltypes import NullType
1241
+
1198
1242
  primitives = (bool, str, int, float)
1199
1243
 
1200
1244
  for col_name, expr in kwargs.items():
@@ -2542,7 +2586,7 @@ class DataChain:
2542
2586
 
2543
2587
  def to_storage(
2544
2588
  self,
2545
- output: str,
2589
+ output: Union[str, os.PathLike[str]],
2546
2590
  signal: str = "file",
2547
2591
  placement: FileExportPlacement = "fullpath",
2548
2592
  link_type: Literal["copy", "symlink"] = "copy",
datachain/lib/file.py CHANGED
@@ -18,7 +18,6 @@ from urllib.request import url2pathname
18
18
 
19
19
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
20
20
  from fsspec.utils import stringify_path
21
- from PIL import Image as PilImage
22
21
  from pydantic import Field, field_validator
23
22
 
24
23
  from datachain.client.fileslice import FileSlice
@@ -52,7 +51,7 @@ class FileExporter(NodesThreadPool):
52
51
 
53
52
  def __init__(
54
53
  self,
55
- output: str,
54
+ output: Union[str, os.PathLike[str]],
56
55
  placement: ExportPlacement,
57
56
  use_cache: bool,
58
57
  link_type: Literal["copy", "symlink"],
@@ -194,7 +193,14 @@ class File(DataModel):
194
193
  "last_modified": DateTime,
195
194
  "location": JSON,
196
195
  }
197
- _hidden_fields: ClassVar[list[str]] = ["version", "source"]
196
+ _hidden_fields: ClassVar[list[str]] = [
197
+ "source",
198
+ "version",
199
+ "etag",
200
+ "is_latest",
201
+ "last_modified",
202
+ "location",
203
+ ]
198
204
 
199
205
  _unique_id_keys: ClassVar[list[str]] = [
200
206
  "source",
@@ -243,6 +249,30 @@ class File(DataModel):
243
249
  self._catalog = None
244
250
  self._caching_enabled: bool = False
245
251
 
252
+ def as_text_file(self) -> "TextFile":
253
+ """Convert the file to a `TextFile` object."""
254
+ if isinstance(self, TextFile):
255
+ return self
256
+ file = TextFile(**self.model_dump())
257
+ file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
258
+ return file
259
+
260
+ def as_image_file(self) -> "ImageFile":
261
+ """Convert the file to a `ImageFile` object."""
262
+ if isinstance(self, ImageFile):
263
+ return self
264
+ file = ImageFile(**self.model_dump())
265
+ file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
266
+ return file
267
+
268
+ def as_video_file(self) -> "VideoFile":
269
+ """Convert the file to a `VideoFile` object."""
270
+ if isinstance(self, VideoFile):
271
+ return self
272
+ file = VideoFile(**self.model_dump())
273
+ file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
274
+ return file
275
+
246
276
  @classmethod
247
277
  def upload(
248
278
  cls, data: bytes, path: str, catalog: Optional["Catalog"] = None
@@ -292,20 +322,20 @@ class File(DataModel):
292
322
  ) as f:
293
323
  yield io.TextIOWrapper(f) if mode == "r" else f
294
324
 
295
- def read(self, length: int = -1):
296
- """Returns file contents."""
325
+ def read_bytes(self, length: int = -1):
326
+ """Returns file contents as bytes."""
297
327
  with self.open() as stream:
298
328
  return stream.read(length)
299
329
 
300
- def read_bytes(self):
301
- """Returns file contents as bytes."""
302
- return self.read()
303
-
304
330
  def read_text(self):
305
331
  """Returns file contents as text."""
306
332
  with self.open(mode="r") as stream:
307
333
  return stream.read()
308
334
 
335
+ def read(self, length: int = -1):
336
+ """Returns file contents."""
337
+ return self.read_bytes(length)
338
+
309
339
  def save(self, destination: str, client_config: Optional[dict] = None):
310
340
  """Writes it's content to destination"""
311
341
  destination = stringify_path(destination)
@@ -333,7 +363,7 @@ class File(DataModel):
333
363
 
334
364
  def export(
335
365
  self,
336
- output: str,
366
+ output: Union[str, os.PathLike[str]],
337
367
  placement: ExportPlacement = "fullpath",
338
368
  use_cache: bool = True,
339
369
  link_type: Literal["copy", "symlink"] = "copy",
@@ -374,15 +404,10 @@ class File(DataModel):
374
404
  client.download(self, callback=self._download_cb)
375
405
 
376
406
  async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
377
- from datachain.client.hf import HfClient
378
-
379
407
  if self._catalog is None:
380
408
  raise RuntimeError("cannot prefetch file because catalog is not setup")
381
409
 
382
410
  client = self._catalog.get_client(self.source)
383
- if client.protocol == HfClient.protocol:
384
- return False
385
-
386
411
  await client._download(self, callback=download_cb or self._download_cb)
387
412
  self._set_stream(
388
413
  self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
@@ -430,7 +455,9 @@ class File(DataModel):
430
455
  path = url2pathname(path)
431
456
  return path
432
457
 
433
- def get_destination_path(self, output: str, placement: ExportPlacement) -> str:
458
+ def get_destination_path(
459
+ self, output: Union[str, os.PathLike[str]], placement: ExportPlacement
460
+ ) -> str:
434
461
  """
435
462
  Returns full destination path of a file for exporting to some output
436
463
  based on export placement
@@ -551,18 +578,36 @@ class TextFile(File):
551
578
  class ImageFile(File):
552
579
  """`DataModel` for reading image files."""
553
580
 
581
+ def get_info(self) -> "Image":
582
+ """
583
+ Retrieves metadata and information about the image file.
584
+
585
+ Returns:
586
+ Image: A Model containing image metadata such as width, height and format.
587
+ """
588
+ from .image import image_info
589
+
590
+ return image_info(self)
591
+
554
592
  def read(self):
555
593
  """Returns `PIL.Image.Image` object."""
594
+ from PIL import Image as PilImage
595
+
556
596
  fobj = super().read()
557
597
  return PilImage.open(BytesIO(fobj))
558
598
 
559
- def save(self, destination: str, client_config: Optional[dict] = None):
599
+ def save( # type: ignore[override]
600
+ self,
601
+ destination: str,
602
+ format: Optional[str] = None,
603
+ client_config: Optional[dict] = None,
604
+ ):
560
605
  """Writes it's content to destination"""
561
606
  destination = stringify_path(destination)
562
607
 
563
608
  client: Client = self._catalog.get_client(destination, **(client_config or {}))
564
609
  with client.fs.open(destination, mode="wb") as f:
565
- self.read().save(f)
610
+ self.read().save(f, format=format)
566
611
 
567
612
 
568
613
  class Image(DataModel):
datachain/lib/image.py CHANGED
@@ -1,17 +1,41 @@
1
1
  from typing import Callable, Optional, Union
2
2
 
3
3
  import torch
4
- from PIL import Image
4
+ from PIL import Image as PILImage
5
+
6
+ from datachain.lib.file import File, FileError, Image, ImageFile
7
+
8
+
9
+ def image_info(file: Union[File, ImageFile]) -> Image:
10
+ """
11
+ Returns image file information.
12
+
13
+ Args:
14
+ file (ImageFile): Image file object.
15
+
16
+ Returns:
17
+ Image: Image file information.
18
+ """
19
+ try:
20
+ img = file.as_image_file().read()
21
+ except Exception as exc:
22
+ raise FileError(file, "unable to open image file") from exc
23
+
24
+ return Image(
25
+ width=img.width,
26
+ height=img.height,
27
+ format=img.format or "",
28
+ )
5
29
 
6
30
 
7
31
  def convert_image(
8
- img: Image.Image,
32
+ img: PILImage.Image,
9
33
  mode: str = "RGB",
10
34
  size: Optional[tuple[int, int]] = None,
11
35
  transform: Optional[Callable] = None,
12
36
  encoder: Optional[Callable] = None,
13
37
  device: Optional[Union[str, torch.device]] = None,
14
- ) -> Union[Image.Image, torch.Tensor]:
38
+ ) -> Union[PILImage.Image, torch.Tensor]:
15
39
  """
16
40
  Resize, transform, and otherwise convert an image.
17
41
 
@@ -47,13 +71,13 @@ def convert_image(
47
71
 
48
72
 
49
73
  def convert_images(
50
- images: Union[Image.Image, list[Image.Image]],
74
+ images: Union[PILImage.Image, list[PILImage.Image]],
51
75
  mode: str = "RGB",
52
76
  size: Optional[tuple[int, int]] = None,
53
77
  transform: Optional[Callable] = None,
54
78
  encoder: Optional[Callable] = None,
55
79
  device: Optional[Union[str, torch.device]] = None,
56
- ) -> Union[list[Image.Image], torch.Tensor]:
80
+ ) -> Union[list[PILImage.Image], torch.Tensor]:
57
81
  """
58
82
  Resize, transform, and otherwise convert one or more images.
59
83
 
@@ -65,7 +89,7 @@ def convert_images(
65
89
  encoder (Callable): Encode image using model.
66
90
  device (str or torch.device): Device to use.
67
91
  """
68
- if isinstance(images, Image.Image):
92
+ if isinstance(images, PILImage.Image):
69
93
  images = [images]
70
94
 
71
95
  converted = [
datachain/lib/listing.py CHANGED
@@ -1,19 +1,21 @@
1
+ import glob
1
2
  import logging
2
3
  import os
3
4
  import posixpath
4
5
  from collections.abc import Iterator
5
- from typing import TYPE_CHECKING, Callable, Optional, TypeVar
6
+ from contextlib import contextmanager
7
+ from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
6
8
 
7
9
  from fsspec.asyn import get_loop
8
10
  from sqlalchemy.sql.expression import true
9
11
 
12
+ import datachain.fs.utils as fsutils
10
13
  from datachain.asyn import iter_over_async
11
14
  from datachain.client import Client
12
- from datachain.error import REMOTE_ERRORS, ClientError
15
+ from datachain.error import ClientError
13
16
  from datachain.lib.file import File
14
17
  from datachain.query.schema import Column
15
18
  from datachain.sql.functions import path as pathfunc
16
- from datachain.telemetry import telemetry
17
19
  from datachain.utils import uses_glob
18
20
 
19
21
  if TYPE_CHECKING:
@@ -92,38 +94,6 @@ def ls(
92
94
  return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*"))
93
95
 
94
96
 
95
- def _isfile(client: "Client", path: str) -> bool:
96
- """
97
- Returns True if uri points to a file
98
- """
99
- try:
100
- if "://" in path:
101
- # This makes sure that the uppercase scheme is converted to lowercase
102
- scheme, path = path.split("://", 1)
103
- path = f"{scheme.lower()}://{path}"
104
-
105
- if os.name == "nt" and "*" in path:
106
- # On Windows, the glob pattern "*" is not supported
107
- return False
108
-
109
- info = client.fs.info(path)
110
- name = info.get("name")
111
- # case for special simulated directories on some clouds
112
- # e.g. Google creates a zero byte file with the same name as the
113
- # directory with a trailing slash at the end
114
- if not name or name.endswith("/"):
115
- return False
116
-
117
- return info["type"] == "file"
118
- except FileNotFoundError:
119
- return False
120
- except REMOTE_ERRORS as e:
121
- raise ClientError(
122
- message=str(e),
123
- error_code=getattr(e, "code", None),
124
- ) from e
125
-
126
-
127
97
  def parse_listing_uri(uri: str, client_config) -> tuple[str, str, str]:
128
98
  """
129
99
  Parsing uri and returns listing dataset name, listing uri and listing path
@@ -156,8 +126,16 @@ def listing_uri_from_name(dataset_name: str) -> str:
156
126
  return dataset_name.removeprefix(LISTING_PREFIX)
157
127
 
158
128
 
129
+ @contextmanager
130
+ def _reraise_as_client_error() -> Iterator[None]:
131
+ try:
132
+ yield
133
+ except Exception as e:
134
+ raise ClientError(message=str(e), error_code=getattr(e, "code", None)) from e
135
+
136
+
159
137
  def get_listing(
160
- uri: str, session: "Session", update: bool = False
138
+ uri: Union[str, os.PathLike[str]], session: "Session", update: bool = False
161
139
  ) -> tuple[Optional[str], str, str, bool]:
162
140
  """Returns correct listing dataset name that must be used for saving listing
163
141
  operation. It takes into account existing listings and reusability of those.
@@ -167,6 +145,7 @@ def get_listing(
167
145
  be used to find rows based on uri.
168
146
  """
169
147
  from datachain.client.local import FileClient
148
+ from datachain.telemetry import telemetry
170
149
 
171
150
  catalog = session.catalog
172
151
  cache = catalog.cache
@@ -174,11 +153,14 @@ def get_listing(
174
153
 
175
154
  client = Client.get_client(uri, cache, **client_config)
176
155
  telemetry.log_param("client", client.PREFIX)
156
+ if not isinstance(uri, str):
157
+ uri = os.fspath(uri)
177
158
 
178
159
  # we don't want to use cached dataset (e.g. for a single file listing)
179
- if not uri.endswith("/") and _isfile(client, uri):
180
- storage_uri, path = Client.parse_url(uri)
181
- return None, f"{storage_uri}/{path.lstrip('/')}", path, False
160
+ isfile = _reraise_as_client_error()(fsutils.isfile)
161
+ if not glob.has_magic(uri) and not uri.endswith("/") and isfile(client.fs, uri):
162
+ _, path = Client.parse_url(uri)
163
+ return None, uri, path, False
182
164
 
183
165
  ds_name, list_uri, list_path = parse_listing_uri(uri, client_config)
184
166
  listing = None
@@ -10,7 +10,7 @@ import jmespath as jsp
10
10
  from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401
11
11
 
12
12
  from datachain.lib.data_model import DataModel # noqa: F401
13
- from datachain.lib.file import File
13
+ from datachain.lib.file import TextFile
14
14
 
15
15
 
16
16
  class UserModel(BaseModel):
@@ -130,7 +130,7 @@ def read_meta( # noqa: C901
130
130
  #
131
131
 
132
132
  def parse_data(
133
- file: File,
133
+ file: TextFile,
134
134
  data_model=spec,
135
135
  format=format,
136
136
  jmespath=jmespath,
@@ -5,6 +5,7 @@ from dataclasses import dataclass
5
5
  from datetime import datetime
6
6
  from inspect import isclass
7
7
  from typing import ( # noqa: UP035
8
+ IO,
8
9
  TYPE_CHECKING,
9
10
  Annotated,
10
11
  Any,
@@ -154,9 +155,9 @@ class SignalSchema:
154
155
  if not callable(func):
155
156
  raise SetupError(key, "value must be function or callable class")
156
157
 
157
- def _init_setup_values(self):
158
+ def _init_setup_values(self) -> None:
158
159
  if self.setup_values is not None:
159
- return self.setup_values
160
+ return
160
161
 
161
162
  res = {}
162
163
  for key, func in self.setup_func.items():
@@ -398,7 +399,7 @@ class SignalSchema:
398
399
  return SignalSchema(signals)
399
400
 
400
401
  @staticmethod
401
- def get_flatten_hidden_fields(schema):
402
+ def get_flatten_hidden_fields(schema: dict):
402
403
  custom_types = schema.get("_custom_types", {})
403
404
  if not custom_types:
404
405
  return []
@@ -464,19 +465,61 @@ class SignalSchema:
464
465
  return False
465
466
 
466
467
  def slice(
467
- self, keys: Sequence[str], setup: Optional[dict[str, Callable]] = None
468
+ self,
469
+ params: dict[str, Union[DataType, Any]],
470
+ setup: Optional[dict[str, Callable]] = None,
471
+ is_batch: bool = False,
468
472
  ) -> "SignalSchema":
469
- # Make new schema that combines current schema and setup signals
470
- setup = setup or {}
471
- setup_no_types = dict.fromkeys(setup.keys(), str)
472
- union = SignalSchema(self.values | setup_no_types)
473
- # Slice combined schema by keys
474
- schema = {}
475
- for k in keys:
476
- try:
477
- schema[k] = union._find_in_tree(k.split("."))
478
- except SignalResolvingError:
479
- pass
473
+ """
474
+ Returns new schema that combines current schema and setup signals.
475
+ """
476
+ setup_params = setup.keys() if setup else []
477
+ schema: dict[str, DataType] = {}
478
+
479
+ for param, param_type in params.items():
480
+ # This is special case for setup params, they are always treated as strings
481
+ if param in setup_params:
482
+ schema[param] = str
483
+ continue
484
+
485
+ schema_type = self._find_in_tree(param.split("."))
486
+
487
+ if param_type is Any:
488
+ schema[param] = schema_type
489
+ continue
490
+
491
+ schema_origin = get_origin(schema_type)
492
+ param_origin = get_origin(param_type)
493
+
494
+ if schema_origin is Union and type(None) in get_args(schema_type):
495
+ schema_type = get_args(schema_type)[0]
496
+ if param_origin is Union and type(None) in get_args(param_type):
497
+ param_type = get_args(param_type)[0]
498
+
499
+ if is_batch:
500
+ if param_type is list:
501
+ schema[param] = schema_type
502
+ continue
503
+
504
+ if param_origin is not list:
505
+ raise SignalResolvingError(param.split("."), "is not a list")
506
+
507
+ param_type = get_args(param_type)[0]
508
+
509
+ if param_type == schema_type or (
510
+ isclass(param_type)
511
+ and isclass(schema_type)
512
+ and issubclass(param_type, File)
513
+ and issubclass(schema_type, File)
514
+ ):
515
+ schema[param] = schema_type
516
+ continue
517
+
518
+ raise SignalResolvingError(
519
+ param.split("."),
520
+ f"types mismatch: {param_type} != {schema_type}",
521
+ )
522
+
480
523
  return SignalSchema(schema, setup)
481
524
 
482
525
  def row_to_features(
@@ -696,16 +739,20 @@ class SignalSchema:
696
739
  substree, new_prefix, depth + 1, include_hidden
697
740
  )
698
741
 
699
- def print_tree(self, indent: int = 4, start_at: int = 0):
742
+ def print_tree(self, indent: int = 2, start_at: int = 0, file: Optional[IO] = None):
700
743
  for path, type_, _, depth in self.get_flat_tree():
701
744
  total_indent = start_at + depth * indent
702
- print(" " * total_indent, f"{path[-1]}:", SignalSchema._type_to_str(type_))
745
+ col_name = " " * total_indent + path[-1]
746
+ col_type = SignalSchema._type_to_str(type_)
747
+ print(col_name, col_type, sep=": ", file=file)
703
748
 
704
749
  if get_origin(type_) is list:
705
750
  args = get_args(type_)
706
751
  if len(args) > 0 and ModelStore.is_pydantic(args[0]):
707
752
  sub_schema = SignalSchema({"* list of": args[0]})
708
- sub_schema.print_tree(indent=indent, start_at=total_indent + indent)
753
+ sub_schema.print_tree(
754
+ indent=indent, start_at=total_indent + indent, file=file
755
+ )
709
756
 
710
757
  def get_headers_with_length(self, include_hidden: bool = True):
711
758
  paths = [