datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,24 @@
1
1
  import contextlib
2
+ import traceback
2
3
  from collections.abc import Iterable, Sequence
3
4
  from itertools import chain
4
5
  from multiprocessing import cpu_count
6
+ from queue import Empty
5
7
  from sys import stdin
6
- from threading import Timer
7
- from typing import TYPE_CHECKING, Optional
8
+ from time import monotonic, sleep
9
+ from typing import TYPE_CHECKING, Literal
8
10
 
9
- import attrs
10
11
  import multiprocess
11
12
  from cloudpickle import load, loads
12
13
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
13
- from multiprocess import get_context
14
- from sqlalchemy.sql import func
14
+ from multiprocess.context import Process
15
+ from multiprocess.queues import Queue as MultiprocessQueue
15
16
 
16
17
  from datachain.catalog import Catalog
17
18
  from datachain.catalog.catalog import clone_catalog_with_cache
18
- from datachain.catalog.loader import get_distributed_class
19
- from datachain.lib.udf import _get_cache
20
- from datachain.query.batch import RowsOutput, RowsOutputBatch
19
+ from datachain.catalog.loader import DISTRIBUTED_IMPORT_PATH, get_udf_distributor_class
20
+ from datachain.lib.model_store import ModelStore
21
+ from datachain.lib.udf import UdfRunError, _get_cache
21
22
  from datachain.query.dataset import (
22
23
  get_download_callback,
23
24
  get_generated_callback,
@@ -26,7 +27,6 @@ from datachain.query.dataset import (
26
27
  )
27
28
  from datachain.query.queue import get_from_queue, put_into_queue
28
29
  from datachain.query.udf import UdfInfo
29
- from datachain.query.utils import get_query_id_column
30
30
  from datachain.utils import batched, flatten, safe_closing
31
31
 
32
32
  if TYPE_CHECKING:
@@ -34,6 +34,7 @@ if TYPE_CHECKING:
34
34
 
35
35
  from datachain.data_storage import AbstractMetastore, AbstractWarehouse
36
36
  from datachain.lib.udf import UDFAdapter
37
+ from datachain.query.batch import RowsOutput
37
38
 
38
39
  DEFAULT_BATCH_SIZE = 10000
39
40
  STOP_SIGNAL = "STOP"
@@ -43,7 +44,7 @@ FAILED_STATUS = "FAILED"
43
44
  NOTIFY_STATUS = "NOTIFY"
44
45
 
45
46
 
46
- def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
47
+ def get_n_workers_from_arg(n_workers: int | None = None) -> int:
47
48
  if not n_workers:
48
49
  return cpu_count()
49
50
  if n_workers < 1:
@@ -52,55 +53,60 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
52
53
 
53
54
 
54
55
  def udf_entrypoint() -> int:
56
+ """Parallel processing (faster for more CPU-heavy UDFs)."""
55
57
  # Load UDF info from stdin
56
58
  udf_info: UdfInfo = load(stdin.buffer)
57
59
 
58
- # Parallel processing (faster for more CPU-heavy UDFs)
59
- dispatch = UDFDispatcher(udf_info)
60
-
61
60
  query = udf_info["query"]
61
+ if "sys__id" not in query.selected_columns:
62
+ raise RuntimeError("sys__id column is required in UDF query")
63
+
62
64
  batching = udf_info["batching"]
63
- n_workers = udf_info["processes"]
64
- if n_workers is True:
65
- n_workers = None # Use default number of CPUs (cores)
65
+ is_generator = udf_info["is_generator"]
66
+
67
+ download_cb = get_download_callback()
68
+ processed_cb = get_processed_callback()
69
+ generated_cb = get_generated_callback(is_generator)
66
70
 
67
71
  wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
68
72
  warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
69
73
 
70
- total_rows = next(
71
- warehouse.db.execute(
72
- query.with_only_columns(func.count(query.c.sys__id)).order_by(None)
73
- )
74
- )[0]
75
-
76
74
  with contextlib.closing(
77
- batching(warehouse.dataset_select_paginated, query, ids_only=True)
75
+ batching(
76
+ warehouse.dataset_select_paginated,
77
+ query,
78
+ id_col=query.selected_columns.sys__id,
79
+ )
78
80
  ) as udf_inputs:
79
- download_cb = get_download_callback()
80
- processed_cb = get_processed_callback()
81
81
  try:
82
- dispatch.run_udf_parallel(
82
+ UDFDispatcher(udf_info).run_udf(
83
83
  udf_inputs,
84
- total_rows=total_rows,
85
- n_workers=n_workers,
86
- processed_cb=processed_cb,
87
84
  download_cb=download_cb,
85
+ processed_cb=processed_cb,
86
+ generated_cb=generated_cb,
88
87
  )
89
88
  finally:
90
89
  download_cb.close()
91
90
  processed_cb.close()
91
+ generated_cb.close()
92
92
 
93
93
  return 0
94
94
 
95
95
 
96
96
  def udf_worker_entrypoint() -> int:
97
- return get_distributed_class().run_worker()
97
+ if not (udf_distributor_class := get_udf_distributor_class()):
98
+ raise RuntimeError(
99
+ f"{DISTRIBUTED_IMPORT_PATH} import path is required "
100
+ "for distributed UDF processing."
101
+ )
102
+
103
+ return udf_distributor_class.run_udf()
98
104
 
99
105
 
100
106
  class UDFDispatcher:
101
- catalog: Optional[Catalog] = None
102
- task_queue: Optional[multiprocess.Queue] = None
103
- done_queue: Optional[multiprocess.Queue] = None
107
+ _catalog: Catalog | None = None
108
+ task_queue: MultiprocessQueue | None = None
109
+ done_queue: MultiprocessQueue | None = None
104
110
 
105
111
  def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
106
112
  self.udf_data = udf_info["udf_data"]
@@ -113,30 +119,38 @@ class UDFDispatcher:
113
119
  self.cache = udf_info["cache"]
114
120
  self.is_generator = udf_info["is_generator"]
115
121
  self.is_batching = udf_info["batching"].is_batching
122
+ self.processes = udf_info["processes"]
123
+ self.rows_total = udf_info["rows_total"]
124
+ self.batch_size = udf_info["batch_size"]
116
125
  self.buffer_size = buffer_size
117
- self.catalog = None
118
126
  self.task_queue = None
119
127
  self.done_queue = None
120
- self.ctx = get_context("spawn")
128
+ self.ctx = multiprocess.get_context("spawn")
121
129
 
122
- def _create_worker(self) -> "UDFWorker":
123
- if not self.catalog:
130
+ @property
131
+ def catalog(self) -> "Catalog":
132
+ if not self._catalog:
124
133
  ms_cls, ms_args, ms_kwargs = self.metastore_clone_params
125
134
  metastore: AbstractMetastore = ms_cls(*ms_args, **ms_kwargs)
126
135
  ws_cls, ws_args, ws_kwargs = self.warehouse_clone_params
127
136
  warehouse: AbstractWarehouse = ws_cls(*ws_args, **ws_kwargs)
128
- self.catalog = Catalog(metastore, warehouse, **self.catalog_init_params)
129
- self.udf = loads(self.udf_data)
137
+ self._catalog = Catalog(metastore, warehouse, **self.catalog_init_params)
138
+ return self._catalog
139
+
140
+ def _create_worker(self) -> "UDFWorker":
141
+ udf: UDFAdapter = loads(self.udf_data)
142
+ # Ensure all registered DataModels have rebuilt schemas in worker processes.
143
+ ModelStore.rebuild_all()
130
144
  return UDFWorker(
131
145
  self.catalog,
132
- self.udf,
146
+ udf,
133
147
  self.task_queue,
134
148
  self.done_queue,
135
149
  self.query,
136
150
  self.table,
137
- self.is_generator,
138
- self.is_batching,
139
151
  self.cache,
152
+ self.is_batching,
153
+ self.batch_size,
140
154
  self.udf_fields,
141
155
  )
142
156
 
@@ -146,45 +160,109 @@ class UDFDispatcher:
146
160
  worker.run()
147
161
  except (Exception, KeyboardInterrupt) as e:
148
162
  if self.done_queue:
163
+ # We put the exception into the done queue so the main process
164
+ # can handle it appropriately. We include the stacktrace to propagate
165
+ # it to the main process and show it to the user.
149
166
  put_into_queue(
150
167
  self.done_queue,
151
- {"status": FAILED_STATUS, "exception": e},
168
+ {
169
+ "status": FAILED_STATUS,
170
+ "exception": e,
171
+ "stacktrace": traceback.format_exc(),
172
+ },
152
173
  )
174
+ if isinstance(e, KeyboardInterrupt):
175
+ return
153
176
  raise
154
177
 
155
- @staticmethod
156
- def send_stop_signal_to_workers(task_queue, n_workers: Optional[int] = None):
178
+ def run_udf(
179
+ self,
180
+ input_rows: Iterable["RowsOutput"],
181
+ download_cb: Callback = DEFAULT_CALLBACK,
182
+ processed_cb: Callback = DEFAULT_CALLBACK,
183
+ generated_cb: Callback = DEFAULT_CALLBACK,
184
+ ) -> None:
185
+ n_workers = self.processes
186
+ if n_workers is True:
187
+ n_workers = None # Use default number of CPUs (cores)
188
+ elif not n_workers or n_workers < 1:
189
+ n_workers = 1 # Single-threaded (on this worker)
157
190
  n_workers = get_n_workers_from_arg(n_workers)
158
- for _ in range(n_workers):
159
- put_into_queue(task_queue, STOP_SIGNAL)
160
191
 
161
- def create_input_queue(self):
162
- return self.ctx.Queue()
192
+ if n_workers == 1:
193
+ # no need to spawn worker processes if we are running in a single process
194
+ self.run_udf_single(input_rows, download_cb, processed_cb, generated_cb)
195
+ else:
196
+ if self.buffer_size < n_workers:
197
+ raise RuntimeError(
198
+ "Parallel run error: buffer size is smaller than "
199
+ f"number of workers: {self.buffer_size} < {n_workers}"
200
+ )
163
201
 
164
- def run_udf_parallel( # noqa: C901, PLR0912
202
+ self.run_udf_parallel(
203
+ n_workers, input_rows, download_cb, processed_cb, generated_cb
204
+ )
205
+
206
+ def run_udf_single(
165
207
  self,
166
- input_rows: Iterable[RowsOutput],
167
- total_rows: int,
168
- n_workers: Optional[int] = None,
169
- processed_cb: Callback = DEFAULT_CALLBACK,
208
+ input_rows: Iterable["RowsOutput"],
170
209
  download_cb: Callback = DEFAULT_CALLBACK,
210
+ processed_cb: Callback = DEFAULT_CALLBACK,
211
+ generated_cb: Callback = DEFAULT_CALLBACK,
171
212
  ) -> None:
172
- n_workers = get_n_workers_from_arg(n_workers)
213
+ udf: UDFAdapter = loads(self.udf_data)
214
+ # Rebuild schemas in single process too for consistency (cheap, idempotent).
215
+ ModelStore.rebuild_all()
216
+
217
+ if not self.is_batching:
218
+ input_rows = flatten(input_rows)
219
+
220
+ def get_inputs() -> Iterable["RowsOutput"]:
221
+ warehouse = self.catalog.warehouse.clone()
222
+ for ids in batched(input_rows, DEFAULT_BATCH_SIZE):
223
+ yield from warehouse.dataset_rows_select_from_ids(
224
+ self.query, ids, self.is_batching
225
+ )
173
226
 
174
- input_batch_size = total_rows // n_workers
227
+ prefetch = udf.prefetch
228
+ with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
229
+ udf_results = udf.run(
230
+ self.udf_fields,
231
+ get_inputs(),
232
+ self.catalog,
233
+ self.cache,
234
+ download_cb=download_cb,
235
+ processed_cb=processed_cb,
236
+ )
237
+ with safe_closing(udf_results):
238
+ process_udf_outputs(
239
+ self.catalog.warehouse.clone(),
240
+ self.table,
241
+ udf_results,
242
+ udf,
243
+ cb=generated_cb,
244
+ batch_size=self.batch_size,
245
+ )
246
+
247
+ def input_batch_size(self, n_workers: int) -> int:
248
+ input_batch_size = self.rows_total // n_workers
175
249
  if input_batch_size == 0:
176
250
  input_batch_size = 1
177
251
  elif input_batch_size > DEFAULT_BATCH_SIZE:
178
252
  input_batch_size = DEFAULT_BATCH_SIZE
253
+ return input_batch_size
179
254
 
180
- if self.buffer_size < n_workers:
181
- raise RuntimeError(
182
- "Parallel run error: buffer size is smaller than "
183
- f"number of workers: {self.buffer_size} < {n_workers}"
184
- )
185
-
255
+ def run_udf_parallel( # noqa: C901, PLR0912
256
+ self,
257
+ n_workers: int,
258
+ input_rows: Iterable["RowsOutput"],
259
+ download_cb: Callback = DEFAULT_CALLBACK,
260
+ processed_cb: Callback = DEFAULT_CALLBACK,
261
+ generated_cb: Callback = DEFAULT_CALLBACK,
262
+ ) -> None:
186
263
  self.task_queue = self.ctx.Queue()
187
264
  self.done_queue = self.ctx.Queue()
265
+
188
266
  pool = [
189
267
  self.ctx.Process(name=f"Worker-UDF-{i}", target=self._run_worker)
190
268
  for i in range(n_workers)
@@ -192,14 +270,14 @@ class UDFDispatcher:
192
270
  for p in pool:
193
271
  p.start()
194
272
 
195
- # Will be set to True if all tasks complete normally
196
- normal_completion = False
197
273
  try:
198
274
  # Will be set to True when the input is exhausted
199
275
  input_finished = False
200
276
 
201
- if not self.is_batching:
202
- input_rows = batched(flatten(input_rows), input_batch_size)
277
+ input_rows = batched(
278
+ input_rows if self.is_batching else flatten(input_rows),
279
+ self.input_batch_size(n_workers),
280
+ )
203
281
 
204
282
  # Stop all workers after the input rows have finished processing
205
283
  input_data = chain(input_rows, [STOP_SIGNAL] * n_workers)
@@ -214,12 +292,29 @@ class UDFDispatcher:
214
292
 
215
293
  # Process all tasks
216
294
  while n_workers > 0:
217
- result = get_from_queue(self.done_queue)
218
-
295
+ while True:
296
+ try:
297
+ result = self.done_queue.get_nowait()
298
+ break
299
+ except Empty:
300
+ for p in pool:
301
+ exitcode = p.exitcode
302
+ if exitcode not in (None, 0):
303
+ message = (
304
+ f"Worker {p.name} exited unexpectedly with "
305
+ f"code {exitcode}"
306
+ )
307
+ raise RuntimeError(message) from None
308
+ sleep(0.01)
309
+
310
+ if bytes_downloaded := result.get("bytes_downloaded"):
311
+ download_cb.relative_update(bytes_downloaded)
219
312
  if downloaded := result.get("downloaded"):
220
- download_cb.relative_update(downloaded)
313
+ download_cb.increment_file_count(downloaded)
221
314
  if processed := result.get("processed"):
222
315
  processed_cb.relative_update(processed)
316
+ if generated := result.get("generated"):
317
+ generated_cb.relative_update(generated)
223
318
 
224
319
  status = result["status"]
225
320
  if status in (OK_STATUS, NOTIFY_STATUS):
@@ -229,7 +324,9 @@ class UDFDispatcher:
229
324
  else: # Failed / error
230
325
  n_workers -= 1
231
326
  if exc := result.get("exception"):
232
- raise exc
327
+ if isinstance(exc, KeyboardInterrupt):
328
+ raise exc
329
+ raise UdfRunError(exc, stacktrace=result.get("stacktrace"))
233
330
  raise RuntimeError("Internal error: Parallel UDF execution failed")
234
331
 
235
332
  if status == OK_STATUS and not input_finished:
@@ -237,75 +334,104 @@ class UDFDispatcher:
237
334
  put_into_queue(self.task_queue, next(input_data))
238
335
  except StopIteration:
239
336
  input_finished = True
240
-
241
- # Finished with all tasks normally
242
- normal_completion = True
243
337
  finally:
244
- if not normal_completion:
245
- # Stop all workers if there is an unexpected exception
246
- for _ in pool:
247
- put_into_queue(self.task_queue, STOP_SIGNAL)
248
- self.task_queue.close()
249
-
250
- # This allows workers (and this process) to exit without
251
- # consuming any remaining data in the queues.
252
- # (If they exit due to an exception.)
253
- self.task_queue.cancel_join_thread()
254
- self.done_queue.cancel_join_thread()
255
-
256
- # Flush all items from the done queue.
257
- # This is needed if any workers are still running.
258
- while n_workers > 0:
259
- result = get_from_queue(self.done_queue)
260
- status = result["status"]
261
- if status != OK_STATUS:
262
- n_workers -= 1
263
-
264
- # Wait for workers to stop
265
- for p in pool:
266
- p.join()
267
-
268
-
269
- class WorkerCallback(Callback):
270
- def __init__(self, queue: "multiprocess.Queue"):
338
+ self._shutdown_workers(pool)
339
+
340
+ def _shutdown_workers(self, pool: list[Process]) -> None:
341
+ self._terminate_pool(pool)
342
+ self._drain_queue(self.done_queue)
343
+ self._drain_queue(self.task_queue)
344
+ self._close_queue(self.done_queue)
345
+ self._close_queue(self.task_queue)
346
+
347
+ def _terminate_pool(self, pool: list[Process]) -> None:
348
+ for proc in pool:
349
+ if proc.is_alive():
350
+ proc.terminate()
351
+
352
+ deadline = monotonic() + 1.0
353
+ for proc in pool:
354
+ if not proc.is_alive():
355
+ continue
356
+ remaining = deadline - monotonic()
357
+ if remaining > 0:
358
+ proc.join(remaining)
359
+ if proc.is_alive():
360
+ proc.kill()
361
+ proc.join(timeout=0.2)
362
+
363
+ def _drain_queue(self, queue: MultiprocessQueue) -> None:
364
+ while True:
365
+ try:
366
+ queue.get_nowait()
367
+ except Empty:
368
+ return
369
+ except (OSError, ValueError):
370
+ return
371
+
372
+ def _close_queue(self, queue: MultiprocessQueue) -> None:
373
+ with contextlib.suppress(OSError, ValueError):
374
+ queue.close()
375
+ with contextlib.suppress(RuntimeError, AssertionError, ValueError):
376
+ queue.join_thread()
377
+
378
+
379
+ class DownloadCallback(Callback):
380
+ def __init__(self, queue: MultiprocessQueue) -> None:
271
381
  self.queue = queue
272
382
  super().__init__()
273
383
 
274
384
  def relative_update(self, inc: int = 1) -> None:
385
+ put_into_queue(self.queue, {"status": NOTIFY_STATUS, "bytes_downloaded": inc})
386
+
387
+ def increment_file_count(self, inc: int = 1) -> None:
275
388
  put_into_queue(self.queue, {"status": NOTIFY_STATUS, "downloaded": inc})
276
389
 
277
390
 
278
391
  class ProcessedCallback(Callback):
279
- def __init__(self):
280
- self.processed_rows: Optional[int] = None
392
+ def __init__(
393
+ self,
394
+ name: Literal["processed", "generated"],
395
+ queue: MultiprocessQueue,
396
+ ) -> None:
397
+ self.name = name
398
+ self.queue = queue
281
399
  super().__init__()
282
400
 
283
401
  def relative_update(self, inc: int = 1) -> None:
284
- self.processed_rows = inc
402
+ put_into_queue(self.queue, {"status": NOTIFY_STATUS, self.name: inc})
285
403
 
286
404
 
287
- @attrs.define
288
405
  class UDFWorker:
289
- catalog: "Catalog"
290
- udf: "UDFAdapter"
291
- task_queue: "multiprocess.Queue"
292
- done_queue: "multiprocess.Queue"
293
- query: "Select"
294
- table: "Table"
295
- is_generator: bool
296
- is_batching: bool
297
- cache: bool
298
- udf_fields: Sequence[str]
299
- cb: Callback = attrs.field()
300
-
301
- @cb.default
302
- def _default_callback(self) -> WorkerCallback:
303
- return WorkerCallback(self.done_queue)
406
+ def __init__(
407
+ self,
408
+ catalog: "Catalog",
409
+ udf: "UDFAdapter",
410
+ task_queue: MultiprocessQueue,
411
+ done_queue: MultiprocessQueue,
412
+ query: "Select",
413
+ table: "Table",
414
+ cache: bool,
415
+ is_batching: bool,
416
+ batch_size: int,
417
+ udf_fields: Sequence[str],
418
+ ) -> None:
419
+ self.catalog = catalog
420
+ self.udf = udf
421
+ self.task_queue = task_queue
422
+ self.done_queue = done_queue
423
+ self.query = query
424
+ self.table = table
425
+ self.cache = cache
426
+ self.is_batching = is_batching
427
+ self.batch_size = batch_size
428
+ self.udf_fields = udf_fields
429
+
430
+ self.download_cb = DownloadCallback(self.done_queue)
431
+ self.processed_cb = ProcessedCallback("processed", self.done_queue)
432
+ self.generated_cb = ProcessedCallback("generated", self.done_queue)
304
433
 
305
434
  def run(self) -> None:
306
- processed_cb = ProcessedCallback()
307
- generated_cb = get_generated_callback(self.is_generator)
308
-
309
435
  prefetch = self.udf.prefetch
310
436
  with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
311
437
  catalog = clone_catalog_with_cache(self.catalog, _cache)
@@ -314,48 +440,29 @@ class UDFWorker:
314
440
  self.get_inputs(),
315
441
  catalog,
316
442
  self.cache,
317
- download_cb=self.cb,
318
- processed_cb=processed_cb,
443
+ download_cb=self.download_cb,
444
+ processed_cb=self.processed_cb,
319
445
  )
320
446
  with safe_closing(udf_results):
321
447
  process_udf_outputs(
322
448
  catalog.warehouse,
323
449
  self.table,
324
- self.notify_and_process(udf_results, processed_cb),
450
+ self.notify_and_process(udf_results),
325
451
  self.udf,
326
- cb=generated_cb,
452
+ cb=self.generated_cb,
453
+ batch_size=self.batch_size,
327
454
  )
455
+ put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
328
456
 
329
- put_into_queue(
330
- self.done_queue,
331
- {"status": FINISHED_STATUS, "processed": processed_cb.processed_rows},
332
- )
333
-
334
- def notify_and_process(self, udf_results, processed_cb):
457
+ def notify_and_process(self, udf_results):
335
458
  for row in udf_results:
336
- put_into_queue(
337
- self.done_queue,
338
- {"status": OK_STATUS, "processed": processed_cb.processed_rows},
339
- )
459
+ put_into_queue(self.done_queue, {"status": OK_STATUS})
340
460
  yield row
341
461
 
342
- def get_inputs(self):
462
+ def get_inputs(self) -> Iterable["RowsOutput"]:
343
463
  warehouse = self.catalog.warehouse.clone()
344
- col_id = get_query_id_column(self.query)
345
-
346
- if self.is_batching:
347
- while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
348
- ids = [row[0] for row in batch.rows]
349
- rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids)))
350
- yield RowsOutputBatch(list(rows))
351
- else:
352
- while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
353
- yield from warehouse.dataset_rows_select(
354
- self.query.where(col_id.in_(batch))
464
+ while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
465
+ for ids in batched(batch, DEFAULT_BATCH_SIZE):
466
+ yield from warehouse.dataset_rows_select_from_ids(
467
+ self.query, ids, self.is_batching
355
468
  )
356
-
357
-
358
- class RepeatTimer(Timer):
359
- def run(self):
360
- while not self.finished.wait(self.interval):
361
- self.function(*self.args, **self.kwargs)
@@ -1,10 +1,9 @@
1
1
  import os
2
- from typing import Optional, Union
3
2
 
4
- metrics: dict[str, Union[str, int, float, bool, None]] = {}
3
+ metrics: dict[str, str | int | float | bool | None] = {}
5
4
 
6
5
 
7
- def set(key: str, value: Union[str, int, float, bool, None]) -> None: # noqa: PYI041
6
+ def set(key: str, value: str | int | float | bool | None) -> None: # noqa: PYI041
8
7
  """Set a metric value."""
9
8
  if not isinstance(key, str):
10
9
  raise TypeError("Key must be a string")
@@ -15,13 +14,12 @@ def set(key: str, value: Union[str, int, float, bool, None]) -> None: # noqa: P
15
14
  metrics[key] = value
16
15
 
17
16
  if job_id := os.getenv("DATACHAIN_JOB_ID"):
18
- from datachain.data_storage.job import JobStatus
19
17
  from datachain.query.session import Session
20
18
 
21
19
  metastore = Session.get().catalog.metastore
22
- metastore.set_job_status(job_id, JobStatus.RUNNING, metrics=metrics)
20
+ metastore.update_job(job_id, metrics=metrics)
23
21
 
24
22
 
25
- def get(key: str) -> Optional[Union[str, int, float, bool]]:
23
+ def get(key: str) -> str | int | float | bool | None:
26
24
  """Get a metric value."""
27
25
  return metrics[key]
datachain/query/params.py CHANGED
@@ -1,11 +1,10 @@
1
1
  import json
2
2
  import os
3
- from typing import Optional
4
3
 
5
- params_cache: Optional[dict[str, str]] = None
4
+ params_cache: dict[str, str] | None = None
6
5
 
7
6
 
8
- def param(key: str, default: Optional[str] = None) -> Optional[str]:
7
+ def param(key: str, default: str | None = None) -> str | None:
9
8
  """Get query parameter."""
10
9
  if not isinstance(key, str):
11
10
  raise TypeError("Param key must be a string")