datachain 0.7.11__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +56 -45
- datachain/cli.py +25 -3
- datachain/client/gcs.py +9 -0
- datachain/data_storage/sqlite.py +20 -6
- datachain/data_storage/warehouse.py +0 -1
- datachain/lib/arrow.py +82 -58
- datachain/lib/dc.py +167 -166
- datachain/lib/diff.py +197 -0
- datachain/lib/file.py +3 -1
- datachain/lib/listing.py +44 -0
- datachain/lib/meta_formats.py +38 -42
- datachain/lib/udf.py +0 -1
- datachain/query/batch.py +32 -6
- datachain/query/dataset.py +18 -17
- datachain/query/dispatch.py +125 -125
- datachain/query/session.py +8 -5
- datachain/query/udf.py +20 -0
- datachain/query/utils.py +42 -0
- datachain/remote/studio.py +53 -1
- datachain/studio.py +47 -2
- datachain/utils.py +1 -1
- {datachain-0.7.11.dist-info → datachain-0.8.1.dist-info}/METADATA +4 -3
- {datachain-0.7.11.dist-info → datachain-0.8.1.dist-info}/RECORD +27 -24
- {datachain-0.7.11.dist-info → datachain-0.8.1.dist-info}/LICENSE +0 -0
- {datachain-0.7.11.dist-info → datachain-0.8.1.dist-info}/WHEEL +0 -0
- {datachain-0.7.11.dist-info → datachain-0.8.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.7.11.dist-info → datachain-0.8.1.dist-info}/top_level.txt +0 -0
datachain/query/dispatch.py
CHANGED
|
@@ -1,34 +1,37 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
-
from collections.abc import
|
|
2
|
+
from collections.abc import Iterable, Sequence
|
|
3
3
|
from itertools import chain
|
|
4
4
|
from multiprocessing import cpu_count
|
|
5
5
|
from sys import stdin
|
|
6
|
-
from
|
|
6
|
+
from threading import Timer
|
|
7
|
+
from typing import TYPE_CHECKING, Optional
|
|
7
8
|
|
|
8
9
|
import attrs
|
|
9
10
|
import multiprocess
|
|
10
11
|
from cloudpickle import load, loads
|
|
11
12
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
12
13
|
from multiprocess import get_context
|
|
14
|
+
from sqlalchemy.sql import func
|
|
13
15
|
|
|
14
16
|
from datachain.catalog import Catalog
|
|
15
17
|
from datachain.catalog.loader import get_distributed_class
|
|
16
|
-
from datachain.
|
|
18
|
+
from datachain.query.batch import RowsOutput, RowsOutputBatch
|
|
17
19
|
from datachain.query.dataset import (
|
|
18
20
|
get_download_callback,
|
|
19
21
|
get_generated_callback,
|
|
20
22
|
get_processed_callback,
|
|
21
23
|
process_udf_outputs,
|
|
22
24
|
)
|
|
23
|
-
from datachain.query.queue import
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
from datachain.
|
|
25
|
+
from datachain.query.queue import get_from_queue, put_into_queue
|
|
26
|
+
from datachain.query.udf import UdfInfo
|
|
27
|
+
from datachain.query.utils import get_query_id_column
|
|
28
|
+
from datachain.utils import batched, flatten
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from sqlalchemy import Select, Table
|
|
32
|
+
|
|
33
|
+
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
|
|
34
|
+
from datachain.lib.udf import UDFAdapter
|
|
32
35
|
|
|
33
36
|
DEFAULT_BATCH_SIZE = 10000
|
|
34
37
|
STOP_SIGNAL = "STOP"
|
|
@@ -38,10 +41,6 @@ FAILED_STATUS = "FAILED"
|
|
|
38
41
|
NOTIFY_STATUS = "NOTIFY"
|
|
39
42
|
|
|
40
43
|
|
|
41
|
-
def full_module_type_path(typ: type) -> str:
|
|
42
|
-
return f"{typ.__module__}.{typ.__qualname__}"
|
|
43
|
-
|
|
44
|
-
|
|
45
44
|
def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
|
|
46
45
|
if not n_workers:
|
|
47
46
|
return cpu_count()
|
|
@@ -52,55 +51,42 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
|
|
|
52
51
|
|
|
53
52
|
def udf_entrypoint() -> int:
|
|
54
53
|
# Load UDF info from stdin
|
|
55
|
-
udf_info = load(stdin.buffer)
|
|
56
|
-
|
|
57
|
-
(
|
|
58
|
-
warehouse_class,
|
|
59
|
-
warehouse_args,
|
|
60
|
-
warehouse_kwargs,
|
|
61
|
-
) = udf_info["warehouse_clone_params"]
|
|
62
|
-
warehouse = warehouse_class(*warehouse_args, **warehouse_kwargs)
|
|
54
|
+
udf_info: UdfInfo = load(stdin.buffer)
|
|
63
55
|
|
|
64
56
|
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
65
|
-
dispatch = UDFDispatcher(
|
|
66
|
-
udf_info["udf_data"],
|
|
67
|
-
udf_info["catalog_init"],
|
|
68
|
-
udf_info["metastore_clone_params"],
|
|
69
|
-
udf_info["warehouse_clone_params"],
|
|
70
|
-
udf_fields=udf_info["udf_fields"],
|
|
71
|
-
cache=udf_info["cache"],
|
|
72
|
-
is_generator=udf_info.get("is_generator", False),
|
|
73
|
-
)
|
|
57
|
+
dispatch = UDFDispatcher(udf_info)
|
|
74
58
|
|
|
75
59
|
query = udf_info["query"]
|
|
76
60
|
batching = udf_info["batching"]
|
|
77
|
-
table = udf_info["table"]
|
|
78
61
|
n_workers = udf_info["processes"]
|
|
79
|
-
udf = loads(udf_info["udf_data"])
|
|
80
62
|
if n_workers is True:
|
|
81
|
-
# Use default number of CPUs (cores)
|
|
82
|
-
|
|
63
|
+
n_workers = None # Use default number of CPUs (cores)
|
|
64
|
+
|
|
65
|
+
wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
|
|
66
|
+
warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
|
|
67
|
+
|
|
68
|
+
total_rows = next(
|
|
69
|
+
warehouse.db.execute(
|
|
70
|
+
query.with_only_columns(func.count(query.c.sys__id)).order_by(None)
|
|
71
|
+
)
|
|
72
|
+
)[0]
|
|
83
73
|
|
|
84
74
|
with contextlib.closing(
|
|
85
|
-
batching(warehouse.dataset_select_paginated, query)
|
|
75
|
+
batching(warehouse.dataset_select_paginated, query, ids_only=True)
|
|
86
76
|
) as udf_inputs:
|
|
87
77
|
download_cb = get_download_callback()
|
|
88
78
|
processed_cb = get_processed_callback()
|
|
89
|
-
generated_cb = get_generated_callback(dispatch.is_generator)
|
|
90
79
|
try:
|
|
91
|
-
|
|
92
|
-
|
|
80
|
+
dispatch.run_udf_parallel(
|
|
81
|
+
udf_inputs,
|
|
82
|
+
total_rows=total_rows,
|
|
93
83
|
n_workers=n_workers,
|
|
94
84
|
processed_cb=processed_cb,
|
|
95
85
|
download_cb=download_cb,
|
|
96
86
|
)
|
|
97
|
-
process_udf_outputs(warehouse, table, udf_results, udf, cb=generated_cb)
|
|
98
87
|
finally:
|
|
99
88
|
download_cb.close()
|
|
100
89
|
processed_cb.close()
|
|
101
|
-
generated_cb.close()
|
|
102
|
-
|
|
103
|
-
warehouse.insert_rows_done(table)
|
|
104
90
|
|
|
105
91
|
return 0
|
|
106
92
|
|
|
@@ -114,32 +100,17 @@ class UDFDispatcher:
|
|
|
114
100
|
task_queue: Optional[multiprocess.Queue] = None
|
|
115
101
|
done_queue: Optional[multiprocess.Queue] = None
|
|
116
102
|
|
|
117
|
-
def __init__(
|
|
118
|
-
self
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
self.udf_data = udf_data
|
|
129
|
-
self.catalog_init_params = catalog_init_params
|
|
130
|
-
(
|
|
131
|
-
self.metastore_class,
|
|
132
|
-
self.metastore_args,
|
|
133
|
-
self.metastore_kwargs,
|
|
134
|
-
) = metastore_clone_params
|
|
135
|
-
(
|
|
136
|
-
self.warehouse_class,
|
|
137
|
-
self.warehouse_args,
|
|
138
|
-
self.warehouse_kwargs,
|
|
139
|
-
) = warehouse_clone_params
|
|
140
|
-
self.udf_fields = udf_fields
|
|
141
|
-
self.cache = cache
|
|
142
|
-
self.is_generator = is_generator
|
|
103
|
+
def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
|
|
104
|
+
self.udf_data = udf_info["udf_data"]
|
|
105
|
+
self.catalog_init_params = udf_info["catalog_init"]
|
|
106
|
+
self.metastore_clone_params = udf_info["metastore_clone_params"]
|
|
107
|
+
self.warehouse_clone_params = udf_info["warehouse_clone_params"]
|
|
108
|
+
self.query = udf_info["query"]
|
|
109
|
+
self.table = udf_info["table"]
|
|
110
|
+
self.udf_fields = udf_info["udf_fields"]
|
|
111
|
+
self.cache = udf_info["cache"]
|
|
112
|
+
self.is_generator = udf_info["is_generator"]
|
|
113
|
+
self.is_batching = udf_info["batching"].is_batching
|
|
143
114
|
self.buffer_size = buffer_size
|
|
144
115
|
self.catalog = None
|
|
145
116
|
self.task_queue = None
|
|
@@ -148,12 +119,10 @@ class UDFDispatcher:
|
|
|
148
119
|
|
|
149
120
|
def _create_worker(self) -> "UDFWorker":
|
|
150
121
|
if not self.catalog:
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
warehouse =
|
|
155
|
-
*self.warehouse_args, **self.warehouse_kwargs
|
|
156
|
-
)
|
|
122
|
+
ms_cls, ms_args, ms_kwargs = self.metastore_clone_params
|
|
123
|
+
metastore: AbstractMetastore = ms_cls(*ms_args, **ms_kwargs)
|
|
124
|
+
ws_cls, ws_args, ws_kwargs = self.warehouse_clone_params
|
|
125
|
+
warehouse: AbstractWarehouse = ws_cls(*ws_args, **ws_kwargs)
|
|
157
126
|
self.catalog = Catalog(metastore, warehouse, **self.catalog_init_params)
|
|
158
127
|
self.udf = loads(self.udf_data)
|
|
159
128
|
return UDFWorker(
|
|
@@ -161,7 +130,10 @@ class UDFDispatcher:
|
|
|
161
130
|
self.udf,
|
|
162
131
|
self.task_queue,
|
|
163
132
|
self.done_queue,
|
|
133
|
+
self.query,
|
|
134
|
+
self.table,
|
|
164
135
|
self.is_generator,
|
|
136
|
+
self.is_batching,
|
|
165
137
|
self.cache,
|
|
166
138
|
self.udf_fields,
|
|
167
139
|
)
|
|
@@ -189,26 +161,27 @@ class UDFDispatcher:
|
|
|
189
161
|
|
|
190
162
|
def run_udf_parallel( # noqa: C901, PLR0912
|
|
191
163
|
self,
|
|
192
|
-
input_rows,
|
|
164
|
+
input_rows: Iterable[RowsOutput],
|
|
165
|
+
total_rows: int,
|
|
193
166
|
n_workers: Optional[int] = None,
|
|
194
|
-
input_queue=None,
|
|
195
167
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
196
168
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
197
|
-
) ->
|
|
169
|
+
) -> None:
|
|
198
170
|
n_workers = get_n_workers_from_arg(n_workers)
|
|
199
171
|
|
|
172
|
+
input_batch_size = total_rows // n_workers
|
|
173
|
+
if input_batch_size == 0:
|
|
174
|
+
input_batch_size = 1
|
|
175
|
+
elif input_batch_size > DEFAULT_BATCH_SIZE:
|
|
176
|
+
input_batch_size = DEFAULT_BATCH_SIZE
|
|
177
|
+
|
|
200
178
|
if self.buffer_size < n_workers:
|
|
201
179
|
raise RuntimeError(
|
|
202
180
|
"Parallel run error: buffer size is smaller than "
|
|
203
181
|
f"number of workers: {self.buffer_size} < {n_workers}"
|
|
204
182
|
)
|
|
205
183
|
|
|
206
|
-
|
|
207
|
-
streaming_mode = True
|
|
208
|
-
self.task_queue = input_queue
|
|
209
|
-
else:
|
|
210
|
-
streaming_mode = False
|
|
211
|
-
self.task_queue = self.ctx.Queue()
|
|
184
|
+
self.task_queue = self.ctx.Queue()
|
|
212
185
|
self.done_queue = self.ctx.Queue()
|
|
213
186
|
pool = [
|
|
214
187
|
self.ctx.Process(name=f"Worker-UDF-{i}", target=self._run_worker)
|
|
@@ -223,41 +196,41 @@ class UDFDispatcher:
|
|
|
223
196
|
# Will be set to True when the input is exhausted
|
|
224
197
|
input_finished = False
|
|
225
198
|
|
|
226
|
-
if not
|
|
227
|
-
|
|
228
|
-
input_data = chain(input_rows, [STOP_SIGNAL] * n_workers)
|
|
199
|
+
if not self.is_batching:
|
|
200
|
+
input_rows = batched(flatten(input_rows), input_batch_size)
|
|
229
201
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
202
|
+
# Stop all workers after the input rows have finished processing
|
|
203
|
+
input_data = chain(input_rows, [STOP_SIGNAL] * n_workers)
|
|
204
|
+
|
|
205
|
+
# Add initial buffer of tasks
|
|
206
|
+
for _ in range(self.buffer_size):
|
|
207
|
+
try:
|
|
208
|
+
put_into_queue(self.task_queue, next(input_data))
|
|
209
|
+
except StopIteration:
|
|
210
|
+
input_finished = True
|
|
211
|
+
break
|
|
237
212
|
|
|
238
213
|
# Process all tasks
|
|
239
214
|
while n_workers > 0:
|
|
240
215
|
result = get_from_queue(self.done_queue)
|
|
216
|
+
|
|
217
|
+
if downloaded := result.get("downloaded"):
|
|
218
|
+
download_cb.relative_update(downloaded)
|
|
219
|
+
if processed := result.get("processed"):
|
|
220
|
+
processed_cb.relative_update(processed)
|
|
221
|
+
|
|
241
222
|
status = result["status"]
|
|
242
|
-
if status
|
|
243
|
-
|
|
244
|
-
download_cb.relative_update(downloaded)
|
|
245
|
-
if processed := result.get("processed"):
|
|
246
|
-
processed_cb.relative_update(processed)
|
|
223
|
+
if status in (OK_STATUS, NOTIFY_STATUS):
|
|
224
|
+
pass # Do nothing here
|
|
247
225
|
elif status == FINISHED_STATUS:
|
|
248
|
-
# Worker finished
|
|
249
|
-
n_workers -= 1
|
|
250
|
-
elif status == OK_STATUS:
|
|
251
|
-
if processed := result.get("processed"):
|
|
252
|
-
processed_cb.relative_update(processed)
|
|
253
|
-
yield msgpack_unpack(result["result"])
|
|
226
|
+
n_workers -= 1 # Worker finished
|
|
254
227
|
else: # Failed / error
|
|
255
228
|
n_workers -= 1
|
|
256
229
|
if exc := result.get("exception"):
|
|
257
230
|
raise exc
|
|
258
231
|
raise RuntimeError("Internal error: Parallel UDF execution failed")
|
|
259
232
|
|
|
260
|
-
if status == OK_STATUS and not
|
|
233
|
+
if status == OK_STATUS and not input_finished:
|
|
261
234
|
try:
|
|
262
235
|
put_into_queue(self.task_queue, next(input_data))
|
|
263
236
|
except StopIteration:
|
|
@@ -311,11 +284,14 @@ class ProcessedCallback(Callback):
|
|
|
311
284
|
|
|
312
285
|
@attrs.define
|
|
313
286
|
class UDFWorker:
|
|
314
|
-
catalog: Catalog
|
|
315
|
-
udf: UDFAdapter
|
|
287
|
+
catalog: "Catalog"
|
|
288
|
+
udf: "UDFAdapter"
|
|
316
289
|
task_queue: "multiprocess.Queue"
|
|
317
290
|
done_queue: "multiprocess.Queue"
|
|
291
|
+
query: "Select"
|
|
292
|
+
table: "Table"
|
|
318
293
|
is_generator: bool
|
|
294
|
+
is_batching: bool
|
|
319
295
|
cache: bool
|
|
320
296
|
udf_fields: Sequence[str]
|
|
321
297
|
cb: Callback = attrs.field()
|
|
@@ -326,30 +302,54 @@ class UDFWorker:
|
|
|
326
302
|
|
|
327
303
|
def run(self) -> None:
|
|
328
304
|
processed_cb = ProcessedCallback()
|
|
305
|
+
generated_cb = get_generated_callback(self.is_generator)
|
|
306
|
+
|
|
329
307
|
udf_results = self.udf.run(
|
|
330
308
|
self.udf_fields,
|
|
331
|
-
|
|
309
|
+
self.get_inputs(),
|
|
332
310
|
self.catalog,
|
|
333
|
-
self.is_generator,
|
|
334
311
|
self.cache,
|
|
335
312
|
download_cb=self.cb,
|
|
336
313
|
processed_cb=processed_cb,
|
|
337
314
|
)
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
315
|
+
process_udf_outputs(
|
|
316
|
+
self.catalog.warehouse,
|
|
317
|
+
self.table,
|
|
318
|
+
self.notify_and_process(udf_results, processed_cb),
|
|
319
|
+
self.udf,
|
|
320
|
+
cb=generated_cb,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
put_into_queue(
|
|
324
|
+
self.done_queue,
|
|
325
|
+
{"status": FINISHED_STATUS, "processed": processed_cb.processed_rows},
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
def notify_and_process(self, udf_results, processed_cb):
|
|
329
|
+
for row in udf_results:
|
|
347
330
|
put_into_queue(
|
|
348
331
|
self.done_queue,
|
|
349
|
-
{"status":
|
|
332
|
+
{"status": OK_STATUS, "processed": processed_cb.processed_rows},
|
|
350
333
|
)
|
|
351
|
-
|
|
334
|
+
yield row
|
|
352
335
|
|
|
353
336
|
def get_inputs(self):
|
|
354
|
-
|
|
355
|
-
|
|
337
|
+
warehouse = self.catalog.warehouse.clone()
|
|
338
|
+
col_id = get_query_id_column(self.query)
|
|
339
|
+
|
|
340
|
+
if self.is_batching:
|
|
341
|
+
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
342
|
+
ids = [row[0] for row in batch.rows]
|
|
343
|
+
rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids)))
|
|
344
|
+
yield RowsOutputBatch(list(rows))
|
|
345
|
+
else:
|
|
346
|
+
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
347
|
+
yield from warehouse.dataset_rows_select(
|
|
348
|
+
self.query.where(col_id.in_(batch))
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class RepeatTimer(Timer):
|
|
353
|
+
def run(self):
|
|
354
|
+
while not self.finished.wait(self.interval):
|
|
355
|
+
self.function(*self.args, **self.kwargs)
|
datachain/query/session.py
CHANGED
|
@@ -69,7 +69,7 @@ class Session:
|
|
|
69
69
|
self.catalog = catalog or get_catalog(
|
|
70
70
|
client_config=client_config, in_memory=in_memory
|
|
71
71
|
)
|
|
72
|
-
self.dataset_versions: list[tuple[DatasetRecord, int]] = []
|
|
72
|
+
self.dataset_versions: list[tuple[DatasetRecord, int, bool]] = []
|
|
73
73
|
|
|
74
74
|
def __enter__(self):
|
|
75
75
|
# Push the current context onto the stack
|
|
@@ -89,8 +89,10 @@ class Session:
|
|
|
89
89
|
if Session.SESSION_CONTEXTS:
|
|
90
90
|
Session.SESSION_CONTEXTS.pop()
|
|
91
91
|
|
|
92
|
-
def add_dataset_version(
|
|
93
|
-
self
|
|
92
|
+
def add_dataset_version(
|
|
93
|
+
self, dataset: "DatasetRecord", version: int, listing: bool = False
|
|
94
|
+
) -> None:
|
|
95
|
+
self.dataset_versions.append((dataset, version, listing))
|
|
94
96
|
|
|
95
97
|
def generate_temp_dataset_name(self) -> str:
|
|
96
98
|
return self.get_temp_prefix() + uuid4().hex[: self.TEMP_TABLE_UUID_LEN]
|
|
@@ -111,8 +113,9 @@ class Session:
|
|
|
111
113
|
if not self.dataset_versions:
|
|
112
114
|
return
|
|
113
115
|
|
|
114
|
-
for dataset, version in self.dataset_versions:
|
|
115
|
-
|
|
116
|
+
for dataset, version, listing in self.dataset_versions:
|
|
117
|
+
if not listing:
|
|
118
|
+
self.catalog.remove_dataset_version(dataset, version)
|
|
116
119
|
|
|
117
120
|
self.dataset_versions.clear()
|
|
118
121
|
|
datachain/query/udf.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from sqlalchemy import Select, Table
|
|
5
|
+
|
|
6
|
+
from datachain.query.batch import BatchingStrategy
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class UdfInfo(TypedDict):
|
|
10
|
+
udf_data: bytes
|
|
11
|
+
catalog_init: dict[str, Any]
|
|
12
|
+
metastore_clone_params: tuple[Callable[..., Any], list[Any], dict[str, Any]]
|
|
13
|
+
warehouse_clone_params: tuple[Callable[..., Any], list[Any], dict[str, Any]]
|
|
14
|
+
table: "Table"
|
|
15
|
+
query: "Select"
|
|
16
|
+
udf_fields: list[str]
|
|
17
|
+
batching: "BatchingStrategy"
|
|
18
|
+
processes: Optional[int]
|
|
19
|
+
is_generator: bool
|
|
20
|
+
cache: bool
|
datachain/query/utils.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import Column
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from sqlalchemy import ColumnElement, Select, TextClause
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
ColT = Union[Column, "ColumnElement", "TextClause"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def column_name(col: ColT) -> str:
|
|
13
|
+
"""Returns column name from column element."""
|
|
14
|
+
return col.name if isinstance(col, Column) else str(col)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_query_column(query: "Select", name: str) -> Optional[ColT]:
|
|
18
|
+
"""Returns column element from query by name or None if column not found."""
|
|
19
|
+
return next((col for col in query.inner_columns if column_name(col) == name), None)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_query_id_column(query: "Select") -> ColT:
|
|
23
|
+
"""Returns ID column element from query or None if column not found."""
|
|
24
|
+
col = get_query_column(query, "sys__id")
|
|
25
|
+
if col is None:
|
|
26
|
+
raise RuntimeError("sys__id column not found in query")
|
|
27
|
+
return col
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def select_only_columns(query: "Select", *names: str) -> "Select":
|
|
31
|
+
"""Returns query selecting defined columns only."""
|
|
32
|
+
if not names:
|
|
33
|
+
return query
|
|
34
|
+
|
|
35
|
+
cols: list[ColT] = []
|
|
36
|
+
for name in names:
|
|
37
|
+
col = get_query_column(query, name)
|
|
38
|
+
if col is None:
|
|
39
|
+
raise ValueError(f"Column '{name}' not found in query")
|
|
40
|
+
cols.append(col)
|
|
41
|
+
|
|
42
|
+
return query.with_only_columns(*cols)
|
datachain/remote/studio.py
CHANGED
|
@@ -2,7 +2,7 @@ import base64
|
|
|
2
2
|
import json
|
|
3
3
|
import logging
|
|
4
4
|
import os
|
|
5
|
-
from collections.abc import Iterable, Iterator
|
|
5
|
+
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
6
6
|
from datetime import datetime, timedelta, timezone
|
|
7
7
|
from struct import unpack
|
|
8
8
|
from typing import (
|
|
@@ -11,6 +11,9 @@ from typing import (
|
|
|
11
11
|
Optional,
|
|
12
12
|
TypeVar,
|
|
13
13
|
)
|
|
14
|
+
from urllib.parse import urlparse, urlunparse
|
|
15
|
+
|
|
16
|
+
import websockets
|
|
14
17
|
|
|
15
18
|
from datachain.config import Config
|
|
16
19
|
from datachain.dataset import DatasetStats
|
|
@@ -22,6 +25,7 @@ LsData = Optional[list[dict[str, Any]]]
|
|
|
22
25
|
DatasetInfoData = Optional[dict[str, Any]]
|
|
23
26
|
DatasetStatsData = Optional[DatasetStats]
|
|
24
27
|
DatasetRowsData = Optional[Iterable[dict[str, Any]]]
|
|
28
|
+
DatasetJobVersionsData = Optional[dict[str, Any]]
|
|
25
29
|
DatasetExportStatus = Optional[dict[str, Any]]
|
|
26
30
|
DatasetExportSignedUrls = Optional[list[str]]
|
|
27
31
|
FileUploadData = Optional[dict[str, Any]]
|
|
@@ -231,6 +235,40 @@ class StudioClient:
|
|
|
231
235
|
|
|
232
236
|
return msgpack.ExtType(code, data)
|
|
233
237
|
|
|
238
|
+
async def tail_job_logs(self, job_id: str) -> AsyncIterator[dict]:
|
|
239
|
+
"""
|
|
240
|
+
Follow job logs via websocket connection.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
job_id: ID of the job to follow logs for
|
|
244
|
+
|
|
245
|
+
Yields:
|
|
246
|
+
Dict containing either job status updates or log messages
|
|
247
|
+
"""
|
|
248
|
+
parsed_url = urlparse(self.url)
|
|
249
|
+
ws_url = urlunparse(
|
|
250
|
+
parsed_url._replace(scheme="wss" if parsed_url.scheme == "https" else "ws")
|
|
251
|
+
)
|
|
252
|
+
ws_url = f"{ws_url}/logs/follow/?job_id={job_id}&team_name={self.team}"
|
|
253
|
+
|
|
254
|
+
async with websockets.connect(
|
|
255
|
+
ws_url,
|
|
256
|
+
additional_headers={"Authorization": f"token {self.token}"},
|
|
257
|
+
) as websocket:
|
|
258
|
+
while True:
|
|
259
|
+
try:
|
|
260
|
+
message = await websocket.recv()
|
|
261
|
+
data = json.loads(message)
|
|
262
|
+
|
|
263
|
+
# Yield the parsed message data
|
|
264
|
+
yield data
|
|
265
|
+
|
|
266
|
+
except websockets.exceptions.ConnectionClosed:
|
|
267
|
+
break
|
|
268
|
+
except Exception as e: # noqa: BLE001
|
|
269
|
+
logger.error("Error receiving websocket message: %s", e)
|
|
270
|
+
break
|
|
271
|
+
|
|
234
272
|
def ls(self, paths: Iterable[str]) -> Iterator[tuple[str, Response[LsData]]]:
|
|
235
273
|
# TODO: change LsData (response.data value) to be list of lists
|
|
236
274
|
# to handle cases where a path will be expanded (i.e. globs)
|
|
@@ -302,6 +340,13 @@ class StudioClient:
|
|
|
302
340
|
method="GET",
|
|
303
341
|
)
|
|
304
342
|
|
|
343
|
+
def dataset_job_versions(self, job_id: str) -> Response[DatasetJobVersionsData]:
|
|
344
|
+
return self._send_request(
|
|
345
|
+
"datachain/datasets/dataset_job_versions",
|
|
346
|
+
{"job_id": job_id},
|
|
347
|
+
method="GET",
|
|
348
|
+
)
|
|
349
|
+
|
|
305
350
|
def dataset_stats(self, name: str, version: int) -> Response[DatasetStatsData]:
|
|
306
351
|
response = self._send_request(
|
|
307
352
|
"datachain/datasets/stats",
|
|
@@ -359,3 +404,10 @@ class StudioClient:
|
|
|
359
404
|
"requirements": requirements,
|
|
360
405
|
}
|
|
361
406
|
return self._send_request("datachain/job", data)
|
|
407
|
+
|
|
408
|
+
def cancel_job(
|
|
409
|
+
self,
|
|
410
|
+
job_id: str,
|
|
411
|
+
) -> Response[JobData]:
|
|
412
|
+
url = f"datachain/job/{job_id}/cancel"
|
|
413
|
+
return self._send_request(url, data={}, method="POST")
|
datachain/studio.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import os
|
|
2
3
|
from typing import TYPE_CHECKING, Optional
|
|
3
4
|
|
|
@@ -19,7 +20,7 @@ POST_LOGIN_MESSAGE = (
|
|
|
19
20
|
)
|
|
20
21
|
|
|
21
22
|
|
|
22
|
-
def process_studio_cli_args(args: "Namespace"):
|
|
23
|
+
def process_studio_cli_args(args: "Namespace"): # noqa: PLR0911
|
|
23
24
|
if args.cmd == "login":
|
|
24
25
|
return login(args)
|
|
25
26
|
if args.cmd == "logout":
|
|
@@ -47,6 +48,9 @@ def process_studio_cli_args(args: "Namespace"):
|
|
|
47
48
|
args.req_file,
|
|
48
49
|
)
|
|
49
50
|
|
|
51
|
+
if args.cmd == "cancel":
|
|
52
|
+
return cancel_job(args.job_id, args.team)
|
|
53
|
+
|
|
50
54
|
if args.cmd == "team":
|
|
51
55
|
return set_team(args)
|
|
52
56
|
raise DataChainError(f"Unknown command '{args.cmd}'.")
|
|
@@ -227,8 +231,34 @@ def create_job(
|
|
|
227
231
|
if not response.data:
|
|
228
232
|
raise DataChainError("Failed to create job")
|
|
229
233
|
|
|
230
|
-
|
|
234
|
+
job_id = response.data.get("job", {}).get("id")
|
|
235
|
+
print(f"Job {job_id} created")
|
|
231
236
|
print("Open the job in Studio at", response.data.get("job", {}).get("url"))
|
|
237
|
+
print("=" * 40)
|
|
238
|
+
|
|
239
|
+
# Sync usage
|
|
240
|
+
async def _run():
|
|
241
|
+
async for message in client.tail_job_logs(job_id):
|
|
242
|
+
if "logs" in message:
|
|
243
|
+
for log in message["logs"]:
|
|
244
|
+
print(log["message"], end="")
|
|
245
|
+
elif "job" in message:
|
|
246
|
+
print(f"\n>>>> Job is now in {message['job']['status']} status.")
|
|
247
|
+
|
|
248
|
+
asyncio.run(_run())
|
|
249
|
+
|
|
250
|
+
response = client.dataset_job_versions(job_id)
|
|
251
|
+
if not response.ok:
|
|
252
|
+
raise_remote_error(response.message)
|
|
253
|
+
|
|
254
|
+
response_data = response.data
|
|
255
|
+
if response_data:
|
|
256
|
+
dataset_versions = response_data.get("dataset_versions", [])
|
|
257
|
+
print("\n\n>>>> Dataset versions created during the job:")
|
|
258
|
+
for version in dataset_versions:
|
|
259
|
+
print(f" - {version.get('dataset_name')}@v{version.get('version')}")
|
|
260
|
+
else:
|
|
261
|
+
print("No dataset versions created during the job.")
|
|
232
262
|
|
|
233
263
|
|
|
234
264
|
def upload_files(client: StudioClient, files: list[str]) -> list[str]:
|
|
@@ -248,3 +278,18 @@ def upload_files(client: StudioClient, files: list[str]) -> list[str]:
|
|
|
248
278
|
if file_id:
|
|
249
279
|
file_ids.append(str(file_id))
|
|
250
280
|
return file_ids
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def cancel_job(job_id: str, team_name: Optional[str]):
|
|
284
|
+
token = Config().read().get("studio", {}).get("token")
|
|
285
|
+
if not token:
|
|
286
|
+
raise DataChainError(
|
|
287
|
+
"Not logged in to Studio. Log in with 'datachain studio login'."
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
client = StudioClient(team=team_name)
|
|
291
|
+
response = client.cancel_job(job_id)
|
|
292
|
+
if not response.ok:
|
|
293
|
+
raise_remote_error(response.message)
|
|
294
|
+
|
|
295
|
+
print(f"Job {job_id} canceled")
|
datachain/utils.py
CHANGED