datachain 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -2,11 +2,8 @@ import contextlib
2
2
  from collections.abc import Iterator, Sequence
3
3
  from itertools import chain
4
4
  from multiprocessing import cpu_count
5
- from queue import Empty, Full, Queue
6
5
  from sys import stdin
7
- from time import sleep
8
- from types import GeneratorType
9
- from typing import Any, Optional
6
+ from typing import Optional
10
7
 
11
8
  import attrs
12
9
  import multiprocess
@@ -22,7 +19,16 @@ from datachain.query.dataset import (
22
19
  get_processed_callback,
23
20
  process_udf_outputs,
24
21
  )
22
+ from datachain.query.queue import (
23
+ get_from_queue,
24
+ marshal,
25
+ msgpack_pack,
26
+ msgpack_unpack,
27
+ put_into_queue,
28
+ unmarshal,
29
+ )
25
30
  from datachain.query.udf import UDFBase, UDFFactory, UDFResult
31
+ from datachain.utils import batched_it
26
32
 
27
33
  DEFAULT_BATCH_SIZE = 10000
28
34
  STOP_SIGNAL = "STOP"
@@ -44,44 +50,6 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
44
50
  return n_workers
45
51
 
46
52
 
47
- # For more context on the get_from_queue and put_into_queue functions, see the
48
- # discussion here:
49
- # https://github.com/iterative/dvcx/pull/1297#issuecomment-2026308773
50
- # This problem is not exactly described by, but is also related to these Python issues:
51
- # https://github.com/python/cpython/issues/66587
52
- # https://github.com/python/cpython/issues/88628
53
- # https://github.com/python/cpython/issues/108645
54
-
55
-
56
- def get_from_queue(queue: Queue) -> Any:
57
- """
58
- Gets an item from a queue.
59
- This is required to handle signals, such as KeyboardInterrupt exceptions
60
- while waiting for items to be available, although only on certain installations.
61
- (See the above comment for more context.)
62
- """
63
- while True:
64
- try:
65
- return queue.get_nowait()
66
- except Empty:
67
- sleep(0.01)
68
-
69
-
70
- def put_into_queue(queue: Queue, item: Any) -> None:
71
- """
72
- Puts an item into a queue.
73
- This is required to handle signals, such as KeyboardInterrupt exceptions
74
- while waiting for items to be queued, although only on certain installations.
75
- (See the above comment for more context.)
76
- """
77
- while True:
78
- try:
79
- queue.put_nowait(item)
80
- return
81
- except Full:
82
- sleep(0.01)
83
-
84
-
85
53
  def udf_entrypoint() -> int:
86
54
  # Load UDF info from stdin
87
55
  udf_info = load(stdin.buffer)
@@ -100,8 +68,9 @@ def udf_entrypoint() -> int:
100
68
  udf_info["id_generator_clone_params"],
101
69
  udf_info["metastore_clone_params"],
102
70
  udf_info["warehouse_clone_params"],
103
- is_generator=udf_info.get("is_generator", False),
71
+ udf_fields=udf_info["udf_fields"],
104
72
  cache=udf_info["cache"],
73
+ is_generator=udf_info.get("is_generator", False),
105
74
  )
106
75
 
107
76
  query = udf_info["query"]
@@ -121,7 +90,7 @@ def udf_entrypoint() -> int:
121
90
  generated_cb = get_generated_callback(dispatch.is_generator)
122
91
  try:
123
92
  udf_results = dispatch.run_udf_parallel(
124
- udf_inputs,
93
+ marshal(udf_inputs),
125
94
  n_workers=n_workers,
126
95
  processed_cb=processed_cb,
127
96
  download_cb=download_cb,
@@ -142,6 +111,9 @@ def udf_worker_entrypoint() -> int:
142
111
 
143
112
 
144
113
  class UDFDispatcher:
114
+ catalog: Optional[Catalog] = None
115
+ task_queue: Optional[multiprocess.Queue] = None
116
+ done_queue: Optional[multiprocess.Queue] = None
145
117
  _batch_size: Optional[int] = None
146
118
 
147
119
  def __init__(
@@ -151,9 +123,10 @@ class UDFDispatcher:
151
123
  id_generator_clone_params,
152
124
  metastore_clone_params,
153
125
  warehouse_clone_params,
154
- cache,
155
- is_generator=False,
156
- buffer_size=DEFAULT_BATCH_SIZE,
126
+ udf_fields: "Sequence[str]",
127
+ cache: bool,
128
+ is_generator: bool = False,
129
+ buffer_size: int = DEFAULT_BATCH_SIZE,
157
130
  ):
158
131
  self.udf_data = udf_data
159
132
  self.catalog_init_params = catalog_init_params
@@ -172,12 +145,13 @@ class UDFDispatcher:
172
145
  self.warehouse_args,
173
146
  self.warehouse_kwargs,
174
147
  ) = warehouse_clone_params
175
- self.is_generator = is_generator
148
+ self.udf_fields = udf_fields
176
149
  self.cache = cache
150
+ self.is_generator = is_generator
151
+ self.buffer_size = buffer_size
177
152
  self.catalog = None
178
153
  self.task_queue = None
179
154
  self.done_queue = None
180
- self.buffer_size = buffer_size
181
155
  self.ctx = get_context("spawn")
182
156
 
183
157
  @property
@@ -226,6 +200,7 @@ class UDFDispatcher:
226
200
  self.done_queue,
227
201
  self.is_generator,
228
202
  self.cache,
203
+ self.udf_fields,
229
204
  )
230
205
 
231
206
  def _run_worker(self) -> None:
@@ -233,7 +208,11 @@ class UDFDispatcher:
233
208
  worker = self._create_worker()
234
209
  worker.run()
235
210
  except (Exception, KeyboardInterrupt) as e:
236
- put_into_queue(self.done_queue, {"status": FAILED_STATUS, "exception": e})
211
+ if self.done_queue:
212
+ put_into_queue(
213
+ self.done_queue,
214
+ {"status": FAILED_STATUS, "exception": e},
215
+ )
237
216
  raise
238
217
 
239
218
  @staticmethod
@@ -249,7 +228,6 @@ class UDFDispatcher:
249
228
  self,
250
229
  input_rows,
251
230
  n_workers: Optional[int] = None,
252
- cache: bool = False,
253
231
  input_queue=None,
254
232
  processed_cb: Callback = DEFAULT_CALLBACK,
255
233
  download_cb: Callback = DEFAULT_CALLBACK,
@@ -299,21 +277,24 @@ class UDFDispatcher:
299
277
  result = get_from_queue(self.done_queue)
300
278
  status = result["status"]
301
279
  if status == NOTIFY_STATUS:
302
- download_cb.relative_update(result["downloaded"])
280
+ if downloaded := result.get("downloaded"):
281
+ download_cb.relative_update(downloaded)
282
+ if processed := result.get("processed"):
283
+ processed_cb.relative_update(processed)
303
284
  elif status == FINISHED_STATUS:
304
285
  # Worker finished
305
286
  n_workers -= 1
306
287
  elif status == OK_STATUS:
307
- processed_cb.relative_update(result["processed"])
308
- yield result["result"]
288
+ if processed := result.get("processed"):
289
+ processed_cb.relative_update(processed)
290
+ yield msgpack_unpack(result["result"])
309
291
  else: # Failed / error
310
292
  n_workers -= 1
311
- exc = result.get("exception")
312
- if exc:
293
+ if exc := result.get("exception"):
313
294
  raise exc
314
295
  raise RuntimeError("Internal error: Parallel UDF execution failed")
315
296
 
316
- if not streaming_mode and not input_finished:
297
+ if status == OK_STATUS and not streaming_mode and not input_finished:
317
298
  try:
318
299
  put_into_queue(self.task_queue, next(input_data))
319
300
  except StopIteration:
@@ -348,7 +329,7 @@ class UDFDispatcher:
348
329
 
349
330
 
350
331
  class WorkerCallback(Callback):
351
- def __init__(self, queue: multiprocess.Queue):
332
+ def __init__(self, queue: "multiprocess.Queue"):
352
333
  self.queue = queue
353
334
  super().__init__()
354
335
 
@@ -369,10 +350,11 @@ class ProcessedCallback(Callback):
369
350
  class UDFWorker:
370
351
  catalog: Catalog
371
352
  udf: UDFBase
372
- task_queue: multiprocess.Queue
373
- done_queue: multiprocess.Queue
353
+ task_queue: "multiprocess.Queue"
354
+ done_queue: "multiprocess.Queue"
374
355
  is_generator: bool
375
356
  cache: bool
357
+ udf_fields: Sequence[str]
376
358
  cb: Callback = attrs.field()
377
359
 
378
360
  @cb.default
@@ -382,7 +364,8 @@ class UDFWorker:
382
364
  def run(self) -> None:
383
365
  processed_cb = ProcessedCallback()
384
366
  udf_results = self.udf.run(
385
- self.get_inputs(),
367
+ self.udf_fields,
368
+ unmarshal(self.get_inputs()),
386
369
  self.catalog,
387
370
  self.is_generator,
388
371
  self.cache,
@@ -390,15 +373,17 @@ class UDFWorker:
390
373
  processed_cb=processed_cb,
391
374
  )
392
375
  for udf_output in udf_results:
393
- if isinstance(udf_output, GeneratorType):
394
- udf_output = list(udf_output) # can not pickle generator
376
+ for batch in batched_it(udf_output, DEFAULT_BATCH_SIZE):
377
+ put_into_queue(
378
+ self.done_queue,
379
+ {
380
+ "status": OK_STATUS,
381
+ "result": msgpack_pack(list(batch)),
382
+ },
383
+ )
395
384
  put_into_queue(
396
385
  self.done_queue,
397
- {
398
- "status": OK_STATUS,
399
- "result": udf_output,
400
- "processed": processed_cb.processed_rows,
401
- },
386
+ {"status": NOTIFY_STATUS, "processed": processed_cb.processed_rows},
402
387
  )
403
388
  put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
404
389
 
@@ -0,0 +1,120 @@
1
+ import datetime
2
+ from collections.abc import Iterable, Iterator
3
+ from queue import Empty, Full, Queue
4
+ from struct import pack, unpack
5
+ from time import sleep
6
+ from typing import Any
7
+
8
+ import msgpack
9
+
10
+ from datachain.query.batch import RowsOutput, RowsOutputBatch
11
+
12
+ DEFAULT_BATCH_SIZE = 10000
13
+ STOP_SIGNAL = "STOP"
14
+ OK_STATUS = "OK"
15
+ FINISHED_STATUS = "FINISHED"
16
+ FAILED_STATUS = "FAILED"
17
+ NOTIFY_STATUS = "NOTIFY"
18
+
19
+
20
+ # For more context on the get_from_queue and put_into_queue functions, see the
21
+ # discussion here:
22
+ # https://github.com/iterative/dvcx/pull/1297#issuecomment-2026308773
23
+ # This problem is not exactly described by, but is also related to these Python issues:
24
+ # https://github.com/python/cpython/issues/66587
25
+ # https://github.com/python/cpython/issues/88628
26
+ # https://github.com/python/cpython/issues/108645
27
+
28
+
29
+ def get_from_queue(queue: Queue) -> Any:
30
+ """
31
+ Gets an item from a queue.
32
+ This is required to handle signals, such as KeyboardInterrupt exceptions
33
+ while waiting for items to be available, although only on certain installations.
34
+ (See the above comment for more context.)
35
+ """
36
+ while True:
37
+ try:
38
+ return queue.get_nowait()
39
+ except Empty:
40
+ sleep(0.01)
41
+
42
+
43
+ def put_into_queue(queue: Queue, item: Any) -> None:
44
+ """
45
+ Puts an item into a queue.
46
+ This is required to handle signals, such as KeyboardInterrupt exceptions
47
+ while waiting for items to be queued, although only on certain installations.
48
+ (See the above comment for more context.)
49
+ """
50
+ while True:
51
+ try:
52
+ queue.put_nowait(item)
53
+ return
54
+ except Full:
55
+ sleep(0.01)
56
+
57
+
58
+ MSGPACK_EXT_TYPE_DATETIME = 42
59
+ MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH = 43
60
+
61
+
62
+ def _msgpack_pack_extended_types(obj: Any) -> msgpack.ExtType:
63
+ if isinstance(obj, datetime.datetime):
64
+ # packing date object as 1 or 2 variables, depending if timezone info is present
65
+ # - timestamp
66
+ # - [OPTIONAL] timezone offset from utc in seconds if timezone info exists
67
+ if obj.tzinfo:
68
+ data = (obj.timestamp(), int(obj.utcoffset().total_seconds())) # type: ignore # noqa: PGH003
69
+ return msgpack.ExtType(MSGPACK_EXT_TYPE_DATETIME, pack("!dl", *data))
70
+ data = (obj.timestamp(),) # type: ignore # noqa: PGH003
71
+ return msgpack.ExtType(MSGPACK_EXT_TYPE_DATETIME, pack("!d", *data))
72
+
73
+ if isinstance(obj, RowsOutputBatch):
74
+ return msgpack.ExtType(
75
+ MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH,
76
+ msgpack_pack(obj.rows),
77
+ )
78
+
79
+ raise TypeError(f"Unknown type: {obj}")
80
+
81
+
82
+ def msgpack_pack(obj: Any) -> bytes:
83
+ return msgpack.packb(obj, default=_msgpack_pack_extended_types)
84
+
85
+
86
+ def _msgpack_unpack_extended_types(code: int, data: bytes) -> Any:
87
+ if code == MSGPACK_EXT_TYPE_DATETIME:
88
+ has_timezone = False
89
+ if len(data) == 8:
90
+ # we send only timestamp without timezone if data is 8 bytes
91
+ values = unpack("!d", data)
92
+ else:
93
+ has_timezone = True
94
+ values = unpack("!dl", data)
95
+
96
+ timestamp = values[0]
97
+ tz_info = None
98
+ if has_timezone:
99
+ timezone_offset = values[1]
100
+ tz_info = datetime.timezone(datetime.timedelta(seconds=timezone_offset))
101
+ return datetime.datetime.fromtimestamp(timestamp, tz=tz_info)
102
+
103
+ if code == MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH:
104
+ return RowsOutputBatch(msgpack_unpack(data))
105
+
106
+ return msgpack.ExtType(code, data)
107
+
108
+
109
+ def msgpack_unpack(data: bytes) -> Any:
110
+ return msgpack.unpackb(data, ext_hook=_msgpack_unpack_extended_types)
111
+
112
+
113
+ def marshal(obj: Iterator[RowsOutput]) -> Iterable[bytes]:
114
+ for row in obj:
115
+ yield msgpack_pack(row)
116
+
117
+
118
+ def unmarshal(obj: Iterator[bytes]) -> Iterable[RowsOutput]:
119
+ for row in obj:
120
+ yield msgpack_unpack(row)
datachain/query/schema.py CHANGED
@@ -45,6 +45,10 @@ class Column(sa.ColumnClause, metaclass=ColumnMeta):
45
45
  """Search for matches using glob pattern matching."""
46
46
  return self.op("GLOB")(glob_str)
47
47
 
48
+ def regexp(self, regexp_str):
49
+ """Search for matches using regexp pattern matching."""
50
+ return self.op("REGEXP")(regexp_str)
51
+
48
52
 
49
53
  class UDFParameter(ABC):
50
54
  @abstractmethod
datachain/query/udf.py CHANGED
@@ -15,7 +15,14 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback
15
15
 
16
16
  from datachain.dataset import RowDict
17
17
 
18
- from .batch import Batch, BatchingStrategy, NoBatching, Partition, RowBatch
18
+ from .batch import (
19
+ Batch,
20
+ BatchingStrategy,
21
+ NoBatching,
22
+ Partition,
23
+ RowsOutputBatch,
24
+ UDFInputBatch,
25
+ )
19
26
  from .schema import (
20
27
  UDFParameter,
21
28
  UDFParamSpec,
@@ -25,7 +32,7 @@ from .schema import (
25
32
  if TYPE_CHECKING:
26
33
  from datachain.catalog import Catalog
27
34
 
28
- from .batch import BatchingResult
35
+ from .batch import RowsOutput, UDFInput
29
36
 
30
37
  ColumnType = Any
31
38
 
@@ -107,7 +114,8 @@ class UDFBase:
107
114
 
108
115
  def run(
109
116
  self,
110
- udf_inputs: "Iterable[BatchingResult]",
117
+ udf_fields: "Sequence[str]",
118
+ udf_inputs: "Iterable[RowsOutput]",
111
119
  catalog: "Catalog",
112
120
  is_generator: bool,
113
121
  cache: bool,
@@ -115,15 +123,22 @@ class UDFBase:
115
123
  processed_cb: Callback = DEFAULT_CALLBACK,
116
124
  ) -> Iterator[Iterable["UDFResult"]]:
117
125
  for batch in udf_inputs:
118
- n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
119
- output = self.run_once(catalog, batch, is_generator, cache, cb=download_cb)
126
+ if isinstance(batch, RowsOutputBatch):
127
+ n_rows = len(batch.rows)
128
+ inputs: UDFInput = UDFInputBatch(
129
+ [RowDict(zip(udf_fields, row)) for row in batch.rows]
130
+ )
131
+ else:
132
+ n_rows = 1
133
+ inputs = RowDict(zip(udf_fields, batch))
134
+ output = self.run_once(catalog, inputs, is_generator, cache, cb=download_cb)
120
135
  processed_cb.relative_update(n_rows)
121
136
  yield output
122
137
 
123
138
  def run_once(
124
139
  self,
125
140
  catalog: "Catalog",
126
- arg: "BatchingResult",
141
+ arg: "UDFInput",
127
142
  is_generator: bool = False,
128
143
  cache: bool = False,
129
144
  cb: Callback = DEFAULT_CALLBACK,
@@ -199,12 +214,12 @@ class UDFWrapper(UDFBase):
199
214
  def run_once(
200
215
  self,
201
216
  catalog: "Catalog",
202
- arg: "BatchingResult",
217
+ arg: "UDFInput",
203
218
  is_generator: bool = False,
204
219
  cache: bool = False,
205
220
  cb: Callback = DEFAULT_CALLBACK,
206
221
  ) -> Iterable[UDFResult]:
207
- if isinstance(arg, RowBatch):
222
+ if isinstance(arg, UDFInputBatch):
208
223
  udf_inputs = [
209
224
  self.bind_parameters(catalog, row, cache=cache, cb=cb)
210
225
  for row in arg.rows
@@ -1,8 +1,10 @@
1
1
  from datachain.sql.types import (
2
+ DBDefaults,
2
3
  TypeConverter,
3
4
  TypeDefaults,
4
5
  TypeReadConverter,
5
6
  register_backend_types,
7
+ register_db_defaults,
6
8
  register_type_defaults,
7
9
  register_type_read_converters,
8
10
  )
@@ -18,5 +20,6 @@ def setup() -> None:
18
20
  register_backend_types("default", TypeConverter())
19
21
  register_type_read_converters("default", TypeReadConverter())
20
22
  register_type_defaults("default", TypeDefaults())
23
+ register_db_defaults("default", DBDefaults())
21
24
 
22
25
  setup_is_complete = True
@@ -22,8 +22,10 @@ from datachain.sql.sqlite.types import (
22
22
  register_type_converters,
23
23
  )
24
24
  from datachain.sql.types import (
25
+ DBDefaults,
25
26
  TypeDefaults,
26
27
  register_backend_types,
28
+ register_db_defaults,
27
29
  register_type_defaults,
28
30
  register_type_read_converters,
29
31
  )
@@ -66,6 +68,7 @@ def setup():
66
68
  register_backend_types("sqlite", SQLiteTypeConverter())
67
69
  register_type_read_converters("sqlite", SQLiteTypeReadConverter())
68
70
  register_type_defaults("sqlite", TypeDefaults())
71
+ register_db_defaults("sqlite", DBDefaults())
69
72
 
70
73
  compiles(sql_path.parent, "sqlite")(compile_path_parent)
71
74
  compiles(sql_path.name, "sqlite")(compile_path_name)