datachain 0.34.6__py3-none-any.whl → 0.34.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (105) hide show
  1. datachain/asyn.py +11 -12
  2. datachain/cache.py +5 -5
  3. datachain/catalog/catalog.py +75 -83
  4. datachain/catalog/loader.py +3 -3
  5. datachain/checkpoint.py +1 -2
  6. datachain/cli/__init__.py +2 -4
  7. datachain/cli/commands/datasets.py +13 -13
  8. datachain/cli/commands/ls.py +4 -4
  9. datachain/cli/commands/query.py +3 -3
  10. datachain/cli/commands/show.py +2 -2
  11. datachain/cli/parser/job.py +1 -1
  12. datachain/cli/parser/utils.py +1 -2
  13. datachain/cli/utils.py +1 -2
  14. datachain/client/azure.py +2 -2
  15. datachain/client/fsspec.py +11 -21
  16. datachain/client/gcs.py +3 -3
  17. datachain/client/http.py +4 -4
  18. datachain/client/local.py +4 -4
  19. datachain/client/s3.py +3 -3
  20. datachain/config.py +4 -8
  21. datachain/data_storage/db_engine.py +5 -5
  22. datachain/data_storage/metastore.py +107 -107
  23. datachain/data_storage/schema.py +18 -24
  24. datachain/data_storage/sqlite.py +21 -28
  25. datachain/data_storage/warehouse.py +13 -13
  26. datachain/dataset.py +64 -70
  27. datachain/delta.py +21 -18
  28. datachain/diff/__init__.py +13 -13
  29. datachain/func/aggregate.py +9 -11
  30. datachain/func/array.py +12 -12
  31. datachain/func/base.py +7 -4
  32. datachain/func/conditional.py +9 -13
  33. datachain/func/func.py +45 -42
  34. datachain/func/numeric.py +5 -7
  35. datachain/func/string.py +2 -2
  36. datachain/hash_utils.py +54 -81
  37. datachain/job.py +8 -8
  38. datachain/lib/arrow.py +17 -14
  39. datachain/lib/audio.py +6 -6
  40. datachain/lib/clip.py +5 -4
  41. datachain/lib/convert/python_to_sql.py +4 -22
  42. datachain/lib/convert/values_to_tuples.py +4 -9
  43. datachain/lib/data_model.py +20 -19
  44. datachain/lib/dataset_info.py +6 -6
  45. datachain/lib/dc/csv.py +10 -10
  46. datachain/lib/dc/database.py +28 -29
  47. datachain/lib/dc/datachain.py +98 -97
  48. datachain/lib/dc/datasets.py +22 -22
  49. datachain/lib/dc/hf.py +4 -4
  50. datachain/lib/dc/json.py +9 -10
  51. datachain/lib/dc/listings.py +5 -8
  52. datachain/lib/dc/pandas.py +3 -6
  53. datachain/lib/dc/parquet.py +5 -5
  54. datachain/lib/dc/records.py +5 -5
  55. datachain/lib/dc/storage.py +12 -12
  56. datachain/lib/dc/storage_pattern.py +2 -2
  57. datachain/lib/dc/utils.py +11 -14
  58. datachain/lib/dc/values.py +3 -6
  59. datachain/lib/file.py +26 -26
  60. datachain/lib/hf.py +7 -5
  61. datachain/lib/image.py +13 -13
  62. datachain/lib/listing.py +5 -5
  63. datachain/lib/listing_info.py +1 -2
  64. datachain/lib/meta_formats.py +1 -2
  65. datachain/lib/model_store.py +3 -3
  66. datachain/lib/namespaces.py +4 -6
  67. datachain/lib/projects.py +5 -9
  68. datachain/lib/pytorch.py +10 -10
  69. datachain/lib/settings.py +23 -23
  70. datachain/lib/signal_schema.py +52 -44
  71. datachain/lib/text.py +8 -7
  72. datachain/lib/udf.py +25 -17
  73. datachain/lib/udf_signature.py +11 -11
  74. datachain/lib/video.py +3 -4
  75. datachain/lib/webdataset.py +30 -35
  76. datachain/lib/webdataset_laion.py +15 -16
  77. datachain/listing.py +4 -4
  78. datachain/model/bbox.py +3 -1
  79. datachain/namespace.py +4 -4
  80. datachain/node.py +6 -6
  81. datachain/nodes_thread_pool.py +0 -1
  82. datachain/plugins.py +1 -7
  83. datachain/project.py +4 -4
  84. datachain/query/batch.py +7 -8
  85. datachain/query/dataset.py +80 -87
  86. datachain/query/dispatch.py +7 -7
  87. datachain/query/metrics.py +3 -4
  88. datachain/query/params.py +2 -3
  89. datachain/query/schema.py +7 -6
  90. datachain/query/session.py +7 -7
  91. datachain/query/udf.py +8 -7
  92. datachain/query/utils.py +3 -5
  93. datachain/remote/studio.py +33 -39
  94. datachain/script_meta.py +12 -12
  95. datachain/sql/sqlite/base.py +6 -9
  96. datachain/studio.py +30 -30
  97. datachain/toolkit/split.py +1 -2
  98. datachain/utils.py +21 -21
  99. {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/METADATA +2 -3
  100. datachain-0.34.7.dist-info/RECORD +173 -0
  101. datachain-0.34.6.dist-info/RECORD +0 -173
  102. {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/WHEEL +0 -0
  103. {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/entry_points.txt +0 -0
  104. {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/licenses/LICENSE +0 -0
  105. {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/top_level.txt +0 -0
datachain/asyn.py CHANGED
@@ -3,6 +3,7 @@ import threading
3
3
  from collections.abc import (
4
4
  AsyncIterable,
5
5
  Awaitable,
6
+ Callable,
6
7
  Coroutine,
7
8
  Generator,
8
9
  Iterable,
@@ -10,7 +11,7 @@ from collections.abc import (
10
11
  )
11
12
  from concurrent.futures import ThreadPoolExecutor, wait
12
13
  from heapq import heappop, heappush
13
- from typing import Any, Callable, Generic, Optional, TypeVar
14
+ from typing import Any, Generic, TypeVar
14
15
 
15
16
  from fsspec.asyn import get_loop
16
17
 
@@ -49,7 +50,7 @@ class AsyncMapper(Generic[InputT, ResultT]):
49
50
  iterable: Iterable[InputT],
50
51
  *,
51
52
  workers: int = ASYNC_WORKERS,
52
- loop: Optional[asyncio.AbstractEventLoop] = None,
53
+ loop: asyncio.AbstractEventLoop | None = None,
53
54
  ):
54
55
  self.func = func
55
56
  self.iterable = iterable
@@ -107,9 +108,7 @@ class AsyncMapper(Generic[InputT, ResultT]):
107
108
 
108
109
  async def init(self) -> None:
109
110
  self.work_queue = asyncio.Queue(2 * self.workers)
110
- self.result_queue: asyncio.Queue[Optional[ResultT]] = asyncio.Queue(
111
- self.workers
112
- )
111
+ self.result_queue: asyncio.Queue[ResultT | None] = asyncio.Queue(self.workers)
113
112
 
114
113
  async def run(self) -> None:
115
114
  producer = self.start_task(self.produce())
@@ -149,10 +148,10 @@ class AsyncMapper(Generic[InputT, ResultT]):
149
148
  if exc:
150
149
  raise exc
151
150
 
152
- async def _pop_result(self) -> Optional[ResultT]:
151
+ async def _pop_result(self) -> ResultT | None:
153
152
  return await self.result_queue.get()
154
153
 
155
- def next_result(self, timeout=None) -> Optional[ResultT]:
154
+ def next_result(self, timeout=None) -> ResultT | None:
156
155
  """
157
156
  Return the next available result.
158
157
 
@@ -212,17 +211,17 @@ class OrderedMapper(AsyncMapper[InputT, ResultT]):
212
211
  iterable: Iterable[InputT],
213
212
  *,
214
213
  workers: int = ASYNC_WORKERS,
215
- loop: Optional[asyncio.AbstractEventLoop] = None,
214
+ loop: asyncio.AbstractEventLoop | None = None,
216
215
  ):
217
216
  super().__init__(func, iterable, workers=workers, loop=loop)
218
217
  self._waiters: dict[int, Any] = {}
219
- self._getters: dict[int, asyncio.Future[Optional[ResultT]]] = {}
220
- self.heap: list[tuple[int, Optional[ResultT]]] = []
218
+ self._getters: dict[int, asyncio.Future[ResultT | None]] = {}
219
+ self.heap: list[tuple[int, ResultT | None]] = []
221
220
  self._next_yield = 0
222
221
  self._items_seen = 0
223
222
  self._window = 2 * workers
224
223
 
225
- def _push_result(self, i: int, result: Optional[ResultT]) -> None:
224
+ def _push_result(self, i: int, result: ResultT | None) -> None:
226
225
  if i in self._getters:
227
226
  future = self._getters.pop(i)
228
227
  future.set_result(result)
@@ -243,7 +242,7 @@ class OrderedMapper(AsyncMapper[InputT, ResultT]):
243
242
  async def init(self) -> None:
244
243
  self.work_queue = asyncio.Queue(2 * self.workers)
245
244
 
246
- async def _pop_result(self) -> Optional[ResultT]:
245
+ async def _pop_result(self) -> ResultT | None:
247
246
  if self.heap and self.heap[0][0] == self._next_yield:
248
247
  _i, out = heappop(self.heap)
249
248
  else:
datachain/cache.py CHANGED
@@ -2,7 +2,7 @@ import os
2
2
  from collections.abc import Iterator
3
3
  from contextlib import contextmanager
4
4
  from tempfile import mkdtemp
5
- from typing import TYPE_CHECKING, Optional
5
+ from typing import TYPE_CHECKING
6
6
 
7
7
  from dvc_data.hashfile.db.local import LocalHashFileDB
8
8
  from dvc_objects.fs.local import LocalFileSystem
@@ -22,14 +22,14 @@ def try_scandir(path):
22
22
  pass
23
23
 
24
24
 
25
- def get_temp_cache(tmp_dir: str, prefix: Optional[str] = None) -> "Cache":
25
+ def get_temp_cache(tmp_dir: str, prefix: str | None = None) -> "Cache":
26
26
  cache_dir = mkdtemp(prefix=prefix, dir=tmp_dir)
27
27
  return Cache(cache_dir, tmp_dir=tmp_dir)
28
28
 
29
29
 
30
30
  @contextmanager
31
31
  def temporary_cache(
32
- tmp_dir: str, prefix: Optional[str] = None, delete: bool = True
32
+ tmp_dir: str, prefix: str | None = None, delete: bool = True
33
33
  ) -> Iterator["Cache"]:
34
34
  cache = get_temp_cache(tmp_dir, prefix=prefix)
35
35
  try:
@@ -58,7 +58,7 @@ class Cache: # noqa: PLW1641
58
58
  def tmp_dir(self):
59
59
  return self.odb.tmp_dir
60
60
 
61
- def get_path(self, file: "File") -> Optional[str]:
61
+ def get_path(self, file: "File") -> str | None:
62
62
  if self.contains(file):
63
63
  return self.path_from_checksum(file.get_hash())
64
64
  return None
@@ -74,7 +74,7 @@ class Cache: # noqa: PLW1641
74
74
  self.odb.delete(file.get_hash())
75
75
 
76
76
  async def download(
77
- self, file: "File", client: "Client", callback: Optional[Callback] = None
77
+ self, file: "File", client: "Client", callback: Callback | None = None
78
78
  ) -> None:
79
79
  from dvc_objects.fs.utils import tmp_fname
80
80
 
@@ -9,20 +9,12 @@ import subprocess
9
9
  import sys
10
10
  import time
11
11
  import traceback
12
- from collections.abc import Iterable, Iterator, Mapping, Sequence
12
+ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
13
13
  from copy import copy
14
14
  from dataclasses import dataclass
15
15
  from functools import cached_property, reduce
16
16
  from threading import Thread
17
- from typing import (
18
- IO,
19
- TYPE_CHECKING,
20
- Any,
21
- Callable,
22
- NoReturn,
23
- Optional,
24
- Union,
25
- )
17
+ from typing import IO, TYPE_CHECKING, Any, NoReturn
26
18
  from uuid import uuid4
27
19
 
28
20
  import sqlalchemy as sa
@@ -64,10 +56,7 @@ from datachain.utils import DataChainDir
64
56
  from .datasource import DataSource
65
57
 
66
58
  if TYPE_CHECKING:
67
- from datachain.data_storage import (
68
- AbstractMetastore,
69
- AbstractWarehouse,
70
- )
59
+ from datachain.data_storage import AbstractMetastore, AbstractWarehouse
71
60
  from datachain.dataset import DatasetListVersion
72
61
  from datachain.job import Job
73
62
  from datachain.lib.listing_info import ListingInfo
@@ -120,8 +109,8 @@ def is_namespace_local(namespace_name) -> bool:
120
109
 
121
110
  def shutdown_process(
122
111
  proc: subprocess.Popen,
123
- interrupt_timeout: Optional[int] = None,
124
- terminate_timeout: Optional[int] = None,
112
+ interrupt_timeout: int | None = None,
113
+ terminate_timeout: int | None = None,
125
114
  ) -> int:
126
115
  """Shut down the process gracefully with SIGINT -> SIGTERM -> SIGKILL."""
127
116
 
@@ -168,7 +157,7 @@ class DatasetRowsFetcher(NodesThreadPool):
168
157
  remote_ds_version: str,
169
158
  local_ds: DatasetRecord,
170
159
  local_ds_version: str,
171
- schema: dict[str, Union[SQLType, type[SQLType]]],
160
+ schema: dict[str, SQLType | type[SQLType]],
172
161
  max_threads: int = PULL_DATASET_MAX_THREADS,
173
162
  progress_bar=None,
174
163
  ):
@@ -183,7 +172,7 @@ class DatasetRowsFetcher(NodesThreadPool):
183
172
  self.local_ds = local_ds
184
173
  self.local_ds_version = local_ds_version
185
174
  self.schema = schema
186
- self.last_status_check: Optional[float] = None
175
+ self.last_status_check: float | None = None
187
176
  self.studio_client = StudioClient()
188
177
  self.progress_bar = progress_bar
189
178
 
@@ -287,16 +276,16 @@ class DatasetRowsFetcher(NodesThreadPool):
287
276
  class NodeGroup:
288
277
  """Class for a group of nodes from the same source"""
289
278
 
290
- listing: Optional["Listing"]
291
- client: "Client"
279
+ listing: "Listing | None"
280
+ client: Client
292
281
  sources: list[DataSource]
293
282
 
294
283
  # The source path within the bucket
295
284
  # (not including the bucket name or s3:// prefix)
296
285
  source_path: str = ""
297
- dataset_name: Optional[str] = None
298
- dataset_version: Optional[str] = None
299
- instantiated_nodes: Optional[list[NodeWithPath]] = None
286
+ dataset_name: str | None = None
287
+ dataset_version: str | None = None
288
+ instantiated_nodes: list[NodeWithPath] | None = None
300
289
 
301
290
  @property
302
291
  def is_dataset(self) -> bool:
@@ -323,7 +312,7 @@ def prepare_output_for_cp(
323
312
  output: str,
324
313
  force: bool = False,
325
314
  no_cp: bool = False,
326
- ) -> tuple[bool, Optional[str]]:
315
+ ) -> tuple[bool, str | None]:
327
316
  total_node_count = 0
328
317
  for node_group in node_groups:
329
318
  if not node_group.sources:
@@ -372,7 +361,7 @@ def collect_nodes_for_cp(
372
361
 
373
362
  # Collect all sources to process
374
363
  for node_group in node_groups:
375
- listing: Optional[Listing] = node_group.listing
364
+ listing: Listing | None = node_group.listing
376
365
  valid_sources: list[DataSource] = []
377
366
  for dsrc in node_group.sources:
378
367
  if dsrc.is_single_object():
@@ -416,7 +405,7 @@ def instantiate_node_groups(
416
405
  recursive: bool = False,
417
406
  virtual_only: bool = False,
418
407
  always_copy_dir_contents: bool = False,
419
- copy_to_filename: Optional[str] = None,
408
+ copy_to_filename: str | None = None,
420
409
  ) -> None:
421
410
  instantiate_progress_bar = (
422
411
  None
@@ -444,7 +433,7 @@ def instantiate_node_groups(
444
433
  for node_group in node_groups:
445
434
  if not node_group.sources:
446
435
  continue
447
- listing: Optional[Listing] = node_group.listing
436
+ listing: Listing | None = node_group.listing
448
437
  source_path: str = node_group.source_path
449
438
 
450
439
  copy_dir_contents = always_copy_dir_contents or source_path.endswith("/")
@@ -527,10 +516,8 @@ class Catalog:
527
516
  warehouse: "AbstractWarehouse",
528
517
  cache_dir=None,
529
518
  tmp_dir=None,
530
- client_config: Optional[dict[str, Any]] = None,
531
- warehouse_ready_callback: Optional[
532
- Callable[["AbstractWarehouse"], None]
533
- ] = None,
519
+ client_config: dict[str, Any] | None = None,
520
+ warehouse_ready_callback: Callable[["AbstractWarehouse"], None] | None = None,
534
521
  in_memory: bool = False,
535
522
  ):
536
523
  datachain_dir = DataChainDir(cache=cache_dir, tmp=tmp_dir)
@@ -592,7 +579,7 @@ class Catalog:
592
579
  client_config=None,
593
580
  column="file",
594
581
  skip_indexing=False,
595
- ) -> tuple[Optional["Listing"], "Client", str]:
582
+ ) -> tuple["Listing | None", Client, str]:
596
583
  from datachain import read_storage
597
584
  from datachain.listing import Listing
598
585
 
@@ -633,7 +620,7 @@ class Catalog:
633
620
  skip_indexing=False,
634
621
  client_config=None,
635
622
  only_index=False,
636
- ) -> Optional[list["DataSource"]]:
623
+ ) -> list["DataSource"] | None:
637
624
  enlisted_sources = []
638
625
  for src in sources: # Opt: parallel
639
626
  listing, client, file_path = self.enlist_source(
@@ -679,7 +666,7 @@ class Catalog:
679
666
  enlisted_sources: list[tuple[bool, bool, Any]] = []
680
667
  client_config = client_config or self.client_config
681
668
  for src in sources: # Opt: parallel
682
- listing: Optional[Listing]
669
+ listing: Listing | None
683
670
  if src.startswith("ds://"):
684
671
  ds_name, ds_version = parse_dataset_uri(src)
685
672
  ds_namespace, ds_project, ds_name = parse_dataset_name(ds_name)
@@ -785,19 +772,19 @@ class Catalog:
785
772
  def create_dataset(
786
773
  self,
787
774
  name: str,
788
- project: Optional[Project] = None,
789
- version: Optional[str] = None,
775
+ project: Project | None = None,
776
+ version: str | None = None,
790
777
  *,
791
778
  columns: Sequence[Column],
792
- feature_schema: Optional[dict] = None,
779
+ feature_schema: dict | None = None,
793
780
  query_script: str = "",
794
- create_rows: Optional[bool] = True,
795
- validate_version: Optional[bool] = True,
796
- listing: Optional[bool] = False,
797
- uuid: Optional[str] = None,
798
- description: Optional[str] = None,
799
- attrs: Optional[list[str]] = None,
800
- update_version: Optional[str] = "patch",
781
+ create_rows: bool | None = True,
782
+ validate_version: bool | None = True,
783
+ listing: bool | None = False,
784
+ uuid: str | None = None,
785
+ description: str | None = None,
786
+ attrs: list[str] | None = None,
787
+ update_version: str | None = "patch",
801
788
  ) -> "DatasetRecord":
802
789
  """
803
790
  Creates new dataset of a specific version.
@@ -886,8 +873,8 @@ class Catalog:
886
873
  error_stack="",
887
874
  script_output="",
888
875
  create_rows_table=True,
889
- job_id: Optional[str] = None,
890
- uuid: Optional[str] = None,
876
+ job_id: str | None = None,
877
+ uuid: str | None = None,
891
878
  ) -> DatasetRecord:
892
879
  """
893
880
  Creates dataset version if it doesn't exist.
@@ -971,7 +958,7 @@ class Catalog:
971
958
  return dataset_updated
972
959
 
973
960
  def remove_dataset_version(
974
- self, dataset: DatasetRecord, version: str, drop_rows: Optional[bool] = True
961
+ self, dataset: DatasetRecord, version: str, drop_rows: bool | None = True
975
962
  ) -> None:
976
963
  """
977
964
  Deletes one single dataset version.
@@ -999,7 +986,7 @@ class Catalog:
999
986
  self,
1000
987
  name: str,
1001
988
  sources: list[str],
1002
- project: Optional[Project] = None,
989
+ project: Project | None = None,
1003
990
  client_config=None,
1004
991
  recursive=False,
1005
992
  ) -> DatasetRecord:
@@ -1068,8 +1055,8 @@ class Catalog:
1068
1055
  def get_full_dataset_name(
1069
1056
  self,
1070
1057
  name: str,
1071
- project_name: Optional[str] = None,
1072
- namespace_name: Optional[str] = None,
1058
+ project_name: str | None = None,
1059
+ namespace_name: str | None = None,
1073
1060
  ) -> tuple[str, str, str]:
1074
1061
  """
1075
1062
  Returns dataset name together with separated namespace and project name.
@@ -1101,8 +1088,8 @@ class Catalog:
1101
1088
  def get_dataset(
1102
1089
  self,
1103
1090
  name: str,
1104
- namespace_name: Optional[str] = None,
1105
- project_name: Optional[str] = None,
1091
+ namespace_name: str | None = None,
1092
+ project_name: str | None = None,
1106
1093
  ) -> DatasetRecord:
1107
1094
  from datachain.lib.listing import is_listing_dataset
1108
1095
 
@@ -1122,7 +1109,7 @@ class Catalog:
1122
1109
  name: str,
1123
1110
  namespace_name: str,
1124
1111
  project_name: str,
1125
- version: Optional[str] = None,
1112
+ version: str | None = None,
1126
1113
  pull_dataset: bool = False,
1127
1114
  update: bool = False,
1128
1115
  ) -> DatasetRecord:
@@ -1213,10 +1200,10 @@ class Catalog:
1213
1200
  self,
1214
1201
  name: str,
1215
1202
  version: str,
1216
- namespace_name: Optional[str] = None,
1217
- project_name: Optional[str] = None,
1203
+ namespace_name: str | None = None,
1204
+ project_name: str | None = None,
1218
1205
  indirect=False,
1219
- ) -> list[Optional[DatasetDependency]]:
1206
+ ) -> list[DatasetDependency | None]:
1220
1207
  dataset = self.get_dataset(
1221
1208
  name,
1222
1209
  namespace_name=namespace_name,
@@ -1248,10 +1235,10 @@ class Catalog:
1248
1235
 
1249
1236
  def ls_datasets(
1250
1237
  self,
1251
- prefix: Optional[str] = None,
1238
+ prefix: str | None = None,
1252
1239
  include_listing: bool = False,
1253
1240
  studio: bool = False,
1254
- project: Optional[Project] = None,
1241
+ project: Project | None = None,
1255
1242
  ) -> Iterator[DatasetListRecord]:
1256
1243
  from datachain.remote.studio import StudioClient
1257
1244
 
@@ -1283,12 +1270,12 @@ class Catalog:
1283
1270
 
1284
1271
  def list_datasets_versions(
1285
1272
  self,
1286
- prefix: Optional[str] = None,
1273
+ prefix: str | None = None,
1287
1274
  include_listing: bool = False,
1288
1275
  with_job: bool = True,
1289
1276
  studio: bool = False,
1290
- project: Optional[Project] = None,
1291
- ) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", Optional["Job"]]]:
1277
+ project: Project | None = None,
1278
+ ) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", "Job | None"]]:
1292
1279
  """Iterate over all dataset versions with related jobs."""
1293
1280
  datasets = list(
1294
1281
  self.ls_datasets(
@@ -1316,7 +1303,7 @@ class Catalog:
1316
1303
  for v in d.versions
1317
1304
  )
1318
1305
 
1319
- def listings(self, prefix: Optional[str] = None) -> list["ListingInfo"]:
1306
+ def listings(self, prefix: str | None = None) -> list["ListingInfo"]:
1320
1307
  """
1321
1308
  Returns list of ListingInfo objects which are representing specific
1322
1309
  storage listing datasets
@@ -1367,9 +1354,9 @@ class Catalog:
1367
1354
  self,
1368
1355
  source: str,
1369
1356
  path: str,
1370
- version_id: Optional[str] = None,
1357
+ version_id: str | None = None,
1371
1358
  client_config=None,
1372
- content_disposition: Optional[str] = None,
1359
+ content_disposition: str | None = None,
1373
1360
  **kwargs,
1374
1361
  ) -> str:
1375
1362
  client_config = client_config or self.client_config
@@ -1388,7 +1375,7 @@ class Catalog:
1388
1375
  bucket_uri: str,
1389
1376
  name: str,
1390
1377
  version: str,
1391
- project: Optional[Project] = None,
1378
+ project: Project | None = None,
1392
1379
  client_config=None,
1393
1380
  ) -> list[str]:
1394
1381
  dataset = self.get_dataset(
@@ -1402,7 +1389,7 @@ class Catalog:
1402
1389
  )
1403
1390
 
1404
1391
  def dataset_table_export_file_names(
1405
- self, name: str, version: str, project: Optional[Project] = None
1392
+ self, name: str, version: str, project: Project | None = None
1406
1393
  ) -> list[str]:
1407
1394
  dataset = self.get_dataset(
1408
1395
  name,
@@ -1414,9 +1401,9 @@ class Catalog:
1414
1401
  def remove_dataset(
1415
1402
  self,
1416
1403
  name: str,
1417
- project: Optional[Project] = None,
1418
- version: Optional[str] = None,
1419
- force: Optional[bool] = False,
1404
+ project: Project | None = None,
1405
+ version: str | None = None,
1406
+ force: bool | None = False,
1420
1407
  ):
1421
1408
  dataset = self.get_dataset(
1422
1409
  name,
@@ -1444,10 +1431,10 @@ class Catalog:
1444
1431
  def edit_dataset(
1445
1432
  self,
1446
1433
  name: str,
1447
- project: Optional[Project] = None,
1448
- new_name: Optional[str] = None,
1449
- description: Optional[str] = None,
1450
- attrs: Optional[list[str]] = None,
1434
+ project: Project | None = None,
1435
+ new_name: str | None = None,
1436
+ description: str | None = None,
1437
+ attrs: list[str] | None = None,
1451
1438
  ) -> DatasetRecord:
1452
1439
  update_data = {}
1453
1440
  if new_name:
@@ -1487,9 +1474,9 @@ class Catalog:
1487
1474
  def pull_dataset( # noqa: C901, PLR0915
1488
1475
  self,
1489
1476
  remote_ds_uri: str,
1490
- output: Optional[str] = None,
1491
- local_ds_name: Optional[str] = None,
1492
- local_ds_version: Optional[str] = None,
1477
+ output: str | None = None,
1478
+ local_ds_name: str | None = None,
1479
+ local_ds_version: str | None = None,
1493
1480
  cp: bool = False,
1494
1481
  force: bool = False,
1495
1482
  *,
@@ -1763,21 +1750,26 @@ class Catalog:
1763
1750
  def query(
1764
1751
  self,
1765
1752
  query_script: str,
1766
- env: Optional[Mapping[str, str]] = None,
1753
+ env: Mapping[str, str] | None = None,
1767
1754
  python_executable: str = sys.executable,
1768
1755
  capture_output: bool = False,
1769
1756
  output_hook: Callable[[str], None] = noop,
1770
- params: Optional[dict[str, str]] = None,
1771
- job_id: Optional[str] = None,
1772
- interrupt_timeout: Optional[int] = None,
1773
- terminate_timeout: Optional[int] = None,
1757
+ params: dict[str, str] | None = None,
1758
+ job_id: str | None = None,
1759
+ reset: bool = False,
1760
+ interrupt_timeout: int | None = None,
1761
+ terminate_timeout: int | None = None,
1774
1762
  ) -> None:
1763
+ if not isinstance(reset, bool):
1764
+ raise TypeError(f"reset must be a bool, got {type(reset).__name__}")
1765
+
1775
1766
  cmd = [python_executable, "-c", query_script]
1776
1767
  env = dict(env or os.environ)
1777
1768
  env.update(
1778
1769
  {
1779
1770
  "DATACHAIN_QUERY_PARAMS": json.dumps(params or {}),
1780
1771
  "DATACHAIN_JOB_ID": job_id or "",
1772
+ "DATACHAIN_CHECKPOINTS_RESET": str(reset),
1781
1773
  },
1782
1774
  )
1783
1775
  popen_kwargs: dict[str, Any] = {}
@@ -1787,7 +1779,7 @@ class Catalog:
1787
1779
  def raise_termination_signal(sig: int, _: Any) -> NoReturn:
1788
1780
  raise TerminationSignal(sig)
1789
1781
 
1790
- thread: Optional[Thread] = None
1782
+ thread: Thread | None = None
1791
1783
  with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603
1792
1784
  logger.info("Starting process %s", proc.pid)
1793
1785
 
@@ -1850,7 +1842,7 @@ class Catalog:
1850
1842
  no_cp: bool = False,
1851
1843
  no_glob: bool = False,
1852
1844
  *,
1853
- client_config: Optional["dict"] = None,
1845
+ client_config: dict | None = None,
1854
1846
  ) -> None:
1855
1847
  """
1856
1848
  This function copies files from cloud sources to local destination directory
@@ -1,7 +1,7 @@
1
1
  import os
2
2
  import sys
3
3
  from importlib import import_module
4
- from typing import TYPE_CHECKING, Any, Optional
4
+ from typing import TYPE_CHECKING, Any
5
5
 
6
6
  from datachain.plugins import ensure_plugins_loaded
7
7
  from datachain.utils import get_envs_by_prefix
@@ -108,7 +108,7 @@ def get_warehouse(in_memory: bool = False) -> "AbstractWarehouse":
108
108
  return warehouse_class(**warehouse_args)
109
109
 
110
110
 
111
- def get_udf_distributor_class() -> Optional[type["AbstractUDFDistributor"]]:
111
+ def get_udf_distributor_class() -> type["AbstractUDFDistributor"] | None:
112
112
  if os.environ.get(DISTRIBUTED_DISABLED) == "True":
113
113
  return None
114
114
 
@@ -132,7 +132,7 @@ def get_udf_distributor_class() -> Optional[type["AbstractUDFDistributor"]]:
132
132
 
133
133
 
134
134
  def get_catalog(
135
- client_config: Optional[dict[str, Any]] = None,
135
+ client_config: dict[str, Any] | None = None,
136
136
  in_memory: bool = False,
137
137
  ) -> "Catalog":
138
138
  """
datachain/checkpoint.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import uuid
2
2
  from dataclasses import dataclass
3
3
  from datetime import datetime
4
- from typing import Union
5
4
 
6
5
 
7
6
  @dataclass
@@ -29,7 +28,7 @@ class Checkpoint:
29
28
  @classmethod
30
29
  def parse(
31
30
  cls,
32
- id: Union[str, uuid.UUID],
31
+ id: str | uuid.UUID,
33
32
  job_id: str,
34
33
  _hash: str,
35
34
  partial: bool,
datachain/cli/__init__.py CHANGED
@@ -3,10 +3,8 @@ import os
3
3
  import sys
4
4
  import traceback
5
5
  from multiprocessing import freeze_support
6
- from typing import Optional
7
6
 
8
7
  from datachain.cli.utils import get_logging_level
9
- from datachain.error import DataChainError as DataChainError
10
8
 
11
9
  from .commands import (
12
10
  clear_cache,
@@ -26,7 +24,7 @@ from .parser import get_parser
26
24
  logger = logging.getLogger("datachain")
27
25
 
28
26
 
29
- def main(argv: Optional[list[str]] = None) -> int:
27
+ def main(argv: list[str] | None = None) -> int:
30
28
  from datachain.catalog import get_catalog
31
29
 
32
30
  # Required for Windows multiprocessing support
@@ -307,7 +305,7 @@ def handle_udf() -> int:
307
305
  return udf_entrypoint()
308
306
 
309
307
 
310
- def handle_udf_runner(fd: Optional[int] = None) -> int:
308
+ def handle_udf_runner(fd: int | None = None) -> int:
311
309
  from datachain.query.dispatch import udf_worker_entrypoint
312
310
 
313
311
  return udf_worker_entrypoint(fd)
@@ -1,6 +1,6 @@
1
1
  import sys
2
2
  from collections.abc import Iterable, Iterator
3
- from typing import TYPE_CHECKING, Optional, Union
3
+ from typing import TYPE_CHECKING
4
4
 
5
5
  from tabulate import tabulate
6
6
 
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
17
17
 
18
18
  def group_dataset_versions(
19
19
  datasets: Iterable[tuple[str, str]], latest_only=True
20
- ) -> dict[str, Union[str, list[str]]]:
20
+ ) -> dict[str, str | list[str]]:
21
21
  grouped: dict[str, list[tuple[int, int, int]]] = {}
22
22
 
23
23
  # Sort to ensure groupby works as expected
@@ -43,9 +43,9 @@ def list_datasets(
43
43
  studio: bool = False,
44
44
  local: bool = False,
45
45
  all: bool = True,
46
- team: Optional[str] = None,
46
+ team: str | None = None,
47
47
  latest_only: bool = True,
48
- name: Optional[str] = None,
48
+ name: str | None = None,
49
49
  ) -> None:
50
50
  token = Config().read().get("studio", {}).get("token")
51
51
  all, local, studio = determine_flavors(studio, local, all, token)
@@ -107,7 +107,7 @@ def list_datasets(
107
107
 
108
108
 
109
109
  def list_datasets_local(
110
- catalog: "Catalog", name: Optional[str] = None
110
+ catalog: "Catalog", name: str | None = None
111
111
  ) -> Iterator[tuple[str, str]]:
112
112
  if name:
113
113
  yield from list_datasets_local_versions(catalog, name)
@@ -147,10 +147,10 @@ def _datasets_tabulate_row(name, both, local_version, studio_version) -> dict[st
147
147
  def rm_dataset(
148
148
  catalog: "Catalog",
149
149
  name: str,
150
- version: Optional[str] = None,
151
- force: Optional[bool] = False,
152
- studio: Optional[bool] = False,
153
- team: Optional[str] = None,
150
+ version: str | None = None,
151
+ force: bool | None = False,
152
+ studio: bool | None = False,
153
+ team: str | None = None,
154
154
  ) -> None:
155
155
  namespace_name, project_name, name = catalog.get_full_dataset_name(name)
156
156
 
@@ -177,10 +177,10 @@ def rm_dataset(
177
177
  def edit_dataset(
178
178
  catalog: "Catalog",
179
179
  name: str,
180
- new_name: Optional[str] = None,
181
- description: Optional[str] = None,
182
- attrs: Optional[list[str]] = None,
183
- team: Optional[str] = None,
180
+ new_name: str | None = None,
181
+ description: str | None = None,
182
+ attrs: list[str] | None = None,
183
+ team: str | None = None,
184
184
  ) -> None:
185
185
  from datachain.lib.dc.utils import is_studio
186
186