datachain 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +17 -2
- datachain/data_storage/db_engine.py +0 -2
- datachain/data_storage/schema.py +10 -27
- datachain/data_storage/warehouse.py +1 -7
- datachain/lib/arrow.py +7 -13
- datachain/lib/clip.py +151 -0
- datachain/lib/dc.py +35 -57
- datachain/lib/feature_utils.py +1 -2
- datachain/lib/file.py +7 -0
- datachain/lib/image.py +37 -79
- datachain/lib/pytorch.py +4 -2
- datachain/lib/signal_schema.py +3 -4
- datachain/lib/text.py +18 -49
- datachain/lib/udf.py +58 -30
- datachain/lib/udf_signature.py +11 -10
- datachain/lib/utils.py +17 -0
- datachain/lib/webdataset.py +2 -2
- datachain/listing.py +0 -3
- datachain/query/dataset.py +63 -37
- datachain/query/dispatch.py +2 -2
- datachain/query/schema.py +1 -8
- datachain/query/udf.py +16 -18
- datachain/utils.py +28 -0
- {datachain-0.2.1.dist-info → datachain-0.2.2.dist-info}/METADATA +2 -1
- {datachain-0.2.1.dist-info → datachain-0.2.2.dist-info}/RECORD +29 -29
- {datachain-0.2.1.dist-info → datachain-0.2.2.dist-info}/WHEEL +1 -1
- datachain/lib/reader.py +0 -49
- {datachain-0.2.1.dist-info → datachain-0.2.2.dist-info}/LICENSE +0 -0
- {datachain-0.2.1.dist-info → datachain-0.2.2.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.1.dist-info → datachain-0.2.2.dist-info}/top_level.txt +0 -0
datachain/lib/pytorch.py
CHANGED
|
@@ -116,10 +116,12 @@ class PytorchDataset(IterableDataset):
|
|
|
116
116
|
self.transform = None
|
|
117
117
|
if self.tokenizer:
|
|
118
118
|
for i, val in enumerate(row):
|
|
119
|
-
if isinstance(val, str)
|
|
119
|
+
if isinstance(val, str) or (
|
|
120
|
+
isinstance(val, list) and isinstance(val[0], str)
|
|
121
|
+
):
|
|
120
122
|
row[i] = convert_text(
|
|
121
123
|
val, self.tokenizer, self.tokenizer_kwargs
|
|
122
|
-
)
|
|
124
|
+
).squeeze(0) # type: ignore[union-attr]
|
|
123
125
|
yield row
|
|
124
126
|
|
|
125
127
|
@staticmethod
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Optional, Union, get_args, get_origin
|
|
|
5
5
|
|
|
6
6
|
from pydantic import create_model
|
|
7
7
|
|
|
8
|
-
from datachain.lib.arrow import Source
|
|
9
8
|
from datachain.lib.feature import (
|
|
10
9
|
DATACHAIN_TO_TYPE,
|
|
11
10
|
DEFAULT_DELIMITER,
|
|
@@ -14,7 +13,7 @@ from datachain.lib.feature import (
|
|
|
14
13
|
convert_type_to_datachain,
|
|
15
14
|
)
|
|
16
15
|
from datachain.lib.feature_registry import Registry
|
|
17
|
-
from datachain.lib.file import File, TextFile
|
|
16
|
+
from datachain.lib.file import File, IndexedFile, TextFile
|
|
18
17
|
from datachain.lib.image import ImageFile
|
|
19
18
|
from datachain.lib.utils import DataChainParamsError
|
|
20
19
|
from datachain.lib.webdataset import TarStream, WDSAllFile, WDSBasic
|
|
@@ -36,7 +35,7 @@ NAMES_TO_TYPES = {
|
|
|
36
35
|
"datetime": datetime,
|
|
37
36
|
"WDSLaion": WDSLaion,
|
|
38
37
|
"Laion": Laion,
|
|
39
|
-
"Source":
|
|
38
|
+
"Source": IndexedFile,
|
|
40
39
|
"File": File,
|
|
41
40
|
"ImageFile": ImageFile,
|
|
42
41
|
"TextFile": TextFile,
|
|
@@ -150,7 +149,7 @@ class SignalSchema:
|
|
|
150
149
|
)
|
|
151
150
|
|
|
152
151
|
def slice(self, keys: Sequence[str]) -> "SignalSchema":
|
|
153
|
-
return SignalSchema({k:
|
|
152
|
+
return SignalSchema({k: self.values[k] for k in keys if k in self.values})
|
|
154
153
|
|
|
155
154
|
def row_to_features(self, row: Sequence, catalog: "Catalog") -> list[FeatureType]:
|
|
156
155
|
res = []
|
datachain/lib/text.py
CHANGED
|
@@ -1,19 +1,15 @@
|
|
|
1
|
-
import inspect
|
|
2
1
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
3
2
|
|
|
4
|
-
from datachain.lib.file import TextFile
|
|
5
|
-
from datachain.lib.reader import FeatureReader
|
|
6
|
-
|
|
7
3
|
if TYPE_CHECKING:
|
|
8
|
-
|
|
4
|
+
import torch
|
|
9
5
|
|
|
10
6
|
|
|
11
7
|
def convert_text(
|
|
12
8
|
text: Union[str, list[str]],
|
|
13
9
|
tokenizer: Optional[Callable] = None,
|
|
14
10
|
tokenizer_kwargs: Optional[dict[str, Any]] = None,
|
|
15
|
-
|
|
16
|
-
):
|
|
11
|
+
encoder: Optional[Callable] = None,
|
|
12
|
+
) -> Union[str, list[str], "torch.Tensor"]:
|
|
17
13
|
"""
|
|
18
14
|
Tokenize and otherwise transform text.
|
|
19
15
|
|
|
@@ -21,18 +17,8 @@ def convert_text(
|
|
|
21
17
|
text (str): Text to convert.
|
|
22
18
|
tokenizer (Callable): Tokenizer to use to tokenize objects.
|
|
23
19
|
tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
|
|
24
|
-
|
|
20
|
+
encoder (Callable): Encode text using model.
|
|
25
21
|
"""
|
|
26
|
-
if open_clip_model:
|
|
27
|
-
method_name = "encode_text"
|
|
28
|
-
if not (
|
|
29
|
-
hasattr(open_clip_model, method_name)
|
|
30
|
-
and inspect.ismethod(getattr(open_clip_model, method_name))
|
|
31
|
-
):
|
|
32
|
-
raise ValueError(
|
|
33
|
-
f"TextColumn error: 'model' doesn't support '{method_name}()'"
|
|
34
|
-
)
|
|
35
|
-
|
|
36
22
|
if not tokenizer:
|
|
37
23
|
return text
|
|
38
24
|
|
|
@@ -43,38 +29,21 @@ def convert_text(
|
|
|
43
29
|
res = tokenizer(text, **tokenizer_kwargs)
|
|
44
30
|
else:
|
|
45
31
|
res = tokenizer(text)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
|
|
49
|
-
|
|
50
|
-
if not open_clip_model:
|
|
51
|
-
return tokens.squeeze(0)
|
|
52
|
-
|
|
53
|
-
return open_clip_model.encode_text(tokens).squeeze(0)
|
|
32
|
+
try:
|
|
33
|
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
54
34
|
|
|
35
|
+
tokens = (
|
|
36
|
+
res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
|
|
37
|
+
)
|
|
38
|
+
except ImportError:
|
|
39
|
+
tokens = res
|
|
55
40
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
self,
|
|
59
|
-
fr_class: "FeatureLike" = TextFile,
|
|
60
|
-
tokenizer: Optional[Callable] = None,
|
|
61
|
-
tokenizer_kwargs: Optional[dict[str, Any]] = None,
|
|
62
|
-
open_clip_model: Optional[Any] = None,
|
|
63
|
-
):
|
|
64
|
-
"""
|
|
65
|
-
Read and optionally transform a text column.
|
|
41
|
+
if not encoder:
|
|
42
|
+
return tokens
|
|
66
43
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
self.open_clip_model = open_clip_model
|
|
72
|
-
super().__init__(fr_class)
|
|
44
|
+
try:
|
|
45
|
+
import torch
|
|
46
|
+
except ImportError:
|
|
47
|
+
"Missing dependency 'torch' needed to encode text."
|
|
73
48
|
|
|
74
|
-
|
|
75
|
-
return convert_text(
|
|
76
|
-
value,
|
|
77
|
-
tokenizer=self.tokenizer,
|
|
78
|
-
tokenizer_kwargs=self.tokenizer_kwargs,
|
|
79
|
-
open_clip_model=self.open_clip_model,
|
|
80
|
-
)
|
|
49
|
+
return encoder(torch.tensor(tokens))
|
datachain/lib/udf.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import sys
|
|
3
3
|
import traceback
|
|
4
|
-
from typing import TYPE_CHECKING, Callable
|
|
4
|
+
from typing import TYPE_CHECKING, Callable
|
|
5
5
|
|
|
6
6
|
from datachain.lib.feature import Feature
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
|
-
from datachain.lib.
|
|
8
|
+
from datachain.lib.udf_signature import UdfSignature
|
|
9
|
+
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
9
10
|
from datachain.query import udf
|
|
10
11
|
|
|
11
12
|
if TYPE_CHECKING:
|
|
@@ -17,26 +18,68 @@ class UdfError(DataChainParamsError):
|
|
|
17
18
|
super().__init__(f"UDF error: {msg}")
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
class UDFBase:
|
|
21
|
+
class UDFBase(AbstractUDF):
|
|
21
22
|
is_input_batched = False
|
|
22
23
|
is_output_batched = False
|
|
23
24
|
is_input_grouped = False
|
|
24
25
|
|
|
25
|
-
def __init__(
|
|
26
|
-
self
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
26
|
+
def __init__(self):
|
|
27
|
+
self.params = None
|
|
28
|
+
self.output = None
|
|
29
|
+
self.params_spec = None
|
|
30
|
+
self.output_spec = None
|
|
31
|
+
self._contains_stream = None
|
|
32
|
+
self._catalog = None
|
|
33
|
+
self._func = None
|
|
34
|
+
|
|
35
|
+
def process(self, *args, **kwargs):
|
|
36
|
+
"""Processing function that needs to be defined by user"""
|
|
37
|
+
if not self._func:
|
|
38
|
+
raise NotImplementedError("UDF processing is not implemented")
|
|
39
|
+
return self._func(*args, **kwargs)
|
|
40
|
+
|
|
41
|
+
def setup(self):
|
|
42
|
+
"""Initialization process executed on each worker before processing begins.
|
|
43
|
+
This is needed for tasks like pre-loading ML models prior to scoring.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def teardown(self):
|
|
47
|
+
"""Teardown process executed on each process/worker after processing ends.
|
|
48
|
+
This is needed for tasks like closing connections to end-points.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def _init(self, sign: UdfSignature, params: SignalSchema, func: Callable):
|
|
31
52
|
self.params = params
|
|
32
|
-
self.output =
|
|
33
|
-
self._func = func
|
|
53
|
+
self.output = sign.output_schema
|
|
34
54
|
|
|
35
|
-
params_spec = params.to_udf_spec()
|
|
55
|
+
params_spec = self.params.to_udf_spec()
|
|
36
56
|
self.params_spec = list(params_spec.keys())
|
|
37
|
-
self.output_spec = output.to_udf_spec()
|
|
57
|
+
self.output_spec = self.output.to_udf_spec()
|
|
38
58
|
|
|
39
|
-
self.
|
|
59
|
+
self._func = func
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def _create(
|
|
63
|
+
cls,
|
|
64
|
+
target_class: type["UDFBase"],
|
|
65
|
+
sign: UdfSignature,
|
|
66
|
+
params: SignalSchema,
|
|
67
|
+
catalog,
|
|
68
|
+
) -> "UDFBase":
|
|
69
|
+
if isinstance(sign.func, AbstractUDF):
|
|
70
|
+
if not isinstance(sign.func, target_class): # type: ignore[unreachable]
|
|
71
|
+
raise UdfError(
|
|
72
|
+
f"cannot create UDF: provided UDF '{sign.func.__name__}'"
|
|
73
|
+
f" must be a child of target class '{target_class.__name__}'",
|
|
74
|
+
)
|
|
75
|
+
result = sign.func
|
|
76
|
+
func = None
|
|
77
|
+
else:
|
|
78
|
+
result = target_class()
|
|
79
|
+
func = sign.func
|
|
80
|
+
|
|
81
|
+
result._init(sign, params, func)
|
|
82
|
+
return result
|
|
40
83
|
|
|
41
84
|
@property
|
|
42
85
|
def name(self):
|
|
@@ -53,25 +96,10 @@ class UDFBase:
|
|
|
53
96
|
udf_wrapper = udf(self.params_spec, self.output_spec, batch=batch)
|
|
54
97
|
return udf_wrapper(self)
|
|
55
98
|
|
|
56
|
-
def bootstrap(self):
|
|
57
|
-
"""Initialization process executed on each worker before processing begins.
|
|
58
|
-
This is needed for tasks like pre-loading ML models prior to scoring.
|
|
59
|
-
"""
|
|
60
|
-
|
|
61
|
-
def teardown(self):
|
|
62
|
-
"""Teardown process executed on each process/worker after processing ends.
|
|
63
|
-
This is needed for tasks like closing connections to end-points.
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
def process(self, *args, **kwargs):
|
|
67
|
-
if not self._func:
|
|
68
|
-
raise NotImplementedError("UDF processing is not implemented")
|
|
69
|
-
return self._func(*args, **kwargs)
|
|
70
|
-
|
|
71
99
|
def validate_results(self, results, *args, **kwargs):
|
|
72
100
|
return results
|
|
73
101
|
|
|
74
|
-
def __call__(self, *rows
|
|
102
|
+
def __call__(self, *rows):
|
|
75
103
|
if self.is_input_grouped:
|
|
76
104
|
objs = self._parse_grouped_rows(rows)
|
|
77
105
|
else:
|
datachain/lib/udf_signature.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Callable, Optional, Union, get_args, get_origin
|
|
|
5
5
|
|
|
6
6
|
from datachain.lib.feature import Feature, FeatureType, FeatureTypeNames
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
|
-
from datachain.lib.utils import DataChainParamsError
|
|
8
|
+
from datachain.lib.utils import AbstractUDF, DataChainParamsError
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class UdfSignatureError(DataChainParamsError):
|
|
@@ -49,10 +49,13 @@ class UdfSignature:
|
|
|
49
49
|
else:
|
|
50
50
|
if func is None:
|
|
51
51
|
raise UdfSignatureError(chain, "user function is not defined")
|
|
52
|
+
|
|
52
53
|
udf_func = func
|
|
53
54
|
signal_name = None
|
|
55
|
+
|
|
54
56
|
if not callable(udf_func):
|
|
55
|
-
raise UdfSignatureError(chain, f"
|
|
57
|
+
raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
|
|
58
|
+
|
|
56
59
|
func_params_map_sign, func_outs_sign, is_iterator = (
|
|
57
60
|
UdfSignature._func_signature(chain, udf_func)
|
|
58
61
|
)
|
|
@@ -108,13 +111,6 @@ class UdfSignature:
|
|
|
108
111
|
if isinstance(output, str):
|
|
109
112
|
output = [output]
|
|
110
113
|
if isinstance(output, Sequence):
|
|
111
|
-
if not func_outs_sign:
|
|
112
|
-
raise UdfSignatureError(
|
|
113
|
-
chain,
|
|
114
|
-
"output types are not specified. Specify types in 'output' as"
|
|
115
|
-
" a dict or as function return value hint.",
|
|
116
|
-
)
|
|
117
|
-
|
|
118
114
|
if len(func_outs_sign) != len(output):
|
|
119
115
|
raise UdfSignatureError(
|
|
120
116
|
chain,
|
|
@@ -158,8 +154,13 @@ class UdfSignature:
|
|
|
158
154
|
|
|
159
155
|
@staticmethod
|
|
160
156
|
def _func_signature(
|
|
161
|
-
chain: str,
|
|
157
|
+
chain: str, udf_func: Callable
|
|
162
158
|
) -> tuple[dict[str, type], Sequence[type], bool]:
|
|
159
|
+
if isinstance(udf_func, AbstractUDF):
|
|
160
|
+
func = udf_func.process # type: ignore[unreachable]
|
|
161
|
+
else:
|
|
162
|
+
func = udf_func
|
|
163
|
+
|
|
163
164
|
sign = inspect.signature(func)
|
|
164
165
|
|
|
165
166
|
input_map = {prm.name: prm.annotation for prm in sign.parameters.values()}
|
datachain/lib/utils.py
CHANGED
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class AbstractUDF(ABC):
|
|
5
|
+
@abstractmethod
|
|
6
|
+
def process(self, *args, **kwargs):
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def setup(self):
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def teardown(self):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
1
18
|
class DataChainError(Exception):
|
|
2
19
|
def __init__(self, message):
|
|
3
20
|
super().__init__(message)
|
datachain/lib/webdataset.py
CHANGED
|
@@ -2,6 +2,7 @@ import hashlib
|
|
|
2
2
|
import json
|
|
3
3
|
import tarfile
|
|
4
4
|
from collections.abc import Iterator, Sequence
|
|
5
|
+
from pathlib import Path
|
|
5
6
|
from typing import (
|
|
6
7
|
Any,
|
|
7
8
|
Callable,
|
|
@@ -240,10 +241,9 @@ class TarStream(File):
|
|
|
240
241
|
def get_tar_groups(stream, tar, core_extensions, spec, encoding="utf-8"):
|
|
241
242
|
builder = Builder(stream, core_extensions, spec, tar, encoding)
|
|
242
243
|
|
|
243
|
-
for item in tar.getmembers():
|
|
244
|
+
for item in sorted(tar.getmembers(), key=lambda m: Path(m.name).stem):
|
|
244
245
|
if not item.isfile():
|
|
245
246
|
continue
|
|
246
|
-
|
|
247
247
|
try:
|
|
248
248
|
builder.add(item)
|
|
249
249
|
except StopIteration:
|
datachain/listing.py
CHANGED
datachain/query/dataset.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import ast
|
|
1
2
|
import contextlib
|
|
2
3
|
import datetime
|
|
3
4
|
import inspect
|
|
@@ -51,9 +52,10 @@ from datachain.data_storage.schema import (
|
|
|
51
52
|
from datachain.dataset import DatasetStatus, RowDict
|
|
52
53
|
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
53
54
|
from datachain.progress import CombinedDownloadCallback
|
|
55
|
+
from datachain.query.schema import DEFAULT_DELIMITER
|
|
54
56
|
from datachain.sql.functions import rand
|
|
55
57
|
from datachain.storage import Storage, StorageURI
|
|
56
|
-
from datachain.utils import batched, determine_processes
|
|
58
|
+
from datachain.utils import batched, determine_processes, inside_notebook
|
|
57
59
|
|
|
58
60
|
from .batch import RowBatch
|
|
59
61
|
from .metrics import metrics
|
|
@@ -62,7 +64,6 @@ from .session import Session
|
|
|
62
64
|
from .udf import UDFBase, UDFClassWrapper, UDFFactory, UDFType
|
|
63
65
|
|
|
64
66
|
if TYPE_CHECKING:
|
|
65
|
-
import pandas as pd
|
|
66
67
|
from sqlalchemy.sql.elements import ClauseElement
|
|
67
68
|
from sqlalchemy.sql.schema import Table
|
|
68
69
|
from sqlalchemy.sql.selectable import GenerativeSelect
|
|
@@ -547,8 +548,9 @@ class UDF(Step, ABC):
|
|
|
547
548
|
else:
|
|
548
549
|
udf = self.udf
|
|
549
550
|
|
|
550
|
-
if hasattr(udf.func, "
|
|
551
|
-
udf.func.
|
|
551
|
+
if hasattr(udf.func, "setup") and callable(udf.func.setup):
|
|
552
|
+
udf.func.setup()
|
|
553
|
+
|
|
552
554
|
warehouse = self.catalog.warehouse
|
|
553
555
|
|
|
554
556
|
with contextlib.closing(
|
|
@@ -599,12 +601,15 @@ class UDF(Step, ABC):
|
|
|
599
601
|
# Create a dynamic module with the generated name
|
|
600
602
|
dynamic_module = types.ModuleType(feature_module_name)
|
|
601
603
|
# Get the import lines for the necessary objects from the main module
|
|
602
|
-
import_lines = [
|
|
603
|
-
source.getimport(obj, alias=name)
|
|
604
|
-
for name, obj in inspect.getmembers(sys.modules["__main__"], _imports)
|
|
605
|
-
if not (name.startswith("__") and name.endswith("__"))
|
|
606
|
-
]
|
|
607
604
|
main_module = sys.modules["__main__"]
|
|
605
|
+
if getattr(main_module, "__file__", None):
|
|
606
|
+
import_lines = list(get_imports(main_module))
|
|
607
|
+
else:
|
|
608
|
+
import_lines = [
|
|
609
|
+
source.getimport(obj, alias=name)
|
|
610
|
+
for name, obj in main_module.__dict__.items()
|
|
611
|
+
if _imports(obj) and not (name.startswith("__") and name.endswith("__"))
|
|
612
|
+
]
|
|
608
613
|
|
|
609
614
|
# Get the feature classes from the main module
|
|
610
615
|
feature_classes = {
|
|
@@ -612,6 +617,10 @@ class UDF(Step, ABC):
|
|
|
612
617
|
for name, obj in main_module.__dict__.items()
|
|
613
618
|
if _feature_predicate(obj)
|
|
614
619
|
}
|
|
620
|
+
if not feature_classes:
|
|
621
|
+
yield None
|
|
622
|
+
return
|
|
623
|
+
|
|
615
624
|
# Get the source code of the feature classes
|
|
616
625
|
feature_sources = [source.getsource(cls) for _, cls in feature_classes.items()]
|
|
617
626
|
# Set the module name for the feature classes to the generated name
|
|
@@ -621,7 +630,7 @@ class UDF(Step, ABC):
|
|
|
621
630
|
# Add the dynamic module to the sys.modules dictionary
|
|
622
631
|
sys.modules[feature_module_name] = dynamic_module
|
|
623
632
|
# Combine the import lines and feature sources
|
|
624
|
-
feature_file = "".join(import_lines) + "\n".join(feature_sources)
|
|
633
|
+
feature_file = "\n".join(import_lines) + "\n" + "\n".join(feature_sources)
|
|
625
634
|
|
|
626
635
|
# Write the module content to a .py file
|
|
627
636
|
with open(f"{feature_module_name}.py", "w") as module_file:
|
|
@@ -1362,33 +1371,11 @@ class DatasetQuery:
|
|
|
1362
1371
|
cols = result.columns
|
|
1363
1372
|
return [dict(zip(cols, row)) for row in result]
|
|
1364
1373
|
|
|
1365
|
-
@classmethod
|
|
1366
|
-
def create_empty_record(
|
|
1367
|
-
cls, name: Optional[str] = None, session: Optional[Session] = None
|
|
1368
|
-
) -> "DatasetRecord":
|
|
1369
|
-
session = Session.get(session)
|
|
1370
|
-
if name is None:
|
|
1371
|
-
name = session.generate_temp_dataset_name()
|
|
1372
|
-
columns = session.catalog.warehouse.dataset_row_cls.file_columns()
|
|
1373
|
-
return session.catalog.create_dataset(name, columns=columns)
|
|
1374
|
-
|
|
1375
|
-
@classmethod
|
|
1376
|
-
def insert_record(
|
|
1377
|
-
cls,
|
|
1378
|
-
dsr: "DatasetRecord",
|
|
1379
|
-
record: dict[str, Any],
|
|
1380
|
-
session: Optional[Session] = None,
|
|
1381
|
-
) -> None:
|
|
1382
|
-
session = Session.get(session)
|
|
1383
|
-
dr = session.catalog.warehouse.dataset_rows(dsr)
|
|
1384
|
-
insert_q = dr.get_table().insert().values(**record)
|
|
1385
|
-
session.catalog.warehouse.db.execute(insert_q)
|
|
1386
|
-
|
|
1387
1374
|
def to_pandas(self) -> "pd.DataFrame":
|
|
1388
|
-
import pandas as pd
|
|
1389
|
-
|
|
1390
1375
|
records = self.to_records()
|
|
1391
|
-
|
|
1376
|
+
df = pd.DataFrame.from_records(records)
|
|
1377
|
+
df.columns = [c.replace(DEFAULT_DELIMITER, ".") for c in df.columns]
|
|
1378
|
+
return df
|
|
1392
1379
|
|
|
1393
1380
|
def shuffle(self) -> "Self":
|
|
1394
1381
|
# ToDo: implement shaffle based on seed and/or generating random column
|
|
@@ -1410,8 +1397,17 @@ class DatasetQuery:
|
|
|
1410
1397
|
|
|
1411
1398
|
def show(self, limit=20) -> None:
|
|
1412
1399
|
df = self.limit(limit).to_pandas()
|
|
1413
|
-
|
|
1414
|
-
|
|
1400
|
+
|
|
1401
|
+
options = ["display.max_colwidth", 50, "display.show_dimensions", False]
|
|
1402
|
+
with pd.option_context(*options):
|
|
1403
|
+
if inside_notebook():
|
|
1404
|
+
from IPython.display import display
|
|
1405
|
+
|
|
1406
|
+
display(df)
|
|
1407
|
+
|
|
1408
|
+
else:
|
|
1409
|
+
print(df.to_string())
|
|
1410
|
+
|
|
1415
1411
|
if len(df) == limit:
|
|
1416
1412
|
print(f"[limited by {limit} objects]")
|
|
1417
1413
|
|
|
@@ -1692,6 +1688,15 @@ class DatasetQuery:
|
|
|
1692
1688
|
storage.timestamp_str,
|
|
1693
1689
|
)
|
|
1694
1690
|
|
|
1691
|
+
def exec(self) -> "Self":
|
|
1692
|
+
"""Execute the query."""
|
|
1693
|
+
try:
|
|
1694
|
+
query = self.clone()
|
|
1695
|
+
query.apply_steps()
|
|
1696
|
+
finally:
|
|
1697
|
+
self.cleanup()
|
|
1698
|
+
return query
|
|
1699
|
+
|
|
1695
1700
|
def save(
|
|
1696
1701
|
self,
|
|
1697
1702
|
name: Optional[str] = None,
|
|
@@ -1878,3 +1883,24 @@ def _feature_predicate(obj):
|
|
|
1878
1883
|
|
|
1879
1884
|
def _imports(obj):
|
|
1880
1885
|
return not source.isfrommain(obj)
|
|
1886
|
+
|
|
1887
|
+
|
|
1888
|
+
def get_imports(m):
|
|
1889
|
+
root = ast.parse(inspect.getsource(m))
|
|
1890
|
+
|
|
1891
|
+
for node in ast.iter_child_nodes(root):
|
|
1892
|
+
if isinstance(node, ast.Import):
|
|
1893
|
+
module = None
|
|
1894
|
+
elif isinstance(node, ast.ImportFrom):
|
|
1895
|
+
module = node.module
|
|
1896
|
+
else:
|
|
1897
|
+
continue
|
|
1898
|
+
|
|
1899
|
+
for n in node.names:
|
|
1900
|
+
import_script = ""
|
|
1901
|
+
if module:
|
|
1902
|
+
import_script += f"from {module} "
|
|
1903
|
+
import_script += f"import {n.name}"
|
|
1904
|
+
if n.asname:
|
|
1905
|
+
import_script += f" as {n.asname}"
|
|
1906
|
+
yield import_script
|
datachain/query/dispatch.py
CHANGED
|
@@ -370,8 +370,8 @@ class UDFWorker:
|
|
|
370
370
|
return WorkerCallback(self.done_queue)
|
|
371
371
|
|
|
372
372
|
def run(self) -> None:
|
|
373
|
-
if hasattr(self.udf.func, "
|
|
374
|
-
self.udf.func.
|
|
373
|
+
if hasattr(self.udf.func, "setup") and callable(self.udf.func.setup):
|
|
374
|
+
self.udf.func.setup()
|
|
375
375
|
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
376
376
|
n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
|
|
377
377
|
udf_output = self.udf(
|
datachain/query/schema.py
CHANGED
|
@@ -3,14 +3,12 @@ import json
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from datetime import datetime, timezone
|
|
5
5
|
from fnmatch import fnmatch
|
|
6
|
-
from random import getrandbits
|
|
7
6
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
|
|
8
7
|
|
|
9
8
|
import attrs
|
|
10
9
|
import sqlalchemy as sa
|
|
11
10
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
12
11
|
|
|
13
|
-
from datachain.data_storage.warehouse import RANDOM_BITS
|
|
14
12
|
from datachain.sql.types import JSON, Boolean, DateTime, Int, Int64, SQLType, String
|
|
15
13
|
|
|
16
14
|
if TYPE_CHECKING:
|
|
@@ -217,7 +215,7 @@ class DatasetRow:
|
|
|
217
215
|
"source": String,
|
|
218
216
|
"parent": String,
|
|
219
217
|
"name": String,
|
|
220
|
-
"size":
|
|
218
|
+
"size": Int64,
|
|
221
219
|
"location": JSON,
|
|
222
220
|
"vtype": String,
|
|
223
221
|
"dir_type": Int,
|
|
@@ -227,8 +225,6 @@ class DatasetRow:
|
|
|
227
225
|
"last_modified": DateTime,
|
|
228
226
|
"version": String,
|
|
229
227
|
"etag": String,
|
|
230
|
-
# system column
|
|
231
|
-
"random": Int64,
|
|
232
228
|
}
|
|
233
229
|
|
|
234
230
|
@staticmethod
|
|
@@ -267,8 +263,6 @@ class DatasetRow:
|
|
|
267
263
|
|
|
268
264
|
last_modified = last_modified or datetime.now(timezone.utc)
|
|
269
265
|
|
|
270
|
-
random = getrandbits(RANDOM_BITS)
|
|
271
|
-
|
|
272
266
|
return ( # type: ignore [return-value]
|
|
273
267
|
source,
|
|
274
268
|
parent,
|
|
@@ -283,7 +277,6 @@ class DatasetRow:
|
|
|
283
277
|
last_modified,
|
|
284
278
|
version,
|
|
285
279
|
etag,
|
|
286
|
-
random,
|
|
287
280
|
)
|
|
288
281
|
|
|
289
282
|
@staticmethod
|
datachain/query/udf.py
CHANGED
|
@@ -14,6 +14,7 @@ from typing import (
|
|
|
14
14
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
15
15
|
|
|
16
16
|
from datachain.dataset import RowDict
|
|
17
|
+
from datachain.lib.utils import AbstractUDF
|
|
17
18
|
|
|
18
19
|
from .batch import Batch, BatchingStrategy, NoBatching, Partition, RowBatch
|
|
19
20
|
from .schema import (
|
|
@@ -58,14 +59,6 @@ class UDFProperties:
|
|
|
58
59
|
def signal_names(self) -> Iterable[str]:
|
|
59
60
|
return self.output.keys()
|
|
60
61
|
|
|
61
|
-
def parameter_parser(self) -> Callable:
|
|
62
|
-
"""Generate a parameter list from a dataset row."""
|
|
63
|
-
|
|
64
|
-
def plist(catalog: "Catalog", row: "RowDict", **kwargs) -> list:
|
|
65
|
-
return [p.get_value(catalog, row, **kwargs) for p in self.params]
|
|
66
|
-
|
|
67
|
-
return plist
|
|
68
|
-
|
|
69
62
|
|
|
70
63
|
def udf(
|
|
71
64
|
params: Sequence[UDFParamSpec],
|
|
@@ -113,32 +106,37 @@ class UDFBase:
|
|
|
113
106
|
self.func = func
|
|
114
107
|
self.properties = properties
|
|
115
108
|
self.signal_names = properties.signal_names()
|
|
116
|
-
self.parameter_parser = properties.parameter_parser()
|
|
117
109
|
self.output = properties.output
|
|
118
110
|
|
|
119
111
|
def __call__(
|
|
120
112
|
self,
|
|
121
113
|
catalog: "Catalog",
|
|
122
|
-
|
|
114
|
+
arg: "BatchingResult",
|
|
123
115
|
is_generator: bool = False,
|
|
124
116
|
cache: bool = False,
|
|
125
117
|
cb: Callback = DEFAULT_CALLBACK,
|
|
126
118
|
) -> Iterable[UDFResult]:
|
|
127
|
-
if isinstance(
|
|
119
|
+
if isinstance(self.func, AbstractUDF):
|
|
120
|
+
self.func._catalog = catalog # type: ignore[unreachable]
|
|
121
|
+
|
|
122
|
+
if isinstance(arg, RowBatch):
|
|
128
123
|
udf_inputs = [
|
|
129
|
-
self.
|
|
130
|
-
for row in
|
|
124
|
+
self.bind_parameters(catalog, row, cache=cache, cb=cb)
|
|
125
|
+
for row in arg.rows
|
|
131
126
|
]
|
|
132
127
|
udf_outputs = self.func(udf_inputs)
|
|
133
|
-
return self._process_results(
|
|
134
|
-
if isinstance(
|
|
135
|
-
udf_inputs = self.
|
|
128
|
+
return self._process_results(arg.rows, udf_outputs, is_generator)
|
|
129
|
+
if isinstance(arg, RowDict):
|
|
130
|
+
udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
|
|
136
131
|
udf_outputs = self.func(*udf_inputs)
|
|
137
132
|
if not is_generator:
|
|
138
133
|
# udf_outputs is generator already if is_generator=True
|
|
139
134
|
udf_outputs = [udf_outputs]
|
|
140
|
-
return self._process_results([
|
|
141
|
-
raise ValueError(f"
|
|
135
|
+
return self._process_results([arg], udf_outputs, is_generator)
|
|
136
|
+
raise ValueError(f"Unexpected UDF argument: {arg}")
|
|
137
|
+
|
|
138
|
+
def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
|
|
139
|
+
return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
|
|
142
140
|
|
|
143
141
|
def _process_results(
|
|
144
142
|
self,
|