datachain 0.3.17__py3-none-any.whl → 0.3.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +5 -2
- datachain/cache.py +14 -55
- datachain/catalog/catalog.py +17 -97
- datachain/cli.py +7 -2
- datachain/client/fsspec.py +29 -63
- datachain/client/local.py +2 -3
- datachain/dataset.py +7 -2
- datachain/error.py +6 -4
- datachain/lib/arrow.py +10 -4
- datachain/lib/dc.py +6 -2
- datachain/lib/file.py +64 -28
- datachain/lib/listing.py +2 -0
- datachain/listing.py +4 -4
- datachain/node.py +6 -6
- datachain/nodes_fetcher.py +12 -5
- datachain/nodes_thread_pool.py +1 -1
- datachain/progress.py +2 -12
- datachain/query/dataset.py +6 -40
- datachain/query/dispatch.py +2 -15
- datachain/query/schema.py +25 -24
- datachain/query/udf.py +0 -106
- datachain/sql/types.py +4 -2
- datachain/telemetry.py +37 -0
- datachain/utils.py +11 -0
- {datachain-0.3.17.dist-info → datachain-0.3.19.dist-info}/METADATA +5 -4
- {datachain-0.3.17.dist-info → datachain-0.3.19.dist-info}/RECORD +30 -29
- {datachain-0.3.17.dist-info → datachain-0.3.19.dist-info}/LICENSE +0 -0
- {datachain-0.3.17.dist-info → datachain-0.3.19.dist-info}/WHEEL +0 -0
- {datachain-0.3.17.dist-info → datachain-0.3.19.dist-info}/entry_points.txt +0 -0
- {datachain-0.3.17.dist-info → datachain-0.3.19.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py
CHANGED
|
@@ -26,8 +26,8 @@ from datachain.lib.convert.python_to_sql import python_to_sql
|
|
|
26
26
|
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
27
27
|
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
|
|
28
28
|
from datachain.lib.dataset_info import DatasetInfo
|
|
29
|
+
from datachain.lib.file import ArrowRow, File, get_file_type
|
|
29
30
|
from datachain.lib.file import ExportPlacement as FileExportPlacement
|
|
30
|
-
from datachain.lib.file import File, IndexedFile, get_file_type
|
|
31
31
|
from datachain.lib.listing import (
|
|
32
32
|
is_listing_dataset,
|
|
33
33
|
is_listing_expired,
|
|
@@ -58,6 +58,7 @@ from datachain.query.dataset import (
|
|
|
58
58
|
)
|
|
59
59
|
from datachain.query.schema import DEFAULT_DELIMITER, Column, DatasetRow
|
|
60
60
|
from datachain.sql.functions import path as pathfunc
|
|
61
|
+
from datachain.telemetry import telemetry
|
|
61
62
|
from datachain.utils import inside_notebook
|
|
62
63
|
|
|
63
64
|
if TYPE_CHECKING:
|
|
@@ -246,6 +247,9 @@ class DataChain(DatasetQuery):
|
|
|
246
247
|
**kwargs,
|
|
247
248
|
indexing_column_types=File._datachain_column_types,
|
|
248
249
|
)
|
|
250
|
+
|
|
251
|
+
telemetry.send_event_once("class", "datachain_init", **kwargs)
|
|
252
|
+
|
|
249
253
|
if settings:
|
|
250
254
|
self._settings = Settings(**settings)
|
|
251
255
|
else:
|
|
@@ -1610,7 +1614,7 @@ class DataChain(DatasetQuery):
|
|
|
1610
1614
|
for name, info in output.model_fields.items()
|
|
1611
1615
|
}
|
|
1612
1616
|
if source:
|
|
1613
|
-
output = {"source":
|
|
1617
|
+
output = {"source": ArrowRow} | output # type: ignore[assignment,operator]
|
|
1614
1618
|
return self.gen(
|
|
1615
1619
|
ArrowGenerator(schema, model, source, nrows, **kwargs), output=output
|
|
1616
1620
|
)
|
datachain/lib/file.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
|
+
import hashlib
|
|
1
2
|
import io
|
|
2
3
|
import json
|
|
3
4
|
import logging
|
|
4
5
|
import os
|
|
5
6
|
import posixpath
|
|
6
7
|
from abc import ABC, abstractmethod
|
|
8
|
+
from collections.abc import Iterator
|
|
7
9
|
from contextlib import contextmanager
|
|
8
10
|
from datetime import datetime
|
|
11
|
+
from functools import partial
|
|
9
12
|
from io import BytesIO
|
|
10
13
|
from pathlib import Path, PurePosixPath
|
|
11
14
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
|
|
@@ -14,12 +17,12 @@ from urllib.request import url2pathname
|
|
|
14
17
|
|
|
15
18
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
16
19
|
from PIL import Image
|
|
20
|
+
from pyarrow.dataset import dataset
|
|
17
21
|
from pydantic import Field, field_validator
|
|
18
22
|
|
|
19
23
|
if TYPE_CHECKING:
|
|
20
24
|
from typing_extensions import Self
|
|
21
25
|
|
|
22
|
-
from datachain.cache import UniqueId
|
|
23
26
|
from datachain.client.fileslice import FileSlice
|
|
24
27
|
from datachain.lib.data_model import DataModel
|
|
25
28
|
from datachain.lib.utils import DataChainError
|
|
@@ -27,7 +30,13 @@ from datachain.sql.types import JSON, Boolean, DateTime, Int, String
|
|
|
27
30
|
from datachain.utils import TIME_ZERO
|
|
28
31
|
|
|
29
32
|
if TYPE_CHECKING:
|
|
33
|
+
from typing_extensions import Self
|
|
34
|
+
|
|
30
35
|
from datachain.catalog import Catalog
|
|
36
|
+
from datachain.client.fsspec import Client
|
|
37
|
+
from datachain.dataset import RowDict
|
|
38
|
+
|
|
39
|
+
sha256 = partial(hashlib.sha256, usedforsecurity=False)
|
|
31
40
|
|
|
32
41
|
logger = logging.getLogger("datachain")
|
|
33
42
|
|
|
@@ -38,7 +47,7 @@ ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
|
|
|
38
47
|
class VFileError(DataChainError):
|
|
39
48
|
def __init__(self, file: "File", message: str, vtype: str = ""):
|
|
40
49
|
type_ = f" of vtype '{vtype}'" if vtype else ""
|
|
41
|
-
super().__init__(f"Error in v-file '{file.
|
|
50
|
+
super().__init__(f"Error in v-file '{file.path}'{type_}: {message}")
|
|
42
51
|
|
|
43
52
|
|
|
44
53
|
class FileError(DataChainError):
|
|
@@ -85,9 +94,8 @@ class TarVFile(VFile):
|
|
|
85
94
|
tar_file = File(**parent)
|
|
86
95
|
tar_file._set_stream(file._catalog)
|
|
87
96
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
fd = client.open_object(tar_file_uid, use_cache=file._caching_enabled)
|
|
97
|
+
client = file._catalog.get_client(tar_file.source)
|
|
98
|
+
fd = client.open_object(tar_file, use_cache=file._caching_enabled)
|
|
91
99
|
return FileSlice(fd, offset, size, file.name)
|
|
92
100
|
|
|
93
101
|
|
|
@@ -181,7 +189,11 @@ class File(DataModel):
|
|
|
181
189
|
def __init__(self, **kwargs):
|
|
182
190
|
super().__init__(**kwargs)
|
|
183
191
|
self._catalog = None
|
|
184
|
-
self._caching_enabled = False
|
|
192
|
+
self._caching_enabled: bool = False
|
|
193
|
+
|
|
194
|
+
@classmethod
|
|
195
|
+
def _from_row(cls, row: "RowDict") -> "Self":
|
|
196
|
+
return cls(**{key: row[key] for key in cls._datachain_column_types})
|
|
185
197
|
|
|
186
198
|
@property
|
|
187
199
|
def name(self):
|
|
@@ -192,19 +204,18 @@ class File(DataModel):
|
|
|
192
204
|
return str(PurePosixPath(self.path).parent)
|
|
193
205
|
|
|
194
206
|
@contextmanager
|
|
195
|
-
def open(self, mode: Literal["rb", "r"] = "rb"):
|
|
207
|
+
def open(self, mode: Literal["rb", "r"] = "rb") -> Iterator[Any]:
|
|
196
208
|
"""Open the file and return a file object."""
|
|
197
209
|
if self.location:
|
|
198
210
|
with VFileRegistry.resolve(self, self.location) as f: # type: ignore[arg-type]
|
|
199
211
|
yield f
|
|
200
212
|
|
|
201
213
|
else:
|
|
202
|
-
uid = self.get_uid()
|
|
203
|
-
client = self._catalog.get_client(self.source)
|
|
204
214
|
if self._caching_enabled:
|
|
205
|
-
|
|
215
|
+
self.ensure_cached()
|
|
216
|
+
client: Client = self._catalog.get_client(self.source)
|
|
206
217
|
with client.open_object(
|
|
207
|
-
|
|
218
|
+
self, use_cache=self._caching_enabled, cb=self._download_cb
|
|
208
219
|
) as f:
|
|
209
220
|
yield io.TextIOWrapper(f) if mode == "r" else f
|
|
210
221
|
|
|
@@ -252,23 +263,25 @@ class File(DataModel):
|
|
|
252
263
|
self._caching_enabled = caching_enabled
|
|
253
264
|
self._download_cb = download_cb
|
|
254
265
|
|
|
255
|
-
def
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
266
|
+
def ensure_cached(self) -> None:
|
|
267
|
+
if self._catalog is None:
|
|
268
|
+
raise RuntimeError(
|
|
269
|
+
"cannot download file to cache because catalog is not setup"
|
|
270
|
+
)
|
|
271
|
+
client = self._catalog.get_client(self.source)
|
|
272
|
+
client.download(self, callback=self._download_cb)
|
|
259
273
|
|
|
260
|
-
def get_local_path(self
|
|
261
|
-
"""
|
|
262
|
-
|
|
274
|
+
def get_local_path(self) -> Optional[str]:
|
|
275
|
+
"""Return path to a file in a local cache.
|
|
276
|
+
|
|
277
|
+
Returns None if file is not cached.
|
|
278
|
+
Raises an exception if cache is not setup.
|
|
279
|
+
"""
|
|
263
280
|
if self._catalog is None:
|
|
264
281
|
raise RuntimeError(
|
|
265
282
|
"cannot resolve local file path because catalog is not setup"
|
|
266
283
|
)
|
|
267
|
-
|
|
268
|
-
if download:
|
|
269
|
-
client = self._catalog.get_client(self.source)
|
|
270
|
-
client.download(uid, callback=self._download_cb)
|
|
271
|
-
return self._catalog.cache.get_path(uid)
|
|
284
|
+
return self._catalog.cache.get_path(self)
|
|
272
285
|
|
|
273
286
|
def get_file_suffix(self):
|
|
274
287
|
"""Returns last part of file name with `.`."""
|
|
@@ -323,6 +336,12 @@ class File(DataModel):
|
|
|
323
336
|
"""Returns `fsspec` filesystem for the file."""
|
|
324
337
|
return self._catalog.get_client(self.source).fs
|
|
325
338
|
|
|
339
|
+
def get_hash(self) -> str:
|
|
340
|
+
fingerprint = f"{self.source}/{self.path}/{self.version}/{self.etag}"
|
|
341
|
+
if self.location:
|
|
342
|
+
fingerprint += f"/{self.location}"
|
|
343
|
+
return sha256(fingerprint.encode()).hexdigest()
|
|
344
|
+
|
|
326
345
|
def resolve(self) -> "Self":
|
|
327
346
|
"""
|
|
328
347
|
Resolve a File object by checking its existence and updating its metadata.
|
|
@@ -421,14 +440,31 @@ class ImageFile(File):
|
|
|
421
440
|
self.read().save(destination)
|
|
422
441
|
|
|
423
442
|
|
|
424
|
-
class
|
|
425
|
-
"""
|
|
426
|
-
|
|
427
|
-
Includes `file` and `index` signals.
|
|
428
|
-
"""
|
|
443
|
+
class ArrowRow(DataModel):
|
|
444
|
+
"""`DataModel` for reading row from Arrow-supported file."""
|
|
429
445
|
|
|
430
446
|
file: File
|
|
431
447
|
index: int
|
|
448
|
+
kwargs: dict
|
|
449
|
+
|
|
450
|
+
@contextmanager
|
|
451
|
+
def open(self):
|
|
452
|
+
"""Stream row contents from indexed file."""
|
|
453
|
+
if self.file._caching_enabled:
|
|
454
|
+
self.file.ensure_cached()
|
|
455
|
+
path = self.file.get_local_path()
|
|
456
|
+
ds = dataset(path, **self.kwargs)
|
|
457
|
+
|
|
458
|
+
else:
|
|
459
|
+
path = self.file.get_path()
|
|
460
|
+
ds = dataset(path, filesystem=self.file.get_fs(), **self.kwargs)
|
|
461
|
+
|
|
462
|
+
return ds.take([self.index]).to_reader()
|
|
463
|
+
|
|
464
|
+
def read(self):
|
|
465
|
+
"""Returns row contents as dict."""
|
|
466
|
+
with self.open() as record_batch:
|
|
467
|
+
return record_batch.to_pylist()[0]
|
|
432
468
|
|
|
433
469
|
|
|
434
470
|
def get_file_type(type_: Literal["binary", "text", "image"] = "binary") -> type[File]:
|
datachain/lib/listing.py
CHANGED
|
@@ -11,6 +11,7 @@ from datachain.client import Client
|
|
|
11
11
|
from datachain.lib.file import File
|
|
12
12
|
from datachain.query.schema import Column
|
|
13
13
|
from datachain.sql.functions import path as pathfunc
|
|
14
|
+
from datachain.telemetry import telemetry
|
|
14
15
|
from datachain.utils import uses_glob
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
@@ -80,6 +81,7 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]:
|
|
|
80
81
|
client_config = client_config or {}
|
|
81
82
|
client = Client.get_client(uri, cache, **client_config)
|
|
82
83
|
storage_uri, path = Client.parse_url(uri)
|
|
84
|
+
telemetry.log_param("client", client.PREFIX)
|
|
83
85
|
|
|
84
86
|
# clean path without globs
|
|
85
87
|
lst_uri_path = (
|
datachain/listing.py
CHANGED
|
@@ -156,12 +156,12 @@ class Listing:
|
|
|
156
156
|
|
|
157
157
|
def instantiate_nodes(
|
|
158
158
|
self,
|
|
159
|
-
all_nodes,
|
|
159
|
+
all_nodes: Iterable[NodeWithPath],
|
|
160
160
|
output,
|
|
161
161
|
total_files=None,
|
|
162
162
|
force=False,
|
|
163
163
|
shared_progress_bar=None,
|
|
164
|
-
):
|
|
164
|
+
) -> None:
|
|
165
165
|
progress_bar = shared_progress_bar or tqdm(
|
|
166
166
|
desc=f"Instantiating '{output}'",
|
|
167
167
|
unit=" files",
|
|
@@ -175,8 +175,8 @@ class Listing:
|
|
|
175
175
|
dst = os.path.join(output, *node.path)
|
|
176
176
|
dst_dir = os.path.dirname(dst)
|
|
177
177
|
os.makedirs(dst_dir, exist_ok=True)
|
|
178
|
-
|
|
179
|
-
self.client.instantiate_object(
|
|
178
|
+
file = node.n.to_file(self.client.uri)
|
|
179
|
+
self.client.instantiate_object(file, dst, progress_bar, force)
|
|
180
180
|
counter += 1
|
|
181
181
|
if counter > 1000:
|
|
182
182
|
progress_bar.update(counter)
|
datachain/node.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|
|
3
3
|
|
|
4
4
|
import attrs
|
|
5
5
|
|
|
6
|
-
from datachain.
|
|
6
|
+
from datachain.lib.file import File
|
|
7
7
|
from datachain.storage import StorageURI
|
|
8
8
|
from datachain.utils import TIME_ZERO, time_to_str
|
|
9
9
|
|
|
@@ -99,11 +99,11 @@ class Node:
|
|
|
99
99
|
return self.path + "/"
|
|
100
100
|
return self.path
|
|
101
101
|
|
|
102
|
-
def
|
|
103
|
-
if
|
|
104
|
-
|
|
105
|
-
return
|
|
106
|
-
|
|
102
|
+
def to_file(self, source: Optional[StorageURI] = None) -> File:
|
|
103
|
+
if source is None:
|
|
104
|
+
source = self.source
|
|
105
|
+
return File(
|
|
106
|
+
source=source,
|
|
107
107
|
path=self.path,
|
|
108
108
|
size=self.size,
|
|
109
109
|
version=self.version or "",
|
datachain/nodes_fetcher.py
CHANGED
|
@@ -1,12 +1,19 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
2
4
|
|
|
5
|
+
from datachain.node import Node
|
|
3
6
|
from datachain.nodes_thread_pool import NodesThreadPool
|
|
4
7
|
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from datachain.cache import DataChainCache
|
|
10
|
+
from datachain.client.fsspec import Client
|
|
11
|
+
|
|
5
12
|
logger = logging.getLogger("datachain")
|
|
6
13
|
|
|
7
14
|
|
|
8
15
|
class NodesFetcher(NodesThreadPool):
|
|
9
|
-
def __init__(self, client, max_threads, cache):
|
|
16
|
+
def __init__(self, client: "Client", max_threads: int, cache: "DataChainCache"):
|
|
10
17
|
super().__init__(max_threads)
|
|
11
18
|
self.client = client
|
|
12
19
|
self.cache = cache
|
|
@@ -15,7 +22,7 @@ class NodesFetcher(NodesThreadPool):
|
|
|
15
22
|
for task in done:
|
|
16
23
|
task.result()
|
|
17
24
|
|
|
18
|
-
def do_task(self, chunk):
|
|
25
|
+
def do_task(self, chunk: Iterable[Node]) -> None:
|
|
19
26
|
from fsspec import Callback
|
|
20
27
|
|
|
21
28
|
class _CB(Callback):
|
|
@@ -23,8 +30,8 @@ class NodesFetcher(NodesThreadPool):
|
|
|
23
30
|
self.increase_counter(inc)
|
|
24
31
|
|
|
25
32
|
for node in chunk:
|
|
26
|
-
|
|
27
|
-
if self.cache.contains(
|
|
33
|
+
file = node.to_file(self.client.uri)
|
|
34
|
+
if self.cache.contains(file):
|
|
28
35
|
self.increase_counter(node.size)
|
|
29
36
|
else:
|
|
30
|
-
self.client.put_in_cache(
|
|
37
|
+
self.client.put_in_cache(file, callback=_CB())
|
datachain/nodes_thread_pool.py
CHANGED
|
@@ -20,7 +20,7 @@ class NodeChunk:
|
|
|
20
20
|
def next_downloadable(self):
|
|
21
21
|
node = next(self.nodes, None)
|
|
22
22
|
while node and (
|
|
23
|
-
not node.is_downloadable or self.cache.contains(node.
|
|
23
|
+
not node.is_downloadable or self.cache.contains(node.to_file(self.storage))
|
|
24
24
|
):
|
|
25
25
|
node = next(self.nodes, None)
|
|
26
26
|
return node
|
datachain/progress.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
"""Manages progress bars."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
import os
|
|
5
|
-
import re
|
|
6
4
|
import sys
|
|
7
5
|
from threading import RLock
|
|
8
6
|
from typing import Any, ClassVar
|
|
@@ -10,20 +8,12 @@ from typing import Any, ClassVar
|
|
|
10
8
|
from fsspec.callbacks import TqdmCallback
|
|
11
9
|
from tqdm import tqdm
|
|
12
10
|
|
|
11
|
+
from datachain.utils import env2bool
|
|
12
|
+
|
|
13
13
|
logger = logging.getLogger(__name__)
|
|
14
14
|
tqdm.set_lock(RLock())
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def env2bool(var, undefined=False):
|
|
18
|
-
"""
|
|
19
|
-
undefined: return value if env var is unset
|
|
20
|
-
"""
|
|
21
|
-
var = os.getenv(var, None)
|
|
22
|
-
if var is None:
|
|
23
|
-
return undefined
|
|
24
|
-
return bool(re.search("1|y|yes|true", var, flags=re.IGNORECASE))
|
|
25
|
-
|
|
26
|
-
|
|
27
17
|
class Tqdm(tqdm):
|
|
28
18
|
"""
|
|
29
19
|
maximum-compatibility tqdm-based progressbars
|
datachain/query/dataset.py
CHANGED
|
@@ -53,7 +53,7 @@ from datachain.utils import (
|
|
|
53
53
|
|
|
54
54
|
from .schema import C, UDFParamSpec, normalize_param
|
|
55
55
|
from .session import Session
|
|
56
|
-
from .udf import UDFBase
|
|
56
|
+
from .udf import UDFBase
|
|
57
57
|
|
|
58
58
|
if TYPE_CHECKING:
|
|
59
59
|
from sqlalchemy.sql.elements import ClauseElement
|
|
@@ -364,7 +364,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
|
364
364
|
|
|
365
365
|
@frozen
|
|
366
366
|
class UDFStep(Step, ABC):
|
|
367
|
-
udf:
|
|
367
|
+
udf: UDFBase
|
|
368
368
|
catalog: "Catalog"
|
|
369
369
|
partition_by: Optional[PartitionByType] = None
|
|
370
370
|
parallel: Optional[int] = None
|
|
@@ -470,12 +470,6 @@ class UDFStep(Step, ABC):
|
|
|
470
470
|
|
|
471
471
|
else:
|
|
472
472
|
# Otherwise process single-threaded (faster for smaller UDFs)
|
|
473
|
-
# Optionally instantiate the UDF instance if a class is provided.
|
|
474
|
-
if isinstance(self.udf, UDFFactory):
|
|
475
|
-
udf: UDFBase = self.udf()
|
|
476
|
-
else:
|
|
477
|
-
udf = self.udf
|
|
478
|
-
|
|
479
473
|
warehouse = self.catalog.warehouse
|
|
480
474
|
|
|
481
475
|
with contextlib.closing(
|
|
@@ -485,7 +479,7 @@ class UDFStep(Step, ABC):
|
|
|
485
479
|
processed_cb = get_processed_callback()
|
|
486
480
|
generated_cb = get_generated_callback(self.is_generator)
|
|
487
481
|
try:
|
|
488
|
-
udf_results = udf.run(
|
|
482
|
+
udf_results = self.udf.run(
|
|
489
483
|
udf_fields,
|
|
490
484
|
udf_inputs,
|
|
491
485
|
self.catalog,
|
|
@@ -498,7 +492,7 @@ class UDFStep(Step, ABC):
|
|
|
498
492
|
warehouse,
|
|
499
493
|
udf_table,
|
|
500
494
|
udf_results,
|
|
501
|
-
udf,
|
|
495
|
+
self.udf,
|
|
502
496
|
cb=generated_cb,
|
|
503
497
|
)
|
|
504
498
|
finally:
|
|
@@ -1471,7 +1465,7 @@ class DatasetQuery:
|
|
|
1471
1465
|
@detach
|
|
1472
1466
|
def add_signals(
|
|
1473
1467
|
self,
|
|
1474
|
-
udf:
|
|
1468
|
+
udf: UDFBase,
|
|
1475
1469
|
parallel: Optional[int] = None,
|
|
1476
1470
|
workers: Union[bool, int] = False,
|
|
1477
1471
|
min_task_size: Optional[int] = None,
|
|
@@ -1492,9 +1486,6 @@ class DatasetQuery:
|
|
|
1492
1486
|
at least that minimum number of rows to each distributed worker, mostly useful
|
|
1493
1487
|
if there are a very large number of small tasks to process.
|
|
1494
1488
|
"""
|
|
1495
|
-
if isinstance(udf, UDFClassWrapper): # type: ignore[unreachable]
|
|
1496
|
-
# This is a bare decorated class, "instantiate" it now.
|
|
1497
|
-
udf = udf() # type: ignore[unreachable]
|
|
1498
1489
|
query = self.clone()
|
|
1499
1490
|
query.steps.append(
|
|
1500
1491
|
UDFSignal(
|
|
@@ -1518,16 +1509,13 @@ class DatasetQuery:
|
|
|
1518
1509
|
@detach
|
|
1519
1510
|
def generate(
|
|
1520
1511
|
self,
|
|
1521
|
-
udf:
|
|
1512
|
+
udf: UDFBase,
|
|
1522
1513
|
parallel: Optional[int] = None,
|
|
1523
1514
|
workers: Union[bool, int] = False,
|
|
1524
1515
|
min_task_size: Optional[int] = None,
|
|
1525
1516
|
partition_by: Optional[PartitionByType] = None,
|
|
1526
1517
|
cache: bool = False,
|
|
1527
1518
|
) -> "Self":
|
|
1528
|
-
if isinstance(udf, UDFClassWrapper): # type: ignore[unreachable]
|
|
1529
|
-
# This is a bare decorated class, "instantiate" it now.
|
|
1530
|
-
udf = udf() # type: ignore[unreachable]
|
|
1531
1519
|
query = self.clone()
|
|
1532
1520
|
steps = query.steps
|
|
1533
1521
|
steps.append(
|
|
@@ -1616,25 +1604,3 @@ class DatasetQuery:
|
|
|
1616
1604
|
finally:
|
|
1617
1605
|
self.cleanup()
|
|
1618
1606
|
return self.__class__(name=name, version=version, catalog=self.catalog)
|
|
1619
|
-
|
|
1620
|
-
|
|
1621
|
-
def query_wrapper(dataset_query: Any) -> Any:
|
|
1622
|
-
"""
|
|
1623
|
-
Wrapper function that wraps the last statement of user query script.
|
|
1624
|
-
Last statement MUST be instance of DatasetQuery, otherwise script exits with
|
|
1625
|
-
error code 10
|
|
1626
|
-
"""
|
|
1627
|
-
if not isinstance(dataset_query, DatasetQuery):
|
|
1628
|
-
return dataset_query
|
|
1629
|
-
|
|
1630
|
-
catalog = dataset_query.catalog
|
|
1631
|
-
save = bool(os.getenv("DATACHAIN_QUERY_SAVE"))
|
|
1632
|
-
|
|
1633
|
-
is_session_temp_dataset = dataset_query.name and dataset_query.name.startswith(
|
|
1634
|
-
dataset_query.session.get_temp_prefix()
|
|
1635
|
-
)
|
|
1636
|
-
|
|
1637
|
-
if save and (is_session_temp_dataset or not dataset_query.attached):
|
|
1638
|
-
name = catalog.generate_query_dataset_name()
|
|
1639
|
-
dataset_query = dataset_query.save(name)
|
|
1640
|
-
return dataset_query
|
datachain/query/dispatch.py
CHANGED
|
@@ -27,7 +27,7 @@ from datachain.query.queue import (
|
|
|
27
27
|
put_into_queue,
|
|
28
28
|
unmarshal,
|
|
29
29
|
)
|
|
30
|
-
from datachain.query.udf import UDFBase,
|
|
30
|
+
from datachain.query.udf import UDFBase, UDFResult
|
|
31
31
|
from datachain.utils import batched_it
|
|
32
32
|
|
|
33
33
|
DEFAULT_BATCH_SIZE = 10000
|
|
@@ -156,8 +156,6 @@ class UDFDispatcher:
|
|
|
156
156
|
|
|
157
157
|
@property
|
|
158
158
|
def batch_size(self):
|
|
159
|
-
if not self.udf:
|
|
160
|
-
self.udf = self.udf_factory()
|
|
161
159
|
if self._batch_size is None:
|
|
162
160
|
if hasattr(self.udf, "properties") and hasattr(
|
|
163
161
|
self.udf.properties, "batch"
|
|
@@ -181,18 +179,7 @@ class UDFDispatcher:
|
|
|
181
179
|
self.catalog = Catalog(
|
|
182
180
|
id_generator, metastore, warehouse, **self.catalog_init_params
|
|
183
181
|
)
|
|
184
|
-
udf = loads(self.udf_data)
|
|
185
|
-
# isinstance cannot be used here, as cloudpickle packages the entire class
|
|
186
|
-
# definition, and so these two types are not considered exactly equal,
|
|
187
|
-
# even if they have the same import path.
|
|
188
|
-
if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
|
|
189
|
-
self.udf = udf
|
|
190
|
-
else:
|
|
191
|
-
self.udf = None
|
|
192
|
-
self.udf_factory = udf
|
|
193
|
-
if not self.udf:
|
|
194
|
-
self.udf = self.udf_factory()
|
|
195
|
-
|
|
182
|
+
self.udf = loads(self.udf_data)
|
|
196
183
|
return UDFWorker(
|
|
197
184
|
self.catalog,
|
|
198
185
|
self.udf,
|
datachain/query/schema.py
CHANGED
|
@@ -9,6 +9,7 @@ import attrs
|
|
|
9
9
|
import sqlalchemy as sa
|
|
10
10
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
11
11
|
|
|
12
|
+
from datachain.lib.file import File
|
|
12
13
|
from datachain.sql.types import JSON, Boolean, DateTime, Int64, SQLType, String
|
|
13
14
|
|
|
14
15
|
if TYPE_CHECKING:
|
|
@@ -97,11 +98,11 @@ class Object(UDFParameter):
|
|
|
97
98
|
cb: Callback = DEFAULT_CALLBACK,
|
|
98
99
|
**kwargs,
|
|
99
100
|
) -> Any:
|
|
100
|
-
|
|
101
|
-
|
|
101
|
+
file = File._from_row(file_signals(row))
|
|
102
|
+
client = catalog.get_client(file.source)
|
|
102
103
|
if cache:
|
|
103
|
-
client.download(
|
|
104
|
-
with client.open_object(
|
|
104
|
+
client.download(file, callback=cb)
|
|
105
|
+
with client.open_object(file, use_cache=cache, cb=cb) as f:
|
|
105
106
|
return self.reader(f)
|
|
106
107
|
|
|
107
108
|
async def get_value_async(
|
|
@@ -114,12 +115,12 @@ class Object(UDFParameter):
|
|
|
114
115
|
cb: Callback = DEFAULT_CALLBACK,
|
|
115
116
|
**kwargs,
|
|
116
117
|
) -> Any:
|
|
117
|
-
|
|
118
|
-
|
|
118
|
+
file = File._from_row(file_signals(row))
|
|
119
|
+
client = catalog.get_client(file.source)
|
|
119
120
|
if cache:
|
|
120
|
-
await client._download(
|
|
121
|
+
await client._download(file, callback=cb)
|
|
121
122
|
obj = await mapper.to_thread(
|
|
122
|
-
functools.partial(client.open_object,
|
|
123
|
+
functools.partial(client.open_object, file, use_cache=cache, cb=cb)
|
|
123
124
|
)
|
|
124
125
|
with obj:
|
|
125
126
|
return await mapper.to_thread(self.reader, obj)
|
|
@@ -140,11 +141,11 @@ class Stream(UDFParameter):
|
|
|
140
141
|
cb: Callback = DEFAULT_CALLBACK,
|
|
141
142
|
**kwargs,
|
|
142
143
|
) -> Any:
|
|
143
|
-
|
|
144
|
-
|
|
144
|
+
file = File._from_row(file_signals(row))
|
|
145
|
+
client = catalog.get_client(file.source)
|
|
145
146
|
if cache:
|
|
146
|
-
client.download(
|
|
147
|
-
return client.open_object(
|
|
147
|
+
client.download(file, callback=cb)
|
|
148
|
+
return client.open_object(file, use_cache=cache, cb=cb)
|
|
148
149
|
|
|
149
150
|
async def get_value_async(
|
|
150
151
|
self,
|
|
@@ -156,12 +157,12 @@ class Stream(UDFParameter):
|
|
|
156
157
|
cb: Callback = DEFAULT_CALLBACK,
|
|
157
158
|
**kwargs,
|
|
158
159
|
) -> Any:
|
|
159
|
-
|
|
160
|
-
|
|
160
|
+
file = File._from_row(file_signals(row))
|
|
161
|
+
client = catalog.get_client(file.source)
|
|
161
162
|
if cache:
|
|
162
|
-
await client._download(
|
|
163
|
+
await client._download(file, callback=cb)
|
|
163
164
|
return await mapper.to_thread(
|
|
164
|
-
functools.partial(client.open_object,
|
|
165
|
+
functools.partial(client.open_object, file, use_cache=cache, cb=cb)
|
|
165
166
|
)
|
|
166
167
|
|
|
167
168
|
|
|
@@ -189,10 +190,10 @@ class LocalFilename(UDFParameter):
|
|
|
189
190
|
# If the glob pattern is specified and the row filename
|
|
190
191
|
# does not match it, then return None
|
|
191
192
|
return None
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
client.download(
|
|
195
|
-
return client.cache.get_path(
|
|
193
|
+
file = File._from_row(file_signals(row))
|
|
194
|
+
client = catalog.get_client(file.source)
|
|
195
|
+
client.download(file, callback=cb)
|
|
196
|
+
return client.cache.get_path(file)
|
|
196
197
|
|
|
197
198
|
async def get_value_async(
|
|
198
199
|
self,
|
|
@@ -208,10 +209,10 @@ class LocalFilename(UDFParameter):
|
|
|
208
209
|
# If the glob pattern is specified and the row filename
|
|
209
210
|
# does not match it, then return None
|
|
210
211
|
return None
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
await client._download(
|
|
214
|
-
return client.cache.get_path(
|
|
212
|
+
file = File._from_row(file_signals(row))
|
|
213
|
+
client = catalog.get_client(file.source)
|
|
214
|
+
await client._download(file, callback=cb)
|
|
215
|
+
return client.cache.get_path(file)
|
|
215
216
|
|
|
216
217
|
|
|
217
218
|
UDFParamSpec = Union[str, Column, UDFParameter]
|