datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +4 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/__init__.py +0 -2
- datachain/catalog/catalog.py +276 -354
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +8 -3
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +10 -17
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +42 -27
- datachain/cli/commands/ls.py +15 -15
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/__init__.py +3 -43
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +34 -23
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +157 -0
- datachain/client/local.py +11 -7
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +2 -0
- datachain/data_storage/metastore.py +716 -137
- datachain/data_storage/schema.py +20 -27
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +114 -114
- datachain/data_storage/warehouse.py +140 -48
- datachain/dataset.py +109 -89
- datachain/delta.py +117 -42
- datachain/diff/__init__.py +25 -33
- datachain/error.py +24 -0
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +63 -45
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +18 -15
- datachain/lib/audio.py +60 -59
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/values_to_tuples.py +151 -53
- datachain/lib/data_model.py +23 -19
- datachain/lib/dataset_info.py +7 -7
- datachain/lib/dc/__init__.py +2 -1
- datachain/lib/dc/csv.py +22 -26
- datachain/lib/dc/database.py +37 -34
- datachain/lib/dc/datachain.py +518 -324
- datachain/lib/dc/datasets.py +38 -30
- datachain/lib/dc/hf.py +16 -20
- datachain/lib/dc/json.py +17 -18
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +33 -21
- datachain/lib/dc/records.py +9 -13
- datachain/lib/dc/storage.py +103 -65
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +17 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +187 -50
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +2 -3
- datachain/lib/model_store.py +20 -8
- datachain/lib/namespaces.py +59 -7
- datachain/lib/projects.py +51 -9
- datachain/lib/pytorch.py +31 -23
- datachain/lib/settings.py +188 -85
- datachain/lib/signal_schema.py +302 -64
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +103 -63
- datachain/lib/udf_signature.py +59 -34
- datachain/lib/utils.py +20 -0
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +31 -36
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +12 -5
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +22 -3
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +4 -4
- datachain/query/batch.py +10 -12
- datachain/query/dataset.py +376 -194
- datachain/query/dispatch.py +112 -84
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/queue.py +2 -1
- datachain/query/schema.py +7 -6
- datachain/query/session.py +190 -33
- datachain/query/udf.py +9 -6
- datachain/remote/studio.py +90 -53
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +37 -25
- datachain/sql/sqlite/types.py +1 -1
- datachain/sql/types.py +36 -5
- datachain/studio.py +49 -40
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +39 -48
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
- datachain-0.39.0.dist-info/RECORD +173 -0
- datachain/cli/commands/query.py +0 -54
- datachain/query/utils.py +0 -36
- datachain-0.30.5.dist-info/RECORD +0 -168
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/query/session.py
CHANGED
|
@@ -1,21 +1,37 @@
|
|
|
1
1
|
import atexit
|
|
2
|
-
import gc
|
|
3
2
|
import logging
|
|
3
|
+
import os
|
|
4
4
|
import re
|
|
5
5
|
import sys
|
|
6
|
-
|
|
6
|
+
import traceback
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
7
9
|
from uuid import uuid4
|
|
10
|
+
from weakref import WeakSet
|
|
8
11
|
|
|
9
12
|
from datachain.catalog import get_catalog
|
|
10
|
-
from datachain.
|
|
13
|
+
from datachain.data_storage import JobQueryType, JobStatus
|
|
14
|
+
from datachain.error import JobNotFoundError, TableMissingError
|
|
11
15
|
|
|
12
16
|
if TYPE_CHECKING:
|
|
13
17
|
from datachain.catalog import Catalog
|
|
14
|
-
from datachain.
|
|
18
|
+
from datachain.job import Job
|
|
15
19
|
|
|
16
20
|
logger = logging.getLogger("datachain")
|
|
17
21
|
|
|
18
22
|
|
|
23
|
+
def is_script_run() -> bool:
|
|
24
|
+
"""
|
|
25
|
+
Returns True if this was ran as python script, e.g python my_script.py.
|
|
26
|
+
Otherwise (if interactive or module run) returns False.
|
|
27
|
+
"""
|
|
28
|
+
try:
|
|
29
|
+
argv0 = sys.argv[0]
|
|
30
|
+
except (IndexError, AttributeError):
|
|
31
|
+
return False
|
|
32
|
+
return bool(argv0) and argv0 not in ("-c", "-m", "ipython")
|
|
33
|
+
|
|
34
|
+
|
|
19
35
|
class Session:
|
|
20
36
|
"""
|
|
21
37
|
Session is a context that keeps track of temporary DataChain datasets for a proper
|
|
@@ -39,10 +55,18 @@ class Session:
|
|
|
39
55
|
catalog (Catalog): Catalog object.
|
|
40
56
|
"""
|
|
41
57
|
|
|
42
|
-
GLOBAL_SESSION_CTX:
|
|
58
|
+
GLOBAL_SESSION_CTX: "Session | None" = None
|
|
43
59
|
SESSION_CONTEXTS: ClassVar[list["Session"]] = []
|
|
60
|
+
_ALL_SESSIONS: ClassVar[WeakSet["Session"]] = WeakSet()
|
|
44
61
|
ORIGINAL_EXCEPT_HOOK = None
|
|
45
62
|
|
|
63
|
+
# Job management - class-level to ensure one job per process
|
|
64
|
+
_CURRENT_JOB: ClassVar["Job | None"] = None
|
|
65
|
+
_JOB_STATUS: ClassVar[JobStatus | None] = None
|
|
66
|
+
_OWNS_JOB: ClassVar[bool | None] = None
|
|
67
|
+
_JOB_HOOKS_REGISTERED: ClassVar[bool] = False
|
|
68
|
+
_JOB_FINALIZE_HOOK: ClassVar[Callable[[], None] | None] = None
|
|
69
|
+
|
|
46
70
|
DATASET_PREFIX = "session_"
|
|
47
71
|
GLOBAL_SESSION_NAME = "global"
|
|
48
72
|
SESSION_UUID_LEN = 6
|
|
@@ -51,8 +75,8 @@ class Session:
|
|
|
51
75
|
def __init__(
|
|
52
76
|
self,
|
|
53
77
|
name="",
|
|
54
|
-
catalog:
|
|
55
|
-
client_config:
|
|
78
|
+
catalog: "Catalog | None" = None,
|
|
79
|
+
client_config: dict | None = None,
|
|
56
80
|
in_memory: bool = False,
|
|
57
81
|
):
|
|
58
82
|
if re.match(r"^[0-9a-zA-Z]*$", name) is None:
|
|
@@ -69,7 +93,7 @@ class Session:
|
|
|
69
93
|
self.catalog = catalog or get_catalog(
|
|
70
94
|
client_config=client_config, in_memory=in_memory
|
|
71
95
|
)
|
|
72
|
-
self
|
|
96
|
+
Session._ALL_SESSIONS.add(self)
|
|
73
97
|
|
|
74
98
|
def __enter__(self):
|
|
75
99
|
# Push the current context onto the stack
|
|
@@ -78,9 +102,8 @@ class Session:
|
|
|
78
102
|
return self
|
|
79
103
|
|
|
80
104
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
105
|
+
# Don't cleanup created versions on exception
|
|
106
|
+
# Datasets should persist even if the session fails
|
|
84
107
|
self._cleanup_temp_datasets()
|
|
85
108
|
if self.is_new_catalog:
|
|
86
109
|
self.catalog.metastore.close_on_exit()
|
|
@@ -88,11 +111,116 @@ class Session:
|
|
|
88
111
|
|
|
89
112
|
if Session.SESSION_CONTEXTS:
|
|
90
113
|
Session.SESSION_CONTEXTS.pop()
|
|
114
|
+
Session._ALL_SESSIONS.discard(self)
|
|
91
115
|
|
|
92
|
-
def
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
116
|
+
def get_or_create_job(self) -> "Job":
|
|
117
|
+
"""
|
|
118
|
+
Get or create a Job for this process.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Job: The active Job instance.
|
|
122
|
+
|
|
123
|
+
Behavior:
|
|
124
|
+
- If a job already exists, it is returned.
|
|
125
|
+
- If ``DATACHAIN_JOB_ID`` is set, the corresponding job is fetched.
|
|
126
|
+
- Otherwise, a new job is created:
|
|
127
|
+
* Name = absolute path to the Python script.
|
|
128
|
+
* Query = empty string.
|
|
129
|
+
* Parent = last job with the same name, if available.
|
|
130
|
+
* Status = "running".
|
|
131
|
+
Exit hooks are registered to finalize the job.
|
|
132
|
+
|
|
133
|
+
Note:
|
|
134
|
+
Job is shared across all Session instances to ensure one job per process.
|
|
135
|
+
"""
|
|
136
|
+
if Session._CURRENT_JOB:
|
|
137
|
+
return Session._CURRENT_JOB
|
|
138
|
+
|
|
139
|
+
if env_job_id := os.getenv("DATACHAIN_JOB_ID"):
|
|
140
|
+
# SaaS run: just fetch existing job
|
|
141
|
+
Session._CURRENT_JOB = self.catalog.metastore.get_job(env_job_id)
|
|
142
|
+
if not Session._CURRENT_JOB:
|
|
143
|
+
raise JobNotFoundError(
|
|
144
|
+
f"Job {env_job_id} from DATACHAIN_JOB_ID env not found"
|
|
145
|
+
)
|
|
146
|
+
Session._OWNS_JOB = False
|
|
147
|
+
else:
|
|
148
|
+
# Local run: create new job
|
|
149
|
+
if is_script_run():
|
|
150
|
+
script = os.path.abspath(sys.argv[0])
|
|
151
|
+
else:
|
|
152
|
+
# Interactive session or module run - use unique name to avoid
|
|
153
|
+
# linking unrelated sessions
|
|
154
|
+
script = str(uuid4())
|
|
155
|
+
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
|
156
|
+
|
|
157
|
+
# try to find the parent job
|
|
158
|
+
parent = self.catalog.metastore.get_last_job_by_name(script)
|
|
159
|
+
|
|
160
|
+
job_id = self.catalog.metastore.create_job(
|
|
161
|
+
name=script,
|
|
162
|
+
query="",
|
|
163
|
+
query_type=JobQueryType.PYTHON,
|
|
164
|
+
status=JobStatus.RUNNING,
|
|
165
|
+
python_version=python_version,
|
|
166
|
+
parent_job_id=parent.id if parent else None,
|
|
167
|
+
)
|
|
168
|
+
Session._CURRENT_JOB = self.catalog.metastore.get_job(job_id)
|
|
169
|
+
Session._OWNS_JOB = True
|
|
170
|
+
Session._JOB_STATUS = JobStatus.RUNNING
|
|
171
|
+
|
|
172
|
+
# register cleanup hooks only once
|
|
173
|
+
if not Session._JOB_HOOKS_REGISTERED:
|
|
174
|
+
|
|
175
|
+
def _finalize_success_hook() -> None:
|
|
176
|
+
self._finalize_job_success()
|
|
177
|
+
|
|
178
|
+
Session._JOB_FINALIZE_HOOK = _finalize_success_hook
|
|
179
|
+
atexit.register(Session._JOB_FINALIZE_HOOK)
|
|
180
|
+
Session._JOB_HOOKS_REGISTERED = True
|
|
181
|
+
|
|
182
|
+
assert Session._CURRENT_JOB is not None
|
|
183
|
+
return Session._CURRENT_JOB
|
|
184
|
+
|
|
185
|
+
def _finalize_job_success(self):
|
|
186
|
+
"""Mark the current job as completed."""
|
|
187
|
+
if (
|
|
188
|
+
Session._CURRENT_JOB
|
|
189
|
+
and Session._OWNS_JOB
|
|
190
|
+
and Session._JOB_STATUS == JobStatus.RUNNING
|
|
191
|
+
):
|
|
192
|
+
self.catalog.metastore.set_job_status(
|
|
193
|
+
Session._CURRENT_JOB.id, JobStatus.COMPLETE
|
|
194
|
+
)
|
|
195
|
+
Session._JOB_STATUS = JobStatus.COMPLETE
|
|
196
|
+
|
|
197
|
+
def _finalize_job_as_canceled(self):
|
|
198
|
+
"""Mark the current job as canceled."""
|
|
199
|
+
if (
|
|
200
|
+
Session._CURRENT_JOB
|
|
201
|
+
and Session._OWNS_JOB
|
|
202
|
+
and Session._JOB_STATUS == JobStatus.RUNNING
|
|
203
|
+
):
|
|
204
|
+
self.catalog.metastore.set_job_status(
|
|
205
|
+
Session._CURRENT_JOB.id, JobStatus.CANCELED
|
|
206
|
+
)
|
|
207
|
+
Session._JOB_STATUS = JobStatus.CANCELED
|
|
208
|
+
|
|
209
|
+
def _finalize_job_as_failed(self, exc_type, exc_value, tb):
|
|
210
|
+
"""Mark the current job as failed with error details."""
|
|
211
|
+
if (
|
|
212
|
+
Session._CURRENT_JOB
|
|
213
|
+
and Session._OWNS_JOB
|
|
214
|
+
and Session._JOB_STATUS == JobStatus.RUNNING
|
|
215
|
+
):
|
|
216
|
+
error_stack = "".join(traceback.format_exception(exc_type, exc_value, tb))
|
|
217
|
+
self.catalog.metastore.set_job_status(
|
|
218
|
+
Session._CURRENT_JOB.id,
|
|
219
|
+
JobStatus.FAILED,
|
|
220
|
+
error_message=str(exc_value),
|
|
221
|
+
error_stack=error_stack,
|
|
222
|
+
)
|
|
223
|
+
Session._JOB_STATUS = JobStatus.FAILED
|
|
96
224
|
|
|
97
225
|
def generate_temp_dataset_name(self) -> str:
|
|
98
226
|
return self.get_temp_prefix() + uuid4().hex[: self.TEMP_TABLE_UUID_LEN]
|
|
@@ -113,22 +241,12 @@ class Session:
|
|
|
113
241
|
except TableMissingError:
|
|
114
242
|
pass
|
|
115
243
|
|
|
116
|
-
def _cleanup_created_versions(self) -> None:
|
|
117
|
-
if not self.dataset_versions:
|
|
118
|
-
return
|
|
119
|
-
|
|
120
|
-
for dataset, version, listing in self.dataset_versions:
|
|
121
|
-
if not listing:
|
|
122
|
-
self.catalog.remove_dataset_version(dataset, version)
|
|
123
|
-
|
|
124
|
-
self.dataset_versions.clear()
|
|
125
|
-
|
|
126
244
|
@classmethod
|
|
127
245
|
def get(
|
|
128
246
|
cls,
|
|
129
|
-
session:
|
|
130
|
-
catalog:
|
|
131
|
-
client_config:
|
|
247
|
+
session: "Session | None" = None,
|
|
248
|
+
catalog: "Catalog | None" = None,
|
|
249
|
+
client_config: dict | None = None,
|
|
132
250
|
in_memory: bool = False,
|
|
133
251
|
) -> "Session":
|
|
134
252
|
"""Creates a Session() object from a catalog.
|
|
@@ -173,33 +291,72 @@ class Session:
|
|
|
173
291
|
|
|
174
292
|
@staticmethod
|
|
175
293
|
def except_hook(exc_type, exc_value, exc_traceback):
|
|
176
|
-
Session.GLOBAL_SESSION_CTX
|
|
294
|
+
if Session.GLOBAL_SESSION_CTX:
|
|
295
|
+
# Handle KeyboardInterrupt specially - mark as canceled and exit with
|
|
296
|
+
# signal code
|
|
297
|
+
if exc_type is KeyboardInterrupt:
|
|
298
|
+
Session.GLOBAL_SESSION_CTX._finalize_job_as_canceled()
|
|
299
|
+
else:
|
|
300
|
+
Session.GLOBAL_SESSION_CTX._finalize_job_as_failed(
|
|
301
|
+
exc_type, exc_value, exc_traceback
|
|
302
|
+
)
|
|
303
|
+
Session.GLOBAL_SESSION_CTX.__exit__(exc_type, exc_value, exc_traceback)
|
|
304
|
+
|
|
177
305
|
Session._global_cleanup()
|
|
178
306
|
|
|
307
|
+
# Always delegate to original hook if it exists
|
|
179
308
|
if Session.ORIGINAL_EXCEPT_HOOK:
|
|
180
309
|
Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)
|
|
181
310
|
|
|
311
|
+
if exc_type is KeyboardInterrupt:
|
|
312
|
+
# Exit with SIGINT signal code (128 + 2 = 130, or -2 in subprocess terms)
|
|
313
|
+
sys.exit(130)
|
|
314
|
+
|
|
182
315
|
@classmethod
|
|
183
316
|
def cleanup_for_tests(cls):
|
|
317
|
+
cls._close_all_contexts()
|
|
184
318
|
if cls.GLOBAL_SESSION_CTX is not None:
|
|
185
319
|
cls.GLOBAL_SESSION_CTX.__exit__(None, None, None)
|
|
186
320
|
cls.GLOBAL_SESSION_CTX = None
|
|
187
321
|
atexit.unregister(cls._global_cleanup)
|
|
188
322
|
|
|
323
|
+
# Reset job-related class variables
|
|
324
|
+
if cls._JOB_FINALIZE_HOOK:
|
|
325
|
+
try:
|
|
326
|
+
atexit.unregister(cls._JOB_FINALIZE_HOOK)
|
|
327
|
+
except ValueError:
|
|
328
|
+
pass # Hook was not registered
|
|
329
|
+
cls._CURRENT_JOB = None
|
|
330
|
+
cls._JOB_STATUS = None
|
|
331
|
+
cls._OWNS_JOB = None
|
|
332
|
+
cls._JOB_HOOKS_REGISTERED = False
|
|
333
|
+
cls._JOB_FINALIZE_HOOK = None
|
|
334
|
+
|
|
189
335
|
if cls.ORIGINAL_EXCEPT_HOOK:
|
|
190
336
|
sys.excepthook = cls.ORIGINAL_EXCEPT_HOOK
|
|
191
337
|
|
|
192
338
|
@staticmethod
|
|
193
339
|
def _global_cleanup():
|
|
340
|
+
Session._close_all_contexts()
|
|
194
341
|
if Session.GLOBAL_SESSION_CTX is not None:
|
|
195
342
|
Session.GLOBAL_SESSION_CTX.__exit__(None, None, None)
|
|
196
343
|
|
|
197
|
-
for
|
|
344
|
+
for session in list(Session._ALL_SESSIONS):
|
|
198
345
|
try:
|
|
199
|
-
|
|
200
|
-
# Cleanup temp dataset for session variables.
|
|
201
|
-
obj.__exit__(None, None, None)
|
|
346
|
+
session.__exit__(None, None, None)
|
|
202
347
|
except ReferenceError:
|
|
203
348
|
continue # Object has been finalized already
|
|
204
349
|
except Exception as e: # noqa: BLE001
|
|
205
350
|
logger.error(f"Exception while cleaning up session: {e}") # noqa: G004
|
|
351
|
+
|
|
352
|
+
@classmethod
|
|
353
|
+
def _close_all_contexts(cls) -> None:
|
|
354
|
+
while cls.SESSION_CONTEXTS:
|
|
355
|
+
session = cls.SESSION_CONTEXTS.pop()
|
|
356
|
+
try:
|
|
357
|
+
session.__exit__(None, None, None)
|
|
358
|
+
except Exception as exc: # noqa: BLE001
|
|
359
|
+
logger.error(
|
|
360
|
+
"Exception while closing session context during cleanup: %s",
|
|
361
|
+
exc,
|
|
362
|
+
)
|
datachain/query/udf.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
3
4
|
|
|
4
5
|
if TYPE_CHECKING:
|
|
5
6
|
from sqlalchemy import Select, Table
|
|
@@ -17,10 +18,11 @@ class UdfInfo(TypedDict):
|
|
|
17
18
|
query: "Select"
|
|
18
19
|
udf_fields: list[str]
|
|
19
20
|
batching: "BatchingStrategy"
|
|
20
|
-
processes:
|
|
21
|
+
processes: int | None
|
|
21
22
|
is_generator: bool
|
|
22
23
|
cache: bool
|
|
23
24
|
rows_total: int
|
|
25
|
+
batch_size: int
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
class AbstractUDFDistributor(ABC):
|
|
@@ -32,13 +34,14 @@ class AbstractUDFDistributor(ABC):
|
|
|
32
34
|
query: "Select",
|
|
33
35
|
udf_data: bytes,
|
|
34
36
|
batching: "BatchingStrategy",
|
|
35
|
-
workers:
|
|
36
|
-
processes:
|
|
37
|
+
workers: bool | int,
|
|
38
|
+
processes: bool | int,
|
|
37
39
|
udf_fields: list[str],
|
|
38
40
|
rows_total: int,
|
|
39
41
|
use_cache: bool,
|
|
40
42
|
is_generator: bool = False,
|
|
41
|
-
min_task_size:
|
|
43
|
+
min_task_size: str | int | None = None,
|
|
44
|
+
batch_size: int | None = None,
|
|
42
45
|
) -> None: ...
|
|
43
46
|
|
|
44
47
|
@abstractmethod
|
|
@@ -46,4 +49,4 @@ class AbstractUDFDistributor(ABC):
|
|
|
46
49
|
|
|
47
50
|
@staticmethod
|
|
48
51
|
@abstractmethod
|
|
49
|
-
def run_udf(
|
|
52
|
+
def run_udf() -> int: ...
|
datachain/remote/studio.py
CHANGED
|
@@ -1,16 +1,10 @@
|
|
|
1
|
-
import base64
|
|
2
1
|
import json
|
|
3
2
|
import logging
|
|
4
3
|
import os
|
|
5
4
|
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
6
5
|
from datetime import datetime, timedelta, timezone
|
|
7
6
|
from struct import unpack
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Generic,
|
|
11
|
-
Optional,
|
|
12
|
-
TypeVar,
|
|
13
|
-
)
|
|
7
|
+
from typing import Any, BinaryIO, Generic, TypeVar
|
|
14
8
|
from urllib.parse import urlparse, urlunparse
|
|
15
9
|
|
|
16
10
|
import websockets
|
|
@@ -22,16 +16,17 @@ from datachain.error import DataChainError
|
|
|
22
16
|
from datachain.utils import STUDIO_URL, retry_with_backoff
|
|
23
17
|
|
|
24
18
|
T = TypeVar("T")
|
|
25
|
-
LsData =
|
|
26
|
-
DatasetInfoData =
|
|
27
|
-
DatasetRowsData =
|
|
28
|
-
DatasetJobVersionsData =
|
|
29
|
-
DatasetExportStatus =
|
|
30
|
-
DatasetExportSignedUrls =
|
|
31
|
-
FileUploadData =
|
|
32
|
-
JobData =
|
|
33
|
-
JobListData = dict[str, Any]
|
|
34
|
-
ClusterListData = dict[str, Any]
|
|
19
|
+
LsData = list[dict[str, Any]] | None
|
|
20
|
+
DatasetInfoData = dict[str, Any] | None
|
|
21
|
+
DatasetRowsData = Iterable[dict[str, Any]] | None
|
|
22
|
+
DatasetJobVersionsData = dict[str, Any] | None
|
|
23
|
+
DatasetExportStatus = dict[str, Any] | None
|
|
24
|
+
DatasetExportSignedUrls = list[str] | None
|
|
25
|
+
FileUploadData = dict[str, Any] | None
|
|
26
|
+
JobData = dict[str, Any] | None
|
|
27
|
+
JobListData = list[dict[str, Any]]
|
|
28
|
+
ClusterListData = list[dict[str, Any]]
|
|
29
|
+
|
|
35
30
|
logger = logging.getLogger("datachain")
|
|
36
31
|
|
|
37
32
|
DATASET_ROWS_CHUNK_SIZE = 8192
|
|
@@ -92,7 +87,7 @@ class Response(Generic[T]):
|
|
|
92
87
|
|
|
93
88
|
|
|
94
89
|
class StudioClient:
|
|
95
|
-
def __init__(self, timeout: float = 3600.0, team:
|
|
90
|
+
def __init__(self, timeout: float = 3600.0, team: str | None = None) -> None:
|
|
96
91
|
self._check_dependencies()
|
|
97
92
|
self.timeout = timeout
|
|
98
93
|
self._config = None
|
|
@@ -153,7 +148,7 @@ class StudioClient:
|
|
|
153
148
|
) from None
|
|
154
149
|
|
|
155
150
|
def _send_request_msgpack(
|
|
156
|
-
self, route: str, data: dict[str, Any], method:
|
|
151
|
+
self, route: str, data: dict[str, Any], method: str | None = "POST"
|
|
157
152
|
) -> Response[Any]:
|
|
158
153
|
import msgpack
|
|
159
154
|
import requests
|
|
@@ -191,7 +186,7 @@ class StudioClient:
|
|
|
191
186
|
|
|
192
187
|
@retry_with_backoff(retries=3, errors=(HTTPError, Timeout))
|
|
193
188
|
def _send_request(
|
|
194
|
-
self, route: str, data: dict[str, Any], method:
|
|
189
|
+
self, route: str, data: dict[str, Any], method: str | None = "POST"
|
|
195
190
|
) -> Response[Any]:
|
|
196
191
|
"""
|
|
197
192
|
Function that communicate Studio API.
|
|
@@ -239,6 +234,45 @@ class StudioClient:
|
|
|
239
234
|
|
|
240
235
|
return Response(data, ok, message, response.status_code)
|
|
241
236
|
|
|
237
|
+
def _send_multipart_request(
|
|
238
|
+
self, route: str, files: dict[str, Any], params: dict[str, Any] | None = None
|
|
239
|
+
) -> Response[Any]:
|
|
240
|
+
"""
|
|
241
|
+
Function that communicates with Studio API using multipart/form-data.
|
|
242
|
+
It will raise an exception, and try to retry, if 5xx status code is
|
|
243
|
+
returned, or if Timeout exceptions is thrown from the requests lib
|
|
244
|
+
"""
|
|
245
|
+
import requests
|
|
246
|
+
|
|
247
|
+
# Add team_name to params
|
|
248
|
+
request_params = {**(params or {}), "team_name": self.team}
|
|
249
|
+
|
|
250
|
+
response = requests.post(
|
|
251
|
+
url=f"{self.url}/{route}",
|
|
252
|
+
files=files,
|
|
253
|
+
params=request_params,
|
|
254
|
+
headers={
|
|
255
|
+
"Authorization": f"token {self.token}",
|
|
256
|
+
},
|
|
257
|
+
timeout=self.timeout,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
ok = response.ok
|
|
261
|
+
try:
|
|
262
|
+
data = json.loads(response.content.decode("utf-8"))
|
|
263
|
+
except json.decoder.JSONDecodeError:
|
|
264
|
+
data = {}
|
|
265
|
+
|
|
266
|
+
if not ok:
|
|
267
|
+
if response.status_code == 403:
|
|
268
|
+
message = f"Not authorized for the team {self.team}"
|
|
269
|
+
else:
|
|
270
|
+
message = data.get("message", "")
|
|
271
|
+
else:
|
|
272
|
+
message = ""
|
|
273
|
+
|
|
274
|
+
return Response(data, ok, message, response.status_code)
|
|
275
|
+
|
|
242
276
|
@staticmethod
|
|
243
277
|
def _unpacker_hook(code, data):
|
|
244
278
|
import msgpack
|
|
@@ -305,7 +339,7 @@ class StudioClient:
|
|
|
305
339
|
response = self._send_request_msgpack("datachain/ls", {"source": path})
|
|
306
340
|
yield path, response
|
|
307
341
|
|
|
308
|
-
def ls_datasets(self, prefix:
|
|
342
|
+
def ls_datasets(self, prefix: str | None = None) -> Response[LsData]:
|
|
309
343
|
return self._send_request(
|
|
310
344
|
"datachain/datasets", {"prefix": prefix}, method="GET"
|
|
311
345
|
)
|
|
@@ -315,9 +349,9 @@ class StudioClient:
|
|
|
315
349
|
name: str,
|
|
316
350
|
namespace: str,
|
|
317
351
|
project: str,
|
|
318
|
-
new_name:
|
|
319
|
-
description:
|
|
320
|
-
attrs:
|
|
352
|
+
new_name: str | None = None,
|
|
353
|
+
description: str | None = None,
|
|
354
|
+
attrs: list[str] | None = None,
|
|
321
355
|
) -> Response[DatasetInfoData]:
|
|
322
356
|
body = {
|
|
323
357
|
"new_name": new_name,
|
|
@@ -338,8 +372,8 @@ class StudioClient:
|
|
|
338
372
|
name: str,
|
|
339
373
|
namespace: str,
|
|
340
374
|
project: str,
|
|
341
|
-
version:
|
|
342
|
-
force:
|
|
375
|
+
version: str | None = None,
|
|
376
|
+
force: bool | None = False,
|
|
343
377
|
) -> Response[DatasetInfoData]:
|
|
344
378
|
return self._send_request(
|
|
345
379
|
"datachain/datasets",
|
|
@@ -409,29 +443,30 @@ class StudioClient:
|
|
|
409
443
|
method="GET",
|
|
410
444
|
)
|
|
411
445
|
|
|
412
|
-
def upload_file(
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
}
|
|
417
|
-
|
|
446
|
+
def upload_file(
|
|
447
|
+
self, file_obj: BinaryIO, file_name: str
|
|
448
|
+
) -> Response[FileUploadData]:
|
|
449
|
+
# Prepare multipart form data
|
|
450
|
+
files = {"file": (file_name, file_obj, "application/octet-stream")}
|
|
451
|
+
|
|
452
|
+
return self._send_multipart_request("datachain/jobs/files", files)
|
|
418
453
|
|
|
419
454
|
def create_job(
|
|
420
455
|
self,
|
|
421
456
|
query: str,
|
|
422
457
|
query_type: str,
|
|
423
|
-
environment:
|
|
424
|
-
workers:
|
|
425
|
-
query_name:
|
|
426
|
-
files:
|
|
427
|
-
python_version:
|
|
428
|
-
requirements:
|
|
429
|
-
repository:
|
|
430
|
-
priority:
|
|
431
|
-
cluster:
|
|
432
|
-
start_time:
|
|
433
|
-
cron:
|
|
434
|
-
credentials_name:
|
|
458
|
+
environment: str | None = None,
|
|
459
|
+
workers: int | None = None,
|
|
460
|
+
query_name: str | None = None,
|
|
461
|
+
files: list[str] | None = None,
|
|
462
|
+
python_version: str | None = None,
|
|
463
|
+
requirements: str | None = None,
|
|
464
|
+
repository: str | None = None,
|
|
465
|
+
priority: int | None = None,
|
|
466
|
+
cluster: str | None = None,
|
|
467
|
+
start_time: str | None = None,
|
|
468
|
+
cron: str | None = None,
|
|
469
|
+
credentials_name: str | None = None,
|
|
435
470
|
) -> Response[JobData]:
|
|
436
471
|
data = {
|
|
437
472
|
"query": query,
|
|
@@ -449,25 +484,27 @@ class StudioClient:
|
|
|
449
484
|
"cron_expression": cron,
|
|
450
485
|
"credentials_name": credentials_name,
|
|
451
486
|
}
|
|
452
|
-
return self._send_request("datachain/
|
|
487
|
+
return self._send_request("datachain/jobs/", data)
|
|
453
488
|
|
|
454
489
|
def get_jobs(
|
|
455
490
|
self,
|
|
456
|
-
status:
|
|
491
|
+
status: str | None = None,
|
|
457
492
|
limit: int = 20,
|
|
493
|
+
job_id: str | None = None,
|
|
458
494
|
) -> Response[JobListData]:
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
495
|
+
params: dict[str, Any] = {"limit": limit}
|
|
496
|
+
if status is not None:
|
|
497
|
+
params["status"] = status
|
|
498
|
+
if job_id is not None:
|
|
499
|
+
params["job_id"] = job_id
|
|
500
|
+
return self._send_request("datachain/jobs/", params, method="GET")
|
|
464
501
|
|
|
465
502
|
def cancel_job(
|
|
466
503
|
self,
|
|
467
504
|
job_id: str,
|
|
468
505
|
) -> Response[JobData]:
|
|
469
|
-
url = f"datachain/
|
|
506
|
+
url = f"datachain/jobs/{job_id}/cancel"
|
|
470
507
|
return self._send_request(url, data={}, method="POST")
|
|
471
508
|
|
|
472
509
|
def get_clusters(self) -> Response[ClusterListData]:
|
|
473
|
-
return self._send_request("datachain/clusters", {}, method="GET")
|
|
510
|
+
return self._send_request("datachain/clusters/", {}, method="GET")
|
datachain/script_meta.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
try:
|
|
6
6
|
import tomllib
|
|
@@ -59,23 +59,23 @@ class ScriptConfig:
|
|
|
59
59
|
|
|
60
60
|
"""
|
|
61
61
|
|
|
62
|
-
python_version:
|
|
62
|
+
python_version: str | None
|
|
63
63
|
dependencies: list[str]
|
|
64
64
|
attachments: dict[str, str]
|
|
65
65
|
params: dict[str, Any]
|
|
66
66
|
inputs: dict[str, Any]
|
|
67
67
|
outputs: dict[str, Any]
|
|
68
|
-
num_workers:
|
|
68
|
+
num_workers: int | None = None
|
|
69
69
|
|
|
70
70
|
def __init__(
|
|
71
71
|
self,
|
|
72
|
-
python_version:
|
|
73
|
-
dependencies:
|
|
74
|
-
attachments:
|
|
75
|
-
params:
|
|
76
|
-
inputs:
|
|
77
|
-
outputs:
|
|
78
|
-
num_workers:
|
|
72
|
+
python_version: str | None = None,
|
|
73
|
+
dependencies: list[str] | None = None,
|
|
74
|
+
attachments: dict[str, str] | None = None,
|
|
75
|
+
params: dict[str, Any] | None = None,
|
|
76
|
+
inputs: dict[str, Any] | None = None,
|
|
77
|
+
outputs: dict[str, Any] | None = None,
|
|
78
|
+
num_workers: int | None = None,
|
|
79
79
|
):
|
|
80
80
|
self.python_version = python_version
|
|
81
81
|
self.dependencies = dependencies or []
|
|
@@ -98,7 +98,7 @@ class ScriptConfig:
|
|
|
98
98
|
return self.attachments.get(name, default)
|
|
99
99
|
|
|
100
100
|
@staticmethod
|
|
101
|
-
def read(script: str) ->
|
|
101
|
+
def read(script: str) -> dict | None:
|
|
102
102
|
"""Converts inline script metadata to dict with all found data"""
|
|
103
103
|
regex = (
|
|
104
104
|
r"(?m)^# \/\/\/ (?P<type>[a-zA-Z0-9-]+)[ \t]*$[\r\n|\r|\n]"
|
|
@@ -119,7 +119,7 @@ class ScriptConfig:
|
|
|
119
119
|
return None
|
|
120
120
|
|
|
121
121
|
@staticmethod
|
|
122
|
-
def parse(script: str) ->
|
|
122
|
+
def parse(script: str) -> "ScriptConfig | None":
|
|
123
123
|
"""
|
|
124
124
|
Method that is parsing inline script metadata from datachain script and
|
|
125
125
|
instantiating ScriptConfig class with found data. If no inline metadata is
|