datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,24 @@
1
1
  import contextlib
2
+ import traceback
2
3
  from collections.abc import Iterable, Sequence
3
4
  from itertools import chain
4
5
  from multiprocessing import cpu_count
6
+ from queue import Empty
5
7
  from sys import stdin
6
- from typing import TYPE_CHECKING, Literal, Optional
8
+ from time import monotonic, sleep
9
+ from typing import TYPE_CHECKING, Literal
7
10
 
8
11
  import multiprocess
9
12
  from cloudpickle import load, loads
10
13
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
11
- from multiprocess import get_context
14
+ from multiprocess.context import Process
15
+ from multiprocess.queues import Queue as MultiprocessQueue
12
16
 
13
17
  from datachain.catalog import Catalog
14
18
  from datachain.catalog.catalog import clone_catalog_with_cache
15
19
  from datachain.catalog.loader import DISTRIBUTED_IMPORT_PATH, get_udf_distributor_class
16
20
  from datachain.lib.model_store import ModelStore
17
- from datachain.lib.udf import _get_cache
21
+ from datachain.lib.udf import UdfRunError, _get_cache
18
22
  from datachain.query.dataset import (
19
23
  get_download_callback,
20
24
  get_generated_callback,
@@ -23,7 +27,6 @@ from datachain.query.dataset import (
23
27
  )
24
28
  from datachain.query.queue import get_from_queue, put_into_queue
25
29
  from datachain.query.udf import UdfInfo
26
- from datachain.query.utils import get_query_id_column
27
30
  from datachain.utils import batched, flatten, safe_closing
28
31
 
29
32
  if TYPE_CHECKING:
@@ -41,7 +44,7 @@ FAILED_STATUS = "FAILED"
41
44
  NOTIFY_STATUS = "NOTIFY"
42
45
 
43
46
 
44
- def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
47
+ def get_n_workers_from_arg(n_workers: int | None = None) -> int:
45
48
  if not n_workers:
46
49
  return cpu_count()
47
50
  if n_workers < 1:
@@ -55,6 +58,9 @@ def udf_entrypoint() -> int:
55
58
  udf_info: UdfInfo = load(stdin.buffer)
56
59
 
57
60
  query = udf_info["query"]
61
+ if "sys__id" not in query.selected_columns:
62
+ raise RuntimeError("sys__id column is required in UDF query")
63
+
58
64
  batching = udf_info["batching"]
59
65
  is_generator = udf_info["is_generator"]
60
66
 
@@ -65,15 +71,16 @@ def udf_entrypoint() -> int:
65
71
  wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
66
72
  warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
67
73
 
68
- id_col = get_query_id_column(query)
69
-
70
74
  with contextlib.closing(
71
- batching(warehouse.dataset_select_paginated, query, id_col=id_col)
75
+ batching(
76
+ warehouse.dataset_select_paginated,
77
+ query,
78
+ id_col=query.selected_columns.sys__id,
79
+ )
72
80
  ) as udf_inputs:
73
81
  try:
74
82
  UDFDispatcher(udf_info).run_udf(
75
83
  udf_inputs,
76
- ids_only=id_col is not None,
77
84
  download_cb=download_cb,
78
85
  processed_cb=processed_cb,
79
86
  generated_cb=generated_cb,
@@ -86,20 +93,20 @@ def udf_entrypoint() -> int:
86
93
  return 0
87
94
 
88
95
 
89
- def udf_worker_entrypoint(fd: Optional[int] = None) -> int:
96
+ def udf_worker_entrypoint() -> int:
90
97
  if not (udf_distributor_class := get_udf_distributor_class()):
91
98
  raise RuntimeError(
92
99
  f"{DISTRIBUTED_IMPORT_PATH} import path is required "
93
100
  "for distributed UDF processing."
94
101
  )
95
102
 
96
- return udf_distributor_class.run_udf(fd)
103
+ return udf_distributor_class.run_udf()
97
104
 
98
105
 
99
106
  class UDFDispatcher:
100
- _catalog: Optional[Catalog] = None
101
- task_queue: Optional[multiprocess.Queue] = None
102
- done_queue: Optional[multiprocess.Queue] = None
107
+ _catalog: Catalog | None = None
108
+ task_queue: MultiprocessQueue | None = None
109
+ done_queue: MultiprocessQueue | None = None
103
110
 
104
111
  def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
105
112
  self.udf_data = udf_info["udf_data"]
@@ -114,10 +121,11 @@ class UDFDispatcher:
114
121
  self.is_batching = udf_info["batching"].is_batching
115
122
  self.processes = udf_info["processes"]
116
123
  self.rows_total = udf_info["rows_total"]
124
+ self.batch_size = udf_info["batch_size"]
117
125
  self.buffer_size = buffer_size
118
126
  self.task_queue = None
119
127
  self.done_queue = None
120
- self.ctx = get_context("spawn")
128
+ self.ctx = multiprocess.get_context("spawn")
121
129
 
122
130
  @property
123
131
  def catalog(self) -> "Catalog":
@@ -142,18 +150,26 @@ class UDFDispatcher:
142
150
  self.table,
143
151
  self.cache,
144
152
  self.is_batching,
153
+ self.batch_size,
145
154
  self.udf_fields,
146
155
  )
147
156
 
148
- def _run_worker(self, ids_only: bool) -> None:
157
+ def _run_worker(self) -> None:
149
158
  try:
150
159
  worker = self._create_worker()
151
- worker.run(ids_only)
160
+ worker.run()
152
161
  except (Exception, KeyboardInterrupt) as e:
153
162
  if self.done_queue:
163
+ # We put the exception into the done queue so the main process
164
+ # can handle it appropriately. We include the stacktrace to propagate
165
+ # it to the main process and show it to the user.
154
166
  put_into_queue(
155
167
  self.done_queue,
156
- {"status": FAILED_STATUS, "exception": e},
168
+ {
169
+ "status": FAILED_STATUS,
170
+ "exception": e,
171
+ "stacktrace": traceback.format_exc(),
172
+ },
157
173
  )
158
174
  if isinstance(e, KeyboardInterrupt):
159
175
  return
@@ -162,7 +178,6 @@ class UDFDispatcher:
162
178
  def run_udf(
163
179
  self,
164
180
  input_rows: Iterable["RowsOutput"],
165
- ids_only: bool,
166
181
  download_cb: Callback = DEFAULT_CALLBACK,
167
182
  processed_cb: Callback = DEFAULT_CALLBACK,
168
183
  generated_cb: Callback = DEFAULT_CALLBACK,
@@ -176,9 +191,7 @@ class UDFDispatcher:
176
191
 
177
192
  if n_workers == 1:
178
193
  # no need to spawn worker processes if we are running in a single process
179
- self.run_udf_single(
180
- input_rows, ids_only, download_cb, processed_cb, generated_cb
181
- )
194
+ self.run_udf_single(input_rows, download_cb, processed_cb, generated_cb)
182
195
  else:
183
196
  if self.buffer_size < n_workers:
184
197
  raise RuntimeError(
@@ -187,13 +200,12 @@ class UDFDispatcher:
187
200
  )
188
201
 
189
202
  self.run_udf_parallel(
190
- n_workers, input_rows, ids_only, download_cb, processed_cb, generated_cb
203
+ n_workers, input_rows, download_cb, processed_cb, generated_cb
191
204
  )
192
205
 
193
206
  def run_udf_single(
194
207
  self,
195
208
  input_rows: Iterable["RowsOutput"],
196
- ids_only: bool,
197
209
  download_cb: Callback = DEFAULT_CALLBACK,
198
210
  processed_cb: Callback = DEFAULT_CALLBACK,
199
211
  generated_cb: Callback = DEFAULT_CALLBACK,
@@ -202,18 +214,15 @@ class UDFDispatcher:
202
214
  # Rebuild schemas in single process too for consistency (cheap, idempotent).
203
215
  ModelStore.rebuild_all()
204
216
 
205
- if ids_only and not self.is_batching:
217
+ if not self.is_batching:
206
218
  input_rows = flatten(input_rows)
207
219
 
208
220
  def get_inputs() -> Iterable["RowsOutput"]:
209
221
  warehouse = self.catalog.warehouse.clone()
210
- if ids_only:
211
- for ids in batched(input_rows, DEFAULT_BATCH_SIZE):
212
- yield from warehouse.dataset_rows_select_from_ids(
213
- self.query, ids, self.is_batching
214
- )
215
- else:
216
- yield from input_rows
222
+ for ids in batched(input_rows, DEFAULT_BATCH_SIZE):
223
+ yield from warehouse.dataset_rows_select_from_ids(
224
+ self.query, ids, self.is_batching
225
+ )
217
226
 
218
227
  prefetch = udf.prefetch
219
228
  with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
@@ -232,6 +241,7 @@ class UDFDispatcher:
232
241
  udf_results,
233
242
  udf,
234
243
  cb=generated_cb,
244
+ batch_size=self.batch_size,
235
245
  )
236
246
 
237
247
  def input_batch_size(self, n_workers: int) -> int:
@@ -246,7 +256,6 @@ class UDFDispatcher:
246
256
  self,
247
257
  n_workers: int,
248
258
  input_rows: Iterable["RowsOutput"],
249
- ids_only: bool,
250
259
  download_cb: Callback = DEFAULT_CALLBACK,
251
260
  processed_cb: Callback = DEFAULT_CALLBACK,
252
261
  generated_cb: Callback = DEFAULT_CALLBACK,
@@ -255,16 +264,12 @@ class UDFDispatcher:
255
264
  self.done_queue = self.ctx.Queue()
256
265
 
257
266
  pool = [
258
- self.ctx.Process(
259
- name=f"Worker-UDF-{i}", target=self._run_worker, args=[ids_only]
260
- )
267
+ self.ctx.Process(name=f"Worker-UDF-{i}", target=self._run_worker)
261
268
  for i in range(n_workers)
262
269
  ]
263
270
  for p in pool:
264
271
  p.start()
265
272
 
266
- # Will be set to True if all tasks complete normally
267
- normal_completion = False
268
273
  try:
269
274
  # Will be set to True when the input is exhausted
270
275
  input_finished = False
@@ -287,10 +292,20 @@ class UDFDispatcher:
287
292
 
288
293
  # Process all tasks
289
294
  while n_workers > 0:
290
- try:
291
- result = get_from_queue(self.done_queue)
292
- except KeyboardInterrupt:
293
- break
295
+ while True:
296
+ try:
297
+ result = self.done_queue.get_nowait()
298
+ break
299
+ except Empty:
300
+ for p in pool:
301
+ exitcode = p.exitcode
302
+ if exitcode not in (None, 0):
303
+ message = (
304
+ f"Worker {p.name} exited unexpectedly with "
305
+ f"code {exitcode}"
306
+ )
307
+ raise RuntimeError(message) from None
308
+ sleep(0.01)
294
309
 
295
310
  if bytes_downloaded := result.get("bytes_downloaded"):
296
311
  download_cb.relative_update(bytes_downloaded)
@@ -309,7 +324,9 @@ class UDFDispatcher:
309
324
  else: # Failed / error
310
325
  n_workers -= 1
311
326
  if exc := result.get("exception"):
312
- raise exc
327
+ if isinstance(exc, KeyboardInterrupt):
328
+ raise exc
329
+ raise UdfRunError(exc, stacktrace=result.get("stacktrace"))
313
330
  raise RuntimeError("Internal error: Parallel UDF execution failed")
314
331
 
315
332
  if status == OK_STATUS and not input_finished:
@@ -317,39 +334,50 @@ class UDFDispatcher:
317
334
  put_into_queue(self.task_queue, next(input_data))
318
335
  except StopIteration:
319
336
  input_finished = True
320
-
321
- # Finished with all tasks normally
322
- normal_completion = True
323
337
  finally:
324
- if not normal_completion:
325
- # Stop all workers if there is an unexpected exception
326
- for _ in pool:
327
- put_into_queue(self.task_queue, STOP_SIGNAL)
328
-
329
- # This allows workers (and this process) to exit without
330
- # consuming any remaining data in the queues.
331
- # (If they exit due to an exception.)
332
- self.task_queue.close()
333
- self.task_queue.join_thread()
334
-
335
- # Flush all items from the done queue.
336
- # This is needed if any workers are still running.
337
- while n_workers > 0:
338
- result = get_from_queue(self.done_queue)
339
- status = result["status"]
340
- if status != OK_STATUS:
341
- n_workers -= 1
342
-
343
- self.done_queue.close()
344
- self.done_queue.join_thread()
338
+ self._shutdown_workers(pool)
339
+
340
+ def _shutdown_workers(self, pool: list[Process]) -> None:
341
+ self._terminate_pool(pool)
342
+ self._drain_queue(self.done_queue)
343
+ self._drain_queue(self.task_queue)
344
+ self._close_queue(self.done_queue)
345
+ self._close_queue(self.task_queue)
346
+
347
+ def _terminate_pool(self, pool: list[Process]) -> None:
348
+ for proc in pool:
349
+ if proc.is_alive():
350
+ proc.terminate()
351
+
352
+ deadline = monotonic() + 1.0
353
+ for proc in pool:
354
+ if not proc.is_alive():
355
+ continue
356
+ remaining = deadline - monotonic()
357
+ if remaining > 0:
358
+ proc.join(remaining)
359
+ if proc.is_alive():
360
+ proc.kill()
361
+ proc.join(timeout=0.2)
362
+
363
+ def _drain_queue(self, queue: MultiprocessQueue) -> None:
364
+ while True:
365
+ try:
366
+ queue.get_nowait()
367
+ except Empty:
368
+ return
369
+ except (OSError, ValueError):
370
+ return
345
371
 
346
- # Wait for workers to stop
347
- for p in pool:
348
- p.join()
372
+ def _close_queue(self, queue: MultiprocessQueue) -> None:
373
+ with contextlib.suppress(OSError, ValueError):
374
+ queue.close()
375
+ with contextlib.suppress(RuntimeError, AssertionError, ValueError):
376
+ queue.join_thread()
349
377
 
350
378
 
351
379
  class DownloadCallback(Callback):
352
- def __init__(self, queue: "multiprocess.Queue") -> None:
380
+ def __init__(self, queue: MultiprocessQueue) -> None:
353
381
  self.queue = queue
354
382
  super().__init__()
355
383
 
@@ -364,7 +392,7 @@ class ProcessedCallback(Callback):
364
392
  def __init__(
365
393
  self,
366
394
  name: Literal["processed", "generated"],
367
- queue: "multiprocess.Queue",
395
+ queue: MultiprocessQueue,
368
396
  ) -> None:
369
397
  self.name = name
370
398
  self.queue = queue
@@ -379,12 +407,13 @@ class UDFWorker:
379
407
  self,
380
408
  catalog: "Catalog",
381
409
  udf: "UDFAdapter",
382
- task_queue: "multiprocess.Queue",
383
- done_queue: "multiprocess.Queue",
410
+ task_queue: MultiprocessQueue,
411
+ done_queue: MultiprocessQueue,
384
412
  query: "Select",
385
413
  table: "Table",
386
414
  cache: bool,
387
415
  is_batching: bool,
416
+ batch_size: int,
388
417
  udf_fields: Sequence[str],
389
418
  ) -> None:
390
419
  self.catalog = catalog
@@ -395,19 +424,20 @@ class UDFWorker:
395
424
  self.table = table
396
425
  self.cache = cache
397
426
  self.is_batching = is_batching
427
+ self.batch_size = batch_size
398
428
  self.udf_fields = udf_fields
399
429
 
400
430
  self.download_cb = DownloadCallback(self.done_queue)
401
431
  self.processed_cb = ProcessedCallback("processed", self.done_queue)
402
432
  self.generated_cb = ProcessedCallback("generated", self.done_queue)
403
433
 
404
- def run(self, ids_only: bool) -> None:
434
+ def run(self) -> None:
405
435
  prefetch = self.udf.prefetch
406
436
  with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
407
437
  catalog = clone_catalog_with_cache(self.catalog, _cache)
408
438
  udf_results = self.udf.run(
409
439
  self.udf_fields,
410
- self.get_inputs(ids_only),
440
+ self.get_inputs(),
411
441
  catalog,
412
442
  self.cache,
413
443
  download_cb=self.download_cb,
@@ -420,6 +450,7 @@ class UDFWorker:
420
450
  self.notify_and_process(udf_results),
421
451
  self.udf,
422
452
  cb=self.generated_cb,
453
+ batch_size=self.batch_size,
423
454
  )
424
455
  put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
425
456
 
@@ -428,13 +459,10 @@ class UDFWorker:
428
459
  put_into_queue(self.done_queue, {"status": OK_STATUS})
429
460
  yield row
430
461
 
431
- def get_inputs(self, ids_only: bool) -> Iterable["RowsOutput"]:
462
+ def get_inputs(self) -> Iterable["RowsOutput"]:
432
463
  warehouse = self.catalog.warehouse.clone()
433
464
  while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
434
- if ids_only:
435
- for ids in batched(batch, DEFAULT_BATCH_SIZE):
436
- yield from warehouse.dataset_rows_select_from_ids(
437
- self.query, ids, self.is_batching
438
- )
439
- else:
440
- yield from batch
465
+ for ids in batched(batch, DEFAULT_BATCH_SIZE):
466
+ yield from warehouse.dataset_rows_select_from_ids(
467
+ self.query, ids, self.is_batching
468
+ )
@@ -1,10 +1,9 @@
1
1
  import os
2
- from typing import Optional, Union
3
2
 
4
- metrics: dict[str, Union[str, int, float, bool, None]] = {}
3
+ metrics: dict[str, str | int | float | bool | None] = {}
5
4
 
6
5
 
7
- def set(key: str, value: Union[str, int, float, bool, None]) -> None: # noqa: PYI041
6
+ def set(key: str, value: str | int | float | bool | None) -> None: # noqa: PYI041
8
7
  """Set a metric value."""
9
8
  if not isinstance(key, str):
10
9
  raise TypeError("Key must be a string")
@@ -21,6 +20,6 @@ def set(key: str, value: Union[str, int, float, bool, None]) -> None: # noqa: P
21
20
  metastore.update_job(job_id, metrics=metrics)
22
21
 
23
22
 
24
- def get(key: str) -> Optional[Union[str, int, float, bool]]:
23
+ def get(key: str) -> str | int | float | bool | None:
25
24
  """Get a metric value."""
26
25
  return metrics[key]
datachain/query/params.py CHANGED
@@ -1,11 +1,10 @@
1
1
  import json
2
2
  import os
3
- from typing import Optional
4
3
 
5
- params_cache: Optional[dict[str, str]] = None
4
+ params_cache: dict[str, str] | None = None
6
5
 
7
6
 
8
- def param(key: str, default: Optional[str] = None) -> Optional[str]:
7
+ def param(key: str, default: str | None = None) -> str | None:
9
8
  """Get query parameter."""
10
9
  if not isinstance(key, str):
11
10
  raise TypeError("Param key must be a string")
datachain/query/queue.py CHANGED
@@ -1,11 +1,12 @@
1
1
  import datetime
2
2
  from collections.abc import Iterable, Iterator
3
- from queue import Empty, Full, Queue
3
+ from queue import Empty, Full
4
4
  from struct import pack, unpack
5
5
  from time import sleep
6
6
  from typing import Any
7
7
 
8
8
  import msgpack
9
+ from multiprocess.queues import Queue
9
10
 
10
11
  from datachain.query.batch import RowsOutput
11
12
 
datachain/query/schema.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import functools
2
2
  from abc import ABC, abstractmethod
3
+ from collections.abc import Callable
3
4
  from fnmatch import fnmatch
4
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
5
+ from typing import TYPE_CHECKING, Any
5
6
 
6
7
  import attrs
7
8
  import sqlalchemy as sa
@@ -42,7 +43,7 @@ class ColumnMeta(type):
42
43
 
43
44
 
44
45
  class Column(sa.ColumnClause, metaclass=ColumnMeta):
45
- inherit_cache: Optional[bool] = True
46
+ inherit_cache: bool | None = True
46
47
 
47
48
  def __init__(self, text, type_=None, is_literal=False, _selectable=None):
48
49
  """Dataset column."""
@@ -177,7 +178,7 @@ class LocalFilename(UDFParameter):
177
178
  otherwise None will be returned.
178
179
  """
179
180
 
180
- glob: Optional[str] = None
181
+ glob: str | None = None
181
182
 
182
183
  def get_value(
183
184
  self,
@@ -186,7 +187,7 @@ class LocalFilename(UDFParameter):
186
187
  *,
187
188
  cb: Callback = DEFAULT_CALLBACK,
188
189
  **kwargs,
189
- ) -> Optional[str]:
190
+ ) -> str | None:
190
191
  if self.glob and not fnmatch(row["name"], self.glob): # type: ignore[type-var]
191
192
  # If the glob pattern is specified and the row filename
192
193
  # does not match it, then return None
@@ -205,7 +206,7 @@ class LocalFilename(UDFParameter):
205
206
  cache: bool = False,
206
207
  cb: Callback = DEFAULT_CALLBACK,
207
208
  **kwargs,
208
- ) -> Optional[str]:
209
+ ) -> str | None:
209
210
  if self.glob and not fnmatch(row["name"], self.glob): # type: ignore[type-var]
210
211
  # If the glob pattern is specified and the row filename
211
212
  # does not match it, then return None
@@ -216,7 +217,7 @@ class LocalFilename(UDFParameter):
216
217
  return client.cache.get_path(file)
217
218
 
218
219
 
219
- UDFParamSpec = Union[str, Column, UDFParameter]
220
+ UDFParamSpec = str | Column | UDFParameter
220
221
 
221
222
 
222
223
  def normalize_param(param: UDFParamSpec) -> UDFParameter: