datachain 0.16.4__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +25 -92
- datachain/cli/__init__.py +11 -9
- datachain/cli/commands/datasets.py +1 -1
- datachain/cli/commands/query.py +1 -0
- datachain/cli/commands/show.py +1 -1
- datachain/cli/parser/__init__.py +11 -3
- datachain/data_storage/job.py +1 -0
- datachain/data_storage/metastore.py +105 -94
- datachain/data_storage/sqlite.py +8 -7
- datachain/data_storage/warehouse.py +58 -46
- datachain/dataset.py +88 -45
- datachain/lib/arrow.py +23 -1
- datachain/lib/dataset_info.py +2 -1
- datachain/lib/dc/csv.py +1 -0
- datachain/lib/dc/datachain.py +38 -16
- datachain/lib/dc/datasets.py +28 -7
- datachain/lib/dc/storage.py +10 -2
- datachain/lib/listing.py +2 -0
- datachain/lib/pytorch.py +2 -2
- datachain/lib/udf.py +17 -5
- datachain/listing.py +1 -1
- datachain/query/batch.py +40 -39
- datachain/query/dataset.py +42 -41
- datachain/query/dispatch.py +137 -75
- datachain/query/metrics.py +1 -2
- datachain/query/queue.py +1 -11
- datachain/query/session.py +2 -2
- datachain/query/udf.py +1 -1
- datachain/query/utils.py +8 -14
- datachain/remote/studio.py +4 -4
- datachain/semver.py +58 -0
- datachain/studio.py +1 -1
- datachain/utils.py +3 -0
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/METADATA +1 -1
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/RECORD +39 -38
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/WHEEL +1 -1
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/top_level.txt +0 -0
datachain/query/dispatch.py
CHANGED
|
@@ -3,7 +3,6 @@ from collections.abc import Iterable, Sequence
|
|
|
3
3
|
from itertools import chain
|
|
4
4
|
from multiprocessing import cpu_count
|
|
5
5
|
from sys import stdin
|
|
6
|
-
from threading import Timer
|
|
7
6
|
from typing import TYPE_CHECKING, Literal, Optional
|
|
8
7
|
|
|
9
8
|
import multiprocess
|
|
@@ -15,7 +14,6 @@ from datachain.catalog import Catalog
|
|
|
15
14
|
from datachain.catalog.catalog import clone_catalog_with_cache
|
|
16
15
|
from datachain.catalog.loader import DISTRIBUTED_IMPORT_PATH, get_udf_distributor_class
|
|
17
16
|
from datachain.lib.udf import _get_cache
|
|
18
|
-
from datachain.query.batch import RowsOutput, RowsOutputBatch
|
|
19
17
|
from datachain.query.dataset import (
|
|
20
18
|
get_download_callback,
|
|
21
19
|
get_generated_callback,
|
|
@@ -32,6 +30,7 @@ if TYPE_CHECKING:
|
|
|
32
30
|
|
|
33
31
|
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
|
|
34
32
|
from datachain.lib.udf import UDFAdapter
|
|
33
|
+
from datachain.query.batch import RowsOutput
|
|
35
34
|
|
|
36
35
|
DEFAULT_BATCH_SIZE = 10000
|
|
37
36
|
STOP_SIGNAL = "STOP"
|
|
@@ -50,34 +49,30 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
|
|
|
50
49
|
|
|
51
50
|
|
|
52
51
|
def udf_entrypoint() -> int:
|
|
52
|
+
"""Parallel processing (faster for more CPU-heavy UDFs)."""
|
|
53
53
|
# Load UDF info from stdin
|
|
54
54
|
udf_info: UdfInfo = load(stdin.buffer)
|
|
55
55
|
|
|
56
|
-
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
57
|
-
dispatch = UDFDispatcher(udf_info)
|
|
58
|
-
|
|
59
56
|
query = udf_info["query"]
|
|
60
|
-
rows_total = udf_info["rows_total"]
|
|
61
57
|
batching = udf_info["batching"]
|
|
62
58
|
is_generator = udf_info["is_generator"]
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
59
|
+
|
|
60
|
+
download_cb = get_download_callback()
|
|
61
|
+
processed_cb = get_processed_callback()
|
|
62
|
+
generated_cb = get_generated_callback(is_generator)
|
|
66
63
|
|
|
67
64
|
wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
|
|
68
65
|
warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
|
|
69
66
|
|
|
67
|
+
id_col = get_query_id_column(query)
|
|
68
|
+
|
|
70
69
|
with contextlib.closing(
|
|
71
|
-
batching(warehouse.dataset_select_paginated, query,
|
|
70
|
+
batching(warehouse.dataset_select_paginated, query, id_col=id_col)
|
|
72
71
|
) as udf_inputs:
|
|
73
|
-
download_cb = get_download_callback()
|
|
74
|
-
processed_cb = get_processed_callback()
|
|
75
|
-
generated_cb = get_generated_callback(is_generator)
|
|
76
72
|
try:
|
|
77
|
-
|
|
73
|
+
UDFDispatcher(udf_info).run_udf(
|
|
78
74
|
udf_inputs,
|
|
79
|
-
|
|
80
|
-
n_workers=n_workers,
|
|
75
|
+
ids_only=id_col is not None,
|
|
81
76
|
download_cb=download_cb,
|
|
82
77
|
processed_cb=processed_cb,
|
|
83
78
|
generated_cb=generated_cb,
|
|
@@ -90,17 +85,18 @@ def udf_entrypoint() -> int:
|
|
|
90
85
|
return 0
|
|
91
86
|
|
|
92
87
|
|
|
93
|
-
def udf_worker_entrypoint() -> int:
|
|
88
|
+
def udf_worker_entrypoint(fd: Optional[int] = None) -> int:
|
|
94
89
|
if not (udf_distributor_class := get_udf_distributor_class()):
|
|
95
90
|
raise RuntimeError(
|
|
96
91
|
f"{DISTRIBUTED_IMPORT_PATH} import path is required "
|
|
97
92
|
"for distributed UDF processing."
|
|
98
93
|
)
|
|
99
|
-
|
|
94
|
+
|
|
95
|
+
return udf_distributor_class.run_udf(fd)
|
|
100
96
|
|
|
101
97
|
|
|
102
98
|
class UDFDispatcher:
|
|
103
|
-
|
|
99
|
+
_catalog: Optional[Catalog] = None
|
|
104
100
|
task_queue: Optional[multiprocess.Queue] = None
|
|
105
101
|
done_queue: Optional[multiprocess.Queue] = None
|
|
106
102
|
|
|
@@ -115,77 +111,147 @@ class UDFDispatcher:
|
|
|
115
111
|
self.cache = udf_info["cache"]
|
|
116
112
|
self.is_generator = udf_info["is_generator"]
|
|
117
113
|
self.is_batching = udf_info["batching"].is_batching
|
|
114
|
+
self.processes = udf_info["processes"]
|
|
115
|
+
self.rows_total = udf_info["rows_total"]
|
|
118
116
|
self.buffer_size = buffer_size
|
|
119
|
-
self.catalog = None
|
|
120
117
|
self.task_queue = None
|
|
121
118
|
self.done_queue = None
|
|
122
119
|
self.ctx = get_context("spawn")
|
|
123
120
|
|
|
124
|
-
|
|
125
|
-
|
|
121
|
+
@property
|
|
122
|
+
def catalog(self) -> "Catalog":
|
|
123
|
+
if not self._catalog:
|
|
126
124
|
ms_cls, ms_args, ms_kwargs = self.metastore_clone_params
|
|
127
125
|
metastore: AbstractMetastore = ms_cls(*ms_args, **ms_kwargs)
|
|
128
126
|
ws_cls, ws_args, ws_kwargs = self.warehouse_clone_params
|
|
129
127
|
warehouse: AbstractWarehouse = ws_cls(*ws_args, **ws_kwargs)
|
|
130
|
-
self.
|
|
131
|
-
|
|
128
|
+
self._catalog = Catalog(metastore, warehouse, **self.catalog_init_params)
|
|
129
|
+
return self._catalog
|
|
130
|
+
|
|
131
|
+
def _create_worker(self) -> "UDFWorker":
|
|
132
|
+
udf: UDFAdapter = loads(self.udf_data)
|
|
132
133
|
return UDFWorker(
|
|
133
134
|
self.catalog,
|
|
134
|
-
|
|
135
|
+
udf,
|
|
135
136
|
self.task_queue,
|
|
136
137
|
self.done_queue,
|
|
137
138
|
self.query,
|
|
138
139
|
self.table,
|
|
139
|
-
self.is_batching,
|
|
140
140
|
self.cache,
|
|
141
|
+
self.is_batching,
|
|
141
142
|
self.udf_fields,
|
|
142
143
|
)
|
|
143
144
|
|
|
144
|
-
def _run_worker(self) -> None:
|
|
145
|
+
def _run_worker(self, ids_only: bool) -> None:
|
|
145
146
|
try:
|
|
146
147
|
worker = self._create_worker()
|
|
147
|
-
worker.run()
|
|
148
|
+
worker.run(ids_only)
|
|
148
149
|
except (Exception, KeyboardInterrupt) as e:
|
|
149
150
|
if self.done_queue:
|
|
150
151
|
put_into_queue(
|
|
151
152
|
self.done_queue,
|
|
152
153
|
{"status": FAILED_STATUS, "exception": e},
|
|
153
154
|
)
|
|
155
|
+
if isinstance(e, KeyboardInterrupt):
|
|
156
|
+
return
|
|
154
157
|
raise
|
|
155
158
|
|
|
156
|
-
|
|
157
|
-
|
|
159
|
+
def run_udf(
|
|
160
|
+
self,
|
|
161
|
+
input_rows: Iterable["RowsOutput"],
|
|
162
|
+
ids_only: bool,
|
|
163
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
164
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
165
|
+
generated_cb: Callback = DEFAULT_CALLBACK,
|
|
166
|
+
) -> None:
|
|
167
|
+
n_workers = self.processes
|
|
168
|
+
if n_workers is True:
|
|
169
|
+
n_workers = None # Use default number of CPUs (cores)
|
|
170
|
+
elif not n_workers or n_workers < 1:
|
|
171
|
+
n_workers = 1 # Single-threaded (on this worker)
|
|
158
172
|
n_workers = get_n_workers_from_arg(n_workers)
|
|
159
|
-
for _ in range(n_workers):
|
|
160
|
-
put_into_queue(task_queue, STOP_SIGNAL)
|
|
161
173
|
|
|
162
|
-
|
|
174
|
+
if n_workers == 1:
|
|
175
|
+
# no need to spawn worker processes if we are running in a single process
|
|
176
|
+
self.run_udf_single(
|
|
177
|
+
input_rows, ids_only, download_cb, processed_cb, generated_cb
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
if self.buffer_size < n_workers:
|
|
181
|
+
raise RuntimeError(
|
|
182
|
+
"Parallel run error: buffer size is smaller than "
|
|
183
|
+
f"number of workers: {self.buffer_size} < {n_workers}"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
self.run_udf_parallel(
|
|
187
|
+
n_workers, input_rows, ids_only, download_cb, processed_cb, generated_cb
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def run_udf_single(
|
|
163
191
|
self,
|
|
164
|
-
input_rows: Iterable[RowsOutput],
|
|
165
|
-
|
|
166
|
-
n_workers: Optional[int] = None,
|
|
192
|
+
input_rows: Iterable["RowsOutput"],
|
|
193
|
+
ids_only: bool,
|
|
167
194
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
168
195
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
169
196
|
generated_cb: Callback = DEFAULT_CALLBACK,
|
|
170
197
|
) -> None:
|
|
171
|
-
|
|
198
|
+
udf: UDFAdapter = loads(self.udf_data)
|
|
199
|
+
|
|
200
|
+
if ids_only and not self.is_batching:
|
|
201
|
+
input_rows = flatten(input_rows)
|
|
202
|
+
|
|
203
|
+
def get_inputs() -> Iterable["RowsOutput"]:
|
|
204
|
+
warehouse = self.catalog.warehouse.clone()
|
|
205
|
+
if ids_only:
|
|
206
|
+
yield from warehouse.dataset_rows_select_from_ids(
|
|
207
|
+
self.query, input_rows, self.is_batching
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
yield from input_rows
|
|
211
|
+
|
|
212
|
+
prefetch = udf.prefetch
|
|
213
|
+
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
|
|
214
|
+
udf_results = udf.run(
|
|
215
|
+
self.udf_fields,
|
|
216
|
+
get_inputs(),
|
|
217
|
+
self.catalog,
|
|
218
|
+
self.cache,
|
|
219
|
+
download_cb=download_cb,
|
|
220
|
+
processed_cb=processed_cb,
|
|
221
|
+
)
|
|
222
|
+
with safe_closing(udf_results):
|
|
223
|
+
process_udf_outputs(
|
|
224
|
+
self.catalog.warehouse.clone(),
|
|
225
|
+
self.table,
|
|
226
|
+
udf_results,
|
|
227
|
+
udf,
|
|
228
|
+
cb=generated_cb,
|
|
229
|
+
)
|
|
172
230
|
|
|
173
|
-
|
|
231
|
+
def input_batch_size(self, n_workers: int) -> int:
|
|
232
|
+
input_batch_size = self.rows_total // n_workers
|
|
174
233
|
if input_batch_size == 0:
|
|
175
234
|
input_batch_size = 1
|
|
176
235
|
elif input_batch_size > DEFAULT_BATCH_SIZE:
|
|
177
236
|
input_batch_size = DEFAULT_BATCH_SIZE
|
|
237
|
+
return input_batch_size
|
|
178
238
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
239
|
+
def run_udf_parallel( # noqa: C901, PLR0912
|
|
240
|
+
self,
|
|
241
|
+
n_workers: int,
|
|
242
|
+
input_rows: Iterable["RowsOutput"],
|
|
243
|
+
ids_only: bool,
|
|
244
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
245
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
246
|
+
generated_cb: Callback = DEFAULT_CALLBACK,
|
|
247
|
+
) -> None:
|
|
185
248
|
self.task_queue = self.ctx.Queue()
|
|
186
249
|
self.done_queue = self.ctx.Queue()
|
|
250
|
+
|
|
187
251
|
pool = [
|
|
188
|
-
self.ctx.Process(
|
|
252
|
+
self.ctx.Process(
|
|
253
|
+
name=f"Worker-UDF-{i}", target=self._run_worker, args=[ids_only]
|
|
254
|
+
)
|
|
189
255
|
for i in range(n_workers)
|
|
190
256
|
]
|
|
191
257
|
for p in pool:
|
|
@@ -198,7 +264,8 @@ class UDFDispatcher:
|
|
|
198
264
|
input_finished = False
|
|
199
265
|
|
|
200
266
|
if not self.is_batching:
|
|
201
|
-
|
|
267
|
+
batch_size = self.input_batch_size(n_workers)
|
|
268
|
+
input_rows = batched(flatten(input_rows), batch_size)
|
|
202
269
|
|
|
203
270
|
# Stop all workers after the input rows have finished processing
|
|
204
271
|
input_data = chain(input_rows, [STOP_SIGNAL] * n_workers)
|
|
@@ -213,10 +280,15 @@ class UDFDispatcher:
|
|
|
213
280
|
|
|
214
281
|
# Process all tasks
|
|
215
282
|
while n_workers > 0:
|
|
216
|
-
|
|
283
|
+
try:
|
|
284
|
+
result = get_from_queue(self.done_queue)
|
|
285
|
+
except KeyboardInterrupt:
|
|
286
|
+
break
|
|
217
287
|
|
|
288
|
+
if bytes_downloaded := result.get("bytes_downloaded"):
|
|
289
|
+
download_cb.relative_update(bytes_downloaded)
|
|
218
290
|
if downloaded := result.get("downloaded"):
|
|
219
|
-
download_cb.
|
|
291
|
+
download_cb.increment_file_count(downloaded)
|
|
220
292
|
if processed := result.get("processed"):
|
|
221
293
|
processed_cb.relative_update(processed)
|
|
222
294
|
if generated := result.get("generated"):
|
|
@@ -246,13 +318,12 @@ class UDFDispatcher:
|
|
|
246
318
|
# Stop all workers if there is an unexpected exception
|
|
247
319
|
for _ in pool:
|
|
248
320
|
put_into_queue(self.task_queue, STOP_SIGNAL)
|
|
249
|
-
self.task_queue.close()
|
|
250
321
|
|
|
251
322
|
# This allows workers (and this process) to exit without
|
|
252
323
|
# consuming any remaining data in the queues.
|
|
253
324
|
# (If they exit due to an exception.)
|
|
254
|
-
self.task_queue.
|
|
255
|
-
self.
|
|
325
|
+
self.task_queue.close()
|
|
326
|
+
self.task_queue.join_thread()
|
|
256
327
|
|
|
257
328
|
# Flush all items from the done queue.
|
|
258
329
|
# This is needed if any workers are still running.
|
|
@@ -262,6 +333,9 @@ class UDFDispatcher:
|
|
|
262
333
|
if status != OK_STATUS:
|
|
263
334
|
n_workers -= 1
|
|
264
335
|
|
|
336
|
+
self.done_queue.close()
|
|
337
|
+
self.done_queue.join_thread()
|
|
338
|
+
|
|
265
339
|
# Wait for workers to stop
|
|
266
340
|
for p in pool:
|
|
267
341
|
p.join()
|
|
@@ -273,8 +347,7 @@ class DownloadCallback(Callback):
|
|
|
273
347
|
super().__init__()
|
|
274
348
|
|
|
275
349
|
def relative_update(self, inc: int = 1) -> None:
|
|
276
|
-
|
|
277
|
-
pass
|
|
350
|
+
put_into_queue(self.queue, {"status": NOTIFY_STATUS, "bytes_downloaded": inc})
|
|
278
351
|
|
|
279
352
|
def increment_file_count(self, inc: int = 1) -> None:
|
|
280
353
|
put_into_queue(self.queue, {"status": NOTIFY_STATUS, "downloaded": inc})
|
|
@@ -303,8 +376,8 @@ class UDFWorker:
|
|
|
303
376
|
done_queue: "multiprocess.Queue",
|
|
304
377
|
query: "Select",
|
|
305
378
|
table: "Table",
|
|
306
|
-
is_batching: bool,
|
|
307
379
|
cache: bool,
|
|
380
|
+
is_batching: bool,
|
|
308
381
|
udf_fields: Sequence[str],
|
|
309
382
|
) -> None:
|
|
310
383
|
self.catalog = catalog
|
|
@@ -313,21 +386,21 @@ class UDFWorker:
|
|
|
313
386
|
self.done_queue = done_queue
|
|
314
387
|
self.query = query
|
|
315
388
|
self.table = table
|
|
316
|
-
self.is_batching = is_batching
|
|
317
389
|
self.cache = cache
|
|
390
|
+
self.is_batching = is_batching
|
|
318
391
|
self.udf_fields = udf_fields
|
|
319
392
|
|
|
320
393
|
self.download_cb = DownloadCallback(self.done_queue)
|
|
321
394
|
self.processed_cb = ProcessedCallback("processed", self.done_queue)
|
|
322
395
|
self.generated_cb = ProcessedCallback("generated", self.done_queue)
|
|
323
396
|
|
|
324
|
-
def run(self) -> None:
|
|
397
|
+
def run(self, ids_only: bool) -> None:
|
|
325
398
|
prefetch = self.udf.prefetch
|
|
326
399
|
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
|
|
327
400
|
catalog = clone_catalog_with_cache(self.catalog, _cache)
|
|
328
401
|
udf_results = self.udf.run(
|
|
329
402
|
self.udf_fields,
|
|
330
|
-
self.get_inputs(),
|
|
403
|
+
self.get_inputs(ids_only),
|
|
331
404
|
catalog,
|
|
332
405
|
self.cache,
|
|
333
406
|
download_cb=self.download_cb,
|
|
@@ -348,23 +421,12 @@ class UDFWorker:
|
|
|
348
421
|
put_into_queue(self.done_queue, {"status": OK_STATUS})
|
|
349
422
|
yield row
|
|
350
423
|
|
|
351
|
-
def get_inputs(self):
|
|
424
|
+
def get_inputs(self, ids_only: bool) -> Iterable["RowsOutput"]:
|
|
352
425
|
warehouse = self.catalog.warehouse.clone()
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
ids = [row[0] for row in batch.rows]
|
|
358
|
-
rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids)))
|
|
359
|
-
yield RowsOutputBatch(list(rows))
|
|
360
|
-
else:
|
|
361
|
-
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
362
|
-
yield from warehouse.dataset_rows_select(
|
|
363
|
-
self.query.where(col_id.in_(batch))
|
|
426
|
+
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
427
|
+
if ids_only:
|
|
428
|
+
yield from warehouse.dataset_rows_select_from_ids(
|
|
429
|
+
self.query, batch, self.is_batching
|
|
364
430
|
)
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
class RepeatTimer(Timer):
|
|
368
|
-
def run(self):
|
|
369
|
-
while not self.finished.wait(self.interval):
|
|
370
|
-
self.function(*self.args, **self.kwargs)
|
|
431
|
+
else:
|
|
432
|
+
yield from batch
|
datachain/query/metrics.py
CHANGED
|
@@ -15,11 +15,10 @@ def set(key: str, value: Union[str, int, float, bool, None]) -> None: # noqa: P
|
|
|
15
15
|
metrics[key] = value
|
|
16
16
|
|
|
17
17
|
if job_id := os.getenv("DATACHAIN_JOB_ID"):
|
|
18
|
-
from datachain.data_storage.job import JobStatus
|
|
19
18
|
from datachain.query.session import Session
|
|
20
19
|
|
|
21
20
|
metastore = Session.get().catalog.metastore
|
|
22
|
-
metastore.
|
|
21
|
+
metastore.update_job(job_id, metrics=metrics)
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
def get(key: str) -> Optional[Union[str, int, float, bool]]:
|
datachain/query/queue.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import Any
|
|
|
7
7
|
|
|
8
8
|
import msgpack
|
|
9
9
|
|
|
10
|
-
from datachain.query.batch import RowsOutput
|
|
10
|
+
from datachain.query.batch import RowsOutput
|
|
11
11
|
|
|
12
12
|
DEFAULT_BATCH_SIZE = 10000
|
|
13
13
|
STOP_SIGNAL = "STOP"
|
|
@@ -56,7 +56,6 @@ def put_into_queue(queue: Queue, item: Any) -> None:
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
MSGPACK_EXT_TYPE_DATETIME = 42
|
|
59
|
-
MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH = 43
|
|
60
59
|
|
|
61
60
|
|
|
62
61
|
def _msgpack_pack_extended_types(obj: Any) -> msgpack.ExtType:
|
|
@@ -70,12 +69,6 @@ def _msgpack_pack_extended_types(obj: Any) -> msgpack.ExtType:
|
|
|
70
69
|
data = (obj.timestamp(),) # type: ignore # noqa: PGH003
|
|
71
70
|
return msgpack.ExtType(MSGPACK_EXT_TYPE_DATETIME, pack("!d", *data))
|
|
72
71
|
|
|
73
|
-
if isinstance(obj, RowsOutputBatch):
|
|
74
|
-
return msgpack.ExtType(
|
|
75
|
-
MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH,
|
|
76
|
-
msgpack_pack(obj.rows),
|
|
77
|
-
)
|
|
78
|
-
|
|
79
72
|
raise TypeError(f"Unknown type: {obj}")
|
|
80
73
|
|
|
81
74
|
|
|
@@ -100,9 +93,6 @@ def _msgpack_unpack_extended_types(code: int, data: bytes) -> Any:
|
|
|
100
93
|
tz_info = datetime.timezone(datetime.timedelta(seconds=timezone_offset))
|
|
101
94
|
return datetime.datetime.fromtimestamp(timestamp, tz=tz_info)
|
|
102
95
|
|
|
103
|
-
if code == MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH:
|
|
104
|
-
return RowsOutputBatch(msgpack_unpack(data))
|
|
105
|
-
|
|
106
96
|
return msgpack.ExtType(code, data)
|
|
107
97
|
|
|
108
98
|
|
datachain/query/session.py
CHANGED
|
@@ -69,7 +69,7 @@ class Session:
|
|
|
69
69
|
self.catalog = catalog or get_catalog(
|
|
70
70
|
client_config=client_config, in_memory=in_memory
|
|
71
71
|
)
|
|
72
|
-
self.dataset_versions: list[tuple[DatasetRecord,
|
|
72
|
+
self.dataset_versions: list[tuple[DatasetRecord, str, bool]] = []
|
|
73
73
|
|
|
74
74
|
def __enter__(self):
|
|
75
75
|
# Push the current context onto the stack
|
|
@@ -90,7 +90,7 @@ class Session:
|
|
|
90
90
|
Session.SESSION_CONTEXTS.pop()
|
|
91
91
|
|
|
92
92
|
def add_dataset_version(
|
|
93
|
-
self, dataset: "DatasetRecord", version:
|
|
93
|
+
self, dataset: "DatasetRecord", version: str, listing: bool = False
|
|
94
94
|
) -> None:
|
|
95
95
|
self.dataset_versions.append((dataset, version, listing))
|
|
96
96
|
|
datachain/query/udf.py
CHANGED
datachain/query/utils.py
CHANGED
|
@@ -1,33 +1,27 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import sqlalchemy as sa
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
from sqlalchemy import ColumnElement, Select, TextClause
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
ColT = Union[Column, "ColumnElement", "TextClause"]
|
|
5
|
+
ColT = Union[sa.Column, sa.ColumnElement, sa.TextClause]
|
|
10
6
|
|
|
11
7
|
|
|
12
8
|
def column_name(col: ColT) -> str:
|
|
13
9
|
"""Returns column name from column element."""
|
|
14
|
-
return col.name if isinstance(col, Column) else str(col)
|
|
10
|
+
return col.name if isinstance(col, sa.Column) else str(col)
|
|
15
11
|
|
|
16
12
|
|
|
17
|
-
def get_query_column(query:
|
|
13
|
+
def get_query_column(query: sa.Select, name: str) -> Optional[ColT]:
|
|
18
14
|
"""Returns column element from query by name or None if column not found."""
|
|
19
15
|
return next((col for col in query.inner_columns if column_name(col) == name), None)
|
|
20
16
|
|
|
21
17
|
|
|
22
|
-
def get_query_id_column(query:
|
|
18
|
+
def get_query_id_column(query: sa.Select) -> Optional[sa.ColumnElement]:
|
|
23
19
|
"""Returns ID column element from query or None if column not found."""
|
|
24
20
|
col = get_query_column(query, "sys__id")
|
|
25
|
-
if col is None
|
|
26
|
-
raise RuntimeError("sys__id column not found in query")
|
|
27
|
-
return col
|
|
21
|
+
return col if col is not None and isinstance(col, sa.ColumnElement) else None
|
|
28
22
|
|
|
29
23
|
|
|
30
|
-
def select_only_columns(query:
|
|
24
|
+
def select_only_columns(query: sa.Select, *names: str) -> sa.Select:
|
|
31
25
|
"""Returns query selecting defined columns only."""
|
|
32
26
|
if not names:
|
|
33
27
|
return query
|
datachain/remote/studio.py
CHANGED
|
@@ -307,7 +307,7 @@ class StudioClient:
|
|
|
307
307
|
def rm_dataset(
|
|
308
308
|
self,
|
|
309
309
|
name: str,
|
|
310
|
-
version: Optional[
|
|
310
|
+
version: Optional[str] = None,
|
|
311
311
|
force: Optional[bool] = False,
|
|
312
312
|
) -> Response[DatasetInfoData]:
|
|
313
313
|
return self._send_request(
|
|
@@ -336,7 +336,7 @@ class StudioClient:
|
|
|
336
336
|
return response
|
|
337
337
|
|
|
338
338
|
def dataset_rows_chunk(
|
|
339
|
-
self, name: str, version:
|
|
339
|
+
self, name: str, version: str, offset: int
|
|
340
340
|
) -> Response[DatasetRowsData]:
|
|
341
341
|
req_data = {"dataset_name": name, "dataset_version": version}
|
|
342
342
|
return self._send_request_msgpack(
|
|
@@ -353,7 +353,7 @@ class StudioClient:
|
|
|
353
353
|
)
|
|
354
354
|
|
|
355
355
|
def export_dataset_table(
|
|
356
|
-
self, name: str, version:
|
|
356
|
+
self, name: str, version: str
|
|
357
357
|
) -> Response[DatasetExportSignedUrls]:
|
|
358
358
|
return self._send_request(
|
|
359
359
|
"datachain/datasets/export",
|
|
@@ -362,7 +362,7 @@ class StudioClient:
|
|
|
362
362
|
)
|
|
363
363
|
|
|
364
364
|
def dataset_export_status(
|
|
365
|
-
self, name: str, version:
|
|
365
|
+
self, name: str, version: str
|
|
366
366
|
) -> Response[DatasetExportStatus]:
|
|
367
367
|
return self._send_request(
|
|
368
368
|
"datachain/datasets/export-status",
|
datachain/semver.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
def parse(version: str) -> tuple[int, int, int]:
|
|
2
|
+
"""Parsing semver into 3 integers: major, minor, patch"""
|
|
3
|
+
validate(version)
|
|
4
|
+
parts = version.split(".")
|
|
5
|
+
return (int(parts[0]), int(parts[1]), int(parts[2]))
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def validate(version: str) -> None:
|
|
9
|
+
"""
|
|
10
|
+
Raises exception if version doesn't have valid semver format which is:
|
|
11
|
+
<major>.<minor>.<patch> or one of version parts is not positive integer
|
|
12
|
+
"""
|
|
13
|
+
error_message = (
|
|
14
|
+
"Invalid version. It should be in format: <major>.<minor>.<patch> where"
|
|
15
|
+
" each version part is positive integer"
|
|
16
|
+
)
|
|
17
|
+
parts = version.split(".")
|
|
18
|
+
if len(parts) != 3:
|
|
19
|
+
raise ValueError(error_message)
|
|
20
|
+
for part in parts:
|
|
21
|
+
try:
|
|
22
|
+
val = int(part)
|
|
23
|
+
assert val >= 0
|
|
24
|
+
except (ValueError, AssertionError):
|
|
25
|
+
raise ValueError(error_message) from None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def create(major: int = 0, minor: int = 0, patch: int = 0) -> str:
|
|
29
|
+
"""Creates new semver from 3 integers: major, minor and patch"""
|
|
30
|
+
if major < 0 or minor < 0 or patch < 0:
|
|
31
|
+
raise ValueError("Major, minor and patch must be greater or equal to zero")
|
|
32
|
+
|
|
33
|
+
return ".".join([str(major), str(minor), str(patch)])
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def value(version: str) -> int:
|
|
37
|
+
"""
|
|
38
|
+
Calculate integer value of a version. This is useful when comparing two versions
|
|
39
|
+
"""
|
|
40
|
+
major, minor, patch = parse(version)
|
|
41
|
+
return major * 100 + minor * 10 + patch
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def compare(v1: str, v2: str) -> int:
|
|
45
|
+
"""
|
|
46
|
+
Compares 2 versions and returns:
|
|
47
|
+
-1 if v1 < v2
|
|
48
|
+
0 if v1 == v2
|
|
49
|
+
1 if v1 > v2
|
|
50
|
+
"""
|
|
51
|
+
v1_val = value(v1)
|
|
52
|
+
v2_val = value(v2)
|
|
53
|
+
|
|
54
|
+
if v1_val < v2_val:
|
|
55
|
+
return -1
|
|
56
|
+
if v1_val > v2_val:
|
|
57
|
+
return 1
|
|
58
|
+
return 0
|
datachain/studio.py
CHANGED
|
@@ -201,7 +201,7 @@ def edit_studio_dataset(
|
|
|
201
201
|
def remove_studio_dataset(
|
|
202
202
|
team_name: Optional[str],
|
|
203
203
|
name: str,
|
|
204
|
-
version: Optional[
|
|
204
|
+
version: Optional[str] = None,
|
|
205
205
|
force: Optional[bool] = False,
|
|
206
206
|
):
|
|
207
207
|
client = StudioClient(team=team_name)
|
datachain/utils.py
CHANGED