datachain 0.36.0__py3-none-any.whl → 0.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -11,7 +11,6 @@ from datachain.sql.types import (
11
11
  JSON,
12
12
  Boolean,
13
13
  DateTime,
14
- Int,
15
14
  Int64,
16
15
  SQLType,
17
16
  String,
@@ -269,7 +268,7 @@ class DataTable:
269
268
  @classmethod
270
269
  def sys_columns(cls):
271
270
  return [
272
- sa.Column("sys__id", Int, primary_key=True),
271
+ sa.Column("sys__id", UInt64, primary_key=True),
273
272
  sa.Column(
274
273
  "sys__rand", UInt64, nullable=False, server_default=f.abs(f.random())
275
274
  ),
@@ -868,11 +868,8 @@ class SQLiteWarehouse(AbstractWarehouse):
868
868
  if isinstance(c, BinaryExpression):
869
869
  right_left_join = add_left_rows_filter(c)
870
870
 
871
- # Use CTE instead of subquery to force SQLite to materialize the result
872
- # This breaks deep nesting and prevents parser stack overflow.
873
871
  union_cte = sqlalchemy.union(left_right_join, right_left_join).cte()
874
-
875
- return self._regenerate_system_columns(union_cte)
872
+ return sqlalchemy.select(*union_cte.c).select_from(union_cte)
876
873
 
877
874
  def _system_row_number_expr(self):
878
875
  return func.row_number().over()
@@ -884,11 +881,7 @@ class SQLiteWarehouse(AbstractWarehouse):
884
881
  """
885
882
  Create a temporary table from a query for use in a UDF.
886
883
  """
887
- columns = [
888
- sqlalchemy.Column(c.name, c.type)
889
- for c in query.selected_columns
890
- if c.name != "sys__id"
891
- ]
884
+ columns = [sqlalchemy.Column(c.name, c.type) for c in query.selected_columns]
892
885
  table = self.create_udf_table(columns)
893
886
 
894
887
  with tqdm(desc="Preparing", unit=" rows", leave=False) as pbar:
@@ -5,7 +5,7 @@ import random
5
5
  import string
6
6
  from abc import ABC, abstractmethod
7
7
  from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
8
- from typing import TYPE_CHECKING, Any, Union
8
+ from typing import TYPE_CHECKING, Any, Union, cast
9
9
  from urllib.parse import urlparse
10
10
 
11
11
  import attrs
@@ -23,7 +23,7 @@ from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
23
23
  from datachain.query.batch import RowsOutput
24
24
  from datachain.query.schema import ColumnMeta
25
25
  from datachain.sql.functions import path as pathfunc
26
- from datachain.sql.types import Int, SQLType
26
+ from datachain.sql.types import SQLType
27
27
  from datachain.utils import sql_escape_like
28
28
 
29
29
  if TYPE_CHECKING:
@@ -32,6 +32,7 @@ if TYPE_CHECKING:
32
32
  _FromClauseArgument,
33
33
  _OnClauseArgument,
34
34
  )
35
+ from sqlalchemy.sql.selectable import FromClause
35
36
  from sqlalchemy.types import TypeEngine
36
37
 
37
38
  from datachain.data_storage import schema
@@ -248,45 +249,56 @@ class AbstractWarehouse(ABC, Serializable):
248
249
 
249
250
  def _regenerate_system_columns(
250
251
  self,
251
- selectable: sa.Select | sa.CTE,
252
+ selectable: sa.Select,
252
253
  keep_existing_columns: bool = False,
254
+ regenerate_columns: Iterable[str] | None = None,
253
255
  ) -> sa.Select:
254
256
  """
255
- Return a SELECT that regenerates sys__id and sys__rand deterministically.
257
+ Return a SELECT that regenerates system columns deterministically.
256
258
 
257
- If keep_existing_columns is True, existing sys__id and sys__rand columns
258
- will be kept as-is if they exist in the input selectable.
259
- """
260
- base = selectable.subquery() if hasattr(selectable, "subquery") else selectable
261
-
262
- result_columns: dict[str, sa.ColumnElement] = {}
263
- for col in base.c:
264
- if col.name in result_columns:
265
- raise ValueError(f"Duplicate column name {col.name} in SELECT")
266
- if col.name in ("sys__id", "sys__rand"):
267
- if keep_existing_columns:
268
- result_columns[col.name] = col
269
- else:
270
- result_columns[col.name] = col
259
+ If keep_existing_columns is True, existing system columns will be kept as-is
260
+ even when they are listed in ``regenerate_columns``.
271
261
 
272
- system_types: dict[str, sa.types.TypeEngine] = {
262
+ Args:
263
+ selectable: Base SELECT
264
+ keep_existing_columns: When True, reuse existing system columns even if
265
+ they are part of the regeneration set.
266
+ regenerate_columns: Names of system columns to regenerate. Defaults to
267
+ {"sys__id", "sys__rand"}. Columns not listed are left untouched.
268
+ """
269
+ system_columns = {
273
270
  sys_col.name: sys_col.type
274
271
  for sys_col in self.schema.dataset_row_cls.sys_columns()
275
272
  }
273
+ regenerate = set(regenerate_columns or system_columns)
274
+ generators = {
275
+ "sys__id": self._system_row_number_expr,
276
+ "sys__rand": self._system_random_expr,
277
+ }
278
+
279
+ base = cast("FromClause", selectable.subquery())
280
+
281
+ def build(name: str) -> sa.ColumnElement:
282
+ expr = generators[name]()
283
+ return sa.cast(expr, system_columns[name]).label(name)
284
+
285
+ columns: list[sa.ColumnElement] = []
286
+ present: set[str] = set()
287
+ changed = False
288
+
289
+ for col in base.c:
290
+ present.add(col.name)
291
+ regen = col.name in regenerate and not keep_existing_columns
292
+ columns.append(build(col.name) if regen else col)
293
+ changed |= regen
294
+
295
+ for name in regenerate - present:
296
+ columns.append(build(name))
297
+ changed = True
298
+
299
+ if not changed:
300
+ return selectable
276
301
 
277
- # Add missing system columns if needed
278
- if "sys__id" not in result_columns:
279
- expr = self._system_row_number_expr()
280
- expr = sa.cast(expr, system_types["sys__id"])
281
- result_columns["sys__id"] = expr.label("sys__id")
282
- if "sys__rand" not in result_columns:
283
- expr = self._system_random_expr()
284
- expr = sa.cast(expr, system_types["sys__rand"])
285
- result_columns["sys__rand"] = expr.label("sys__rand")
286
-
287
- # Wrap in subquery to materialize window functions, then wrap again in SELECT
288
- # This ensures window functions are computed before INSERT...FROM SELECT
289
- columns = list(result_columns.values())
290
302
  inner = sa.select(*columns).select_from(base).subquery()
291
303
  return sa.select(*inner.c).select_from(inner)
292
304
 
@@ -950,10 +962,15 @@ class AbstractWarehouse(ABC, Serializable):
950
962
  SQLite TEMPORARY tables cannot be directly used as they are process-specific,
951
963
  and UDFs are run in other processes when run in parallel.
952
964
  """
965
+ columns = [
966
+ c
967
+ for c in columns
968
+ if c.name not in [col.name for col in self.dataset_row_cls.sys_columns()]
969
+ ]
953
970
  tbl = sa.Table(
954
971
  name or self.udf_table_name(),
955
972
  sa.MetaData(),
956
- sa.Column("sys__id", Int, primary_key=True),
973
+ *self.dataset_row_cls.sys_columns(),
957
974
  *columns,
958
975
  )
959
976
  self.db.create_table(tbl, if_not_exists=True)
@@ -24,7 +24,7 @@ class CompareStatus(str, Enum):
24
24
  SAME = "S"
25
25
 
26
26
 
27
- def _compare( # noqa: C901, PLR0912
27
+ def _compare( # noqa: C901
28
28
  left: "DataChain",
29
29
  right: "DataChain",
30
30
  on: str | Sequence[str],
@@ -151,11 +151,7 @@ def _compare( # noqa: C901, PLR0912
151
151
  if status_col:
152
152
  cols_select.append(diff_col)
153
153
 
154
- if not dc_diff._sys:
155
- # TODO workaround when sys signal is not available in diff
156
- dc_diff = dc_diff.settings(sys=True).select(*cols_select).settings(sys=False)
157
- else:
158
- dc_diff = dc_diff.select(*cols_select)
154
+ dc_diff = dc_diff.select(*cols_select)
159
155
 
160
156
  # final schema is schema from the left chain with status column added if needed
161
157
  dc_diff.signals_schema = (
datachain/lib/audio.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import posixpath
2
+ import re
2
3
  from typing import TYPE_CHECKING
3
4
 
4
5
  from datachain.lib.file import FileError
@@ -9,7 +10,7 @@ if TYPE_CHECKING:
9
10
  from datachain.lib.file import Audio, AudioFile, File
10
11
 
11
12
  try:
12
- import torchaudio
13
+ import soundfile as sf
13
14
  except ImportError as exc:
14
15
  raise ImportError(
15
16
  "Missing dependencies for processing audio.\n"
@@ -26,18 +27,25 @@ def audio_info(file: "File | AudioFile") -> "Audio":
26
27
 
27
28
  try:
28
29
  with file.open() as f:
29
- info = torchaudio.info(f)
30
+ info = sf.info(f)
31
+
32
+ sample_rate = int(info.samplerate)
33
+ channels = int(info.channels)
34
+ frames = int(info.frames)
35
+ duration = float(info.duration)
30
36
 
31
- sample_rate = int(info.sample_rate)
32
- channels = int(info.num_channels)
33
- frames = int(info.num_frames)
34
- duration = float(frames / sample_rate) if sample_rate > 0 else 0.0
37
+ # soundfile provides format and subtype
38
+ if info.format:
39
+ format_name = info.format.lower()
40
+ else:
41
+ format_name = file.get_file_ext().lower()
35
42
 
36
- codec_name = getattr(info, "encoding", "")
37
- file_ext = file.get_file_ext().lower()
38
- format_name = _encoding_to_format(codec_name, file_ext)
43
+ if not format_name:
44
+ format_name = "unknown"
45
+ codec_name = info.subtype if info.subtype else ""
39
46
 
40
- bits_per_sample = getattr(info, "bits_per_sample", 0)
47
+ # Calculate bit rate from subtype
48
+ bits_per_sample = _get_bits_per_sample(info.subtype)
41
49
  bit_rate = (
42
50
  bits_per_sample * sample_rate * channels if bits_per_sample > 0 else -1
43
51
  )
@@ -58,44 +66,39 @@ def audio_info(file: "File | AudioFile") -> "Audio":
58
66
  )
59
67
 
60
68
 
61
- def _encoding_to_format(encoding: str, file_ext: str) -> str:
69
+ def _get_bits_per_sample(subtype: str) -> int:
62
70
  """
63
- Map torchaudio encoding to a format name.
71
+ Map soundfile subtype to bits per sample.
64
72
 
65
73
  Args:
66
- encoding: The encoding string from torchaudio.info()
67
- file_ext: The file extension as a fallback
74
+ subtype: The subtype string from soundfile
68
75
 
69
76
  Returns:
70
- Format name as a string
77
+ Bits per sample, or 0 if unknown
71
78
  """
72
- # Direct mapping for formats that match exactly
73
- encoding_map = {
74
- "FLAC": "flac",
75
- "MP3": "mp3",
76
- "VORBIS": "ogg",
77
- "AMR_WB": "amr",
78
- "AMR_NB": "amr",
79
- "OPUS": "opus",
80
- "GSM": "gsm",
79
+ if not subtype:
80
+ return 0
81
+
82
+ # Common PCM and floating-point subtypes
83
+ pcm_bits = {
84
+ "PCM_16": 16,
85
+ "PCM_24": 24,
86
+ "PCM_32": 32,
87
+ "PCM_S8": 8,
88
+ "PCM_U8": 8,
89
+ "FLOAT": 32,
90
+ "DOUBLE": 64,
81
91
  }
82
92
 
83
- if encoding in encoding_map:
84
- return encoding_map[encoding]
93
+ if subtype in pcm_bits:
94
+ return pcm_bits[subtype]
85
95
 
86
- # For PCM variants, use file extension to determine format
87
- if encoding.startswith("PCM_"):
88
- # Common PCM formats by extension
89
- pcm_formats = {
90
- "wav": "wav",
91
- "aiff": "aiff",
92
- "au": "au",
93
- "raw": "raw",
94
- }
95
- return pcm_formats.get(file_ext, "wav") # Default to wav for PCM
96
+ # Handle variants such as PCM_S16LE, PCM_F32LE, etc.
97
+ match = re.search(r"PCM_(?:[A-Z]*?)(\d+)", subtype)
98
+ if match:
99
+ return int(match.group(1))
96
100
 
97
- # Fallback to file extension if encoding is unknown
98
- return file_ext if file_ext else "unknown"
101
+ return 0
99
102
 
100
103
 
101
104
  def audio_to_np(
@@ -114,27 +117,27 @@ def audio_to_np(
114
117
 
115
118
  try:
116
119
  with audio.open() as f:
117
- info = torchaudio.info(f)
118
- sample_rate = info.sample_rate
120
+ info = sf.info(f)
121
+ sample_rate = info.samplerate
119
122
 
120
123
  frame_offset = int(start * sample_rate)
121
124
  num_frames = int(duration * sample_rate) if duration is not None else -1
122
125
 
123
126
  # Reset file pointer to the beginning
124
- # This is important to ensure we read from the correct position later
125
127
  f.seek(0)
126
128
 
127
- waveform, sr = torchaudio.load(
128
- f, frame_offset=frame_offset, num_frames=num_frames
129
+ # Read audio data with offset and frame count
130
+ audio_np, sr = sf.read(
131
+ f,
132
+ start=frame_offset,
133
+ frames=num_frames,
134
+ always_2d=False,
135
+ dtype="float32",
129
136
  )
130
137
 
131
- audio_np = waveform.numpy()
132
-
133
- if audio_np.shape[0] > 1:
134
- audio_np = audio_np.T
135
- else:
136
- audio_np = audio_np.squeeze()
137
-
138
+ # soundfile returns shape (frames,) for mono or
139
+ # (frames, channels) for multi-channel
140
+ # We keep this format as it matches expected output
138
141
  return audio_np, int(sr)
139
142
  except Exception as exc:
140
143
  raise FileError(
@@ -152,11 +155,9 @@ def audio_to_bytes(
152
155
 
153
156
  If duration is None, converts from start to end of file.
154
157
  If start is 0 and duration is None, converts entire file."""
155
- y, sr = audio_to_np(audio, start, duration)
156
-
157
158
  import io
158
159
 
159
- import soundfile as sf
160
+ y, sr = audio_to_np(audio, start, duration)
160
161
 
161
162
  buffer = io.BytesIO()
162
163
  sf.write(buffer, y, sr, format=format)
@@ -856,7 +856,9 @@ class DataChain:
856
856
  udf_obj.to_udf_wrapper(self._settings.batch_size),
857
857
  **self._settings.to_dict(),
858
858
  ),
859
- signal_schema=self.signals_schema | udf_obj.output,
859
+ signal_schema=SignalSchema({"sys": Sys})
860
+ | self.signals_schema
861
+ | udf_obj.output,
860
862
  )
861
863
 
862
864
  def gen(
@@ -894,7 +896,7 @@ class DataChain:
894
896
  udf_obj.to_udf_wrapper(self._settings.batch_size),
895
897
  **self._settings.to_dict(),
896
898
  ),
897
- signal_schema=udf_obj.output,
899
+ signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
898
900
  )
899
901
 
900
902
  @delta_disabled
@@ -1031,7 +1033,7 @@ class DataChain:
1031
1033
  partition_by=processed_partition_by,
1032
1034
  **self._settings.to_dict(),
1033
1035
  ),
1034
- signal_schema=udf_obj.output,
1036
+ signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
1035
1037
  )
1036
1038
 
1037
1039
  def batch_map(
@@ -1097,11 +1099,7 @@ class DataChain:
1097
1099
  sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
1098
1100
  DataModel.register(list(sign.output_schema.values.values()))
1099
1101
 
1100
- signals_schema = self.signals_schema
1101
- if self._sys:
1102
- signals_schema = SignalSchema({"sys": Sys}) | signals_schema
1103
-
1104
- params_schema = signals_schema.slice(
1102
+ params_schema = self.signals_schema.slice(
1105
1103
  sign.params, self._setup, is_batch=is_batch
1106
1104
  )
1107
1105
 
@@ -1156,11 +1154,9 @@ class DataChain:
1156
1154
  )
1157
1155
  )
1158
1156
 
1159
- def select(self, *args: str, _sys: bool = True) -> "Self":
1157
+ def select(self, *args: str) -> "Self":
1160
1158
  """Select only a specified set of signals."""
1161
1159
  new_schema = self.signals_schema.resolve(*args)
1162
- if self._sys and _sys:
1163
- new_schema = SignalSchema({"sys": Sys}) | new_schema
1164
1160
  columns = new_schema.db_signals()
1165
1161
  return self._evolve(
1166
1162
  query=self._query.select(*columns), signal_schema=new_schema
@@ -1710,9 +1706,11 @@ class DataChain:
1710
1706
 
1711
1707
  signals_schema = self.signals_schema.clone_without_sys_signals()
1712
1708
  right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
1713
- ds.signals_schema = SignalSchema({"sys": Sys}) | signals_schema.merge(
1714
- right_signals_schema, rname
1715
- )
1709
+
1710
+ ds.signals_schema = signals_schema.merge(right_signals_schema, rname)
1711
+
1712
+ if not full:
1713
+ ds.signals_schema = SignalSchema({"sys": Sys}) | ds.signals_schema
1716
1714
 
1717
1715
  return ds
1718
1716
 
@@ -1723,6 +1721,7 @@ class DataChain:
1723
1721
  Parameters:
1724
1722
  other: chain whose rows will be added to `self`.
1725
1723
  """
1724
+ self.signals_schema = self.signals_schema.clone_without_sys_signals()
1726
1725
  return self._evolve(query=self._query.union(other._query))
1727
1726
 
1728
1727
  def subtract( # type: ignore[override]
@@ -438,9 +438,6 @@ class UDFStep(Step, ABC):
438
438
  """
439
439
 
440
440
  def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
441
- if "sys__id" not in query.selected_columns:
442
- raise RuntimeError("Query must have sys__id column to run UDF")
443
-
444
441
  if (rows_total := self.catalog.warehouse.query_count(query)) == 0:
445
442
  return
446
443
 
@@ -634,12 +631,11 @@ class UDFStep(Step, ABC):
634
631
 
635
632
  # Apply partitioning if needed.
636
633
  if self.partition_by is not None:
637
- if "sys__id" not in query.selected_columns:
638
- _query = query = self.catalog.warehouse._regenerate_system_columns(
639
- query,
640
- keep_existing_columns=True,
641
- )
642
-
634
+ _query = query = self.catalog.warehouse._regenerate_system_columns(
635
+ query_generator.select(),
636
+ keep_existing_columns=True,
637
+ regenerate_columns=["sys__id"],
638
+ )
643
639
  partition_tbl = self.create_partitions_table(query)
644
640
  temp_tables.append(partition_tbl.name)
645
641
  query = query.outerjoin(
@@ -960,28 +956,23 @@ class SQLUnion(Step):
960
956
  q2 = self.query2.apply_steps().select().subquery()
961
957
  temp_tables.extend(self.query2.temp_table_names)
962
958
 
963
- columns1, columns2 = _order_columns(q1.columns, q2.columns)
964
-
965
- union_select = sqlalchemy.select(*columns1).union_all(
966
- sqlalchemy.select(*columns2)
967
- )
968
- union_cte = union_select.cte()
969
- regenerated = self.query1.catalog.warehouse._regenerate_system_columns(
970
- union_cte
971
- )
972
- result_columns = tuple(regenerated.selected_columns)
959
+ columns1 = _drop_system_columns(q1.columns)
960
+ columns2 = _drop_system_columns(q2.columns)
961
+ columns1, columns2 = _order_columns(columns1, columns2)
973
962
 
974
963
  def q(*columns):
975
- if not columns:
976
- return regenerated
964
+ selected_names = [c.name for c in columns]
965
+ col1 = [c for c in columns1 if c.name in selected_names]
966
+ col2 = [c for c in columns2 if c.name in selected_names]
967
+ union_query = sqlalchemy.select(*col1).union_all(sqlalchemy.select(*col2))
977
968
 
978
- names = {c.name for c in columns}
979
- selected = [c for c in result_columns if c.name in names]
980
- return regenerated.with_only_columns(*selected)
969
+ union_cte = union_query.cte()
970
+ select_cols = [union_cte.c[name] for name in selected_names]
971
+ return sqlalchemy.select(*select_cols)
981
972
 
982
973
  return step_result(
983
974
  q,
984
- result_columns,
975
+ columns1,
985
976
  dependencies=self.query1.dependencies | self.query2.dependencies,
986
977
  )
987
978
 
@@ -1070,7 +1061,7 @@ class SQLJoin(Step):
1070
1061
  q1 = self.get_query(self.query1, temp_tables)
1071
1062
  q2 = self.get_query(self.query2, temp_tables)
1072
1063
 
1073
- q1_columns = list(q1.c)
1064
+ q1_columns = _drop_system_columns(q1.c) if self.full else list(q1.c)
1074
1065
  q1_column_names = {c.name for c in q1_columns}
1075
1066
 
1076
1067
  q2_columns = []
@@ -1211,6 +1202,10 @@ def _order_columns(
1211
1202
  return [[d[n] for n in column_order] for d in column_dicts]
1212
1203
 
1213
1204
 
1205
+ def _drop_system_columns(columns: Iterable[ColumnElement]) -> list[ColumnElement]:
1206
+ return [c for c in columns if not c.name.startswith("sys__")]
1207
+
1208
+
1214
1209
  @attrs.define
1215
1210
  class ResultIter:
1216
1211
  _row_iter: Iterable[Any]
@@ -2,12 +2,16 @@ import contextlib
2
2
  from collections.abc import Iterable, Sequence
3
3
  from itertools import chain
4
4
  from multiprocessing import cpu_count
5
+ from queue import Empty
5
6
  from sys import stdin
7
+ from time import monotonic, sleep
6
8
  from typing import TYPE_CHECKING, Literal
7
9
 
10
+ import multiprocess
8
11
  from cloudpickle import load, loads
9
12
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
10
- from multiprocess import get_context
13
+ from multiprocess.context import Process
14
+ from multiprocess.queues import Queue as MultiprocessQueue
11
15
 
12
16
  from datachain.catalog import Catalog
13
17
  from datachain.catalog.catalog import clone_catalog_with_cache
@@ -25,7 +29,6 @@ from datachain.query.udf import UdfInfo
25
29
  from datachain.utils import batched, flatten, safe_closing
26
30
 
27
31
  if TYPE_CHECKING:
28
- import multiprocess
29
32
  from sqlalchemy import Select, Table
30
33
 
31
34
  from datachain.data_storage import AbstractMetastore, AbstractWarehouse
@@ -101,8 +104,8 @@ def udf_worker_entrypoint(fd: int | None = None) -> int:
101
104
 
102
105
  class UDFDispatcher:
103
106
  _catalog: Catalog | None = None
104
- task_queue: "multiprocess.Queue | None" = None
105
- done_queue: "multiprocess.Queue | None" = None
107
+ task_queue: MultiprocessQueue | None = None
108
+ done_queue: MultiprocessQueue | None = None
106
109
 
107
110
  def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
108
111
  self.udf_data = udf_info["udf_data"]
@@ -121,7 +124,7 @@ class UDFDispatcher:
121
124
  self.buffer_size = buffer_size
122
125
  self.task_queue = None
123
126
  self.done_queue = None
124
- self.ctx = get_context("spawn")
127
+ self.ctx = multiprocess.get_context("spawn")
125
128
 
126
129
  @property
127
130
  def catalog(self) -> "Catalog":
@@ -259,8 +262,6 @@ class UDFDispatcher:
259
262
  for p in pool:
260
263
  p.start()
261
264
 
262
- # Will be set to True if all tasks complete normally
263
- normal_completion = False
264
265
  try:
265
266
  # Will be set to True when the input is exhausted
266
267
  input_finished = False
@@ -283,10 +284,20 @@ class UDFDispatcher:
283
284
 
284
285
  # Process all tasks
285
286
  while n_workers > 0:
286
- try:
287
- result = get_from_queue(self.done_queue)
288
- except KeyboardInterrupt:
289
- break
287
+ while True:
288
+ try:
289
+ result = self.done_queue.get_nowait()
290
+ break
291
+ except Empty:
292
+ for p in pool:
293
+ exitcode = p.exitcode
294
+ if exitcode not in (None, 0):
295
+ message = (
296
+ f"Worker {p.name} exited unexpectedly with "
297
+ f"code {exitcode}"
298
+ )
299
+ raise RuntimeError(message) from None
300
+ sleep(0.01)
290
301
 
291
302
  if bytes_downloaded := result.get("bytes_downloaded"):
292
303
  download_cb.relative_update(bytes_downloaded)
@@ -313,39 +324,50 @@ class UDFDispatcher:
313
324
  put_into_queue(self.task_queue, next(input_data))
314
325
  except StopIteration:
315
326
  input_finished = True
316
-
317
- # Finished with all tasks normally
318
- normal_completion = True
319
327
  finally:
320
- if not normal_completion:
321
- # Stop all workers if there is an unexpected exception
322
- for _ in pool:
323
- put_into_queue(self.task_queue, STOP_SIGNAL)
324
-
325
- # This allows workers (and this process) to exit without
326
- # consuming any remaining data in the queues.
327
- # (If they exit due to an exception.)
328
- self.task_queue.close()
329
- self.task_queue.join_thread()
330
-
331
- # Flush all items from the done queue.
332
- # This is needed if any workers are still running.
333
- while n_workers > 0:
334
- result = get_from_queue(self.done_queue)
335
- status = result["status"]
336
- if status != OK_STATUS:
337
- n_workers -= 1
338
-
339
- self.done_queue.close()
340
- self.done_queue.join_thread()
328
+ self._shutdown_workers(pool)
329
+
330
+ def _shutdown_workers(self, pool: list[Process]) -> None:
331
+ self._terminate_pool(pool)
332
+ self._drain_queue(self.done_queue)
333
+ self._drain_queue(self.task_queue)
334
+ self._close_queue(self.done_queue)
335
+ self._close_queue(self.task_queue)
336
+
337
+ def _terminate_pool(self, pool: list[Process]) -> None:
338
+ for proc in pool:
339
+ if proc.is_alive():
340
+ proc.terminate()
341
+
342
+ deadline = monotonic() + 1.0
343
+ for proc in pool:
344
+ if not proc.is_alive():
345
+ continue
346
+ remaining = deadline - monotonic()
347
+ if remaining > 0:
348
+ proc.join(remaining)
349
+ if proc.is_alive():
350
+ proc.kill()
351
+ proc.join(timeout=0.2)
352
+
353
+ def _drain_queue(self, queue: MultiprocessQueue) -> None:
354
+ while True:
355
+ try:
356
+ queue.get_nowait()
357
+ except Empty:
358
+ return
359
+ except (OSError, ValueError):
360
+ return
341
361
 
342
- # Wait for workers to stop
343
- for p in pool:
344
- p.join()
362
+ def _close_queue(self, queue: MultiprocessQueue) -> None:
363
+ with contextlib.suppress(OSError, ValueError):
364
+ queue.close()
365
+ with contextlib.suppress(RuntimeError, AssertionError, ValueError):
366
+ queue.join_thread()
345
367
 
346
368
 
347
369
  class DownloadCallback(Callback):
348
- def __init__(self, queue: "multiprocess.Queue") -> None:
370
+ def __init__(self, queue: MultiprocessQueue) -> None:
349
371
  self.queue = queue
350
372
  super().__init__()
351
373
 
@@ -360,7 +382,7 @@ class ProcessedCallback(Callback):
360
382
  def __init__(
361
383
  self,
362
384
  name: Literal["processed", "generated"],
363
- queue: "multiprocess.Queue",
385
+ queue: MultiprocessQueue,
364
386
  ) -> None:
365
387
  self.name = name
366
388
  self.queue = queue
@@ -375,8 +397,8 @@ class UDFWorker:
375
397
  self,
376
398
  catalog: "Catalog",
377
399
  udf: "UDFAdapter",
378
- task_queue: "multiprocess.Queue",
379
- done_queue: "multiprocess.Queue",
400
+ task_queue: MultiprocessQueue,
401
+ done_queue: MultiprocessQueue,
380
402
  query: "Select",
381
403
  table: "Table",
382
404
  cache: bool,
datachain/query/queue.py CHANGED
@@ -1,11 +1,12 @@
1
1
  import datetime
2
2
  from collections.abc import Iterable, Iterator
3
- from queue import Empty, Full, Queue
3
+ from queue import Empty, Full
4
4
  from struct import pack, unpack
5
5
  from time import sleep
6
6
  from typing import Any
7
7
 
8
8
  import msgpack
9
+ from multiprocess.queues import Queue
9
10
 
10
11
  from datachain.query.batch import RowsOutput
11
12
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datachain
3
- Version: 0.36.0
3
+ Version: 0.36.1
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License-Expression: Apache-2.0
@@ -64,7 +64,6 @@ Requires-Dist: torch>=2.1.0; extra == "torch"
64
64
  Requires-Dist: torchvision; extra == "torch"
65
65
  Requires-Dist: transformers>=4.36.0; extra == "torch"
66
66
  Provides-Extra: audio
67
- Requires-Dist: torchaudio; extra == "audio"
68
67
  Requires-Dist: soundfile; extra == "audio"
69
68
  Provides-Extra: remote
70
69
  Requires-Dist: lz4; extra == "remote"
@@ -76,6 +75,7 @@ Requires-Dist: numba>=0.60.0; extra == "hf"
76
75
  Requires-Dist: datasets[vision]>=4.0.0; extra == "hf"
77
76
  Requires-Dist: datasets[audio]>=4.0.0; (sys_platform == "linux" or sys_platform == "darwin") and extra == "hf"
78
77
  Requires-Dist: fsspec>=2024.12.0; extra == "hf"
78
+ Requires-Dist: torch<2.9.0; extra == "hf"
79
79
  Provides-Extra: video
80
80
  Requires-Dist: ffmpeg-python; extra == "video"
81
81
  Requires-Dist: imageio[ffmpeg,pyav]>=2.37.0; extra == "video"
@@ -117,6 +117,7 @@ Requires-Dist: huggingface_hub[hf_transfer]; extra == "examples"
117
117
  Requires-Dist: ultralytics; extra == "examples"
118
118
  Requires-Dist: open_clip_torch; extra == "examples"
119
119
  Requires-Dist: openai; extra == "examples"
120
+ Requires-Dist: torchaudio<2.9.0; extra == "examples"
120
121
  Dynamic: license-file
121
122
 
122
123
  ================
@@ -55,11 +55,11 @@ datachain/data_storage/__init__.py,sha256=9Wit-oe5P46V7CJQTD0BJ5MhOa2Y9h3ddJ4VWT
55
55
  datachain/data_storage/db_engine.py,sha256=MGbrckXk5kHOfpjnhHhGpyJpAsgaBCxMmfd33hB2SWI,3756
56
56
  datachain/data_storage/job.py,sha256=NGFhXg0C0zRFTaF6ccjXZJT4xI4_gUr1WcxTLK6WYDE,448
57
57
  datachain/data_storage/metastore.py,sha256=NLGYLErWFUNXjKbEoESFkKW222MQdMCBlpuqaYVugsE,63484
58
- datachain/data_storage/schema.py,sha256=4FZZFgPTI9e3gUFdlm1smPdES7FHctwXQNdNfY69tj8,9807
58
+ datachain/data_storage/schema.py,sha256=3fAgiE11TIDYCW7EbTdiOm61SErRitvsLr7YPnUlVm0,9801
59
59
  datachain/data_storage/serializer.py,sha256=oL8i8smyAeVUyDepk8Xhf3lFOGOEHMoZjA5GdFzvfGI,3862
60
- datachain/data_storage/sqlite.py,sha256=xQZ944neP57K_25HSetIy35IakAcyA0cUKVe-xeIEgQ,31168
61
- datachain/data_storage/warehouse.py,sha256=rNz2wFlFA-pyBAuy14RL6lRIFhrNEnX02c9SgGs4v58,34994
62
- datachain/diff/__init__.py,sha256=pixXOnbOcoxfkBvbaiDNGPhJMEyTiHb9EIFxR7QqY5A,9533
60
+ datachain/data_storage/sqlite.py,sha256=MgQ6bfJ7LGW91UiVHQtSkj_5HalRi1aeHCEW__5JEe8,30959
61
+ datachain/data_storage/warehouse.py,sha256=nuGT27visvAi7jr7ZAZF-wmFe0ZEFD8qaTheINX_7RM,35269
62
+ datachain/diff/__init__.py,sha256=Fo3xMnctKyA0YtvnsBXQ-P5gQeeEwed17Tn_i7vfLKs,9332
63
63
  datachain/fs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
64
64
  datachain/fs/reference.py,sha256=A8McpXF0CqbXPqanXuvpKu50YLB3a2ZXA3YAPxtBXSM,914
65
65
  datachain/fs/utils.py,sha256=s-FkTOCGBk-b6TT3toQH51s9608pofoFjUSTc1yy7oE,825
@@ -76,7 +76,7 @@ datachain/func/string.py,sha256=kXkPHimtA__EVg_Th1yldGaLJpw4HYVhIeYtKy3DuyQ,7406
76
76
  datachain/func/window.py,sha256=ImyRpc1QI8QUSPO7KdD60e_DPVo7Ja0G5kcm6BlyMcw,1584
77
77
  datachain/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
78
  datachain/lib/arrow.py,sha256=eCZtqbjAzkL4aemY74f_XkIJ_FWwXugJNjIFOwDa9w0,10815
79
- datachain/lib/audio.py,sha256=3QWQ7PHuRnen7al8EjgjWuKbRKe4SvrbWELJ1T_Cin0,7545
79
+ datachain/lib/audio.py,sha256=hHG29vqrV389im152wCjh80d0xqXGGvFnUpUwkzZejQ,7385
80
80
  datachain/lib/clip.py,sha256=nF8-N6Uz0MbAsPJBY2iXEYa3DPLo80OOer5SRNAtcGM,6149
81
81
  datachain/lib/data_model.py,sha256=H-bagx24-cLlC7ngSP6Dby4mB6kSxxV7KDiHxQjzwlg,3798
82
82
  datachain/lib/dataset_info.py,sha256=Ym7yYcGpfUmPLrfdxueijCVRP2Go6KbyuLk_fmzYgDU,3273
@@ -109,7 +109,7 @@ datachain/lib/convert/values_to_tuples.py,sha256=Sxj0ojeMSpAwM_NNoXa1dMR_2L_cQ6X
109
109
  datachain/lib/dc/__init__.py,sha256=UrUzmDH6YyVl8fxM5iXTSFtl5DZTUzEYm1MaazK4vdQ,900
110
110
  datachain/lib/dc/csv.py,sha256=fIfj5-2Ix4z5D5yZueagd5WUWw86pusJ9JJKD-U3KGg,4407
111
111
  datachain/lib/dc/database.py,sha256=Wqob3dQc9Mol_0vagzVEXzteCKS9M0E3U5130KVmQKg,14629
112
- datachain/lib/dc/datachain.py,sha256=Q8iEmf0MT6o5ORjyoKAt2xEIelcJ6vzZoB2e7haT7V8,104189
112
+ datachain/lib/dc/datachain.py,sha256=cVqgemBiPVLSnfEVDLU1YH0dtowS-N-YFOAxV1k7i6U,104178
113
113
  datachain/lib/dc/datasets.py,sha256=A4SW-b3dkQnm9Wi7ciCdlXqtrsquIeRfBQN_bJ_ulqY,15237
114
114
  datachain/lib/dc/hf.py,sha256=FeruEO176L2qQ1Mnx0QmK4kV0GuQ4xtj717N8fGJrBI,2849
115
115
  datachain/lib/dc/json.py,sha256=iJ6G0jwTKz8xtfh1eICShnWk_bAMWjF5bFnOXLHaTlw,2683
@@ -132,11 +132,11 @@ datachain/model/ultralytics/pose.py,sha256=pvoXrWWUSWT_UBaMwUb5MBHAY57Co2HFDPigF
132
132
  datachain/model/ultralytics/segment.py,sha256=v9_xDxd5zw_I8rXsbl7yQXgEdTs2T38zyY_Y4XGN8ok,3194
133
133
  datachain/query/__init__.py,sha256=7DhEIjAA8uZJfejruAVMZVcGFmvUpffuZJwgRqNwe-c,263
134
134
  datachain/query/batch.py,sha256=ugTlSFqh_kxMcG6vJ5XrEzG9jBXRdb7KRAEEsFWiPew,4190
135
- datachain/query/dataset.py,sha256=lv5Ta7FjFZWQRUTz9_97oeoT5OvD62unRoNLgEueWUU,67384
136
- datachain/query/dispatch.py,sha256=B0sxnyN6unU8VFc35eWa_pe_TX6JfHDDbzyIQtp8AoM,15665
135
+ datachain/query/dataset.py,sha256=Pu8FC11VcIj8ewXJxe0mjJpr4HBr2-gvCtMk4GQCva0,67419
136
+ datachain/query/dispatch.py,sha256=Tg73zB6vDnYYYAvtlS9l7BI3sI1EfRCbDjiasvNxz2s,16385
137
137
  datachain/query/metrics.py,sha256=qOMHiYPTMtVs2zI-mUSy8OPAVwrg4oJtVF85B9tdQyM,810
138
138
  datachain/query/params.py,sha256=JkVz6IKUIpF58JZRkUXFT8DAHX2yfaULbhVaGmHKFLc,826
139
- datachain/query/queue.py,sha256=v0UeK4ilmdiRoJ5OdjB5qpnHTYDxRP4vhVp5Iw_toaI,3512
139
+ datachain/query/queue.py,sha256=kCetMG6y7_ynV_jJDAXkLsf8WsVZCEk1fAuQGd7yTOo,3543
140
140
  datachain/query/schema.py,sha256=Cn1keXjktptAbEDbHlxSzdoCu5H6h_Vzp_DtNpMSr5w,6697
141
141
  datachain/query/session.py,sha256=lbwMDvxjZ2BS2rA9qk7MVBRzlsSrwH92yJ_waP3uvDc,6781
142
142
  datachain/query/udf.py,sha256=SLLLNLz3QmtaM04ZVTu7K6jo58I-1j5Jf7Lb4ORv4tQ,1385
@@ -165,9 +165,9 @@ datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR
165
165
  datachain/toolkit/__init__.py,sha256=eQ58Q5Yf_Fgv1ZG0IO5dpB4jmP90rk8YxUWmPc1M2Bo,68
166
166
  datachain/toolkit/split.py,sha256=xQzzmvQRKsPteDKbpgOxd4r971BnFaK33mcOl0FuGeI,2883
167
167
  datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
168
- datachain-0.36.0.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
169
- datachain-0.36.0.dist-info/METADATA,sha256=ZH1x0Zcl8YD035rT1qvKm3D_NnSRgGtnD0TP2FNlwgI,13606
170
- datachain-0.36.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
171
- datachain-0.36.0.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
172
- datachain-0.36.0.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
173
- datachain-0.36.0.dist-info/RECORD,,
168
+ datachain-0.36.1.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
169
+ datachain-0.36.1.dist-info/METADATA,sha256=BBaBx1Ail7RzpUlvEywlXKZtl_6Vn-KIEjm8OJdXrng,13657
170
+ datachain-0.36.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
171
+ datachain-0.36.1.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
172
+ datachain-0.36.1.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
173
+ datachain-0.36.1.dist-info/RECORD,,