datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/file.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import errno
2
2
  import hashlib
3
3
  import io
4
- import json
5
4
  import logging
6
5
  import os
7
6
  import posixpath
@@ -13,7 +12,7 @@ from datetime import datetime
13
12
  from functools import partial
14
13
  from io import BytesIO
15
14
  from pathlib import Path, PurePath, PurePosixPath
16
- from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
15
+ from typing import TYPE_CHECKING, Any, ClassVar, Literal
17
16
  from urllib.parse import unquote, urlparse
18
17
  from urllib.request import url2pathname
19
18
 
@@ -21,6 +20,7 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback
21
20
  from fsspec.utils import stringify_path
22
21
  from pydantic import Field, field_validator
23
22
 
23
+ from datachain import json
24
24
  from datachain.client.fileslice import FileSlice
25
25
  from datachain.lib.data_model import DataModel
26
26
  from datachain.lib.utils import DataChainError, rebase_path
@@ -35,13 +35,14 @@ if TYPE_CHECKING:
35
35
  from datachain.catalog import Catalog
36
36
  from datachain.client.fsspec import Client
37
37
  from datachain.dataset import RowDict
38
+ from datachain.query.session import Session
38
39
 
39
40
  sha256 = partial(hashlib.sha256, usedforsecurity=False)
40
41
 
41
42
  logger = logging.getLogger("datachain")
42
43
 
43
44
  # how to create file path when exporting
44
- ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
45
+ ExportPlacement = Literal["filename", "etag", "fullpath", "checksum", "filepath"]
45
46
 
46
47
  FileType = Literal["binary", "text", "image", "video", "audio"]
47
48
  EXPORT_FILES_MAX_THREADS = 5
@@ -52,12 +53,12 @@ class FileExporter(NodesThreadPool):
52
53
 
53
54
  def __init__(
54
55
  self,
55
- output: Union[str, os.PathLike[str]],
56
+ output: str | os.PathLike[str],
56
57
  placement: ExportPlacement,
57
58
  use_cache: bool,
58
59
  link_type: Literal["copy", "symlink"],
59
60
  max_threads: int = EXPORT_FILES_MAX_THREADS,
60
- client_config: Optional[dict] = None,
61
+ client_config: dict | None = None,
61
62
  ):
62
63
  super().__init__(max_threads)
63
64
  self.output = output
@@ -220,7 +221,7 @@ class File(DataModel):
220
221
  etag: str = Field(default="")
221
222
  is_latest: bool = Field(default=True)
222
223
  last_modified: datetime = Field(default=TIME_ZERO)
223
- location: Optional[Union[dict, list[dict]]] = Field(default=None)
224
+ location: dict | list[dict] | None = Field(default=None)
224
225
 
225
226
  _datachain_column_types: ClassVar[dict[str, Any]] = {
226
227
  "source": String,
@@ -252,10 +253,19 @@ class File(DataModel):
252
253
  "last_modified",
253
254
  ]
254
255
 
256
+ # Allowed kwargs we forward to TextIOWrapper
257
+ _TEXT_WRAPPER_ALLOWED: ClassVar[tuple[str, ...]] = (
258
+ "encoding",
259
+ "errors",
260
+ "newline",
261
+ "line_buffering",
262
+ "write_through",
263
+ )
264
+
255
265
  @staticmethod
256
266
  def _validate_dict(
257
- v: Optional[Union[str, dict, list[dict]]],
258
- ) -> Optional[Union[str, dict, list[dict]]]:
267
+ v: str | dict | list[dict] | None,
268
+ ) -> str | dict | list[dict] | None:
259
269
  if v is None or v == "":
260
270
  return None
261
271
  if isinstance(v, str):
@@ -287,6 +297,16 @@ class File(DataModel):
287
297
  super().__init__(**kwargs)
288
298
  self._catalog = None
289
299
  self._caching_enabled: bool = False
300
+ self._download_cb: Callback = DEFAULT_CALLBACK
301
+
302
+ def __getstate__(self):
303
+ state = super().__getstate__()
304
+ # Exclude _catalog from pickling - it contains SQLAlchemy engine and other
305
+ # non-picklable objects. The catalog will be re-set by _set_stream() on the
306
+ # worker side when needed.
307
+ state["__dict__"] = state["__dict__"].copy()
308
+ state["__dict__"]["_catalog"] = None
309
+ return state
290
310
 
291
311
  def as_text_file(self) -> "TextFile":
292
312
  """Convert the file to a `TextFile` object."""
@@ -322,17 +342,21 @@ class File(DataModel):
322
342
 
323
343
  @classmethod
324
344
  def upload(
325
- cls, data: bytes, path: str, catalog: Optional["Catalog"] = None
345
+ cls,
346
+ data: bytes,
347
+ path: str | os.PathLike[str],
348
+ catalog: "Catalog | None" = None,
326
349
  ) -> "Self":
327
350
  if catalog is None:
328
- from datachain.catalog.loader import get_catalog
329
-
330
- catalog = get_catalog()
351
+ from datachain.query.session import Session
331
352
 
353
+ catalog = Session.get().catalog
332
354
  from datachain.client.fsspec import Client
333
355
 
334
- client_cls = Client.get_implementation(path)
335
- source, rel_path = client_cls.split_url(path)
356
+ path_str = stringify_path(path)
357
+
358
+ client_cls = Client.get_implementation(path_str)
359
+ source, rel_path = client_cls.split_url(path_str)
336
360
 
337
361
  client = catalog.get_client(client_cls.get_uri(source))
338
362
  file = client.upload(data, rel_path)
@@ -341,6 +365,35 @@ class File(DataModel):
341
365
  file._set_stream(catalog)
342
366
  return file
343
367
 
368
+ @classmethod
369
+ def at(
370
+ cls, uri: str | os.PathLike[str], session: "Session | None" = None
371
+ ) -> "Self":
372
+ """Construct a File from a full URI in one call.
373
+
374
+ Example:
375
+ file = File.at("s3://bucket/path/to/output.png")
376
+ with file.open("wb") as f: ...
377
+ """
378
+ from datachain.client.fsspec import Client
379
+ from datachain.query.session import Session
380
+
381
+ if session is None:
382
+ session = Session.get()
383
+ catalog = session.catalog
384
+ uri_str = stringify_path(uri)
385
+ if uri_str.endswith(("/", os.sep)):
386
+ raise ValueError(
387
+ f"File.at directory URL/path given (trailing slash), got: {uri_str}"
388
+ )
389
+ client_cls = Client.get_implementation(uri_str)
390
+ uri_str = client_cls.path_to_uri(uri_str)
391
+ source, rel_path = client_cls.split_url(uri_str)
392
+ source_uri = client_cls.get_uri(source)
393
+ file = cls(source=source_uri, path=rel_path)
394
+ file._set_stream(catalog)
395
+ return file
396
+
344
397
  @classmethod
345
398
  def _from_row(cls, row: "RowDict") -> "Self":
346
399
  return cls(**{key: row[key] for key in cls._datachain_column_types})
@@ -354,28 +407,93 @@ class File(DataModel):
354
407
  return str(PurePosixPath(self.path).parent)
355
408
 
356
409
  @contextmanager
357
- def open(self, mode: Literal["rb", "r"] = "rb") -> Iterator[Any]:
358
- """Open the file and return a file object."""
359
- if self.location:
360
- with VFileRegistry.open(self, self.location) as f: # type: ignore[arg-type]
361
- yield f
410
+ def open(
411
+ self,
412
+ mode: str = "rb",
413
+ *,
414
+ client_config: dict[str, Any] | None = None,
415
+ **open_kwargs,
416
+ ) -> Iterator[Any]:
417
+ """Open the file and return a file-like object.
418
+
419
+ Supports both read ("rb", "r") and write modes (e.g. "wb", "w", "ab").
420
+ When opened in a write mode, metadata is refreshed after closing.
421
+ """
422
+ writing = any(ch in mode for ch in "wax+")
423
+ if self.location and writing:
424
+ raise VFileError(
425
+ "Writing to virtual file is not supported",
426
+ self.source,
427
+ self.path,
428
+ )
362
429
 
363
- else:
430
+ if self._catalog is None:
431
+ raise RuntimeError("Cannot open file: catalog is not set")
432
+
433
+ base_cfg = getattr(self._catalog, "client_config", {}) or {}
434
+ merged_cfg = {**base_cfg, **(client_config or {})}
435
+ client: Client = self._catalog.get_client(self.source, **merged_cfg)
436
+
437
+ if not writing:
438
+ if self.location:
439
+ with VFileRegistry.open(self, self.location) as f: # type: ignore[arg-type]
440
+ yield self._wrap_text(f, mode, open_kwargs)
441
+ return
364
442
  if self._caching_enabled:
365
443
  self.ensure_cached()
366
- client: Client = self._catalog.get_client(self.source)
367
444
  with client.open_object(
368
445
  self, use_cache=self._caching_enabled, cb=self._download_cb
369
446
  ) as f:
370
- yield io.TextIOWrapper(f) if mode == "r" else f
447
+ yield self._wrap_text(f, mode, open_kwargs)
448
+ return
449
+
450
+ # write path
451
+ full_path = client.get_full_path(self.get_path_normalized())
452
+ with client.fs.open(full_path, mode, **open_kwargs) as f:
453
+ yield self._wrap_text(f, mode, open_kwargs)
454
+
455
+ version_hint = self._extract_write_version(f)
456
+
457
+ # refresh metadata pinned to the version that was just written
458
+ refreshed = client.get_file_info(
459
+ self.get_path_normalized(), version_id=version_hint
460
+ )
461
+ for k, v in refreshed.model_dump().items():
462
+ setattr(self, k, v)
463
+
464
+ def _wrap_text(self, f: Any, mode: str, open_kwargs: dict[str, Any]) -> Any:
465
+ """Return stream possibly wrapped for text."""
466
+ if "b" in mode or isinstance(f, io.TextIOBase):
467
+ return f
468
+ filtered = {
469
+ k: open_kwargs[k] for k in self._TEXT_WRAPPER_ALLOWED if k in open_kwargs
470
+ }
471
+ return io.TextIOWrapper(f, **filtered)
472
+
473
+ def _extract_write_version(self, handle: Any) -> str | None:
474
+ """Best-effort extraction of object version after a write.
475
+
476
+ S3 (s3fs) and Azure (adlfs) populate version_id on the handle.
477
+ GCS (gcsfs) populates generation. Azure and GCS require upstream
478
+ fixes to be released.
479
+ """
480
+ for attr in ("version_id", "generation"):
481
+ if value := getattr(handle, attr, None):
482
+ return value
483
+ return None
371
484
 
372
485
  def read_bytes(self, length: int = -1):
373
486
  """Returns file contents as bytes."""
374
487
  with self.open() as stream:
375
488
  return stream.read(length)
376
489
 
377
- def read_text(self):
378
- """Returns file contents as text."""
490
+ def read_text(self, **open_kwargs):
491
+ """Return file contents decoded as text.
492
+
493
+ **open_kwargs : Any
494
+ Extra keyword arguments forwarded to ``open(mode="r", ...)``
495
+ (e.g. ``encoding="utf-8"``, ``errors="ignore"``)
496
+ """
379
497
  if self.location:
380
498
  raise VFileError(
381
499
  "Reading text from virtual file is not supported",
@@ -383,14 +501,14 @@ class File(DataModel):
383
501
  self.path,
384
502
  )
385
503
 
386
- with self.open(mode="r") as stream:
504
+ with self.open(mode="r", **open_kwargs) as stream:
387
505
  return stream.read()
388
506
 
389
507
  def read(self, length: int = -1):
390
508
  """Returns file contents."""
391
509
  return self.read_bytes(length)
392
510
 
393
- def save(self, destination: str, client_config: Optional[dict] = None):
511
+ def save(self, destination: str, client_config: dict | None = None):
394
512
  """Writes it's content to destination"""
395
513
  destination = stringify_path(destination)
396
514
  client: Client = self._catalog.get_client(destination, **(client_config or {}))
@@ -417,11 +535,11 @@ class File(DataModel):
417
535
 
418
536
  def export(
419
537
  self,
420
- output: Union[str, os.PathLike[str]],
538
+ output: str | os.PathLike[str],
421
539
  placement: ExportPlacement = "fullpath",
422
540
  use_cache: bool = True,
423
541
  link_type: Literal["copy", "symlink"] = "copy",
424
- client_config: Optional[dict] = None,
542
+ client_config: dict | None = None,
425
543
  ) -> None:
426
544
  """Export file to new location."""
427
545
  self._caching_enabled = use_cache
@@ -457,7 +575,7 @@ class File(DataModel):
457
575
  client = self._catalog.get_client(self.source)
458
576
  client.download(self, callback=self._download_cb)
459
577
 
460
- async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
578
+ async def _prefetch(self, download_cb: "Callback | None" = None) -> bool:
461
579
  if self._catalog is None:
462
580
  raise RuntimeError("cannot prefetch file because catalog is not setup")
463
581
 
@@ -472,7 +590,7 @@ class File(DataModel):
472
590
  )
473
591
  return True
474
592
 
475
- def get_local_path(self) -> Optional[str]:
593
+ def get_local_path(self) -> str | None:
476
594
  """Return path to a file in a local cache.
477
595
 
478
596
  Returns None if file is not cached.
@@ -549,7 +667,7 @@ class File(DataModel):
549
667
  return path
550
668
 
551
669
  def get_destination_path(
552
- self, output: Union[str, os.PathLike[str]], placement: ExportPlacement
670
+ self, output: str | os.PathLike[str], placement: ExportPlacement
553
671
  ) -> str:
554
672
  """
555
673
  Returns full destination path of a file for exporting to some output
@@ -564,6 +682,8 @@ class File(DataModel):
564
682
  source = urlparse(self.source)
565
683
  if source.scheme and source.scheme != "file":
566
684
  path = posixpath.join(source.netloc, path)
685
+ elif placement == "filepath":
686
+ path = unquote(self.get_path_normalized())
567
687
  elif placement == "checksum":
568
688
  raise NotImplementedError("Checksum placement not implemented yet")
569
689
  else:
@@ -601,7 +721,7 @@ class File(DataModel):
601
721
  normalized_path = self.get_path_normalized()
602
722
  info = client.fs.info(client.get_full_path(normalized_path))
603
723
  converted_info = client.info_to_file(info, normalized_path)
604
- return type(self)(
724
+ res = type(self)(
605
725
  path=self.path,
606
726
  source=self.source,
607
727
  size=converted_info.size,
@@ -611,6 +731,8 @@ class File(DataModel):
611
731
  last_modified=converted_info.last_modified,
612
732
  location=self.location,
613
733
  )
734
+ res._set_stream(self._catalog)
735
+ return res
614
736
  except FileError as e:
615
737
  logger.warning(
616
738
  "File error when resolving %s/%s: %s", self.source, self.path, str(e)
@@ -623,7 +745,7 @@ class File(DataModel):
623
745
  str(e),
624
746
  )
625
747
 
626
- return type(self)(
748
+ res = type(self)(
627
749
  path=self.path,
628
750
  source=self.source,
629
751
  size=0,
@@ -633,6 +755,8 @@ class File(DataModel):
633
755
  last_modified=TIME_ZERO,
634
756
  location=self.location,
635
757
  )
758
+ res._set_stream(self._catalog)
759
+ return res
636
760
 
637
761
  def rebase(
638
762
  self,
@@ -701,17 +825,30 @@ class TextFile(File):
701
825
  """`DataModel` for reading text files."""
702
826
 
703
827
  @contextmanager
704
- def open(self, mode: Literal["rb", "r"] = "r"):
705
- """Open the file and return a file object (default to text mode)."""
706
- with super().open(mode=mode) as stream:
828
+ def open(
829
+ self,
830
+ mode: str = "r",
831
+ *,
832
+ client_config: dict[str, Any] | None = None,
833
+ **open_kwargs,
834
+ ) -> Iterator[Any]:
835
+ """Open the file and return a file-like object.
836
+ Default to text mode"""
837
+ with super().open(
838
+ mode=mode, client_config=client_config, **open_kwargs
839
+ ) as stream:
707
840
  yield stream
708
841
 
709
- def read_text(self):
710
- """Returns file contents as text."""
711
- with self.open() as stream:
842
+ def read_text(self, **open_kwargs):
843
+ """Return file contents as text.
844
+
845
+ **open_kwargs : Any
846
+ Extra keyword arguments forwarded to ``open()`` (e.g. encoding).
847
+ """
848
+ with self.open(**open_kwargs) as stream:
712
849
  return stream.read()
713
850
 
714
- def save(self, destination: str, client_config: Optional[dict] = None):
851
+ def save(self, destination: str, client_config: dict | None = None):
715
852
  """Writes it's content to destination"""
716
853
  destination = stringify_path(destination)
717
854
 
@@ -744,8 +881,8 @@ class ImageFile(File):
744
881
  def save( # type: ignore[override]
745
882
  self,
746
883
  destination: str,
747
- format: Optional[str] = None,
748
- client_config: Optional[dict] = None,
884
+ format: str | None = None,
885
+ client_config: dict | None = None,
749
886
  ):
750
887
  """Writes it's content to destination"""
751
888
  destination = stringify_path(destination)
@@ -827,7 +964,7 @@ class VideoFile(File):
827
964
  def get_frames(
828
965
  self,
829
966
  start: int = 0,
830
- end: Optional[int] = None,
967
+ end: int | None = None,
831
968
  step: int = 1,
832
969
  ) -> "Iterator[VideoFrame]":
833
970
  """
@@ -877,7 +1014,7 @@ class VideoFile(File):
877
1014
  self,
878
1015
  duration: float,
879
1016
  start: float = 0,
880
- end: Optional[float] = None,
1017
+ end: float | None = None,
881
1018
  ) -> "Iterator[VideoFragment]":
882
1019
  """
883
1020
  Splits the video into multiple fragments of a specified duration.
@@ -963,7 +1100,7 @@ class AudioFile(File):
963
1100
  self,
964
1101
  duration: float,
965
1102
  start: float = 0,
966
- end: Optional[float] = None,
1103
+ end: float | None = None,
967
1104
  ) -> "Iterator[AudioFragment]":
968
1105
  """
969
1106
  Splits the audio into multiple fragments of a specified duration.
@@ -1001,10 +1138,10 @@ class AudioFile(File):
1001
1138
  def save( # type: ignore[override]
1002
1139
  self,
1003
1140
  output: str,
1004
- format: Optional[str] = None,
1141
+ format: str | None = None,
1005
1142
  start: float = 0,
1006
- end: Optional[float] = None,
1007
- client_config: Optional[dict] = None,
1143
+ end: float | None = None,
1144
+ client_config: dict | None = None,
1008
1145
  ) -> "AudioFile":
1009
1146
  """Save audio file or extract fragment to specified format.
1010
1147
 
@@ -1075,7 +1212,7 @@ class AudioFragment(DataModel):
1075
1212
  duration = self.end - self.start
1076
1213
  return audio_to_bytes(self.audio, format, self.start, duration)
1077
1214
 
1078
- def save(self, output: str, format: Optional[str] = None) -> "AudioFile":
1215
+ def save(self, output: str, format: str | None = None) -> "AudioFile":
1079
1216
  """
1080
1217
  Saves the audio fragment as a new audio file.
1081
1218
 
@@ -1178,7 +1315,7 @@ class VideoFragment(DataModel):
1178
1315
  start: float
1179
1316
  end: float
1180
1317
 
1181
- def save(self, output: str, format: Optional[str] = None) -> "VideoFile":
1318
+ def save(self, output: str, format: str | None = None) -> "VideoFile":
1182
1319
  """
1183
1320
  Saves the video fragment as a new video file.
1184
1321
 
datachain/lib/hf.py CHANGED
@@ -26,7 +26,7 @@ except ImportError as exc:
26
26
  ) from exc
27
27
 
28
28
  from io import BytesIO
29
- from typing import TYPE_CHECKING, Any, Optional, Union
29
+ from typing import TYPE_CHECKING, Any, TypeAlias
30
30
 
31
31
  import PIL
32
32
  from tqdm.auto import tqdm
@@ -41,7 +41,9 @@ if TYPE_CHECKING:
41
41
  from pydantic import BaseModel
42
42
 
43
43
 
44
- HFDatasetType = Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]
44
+ HFDatasetType: TypeAlias = (
45
+ str | DatasetDict | Dataset | IterableDatasetDict | IterableDataset
46
+ )
45
47
 
46
48
 
47
49
  class HFClassLabel(DataModel):
@@ -67,7 +69,7 @@ class HFAudio(DataModel):
67
69
  class HFGenerator(Generator):
68
70
  def __init__(
69
71
  self,
70
- ds: Union[str, HFDatasetType],
72
+ ds: HFDatasetType,
71
73
  output_schema: type["BaseModel"],
72
74
  limit: int = 0,
73
75
  *args,
@@ -117,7 +119,7 @@ class HFGenerator(Generator):
117
119
  pbar.update(1)
118
120
 
119
121
 
120
- def stream_splits(ds: Union[str, HFDatasetType], *args, **kwargs):
122
+ def stream_splits(ds: HFDatasetType, *args, **kwargs):
121
123
  if isinstance(ds, str):
122
124
  ds = load_dataset(ds, *args, **kwargs)
123
125
  if isinstance(ds, (DatasetDict, IterableDatasetDict)):
@@ -153,7 +155,7 @@ def convert_feature(val: Any, feat: Any, anno: Any) -> Any:
153
155
 
154
156
 
155
157
  def get_output_schema(
156
- features: Features, existing_column_names: Optional[list[str]] = None
158
+ features: Features, existing_column_names: list[str] | None = None
157
159
  ) -> tuple[dict[str, DataType], dict[str, str]]:
158
160
  """
159
161
  Generate UDF output schema from Hugging Face datasets features. It normalizes the
datachain/lib/image.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Callable, Optional, Union
1
+ from collections.abc import Callable
2
2
 
3
3
  import torch
4
4
  from PIL import Image as PILImage
@@ -6,7 +6,7 @@ from PIL import Image as PILImage
6
6
  from datachain.lib.file import File, FileError, Image, ImageFile
7
7
 
8
8
 
9
- def image_info(file: Union[File, ImageFile]) -> Image:
9
+ def image_info(file: File | ImageFile) -> Image:
10
10
  """
11
11
  Returns image file information.
12
12
 
@@ -31,11 +31,11 @@ def image_info(file: Union[File, ImageFile]) -> Image:
31
31
  def convert_image(
32
32
  img: PILImage.Image,
33
33
  mode: str = "RGB",
34
- size: Optional[tuple[int, int]] = None,
35
- transform: Optional[Callable] = None,
36
- encoder: Optional[Callable] = None,
37
- device: Optional[Union[str, torch.device]] = None,
38
- ) -> Union[PILImage.Image, torch.Tensor]:
34
+ size: tuple[int, int] | None = None,
35
+ transform: Callable | None = None,
36
+ encoder: Callable | None = None,
37
+ device: str | torch.device | None = None,
38
+ ) -> PILImage.Image | torch.Tensor:
39
39
  """
40
40
  Resize, transform, and otherwise convert an image.
41
41
 
@@ -71,13 +71,13 @@ def convert_image(
71
71
 
72
72
 
73
73
  def convert_images(
74
- images: Union[PILImage.Image, list[PILImage.Image]],
74
+ images: PILImage.Image | list[PILImage.Image],
75
75
  mode: str = "RGB",
76
- size: Optional[tuple[int, int]] = None,
77
- transform: Optional[Callable] = None,
78
- encoder: Optional[Callable] = None,
79
- device: Optional[Union[str, torch.device]] = None,
80
- ) -> Union[list[PILImage.Image], torch.Tensor]:
76
+ size: tuple[int, int] | None = None,
77
+ transform: Callable | None = None,
78
+ encoder: Callable | None = None,
79
+ device: str | torch.device | None = None,
80
+ ) -> list[PILImage.Image] | torch.Tensor:
81
81
  """
82
82
  Resize, transform, and otherwise convert one or more images.
83
83
 
datachain/lib/listing.py CHANGED
@@ -2,10 +2,10 @@ import glob
2
2
  import logging
3
3
  import os
4
4
  import posixpath
5
- from collections.abc import Iterator
5
+ from collections.abc import Callable, Iterator
6
6
  from contextlib import contextmanager
7
7
  from datetime import datetime, timedelta, timezone
8
- from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
8
+ from typing import TYPE_CHECKING, TypeVar
9
9
 
10
10
  from fsspec.asyn import get_loop
11
11
  from sqlalchemy.sql.expression import true
@@ -73,7 +73,7 @@ def get_file_info(uri: str, cache, client_config=None) -> File:
73
73
  def ls(
74
74
  dc: D,
75
75
  path: str,
76
- recursive: Optional[bool] = True,
76
+ recursive: bool | None = True,
77
77
  column="file",
78
78
  ) -> D:
79
79
  """
@@ -150,8 +150,8 @@ def _reraise_as_client_error() -> Iterator[None]:
150
150
 
151
151
 
152
152
  def get_listing(
153
- uri: Union[str, os.PathLike[str]], session: "Session", update: bool = False
154
- ) -> tuple[Optional[str], str, str, bool]:
153
+ uri: str | os.PathLike[str], session: "Session", update: bool = False
154
+ ) -> tuple[str | None, str, str, bool]:
155
155
  """Returns correct listing dataset name that must be used for saving listing
156
156
  operation. It takes into account existing listings and reusability of those.
157
157
  It also returns boolean saying if returned dataset name is reused / already
@@ -1,5 +1,4 @@
1
1
  from datetime import datetime, timedelta, timezone
2
- from typing import Optional
3
2
 
4
3
  from datachain.client import Client
5
4
  from datachain.lib.dataset_info import DatasetInfo
@@ -17,7 +16,7 @@ class ListingInfo(DatasetInfo):
17
16
  return uri
18
17
 
19
18
  @property
20
- def expires(self) -> Optional[datetime]:
19
+ def expires(self) -> datetime | None:
21
20
  if not self.finished_at:
22
21
  return None
23
22
  return self.finished_at + timedelta(seconds=LISTING_TTL)
@@ -1,14 +1,13 @@
1
1
  import csv
2
- import json
3
2
  import tempfile
4
3
  import uuid
5
- from collections.abc import Iterator
4
+ from collections.abc import Callable, Iterator
6
5
  from pathlib import Path
7
- from typing import Callable
8
6
 
9
7
  import jmespath as jsp
10
8
  from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401
11
9
 
10
+ from datachain import json
12
11
  from datachain.lib.data_model import DataModel # noqa: F401
13
12
  from datachain.lib.file import TextFile
14
13