datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/query/queue.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import datetime
|
|
2
2
|
from collections.abc import Iterable, Iterator
|
|
3
|
-
from queue import Empty, Full
|
|
3
|
+
from queue import Empty, Full
|
|
4
4
|
from struct import pack, unpack
|
|
5
5
|
from time import sleep
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
import msgpack
|
|
9
|
+
from multiprocess.queues import Queue
|
|
9
10
|
|
|
10
|
-
from datachain.query.batch import RowsOutput
|
|
11
|
+
from datachain.query.batch import RowsOutput
|
|
11
12
|
|
|
12
13
|
DEFAULT_BATCH_SIZE = 10000
|
|
13
14
|
STOP_SIGNAL = "STOP"
|
|
@@ -56,7 +57,6 @@ def put_into_queue(queue: Queue, item: Any) -> None:
|
|
|
56
57
|
|
|
57
58
|
|
|
58
59
|
MSGPACK_EXT_TYPE_DATETIME = 42
|
|
59
|
-
MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH = 43
|
|
60
60
|
|
|
61
61
|
|
|
62
62
|
def _msgpack_pack_extended_types(obj: Any) -> msgpack.ExtType:
|
|
@@ -70,12 +70,6 @@ def _msgpack_pack_extended_types(obj: Any) -> msgpack.ExtType:
|
|
|
70
70
|
data = (obj.timestamp(),) # type: ignore # noqa: PGH003
|
|
71
71
|
return msgpack.ExtType(MSGPACK_EXT_TYPE_DATETIME, pack("!d", *data))
|
|
72
72
|
|
|
73
|
-
if isinstance(obj, RowsOutputBatch):
|
|
74
|
-
return msgpack.ExtType(
|
|
75
|
-
MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH,
|
|
76
|
-
msgpack_pack(obj.rows),
|
|
77
|
-
)
|
|
78
|
-
|
|
79
73
|
raise TypeError(f"Unknown type: {obj}")
|
|
80
74
|
|
|
81
75
|
|
|
@@ -100,9 +94,6 @@ def _msgpack_unpack_extended_types(code: int, data: bytes) -> Any:
|
|
|
100
94
|
tz_info = datetime.timezone(datetime.timedelta(seconds=timezone_offset))
|
|
101
95
|
return datetime.datetime.fromtimestamp(timestamp, tz=tz_info)
|
|
102
96
|
|
|
103
|
-
if code == MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH:
|
|
104
|
-
return RowsOutputBatch(msgpack_unpack(data))
|
|
105
|
-
|
|
106
97
|
return msgpack.ExtType(code, data)
|
|
107
98
|
|
|
108
99
|
|
datachain/query/schema.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from fnmatch import fnmatch
|
|
4
|
-
from typing import TYPE_CHECKING, Any
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
5
6
|
|
|
6
7
|
import attrs
|
|
7
8
|
import sqlalchemy as sa
|
|
@@ -36,9 +37,13 @@ class ColumnMeta(type):
|
|
|
36
37
|
def __getattr__(cls, name: str):
|
|
37
38
|
return cls(ColumnMeta.to_db_name(name))
|
|
38
39
|
|
|
40
|
+
@staticmethod
|
|
41
|
+
def is_nested(name: str) -> bool:
|
|
42
|
+
return DEFAULT_DELIMITER in name
|
|
43
|
+
|
|
39
44
|
|
|
40
45
|
class Column(sa.ColumnClause, metaclass=ColumnMeta):
|
|
41
|
-
inherit_cache:
|
|
46
|
+
inherit_cache: bool | None = True
|
|
42
47
|
|
|
43
48
|
def __init__(self, text, type_=None, is_literal=False, _selectable=None):
|
|
44
49
|
"""Dataset column."""
|
|
@@ -173,7 +178,7 @@ class LocalFilename(UDFParameter):
|
|
|
173
178
|
otherwise None will be returned.
|
|
174
179
|
"""
|
|
175
180
|
|
|
176
|
-
glob:
|
|
181
|
+
glob: str | None = None
|
|
177
182
|
|
|
178
183
|
def get_value(
|
|
179
184
|
self,
|
|
@@ -182,7 +187,7 @@ class LocalFilename(UDFParameter):
|
|
|
182
187
|
*,
|
|
183
188
|
cb: Callback = DEFAULT_CALLBACK,
|
|
184
189
|
**kwargs,
|
|
185
|
-
) ->
|
|
190
|
+
) -> str | None:
|
|
186
191
|
if self.glob and not fnmatch(row["name"], self.glob): # type: ignore[type-var]
|
|
187
192
|
# If the glob pattern is specified and the row filename
|
|
188
193
|
# does not match it, then return None
|
|
@@ -201,7 +206,7 @@ class LocalFilename(UDFParameter):
|
|
|
201
206
|
cache: bool = False,
|
|
202
207
|
cb: Callback = DEFAULT_CALLBACK,
|
|
203
208
|
**kwargs,
|
|
204
|
-
) ->
|
|
209
|
+
) -> str | None:
|
|
205
210
|
if self.glob and not fnmatch(row["name"], self.glob): # type: ignore[type-var]
|
|
206
211
|
# If the glob pattern is specified and the row filename
|
|
207
212
|
# does not match it, then return None
|
|
@@ -212,7 +217,7 @@ class LocalFilename(UDFParameter):
|
|
|
212
217
|
return client.cache.get_path(file)
|
|
213
218
|
|
|
214
219
|
|
|
215
|
-
UDFParamSpec =
|
|
220
|
+
UDFParamSpec = str | Column | UDFParameter
|
|
216
221
|
|
|
217
222
|
|
|
218
223
|
def normalize_param(param: UDFParamSpec) -> UDFParameter:
|
datachain/query/session.py
CHANGED
|
@@ -1,21 +1,37 @@
|
|
|
1
1
|
import atexit
|
|
2
|
-
import gc
|
|
3
2
|
import logging
|
|
3
|
+
import os
|
|
4
4
|
import re
|
|
5
5
|
import sys
|
|
6
|
-
|
|
6
|
+
import traceback
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
7
9
|
from uuid import uuid4
|
|
10
|
+
from weakref import WeakSet
|
|
8
11
|
|
|
9
12
|
from datachain.catalog import get_catalog
|
|
10
|
-
from datachain.
|
|
13
|
+
from datachain.data_storage import JobQueryType, JobStatus
|
|
14
|
+
from datachain.error import JobNotFoundError, TableMissingError
|
|
11
15
|
|
|
12
16
|
if TYPE_CHECKING:
|
|
13
17
|
from datachain.catalog import Catalog
|
|
14
|
-
from datachain.
|
|
18
|
+
from datachain.job import Job
|
|
15
19
|
|
|
16
20
|
logger = logging.getLogger("datachain")
|
|
17
21
|
|
|
18
22
|
|
|
23
|
+
def is_script_run() -> bool:
|
|
24
|
+
"""
|
|
25
|
+
Returns True if this was ran as python script, e.g python my_script.py.
|
|
26
|
+
Otherwise (if interactive or module run) returns False.
|
|
27
|
+
"""
|
|
28
|
+
try:
|
|
29
|
+
argv0 = sys.argv[0]
|
|
30
|
+
except (IndexError, AttributeError):
|
|
31
|
+
return False
|
|
32
|
+
return bool(argv0) and argv0 not in ("-c", "-m", "ipython")
|
|
33
|
+
|
|
34
|
+
|
|
19
35
|
class Session:
|
|
20
36
|
"""
|
|
21
37
|
Session is a context that keeps track of temporary DataChain datasets for a proper
|
|
@@ -39,10 +55,18 @@ class Session:
|
|
|
39
55
|
catalog (Catalog): Catalog object.
|
|
40
56
|
"""
|
|
41
57
|
|
|
42
|
-
GLOBAL_SESSION_CTX:
|
|
58
|
+
GLOBAL_SESSION_CTX: "Session | None" = None
|
|
43
59
|
SESSION_CONTEXTS: ClassVar[list["Session"]] = []
|
|
60
|
+
_ALL_SESSIONS: ClassVar[WeakSet["Session"]] = WeakSet()
|
|
44
61
|
ORIGINAL_EXCEPT_HOOK = None
|
|
45
62
|
|
|
63
|
+
# Job management - class-level to ensure one job per process
|
|
64
|
+
_CURRENT_JOB: ClassVar["Job | None"] = None
|
|
65
|
+
_JOB_STATUS: ClassVar[JobStatus | None] = None
|
|
66
|
+
_OWNS_JOB: ClassVar[bool | None] = None
|
|
67
|
+
_JOB_HOOKS_REGISTERED: ClassVar[bool] = False
|
|
68
|
+
_JOB_FINALIZE_HOOK: ClassVar[Callable[[], None] | None] = None
|
|
69
|
+
|
|
46
70
|
DATASET_PREFIX = "session_"
|
|
47
71
|
GLOBAL_SESSION_NAME = "global"
|
|
48
72
|
SESSION_UUID_LEN = 6
|
|
@@ -51,8 +75,8 @@ class Session:
|
|
|
51
75
|
def __init__(
|
|
52
76
|
self,
|
|
53
77
|
name="",
|
|
54
|
-
catalog:
|
|
55
|
-
client_config:
|
|
78
|
+
catalog: "Catalog | None" = None,
|
|
79
|
+
client_config: dict | None = None,
|
|
56
80
|
in_memory: bool = False,
|
|
57
81
|
):
|
|
58
82
|
if re.match(r"^[0-9a-zA-Z]*$", name) is None:
|
|
@@ -69,7 +93,7 @@ class Session:
|
|
|
69
93
|
self.catalog = catalog or get_catalog(
|
|
70
94
|
client_config=client_config, in_memory=in_memory
|
|
71
95
|
)
|
|
72
|
-
self
|
|
96
|
+
Session._ALL_SESSIONS.add(self)
|
|
73
97
|
|
|
74
98
|
def __enter__(self):
|
|
75
99
|
# Push the current context onto the stack
|
|
@@ -78,9 +102,8 @@ class Session:
|
|
|
78
102
|
return self
|
|
79
103
|
|
|
80
104
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
105
|
+
# Don't cleanup created versions on exception
|
|
106
|
+
# Datasets should persist even if the session fails
|
|
84
107
|
self._cleanup_temp_datasets()
|
|
85
108
|
if self.is_new_catalog:
|
|
86
109
|
self.catalog.metastore.close_on_exit()
|
|
@@ -88,11 +111,116 @@ class Session:
|
|
|
88
111
|
|
|
89
112
|
if Session.SESSION_CONTEXTS:
|
|
90
113
|
Session.SESSION_CONTEXTS.pop()
|
|
114
|
+
Session._ALL_SESSIONS.discard(self)
|
|
91
115
|
|
|
92
|
-
def
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
116
|
+
def get_or_create_job(self) -> "Job":
|
|
117
|
+
"""
|
|
118
|
+
Get or create a Job for this process.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Job: The active Job instance.
|
|
122
|
+
|
|
123
|
+
Behavior:
|
|
124
|
+
- If a job already exists, it is returned.
|
|
125
|
+
- If ``DATACHAIN_JOB_ID`` is set, the corresponding job is fetched.
|
|
126
|
+
- Otherwise, a new job is created:
|
|
127
|
+
* Name = absolute path to the Python script.
|
|
128
|
+
* Query = empty string.
|
|
129
|
+
* Parent = last job with the same name, if available.
|
|
130
|
+
* Status = "running".
|
|
131
|
+
Exit hooks are registered to finalize the job.
|
|
132
|
+
|
|
133
|
+
Note:
|
|
134
|
+
Job is shared across all Session instances to ensure one job per process.
|
|
135
|
+
"""
|
|
136
|
+
if Session._CURRENT_JOB:
|
|
137
|
+
return Session._CURRENT_JOB
|
|
138
|
+
|
|
139
|
+
if env_job_id := os.getenv("DATACHAIN_JOB_ID"):
|
|
140
|
+
# SaaS run: just fetch existing job
|
|
141
|
+
Session._CURRENT_JOB = self.catalog.metastore.get_job(env_job_id)
|
|
142
|
+
if not Session._CURRENT_JOB:
|
|
143
|
+
raise JobNotFoundError(
|
|
144
|
+
f"Job {env_job_id} from DATACHAIN_JOB_ID env not found"
|
|
145
|
+
)
|
|
146
|
+
Session._OWNS_JOB = False
|
|
147
|
+
else:
|
|
148
|
+
# Local run: create new job
|
|
149
|
+
if is_script_run():
|
|
150
|
+
script = os.path.abspath(sys.argv[0])
|
|
151
|
+
else:
|
|
152
|
+
# Interactive session or module run - use unique name to avoid
|
|
153
|
+
# linking unrelated sessions
|
|
154
|
+
script = str(uuid4())
|
|
155
|
+
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
|
156
|
+
|
|
157
|
+
# try to find the parent job
|
|
158
|
+
parent = self.catalog.metastore.get_last_job_by_name(script)
|
|
159
|
+
|
|
160
|
+
job_id = self.catalog.metastore.create_job(
|
|
161
|
+
name=script,
|
|
162
|
+
query="",
|
|
163
|
+
query_type=JobQueryType.PYTHON,
|
|
164
|
+
status=JobStatus.RUNNING,
|
|
165
|
+
python_version=python_version,
|
|
166
|
+
parent_job_id=parent.id if parent else None,
|
|
167
|
+
)
|
|
168
|
+
Session._CURRENT_JOB = self.catalog.metastore.get_job(job_id)
|
|
169
|
+
Session._OWNS_JOB = True
|
|
170
|
+
Session._JOB_STATUS = JobStatus.RUNNING
|
|
171
|
+
|
|
172
|
+
# register cleanup hooks only once
|
|
173
|
+
if not Session._JOB_HOOKS_REGISTERED:
|
|
174
|
+
|
|
175
|
+
def _finalize_success_hook() -> None:
|
|
176
|
+
self._finalize_job_success()
|
|
177
|
+
|
|
178
|
+
Session._JOB_FINALIZE_HOOK = _finalize_success_hook
|
|
179
|
+
atexit.register(Session._JOB_FINALIZE_HOOK)
|
|
180
|
+
Session._JOB_HOOKS_REGISTERED = True
|
|
181
|
+
|
|
182
|
+
assert Session._CURRENT_JOB is not None
|
|
183
|
+
return Session._CURRENT_JOB
|
|
184
|
+
|
|
185
|
+
def _finalize_job_success(self):
|
|
186
|
+
"""Mark the current job as completed."""
|
|
187
|
+
if (
|
|
188
|
+
Session._CURRENT_JOB
|
|
189
|
+
and Session._OWNS_JOB
|
|
190
|
+
and Session._JOB_STATUS == JobStatus.RUNNING
|
|
191
|
+
):
|
|
192
|
+
self.catalog.metastore.set_job_status(
|
|
193
|
+
Session._CURRENT_JOB.id, JobStatus.COMPLETE
|
|
194
|
+
)
|
|
195
|
+
Session._JOB_STATUS = JobStatus.COMPLETE
|
|
196
|
+
|
|
197
|
+
def _finalize_job_as_canceled(self):
|
|
198
|
+
"""Mark the current job as canceled."""
|
|
199
|
+
if (
|
|
200
|
+
Session._CURRENT_JOB
|
|
201
|
+
and Session._OWNS_JOB
|
|
202
|
+
and Session._JOB_STATUS == JobStatus.RUNNING
|
|
203
|
+
):
|
|
204
|
+
self.catalog.metastore.set_job_status(
|
|
205
|
+
Session._CURRENT_JOB.id, JobStatus.CANCELED
|
|
206
|
+
)
|
|
207
|
+
Session._JOB_STATUS = JobStatus.CANCELED
|
|
208
|
+
|
|
209
|
+
def _finalize_job_as_failed(self, exc_type, exc_value, tb):
|
|
210
|
+
"""Mark the current job as failed with error details."""
|
|
211
|
+
if (
|
|
212
|
+
Session._CURRENT_JOB
|
|
213
|
+
and Session._OWNS_JOB
|
|
214
|
+
and Session._JOB_STATUS == JobStatus.RUNNING
|
|
215
|
+
):
|
|
216
|
+
error_stack = "".join(traceback.format_exception(exc_type, exc_value, tb))
|
|
217
|
+
self.catalog.metastore.set_job_status(
|
|
218
|
+
Session._CURRENT_JOB.id,
|
|
219
|
+
JobStatus.FAILED,
|
|
220
|
+
error_message=str(exc_value),
|
|
221
|
+
error_stack=error_stack,
|
|
222
|
+
)
|
|
223
|
+
Session._JOB_STATUS = JobStatus.FAILED
|
|
96
224
|
|
|
97
225
|
def generate_temp_dataset_name(self) -> str:
|
|
98
226
|
return self.get_temp_prefix() + uuid4().hex[: self.TEMP_TABLE_UUID_LEN]
|
|
@@ -100,31 +228,25 @@ class Session:
|
|
|
100
228
|
def get_temp_prefix(self) -> str:
|
|
101
229
|
return f"{self.DATASET_PREFIX}{self.name}_"
|
|
102
230
|
|
|
231
|
+
@classmethod
|
|
232
|
+
def is_temp_dataset(cls, name) -> bool:
|
|
233
|
+
return name.startswith(cls.DATASET_PREFIX)
|
|
234
|
+
|
|
103
235
|
def _cleanup_temp_datasets(self) -> None:
|
|
104
236
|
prefix = self.get_temp_prefix()
|
|
105
237
|
try:
|
|
106
238
|
for dataset in list(self.catalog.metastore.list_datasets_by_prefix(prefix)):
|
|
107
|
-
self.catalog.remove_dataset(dataset.name, force=True)
|
|
239
|
+
self.catalog.remove_dataset(dataset.name, dataset.project, force=True)
|
|
108
240
|
# suppress error when metastore has been reset during testing
|
|
109
241
|
except TableMissingError:
|
|
110
242
|
pass
|
|
111
243
|
|
|
112
|
-
def _cleanup_created_versions(self) -> None:
|
|
113
|
-
if not self.dataset_versions:
|
|
114
|
-
return
|
|
115
|
-
|
|
116
|
-
for dataset, version, listing in self.dataset_versions:
|
|
117
|
-
if not listing:
|
|
118
|
-
self.catalog.remove_dataset_version(dataset, version)
|
|
119
|
-
|
|
120
|
-
self.dataset_versions.clear()
|
|
121
|
-
|
|
122
244
|
@classmethod
|
|
123
245
|
def get(
|
|
124
246
|
cls,
|
|
125
|
-
session:
|
|
126
|
-
catalog:
|
|
127
|
-
client_config:
|
|
247
|
+
session: "Session | None" = None,
|
|
248
|
+
catalog: "Catalog | None" = None,
|
|
249
|
+
client_config: dict | None = None,
|
|
128
250
|
in_memory: bool = False,
|
|
129
251
|
) -> "Session":
|
|
130
252
|
"""Creates a Session() object from a catalog.
|
|
@@ -169,27 +291,72 @@ class Session:
|
|
|
169
291
|
|
|
170
292
|
@staticmethod
|
|
171
293
|
def except_hook(exc_type, exc_value, exc_traceback):
|
|
172
|
-
Session.GLOBAL_SESSION_CTX
|
|
294
|
+
if Session.GLOBAL_SESSION_CTX:
|
|
295
|
+
# Handle KeyboardInterrupt specially - mark as canceled and exit with
|
|
296
|
+
# signal code
|
|
297
|
+
if exc_type is KeyboardInterrupt:
|
|
298
|
+
Session.GLOBAL_SESSION_CTX._finalize_job_as_canceled()
|
|
299
|
+
else:
|
|
300
|
+
Session.GLOBAL_SESSION_CTX._finalize_job_as_failed(
|
|
301
|
+
exc_type, exc_value, exc_traceback
|
|
302
|
+
)
|
|
303
|
+
Session.GLOBAL_SESSION_CTX.__exit__(exc_type, exc_value, exc_traceback)
|
|
304
|
+
|
|
173
305
|
Session._global_cleanup()
|
|
174
306
|
|
|
307
|
+
# Always delegate to original hook if it exists
|
|
175
308
|
if Session.ORIGINAL_EXCEPT_HOOK:
|
|
176
309
|
Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)
|
|
177
310
|
|
|
311
|
+
if exc_type is KeyboardInterrupt:
|
|
312
|
+
# Exit with SIGINT signal code (128 + 2 = 130, or -2 in subprocess terms)
|
|
313
|
+
sys.exit(130)
|
|
314
|
+
|
|
178
315
|
@classmethod
|
|
179
316
|
def cleanup_for_tests(cls):
|
|
317
|
+
cls._close_all_contexts()
|
|
180
318
|
if cls.GLOBAL_SESSION_CTX is not None:
|
|
181
319
|
cls.GLOBAL_SESSION_CTX.__exit__(None, None, None)
|
|
182
320
|
cls.GLOBAL_SESSION_CTX = None
|
|
183
321
|
atexit.unregister(cls._global_cleanup)
|
|
184
322
|
|
|
323
|
+
# Reset job-related class variables
|
|
324
|
+
if cls._JOB_FINALIZE_HOOK:
|
|
325
|
+
try:
|
|
326
|
+
atexit.unregister(cls._JOB_FINALIZE_HOOK)
|
|
327
|
+
except ValueError:
|
|
328
|
+
pass # Hook was not registered
|
|
329
|
+
cls._CURRENT_JOB = None
|
|
330
|
+
cls._JOB_STATUS = None
|
|
331
|
+
cls._OWNS_JOB = None
|
|
332
|
+
cls._JOB_HOOKS_REGISTERED = False
|
|
333
|
+
cls._JOB_FINALIZE_HOOK = None
|
|
334
|
+
|
|
185
335
|
if cls.ORIGINAL_EXCEPT_HOOK:
|
|
186
336
|
sys.excepthook = cls.ORIGINAL_EXCEPT_HOOK
|
|
187
337
|
|
|
188
338
|
@staticmethod
|
|
189
339
|
def _global_cleanup():
|
|
340
|
+
Session._close_all_contexts()
|
|
190
341
|
if Session.GLOBAL_SESSION_CTX is not None:
|
|
191
342
|
Session.GLOBAL_SESSION_CTX.__exit__(None, None, None)
|
|
192
343
|
|
|
193
|
-
for
|
|
194
|
-
|
|
195
|
-
|
|
344
|
+
for session in list(Session._ALL_SESSIONS):
|
|
345
|
+
try:
|
|
346
|
+
session.__exit__(None, None, None)
|
|
347
|
+
except ReferenceError:
|
|
348
|
+
continue # Object has been finalized already
|
|
349
|
+
except Exception as e: # noqa: BLE001
|
|
350
|
+
logger.error(f"Exception while cleaning up session: {e}") # noqa: G004
|
|
351
|
+
|
|
352
|
+
@classmethod
|
|
353
|
+
def _close_all_contexts(cls) -> None:
|
|
354
|
+
while cls.SESSION_CONTEXTS:
|
|
355
|
+
session = cls.SESSION_CONTEXTS.pop()
|
|
356
|
+
try:
|
|
357
|
+
session.__exit__(None, None, None)
|
|
358
|
+
except Exception as exc: # noqa: BLE001
|
|
359
|
+
logger.error(
|
|
360
|
+
"Exception while closing session context during cleanup: %s",
|
|
361
|
+
exc,
|
|
362
|
+
)
|
datachain/query/udf.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
|
-
from
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
2
4
|
|
|
3
5
|
if TYPE_CHECKING:
|
|
4
6
|
from sqlalchemy import Select, Table
|
|
5
7
|
|
|
8
|
+
from datachain.catalog import Catalog
|
|
6
9
|
from datachain.query.batch import BatchingStrategy
|
|
7
10
|
|
|
8
11
|
|
|
@@ -15,6 +18,35 @@ class UdfInfo(TypedDict):
|
|
|
15
18
|
query: "Select"
|
|
16
19
|
udf_fields: list[str]
|
|
17
20
|
batching: "BatchingStrategy"
|
|
18
|
-
processes:
|
|
21
|
+
processes: int | None
|
|
19
22
|
is_generator: bool
|
|
20
23
|
cache: bool
|
|
24
|
+
rows_total: int
|
|
25
|
+
batch_size: int
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AbstractUDFDistributor(ABC):
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
catalog: "Catalog",
|
|
33
|
+
table: "Table",
|
|
34
|
+
query: "Select",
|
|
35
|
+
udf_data: bytes,
|
|
36
|
+
batching: "BatchingStrategy",
|
|
37
|
+
workers: bool | int,
|
|
38
|
+
processes: bool | int,
|
|
39
|
+
udf_fields: list[str],
|
|
40
|
+
rows_total: int,
|
|
41
|
+
use_cache: bool,
|
|
42
|
+
is_generator: bool = False,
|
|
43
|
+
min_task_size: str | int | None = None,
|
|
44
|
+
batch_size: int | None = None,
|
|
45
|
+
) -> None: ...
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def __call__(self) -> None: ...
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def run_udf() -> int: ...
|