datachain 0.8.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -43,8 +43,9 @@ from datachain.data_storage.schema import (
43
43
  from datachain.dataset import DatasetStatus, RowDict
44
44
  from datachain.error import DatasetNotFoundError, QueryScriptCancelError
45
45
  from datachain.func.base import Function
46
- from datachain.lib.udf import UDFAdapter
47
46
  from datachain.progress import CombinedDownloadCallback
47
+ from datachain.query.schema import C, UDFParamSpec, normalize_param
48
+ from datachain.query.session import Session
48
49
  from datachain.sql.functions.random import rand
49
50
  from datachain.utils import (
50
51
  batched,
@@ -53,9 +54,6 @@ from datachain.utils import (
53
54
  get_datachain_executable,
54
55
  )
55
56
 
56
- from .schema import C, UDFParamSpec, normalize_param
57
- from .session import Session
58
-
59
57
  if TYPE_CHECKING:
60
58
  from sqlalchemy.sql.elements import ClauseElement
61
59
  from sqlalchemy.sql.schema import Table
@@ -65,7 +63,8 @@ if TYPE_CHECKING:
65
63
  from datachain.catalog import Catalog
66
64
  from datachain.data_storage import AbstractWarehouse
67
65
  from datachain.dataset import DatasetRecord
68
- from datachain.lib.udf import UDFResult
66
+ from datachain.lib.udf import UDFAdapter, UDFResult
67
+ from datachain.query.udf import UdfInfo
69
68
 
70
69
  P = ParamSpec("P")
71
70
 
@@ -301,7 +300,7 @@ def adjust_outputs(
301
300
  return row
302
301
 
303
302
 
304
- def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]:
303
+ def get_udf_col_types(warehouse: "AbstractWarehouse", udf: "UDFAdapter") -> list[tuple]:
305
304
  """Optimization: Precompute UDF column types so these don't have to be computed
306
305
  in the convert_type function for each row in a loop."""
307
306
  dialect = warehouse.db.dialect
@@ -322,7 +321,7 @@ def process_udf_outputs(
322
321
  warehouse: "AbstractWarehouse",
323
322
  udf_table: "Table",
324
323
  udf_results: Iterator[Iterable["UDFResult"]],
325
- udf: UDFAdapter,
324
+ udf: "UDFAdapter",
326
325
  batch_size: int = INSERT_BATCH_SIZE,
327
326
  cb: Callback = DEFAULT_CALLBACK,
328
327
  ) -> None:
@@ -347,6 +346,8 @@ def process_udf_outputs(
347
346
  for row_chunk in batched(rows, batch_size):
348
347
  warehouse.insert_rows(udf_table, row_chunk)
349
348
 
349
+ warehouse.insert_rows_done(udf_table)
350
+
350
351
 
351
352
  def get_download_callback() -> Callback:
352
353
  return CombinedDownloadCallback(
@@ -366,7 +367,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
366
367
 
367
368
  @frozen
368
369
  class UDFStep(Step, ABC):
369
- udf: UDFAdapter
370
+ udf: "UDFAdapter"
370
371
  catalog: "Catalog"
371
372
  partition_by: Optional[PartitionByType] = None
372
373
  parallel: Optional[int] = None
@@ -440,7 +441,7 @@ class UDFStep(Step, ABC):
440
441
  raise RuntimeError(
441
442
  "In-memory databases cannot be used with parallel processing."
442
443
  )
443
- udf_info = {
444
+ udf_info: UdfInfo = {
444
445
  "udf_data": filtered_cloudpickle_dumps(self.udf),
445
446
  "catalog_init": self.catalog.get_init_params(),
446
447
  "metastore_clone_params": self.catalog.metastore.clone_params(),
@@ -464,8 +465,8 @@ class UDFStep(Step, ABC):
464
465
 
465
466
  with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603
466
467
  process.communicate(process_data)
467
- if process.poll():
468
- raise RuntimeError("UDF Execution Failed!")
468
+ if retval := process.poll():
469
+ raise RuntimeError(f"UDF Execution Failed! Exit code: {retval}")
469
470
  else:
470
471
  # Otherwise process single-threaded (faster for smaller UDFs)
471
472
  warehouse = self.catalog.warehouse
@@ -479,7 +480,6 @@ class UDFStep(Step, ABC):
479
480
  udf_fields,
480
481
  udf_inputs,
481
482
  self.catalog,
482
- self.is_generator,
483
483
  self.cache,
484
484
  download_cb,
485
485
  processed_cb,
@@ -496,8 +496,6 @@ class UDFStep(Step, ABC):
496
496
  processed_cb.close()
497
497
  generated_cb.close()
498
498
 
499
- warehouse.insert_rows_done(udf_table)
500
-
501
499
  except QueryScriptCancelError:
502
500
  self.catalog.warehouse.close()
503
501
  sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)
@@ -1491,7 +1489,7 @@ class DatasetQuery:
1491
1489
  @detach
1492
1490
  def add_signals(
1493
1491
  self,
1494
- udf: UDFAdapter,
1492
+ udf: "UDFAdapter",
1495
1493
  parallel: Optional[int] = None,
1496
1494
  workers: Union[bool, int] = False,
1497
1495
  min_task_size: Optional[int] = None,
@@ -1535,7 +1533,7 @@ class DatasetQuery:
1535
1533
  @detach
1536
1534
  def generate(
1537
1535
  self,
1538
- udf: UDFAdapter,
1536
+ udf: "UDFAdapter",
1539
1537
  parallel: Optional[int] = None,
1540
1538
  workers: Union[bool, int] = False,
1541
1539
  min_task_size: Optional[int] = None,
@@ -1617,7 +1615,9 @@ class DatasetQuery:
1617
1615
  )
1618
1616
  version = version or dataset.latest_version
1619
1617
 
1620
- self.session.add_dataset_version(dataset=dataset, version=version)
1618
+ self.session.add_dataset_version(
1619
+ dataset=dataset, version=version, listing=kwargs.get("listing", False)
1620
+ )
1621
1621
 
1622
1622
  dr = self.catalog.warehouse.dataset_rows(dataset)
1623
1623
 
@@ -1,34 +1,37 @@
1
1
  import contextlib
2
- from collections.abc import Iterator, Sequence
2
+ from collections.abc import Iterable, Sequence
3
3
  from itertools import chain
4
4
  from multiprocessing import cpu_count
5
5
  from sys import stdin
6
- from typing import Optional
6
+ from threading import Timer
7
+ from typing import TYPE_CHECKING, Optional
7
8
 
8
9
  import attrs
9
10
  import multiprocess
10
11
  from cloudpickle import load, loads
11
12
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
12
13
  from multiprocess import get_context
14
+ from sqlalchemy.sql import func
13
15
 
14
16
  from datachain.catalog import Catalog
15
17
  from datachain.catalog.loader import get_distributed_class
16
- from datachain.lib.udf import UDFAdapter, UDFResult
18
+ from datachain.query.batch import RowsOutput, RowsOutputBatch
17
19
  from datachain.query.dataset import (
18
20
  get_download_callback,
19
21
  get_generated_callback,
20
22
  get_processed_callback,
21
23
  process_udf_outputs,
22
24
  )
23
- from datachain.query.queue import (
24
- get_from_queue,
25
- marshal,
26
- msgpack_pack,
27
- msgpack_unpack,
28
- put_into_queue,
29
- unmarshal,
30
- )
31
- from datachain.utils import batched_it
25
+ from datachain.query.queue import get_from_queue, put_into_queue
26
+ from datachain.query.udf import UdfInfo
27
+ from datachain.query.utils import get_query_id_column
28
+ from datachain.utils import batched, flatten
29
+
30
+ if TYPE_CHECKING:
31
+ from sqlalchemy import Select, Table
32
+
33
+ from datachain.data_storage import AbstractMetastore, AbstractWarehouse
34
+ from datachain.lib.udf import UDFAdapter
32
35
 
33
36
  DEFAULT_BATCH_SIZE = 10000
34
37
  STOP_SIGNAL = "STOP"
@@ -38,10 +41,6 @@ FAILED_STATUS = "FAILED"
38
41
  NOTIFY_STATUS = "NOTIFY"
39
42
 
40
43
 
41
- def full_module_type_path(typ: type) -> str:
42
- return f"{typ.__module__}.{typ.__qualname__}"
43
-
44
-
45
44
  def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
46
45
  if not n_workers:
47
46
  return cpu_count()
@@ -52,55 +51,42 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
52
51
 
53
52
  def udf_entrypoint() -> int:
54
53
  # Load UDF info from stdin
55
- udf_info = load(stdin.buffer)
56
-
57
- (
58
- warehouse_class,
59
- warehouse_args,
60
- warehouse_kwargs,
61
- ) = udf_info["warehouse_clone_params"]
62
- warehouse = warehouse_class(*warehouse_args, **warehouse_kwargs)
54
+ udf_info: UdfInfo = load(stdin.buffer)
63
55
 
64
56
  # Parallel processing (faster for more CPU-heavy UDFs)
65
- dispatch = UDFDispatcher(
66
- udf_info["udf_data"],
67
- udf_info["catalog_init"],
68
- udf_info["metastore_clone_params"],
69
- udf_info["warehouse_clone_params"],
70
- udf_fields=udf_info["udf_fields"],
71
- cache=udf_info["cache"],
72
- is_generator=udf_info.get("is_generator", False),
73
- )
57
+ dispatch = UDFDispatcher(udf_info)
74
58
 
75
59
  query = udf_info["query"]
76
60
  batching = udf_info["batching"]
77
- table = udf_info["table"]
78
61
  n_workers = udf_info["processes"]
79
- udf = loads(udf_info["udf_data"])
80
62
  if n_workers is True:
81
- # Use default number of CPUs (cores)
82
- n_workers = None
63
+ n_workers = None # Use default number of CPUs (cores)
64
+
65
+ wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
66
+ warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
67
+
68
+ total_rows = next(
69
+ warehouse.db.execute(
70
+ query.with_only_columns(func.count(query.c.sys__id)).order_by(None)
71
+ )
72
+ )[0]
83
73
 
84
74
  with contextlib.closing(
85
- batching(warehouse.dataset_select_paginated, query)
75
+ batching(warehouse.dataset_select_paginated, query, ids_only=True)
86
76
  ) as udf_inputs:
87
77
  download_cb = get_download_callback()
88
78
  processed_cb = get_processed_callback()
89
- generated_cb = get_generated_callback(dispatch.is_generator)
90
79
  try:
91
- udf_results = dispatch.run_udf_parallel(
92
- marshal(udf_inputs),
80
+ dispatch.run_udf_parallel(
81
+ udf_inputs,
82
+ total_rows=total_rows,
93
83
  n_workers=n_workers,
94
84
  processed_cb=processed_cb,
95
85
  download_cb=download_cb,
96
86
  )
97
- process_udf_outputs(warehouse, table, udf_results, udf, cb=generated_cb)
98
87
  finally:
99
88
  download_cb.close()
100
89
  processed_cb.close()
101
- generated_cb.close()
102
-
103
- warehouse.insert_rows_done(table)
104
90
 
105
91
  return 0
106
92
 
@@ -114,32 +100,17 @@ class UDFDispatcher:
114
100
  task_queue: Optional[multiprocess.Queue] = None
115
101
  done_queue: Optional[multiprocess.Queue] = None
116
102
 
117
- def __init__(
118
- self,
119
- udf_data,
120
- catalog_init_params,
121
- metastore_clone_params,
122
- warehouse_clone_params,
123
- udf_fields: "Sequence[str]",
124
- cache: bool,
125
- is_generator: bool = False,
126
- buffer_size: int = DEFAULT_BATCH_SIZE,
127
- ):
128
- self.udf_data = udf_data
129
- self.catalog_init_params = catalog_init_params
130
- (
131
- self.metastore_class,
132
- self.metastore_args,
133
- self.metastore_kwargs,
134
- ) = metastore_clone_params
135
- (
136
- self.warehouse_class,
137
- self.warehouse_args,
138
- self.warehouse_kwargs,
139
- ) = warehouse_clone_params
140
- self.udf_fields = udf_fields
141
- self.cache = cache
142
- self.is_generator = is_generator
103
+ def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
104
+ self.udf_data = udf_info["udf_data"]
105
+ self.catalog_init_params = udf_info["catalog_init"]
106
+ self.metastore_clone_params = udf_info["metastore_clone_params"]
107
+ self.warehouse_clone_params = udf_info["warehouse_clone_params"]
108
+ self.query = udf_info["query"]
109
+ self.table = udf_info["table"]
110
+ self.udf_fields = udf_info["udf_fields"]
111
+ self.cache = udf_info["cache"]
112
+ self.is_generator = udf_info["is_generator"]
113
+ self.is_batching = udf_info["batching"].is_batching
143
114
  self.buffer_size = buffer_size
144
115
  self.catalog = None
145
116
  self.task_queue = None
@@ -148,12 +119,10 @@ class UDFDispatcher:
148
119
 
149
120
  def _create_worker(self) -> "UDFWorker":
150
121
  if not self.catalog:
151
- metastore = self.metastore_class(
152
- *self.metastore_args, **self.metastore_kwargs
153
- )
154
- warehouse = self.warehouse_class(
155
- *self.warehouse_args, **self.warehouse_kwargs
156
- )
122
+ ms_cls, ms_args, ms_kwargs = self.metastore_clone_params
123
+ metastore: AbstractMetastore = ms_cls(*ms_args, **ms_kwargs)
124
+ ws_cls, ws_args, ws_kwargs = self.warehouse_clone_params
125
+ warehouse: AbstractWarehouse = ws_cls(*ws_args, **ws_kwargs)
157
126
  self.catalog = Catalog(metastore, warehouse, **self.catalog_init_params)
158
127
  self.udf = loads(self.udf_data)
159
128
  return UDFWorker(
@@ -161,7 +130,10 @@ class UDFDispatcher:
161
130
  self.udf,
162
131
  self.task_queue,
163
132
  self.done_queue,
133
+ self.query,
134
+ self.table,
164
135
  self.is_generator,
136
+ self.is_batching,
165
137
  self.cache,
166
138
  self.udf_fields,
167
139
  )
@@ -189,26 +161,27 @@ class UDFDispatcher:
189
161
 
190
162
  def run_udf_parallel( # noqa: C901, PLR0912
191
163
  self,
192
- input_rows,
164
+ input_rows: Iterable[RowsOutput],
165
+ total_rows: int,
193
166
  n_workers: Optional[int] = None,
194
- input_queue=None,
195
167
  processed_cb: Callback = DEFAULT_CALLBACK,
196
168
  download_cb: Callback = DEFAULT_CALLBACK,
197
- ) -> Iterator[Sequence[UDFResult]]:
169
+ ) -> None:
198
170
  n_workers = get_n_workers_from_arg(n_workers)
199
171
 
172
+ input_batch_size = total_rows // n_workers
173
+ if input_batch_size == 0:
174
+ input_batch_size = 1
175
+ elif input_batch_size > DEFAULT_BATCH_SIZE:
176
+ input_batch_size = DEFAULT_BATCH_SIZE
177
+
200
178
  if self.buffer_size < n_workers:
201
179
  raise RuntimeError(
202
180
  "Parallel run error: buffer size is smaller than "
203
181
  f"number of workers: {self.buffer_size} < {n_workers}"
204
182
  )
205
183
 
206
- if input_queue:
207
- streaming_mode = True
208
- self.task_queue = input_queue
209
- else:
210
- streaming_mode = False
211
- self.task_queue = self.ctx.Queue()
184
+ self.task_queue = self.ctx.Queue()
212
185
  self.done_queue = self.ctx.Queue()
213
186
  pool = [
214
187
  self.ctx.Process(name=f"Worker-UDF-{i}", target=self._run_worker)
@@ -223,41 +196,41 @@ class UDFDispatcher:
223
196
  # Will be set to True when the input is exhausted
224
197
  input_finished = False
225
198
 
226
- if not streaming_mode:
227
- # Stop all workers after the input rows have finished processing
228
- input_data = chain(input_rows, [STOP_SIGNAL] * n_workers)
199
+ if not self.is_batching:
200
+ input_rows = batched(flatten(input_rows), input_batch_size)
229
201
 
230
- # Add initial buffer of tasks
231
- for _ in range(self.buffer_size):
232
- try:
233
- put_into_queue(self.task_queue, next(input_data))
234
- except StopIteration:
235
- input_finished = True
236
- break
202
+ # Stop all workers after the input rows have finished processing
203
+ input_data = chain(input_rows, [STOP_SIGNAL] * n_workers)
204
+
205
+ # Add initial buffer of tasks
206
+ for _ in range(self.buffer_size):
207
+ try:
208
+ put_into_queue(self.task_queue, next(input_data))
209
+ except StopIteration:
210
+ input_finished = True
211
+ break
237
212
 
238
213
  # Process all tasks
239
214
  while n_workers > 0:
240
215
  result = get_from_queue(self.done_queue)
216
+
217
+ if downloaded := result.get("downloaded"):
218
+ download_cb.relative_update(downloaded)
219
+ if processed := result.get("processed"):
220
+ processed_cb.relative_update(processed)
221
+
241
222
  status = result["status"]
242
- if status == NOTIFY_STATUS:
243
- if downloaded := result.get("downloaded"):
244
- download_cb.relative_update(downloaded)
245
- if processed := result.get("processed"):
246
- processed_cb.relative_update(processed)
223
+ if status in (OK_STATUS, NOTIFY_STATUS):
224
+ pass # Do nothing here
247
225
  elif status == FINISHED_STATUS:
248
- # Worker finished
249
- n_workers -= 1
250
- elif status == OK_STATUS:
251
- if processed := result.get("processed"):
252
- processed_cb.relative_update(processed)
253
- yield msgpack_unpack(result["result"])
226
+ n_workers -= 1 # Worker finished
254
227
  else: # Failed / error
255
228
  n_workers -= 1
256
229
  if exc := result.get("exception"):
257
230
  raise exc
258
231
  raise RuntimeError("Internal error: Parallel UDF execution failed")
259
232
 
260
- if status == OK_STATUS and not streaming_mode and not input_finished:
233
+ if status == OK_STATUS and not input_finished:
261
234
  try:
262
235
  put_into_queue(self.task_queue, next(input_data))
263
236
  except StopIteration:
@@ -311,11 +284,14 @@ class ProcessedCallback(Callback):
311
284
 
312
285
  @attrs.define
313
286
  class UDFWorker:
314
- catalog: Catalog
315
- udf: UDFAdapter
287
+ catalog: "Catalog"
288
+ udf: "UDFAdapter"
316
289
  task_queue: "multiprocess.Queue"
317
290
  done_queue: "multiprocess.Queue"
291
+ query: "Select"
292
+ table: "Table"
318
293
  is_generator: bool
294
+ is_batching: bool
319
295
  cache: bool
320
296
  udf_fields: Sequence[str]
321
297
  cb: Callback = attrs.field()
@@ -326,30 +302,54 @@ class UDFWorker:
326
302
 
327
303
  def run(self) -> None:
328
304
  processed_cb = ProcessedCallback()
305
+ generated_cb = get_generated_callback(self.is_generator)
306
+
329
307
  udf_results = self.udf.run(
330
308
  self.udf_fields,
331
- unmarshal(self.get_inputs()),
309
+ self.get_inputs(),
332
310
  self.catalog,
333
- self.is_generator,
334
311
  self.cache,
335
312
  download_cb=self.cb,
336
313
  processed_cb=processed_cb,
337
314
  )
338
- for udf_output in udf_results:
339
- for batch in batched_it(udf_output, DEFAULT_BATCH_SIZE):
340
- put_into_queue(
341
- self.done_queue,
342
- {
343
- "status": OK_STATUS,
344
- "result": msgpack_pack(list(batch)),
345
- },
346
- )
315
+ process_udf_outputs(
316
+ self.catalog.warehouse,
317
+ self.table,
318
+ self.notify_and_process(udf_results, processed_cb),
319
+ self.udf,
320
+ cb=generated_cb,
321
+ )
322
+
323
+ put_into_queue(
324
+ self.done_queue,
325
+ {"status": FINISHED_STATUS, "processed": processed_cb.processed_rows},
326
+ )
327
+
328
+ def notify_and_process(self, udf_results, processed_cb):
329
+ for row in udf_results:
347
330
  put_into_queue(
348
331
  self.done_queue,
349
- {"status": NOTIFY_STATUS, "processed": processed_cb.processed_rows},
332
+ {"status": OK_STATUS, "processed": processed_cb.processed_rows},
350
333
  )
351
- put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
334
+ yield row
352
335
 
353
336
  def get_inputs(self):
354
- while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
355
- yield batch
337
+ warehouse = self.catalog.warehouse.clone()
338
+ col_id = get_query_id_column(self.query)
339
+
340
+ if self.is_batching:
341
+ while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
342
+ ids = [row[0] for row in batch.rows]
343
+ rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids)))
344
+ yield RowsOutputBatch(list(rows))
345
+ else:
346
+ while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
347
+ yield from warehouse.dataset_rows_select(
348
+ self.query.where(col_id.in_(batch))
349
+ )
350
+
351
+
352
+ class RepeatTimer(Timer):
353
+ def run(self):
354
+ while not self.finished.wait(self.interval):
355
+ self.function(*self.args, **self.kwargs)
@@ -69,7 +69,7 @@ class Session:
69
69
  self.catalog = catalog or get_catalog(
70
70
  client_config=client_config, in_memory=in_memory
71
71
  )
72
- self.dataset_versions: list[tuple[DatasetRecord, int]] = []
72
+ self.dataset_versions: list[tuple[DatasetRecord, int, bool]] = []
73
73
 
74
74
  def __enter__(self):
75
75
  # Push the current context onto the stack
@@ -89,8 +89,10 @@ class Session:
89
89
  if Session.SESSION_CONTEXTS:
90
90
  Session.SESSION_CONTEXTS.pop()
91
91
 
92
- def add_dataset_version(self, dataset: "DatasetRecord", version: int) -> None:
93
- self.dataset_versions.append((dataset, version))
92
+ def add_dataset_version(
93
+ self, dataset: "DatasetRecord", version: int, listing: bool = False
94
+ ) -> None:
95
+ self.dataset_versions.append((dataset, version, listing))
94
96
 
95
97
  def generate_temp_dataset_name(self) -> str:
96
98
  return self.get_temp_prefix() + uuid4().hex[: self.TEMP_TABLE_UUID_LEN]
@@ -111,8 +113,9 @@ class Session:
111
113
  if not self.dataset_versions:
112
114
  return
113
115
 
114
- for dataset, version in self.dataset_versions:
115
- self.catalog.remove_dataset_version(dataset, version)
116
+ for dataset, version, listing in self.dataset_versions:
117
+ if not listing:
118
+ self.catalog.remove_dataset_version(dataset, version)
116
119
 
117
120
  self.dataset_versions.clear()
118
121
 
datachain/query/udf.py ADDED
@@ -0,0 +1,20 @@
1
+ from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict
2
+
3
+ if TYPE_CHECKING:
4
+ from sqlalchemy import Select, Table
5
+
6
+ from datachain.query.batch import BatchingStrategy
7
+
8
+
9
+ class UdfInfo(TypedDict):
10
+ udf_data: bytes
11
+ catalog_init: dict[str, Any]
12
+ metastore_clone_params: tuple[Callable[..., Any], list[Any], dict[str, Any]]
13
+ warehouse_clone_params: tuple[Callable[..., Any], list[Any], dict[str, Any]]
14
+ table: "Table"
15
+ query: "Select"
16
+ udf_fields: list[str]
17
+ batching: "BatchingStrategy"
18
+ processes: Optional[int]
19
+ is_generator: bool
20
+ cache: bool
@@ -0,0 +1,42 @@
1
+ from typing import TYPE_CHECKING, Optional, Union
2
+
3
+ from sqlalchemy import Column
4
+
5
+ if TYPE_CHECKING:
6
+ from sqlalchemy import ColumnElement, Select, TextClause
7
+
8
+
9
+ ColT = Union[Column, "ColumnElement", "TextClause"]
10
+
11
+
12
+ def column_name(col: ColT) -> str:
13
+ """Returns column name from column element."""
14
+ return col.name if isinstance(col, Column) else str(col)
15
+
16
+
17
+ def get_query_column(query: "Select", name: str) -> Optional[ColT]:
18
+ """Returns column element from query by name or None if column not found."""
19
+ return next((col for col in query.inner_columns if column_name(col) == name), None)
20
+
21
+
22
+ def get_query_id_column(query: "Select") -> ColT:
23
+ """Returns ID column element from query or None if column not found."""
24
+ col = get_query_column(query, "sys__id")
25
+ if col is None:
26
+ raise RuntimeError("sys__id column not found in query")
27
+ return col
28
+
29
+
30
+ def select_only_columns(query: "Select", *names: str) -> "Select":
31
+ """Returns query selecting defined columns only."""
32
+ if not names:
33
+ return query
34
+
35
+ cols: list[ColT] = []
36
+ for name in names:
37
+ col = get_query_column(query, name)
38
+ if col is None:
39
+ raise ValueError(f"Column '{name}' not found in query")
40
+ cols.append(col)
41
+
42
+ return query.with_only_columns(*cols)
datachain/utils.py CHANGED
@@ -263,7 +263,7 @@ def batched_it(iterable: Iterable[_T_co], n: int) -> Iterator[Iterator[_T_co]]:
263
263
 
264
264
  def flatten(items):
265
265
  for item in items:
266
- if isinstance(item, list):
266
+ if isinstance(item, (list, tuple)):
267
267
  yield from item
268
268
  else:
269
269
  yield item
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: datachain
3
- Version: 0.8.0
3
+ Version: 0.8.1
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License: Apache-2.0
@@ -84,7 +84,7 @@ Requires-Dist: requests-mock; extra == "tests"
84
84
  Requires-Dist: scipy; extra == "tests"
85
85
  Provides-Extra: dev
86
86
  Requires-Dist: datachain[docs,tests]; extra == "dev"
87
- Requires-Dist: mypy==1.13.0; extra == "dev"
87
+ Requires-Dist: mypy==1.14.0; extra == "dev"
88
88
  Requires-Dist: types-python-dateutil; extra == "dev"
89
89
  Requires-Dist: types-pytz; extra == "dev"
90
90
  Requires-Dist: types-PyYAML; extra == "dev"
@@ -99,7 +99,7 @@ Requires-Dist: unstructured[pdf]; extra == "examples"
99
99
  Requires-Dist: pdfplumber==0.11.4; extra == "examples"
100
100
  Requires-Dist: huggingface_hub[hf_transfer]; extra == "examples"
101
101
  Requires-Dist: onnx==1.16.1; extra == "examples"
102
- Requires-Dist: ultralytics==8.3.50; extra == "examples"
102
+ Requires-Dist: ultralytics==8.3.53; extra == "examples"
103
103
 
104
104
  ================
105
105
  |logo| DataChain