datachain 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (39) hide show
  1. datachain/catalog/catalog.py +33 -5
  2. datachain/catalog/loader.py +19 -13
  3. datachain/cli/__init__.py +3 -1
  4. datachain/cli/commands/show.py +12 -1
  5. datachain/cli/parser/studio.py +13 -1
  6. datachain/cli/parser/utils.py +6 -0
  7. datachain/client/fsspec.py +12 -16
  8. datachain/client/hf.py +36 -14
  9. datachain/client/local.py +1 -4
  10. datachain/data_storage/warehouse.py +3 -8
  11. datachain/dataset.py +8 -0
  12. datachain/error.py +0 -12
  13. datachain/fs/utils.py +30 -0
  14. datachain/func/__init__.py +5 -0
  15. datachain/func/func.py +2 -1
  16. datachain/lib/data_model.py +6 -0
  17. datachain/lib/dc.py +114 -28
  18. datachain/lib/file.py +100 -25
  19. datachain/lib/image.py +30 -6
  20. datachain/lib/listing.py +21 -39
  21. datachain/lib/signal_schema.py +194 -15
  22. datachain/lib/video.py +7 -5
  23. datachain/model/bbox.py +209 -58
  24. datachain/model/pose.py +49 -37
  25. datachain/model/segment.py +22 -18
  26. datachain/model/ultralytics/bbox.py +9 -9
  27. datachain/model/ultralytics/pose.py +7 -7
  28. datachain/model/ultralytics/segment.py +7 -7
  29. datachain/model/utils.py +191 -0
  30. datachain/nodes_thread_pool.py +32 -11
  31. datachain/query/dataset.py +4 -2
  32. datachain/studio.py +8 -6
  33. datachain/utils.py +3 -16
  34. {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/METADATA +6 -4
  35. {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/RECORD +39 -37
  36. {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/WHEEL +1 -1
  37. {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/LICENSE +0 -0
  38. {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/entry_points.txt +0 -0
  39. {datachain-0.11.0.dist-info → datachain-0.12.0.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py CHANGED
@@ -22,7 +22,7 @@ import orjson
22
22
  import sqlalchemy
23
23
  from pydantic import BaseModel
24
24
  from sqlalchemy.sql.functions import GenericFunction
25
- from sqlalchemy.sql.sqltypes import NullType
25
+ from tqdm import tqdm
26
26
 
27
27
  from datachain.dataset import DatasetRecord
28
28
  from datachain.func import literal
@@ -32,7 +32,14 @@ from datachain.lib.convert.python_to_sql import python_to_sql
32
32
  from datachain.lib.convert.values_to_tuples import values_to_tuples
33
33
  from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_data_model
34
34
  from datachain.lib.dataset_info import DatasetInfo
35
- from datachain.lib.file import ArrowRow, File, FileType, get_file_type
35
+ from datachain.lib.file import (
36
+ EXPORT_FILES_MAX_THREADS,
37
+ ArrowRow,
38
+ File,
39
+ FileExporter,
40
+ FileType,
41
+ get_file_type,
42
+ )
36
43
  from datachain.lib.file import ExportPlacement as FileExportPlacement
37
44
  from datachain.lib.listing import get_file_info, get_listing, list_bucket, ls
38
45
  from datachain.lib.listing_info import ListingInfo
@@ -47,7 +54,6 @@ from datachain.query import Session
47
54
  from datachain.query.dataset import DatasetQuery, PartitionByType
48
55
  from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
49
56
  from datachain.sql.functions import path as pathfunc
50
- from datachain.telemetry import telemetry
51
57
  from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
52
58
 
53
59
  if TYPE_CHECKING:
@@ -65,7 +71,6 @@ _T = TypeVar("_T")
65
71
  D = TypeVar("D", bound="DataChain")
66
72
  UDFObjT = TypeVar("UDFObjT", bound=UDFBase)
67
73
 
68
-
69
74
  DEFAULT_PARQUET_CHUNK_SIZE = 100_000
70
75
 
71
76
 
@@ -208,7 +213,7 @@ class DataChain:
208
213
  from mistralai.client import MistralClient
209
214
  from mistralai.models.chat_completion import ChatMessage
210
215
 
211
- from datachain.dc import DataChain, Column
216
+ from datachain import DataChain, Column
212
217
 
213
218
  PROMPT = (
214
219
  "Was this bot dialog successful? "
@@ -401,7 +406,7 @@ class DataChain:
401
406
  @classmethod
402
407
  def from_storage(
403
408
  cls,
404
- uri,
409
+ uri: Union[str, os.PathLike[str]],
405
410
  *,
406
411
  type: FileType = "binary",
407
412
  session: Optional[Session] = None,
@@ -543,6 +548,8 @@ class DataChain:
543
548
  )
544
549
  ```
545
550
  """
551
+ from datachain.telemetry import telemetry
552
+
546
553
  query = DatasetQuery(
547
554
  name=name,
548
555
  version=version,
@@ -566,7 +573,7 @@ class DataChain:
566
573
  @classmethod
567
574
  def from_json(
568
575
  cls,
569
- path,
576
+ path: Union[str, os.PathLike[str]],
570
577
  type: FileType = "text",
571
578
  spec: Optional[DataType] = None,
572
579
  schema_from: Optional[str] = "auto",
@@ -603,7 +610,7 @@ class DataChain:
603
610
  ```
604
611
  """
605
612
  if schema_from == "auto":
606
- schema_from = path
613
+ schema_from = str(path)
607
614
 
608
615
  def jmespath_to_name(s: str):
609
616
  name_end = re.search(r"\W", s).start() if re.search(r"\W", s) else len(s) # type: ignore[union-attr]
@@ -694,9 +701,22 @@ class DataChain:
694
701
  in_memory: bool = False,
695
702
  object_name: str = "dataset",
696
703
  include_listing: bool = False,
704
+ studio: bool = False,
697
705
  ) -> "DataChain":
698
706
  """Generate chain with list of registered datasets.
699
707
 
708
+ Args:
709
+ session: Optional session instance. If not provided, uses default session.
710
+ settings: Optional dictionary of settings to configure the chain.
711
+ in_memory: If True, creates an in-memory session. Defaults to False.
712
+ object_name: Name of the output object in the chain. Defaults to "dataset".
713
+ include_listing: If True, includes listing datasets. Defaults to False.
714
+ studio: If True, returns datasets from Studio only,
715
+ otherwise returns all local datasets. Defaults to False.
716
+
717
+ Returns:
718
+ DataChain: A new DataChain instance containing dataset information.
719
+
700
720
  Example:
701
721
  ```py
702
722
  from datachain import DataChain
@@ -712,7 +732,7 @@ class DataChain:
712
732
  datasets = [
713
733
  DatasetInfo.from_models(d, v, j)
714
734
  for d, v, j in catalog.list_datasets_versions(
715
- include_listing=include_listing
735
+ include_listing=include_listing, studio=studio
716
736
  )
717
737
  ]
718
738
 
@@ -1050,7 +1070,7 @@ class DataChain:
1050
1070
  def select(self, *args: str, _sys: bool = True) -> "Self":
1051
1071
  """Select only a specified set of signals."""
1052
1072
  new_schema = self.signals_schema.resolve(*args)
1053
- if _sys:
1073
+ if self._sys and _sys:
1054
1074
  new_schema = SignalSchema({"sys": Sys}) | new_schema
1055
1075
  columns = new_schema.db_signals()
1056
1076
  return self._evolve(
@@ -1093,6 +1113,7 @@ class DataChain:
1093
1113
  partition_by_columns: list[Column] = []
1094
1114
  signal_columns: list[Column] = []
1095
1115
  schema_fields: dict[str, DataType] = {}
1116
+ keep_columns: list[str] = []
1096
1117
 
1097
1118
  # validate partition_by columns and add them to the schema
1098
1119
  for col in partition_by:
@@ -1100,10 +1121,13 @@ class DataChain:
1100
1121
  col_db_name = ColumnMeta.to_db_name(col)
1101
1122
  col_type = self.signals_schema.get_column_type(col_db_name)
1102
1123
  column = Column(col_db_name, python_to_sql(col_type))
1124
+ if col not in keep_columns:
1125
+ keep_columns.append(col)
1103
1126
  elif isinstance(col, Function):
1104
1127
  column = col.get_column(self.signals_schema)
1105
1128
  col_db_name = column.name
1106
1129
  col_type = column.type.python_type
1130
+ schema_fields[col_db_name] = col_type
1107
1131
  else:
1108
1132
  raise DataChainColumnError(
1109
1133
  col,
@@ -1113,7 +1137,6 @@ class DataChain:
1113
1137
  ),
1114
1138
  )
1115
1139
  partition_by_columns.append(column)
1116
- schema_fields[col_db_name] = col_type
1117
1140
 
1118
1141
  # validate signal columns and add them to the schema
1119
1142
  if not kwargs:
@@ -1128,9 +1151,13 @@ class DataChain:
1128
1151
  signal_columns.append(column)
1129
1152
  schema_fields[col_name] = func.get_result_type(self.signals_schema)
1130
1153
 
1154
+ signal_schema = SignalSchema(schema_fields)
1155
+ if keep_columns:
1156
+ signal_schema |= self.signals_schema.to_partial(*keep_columns)
1157
+
1131
1158
  return self._evolve(
1132
1159
  query=self._query.group_by(signal_columns, partition_by_columns),
1133
- signal_schema=SignalSchema(schema_fields),
1160
+ signal_schema=signal_schema,
1134
1161
  )
1135
1162
 
1136
1163
  def mutate(self, **kwargs) -> "Self":
@@ -1181,6 +1208,8 @@ class DataChain:
1181
1208
  )
1182
1209
  ```
1183
1210
  """
1211
+ from sqlalchemy.sql.sqltypes import NullType
1212
+
1184
1213
  primitives = (bool, str, int, float)
1185
1214
 
1186
1215
  for col_name, expr in kwargs.items():
@@ -1225,23 +1254,37 @@ class DataChain:
1225
1254
  @overload
1226
1255
  def collect_flatten(self) -> Iterator[tuple[Any, ...]]: ...
1227
1256
 
1257
+ @overload
1258
+ def collect_flatten(self, *, include_hidden: bool) -> Iterator[tuple[Any, ...]]: ...
1259
+
1228
1260
  @overload
1229
1261
  def collect_flatten(
1230
1262
  self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
1231
1263
  ) -> Iterator[_T]: ...
1232
1264
 
1233
- def collect_flatten(self, *, row_factory=None):
1265
+ @overload
1266
+ def collect_flatten(
1267
+ self,
1268
+ *,
1269
+ row_factory: Callable[[list[str], tuple[Any, ...]], _T],
1270
+ include_hidden: bool,
1271
+ ) -> Iterator[_T]: ...
1272
+
1273
+ def collect_flatten(self, *, row_factory=None, include_hidden: bool = True):
1234
1274
  """Yields flattened rows of values as a tuple.
1235
1275
 
1236
1276
  Args:
1237
1277
  row_factory : A callable to convert row to a custom format.
1238
1278
  It should accept two arguments: a list of column names and
1239
1279
  a tuple of row values.
1280
+ include_hidden: Whether to include hidden signals from the schema.
1240
1281
  """
1241
- db_signals = self._effective_signals_schema.db_signals()
1282
+ db_signals = self._effective_signals_schema.db_signals(
1283
+ include_hidden=include_hidden
1284
+ )
1242
1285
  with self._query.ordered_select(*db_signals).as_iterable() as rows:
1243
1286
  if row_factory:
1244
- rows = (row_factory(db_signals, r) for r in rows)
1287
+ rows = (row_factory(db_signals, r) for r in rows) # type: ignore[assignment]
1245
1288
  yield from rows
1246
1289
 
1247
1290
  def to_columnar_data_with_names(
@@ -1275,10 +1318,23 @@ class DataChain:
1275
1318
  self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
1276
1319
  ) -> list[_T]: ...
1277
1320
 
1278
- def results(self, *, row_factory=None): # noqa: D102
1321
+ @overload
1322
+ def results(
1323
+ self,
1324
+ *,
1325
+ row_factory: Callable[[list[str], tuple[Any, ...]], _T],
1326
+ include_hidden: bool,
1327
+ ) -> list[_T]: ...
1328
+
1329
+ @overload
1330
+ def results(self, *, include_hidden: bool) -> list[tuple[Any, ...]]: ...
1331
+
1332
+ def results(self, *, row_factory=None, include_hidden=True): # noqa: D102
1279
1333
  if row_factory is None:
1280
- return list(self.collect_flatten())
1281
- return list(self.collect_flatten(row_factory=row_factory))
1334
+ return list(self.collect_flatten(include_hidden=include_hidden))
1335
+ return list(
1336
+ self.collect_flatten(row_factory=row_factory, include_hidden=include_hidden)
1337
+ )
1282
1338
 
1283
1339
  def to_records(self) -> list[dict[str, Any]]:
1284
1340
  """Convert every row to a dictionary."""
@@ -1788,21 +1844,25 @@ class DataChain:
1788
1844
  **fr_map,
1789
1845
  )
1790
1846
 
1791
- def to_pandas(self, flatten=False) -> "pd.DataFrame":
1847
+ def to_pandas(self, flatten=False, include_hidden=True) -> "pd.DataFrame":
1792
1848
  """Return a pandas DataFrame from the chain.
1793
1849
 
1794
1850
  Parameters:
1795
1851
  flatten : Whether to use a multiindex or flatten column names.
1852
+ include_hidden : Whether to include hidden columns.
1796
1853
  """
1797
1854
  import pandas as pd
1798
1855
 
1799
- headers, max_length = self._effective_signals_schema.get_headers_with_length()
1856
+ headers, max_length = self._effective_signals_schema.get_headers_with_length(
1857
+ include_hidden=include_hidden
1858
+ )
1800
1859
  if flatten or max_length < 2:
1801
1860
  columns = [".".join(filter(None, header)) for header in headers]
1802
1861
  else:
1803
1862
  columns = pd.MultiIndex.from_tuples(map(tuple, headers))
1804
1863
 
1805
- return pd.DataFrame.from_records(self.results(), columns=columns)
1864
+ results = self.results(include_hidden=include_hidden)
1865
+ return pd.DataFrame.from_records(results, columns=columns)
1806
1866
 
1807
1867
  def show(
1808
1868
  self,
@@ -1810,6 +1870,7 @@ class DataChain:
1810
1870
  flatten=False,
1811
1871
  transpose=False,
1812
1872
  truncate=True,
1873
+ include_hidden=False,
1813
1874
  ) -> None:
1814
1875
  """Show a preview of the chain results.
1815
1876
 
@@ -1818,11 +1879,12 @@ class DataChain:
1818
1879
  flatten : Whether to use a multiindex or flatten column names.
1819
1880
  transpose : Whether to transpose rows and columns.
1820
1881
  truncate : Whether or not to truncate the contents of columns.
1882
+ include_hidden : Whether to include hidden columns.
1821
1883
  """
1822
1884
  import pandas as pd
1823
1885
 
1824
1886
  dc = self.limit(limit) if limit > 0 else self # type: ignore[misc]
1825
- df = dc.to_pandas(flatten)
1887
+ df = dc.to_pandas(flatten, include_hidden=include_hidden)
1826
1888
 
1827
1889
  if df.empty:
1828
1890
  print("Empty result")
@@ -2495,22 +2557,28 @@ class DataChain:
2495
2557
 
2496
2558
  def to_storage(
2497
2559
  self,
2498
- output: str,
2560
+ output: Union[str, os.PathLike[str]],
2499
2561
  signal: str = "file",
2500
2562
  placement: FileExportPlacement = "fullpath",
2501
- use_cache: bool = True,
2502
2563
  link_type: Literal["copy", "symlink"] = "copy",
2564
+ num_threads: Optional[int] = EXPORT_FILES_MAX_THREADS,
2565
+ anon: bool = False,
2566
+ client_config: Optional[dict] = None,
2503
2567
  ) -> None:
2504
- """Export files from a specified signal to a directory.
2568
+ """Export files from a specified signal to a directory. Files can be
2569
+ exported to a local or cloud directory.
2505
2570
 
2506
2571
  Args:
2507
2572
  output: Path to the target directory for exporting files.
2508
2573
  signal: Name of the signal to export files from.
2509
2574
  placement: The method to use for naming exported files.
2510
2575
  The possible values are: "filename", "etag", "fullpath", and "checksum".
2511
- use_cache: If `True`, cache the files before exporting.
2512
2576
  link_type: Method to use for exporting files.
2513
2577
  Falls back to `'copy'` if symlinking fails.
2578
+ num_threads : number of threads to use for exporting files.
2579
+ By default it uses 5 threads.
2580
+ anon: If true, we will treat cloud bucket as public one
2581
+ client_config: Optional configuration for the destination storage client
2514
2582
 
2515
2583
  Example:
2516
2584
  Cross cloud transfer
@@ -2525,8 +2593,26 @@ class DataChain:
2525
2593
  ):
2526
2594
  raise ValueError("Files with the same name found")
2527
2595
 
2528
- for file in self.collect(signal):
2529
- file.export(output, placement, use_cache, link_type=link_type) # type: ignore[union-attr]
2596
+ if anon:
2597
+ client_config = (client_config or {}) | {"anon": True}
2598
+
2599
+ progress_bar = tqdm(
2600
+ desc=f"Exporting files to {output}: ",
2601
+ unit=" files",
2602
+ unit_scale=True,
2603
+ unit_divisor=10,
2604
+ total=self.count(),
2605
+ leave=False,
2606
+ )
2607
+ file_exporter = FileExporter(
2608
+ output,
2609
+ placement,
2610
+ self._settings.cache if self._settings else False,
2611
+ link_type,
2612
+ max_threads=num_threads or 1,
2613
+ client_config=client_config,
2614
+ )
2615
+ file_exporter.run(self.collect(signal), progress_bar)
2530
2616
 
2531
2617
  def shuffle(self) -> "Self":
2532
2618
  """Shuffle the rows of the chain deterministically."""
datachain/lib/file.py CHANGED
@@ -18,12 +18,12 @@ from urllib.request import url2pathname
18
18
 
19
19
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
20
20
  from fsspec.utils import stringify_path
21
- from PIL import Image as PilImage
22
21
  from pydantic import Field, field_validator
23
22
 
24
23
  from datachain.client.fileslice import FileSlice
25
24
  from datachain.lib.data_model import DataModel
26
25
  from datachain.lib.utils import DataChainError
26
+ from datachain.nodes_thread_pool import NodesThreadPool
27
27
  from datachain.sql.types import JSON, Boolean, DateTime, Int, String
28
28
  from datachain.utils import TIME_ZERO
29
29
 
@@ -43,6 +43,41 @@ logger = logging.getLogger("datachain")
43
43
  ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
44
44
 
45
45
  FileType = Literal["binary", "text", "image", "video"]
46
+ EXPORT_FILES_MAX_THREADS = 5
47
+
48
+
49
+ class FileExporter(NodesThreadPool):
50
+ """Class that does file exporting concurrently with thread pool"""
51
+
52
+ def __init__(
53
+ self,
54
+ output: Union[str, os.PathLike[str]],
55
+ placement: ExportPlacement,
56
+ use_cache: bool,
57
+ link_type: Literal["copy", "symlink"],
58
+ max_threads: int = EXPORT_FILES_MAX_THREADS,
59
+ client_config: Optional[dict] = None,
60
+ ):
61
+ super().__init__(max_threads)
62
+ self.output = output
63
+ self.placement = placement
64
+ self.use_cache = use_cache
65
+ self.link_type = link_type
66
+ self.client_config = client_config
67
+
68
+ def done_task(self, done):
69
+ for task in done:
70
+ task.result()
71
+
72
+ def do_task(self, file):
73
+ file.export(
74
+ self.output,
75
+ self.placement,
76
+ self.use_cache,
77
+ link_type=self.link_type,
78
+ client_config=self.client_config,
79
+ )
80
+ self.increase_counter(1)
46
81
 
47
82
 
48
83
  class VFileError(DataChainError):
@@ -158,6 +193,7 @@ class File(DataModel):
158
193
  "last_modified": DateTime,
159
194
  "location": JSON,
160
195
  }
196
+ _hidden_fields: ClassVar[list[str]] = ["version", "source"]
161
197
 
162
198
  _unique_id_keys: ClassVar[list[str]] = [
163
199
  "source",
@@ -206,6 +242,30 @@ class File(DataModel):
206
242
  self._catalog = None
207
243
  self._caching_enabled: bool = False
208
244
 
245
+ def as_text_file(self) -> "TextFile":
246
+ """Convert the file to a `TextFile` object."""
247
+ if isinstance(self, TextFile):
248
+ return self
249
+ file = TextFile(**self.model_dump())
250
+ file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
251
+ return file
252
+
253
+ def as_image_file(self) -> "ImageFile":
254
+ """Convert the file to a `ImageFile` object."""
255
+ if isinstance(self, ImageFile):
256
+ return self
257
+ file = ImageFile(**self.model_dump())
258
+ file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
259
+ return file
260
+
261
+ def as_video_file(self) -> "VideoFile":
262
+ """Convert the file to a `VideoFile` object."""
263
+ if isinstance(self, VideoFile):
264
+ return self
265
+ file = VideoFile(**self.model_dump())
266
+ file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
267
+ return file
268
+
209
269
  @classmethod
210
270
  def upload(
211
271
  cls, data: bytes, path: str, catalog: Optional["Catalog"] = None
@@ -255,24 +315,24 @@ class File(DataModel):
255
315
  ) as f:
256
316
  yield io.TextIOWrapper(f) if mode == "r" else f
257
317
 
258
- def read(self, length: int = -1):
259
- """Returns file contents."""
318
+ def read_bytes(self, length: int = -1):
319
+ """Returns file contents as bytes."""
260
320
  with self.open() as stream:
261
321
  return stream.read(length)
262
322
 
263
- def read_bytes(self):
264
- """Returns file contents as bytes."""
265
- return self.read()
266
-
267
323
  def read_text(self):
268
324
  """Returns file contents as text."""
269
325
  with self.open(mode="r") as stream:
270
326
  return stream.read()
271
327
 
272
- def save(self, destination: str):
328
+ def read(self, length: int = -1):
329
+ """Returns file contents."""
330
+ return self.read_bytes(length)
331
+
332
+ def save(self, destination: str, client_config: Optional[dict] = None):
273
333
  """Writes it's content to destination"""
274
334
  destination = stringify_path(destination)
275
- client: Client = self._catalog.get_client(destination)
335
+ client: Client = self._catalog.get_client(destination, **(client_config or {}))
276
336
 
277
337
  if client.PREFIX == "file://" and not destination.startswith(client.PREFIX):
278
338
  destination = Path(destination).absolute().as_uri()
@@ -296,17 +356,17 @@ class File(DataModel):
296
356
 
297
357
  def export(
298
358
  self,
299
- output: str,
359
+ output: Union[str, os.PathLike[str]],
300
360
  placement: ExportPlacement = "fullpath",
301
361
  use_cache: bool = True,
302
362
  link_type: Literal["copy", "symlink"] = "copy",
363
+ client_config: Optional[dict] = None,
303
364
  ) -> None:
304
365
  """Export file to new location."""
305
- if use_cache:
306
- self._caching_enabled = use_cache
366
+ self._caching_enabled = use_cache
307
367
  dst = self.get_destination_path(output, placement)
308
368
  dst_dir = os.path.dirname(dst)
309
- client: Client = self._catalog.get_client(dst_dir)
369
+ client: Client = self._catalog.get_client(dst_dir, **(client_config or {}))
310
370
  client.fs.makedirs(dst_dir, exist_ok=True)
311
371
 
312
372
  if link_type == "symlink":
@@ -316,7 +376,7 @@ class File(DataModel):
316
376
  if exc.errno not in (errno.ENOTSUP, errno.EXDEV, errno.ENOSYS):
317
377
  raise
318
378
 
319
- self.save(dst)
379
+ self.save(dst, client_config=client_config)
320
380
 
321
381
  def _set_stream(
322
382
  self,
@@ -337,15 +397,10 @@ class File(DataModel):
337
397
  client.download(self, callback=self._download_cb)
338
398
 
339
399
  async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
340
- from datachain.client.hf import HfClient
341
-
342
400
  if self._catalog is None:
343
401
  raise RuntimeError("cannot prefetch file because catalog is not setup")
344
402
 
345
403
  client = self._catalog.get_client(self.source)
346
- if client.protocol == HfClient.protocol:
347
- return False
348
-
349
404
  await client._download(self, callback=download_cb or self._download_cb)
350
405
  self._set_stream(
351
406
  self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
@@ -393,7 +448,9 @@ class File(DataModel):
393
448
  path = url2pathname(path)
394
449
  return path
395
450
 
396
- def get_destination_path(self, output: str, placement: ExportPlacement) -> str:
451
+ def get_destination_path(
452
+ self, output: Union[str, os.PathLike[str]], placement: ExportPlacement
453
+ ) -> str:
397
454
  """
398
455
  Returns full destination path of a file for exporting to some output
399
456
  based on export placement
@@ -502,11 +559,11 @@ class TextFile(File):
502
559
  with self.open() as stream:
503
560
  return stream.read()
504
561
 
505
- def save(self, destination: str):
562
+ def save(self, destination: str, client_config: Optional[dict] = None):
506
563
  """Writes it's content to destination"""
507
564
  destination = stringify_path(destination)
508
565
 
509
- client: Client = self._catalog.get_client(destination)
566
+ client: Client = self._catalog.get_client(destination, **(client_config or {}))
510
567
  with client.fs.open(destination, mode="w") as f:
511
568
  f.write(self.read_text())
512
569
 
@@ -514,18 +571,36 @@ class TextFile(File):
514
571
  class ImageFile(File):
515
572
  """`DataModel` for reading image files."""
516
573
 
574
+ def get_info(self) -> "Image":
575
+ """
576
+ Retrieves metadata and information about the image file.
577
+
578
+ Returns:
579
+ Image: A Model containing image metadata such as width, height and format.
580
+ """
581
+ from .image import image_info
582
+
583
+ return image_info(self)
584
+
517
585
  def read(self):
518
586
  """Returns `PIL.Image.Image` object."""
587
+ from PIL import Image as PilImage
588
+
519
589
  fobj = super().read()
520
590
  return PilImage.open(BytesIO(fobj))
521
591
 
522
- def save(self, destination: str):
592
+ def save( # type: ignore[override]
593
+ self,
594
+ destination: str,
595
+ format: Optional[str] = None,
596
+ client_config: Optional[dict] = None,
597
+ ):
523
598
  """Writes it's content to destination"""
524
599
  destination = stringify_path(destination)
525
600
 
526
- client: Client = self._catalog.get_client(destination)
601
+ client: Client = self._catalog.get_client(destination, **(client_config or {}))
527
602
  with client.fs.open(destination, mode="wb") as f:
528
- self.read().save(f)
603
+ self.read().save(f, format=format)
529
604
 
530
605
 
531
606
  class Image(DataModel):
datachain/lib/image.py CHANGED
@@ -1,17 +1,41 @@
1
1
  from typing import Callable, Optional, Union
2
2
 
3
3
  import torch
4
- from PIL import Image
4
+ from PIL import Image as PILImage
5
+
6
+ from datachain.lib.file import File, FileError, Image, ImageFile
7
+
8
+
9
+ def image_info(file: Union[File, ImageFile]) -> Image:
10
+ """
11
+ Returns image file information.
12
+
13
+ Args:
14
+ file (ImageFile): Image file object.
15
+
16
+ Returns:
17
+ Image: Image file information.
18
+ """
19
+ try:
20
+ img = file.as_image_file().read()
21
+ except Exception as exc:
22
+ raise FileError(file, "unable to open image file") from exc
23
+
24
+ return Image(
25
+ width=img.width,
26
+ height=img.height,
27
+ format=img.format or "",
28
+ )
5
29
 
6
30
 
7
31
  def convert_image(
8
- img: Image.Image,
32
+ img: PILImage.Image,
9
33
  mode: str = "RGB",
10
34
  size: Optional[tuple[int, int]] = None,
11
35
  transform: Optional[Callable] = None,
12
36
  encoder: Optional[Callable] = None,
13
37
  device: Optional[Union[str, torch.device]] = None,
14
- ) -> Union[Image.Image, torch.Tensor]:
38
+ ) -> Union[PILImage.Image, torch.Tensor]:
15
39
  """
16
40
  Resize, transform, and otherwise convert an image.
17
41
 
@@ -47,13 +71,13 @@ def convert_image(
47
71
 
48
72
 
49
73
  def convert_images(
50
- images: Union[Image.Image, list[Image.Image]],
74
+ images: Union[PILImage.Image, list[PILImage.Image]],
51
75
  mode: str = "RGB",
52
76
  size: Optional[tuple[int, int]] = None,
53
77
  transform: Optional[Callable] = None,
54
78
  encoder: Optional[Callable] = None,
55
79
  device: Optional[Union[str, torch.device]] = None,
56
- ) -> Union[list[Image.Image], torch.Tensor]:
80
+ ) -> Union[list[PILImage.Image], torch.Tensor]:
57
81
  """
58
82
  Resize, transform, and otherwise convert one or more images.
59
83
 
@@ -65,7 +89,7 @@ def convert_images(
65
89
  encoder (Callable): Encode image using model.
66
90
  device (str or torch.device): Device to use.
67
91
  """
68
- if isinstance(images, Image.Image):
92
+ if isinstance(images, PILImage.Image):
69
93
  images = [images]
70
94
 
71
95
  converted = [