datachain 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/listing.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import posixpath
2
2
  from collections.abc import Iterator
3
3
  from datetime import datetime, timedelta, timezone
4
- from typing import TYPE_CHECKING, Callable, Optional
4
+ from typing import TYPE_CHECKING, Callable, Optional, TypeVar
5
5
 
6
6
  from fsspec.asyn import get_loop
7
7
  from sqlalchemy.sql.expression import true
@@ -20,6 +20,8 @@ if TYPE_CHECKING:
20
20
  LISTING_TTL = 4 * 60 * 60 # cached listing lasts 4 hours
21
21
  LISTING_PREFIX = "lst__" # listing datasets start with this name
22
22
 
23
+ D = TypeVar("D", bound="DataChain")
24
+
23
25
 
24
26
  def list_bucket(uri: str, cache, client_config=None) -> Callable:
25
27
  """
@@ -38,11 +40,11 @@ def list_bucket(uri: str, cache, client_config=None) -> Callable:
38
40
 
39
41
 
40
42
  def ls(
41
- dc: "DataChain",
43
+ dc: D,
42
44
  path: str,
43
45
  recursive: Optional[bool] = True,
44
46
  object_name="file",
45
- ):
47
+ ) -> D:
46
48
  """
47
49
  Return files by some path from DataChain instance which contains bucket listing.
48
50
  Path can have globs.
datachain/lib/pytorch.py CHANGED
@@ -9,6 +9,7 @@ from torch.utils.data import IterableDataset, get_worker_info
9
9
  from torchvision.transforms import v2
10
10
  from tqdm import tqdm
11
11
 
12
+ from datachain import Session
12
13
  from datachain.catalog import Catalog, get_catalog
13
14
  from datachain.lib.dc import DataChain
14
15
  from datachain.lib.text import convert_text
@@ -87,8 +88,11 @@ class PytorchDataset(IterableDataset):
87
88
  def __iter__(self) -> Iterator[Any]:
88
89
  if self.catalog is None:
89
90
  self.catalog = self._get_catalog()
91
+ session = Session.get(catalog=self.catalog)
90
92
  total_rank, total_workers = self.get_rank_and_workers()
91
- ds = DataChain(name=self.name, version=self.version, catalog=self.catalog)
93
+ ds = DataChain.from_dataset(
94
+ name=self.name, version=self.version, session=session
95
+ )
92
96
  ds = ds.remove_file_signals()
93
97
 
94
98
  if self.num_samples > 0:
datachain/lib/udf.py CHANGED
@@ -1,31 +1,33 @@
1
1
  import sys
2
2
  import traceback
3
- from typing import TYPE_CHECKING, Callable, Optional
3
+ from collections.abc import Iterable, Iterator, Mapping, Sequence
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Any, Callable, Optional
4
6
 
5
7
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
6
8
  from pydantic import BaseModel
7
9
 
8
10
  from datachain.dataset import RowDict
9
11
  from datachain.lib.convert.flatten import flatten
10
- from datachain.lib.convert.unflatten import unflatten_to_json
11
12
  from datachain.lib.file import File
12
- from datachain.lib.model_store import ModelStore
13
13
  from datachain.lib.signal_schema import SignalSchema
14
- from datachain.lib.udf_signature import UdfSignature
15
14
  from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
16
- from datachain.query.batch import UDFInputBatch
17
- from datachain.query.schema import ColumnParameter
18
- from datachain.query.udf import UDFBase as _UDFBase
19
- from datachain.query.udf import UDFProperties
15
+ from datachain.query.batch import (
16
+ Batch,
17
+ BatchingStrategy,
18
+ NoBatching,
19
+ Partition,
20
+ RowsOutputBatch,
21
+ UDFInputBatch,
22
+ )
23
+ from datachain.query.schema import ColumnParameter, UDFParameter
20
24
 
21
25
  if TYPE_CHECKING:
22
- from collections.abc import Iterable, Iterator, Sequence
23
-
24
26
  from typing_extensions import Self
25
27
 
26
28
  from datachain.catalog import Catalog
29
+ from datachain.lib.udf_signature import UdfSignature
27
30
  from datachain.query.batch import RowsOutput, UDFInput
28
- from datachain.query.udf import UDFResult
29
31
 
30
32
 
31
33
  class UdfError(DataChainParamsError):
@@ -33,14 +35,47 @@ class UdfError(DataChainParamsError):
33
35
  super().__init__(f"UDF error: {msg}")
34
36
 
35
37
 
36
- class UDFAdapter(_UDFBase):
38
+ ColumnType = Any
39
+
40
+ # Specification for the output of a UDF
41
+ UDFOutputSpec = Mapping[str, ColumnType]
42
+
43
+ # Result type when calling the UDF wrapper around the actual
44
+ # Python function / class implementing it.
45
+ UDFResult = dict[str, Any]
46
+
47
+
48
+ @dataclass
49
+ class UDFProperties:
50
+ """Container for basic UDF properties."""
51
+
52
+ params: list[UDFParameter]
53
+ output: UDFOutputSpec
54
+ batch: int = 1
55
+
56
+ def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
57
+ if use_partitioning:
58
+ return Partition()
59
+ if self.batch == 1:
60
+ return NoBatching()
61
+ if self.batch > 1:
62
+ return Batch(self.batch)
63
+ raise ValueError(f"invalid batch size {self.batch}")
64
+
65
+ def signal_names(self) -> Iterable[str]:
66
+ return self.output.keys()
67
+
68
+
69
+ class UDFAdapter:
37
70
  def __init__(
38
71
  self,
39
72
  inner: "UDFBase",
40
73
  properties: UDFProperties,
41
74
  ):
42
75
  self.inner = inner
43
- super().__init__(properties)
76
+ self.properties = properties
77
+ self.signal_names = properties.signal_names()
78
+ self.output = properties.output
44
79
 
45
80
  def run(
46
81
  self,
@@ -51,20 +86,23 @@ class UDFAdapter(_UDFBase):
51
86
  cache: bool,
52
87
  download_cb: Callback = DEFAULT_CALLBACK,
53
88
  processed_cb: Callback = DEFAULT_CALLBACK,
54
- ) -> "Iterator[Iterable[UDFResult]]":
55
- self.inner._catalog = catalog
89
+ ) -> Iterator[Iterable[UDFResult]]:
90
+ self.inner.catalog = catalog
56
91
  if hasattr(self.inner, "setup") and callable(self.inner.setup):
57
92
  self.inner.setup()
58
93
 
59
- yield from super().run(
60
- udf_fields,
61
- udf_inputs,
62
- catalog,
63
- is_generator,
64
- cache,
65
- download_cb,
66
- processed_cb,
67
- )
94
+ for batch in udf_inputs:
95
+ if isinstance(batch, RowsOutputBatch):
96
+ n_rows = len(batch.rows)
97
+ inputs: UDFInput = UDFInputBatch(
98
+ [RowDict(zip(udf_fields, row)) for row in batch.rows]
99
+ )
100
+ else:
101
+ n_rows = 1
102
+ inputs = RowDict(zip(udf_fields, batch))
103
+ output = self.run_once(catalog, inputs, is_generator, cache, cb=download_cb)
104
+ processed_cb.relative_update(n_rows)
105
+ yield output
68
106
 
69
107
  if hasattr(self.inner, "teardown") and callable(self.inner.teardown):
70
108
  self.inner.teardown()
@@ -76,23 +114,46 @@ class UDFAdapter(_UDFBase):
76
114
  is_generator: bool = False,
77
115
  cache: bool = False,
78
116
  cb: Callback = DEFAULT_CALLBACK,
79
- ) -> "Iterable[UDFResult]":
117
+ ) -> Iterable[UDFResult]:
80
118
  if isinstance(arg, UDFInputBatch):
81
119
  udf_inputs = [
82
120
  self.bind_parameters(catalog, row, cache=cache, cb=cb)
83
121
  for row in arg.rows
84
122
  ]
85
- udf_outputs = self.inner(udf_inputs, cache=cache, download_cb=cb)
123
+ udf_outputs = self.inner.run_once(udf_inputs, cache=cache, download_cb=cb)
86
124
  return self._process_results(arg.rows, udf_outputs, is_generator)
87
125
  if isinstance(arg, RowDict):
88
126
  udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
89
- udf_outputs = self.inner(*udf_inputs, cache=cache, download_cb=cb)
127
+ udf_outputs = self.inner.run_once(udf_inputs, cache=cache, download_cb=cb)
90
128
  if not is_generator:
91
129
  # udf_outputs is generator already if is_generator=True
92
130
  udf_outputs = [udf_outputs]
93
131
  return self._process_results([arg], udf_outputs, is_generator)
94
132
  raise ValueError(f"Unexpected UDF argument: {arg}")
95
133
 
134
+ def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
135
+ return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
136
+
137
+ def _process_results(
138
+ self,
139
+ rows: Sequence["RowDict"],
140
+ results: Sequence[Sequence[Any]],
141
+ is_generator=False,
142
+ ) -> Iterable[UDFResult]:
143
+ """Create a list of dictionaries representing UDF results."""
144
+
145
+ # outputting rows
146
+ if is_generator:
147
+ # each row in results is a tuple of column values
148
+ return (dict(zip(self.signal_names, row)) for row in results)
149
+
150
+ # outputting signals
151
+ row_ids = [row["sys__id"] for row in rows]
152
+ return [
153
+ {"sys__id": row_id} | dict(zip(self.signal_names, signals))
154
+ for row_id, signals in zip(row_ids, results)
155
+ ]
156
+
96
157
 
97
158
  class UDFBase(AbstractUDF):
98
159
  """Base class for stateful user-defined functions.
@@ -146,14 +207,14 @@ class UDFBase(AbstractUDF):
146
207
  is_output_batched = False
147
208
  is_input_grouped = False
148
209
  params_spec: Optional[list[str]]
210
+ catalog: "Optional[Catalog]"
149
211
 
150
212
  def __init__(self):
151
213
  self.params = None
152
214
  self.output = None
153
215
  self.params_spec = None
154
216
  self.output_spec = None
155
- self._contains_stream = None
156
- self._catalog = None
217
+ self.catalog = None
157
218
  self._func = None
158
219
 
159
220
  def process(self, *args, **kwargs):
@@ -174,9 +235,9 @@ class UDFBase(AbstractUDF):
174
235
 
175
236
  def _init(
176
237
  self,
177
- sign: UdfSignature,
238
+ sign: "UdfSignature",
178
239
  params: SignalSchema,
179
- func: Callable,
240
+ func: Optional[Callable],
180
241
  ):
181
242
  self.params = params
182
243
  self.output = sign.output_schema
@@ -190,13 +251,13 @@ class UDFBase(AbstractUDF):
190
251
  @classmethod
191
252
  def _create(
192
253
  cls,
193
- sign: UdfSignature,
254
+ sign: "UdfSignature",
194
255
  params: SignalSchema,
195
256
  ) -> "Self":
196
257
  if isinstance(sign.func, AbstractUDF):
197
258
  if not isinstance(sign.func, cls): # type: ignore[unreachable]
198
259
  raise UdfError(
199
- f"cannot create UDF: provided UDF '{sign.func.__name__}'"
260
+ f"cannot create UDF: provided UDF '{type(sign.func).__name__}'"
200
261
  f" must be a child of target class '{cls.__name__}'",
201
262
  )
202
263
  result = sign.func
@@ -212,13 +273,6 @@ class UDFBase(AbstractUDF):
212
273
  def name(self):
213
274
  return self.__class__.__name__
214
275
 
215
- def set_catalog(self, catalog):
216
- self._catalog = catalog.copy(db=False)
217
-
218
- @property
219
- def catalog(self):
220
- return self._catalog
221
-
222
276
  def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
223
277
  assert self.params_spec is not None
224
278
  properties = UDFProperties(
@@ -229,11 +283,9 @@ class UDFBase(AbstractUDF):
229
283
  def validate_results(self, results, *args, **kwargs):
230
284
  return results
231
285
 
232
- def __call__(self, *rows, cache, download_cb):
233
- if self.is_input_grouped:
234
- objs = self._parse_grouped_rows(rows[0], cache, download_cb)
235
- elif self.is_input_batched:
236
- objs = zip(*self._parse_rows(rows[0], cache, download_cb))
286
+ def run_once(self, rows, cache, download_cb):
287
+ if self.is_input_batched:
288
+ objs = zip(*self._parse_rows(rows, cache, download_cb))
237
289
  else:
238
290
  objs = self._parse_rows([rows], cache, download_cb)[0]
239
291
 
@@ -259,8 +311,8 @@ class UDFBase(AbstractUDF):
259
311
  ):
260
312
  res = list(res)
261
313
  assert len(res) == len(
262
- rows[0]
263
- ), f"{self.name} returns {len(res)} rows while len(rows[0]) expected"
314
+ rows
315
+ ), f"{self.name} returns {len(res)} rows while {len(rows)} expected"
264
316
 
265
317
  return res
266
318
 
@@ -283,41 +335,11 @@ class UDFBase(AbstractUDF):
283
335
  for obj in obj_row:
284
336
  if isinstance(obj, File):
285
337
  obj._set_stream(
286
- self._catalog, caching_enabled=cache, download_cb=download_cb
338
+ self.catalog, caching_enabled=cache, download_cb=download_cb
287
339
  )
288
340
  objs.append(obj_row)
289
341
  return objs
290
342
 
291
- def _parse_grouped_rows(self, group, cache, download_cb):
292
- spec_map = {}
293
- output_map = {}
294
- for name, (anno, subtree) in self.params.tree.items():
295
- if ModelStore.is_pydantic(anno):
296
- length = sum(1 for _ in self.params._get_flat_tree(subtree, [], 0))
297
- else:
298
- length = 1
299
- spec_map[name] = anno, length
300
- output_map[name] = []
301
-
302
- for flat_obj in group:
303
- position = 0
304
- for signal, (cls, length) in spec_map.items():
305
- slice = flat_obj[position : position + length]
306
- position += length
307
-
308
- if ModelStore.is_pydantic(cls):
309
- obj = cls(**unflatten_to_json(cls, slice))
310
- else:
311
- obj = slice[0]
312
-
313
- if isinstance(obj, File):
314
- obj._set_stream(
315
- self._catalog, caching_enabled=cache, download_cb=download_cb
316
- )
317
- output_map[signal].append(obj)
318
-
319
- return list(output_map.values())
320
-
321
343
  def process_safe(self, obj_rows):
322
344
  try:
323
345
  result_objs = self.process(*obj_rows)
@@ -1,10 +1,11 @@
1
1
  import inspect
2
2
  from collections.abc import Generator, Iterator, Sequence
3
3
  from dataclasses import dataclass
4
- from typing import Callable, Optional, Union, get_args, get_origin
4
+ from typing import Callable, Union, get_args, get_origin
5
5
 
6
6
  from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
7
7
  from datachain.lib.signal_schema import SignalSchema
8
+ from datachain.lib.udf import UDFBase
8
9
  from datachain.lib.utils import AbstractUDF, DataChainParamsError
9
10
 
10
11
 
@@ -16,7 +17,7 @@ class UdfSignatureError(DataChainParamsError):
16
17
 
17
18
  @dataclass
18
19
  class UdfSignature:
19
- func: Callable
20
+ func: Union[Callable, UDFBase]
20
21
  params: Sequence[str]
21
22
  output_schema: SignalSchema
22
23
 
@@ -27,7 +28,7 @@ class UdfSignature:
27
28
  cls,
28
29
  chain: str,
29
30
  signal_map: dict[str, Callable],
30
- func: Optional[Callable] = None,
31
+ func: Union[None, UDFBase, Callable] = None,
31
32
  params: Union[None, str, Sequence[str]] = None,
32
33
  output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
33
34
  is_generator: bool = True,
@@ -39,6 +40,7 @@ class UdfSignature:
39
40
  f"multiple signals '{keys}' are not supported in processors."
40
41
  " Chain multiple processors instead.",
41
42
  )
43
+ udf_func: Union[UDFBase, Callable]
42
44
  if len(signal_map) == 1:
43
45
  if func is not None:
44
46
  raise UdfSignatureError(
@@ -53,7 +55,7 @@ class UdfSignature:
53
55
  udf_func = func
54
56
  signal_name = None
55
57
 
56
- if not callable(udf_func):
58
+ if not isinstance(udf_func, UDFBase) and not callable(udf_func):
57
59
  raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
58
60
 
59
61
  func_params_map_sign, func_outs_sign, is_iterator = (
@@ -73,7 +75,7 @@ class UdfSignature:
73
75
  if not func_outs_sign:
74
76
  raise UdfSignatureError(
75
77
  chain,
76
- f"outputs are not defined in function '{udf_func.__name__}'"
78
+ f"outputs are not defined in function '{udf_func}'"
77
79
  " hints or 'output'",
78
80
  )
79
81
 
@@ -154,7 +156,7 @@ class UdfSignature:
154
156
 
155
157
  @staticmethod
156
158
  def _func_signature(
157
- chain: str, udf_func: Callable
159
+ chain: str, udf_func: Union[Callable, UDFBase]
158
160
  ) -> tuple[dict[str, type], Sequence[type], bool]:
159
161
  if isinstance(udf_func, AbstractUDF):
160
162
  func = udf_func.process # type: ignore[unreachable]
@@ -42,6 +42,7 @@ from datachain.data_storage.schema import (
42
42
  )
43
43
  from datachain.dataset import DatasetStatus, RowDict
44
44
  from datachain.error import DatasetNotFoundError, QueryScriptCancelError
45
+ from datachain.lib.udf import UDFAdapter
45
46
  from datachain.progress import CombinedDownloadCallback
46
47
  from datachain.sql.functions import rand
47
48
  from datachain.utils import (
@@ -53,7 +54,6 @@ from datachain.utils import (
53
54
 
54
55
  from .schema import C, UDFParamSpec, normalize_param
55
56
  from .session import Session
56
- from .udf import UDFBase
57
57
 
58
58
  if TYPE_CHECKING:
59
59
  from sqlalchemy.sql.elements import ClauseElement
@@ -299,7 +299,7 @@ def adjust_outputs(
299
299
  return row
300
300
 
301
301
 
302
- def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFBase) -> list[tuple]:
302
+ def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]:
303
303
  """Optimization: Precompute UDF column types so these don't have to be computed
304
304
  in the convert_type function for each row in a loop."""
305
305
  dialect = warehouse.db.dialect
@@ -320,7 +320,7 @@ def process_udf_outputs(
320
320
  warehouse: "AbstractWarehouse",
321
321
  udf_table: "Table",
322
322
  udf_results: Iterator[Iterable["UDFResult"]],
323
- udf: UDFBase,
323
+ udf: UDFAdapter,
324
324
  batch_size: int = INSERT_BATCH_SIZE,
325
325
  cb: Callback = DEFAULT_CALLBACK,
326
326
  ) -> None:
@@ -364,7 +364,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
364
364
 
365
365
  @frozen
366
366
  class UDFStep(Step, ABC):
367
- udf: UDFBase
367
+ udf: UDFAdapter
368
368
  catalog: "Catalog"
369
369
  partition_by: Optional[PartitionByType] = None
370
370
  parallel: Optional[int] = None
@@ -1037,7 +1037,7 @@ class DatasetQuery:
1037
1037
  session: Optional[Session] = None,
1038
1038
  indexing_column_types: Optional[dict[str, Any]] = None,
1039
1039
  in_memory: bool = False,
1040
- ):
1040
+ ) -> None:
1041
1041
  self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
1042
1042
  self.catalog = catalog or self.session.catalog
1043
1043
  self.steps: list[Step] = []
@@ -1465,7 +1465,7 @@ class DatasetQuery:
1465
1465
  @detach
1466
1466
  def add_signals(
1467
1467
  self,
1468
- udf: UDFBase,
1468
+ udf: UDFAdapter,
1469
1469
  parallel: Optional[int] = None,
1470
1470
  workers: Union[bool, int] = False,
1471
1471
  min_task_size: Optional[int] = None,
@@ -1509,7 +1509,7 @@ class DatasetQuery:
1509
1509
  @detach
1510
1510
  def generate(
1511
1511
  self,
1512
- udf: UDFBase,
1512
+ udf: UDFAdapter,
1513
1513
  parallel: Optional[int] = None,
1514
1514
  workers: Union[bool, int] = False,
1515
1515
  min_task_size: Optional[int] = None,
@@ -13,6 +13,7 @@ from multiprocess import get_context
13
13
 
14
14
  from datachain.catalog import Catalog
15
15
  from datachain.catalog.loader import get_distributed_class
16
+ from datachain.lib.udf import UDFAdapter, UDFResult
16
17
  from datachain.query.dataset import (
17
18
  get_download_callback,
18
19
  get_generated_callback,
@@ -27,7 +28,6 @@ from datachain.query.queue import (
27
28
  put_into_queue,
28
29
  unmarshal,
29
30
  )
30
- from datachain.query.udf import UDFBase, UDFResult
31
31
  from datachain.utils import batched_it
32
32
 
33
33
  DEFAULT_BATCH_SIZE = 10000
@@ -336,7 +336,7 @@ class ProcessedCallback(Callback):
336
336
  @attrs.define
337
337
  class UDFWorker:
338
338
  catalog: Catalog
339
- udf: UDFBase
339
+ udf: UDFAdapter
340
340
  task_queue: "multiprocess.Queue"
341
341
  done_queue: "multiprocess.Queue"
342
342
  is_generator: bool
@@ -1,5 +1,8 @@
1
1
  import atexit
2
+ import logging
3
+ import os
2
4
  import re
5
+ import sys
3
6
  from typing import TYPE_CHECKING, Optional
4
7
  from uuid import uuid4
5
8
 
@@ -9,6 +12,8 @@ from datachain.error import TableMissingError
9
12
  if TYPE_CHECKING:
10
13
  from datachain.catalog import Catalog
11
14
 
15
+ logger = logging.getLogger("datachain")
16
+
12
17
 
13
18
  class Session:
14
19
  """
@@ -35,6 +40,7 @@ class Session:
35
40
 
36
41
  GLOBAL_SESSION_CTX: Optional["Session"] = None
37
42
  GLOBAL_SESSION: Optional["Session"] = None
43
+ ORIGINAL_EXCEPT_HOOK = None
38
44
 
39
45
  DATASET_PREFIX = "session_"
40
46
  GLOBAL_SESSION_NAME = "global"
@@ -58,6 +64,7 @@ class Session:
58
64
 
59
65
  session_uuid = uuid4().hex[: self.SESSION_UUID_LEN]
60
66
  self.name = f"{name}_{session_uuid}"
67
+ self.job_id = os.getenv("DATACHAIN_JOB_ID") or str(uuid4())
61
68
  self.is_new_catalog = not catalog
62
69
  self.catalog = catalog or get_catalog(
63
70
  client_config=client_config, in_memory=in_memory
@@ -67,6 +74,9 @@ class Session:
67
74
  return self
68
75
 
69
76
  def __exit__(self, exc_type, exc_val, exc_tb):
77
+ if exc_type:
78
+ self._cleanup_created_versions(self.name)
79
+
70
80
  self._cleanup_temp_datasets()
71
81
  if self.is_new_catalog:
72
82
  self.catalog.metastore.close_on_exit()
@@ -88,6 +98,21 @@ class Session:
88
98
  except TableMissingError:
89
99
  pass
90
100
 
101
+ def _cleanup_created_versions(self, job_id: str) -> None:
102
+ versions = self.catalog.metastore.get_job_dataset_versions(job_id)
103
+ if not versions:
104
+ return
105
+
106
+ datasets = {}
107
+ for dataset_name, version in versions:
108
+ if dataset_name not in datasets:
109
+ datasets[dataset_name] = self.catalog.get_dataset(dataset_name)
110
+ dataset = datasets[dataset_name]
111
+ logger.info(
112
+ "Removing dataset version %s@%s due to exception", dataset_name, version
113
+ )
114
+ self.catalog.remove_dataset_version(dataset, version)
115
+
91
116
  @classmethod
92
117
  def get(
93
118
  cls,
@@ -114,9 +139,23 @@ class Session:
114
139
  in_memory=in_memory,
115
140
  )
116
141
  cls.GLOBAL_SESSION = cls.GLOBAL_SESSION_CTX.__enter__()
142
+
117
143
  atexit.register(cls._global_cleanup)
144
+ cls.ORIGINAL_EXCEPT_HOOK = sys.excepthook
145
+ sys.excepthook = cls.except_hook
146
+
118
147
  return cls.GLOBAL_SESSION
119
148
 
149
+ @staticmethod
150
+ def except_hook(exc_type, exc_value, exc_traceback):
151
+ Session._global_cleanup()
152
+ if Session.GLOBAL_SESSION_CTX is not None:
153
+ job_id = Session.GLOBAL_SESSION_CTX.job_id
154
+ Session.GLOBAL_SESSION_CTX._cleanup_created_versions(job_id)
155
+
156
+ if Session.ORIGINAL_EXCEPT_HOOK:
157
+ Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)
158
+
120
159
  @classmethod
121
160
  def cleanup_for_tests(cls):
122
161
  if cls.GLOBAL_SESSION_CTX is not None:
@@ -125,6 +164,9 @@ class Session:
125
164
  cls.GLOBAL_SESSION_CTX = None
126
165
  atexit.unregister(cls._global_cleanup)
127
166
 
167
+ if cls.ORIGINAL_EXCEPT_HOOK:
168
+ sys.excepthook = cls.ORIGINAL_EXCEPT_HOOK
169
+
128
170
  @staticmethod
129
171
  def _global_cleanup():
130
172
  if Session.GLOBAL_SESSION_CTX is not None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: datachain
3
- Version: 0.4.0
3
+ Version: 0.5.1
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License: Apache-2.0