datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/query/queue.py CHANGED
@@ -1,13 +1,14 @@
1
1
  import datetime
2
2
  from collections.abc import Iterable, Iterator
3
- from queue import Empty, Full, Queue
3
+ from queue import Empty, Full
4
4
  from struct import pack, unpack
5
5
  from time import sleep
6
6
  from typing import Any
7
7
 
8
8
  import msgpack
9
+ from multiprocess.queues import Queue
9
10
 
10
- from datachain.query.batch import RowsOutput, RowsOutputBatch
11
+ from datachain.query.batch import RowsOutput
11
12
 
12
13
  DEFAULT_BATCH_SIZE = 10000
13
14
  STOP_SIGNAL = "STOP"
@@ -56,7 +57,6 @@ def put_into_queue(queue: Queue, item: Any) -> None:
56
57
 
57
58
 
58
59
  MSGPACK_EXT_TYPE_DATETIME = 42
59
- MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH = 43
60
60
 
61
61
 
62
62
  def _msgpack_pack_extended_types(obj: Any) -> msgpack.ExtType:
@@ -70,12 +70,6 @@ def _msgpack_pack_extended_types(obj: Any) -> msgpack.ExtType:
70
70
  data = (obj.timestamp(),) # type: ignore # noqa: PGH003
71
71
  return msgpack.ExtType(MSGPACK_EXT_TYPE_DATETIME, pack("!d", *data))
72
72
 
73
- if isinstance(obj, RowsOutputBatch):
74
- return msgpack.ExtType(
75
- MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH,
76
- msgpack_pack(obj.rows),
77
- )
78
-
79
73
  raise TypeError(f"Unknown type: {obj}")
80
74
 
81
75
 
@@ -100,9 +94,6 @@ def _msgpack_unpack_extended_types(code: int, data: bytes) -> Any:
100
94
  tz_info = datetime.timezone(datetime.timedelta(seconds=timezone_offset))
101
95
  return datetime.datetime.fromtimestamp(timestamp, tz=tz_info)
102
96
 
103
- if code == MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH:
104
- return RowsOutputBatch(msgpack_unpack(data))
105
-
106
97
  return msgpack.ExtType(code, data)
107
98
 
108
99
 
datachain/query/schema.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import functools
2
2
  from abc import ABC, abstractmethod
3
+ from collections.abc import Callable
3
4
  from fnmatch import fnmatch
4
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
5
+ from typing import TYPE_CHECKING, Any
5
6
 
6
7
  import attrs
7
8
  import sqlalchemy as sa
@@ -36,9 +37,13 @@ class ColumnMeta(type):
36
37
  def __getattr__(cls, name: str):
37
38
  return cls(ColumnMeta.to_db_name(name))
38
39
 
40
+ @staticmethod
41
+ def is_nested(name: str) -> bool:
42
+ return DEFAULT_DELIMITER in name
43
+
39
44
 
40
45
  class Column(sa.ColumnClause, metaclass=ColumnMeta):
41
- inherit_cache: Optional[bool] = True
46
+ inherit_cache: bool | None = True
42
47
 
43
48
  def __init__(self, text, type_=None, is_literal=False, _selectable=None):
44
49
  """Dataset column."""
@@ -173,7 +178,7 @@ class LocalFilename(UDFParameter):
173
178
  otherwise None will be returned.
174
179
  """
175
180
 
176
- glob: Optional[str] = None
181
+ glob: str | None = None
177
182
 
178
183
  def get_value(
179
184
  self,
@@ -182,7 +187,7 @@ class LocalFilename(UDFParameter):
182
187
  *,
183
188
  cb: Callback = DEFAULT_CALLBACK,
184
189
  **kwargs,
185
- ) -> Optional[str]:
190
+ ) -> str | None:
186
191
  if self.glob and not fnmatch(row["name"], self.glob): # type: ignore[type-var]
187
192
  # If the glob pattern is specified and the row filename
188
193
  # does not match it, then return None
@@ -201,7 +206,7 @@ class LocalFilename(UDFParameter):
201
206
  cache: bool = False,
202
207
  cb: Callback = DEFAULT_CALLBACK,
203
208
  **kwargs,
204
- ) -> Optional[str]:
209
+ ) -> str | None:
205
210
  if self.glob and not fnmatch(row["name"], self.glob): # type: ignore[type-var]
206
211
  # If the glob pattern is specified and the row filename
207
212
  # does not match it, then return None
@@ -212,7 +217,7 @@ class LocalFilename(UDFParameter):
212
217
  return client.cache.get_path(file)
213
218
 
214
219
 
215
- UDFParamSpec = Union[str, Column, UDFParameter]
220
+ UDFParamSpec = str | Column | UDFParameter
216
221
 
217
222
 
218
223
  def normalize_param(param: UDFParamSpec) -> UDFParameter:
@@ -1,21 +1,37 @@
1
1
  import atexit
2
- import gc
3
2
  import logging
3
+ import os
4
4
  import re
5
5
  import sys
6
- from typing import TYPE_CHECKING, ClassVar, Optional
6
+ import traceback
7
+ from collections.abc import Callable
8
+ from typing import TYPE_CHECKING, ClassVar
7
9
  from uuid import uuid4
10
+ from weakref import WeakSet
8
11
 
9
12
  from datachain.catalog import get_catalog
10
- from datachain.error import TableMissingError
13
+ from datachain.data_storage import JobQueryType, JobStatus
14
+ from datachain.error import JobNotFoundError, TableMissingError
11
15
 
12
16
  if TYPE_CHECKING:
13
17
  from datachain.catalog import Catalog
14
- from datachain.dataset import DatasetRecord
18
+ from datachain.job import Job
15
19
 
16
20
  logger = logging.getLogger("datachain")
17
21
 
18
22
 
23
+ def is_script_run() -> bool:
24
+ """
25
+ Returns True if this was ran as python script, e.g python my_script.py.
26
+ Otherwise (if interactive or module run) returns False.
27
+ """
28
+ try:
29
+ argv0 = sys.argv[0]
30
+ except (IndexError, AttributeError):
31
+ return False
32
+ return bool(argv0) and argv0 not in ("-c", "-m", "ipython")
33
+
34
+
19
35
  class Session:
20
36
  """
21
37
  Session is a context that keeps track of temporary DataChain datasets for a proper
@@ -39,10 +55,18 @@ class Session:
39
55
  catalog (Catalog): Catalog object.
40
56
  """
41
57
 
42
- GLOBAL_SESSION_CTX: Optional["Session"] = None
58
+ GLOBAL_SESSION_CTX: "Session | None" = None
43
59
  SESSION_CONTEXTS: ClassVar[list["Session"]] = []
60
+ _ALL_SESSIONS: ClassVar[WeakSet["Session"]] = WeakSet()
44
61
  ORIGINAL_EXCEPT_HOOK = None
45
62
 
63
+ # Job management - class-level to ensure one job per process
64
+ _CURRENT_JOB: ClassVar["Job | None"] = None
65
+ _JOB_STATUS: ClassVar[JobStatus | None] = None
66
+ _OWNS_JOB: ClassVar[bool | None] = None
67
+ _JOB_HOOKS_REGISTERED: ClassVar[bool] = False
68
+ _JOB_FINALIZE_HOOK: ClassVar[Callable[[], None] | None] = None
69
+
46
70
  DATASET_PREFIX = "session_"
47
71
  GLOBAL_SESSION_NAME = "global"
48
72
  SESSION_UUID_LEN = 6
@@ -51,8 +75,8 @@ class Session:
51
75
  def __init__(
52
76
  self,
53
77
  name="",
54
- catalog: Optional["Catalog"] = None,
55
- client_config: Optional[dict] = None,
78
+ catalog: "Catalog | None" = None,
79
+ client_config: dict | None = None,
56
80
  in_memory: bool = False,
57
81
  ):
58
82
  if re.match(r"^[0-9a-zA-Z]*$", name) is None:
@@ -69,7 +93,7 @@ class Session:
69
93
  self.catalog = catalog or get_catalog(
70
94
  client_config=client_config, in_memory=in_memory
71
95
  )
72
- self.dataset_versions: list[tuple[DatasetRecord, int, bool]] = []
96
+ Session._ALL_SESSIONS.add(self)
73
97
 
74
98
  def __enter__(self):
75
99
  # Push the current context onto the stack
@@ -78,9 +102,8 @@ class Session:
78
102
  return self
79
103
 
80
104
  def __exit__(self, exc_type, exc_val, exc_tb):
81
- if exc_type:
82
- self._cleanup_created_versions()
83
-
105
+ # Don't cleanup created versions on exception
106
+ # Datasets should persist even if the session fails
84
107
  self._cleanup_temp_datasets()
85
108
  if self.is_new_catalog:
86
109
  self.catalog.metastore.close_on_exit()
@@ -88,11 +111,116 @@ class Session:
88
111
 
89
112
  if Session.SESSION_CONTEXTS:
90
113
  Session.SESSION_CONTEXTS.pop()
114
+ Session._ALL_SESSIONS.discard(self)
91
115
 
92
- def add_dataset_version(
93
- self, dataset: "DatasetRecord", version: int, listing: bool = False
94
- ) -> None:
95
- self.dataset_versions.append((dataset, version, listing))
116
+ def get_or_create_job(self) -> "Job":
117
+ """
118
+ Get or create a Job for this process.
119
+
120
+ Returns:
121
+ Job: The active Job instance.
122
+
123
+ Behavior:
124
+ - If a job already exists, it is returned.
125
+ - If ``DATACHAIN_JOB_ID`` is set, the corresponding job is fetched.
126
+ - Otherwise, a new job is created:
127
+ * Name = absolute path to the Python script.
128
+ * Query = empty string.
129
+ * Parent = last job with the same name, if available.
130
+ * Status = "running".
131
+ Exit hooks are registered to finalize the job.
132
+
133
+ Note:
134
+ Job is shared across all Session instances to ensure one job per process.
135
+ """
136
+ if Session._CURRENT_JOB:
137
+ return Session._CURRENT_JOB
138
+
139
+ if env_job_id := os.getenv("DATACHAIN_JOB_ID"):
140
+ # SaaS run: just fetch existing job
141
+ Session._CURRENT_JOB = self.catalog.metastore.get_job(env_job_id)
142
+ if not Session._CURRENT_JOB:
143
+ raise JobNotFoundError(
144
+ f"Job {env_job_id} from DATACHAIN_JOB_ID env not found"
145
+ )
146
+ Session._OWNS_JOB = False
147
+ else:
148
+ # Local run: create new job
149
+ if is_script_run():
150
+ script = os.path.abspath(sys.argv[0])
151
+ else:
152
+ # Interactive session or module run - use unique name to avoid
153
+ # linking unrelated sessions
154
+ script = str(uuid4())
155
+ python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
156
+
157
+ # try to find the parent job
158
+ parent = self.catalog.metastore.get_last_job_by_name(script)
159
+
160
+ job_id = self.catalog.metastore.create_job(
161
+ name=script,
162
+ query="",
163
+ query_type=JobQueryType.PYTHON,
164
+ status=JobStatus.RUNNING,
165
+ python_version=python_version,
166
+ parent_job_id=parent.id if parent else None,
167
+ )
168
+ Session._CURRENT_JOB = self.catalog.metastore.get_job(job_id)
169
+ Session._OWNS_JOB = True
170
+ Session._JOB_STATUS = JobStatus.RUNNING
171
+
172
+ # register cleanup hooks only once
173
+ if not Session._JOB_HOOKS_REGISTERED:
174
+
175
+ def _finalize_success_hook() -> None:
176
+ self._finalize_job_success()
177
+
178
+ Session._JOB_FINALIZE_HOOK = _finalize_success_hook
179
+ atexit.register(Session._JOB_FINALIZE_HOOK)
180
+ Session._JOB_HOOKS_REGISTERED = True
181
+
182
+ assert Session._CURRENT_JOB is not None
183
+ return Session._CURRENT_JOB
184
+
185
+ def _finalize_job_success(self):
186
+ """Mark the current job as completed."""
187
+ if (
188
+ Session._CURRENT_JOB
189
+ and Session._OWNS_JOB
190
+ and Session._JOB_STATUS == JobStatus.RUNNING
191
+ ):
192
+ self.catalog.metastore.set_job_status(
193
+ Session._CURRENT_JOB.id, JobStatus.COMPLETE
194
+ )
195
+ Session._JOB_STATUS = JobStatus.COMPLETE
196
+
197
+ def _finalize_job_as_canceled(self):
198
+ """Mark the current job as canceled."""
199
+ if (
200
+ Session._CURRENT_JOB
201
+ and Session._OWNS_JOB
202
+ and Session._JOB_STATUS == JobStatus.RUNNING
203
+ ):
204
+ self.catalog.metastore.set_job_status(
205
+ Session._CURRENT_JOB.id, JobStatus.CANCELED
206
+ )
207
+ Session._JOB_STATUS = JobStatus.CANCELED
208
+
209
+ def _finalize_job_as_failed(self, exc_type, exc_value, tb):
210
+ """Mark the current job as failed with error details."""
211
+ if (
212
+ Session._CURRENT_JOB
213
+ and Session._OWNS_JOB
214
+ and Session._JOB_STATUS == JobStatus.RUNNING
215
+ ):
216
+ error_stack = "".join(traceback.format_exception(exc_type, exc_value, tb))
217
+ self.catalog.metastore.set_job_status(
218
+ Session._CURRENT_JOB.id,
219
+ JobStatus.FAILED,
220
+ error_message=str(exc_value),
221
+ error_stack=error_stack,
222
+ )
223
+ Session._JOB_STATUS = JobStatus.FAILED
96
224
 
97
225
  def generate_temp_dataset_name(self) -> str:
98
226
  return self.get_temp_prefix() + uuid4().hex[: self.TEMP_TABLE_UUID_LEN]
@@ -100,31 +228,25 @@ class Session:
100
228
  def get_temp_prefix(self) -> str:
101
229
  return f"{self.DATASET_PREFIX}{self.name}_"
102
230
 
231
+ @classmethod
232
+ def is_temp_dataset(cls, name) -> bool:
233
+ return name.startswith(cls.DATASET_PREFIX)
234
+
103
235
  def _cleanup_temp_datasets(self) -> None:
104
236
  prefix = self.get_temp_prefix()
105
237
  try:
106
238
  for dataset in list(self.catalog.metastore.list_datasets_by_prefix(prefix)):
107
- self.catalog.remove_dataset(dataset.name, force=True)
239
+ self.catalog.remove_dataset(dataset.name, dataset.project, force=True)
108
240
  # suppress error when metastore has been reset during testing
109
241
  except TableMissingError:
110
242
  pass
111
243
 
112
- def _cleanup_created_versions(self) -> None:
113
- if not self.dataset_versions:
114
- return
115
-
116
- for dataset, version, listing in self.dataset_versions:
117
- if not listing:
118
- self.catalog.remove_dataset_version(dataset, version)
119
-
120
- self.dataset_versions.clear()
121
-
122
244
  @classmethod
123
245
  def get(
124
246
  cls,
125
- session: Optional["Session"] = None,
126
- catalog: Optional["Catalog"] = None,
127
- client_config: Optional[dict] = None,
247
+ session: "Session | None" = None,
248
+ catalog: "Catalog | None" = None,
249
+ client_config: dict | None = None,
128
250
  in_memory: bool = False,
129
251
  ) -> "Session":
130
252
  """Creates a Session() object from a catalog.
@@ -169,27 +291,72 @@ class Session:
169
291
 
170
292
  @staticmethod
171
293
  def except_hook(exc_type, exc_value, exc_traceback):
172
- Session.GLOBAL_SESSION_CTX.__exit__(exc_type, exc_value, exc_traceback)
294
+ if Session.GLOBAL_SESSION_CTX:
295
+ # Handle KeyboardInterrupt specially - mark as canceled and exit with
296
+ # signal code
297
+ if exc_type is KeyboardInterrupt:
298
+ Session.GLOBAL_SESSION_CTX._finalize_job_as_canceled()
299
+ else:
300
+ Session.GLOBAL_SESSION_CTX._finalize_job_as_failed(
301
+ exc_type, exc_value, exc_traceback
302
+ )
303
+ Session.GLOBAL_SESSION_CTX.__exit__(exc_type, exc_value, exc_traceback)
304
+
173
305
  Session._global_cleanup()
174
306
 
307
+ # Always delegate to original hook if it exists
175
308
  if Session.ORIGINAL_EXCEPT_HOOK:
176
309
  Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)
177
310
 
311
+ if exc_type is KeyboardInterrupt:
312
+ # Exit with SIGINT signal code (128 + 2 = 130, or -2 in subprocess terms)
313
+ sys.exit(130)
314
+
178
315
  @classmethod
179
316
  def cleanup_for_tests(cls):
317
+ cls._close_all_contexts()
180
318
  if cls.GLOBAL_SESSION_CTX is not None:
181
319
  cls.GLOBAL_SESSION_CTX.__exit__(None, None, None)
182
320
  cls.GLOBAL_SESSION_CTX = None
183
321
  atexit.unregister(cls._global_cleanup)
184
322
 
323
+ # Reset job-related class variables
324
+ if cls._JOB_FINALIZE_HOOK:
325
+ try:
326
+ atexit.unregister(cls._JOB_FINALIZE_HOOK)
327
+ except ValueError:
328
+ pass # Hook was not registered
329
+ cls._CURRENT_JOB = None
330
+ cls._JOB_STATUS = None
331
+ cls._OWNS_JOB = None
332
+ cls._JOB_HOOKS_REGISTERED = False
333
+ cls._JOB_FINALIZE_HOOK = None
334
+
185
335
  if cls.ORIGINAL_EXCEPT_HOOK:
186
336
  sys.excepthook = cls.ORIGINAL_EXCEPT_HOOK
187
337
 
188
338
  @staticmethod
189
339
  def _global_cleanup():
340
+ Session._close_all_contexts()
190
341
  if Session.GLOBAL_SESSION_CTX is not None:
191
342
  Session.GLOBAL_SESSION_CTX.__exit__(None, None, None)
192
343
 
193
- for obj in gc.get_objects(): # Get all tracked objects
194
- if isinstance(obj, Session): # Cleanup temp dataset for session variables.
195
- obj.__exit__(None, None, None)
344
+ for session in list(Session._ALL_SESSIONS):
345
+ try:
346
+ session.__exit__(None, None, None)
347
+ except ReferenceError:
348
+ continue # Object has been finalized already
349
+ except Exception as e: # noqa: BLE001
350
+ logger.error(f"Exception while cleaning up session: {e}") # noqa: G004
351
+
352
+ @classmethod
353
+ def _close_all_contexts(cls) -> None:
354
+ while cls.SESSION_CONTEXTS:
355
+ session = cls.SESSION_CONTEXTS.pop()
356
+ try:
357
+ session.__exit__(None, None, None)
358
+ except Exception as exc: # noqa: BLE001
359
+ logger.error(
360
+ "Exception while closing session context during cleanup: %s",
361
+ exc,
362
+ )
datachain/query/udf.py CHANGED
@@ -1,8 +1,11 @@
1
- from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Callable
3
+ from typing import TYPE_CHECKING, Any, TypedDict
2
4
 
3
5
  if TYPE_CHECKING:
4
6
  from sqlalchemy import Select, Table
5
7
 
8
+ from datachain.catalog import Catalog
6
9
  from datachain.query.batch import BatchingStrategy
7
10
 
8
11
 
@@ -15,6 +18,35 @@ class UdfInfo(TypedDict):
15
18
  query: "Select"
16
19
  udf_fields: list[str]
17
20
  batching: "BatchingStrategy"
18
- processes: Optional[int]
21
+ processes: int | None
19
22
  is_generator: bool
20
23
  cache: bool
24
+ rows_total: int
25
+ batch_size: int
26
+
27
+
28
+ class AbstractUDFDistributor(ABC):
29
+ @abstractmethod
30
+ def __init__(
31
+ self,
32
+ catalog: "Catalog",
33
+ table: "Table",
34
+ query: "Select",
35
+ udf_data: bytes,
36
+ batching: "BatchingStrategy",
37
+ workers: bool | int,
38
+ processes: bool | int,
39
+ udf_fields: list[str],
40
+ rows_total: int,
41
+ use_cache: bool,
42
+ is_generator: bool = False,
43
+ min_task_size: str | int | None = None,
44
+ batch_size: int | None = None,
45
+ ) -> None: ...
46
+
47
+ @abstractmethod
48
+ def __call__(self) -> None: ...
49
+
50
+ @staticmethod
51
+ @abstractmethod
52
+ def run_udf() -> int: ...