datachain 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -1636,18 +1636,27 @@ class DataChain:
1636
1636
  """
1637
1637
  from pyarrow.dataset import CsvFileFormat, JsonFileFormat
1638
1638
 
1639
- from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output
1639
+ from datachain.lib.arrow import (
1640
+ ArrowGenerator,
1641
+ fix_pyarrow_format,
1642
+ infer_schema,
1643
+ schema_to_output,
1644
+ )
1640
1645
 
1641
- if nrows:
1642
- format = kwargs.get("format")
1643
- if format not in ["csv", "json"] and not isinstance(
1644
- format, (CsvFileFormat, JsonFileFormat)
1645
- ):
1646
- raise DatasetPrepareError(
1647
- self.name,
1648
- "error in `parse_tabular` - "
1649
- "`nrows` only supported for csv and json formats.",
1650
- )
1646
+ parse_options = kwargs.pop("parse_options", None)
1647
+ if format := kwargs.get("format"):
1648
+ kwargs["format"] = fix_pyarrow_format(format, parse_options)
1649
+
1650
+ if (
1651
+ nrows
1652
+ and format not in ["csv", "json"]
1653
+ and not isinstance(format, (CsvFileFormat, JsonFileFormat))
1654
+ ):
1655
+ raise DatasetPrepareError(
1656
+ self.name,
1657
+ "error in `parse_tabular` - "
1658
+ "`nrows` only supported for csv and json formats.",
1659
+ )
1651
1660
 
1652
1661
  if "file" not in self.schema or not self.count():
1653
1662
  raise DatasetPrepareError(self.name, "no files to parse.")
@@ -1656,7 +1665,7 @@ class DataChain:
1656
1665
  col_names = output if isinstance(output, Sequence) else None
1657
1666
  if col_names or not output:
1658
1667
  try:
1659
- schema = infer_schema(self, **kwargs)
1668
+ schema = infer_schema(self, **kwargs, parse_options=parse_options)
1660
1669
  output, _ = schema_to_output(schema, col_names)
1661
1670
  except ValueError as e:
1662
1671
  raise DatasetPrepareError(self.name, e) from e
@@ -1682,7 +1691,15 @@ class DataChain:
1682
1691
  # disable prefetch if nrows is set
1683
1692
  settings = {"prefetch": 0} if nrows else {}
1684
1693
  return self.settings(**settings).gen( # type: ignore[arg-type]
1685
- ArrowGenerator(schema, model, source, nrows, **kwargs), output=output
1694
+ ArrowGenerator(
1695
+ schema,
1696
+ model,
1697
+ source,
1698
+ nrows,
1699
+ parse_options=parse_options,
1700
+ **kwargs,
1701
+ ),
1702
+ output=output,
1686
1703
  )
1687
1704
 
1688
1705
  @classmethod
datachain/lib/listing.py CHANGED
@@ -56,6 +56,8 @@ def list_bucket(uri: str, cache, client_config=None) -> Callable:
56
56
  for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()):
57
57
  yield from entries
58
58
 
59
+ list_func.__name__ = "read_storage"
60
+
59
61
  return list_func
60
62
 
61
63
 
datachain/lib/udf.py CHANGED
@@ -218,6 +218,18 @@ class UDFBase(AbstractUDF):
218
218
  def name(self):
219
219
  return self.__class__.__name__
220
220
 
221
+ @property
222
+ def verbose_name(self):
223
+ """Returns the name of the function or class that implements the UDF."""
224
+ if self._func and callable(self._func):
225
+ if hasattr(self._func, "__name__"):
226
+ return self._func.__name__
227
+ if hasattr(self._func, "__class__") and hasattr(
228
+ self._func.__class__, "__name__"
229
+ ):
230
+ return self._func.__class__.__name__
231
+ return "<unknown>"
232
+
221
233
  @property
222
234
  def signal_names(self) -> Iterable[str]:
223
235
  return self.output.to_udf_spec().keys()
@@ -411,13 +423,13 @@ class BatchMapper(UDFBase):
411
423
  self.setup()
412
424
 
413
425
  for batch in udf_inputs:
414
- n_rows = len(batch.rows)
426
+ n_rows = len(batch)
415
427
  row_ids, *udf_args = zip(
416
428
  *[
417
429
  self._prepare_row_and_id(
418
430
  row, udf_fields, catalog, cache, download_cb
419
431
  )
420
- for row in batch.rows
432
+ for row in batch
421
433
  ]
422
434
  )
423
435
  result_objs = list(self.process_safe(udf_args))
@@ -489,7 +501,7 @@ class Aggregator(UDFBase):
489
501
 
490
502
  def run(
491
503
  self,
492
- udf_fields: "Sequence[str]",
504
+ udf_fields: Sequence[str],
493
505
  udf_inputs: Iterable[RowsOutputBatch],
494
506
  catalog: "Catalog",
495
507
  cache: bool,
@@ -502,13 +514,13 @@ class Aggregator(UDFBase):
502
514
  udf_args = zip(
503
515
  *[
504
516
  self._prepare_row(row, udf_fields, catalog, cache, download_cb)
505
- for row in batch.rows
517
+ for row in batch
506
518
  ]
507
519
  )
508
520
  result_objs = self.process_safe(udf_args)
509
521
  udf_outputs = (self._flatten_row(row) for row in result_objs)
510
522
  output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
511
- processed_cb.relative_update(len(batch.rows))
523
+ processed_cb.relative_update(len(batch))
512
524
  yield output
513
525
 
514
526
  self.teardown()
datachain/query/batch.py CHANGED
@@ -2,22 +2,14 @@ import contextlib
2
2
  import math
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import Generator, Sequence
5
- from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Callable, Optional, Union
7
-
8
- from datachain.data_storage.schema import PARTITION_COLUMN_ID
9
- from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
10
- from datachain.query.utils import get_query_column, get_query_id_column
11
-
12
- if TYPE_CHECKING:
13
- from sqlalchemy import Select
5
+ from typing import Callable, Optional, Union
14
6
 
7
+ import sqlalchemy as sa
15
8
 
16
- @dataclass
17
- class RowsOutputBatch:
18
- rows: Sequence[Sequence]
19
-
9
+ from datachain.data_storage.schema import PARTITION_COLUMN_ID
10
+ from datachain.query.utils import get_query_column
20
11
 
12
+ RowsOutputBatch = Sequence[Sequence]
21
13
  RowsOutput = Union[Sequence, RowsOutputBatch]
22
14
 
23
15
 
@@ -30,8 +22,8 @@ class BatchingStrategy(ABC):
30
22
  def __call__(
31
23
  self,
32
24
  execute: Callable,
33
- query: "Select",
34
- ids_only: bool = False,
25
+ query: sa.Select,
26
+ id_col: Optional[sa.ColumnElement] = None,
35
27
  ) -> Generator[RowsOutput, None, None]:
36
28
  """Apply the provided parameters to the UDF."""
37
29
 
@@ -47,12 +39,16 @@ class NoBatching(BatchingStrategy):
47
39
  def __call__(
48
40
  self,
49
41
  execute: Callable,
50
- query: "Select",
51
- ids_only: bool = False,
42
+ query: sa.Select,
43
+ id_col: Optional[sa.ColumnElement] = None,
52
44
  ) -> Generator[Sequence, None, None]:
53
- if ids_only:
54
- query = query.with_only_columns(get_query_id_column(query))
55
- return execute(query)
45
+ ids_only = False
46
+ if id_col is not None:
47
+ query = query.with_only_columns(id_col)
48
+ ids_only = True
49
+
50
+ rows = execute(query)
51
+ yield from (r[0] for r in rows) if ids_only else rows
56
52
 
57
53
 
58
54
  class Batch(BatchingStrategy):
@@ -69,27 +65,31 @@ class Batch(BatchingStrategy):
69
65
  def __call__(
70
66
  self,
71
67
  execute: Callable,
72
- query: "Select",
73
- ids_only: bool = False,
74
- ) -> Generator[RowsOutputBatch, None, None]:
75
- if ids_only:
76
- query = query.with_only_columns(get_query_id_column(query))
68
+ query: sa.Select,
69
+ id_col: Optional[sa.ColumnElement] = None,
70
+ ) -> Generator[RowsOutput, None, None]:
71
+ from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
72
+
73
+ ids_only = False
74
+ if id_col is not None:
75
+ query = query.with_only_columns(id_col)
76
+ ids_only = True
77
77
 
78
78
  # choose page size that is a multiple of the batch size
79
79
  page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count
80
80
 
81
81
  # select rows in batches
82
- results: list[Sequence] = []
82
+ results = []
83
83
 
84
- with contextlib.closing(execute(query, page_size=page_size)) as rows:
85
- for row in rows:
84
+ with contextlib.closing(execute(query, page_size=page_size)) as batch_rows:
85
+ for row in batch_rows:
86
86
  results.append(row)
87
87
  if len(results) >= self.count:
88
88
  batch, results = results[: self.count], results[self.count :]
89
- yield RowsOutputBatch(batch)
89
+ yield [r[0] for r in batch] if ids_only else batch
90
90
 
91
91
  if len(results) > 0:
92
- yield RowsOutputBatch(results)
92
+ yield [r[0] for r in results] if ids_only else results
93
93
 
94
94
 
95
95
  class Partition(BatchingStrategy):
@@ -104,18 +104,19 @@ class Partition(BatchingStrategy):
104
104
  def __call__(
105
105
  self,
106
106
  execute: Callable,
107
- query: "Select",
108
- ids_only: bool = False,
109
- ) -> Generator[RowsOutputBatch, None, None]:
110
- id_col = get_query_id_column(query)
107
+ query: sa.Select,
108
+ id_col: Optional[sa.ColumnElement] = None,
109
+ ) -> Generator[RowsOutput, None, None]:
111
110
  if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None:
112
111
  raise RuntimeError("partition column not found in query")
113
112
 
114
- if ids_only:
113
+ ids_only = False
114
+ if id_col is not None:
115
115
  query = query.with_only_columns(id_col, partition_col)
116
+ ids_only = True
116
117
 
117
118
  current_partition: Optional[int] = None
118
- batch: list[Sequence] = []
119
+ batch: list = []
119
120
 
120
121
  query_fields = [str(c.name) for c in query.selected_columns]
121
122
  id_column_idx = query_fields.index("sys__id")
@@ -132,9 +133,9 @@ class Partition(BatchingStrategy):
132
133
  if current_partition != partition:
133
134
  current_partition = partition
134
135
  if len(batch) > 0:
135
- yield RowsOutputBatch(batch)
136
+ yield batch
136
137
  batch = []
137
- batch.append([row[id_column_idx]] if ids_only else row)
138
+ batch.append(row[id_column_idx] if ids_only else row)
138
139
 
139
140
  if len(batch) > 0:
140
- yield RowsOutputBatch(batch)
141
+ yield batch
@@ -42,15 +42,9 @@ from datachain.data_storage.schema import (
42
42
  partition_columns,
43
43
  )
44
44
  from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict
45
- from datachain.error import (
46
- DatasetNotFoundError,
47
- QueryScriptCancelError,
48
- )
45
+ from datachain.error import DatasetNotFoundError, QueryScriptCancelError
49
46
  from datachain.func.base import Function
50
- from datachain.lib.listing import (
51
- is_listing_dataset,
52
- listing_dataset_expired,
53
- )
47
+ from datachain.lib.listing import is_listing_dataset, listing_dataset_expired
54
48
  from datachain.lib.udf import UDFAdapter, _get_cache
55
49
  from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
56
50
  from datachain.query.schema import C, UDFParamSpec, normalize_param
@@ -420,41 +414,30 @@ class UDFStep(Step, ABC):
420
414
  """
421
415
 
422
416
  def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
423
- from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
424
-
425
- rows_total = self.catalog.warehouse.query_count(query)
426
- if rows_total == 0:
417
+ if (rows_total := self.catalog.warehouse.query_count(query)) == 0:
427
418
  return
428
419
 
420
+ from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
421
+ from datachain.catalog.loader import (
422
+ DISTRIBUTED_IMPORT_PATH,
423
+ get_udf_distributor_class,
424
+ )
425
+
429
426
  workers = determine_workers(self.workers, rows_total=rows_total)
430
427
  processes = determine_processes(self.parallel, rows_total=rows_total)
431
428
 
432
429
  use_partitioning = self.partition_by is not None
433
430
  batching = self.udf.get_batching(use_partitioning)
434
431
  udf_fields = [str(c.name) for c in query.selected_columns]
432
+ udf_distributor_class = get_udf_distributor_class()
435
433
 
436
434
  prefetch = self.udf.prefetch
437
435
  with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
438
436
  catalog = clone_catalog_with_cache(self.catalog, _cache)
439
- try:
440
- if workers:
441
- if catalog.in_memory:
442
- raise RuntimeError(
443
- "In-memory databases cannot be used with "
444
- "distributed processing."
445
- )
446
-
447
- from datachain.catalog.loader import (
448
- DISTRIBUTED_IMPORT_PATH,
449
- get_udf_distributor_class,
450
- )
451
-
452
- if not (udf_distributor_class := get_udf_distributor_class()):
453
- raise RuntimeError(
454
- f"{DISTRIBUTED_IMPORT_PATH} import path is required "
455
- "for distributed UDF processing."
456
- )
457
437
 
438
+ try:
439
+ if udf_distributor_class and not catalog.in_memory:
440
+ # Use the UDF distributor if available (running in SaaS)
458
441
  udf_distributor = udf_distributor_class(
459
442
  catalog=catalog,
460
443
  table=udf_table,
@@ -470,7 +453,20 @@ class UDFStep(Step, ABC):
470
453
  min_task_size=self.min_task_size,
471
454
  )
472
455
  udf_distributor()
473
- elif processes:
456
+ return
457
+
458
+ if workers:
459
+ if catalog.in_memory:
460
+ raise RuntimeError(
461
+ "In-memory databases cannot be used with "
462
+ "distributed processing."
463
+ )
464
+
465
+ raise RuntimeError(
466
+ f"{DISTRIBUTED_IMPORT_PATH} import path is required "
467
+ "for distributed UDF processing."
468
+ )
469
+ if processes:
474
470
  # Parallel processing (faster for more CPU-heavy UDFs)
475
471
  if catalog.in_memory:
476
472
  raise RuntimeError(
@@ -504,7 +500,12 @@ class UDFStep(Step, ABC):
504
500
  with subprocess.Popen( # noqa: S603
505
501
  cmd, env=envs, stdin=subprocess.PIPE
506
502
  ) as process:
507
- process.communicate(process_data)
503
+ try:
504
+ process.communicate(process_data)
505
+ except KeyboardInterrupt:
506
+ raise QueryScriptCancelError(
507
+ "UDF execution was canceled by the user."
508
+ ) from None
508
509
  if retval := process.poll():
509
510
  raise RuntimeError(
510
511
  f"UDF Execution Failed! Exit code: {retval}"