datachain 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -3,7 +3,6 @@ from collections.abc import Iterable, Sequence
3
3
  from itertools import chain
4
4
  from multiprocessing import cpu_count
5
5
  from sys import stdin
6
- from threading import Timer
7
6
  from typing import TYPE_CHECKING, Literal, Optional
8
7
 
9
8
  import multiprocess
@@ -15,7 +14,6 @@ from datachain.catalog import Catalog
15
14
  from datachain.catalog.catalog import clone_catalog_with_cache
16
15
  from datachain.catalog.loader import DISTRIBUTED_IMPORT_PATH, get_udf_distributor_class
17
16
  from datachain.lib.udf import _get_cache
18
- from datachain.query.batch import RowsOutput, RowsOutputBatch
19
17
  from datachain.query.dataset import (
20
18
  get_download_callback,
21
19
  get_generated_callback,
@@ -32,6 +30,7 @@ if TYPE_CHECKING:
32
30
 
33
31
  from datachain.data_storage import AbstractMetastore, AbstractWarehouse
34
32
  from datachain.lib.udf import UDFAdapter
33
+ from datachain.query.batch import RowsOutput
35
34
 
36
35
  DEFAULT_BATCH_SIZE = 10000
37
36
  STOP_SIGNAL = "STOP"
@@ -50,34 +49,30 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
50
49
 
51
50
 
52
51
  def udf_entrypoint() -> int:
52
+ """Parallel processing (faster for more CPU-heavy UDFs)."""
53
53
  # Load UDF info from stdin
54
54
  udf_info: UdfInfo = load(stdin.buffer)
55
55
 
56
- # Parallel processing (faster for more CPU-heavy UDFs)
57
- dispatch = UDFDispatcher(udf_info)
58
-
59
56
  query = udf_info["query"]
60
- rows_total = udf_info["rows_total"]
61
57
  batching = udf_info["batching"]
62
58
  is_generator = udf_info["is_generator"]
63
- n_workers = udf_info["processes"]
64
- if n_workers is True:
65
- n_workers = None # Use default number of CPUs (cores)
59
+
60
+ download_cb = get_download_callback()
61
+ processed_cb = get_processed_callback()
62
+ generated_cb = get_generated_callback(is_generator)
66
63
 
67
64
  wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
68
65
  warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
69
66
 
67
+ id_col = get_query_id_column(query)
68
+
70
69
  with contextlib.closing(
71
- batching(warehouse.dataset_select_paginated, query, ids_only=True)
70
+ batching(warehouse.dataset_select_paginated, query, id_col=id_col)
72
71
  ) as udf_inputs:
73
- download_cb = get_download_callback()
74
- processed_cb = get_processed_callback()
75
- generated_cb = get_generated_callback(is_generator)
76
72
  try:
77
- dispatch.run_udf_parallel(
73
+ UDFDispatcher(udf_info).run_udf(
78
74
  udf_inputs,
79
- rows_total=rows_total,
80
- n_workers=n_workers,
75
+ ids_only=id_col is not None,
81
76
  download_cb=download_cb,
82
77
  processed_cb=processed_cb,
83
78
  generated_cb=generated_cb,
@@ -90,17 +85,18 @@ def udf_entrypoint() -> int:
90
85
  return 0
91
86
 
92
87
 
93
- def udf_worker_entrypoint() -> int:
88
+ def udf_worker_entrypoint(fd: Optional[int] = None) -> int:
94
89
  if not (udf_distributor_class := get_udf_distributor_class()):
95
90
  raise RuntimeError(
96
91
  f"{DISTRIBUTED_IMPORT_PATH} import path is required "
97
92
  "for distributed UDF processing."
98
93
  )
99
- return udf_distributor_class.run_worker()
94
+
95
+ return udf_distributor_class.run_udf(fd)
100
96
 
101
97
 
102
98
  class UDFDispatcher:
103
- catalog: Optional[Catalog] = None
99
+ _catalog: Optional[Catalog] = None
104
100
  task_queue: Optional[multiprocess.Queue] = None
105
101
  done_queue: Optional[multiprocess.Queue] = None
106
102
 
@@ -115,77 +111,147 @@ class UDFDispatcher:
115
111
  self.cache = udf_info["cache"]
116
112
  self.is_generator = udf_info["is_generator"]
117
113
  self.is_batching = udf_info["batching"].is_batching
114
+ self.processes = udf_info["processes"]
115
+ self.rows_total = udf_info["rows_total"]
118
116
  self.buffer_size = buffer_size
119
- self.catalog = None
120
117
  self.task_queue = None
121
118
  self.done_queue = None
122
119
  self.ctx = get_context("spawn")
123
120
 
124
- def _create_worker(self) -> "UDFWorker":
125
- if not self.catalog:
121
+ @property
122
+ def catalog(self) -> "Catalog":
123
+ if not self._catalog:
126
124
  ms_cls, ms_args, ms_kwargs = self.metastore_clone_params
127
125
  metastore: AbstractMetastore = ms_cls(*ms_args, **ms_kwargs)
128
126
  ws_cls, ws_args, ws_kwargs = self.warehouse_clone_params
129
127
  warehouse: AbstractWarehouse = ws_cls(*ws_args, **ws_kwargs)
130
- self.catalog = Catalog(metastore, warehouse, **self.catalog_init_params)
131
- self.udf = loads(self.udf_data)
128
+ self._catalog = Catalog(metastore, warehouse, **self.catalog_init_params)
129
+ return self._catalog
130
+
131
+ def _create_worker(self) -> "UDFWorker":
132
+ udf: UDFAdapter = loads(self.udf_data)
132
133
  return UDFWorker(
133
134
  self.catalog,
134
- self.udf,
135
+ udf,
135
136
  self.task_queue,
136
137
  self.done_queue,
137
138
  self.query,
138
139
  self.table,
139
- self.is_batching,
140
140
  self.cache,
141
+ self.is_batching,
141
142
  self.udf_fields,
142
143
  )
143
144
 
144
- def _run_worker(self) -> None:
145
+ def _run_worker(self, ids_only: bool) -> None:
145
146
  try:
146
147
  worker = self._create_worker()
147
- worker.run()
148
+ worker.run(ids_only)
148
149
  except (Exception, KeyboardInterrupt) as e:
149
150
  if self.done_queue:
150
151
  put_into_queue(
151
152
  self.done_queue,
152
153
  {"status": FAILED_STATUS, "exception": e},
153
154
  )
155
+ if isinstance(e, KeyboardInterrupt):
156
+ return
154
157
  raise
155
158
 
156
- @staticmethod
157
- def send_stop_signal_to_workers(task_queue, n_workers: Optional[int] = None):
159
+ def run_udf(
160
+ self,
161
+ input_rows: Iterable["RowsOutput"],
162
+ ids_only: bool,
163
+ download_cb: Callback = DEFAULT_CALLBACK,
164
+ processed_cb: Callback = DEFAULT_CALLBACK,
165
+ generated_cb: Callback = DEFAULT_CALLBACK,
166
+ ) -> None:
167
+ n_workers = self.processes
168
+ if n_workers is True:
169
+ n_workers = None # Use default number of CPUs (cores)
170
+ elif not n_workers or n_workers < 1:
171
+ n_workers = 1 # Single-threaded (on this worker)
158
172
  n_workers = get_n_workers_from_arg(n_workers)
159
- for _ in range(n_workers):
160
- put_into_queue(task_queue, STOP_SIGNAL)
161
173
 
162
- def run_udf_parallel( # noqa: C901, PLR0912
174
+ if n_workers == 1:
175
+ # no need to spawn worker processes if we are running in a single process
176
+ self.run_udf_single(
177
+ input_rows, ids_only, download_cb, processed_cb, generated_cb
178
+ )
179
+ else:
180
+ if self.buffer_size < n_workers:
181
+ raise RuntimeError(
182
+ "Parallel run error: buffer size is smaller than "
183
+ f"number of workers: {self.buffer_size} < {n_workers}"
184
+ )
185
+
186
+ self.run_udf_parallel(
187
+ n_workers, input_rows, ids_only, download_cb, processed_cb, generated_cb
188
+ )
189
+
190
+ def run_udf_single(
163
191
  self,
164
- input_rows: Iterable[RowsOutput],
165
- rows_total: int,
166
- n_workers: Optional[int] = None,
192
+ input_rows: Iterable["RowsOutput"],
193
+ ids_only: bool,
167
194
  download_cb: Callback = DEFAULT_CALLBACK,
168
195
  processed_cb: Callback = DEFAULT_CALLBACK,
169
196
  generated_cb: Callback = DEFAULT_CALLBACK,
170
197
  ) -> None:
171
- n_workers = get_n_workers_from_arg(n_workers)
198
+ udf: UDFAdapter = loads(self.udf_data)
199
+
200
+ if ids_only and not self.is_batching:
201
+ input_rows = flatten(input_rows)
202
+
203
+ def get_inputs() -> Iterable["RowsOutput"]:
204
+ warehouse = self.catalog.warehouse.clone()
205
+ if ids_only:
206
+ yield from warehouse.dataset_rows_select_from_ids(
207
+ self.query, input_rows, self.is_batching
208
+ )
209
+ else:
210
+ yield from input_rows
211
+
212
+ prefetch = udf.prefetch
213
+ with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
214
+ udf_results = udf.run(
215
+ self.udf_fields,
216
+ get_inputs(),
217
+ self.catalog,
218
+ self.cache,
219
+ download_cb=download_cb,
220
+ processed_cb=processed_cb,
221
+ )
222
+ with safe_closing(udf_results):
223
+ process_udf_outputs(
224
+ self.catalog.warehouse.clone(),
225
+ self.table,
226
+ udf_results,
227
+ udf,
228
+ cb=generated_cb,
229
+ )
172
230
 
173
- input_batch_size = rows_total // n_workers
231
+ def input_batch_size(self, n_workers: int) -> int:
232
+ input_batch_size = self.rows_total // n_workers
174
233
  if input_batch_size == 0:
175
234
  input_batch_size = 1
176
235
  elif input_batch_size > DEFAULT_BATCH_SIZE:
177
236
  input_batch_size = DEFAULT_BATCH_SIZE
237
+ return input_batch_size
178
238
 
179
- if self.buffer_size < n_workers:
180
- raise RuntimeError(
181
- "Parallel run error: buffer size is smaller than "
182
- f"number of workers: {self.buffer_size} < {n_workers}"
183
- )
184
-
239
+ def run_udf_parallel( # noqa: C901, PLR0912
240
+ self,
241
+ n_workers: int,
242
+ input_rows: Iterable["RowsOutput"],
243
+ ids_only: bool,
244
+ download_cb: Callback = DEFAULT_CALLBACK,
245
+ processed_cb: Callback = DEFAULT_CALLBACK,
246
+ generated_cb: Callback = DEFAULT_CALLBACK,
247
+ ) -> None:
185
248
  self.task_queue = self.ctx.Queue()
186
249
  self.done_queue = self.ctx.Queue()
250
+
187
251
  pool = [
188
- self.ctx.Process(name=f"Worker-UDF-{i}", target=self._run_worker)
252
+ self.ctx.Process(
253
+ name=f"Worker-UDF-{i}", target=self._run_worker, args=[ids_only]
254
+ )
189
255
  for i in range(n_workers)
190
256
  ]
191
257
  for p in pool:
@@ -198,7 +264,8 @@ class UDFDispatcher:
198
264
  input_finished = False
199
265
 
200
266
  if not self.is_batching:
201
- input_rows = batched(flatten(input_rows), input_batch_size)
267
+ batch_size = self.input_batch_size(n_workers)
268
+ input_rows = batched(flatten(input_rows), batch_size)
202
269
 
203
270
  # Stop all workers after the input rows have finished processing
204
271
  input_data = chain(input_rows, [STOP_SIGNAL] * n_workers)
@@ -213,10 +280,15 @@ class UDFDispatcher:
213
280
 
214
281
  # Process all tasks
215
282
  while n_workers > 0:
216
- result = get_from_queue(self.done_queue)
283
+ try:
284
+ result = get_from_queue(self.done_queue)
285
+ except KeyboardInterrupt:
286
+ break
217
287
 
288
+ if bytes_downloaded := result.get("bytes_downloaded"):
289
+ download_cb.relative_update(bytes_downloaded)
218
290
  if downloaded := result.get("downloaded"):
219
- download_cb.relative_update(downloaded)
291
+ download_cb.increment_file_count(downloaded)
220
292
  if processed := result.get("processed"):
221
293
  processed_cb.relative_update(processed)
222
294
  if generated := result.get("generated"):
@@ -246,13 +318,12 @@ class UDFDispatcher:
246
318
  # Stop all workers if there is an unexpected exception
247
319
  for _ in pool:
248
320
  put_into_queue(self.task_queue, STOP_SIGNAL)
249
- self.task_queue.close()
250
321
 
251
322
  # This allows workers (and this process) to exit without
252
323
  # consuming any remaining data in the queues.
253
324
  # (If they exit due to an exception.)
254
- self.task_queue.cancel_join_thread()
255
- self.done_queue.cancel_join_thread()
325
+ self.task_queue.close()
326
+ self.task_queue.join_thread()
256
327
 
257
328
  # Flush all items from the done queue.
258
329
  # This is needed if any workers are still running.
@@ -262,6 +333,9 @@ class UDFDispatcher:
262
333
  if status != OK_STATUS:
263
334
  n_workers -= 1
264
335
 
336
+ self.done_queue.close()
337
+ self.done_queue.join_thread()
338
+
265
339
  # Wait for workers to stop
266
340
  for p in pool:
267
341
  p.join()
@@ -273,8 +347,7 @@ class DownloadCallback(Callback):
273
347
  super().__init__()
274
348
 
275
349
  def relative_update(self, inc: int = 1) -> None:
276
- # This callback is used to notify the size of the downloaded files
277
- pass
350
+ put_into_queue(self.queue, {"status": NOTIFY_STATUS, "bytes_downloaded": inc})
278
351
 
279
352
  def increment_file_count(self, inc: int = 1) -> None:
280
353
  put_into_queue(self.queue, {"status": NOTIFY_STATUS, "downloaded": inc})
@@ -303,8 +376,8 @@ class UDFWorker:
303
376
  done_queue: "multiprocess.Queue",
304
377
  query: "Select",
305
378
  table: "Table",
306
- is_batching: bool,
307
379
  cache: bool,
380
+ is_batching: bool,
308
381
  udf_fields: Sequence[str],
309
382
  ) -> None:
310
383
  self.catalog = catalog
@@ -313,21 +386,21 @@ class UDFWorker:
313
386
  self.done_queue = done_queue
314
387
  self.query = query
315
388
  self.table = table
316
- self.is_batching = is_batching
317
389
  self.cache = cache
390
+ self.is_batching = is_batching
318
391
  self.udf_fields = udf_fields
319
392
 
320
393
  self.download_cb = DownloadCallback(self.done_queue)
321
394
  self.processed_cb = ProcessedCallback("processed", self.done_queue)
322
395
  self.generated_cb = ProcessedCallback("generated", self.done_queue)
323
396
 
324
- def run(self) -> None:
397
+ def run(self, ids_only: bool) -> None:
325
398
  prefetch = self.udf.prefetch
326
399
  with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
327
400
  catalog = clone_catalog_with_cache(self.catalog, _cache)
328
401
  udf_results = self.udf.run(
329
402
  self.udf_fields,
330
- self.get_inputs(),
403
+ self.get_inputs(ids_only),
331
404
  catalog,
332
405
  self.cache,
333
406
  download_cb=self.download_cb,
@@ -348,23 +421,12 @@ class UDFWorker:
348
421
  put_into_queue(self.done_queue, {"status": OK_STATUS})
349
422
  yield row
350
423
 
351
- def get_inputs(self):
424
+ def get_inputs(self, ids_only: bool) -> Iterable["RowsOutput"]:
352
425
  warehouse = self.catalog.warehouse.clone()
353
- col_id = get_query_id_column(self.query)
354
-
355
- if self.is_batching:
356
- while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
357
- ids = [row[0] for row in batch.rows]
358
- rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids)))
359
- yield RowsOutputBatch(list(rows))
360
- else:
361
- while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
362
- yield from warehouse.dataset_rows_select(
363
- self.query.where(col_id.in_(batch))
426
+ while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
427
+ if ids_only:
428
+ yield from warehouse.dataset_rows_select_from_ids(
429
+ self.query, batch, self.is_batching
364
430
  )
365
-
366
-
367
- class RepeatTimer(Timer):
368
- def run(self):
369
- while not self.finished.wait(self.interval):
370
- self.function(*self.args, **self.kwargs)
431
+ else:
432
+ yield from batch
@@ -15,11 +15,10 @@ def set(key: str, value: Union[str, int, float, bool, None]) -> None: # noqa: P
15
15
  metrics[key] = value
16
16
 
17
17
  if job_id := os.getenv("DATACHAIN_JOB_ID"):
18
- from datachain.data_storage.job import JobStatus
19
18
  from datachain.query.session import Session
20
19
 
21
20
  metastore = Session.get().catalog.metastore
22
- metastore.set_job_status(job_id, JobStatus.RUNNING, metrics=metrics)
21
+ metastore.update_job(job_id, metrics=metrics)
23
22
 
24
23
 
25
24
  def get(key: str) -> Optional[Union[str, int, float, bool]]:
datachain/query/queue.py CHANGED
@@ -7,7 +7,7 @@ from typing import Any
7
7
 
8
8
  import msgpack
9
9
 
10
- from datachain.query.batch import RowsOutput, RowsOutputBatch
10
+ from datachain.query.batch import RowsOutput
11
11
 
12
12
  DEFAULT_BATCH_SIZE = 10000
13
13
  STOP_SIGNAL = "STOP"
@@ -56,7 +56,6 @@ def put_into_queue(queue: Queue, item: Any) -> None:
56
56
 
57
57
 
58
58
  MSGPACK_EXT_TYPE_DATETIME = 42
59
- MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH = 43
60
59
 
61
60
 
62
61
  def _msgpack_pack_extended_types(obj: Any) -> msgpack.ExtType:
@@ -70,12 +69,6 @@ def _msgpack_pack_extended_types(obj: Any) -> msgpack.ExtType:
70
69
  data = (obj.timestamp(),) # type: ignore # noqa: PGH003
71
70
  return msgpack.ExtType(MSGPACK_EXT_TYPE_DATETIME, pack("!d", *data))
72
71
 
73
- if isinstance(obj, RowsOutputBatch):
74
- return msgpack.ExtType(
75
- MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH,
76
- msgpack_pack(obj.rows),
77
- )
78
-
79
72
  raise TypeError(f"Unknown type: {obj}")
80
73
 
81
74
 
@@ -100,9 +93,6 @@ def _msgpack_unpack_extended_types(code: int, data: bytes) -> Any:
100
93
  tz_info = datetime.timezone(datetime.timedelta(seconds=timezone_offset))
101
94
  return datetime.datetime.fromtimestamp(timestamp, tz=tz_info)
102
95
 
103
- if code == MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH:
104
- return RowsOutputBatch(msgpack_unpack(data))
105
-
106
96
  return msgpack.ExtType(code, data)
107
97
 
108
98
 
datachain/query/udf.py CHANGED
@@ -46,4 +46,4 @@ class AbstractUDFDistributor(ABC):
46
46
 
47
47
  @staticmethod
48
48
  @abstractmethod
49
- def run_worker() -> int: ...
49
+ def run_udf(fd: Optional[int] = None) -> int: ...
datachain/query/utils.py CHANGED
@@ -1,33 +1,27 @@
1
- from typing import TYPE_CHECKING, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
- from sqlalchemy import Column
3
+ import sqlalchemy as sa
4
4
 
5
- if TYPE_CHECKING:
6
- from sqlalchemy import ColumnElement, Select, TextClause
7
-
8
-
9
- ColT = Union[Column, "ColumnElement", "TextClause"]
5
+ ColT = Union[sa.Column, sa.ColumnElement, sa.TextClause]
10
6
 
11
7
 
12
8
  def column_name(col: ColT) -> str:
13
9
  """Returns column name from column element."""
14
- return col.name if isinstance(col, Column) else str(col)
10
+ return col.name if isinstance(col, sa.Column) else str(col)
15
11
 
16
12
 
17
- def get_query_column(query: "Select", name: str) -> Optional[ColT]:
13
+ def get_query_column(query: sa.Select, name: str) -> Optional[ColT]:
18
14
  """Returns column element from query by name or None if column not found."""
19
15
  return next((col for col in query.inner_columns if column_name(col) == name), None)
20
16
 
21
17
 
22
- def get_query_id_column(query: "Select") -> ColT:
18
+ def get_query_id_column(query: sa.Select) -> Optional[sa.ColumnElement]:
23
19
  """Returns ID column element from query or None if column not found."""
24
20
  col = get_query_column(query, "sys__id")
25
- if col is None:
26
- raise RuntimeError("sys__id column not found in query")
27
- return col
21
+ return col if col is not None and isinstance(col, sa.ColumnElement) else None
28
22
 
29
23
 
30
- def select_only_columns(query: "Select", *names: str) -> "Select":
24
+ def select_only_columns(query: sa.Select, *names: str) -> sa.Select:
31
25
  """Returns query selecting defined columns only."""
32
26
  if not names:
33
27
  return query
@@ -387,6 +387,7 @@ class StudioClient:
387
387
  files: Optional[list[str]] = None,
388
388
  python_version: Optional[str] = None,
389
389
  requirements: Optional[str] = None,
390
+ repository: Optional[str] = None,
390
391
  ) -> Response[JobData]:
391
392
  data = {
392
393
  "query": query,
@@ -397,6 +398,7 @@ class StudioClient:
397
398
  "files": files,
398
399
  "python_version": python_version,
399
400
  "requirements": requirements,
401
+ "repository": repository,
400
402
  }
401
403
  return self._send_request("datachain/job", data)
402
404
 
datachain/studio.py CHANGED
@@ -35,6 +35,7 @@ def process_jobs_args(args: "Namespace"):
35
35
  args.workers,
36
36
  args.files,
37
37
  args.python_version,
38
+ args.repository,
38
39
  args.req,
39
40
  args.req_file,
40
41
  )
@@ -256,6 +257,7 @@ def create_job(
256
257
  workers: Optional[int] = None,
257
258
  files: Optional[list[str]] = None,
258
259
  python_version: Optional[str] = None,
260
+ repository: Optional[str] = None,
259
261
  req: Optional[list[str]] = None,
260
262
  req_file: Optional[str] = None,
261
263
  ):
@@ -284,6 +286,7 @@ def create_job(
284
286
  query_name=os.path.basename(query_file),
285
287
  files=file_ids,
286
288
  python_version=python_version,
289
+ repository=repository,
287
290
  requirements=requirements,
288
291
  )
289
292
  if not response.ok:
datachain/utils.py CHANGED
@@ -323,6 +323,9 @@ def determine_processes(
323
323
  return True
324
324
  if parallel < 0:
325
325
  return True
326
+ if parallel == 1:
327
+ # Disable parallel processing if only one process is requested.
328
+ return False
326
329
  return parallel
327
330
 
328
331
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datachain
3
- Version: 0.16.3
3
+ Version: 0.16.5
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License-Expression: Apache-2.0