datachain 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (50) hide show
  1. datachain/asyn.py +16 -6
  2. datachain/cache.py +32 -10
  3. datachain/catalog/catalog.py +17 -1
  4. datachain/cli/__init__.py +311 -0
  5. datachain/cli/commands/__init__.py +29 -0
  6. datachain/cli/commands/datasets.py +129 -0
  7. datachain/cli/commands/du.py +14 -0
  8. datachain/cli/commands/index.py +12 -0
  9. datachain/cli/commands/ls.py +169 -0
  10. datachain/cli/commands/misc.py +28 -0
  11. datachain/cli/commands/query.py +53 -0
  12. datachain/cli/commands/show.py +38 -0
  13. datachain/cli/parser/__init__.py +547 -0
  14. datachain/cli/parser/job.py +120 -0
  15. datachain/cli/parser/studio.py +126 -0
  16. datachain/cli/parser/utils.py +63 -0
  17. datachain/{cli_utils.py → cli/utils.py} +27 -1
  18. datachain/client/azure.py +6 -2
  19. datachain/client/fsspec.py +9 -3
  20. datachain/client/gcs.py +6 -2
  21. datachain/client/s3.py +16 -1
  22. datachain/data_storage/db_engine.py +9 -0
  23. datachain/data_storage/schema.py +4 -10
  24. datachain/data_storage/sqlite.py +7 -1
  25. datachain/data_storage/warehouse.py +6 -4
  26. datachain/{lib/diff.py → diff/__init__.py} +116 -12
  27. datachain/func/__init__.py +3 -2
  28. datachain/func/conditional.py +74 -0
  29. datachain/func/func.py +5 -1
  30. datachain/lib/arrow.py +7 -1
  31. datachain/lib/dc.py +8 -3
  32. datachain/lib/file.py +16 -5
  33. datachain/lib/hf.py +1 -1
  34. datachain/lib/listing.py +19 -1
  35. datachain/lib/pytorch.py +57 -13
  36. datachain/lib/signal_schema.py +89 -27
  37. datachain/lib/udf.py +82 -40
  38. datachain/listing.py +1 -0
  39. datachain/progress.py +20 -3
  40. datachain/query/dataset.py +122 -93
  41. datachain/query/dispatch.py +22 -16
  42. datachain/studio.py +58 -38
  43. datachain/utils.py +14 -3
  44. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/METADATA +9 -9
  45. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/RECORD +49 -37
  46. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/WHEEL +1 -1
  47. datachain/cli.py +0 -1475
  48. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/LICENSE +0 -0
  49. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/entry_points.txt +0 -0
  50. {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/top_level.txt +0 -0
datachain/lib/udf.py CHANGED
@@ -1,14 +1,16 @@
1
- import contextlib
2
1
  import sys
3
2
  import traceback
4
- from collections.abc import Iterable, Iterator, Mapping, Sequence
5
- from typing import TYPE_CHECKING, Any, Callable, Optional
3
+ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
4
+ from contextlib import closing, nullcontext
5
+ from functools import partial
6
+ from typing import TYPE_CHECKING, Any, Optional, TypeVar
6
7
 
7
8
  import attrs
8
9
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
9
10
  from pydantic import BaseModel
10
11
 
11
12
  from datachain.asyn import AsyncMapper
13
+ from datachain.cache import temporary_cache
12
14
  from datachain.dataset import RowDict
13
15
  from datachain.lib.convert.flatten import flatten
14
16
  from datachain.lib.data_model import DataValue
@@ -21,17 +23,22 @@ from datachain.query.batch import (
21
23
  Partition,
22
24
  RowsOutputBatch,
23
25
  )
26
+ from datachain.utils import safe_closing
24
27
 
25
28
  if TYPE_CHECKING:
26
29
  from collections import abc
30
+ from contextlib import AbstractContextManager
27
31
 
28
32
  from typing_extensions import Self
29
33
 
34
+ from datachain.cache import DataChainCache as Cache
30
35
  from datachain.catalog import Catalog
31
36
  from datachain.lib.signal_schema import SignalSchema
32
37
  from datachain.lib.udf_signature import UdfSignature
33
38
  from datachain.query.batch import RowsOutput
34
39
 
40
+ T = TypeVar("T", bound=Sequence[Any])
41
+
35
42
 
36
43
  class UdfError(DataChainParamsError):
37
44
  def __init__(self, msg):
@@ -98,6 +105,10 @@ class UDFAdapter:
98
105
  processed_cb,
99
106
  )
100
107
 
108
+ @property
109
+ def prefetch(self) -> int:
110
+ return self.inner.prefetch
111
+
101
112
 
102
113
  class UDFBase(AbstractUDF):
103
114
  """Base class for stateful user-defined functions.
@@ -148,12 +159,11 @@ class UDFBase(AbstractUDF):
148
159
  """
149
160
 
150
161
  is_output_batched = False
151
- catalog: "Optional[Catalog]"
162
+ prefetch: int = 0
152
163
 
153
164
  def __init__(self):
154
165
  self.params: Optional[SignalSchema] = None
155
166
  self.output = None
156
- self.catalog = None
157
167
  self._func = None
158
168
 
159
169
  def process(self, *args, **kwargs):
@@ -242,26 +252,23 @@ class UDFBase(AbstractUDF):
242
252
  return flatten(obj) if isinstance(obj, BaseModel) else [obj]
243
253
 
244
254
  def _parse_row(
245
- self, row_dict: RowDict, cache: bool, download_cb: Callback
255
+ self, row_dict: RowDict, catalog: "Catalog", cache: bool, download_cb: Callback
246
256
  ) -> list[DataValue]:
247
257
  assert self.params
248
258
  row = [row_dict[p] for p in self.params.to_udf_spec()]
249
259
  obj_row = self.params.row_to_objs(row)
250
260
  for obj in obj_row:
251
261
  if isinstance(obj, File):
252
- assert self.catalog is not None
253
- obj._set_stream(
254
- self.catalog, caching_enabled=cache, download_cb=download_cb
255
- )
262
+ obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb)
256
263
  return obj_row
257
264
 
258
- def _prepare_row(self, row, udf_fields, cache, download_cb):
265
+ def _prepare_row(self, row, udf_fields, catalog, cache, download_cb):
259
266
  row_dict = RowDict(zip(udf_fields, row))
260
- return self._parse_row(row_dict, cache, download_cb)
267
+ return self._parse_row(row_dict, catalog, cache, download_cb)
261
268
 
262
- def _prepare_row_and_id(self, row, udf_fields, cache, download_cb):
269
+ def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb):
263
270
  row_dict = RowDict(zip(udf_fields, row))
264
- udf_input = self._parse_row(row_dict, cache, download_cb)
271
+ udf_input = self._parse_row(row_dict, catalog, cache, download_cb)
265
272
  return row_dict["sys__id"], *udf_input
266
273
 
267
274
  def process_safe(self, obj_rows):
@@ -279,13 +286,47 @@ class UDFBase(AbstractUDF):
279
286
  return result_objs
280
287
 
281
288
 
282
- async def _prefetch_input(row):
289
+ def noop(*args, **kwargs):
290
+ pass
291
+
292
+
293
+ async def _prefetch_input(
294
+ row: T,
295
+ download_cb: Optional["Callback"] = None,
296
+ after_prefetch: "Callable[[], None]" = noop,
297
+ ) -> T:
283
298
  for obj in row:
284
- if isinstance(obj, File):
285
- await obj._prefetch()
299
+ if isinstance(obj, File) and await obj._prefetch(download_cb):
300
+ after_prefetch()
286
301
  return row
287
302
 
288
303
 
304
+ def _prefetch_inputs(
305
+ prepared_inputs: "Iterable[T]",
306
+ prefetch: int = 0,
307
+ download_cb: Optional["Callback"] = None,
308
+ after_prefetch: "Callable[[], None]" = noop,
309
+ ) -> "abc.Generator[T, None, None]":
310
+ if prefetch > 0:
311
+ f = partial(
312
+ _prefetch_input,
313
+ download_cb=download_cb,
314
+ after_prefetch=after_prefetch,
315
+ )
316
+ prepared_inputs = AsyncMapper(f, prepared_inputs, workers=prefetch).iterate() # type: ignore[assignment]
317
+ yield from prepared_inputs
318
+
319
+
320
+ def _get_cache(
321
+ cache: "Cache", prefetch: int = 0, use_cache: bool = False
322
+ ) -> "AbstractContextManager[Cache]":
323
+ tmp_dir = cache.tmp_dir
324
+ assert tmp_dir
325
+ if prefetch and not use_cache:
326
+ return temporary_cache(tmp_dir, prefix="prefetch-")
327
+ return nullcontext(cache)
328
+
329
+
289
330
  class Mapper(UDFBase):
290
331
  """Inherit from this class to pass to `DataChain.map()`."""
291
332
 
@@ -300,18 +341,18 @@ class Mapper(UDFBase):
300
341
  download_cb: Callback = DEFAULT_CALLBACK,
301
342
  processed_cb: Callback = DEFAULT_CALLBACK,
302
343
  ) -> Iterator[Iterable[UDFResult]]:
303
- self.catalog = catalog
304
344
  self.setup()
305
- prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
306
- self._prepare_row_and_id(row, udf_fields, cache, download_cb)
307
- for row in udf_inputs
308
- )
309
- if self.prefetch > 0:
310
- prepared_inputs = AsyncMapper(
311
- _prefetch_input, prepared_inputs, workers=self.prefetch
312
- ).iterate()
313
345
 
314
- with contextlib.closing(prepared_inputs):
346
+ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
347
+ with safe_closing(udf_inputs):
348
+ for row in udf_inputs:
349
+ yield self._prepare_row_and_id(
350
+ row, udf_fields, catalog, cache, download_cb
351
+ )
352
+
353
+ prepared_inputs = _prepare_rows(udf_inputs)
354
+ prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
355
+ with closing(prepared_inputs):
315
356
  for id_, *udf_args in prepared_inputs:
316
357
  result_objs = self.process_safe(udf_args)
317
358
  udf_output = self._flatten_row(result_objs)
@@ -336,14 +377,15 @@ class BatchMapper(UDFBase):
336
377
  download_cb: Callback = DEFAULT_CALLBACK,
337
378
  processed_cb: Callback = DEFAULT_CALLBACK,
338
379
  ) -> Iterator[Iterable[UDFResult]]:
339
- self.catalog = catalog
340
380
  self.setup()
341
381
 
342
382
  for batch in udf_inputs:
343
383
  n_rows = len(batch.rows)
344
384
  row_ids, *udf_args = zip(
345
385
  *[
346
- self._prepare_row_and_id(row, udf_fields, cache, download_cb)
386
+ self._prepare_row_and_id(
387
+ row, udf_fields, catalog, cache, download_cb
388
+ )
347
389
  for row in batch.rows
348
390
  ]
349
391
  )
@@ -378,17 +420,18 @@ class Generator(UDFBase):
378
420
  download_cb: Callback = DEFAULT_CALLBACK,
379
421
  processed_cb: Callback = DEFAULT_CALLBACK,
380
422
  ) -> Iterator[Iterable[UDFResult]]:
381
- self.catalog = catalog
382
423
  self.setup()
383
- prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
384
- self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
385
- )
386
- if self.prefetch > 0:
387
- prepared_inputs = AsyncMapper(
388
- _prefetch_input, prepared_inputs, workers=self.prefetch
389
- ).iterate()
390
424
 
391
- with contextlib.closing(prepared_inputs):
425
+ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
426
+ with safe_closing(udf_inputs):
427
+ for row in udf_inputs:
428
+ yield self._prepare_row(
429
+ row, udf_fields, catalog, cache, download_cb
430
+ )
431
+
432
+ prepared_inputs = _prepare_rows(udf_inputs)
433
+ prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
434
+ with closing(prepared_inputs):
392
435
  for row in prepared_inputs:
393
436
  result_objs = self.process_safe(row)
394
437
  udf_outputs = (self._flatten_row(row) for row in result_objs)
@@ -413,13 +456,12 @@ class Aggregator(UDFBase):
413
456
  download_cb: Callback = DEFAULT_CALLBACK,
414
457
  processed_cb: Callback = DEFAULT_CALLBACK,
415
458
  ) -> Iterator[Iterable[UDFResult]]:
416
- self.catalog = catalog
417
459
  self.setup()
418
460
 
419
461
  for batch in udf_inputs:
420
462
  udf_args = zip(
421
463
  *[
422
- self._prepare_row(row, udf_fields, cache, download_cb)
464
+ self._prepare_row(row, udf_fields, catalog, cache, download_cb)
423
465
  for row in batch.rows
424
466
  ]
425
467
  )
datachain/listing.py CHANGED
@@ -153,6 +153,7 @@ class Listing:
153
153
  unit_scale=True,
154
154
  unit_divisor=1000,
155
155
  total=total_files,
156
+ leave=False,
156
157
  )
157
158
 
158
159
  counter = 0
datachain/progress.py CHANGED
@@ -5,6 +5,7 @@ import sys
5
5
  from threading import RLock
6
6
  from typing import Any, ClassVar
7
7
 
8
+ from fsspec import Callback
8
9
  from fsspec.callbacks import TqdmCallback
9
10
  from tqdm import tqdm
10
11
 
@@ -61,7 +62,7 @@ class Tqdm(tqdm):
61
62
  disable : If (default: None) or False,
62
63
  will be determined by logging level.
63
64
  May be overridden to `True` due to non-TTY status.
64
- Skip override by specifying env var `DVC_IGNORE_ISATTY`.
65
+ Skip override by specifying env var `DATACHAIN_IGNORE_ISATTY`.
65
66
  kwargs : anything accepted by `tqdm.tqdm()`
66
67
  """
67
68
  kwargs = kwargs.copy()
@@ -77,7 +78,7 @@ class Tqdm(tqdm):
77
78
  # auto-disable based on TTY
78
79
  if (
79
80
  not disable
80
- and not env2bool("DVC_IGNORE_ISATTY")
81
+ and not env2bool("DATACHAIN_IGNORE_ISATTY")
81
82
  and hasattr(file, "isatty")
82
83
  ):
83
84
  disable = not file.isatty()
@@ -132,8 +133,24 @@ class Tqdm(tqdm):
132
133
  return d
133
134
 
134
135
 
135
- class CombinedDownloadCallback(TqdmCallback):
136
+ class CombinedDownloadCallback(Callback):
136
137
  def set_size(self, size):
137
138
  # This is a no-op to prevent fsspec's .get_file() from setting the combined
138
139
  # download size to the size of the current file.
139
140
  pass
141
+
142
+ def increment_file_count(self, n: int = 1) -> None:
143
+ pass
144
+
145
+
146
+ class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback):
147
+ def __init__(self, tqdm_kwargs=None, *args, **kwargs):
148
+ self.files_count = 0
149
+ tqdm_kwargs = tqdm_kwargs or {}
150
+ tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
151
+ super().__init__(tqdm_kwargs, *args, **kwargs)
152
+
153
+ def increment_file_count(self, n: int = 1) -> None:
154
+ self.files_count += n
155
+ if self.tqdm is not None:
156
+ self.tqdm.postfix = f"{self.files_count} files"
@@ -35,6 +35,7 @@ from sqlalchemy.sql.schema import TableClause
35
35
  from sqlalchemy.sql.selectable import Select
36
36
 
37
37
  from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
38
+ from datachain.catalog.catalog import clone_catalog_with_cache
38
39
  from datachain.data_storage.schema import (
39
40
  PARTITION_COLUMN_ID,
40
41
  partition_col_names,
@@ -43,7 +44,8 @@ from datachain.data_storage.schema import (
43
44
  from datachain.dataset import DatasetStatus, RowDict
44
45
  from datachain.error import DatasetNotFoundError, QueryScriptCancelError
45
46
  from datachain.func.base import Function
46
- from datachain.progress import CombinedDownloadCallback
47
+ from datachain.lib.udf import UDFAdapter, _get_cache
48
+ from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
47
49
  from datachain.query.schema import C, UDFParamSpec, normalize_param
48
50
  from datachain.query.session import Session
49
51
  from datachain.sql.functions.random import rand
@@ -52,6 +54,7 @@ from datachain.utils import (
52
54
  determine_processes,
53
55
  filtered_cloudpickle_dumps,
54
56
  get_datachain_executable,
57
+ safe_closing,
55
58
  )
56
59
 
57
60
  if TYPE_CHECKING:
@@ -349,19 +352,26 @@ def process_udf_outputs(
349
352
  warehouse.insert_rows_done(udf_table)
350
353
 
351
354
 
352
- def get_download_callback() -> Callback:
353
- return CombinedDownloadCallback(
354
- {"desc": "Download", "unit": "B", "unit_scale": True, "unit_divisor": 1024}
355
+ def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallback:
356
+ return TqdmCombinedDownloadCallback(
357
+ {
358
+ "desc": "Download" + suffix,
359
+ "unit": "B",
360
+ "unit_scale": True,
361
+ "unit_divisor": 1024,
362
+ "leave": False,
363
+ **kwargs,
364
+ },
355
365
  )
356
366
 
357
367
 
358
368
  def get_processed_callback() -> Callback:
359
- return TqdmCallback({"desc": "Processed", "unit": " rows"})
369
+ return TqdmCallback({"desc": "Processed", "unit": " rows", "leave": False})
360
370
 
361
371
 
362
372
  def get_generated_callback(is_generator: bool = False) -> Callback:
363
373
  if is_generator:
364
- return TqdmCallback({"desc": "Generated", "unit": " rows"})
374
+ return TqdmCallback({"desc": "Generated", "unit": " rows", "leave": False})
365
375
  return DEFAULT_CALLBACK
366
376
 
367
377
 
@@ -412,97 +422,109 @@ class UDFStep(Step, ABC):
412
422
 
413
423
  udf_fields = [str(c.name) for c in query.selected_columns]
414
424
 
415
- try:
416
- if workers:
417
- if self.catalog.in_memory:
418
- raise RuntimeError(
419
- "In-memory databases cannot be used with "
420
- "distributed processing."
421
- )
425
+ prefetch = self.udf.prefetch
426
+ with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
427
+ catalog = clone_catalog_with_cache(self.catalog, _cache)
428
+ try:
429
+ if workers:
430
+ if catalog.in_memory:
431
+ raise RuntimeError(
432
+ "In-memory databases cannot be used with "
433
+ "distributed processing."
434
+ )
422
435
 
423
- from datachain.catalog.loader import get_distributed_class
424
-
425
- distributor = get_distributed_class(min_task_size=self.min_task_size)
426
- distributor(
427
- self.udf,
428
- self.catalog,
429
- udf_table,
430
- query,
431
- workers,
432
- processes,
433
- udf_fields=udf_fields,
434
- is_generator=self.is_generator,
435
- use_partitioning=use_partitioning,
436
- cache=self.cache,
437
- )
438
- elif processes:
439
- # Parallel processing (faster for more CPU-heavy UDFs)
440
- if self.catalog.in_memory:
441
- raise RuntimeError(
442
- "In-memory databases cannot be used with parallel processing."
443
- )
444
- udf_info: UdfInfo = {
445
- "udf_data": filtered_cloudpickle_dumps(self.udf),
446
- "catalog_init": self.catalog.get_init_params(),
447
- "metastore_clone_params": self.catalog.metastore.clone_params(),
448
- "warehouse_clone_params": self.catalog.warehouse.clone_params(),
449
- "table": udf_table,
450
- "query": query,
451
- "udf_fields": udf_fields,
452
- "batching": batching,
453
- "processes": processes,
454
- "is_generator": self.is_generator,
455
- "cache": self.cache,
456
- }
457
-
458
- # Run the UDFDispatcher in another process to avoid needing
459
- # if __name__ == '__main__': in user scripts
460
- exec_cmd = get_datachain_executable()
461
- cmd = [*exec_cmd, "internal-run-udf"]
462
- envs = dict(os.environ)
463
- envs.update({"PYTHONPATH": os.getcwd()})
464
- process_data = filtered_cloudpickle_dumps(udf_info)
465
-
466
- with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603
467
- process.communicate(process_data)
468
- if retval := process.poll():
469
- raise RuntimeError(f"UDF Execution Failed! Exit code: {retval}")
470
- else:
471
- # Otherwise process single-threaded (faster for smaller UDFs)
472
- warehouse = self.catalog.warehouse
473
-
474
- udf_inputs = batching(warehouse.dataset_select_paginated, query)
475
- download_cb = get_download_callback()
476
- processed_cb = get_processed_callback()
477
- generated_cb = get_generated_callback(self.is_generator)
478
- try:
479
- udf_results = self.udf.run(
480
- udf_fields,
481
- udf_inputs,
482
- self.catalog,
483
- self.cache,
484
- download_cb,
485
- processed_cb,
436
+ from datachain.catalog.loader import get_distributed_class
437
+
438
+ distributor = get_distributed_class(
439
+ min_task_size=self.min_task_size
486
440
  )
487
- process_udf_outputs(
488
- warehouse,
489
- udf_table,
490
- udf_results,
441
+ distributor(
491
442
  self.udf,
492
- cb=generated_cb,
443
+ catalog,
444
+ udf_table,
445
+ query,
446
+ workers,
447
+ processes,
448
+ udf_fields=udf_fields,
449
+ is_generator=self.is_generator,
450
+ use_partitioning=use_partitioning,
451
+ cache=self.cache,
493
452
  )
494
- finally:
495
- download_cb.close()
496
- processed_cb.close()
497
- generated_cb.close()
498
-
499
- except QueryScriptCancelError:
500
- self.catalog.warehouse.close()
501
- sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)
502
- except (Exception, KeyboardInterrupt):
503
- # Close any open database connections if an error is encountered
504
- self.catalog.warehouse.close()
505
- raise
453
+ elif processes:
454
+ # Parallel processing (faster for more CPU-heavy UDFs)
455
+ if catalog.in_memory:
456
+ raise RuntimeError(
457
+ "In-memory databases cannot be used "
458
+ "with parallel processing."
459
+ )
460
+ udf_info: UdfInfo = {
461
+ "udf_data": filtered_cloudpickle_dumps(self.udf),
462
+ "catalog_init": catalog.get_init_params(),
463
+ "metastore_clone_params": catalog.metastore.clone_params(),
464
+ "warehouse_clone_params": catalog.warehouse.clone_params(),
465
+ "table": udf_table,
466
+ "query": query,
467
+ "udf_fields": udf_fields,
468
+ "batching": batching,
469
+ "processes": processes,
470
+ "is_generator": self.is_generator,
471
+ "cache": self.cache,
472
+ }
473
+
474
+ # Run the UDFDispatcher in another process to avoid needing
475
+ # if __name__ == '__main__': in user scripts
476
+ exec_cmd = get_datachain_executable()
477
+ cmd = [*exec_cmd, "internal-run-udf"]
478
+ envs = dict(os.environ)
479
+ envs.update({"PYTHONPATH": os.getcwd()})
480
+ process_data = filtered_cloudpickle_dumps(udf_info)
481
+
482
+ with subprocess.Popen( # noqa: S603
483
+ cmd, env=envs, stdin=subprocess.PIPE
484
+ ) as process:
485
+ process.communicate(process_data)
486
+ if retval := process.poll():
487
+ raise RuntimeError(
488
+ f"UDF Execution Failed! Exit code: {retval}"
489
+ )
490
+ else:
491
+ # Otherwise process single-threaded (faster for smaller UDFs)
492
+ warehouse = catalog.warehouse
493
+
494
+ udf_inputs = batching(warehouse.dataset_select_paginated, query)
495
+ download_cb = get_download_callback()
496
+ processed_cb = get_processed_callback()
497
+ generated_cb = get_generated_callback(self.is_generator)
498
+
499
+ try:
500
+ udf_results = self.udf.run(
501
+ udf_fields,
502
+ udf_inputs,
503
+ catalog,
504
+ self.cache,
505
+ download_cb,
506
+ processed_cb,
507
+ )
508
+ with safe_closing(udf_results):
509
+ process_udf_outputs(
510
+ warehouse,
511
+ udf_table,
512
+ udf_results,
513
+ self.udf,
514
+ cb=generated_cb,
515
+ )
516
+ finally:
517
+ download_cb.close()
518
+ processed_cb.close()
519
+ generated_cb.close()
520
+
521
+ except QueryScriptCancelError:
522
+ self.catalog.warehouse.close()
523
+ sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)
524
+ except (Exception, KeyboardInterrupt):
525
+ # Close any open database connections if an error is encountered
526
+ self.catalog.warehouse.close()
527
+ raise
506
528
 
507
529
  def create_partitions_table(self, query: Select) -> "Table":
508
530
  """
@@ -602,6 +624,13 @@ class UDFSignal(UDFStep):
602
624
  signal_name_cols = {c.name: c for c in signal_cols}
603
625
  cols = signal_cols
604
626
 
627
+ overlap = {c.name for c in original_cols} & {c.name for c in cols}
628
+ if overlap:
629
+ raise ValueError(
630
+ "Column already exists or added in the previous steps: "
631
+ + ", ".join(overlap)
632
+ )
633
+
605
634
  def q(*columns):
606
635
  cols1 = []
607
636
  cols2 = []
@@ -14,7 +14,9 @@ from multiprocess import get_context
14
14
  from sqlalchemy.sql import func
15
15
 
16
16
  from datachain.catalog import Catalog
17
+ from datachain.catalog.catalog import clone_catalog_with_cache
17
18
  from datachain.catalog.loader import get_distributed_class
19
+ from datachain.lib.udf import _get_cache
18
20
  from datachain.query.batch import RowsOutput, RowsOutputBatch
19
21
  from datachain.query.dataset import (
20
22
  get_download_callback,
@@ -25,7 +27,7 @@ from datachain.query.dataset import (
25
27
  from datachain.query.queue import get_from_queue, put_into_queue
26
28
  from datachain.query.udf import UdfInfo
27
29
  from datachain.query.utils import get_query_id_column
28
- from datachain.utils import batched, flatten
30
+ from datachain.utils import batched, flatten, safe_closing
29
31
 
30
32
  if TYPE_CHECKING:
31
33
  from sqlalchemy import Select, Table
@@ -304,21 +306,25 @@ class UDFWorker:
304
306
  processed_cb = ProcessedCallback()
305
307
  generated_cb = get_generated_callback(self.is_generator)
306
308
 
307
- udf_results = self.udf.run(
308
- self.udf_fields,
309
- self.get_inputs(),
310
- self.catalog,
311
- self.cache,
312
- download_cb=self.cb,
313
- processed_cb=processed_cb,
314
- )
315
- process_udf_outputs(
316
- self.catalog.warehouse,
317
- self.table,
318
- self.notify_and_process(udf_results, processed_cb),
319
- self.udf,
320
- cb=generated_cb,
321
- )
309
+ prefetch = self.udf.prefetch
310
+ with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
311
+ catalog = clone_catalog_with_cache(self.catalog, _cache)
312
+ udf_results = self.udf.run(
313
+ self.udf_fields,
314
+ self.get_inputs(),
315
+ catalog,
316
+ self.cache,
317
+ download_cb=self.cb,
318
+ processed_cb=processed_cb,
319
+ )
320
+ with safe_closing(udf_results):
321
+ process_udf_outputs(
322
+ catalog.warehouse,
323
+ self.table,
324
+ self.notify_and_process(udf_results, processed_cb),
325
+ self.udf,
326
+ cb=generated_cb,
327
+ )
322
328
 
323
329
  put_into_queue(
324
330
  self.done_queue,