datachain 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +0 -4
- datachain/catalog/catalog.py +17 -2
- datachain/cli.py +8 -1
- datachain/data_storage/db_engine.py +0 -2
- datachain/data_storage/schema.py +15 -26
- datachain/data_storage/sqlite.py +3 -0
- datachain/data_storage/warehouse.py +1 -7
- datachain/lib/arrow.py +7 -13
- datachain/lib/cached_stream.py +3 -85
- datachain/lib/clip.py +151 -0
- datachain/lib/dc.py +41 -59
- datachain/lib/feature.py +5 -1
- datachain/lib/feature_registry.py +3 -2
- datachain/lib/feature_utils.py +1 -2
- datachain/lib/file.py +17 -24
- datachain/lib/image.py +37 -79
- datachain/lib/pytorch.py +4 -2
- datachain/lib/signal_schema.py +3 -4
- datachain/lib/text.py +18 -49
- datachain/lib/udf.py +64 -55
- datachain/lib/udf_signature.py +11 -10
- datachain/lib/utils.py +17 -0
- datachain/lib/webdataset.py +2 -2
- datachain/listing.py +0 -3
- datachain/query/dataset.py +66 -46
- datachain/query/dispatch.py +2 -2
- datachain/query/schema.py +1 -8
- datachain/query/udf.py +16 -18
- datachain/sql/sqlite/base.py +34 -2
- datachain/sql/sqlite/vector.py +13 -5
- datachain/utils.py +28 -0
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/METADATA +3 -2
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/RECORD +37 -38
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/WHEEL +1 -1
- datachain/_version.py +0 -16
- datachain/lib/reader.py +0 -49
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/LICENSE +0 -0
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/top_level.txt +0 -0
datachain/lib/udf.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import sys
|
|
3
3
|
import traceback
|
|
4
|
-
from typing import TYPE_CHECKING, Callable
|
|
4
|
+
from typing import TYPE_CHECKING, Callable
|
|
5
5
|
|
|
6
6
|
from datachain.lib.feature import Feature
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
|
-
from datachain.lib.
|
|
9
|
-
from datachain.
|
|
8
|
+
from datachain.lib.udf_signature import UdfSignature
|
|
9
|
+
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
10
|
+
from datachain.query import udf
|
|
10
11
|
|
|
11
12
|
if TYPE_CHECKING:
|
|
12
|
-
from
|
|
13
|
+
from datachain.query.udf import UDFWrapper
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class UdfError(DataChainParamsError):
|
|
@@ -17,31 +18,68 @@ class UdfError(DataChainParamsError):
|
|
|
17
18
|
super().__init__(f"UDF error: {msg}")
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
class UDFBase:
|
|
21
|
+
class UDFBase(AbstractUDF):
|
|
21
22
|
is_input_batched = False
|
|
22
23
|
is_output_batched = False
|
|
23
24
|
is_input_grouped = False
|
|
24
25
|
|
|
25
|
-
def __init__(
|
|
26
|
-
self
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
26
|
+
def __init__(self):
|
|
27
|
+
self.params = None
|
|
28
|
+
self.output = None
|
|
29
|
+
self.params_spec = None
|
|
30
|
+
self.output_spec = None
|
|
31
|
+
self._contains_stream = None
|
|
32
|
+
self._catalog = None
|
|
33
|
+
self._func = None
|
|
34
|
+
|
|
35
|
+
def process(self, *args, **kwargs):
|
|
36
|
+
"""Processing function that needs to be defined by user"""
|
|
37
|
+
if not self._func:
|
|
38
|
+
raise NotImplementedError("UDF processing is not implemented")
|
|
39
|
+
return self._func(*args, **kwargs)
|
|
40
|
+
|
|
41
|
+
def setup(self):
|
|
42
|
+
"""Initialization process executed on each worker before processing begins.
|
|
43
|
+
This is needed for tasks like pre-loading ML models prior to scoring.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def teardown(self):
|
|
47
|
+
"""Teardown process executed on each process/worker after processing ends.
|
|
48
|
+
This is needed for tasks like closing connections to end-points.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def _init(self, sign: UdfSignature, params: SignalSchema, func: Callable):
|
|
31
52
|
self.params = params
|
|
32
|
-
self.output =
|
|
33
|
-
self._func = func
|
|
53
|
+
self.output = sign.output_schema
|
|
34
54
|
|
|
35
|
-
params_spec = params.to_udf_spec()
|
|
55
|
+
params_spec = self.params.to_udf_spec()
|
|
36
56
|
self.params_spec = list(params_spec.keys())
|
|
37
|
-
self.
|
|
38
|
-
if params.contains_file():
|
|
39
|
-
self.params_spec.insert(0, Stream()) # type: ignore[arg-type]
|
|
40
|
-
self._contains_stream = True
|
|
57
|
+
self.output_spec = self.output.to_udf_spec()
|
|
41
58
|
|
|
42
|
-
self.
|
|
59
|
+
self._func = func
|
|
43
60
|
|
|
44
|
-
|
|
61
|
+
@classmethod
|
|
62
|
+
def _create(
|
|
63
|
+
cls,
|
|
64
|
+
target_class: type["UDFBase"],
|
|
65
|
+
sign: UdfSignature,
|
|
66
|
+
params: SignalSchema,
|
|
67
|
+
catalog,
|
|
68
|
+
) -> "UDFBase":
|
|
69
|
+
if isinstance(sign.func, AbstractUDF):
|
|
70
|
+
if not isinstance(sign.func, target_class): # type: ignore[unreachable]
|
|
71
|
+
raise UdfError(
|
|
72
|
+
f"cannot create UDF: provided UDF '{sign.func.__name__}'"
|
|
73
|
+
f" must be a child of target class '{target_class.__name__}'",
|
|
74
|
+
)
|
|
75
|
+
result = sign.func
|
|
76
|
+
func = None
|
|
77
|
+
else:
|
|
78
|
+
result = target_class()
|
|
79
|
+
func = sign.func
|
|
80
|
+
|
|
81
|
+
result._init(sign, params, func)
|
|
82
|
+
return result
|
|
45
83
|
|
|
46
84
|
@property
|
|
47
85
|
def name(self):
|
|
@@ -58,25 +96,10 @@ class UDFBase:
|
|
|
58
96
|
udf_wrapper = udf(self.params_spec, self.output_spec, batch=batch)
|
|
59
97
|
return udf_wrapper(self)
|
|
60
98
|
|
|
61
|
-
def bootstrap(self):
|
|
62
|
-
"""Initialization process executed on each worker before processing begins.
|
|
63
|
-
This is needed for tasks like pre-loading ML models prior to scoring.
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
def teardown(self):
|
|
67
|
-
"""Teardown process executed on each process/worker after processing ends.
|
|
68
|
-
This is needed for tasks like closing connections to end-points.
|
|
69
|
-
"""
|
|
70
|
-
|
|
71
|
-
def process(self, *args, **kwargs):
|
|
72
|
-
if not self._func:
|
|
73
|
-
raise NotImplementedError("UDF processing is not implemented")
|
|
74
|
-
return self._func(*args, **kwargs)
|
|
75
|
-
|
|
76
99
|
def validate_results(self, results, *args, **kwargs):
|
|
77
100
|
return results
|
|
78
101
|
|
|
79
|
-
def __call__(self, *rows
|
|
102
|
+
def __call__(self, *rows):
|
|
80
103
|
if self.is_input_grouped:
|
|
81
104
|
objs = self._parse_grouped_rows(rows)
|
|
82
105
|
else:
|
|
@@ -122,18 +145,10 @@ class UDFBase:
|
|
|
122
145
|
rows = [rows]
|
|
123
146
|
objs = []
|
|
124
147
|
for row in rows:
|
|
125
|
-
if self._contains_stream:
|
|
126
|
-
stream, *row = row
|
|
127
|
-
else:
|
|
128
|
-
stream = None
|
|
129
|
-
|
|
130
148
|
obj_row = self.params.row_to_objs(row)
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
if isinstance(obj, Feature):
|
|
135
|
-
obj._set_stream(self._catalog, stream, True)
|
|
136
|
-
|
|
149
|
+
for obj in obj_row:
|
|
150
|
+
if isinstance(obj, Feature):
|
|
151
|
+
obj._set_stream(self._catalog, caching_enabled=True)
|
|
137
152
|
objs.append(obj_row)
|
|
138
153
|
return objs
|
|
139
154
|
|
|
@@ -150,13 +165,7 @@ class UDFBase:
|
|
|
150
165
|
output_map[name] = []
|
|
151
166
|
|
|
152
167
|
for flat_obj in group:
|
|
153
|
-
|
|
154
|
-
position = 1
|
|
155
|
-
stream = flat_obj[0]
|
|
156
|
-
else:
|
|
157
|
-
position = 0
|
|
158
|
-
stream = None
|
|
159
|
-
|
|
168
|
+
position = 0
|
|
160
169
|
for signal, (cls, length) in spec_map.items():
|
|
161
170
|
slice = flat_obj[position : position + length]
|
|
162
171
|
position += length
|
|
@@ -167,7 +176,7 @@ class UDFBase:
|
|
|
167
176
|
obj = slice[0]
|
|
168
177
|
|
|
169
178
|
if isinstance(obj, Feature):
|
|
170
|
-
obj._set_stream(self._catalog
|
|
179
|
+
obj._set_stream(self._catalog)
|
|
171
180
|
output_map[signal].append(obj)
|
|
172
181
|
|
|
173
182
|
return list(output_map.values())
|
datachain/lib/udf_signature.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Callable, Optional, Union, get_args, get_origin
|
|
|
5
5
|
|
|
6
6
|
from datachain.lib.feature import Feature, FeatureType, FeatureTypeNames
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
|
-
from datachain.lib.utils import DataChainParamsError
|
|
8
|
+
from datachain.lib.utils import AbstractUDF, DataChainParamsError
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class UdfSignatureError(DataChainParamsError):
|
|
@@ -49,10 +49,13 @@ class UdfSignature:
|
|
|
49
49
|
else:
|
|
50
50
|
if func is None:
|
|
51
51
|
raise UdfSignatureError(chain, "user function is not defined")
|
|
52
|
+
|
|
52
53
|
udf_func = func
|
|
53
54
|
signal_name = None
|
|
55
|
+
|
|
54
56
|
if not callable(udf_func):
|
|
55
|
-
raise UdfSignatureError(chain, f"
|
|
57
|
+
raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
|
|
58
|
+
|
|
56
59
|
func_params_map_sign, func_outs_sign, is_iterator = (
|
|
57
60
|
UdfSignature._func_signature(chain, udf_func)
|
|
58
61
|
)
|
|
@@ -108,13 +111,6 @@ class UdfSignature:
|
|
|
108
111
|
if isinstance(output, str):
|
|
109
112
|
output = [output]
|
|
110
113
|
if isinstance(output, Sequence):
|
|
111
|
-
if not func_outs_sign:
|
|
112
|
-
raise UdfSignatureError(
|
|
113
|
-
chain,
|
|
114
|
-
"output types are not specified. Specify types in 'output' as"
|
|
115
|
-
" a dict or as function return value hint.",
|
|
116
|
-
)
|
|
117
|
-
|
|
118
114
|
if len(func_outs_sign) != len(output):
|
|
119
115
|
raise UdfSignatureError(
|
|
120
116
|
chain,
|
|
@@ -158,8 +154,13 @@ class UdfSignature:
|
|
|
158
154
|
|
|
159
155
|
@staticmethod
|
|
160
156
|
def _func_signature(
|
|
161
|
-
chain: str,
|
|
157
|
+
chain: str, udf_func: Callable
|
|
162
158
|
) -> tuple[dict[str, type], Sequence[type], bool]:
|
|
159
|
+
if isinstance(udf_func, AbstractUDF):
|
|
160
|
+
func = udf_func.process # type: ignore[unreachable]
|
|
161
|
+
else:
|
|
162
|
+
func = udf_func
|
|
163
|
+
|
|
163
164
|
sign = inspect.signature(func)
|
|
164
165
|
|
|
165
166
|
input_map = {prm.name: prm.annotation for prm in sign.parameters.values()}
|
datachain/lib/utils.py
CHANGED
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class AbstractUDF(ABC):
|
|
5
|
+
@abstractmethod
|
|
6
|
+
def process(self, *args, **kwargs):
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def setup(self):
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def teardown(self):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
1
18
|
class DataChainError(Exception):
|
|
2
19
|
def __init__(self, message):
|
|
3
20
|
super().__init__(message)
|
datachain/lib/webdataset.py
CHANGED
|
@@ -2,6 +2,7 @@ import hashlib
|
|
|
2
2
|
import json
|
|
3
3
|
import tarfile
|
|
4
4
|
from collections.abc import Iterator, Sequence
|
|
5
|
+
from pathlib import Path
|
|
5
6
|
from typing import (
|
|
6
7
|
Any,
|
|
7
8
|
Callable,
|
|
@@ -240,10 +241,9 @@ class TarStream(File):
|
|
|
240
241
|
def get_tar_groups(stream, tar, core_extensions, spec, encoding="utf-8"):
|
|
241
242
|
builder = Builder(stream, core_extensions, spec, tar, encoding)
|
|
242
243
|
|
|
243
|
-
for item in tar.getmembers():
|
|
244
|
+
for item in sorted(tar.getmembers(), key=lambda m: Path(m.name).stem):
|
|
244
245
|
if not item.isfile():
|
|
245
246
|
continue
|
|
246
|
-
|
|
247
247
|
try:
|
|
248
248
|
builder.add(item)
|
|
249
249
|
except StopIteration:
|
datachain/listing.py
CHANGED
datachain/query/dataset.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import ast
|
|
1
2
|
import contextlib
|
|
2
3
|
import datetime
|
|
3
4
|
import inspect
|
|
@@ -51,9 +52,10 @@ from datachain.data_storage.schema import (
|
|
|
51
52
|
from datachain.dataset import DatasetStatus, RowDict
|
|
52
53
|
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
53
54
|
from datachain.progress import CombinedDownloadCallback
|
|
55
|
+
from datachain.query.schema import DEFAULT_DELIMITER
|
|
54
56
|
from datachain.sql.functions import rand
|
|
55
57
|
from datachain.storage import Storage, StorageURI
|
|
56
|
-
from datachain.utils import batched, determine_processes
|
|
58
|
+
from datachain.utils import batched, determine_processes, inside_notebook
|
|
57
59
|
|
|
58
60
|
from .batch import RowBatch
|
|
59
61
|
from .metrics import metrics
|
|
@@ -62,7 +64,6 @@ from .session import Session
|
|
|
62
64
|
from .udf import UDFBase, UDFClassWrapper, UDFFactory, UDFType
|
|
63
65
|
|
|
64
66
|
if TYPE_CHECKING:
|
|
65
|
-
import pandas as pd
|
|
66
67
|
from sqlalchemy.sql.elements import ClauseElement
|
|
67
68
|
from sqlalchemy.sql.schema import Table
|
|
68
69
|
from sqlalchemy.sql.selectable import GenerativeSelect
|
|
@@ -547,8 +548,9 @@ class UDF(Step, ABC):
|
|
|
547
548
|
else:
|
|
548
549
|
udf = self.udf
|
|
549
550
|
|
|
550
|
-
if hasattr(udf.func, "
|
|
551
|
-
udf.func.
|
|
551
|
+
if hasattr(udf.func, "setup") and callable(udf.func.setup):
|
|
552
|
+
udf.func.setup()
|
|
553
|
+
|
|
552
554
|
warehouse = self.catalog.warehouse
|
|
553
555
|
|
|
554
556
|
with contextlib.closing(
|
|
@@ -599,12 +601,15 @@ class UDF(Step, ABC):
|
|
|
599
601
|
# Create a dynamic module with the generated name
|
|
600
602
|
dynamic_module = types.ModuleType(feature_module_name)
|
|
601
603
|
# Get the import lines for the necessary objects from the main module
|
|
602
|
-
import_lines = [
|
|
603
|
-
source.getimport(obj, alias=name)
|
|
604
|
-
for name, obj in inspect.getmembers(sys.modules["__main__"], _imports)
|
|
605
|
-
if not (name.startswith("__") and name.endswith("__"))
|
|
606
|
-
]
|
|
607
604
|
main_module = sys.modules["__main__"]
|
|
605
|
+
if getattr(main_module, "__file__", None):
|
|
606
|
+
import_lines = list(get_imports(main_module))
|
|
607
|
+
else:
|
|
608
|
+
import_lines = [
|
|
609
|
+
source.getimport(obj, alias=name)
|
|
610
|
+
for name, obj in main_module.__dict__.items()
|
|
611
|
+
if _imports(obj) and not (name.startswith("__") and name.endswith("__"))
|
|
612
|
+
]
|
|
608
613
|
|
|
609
614
|
# Get the feature classes from the main module
|
|
610
615
|
feature_classes = {
|
|
@@ -612,6 +617,10 @@ class UDF(Step, ABC):
|
|
|
612
617
|
for name, obj in main_module.__dict__.items()
|
|
613
618
|
if _feature_predicate(obj)
|
|
614
619
|
}
|
|
620
|
+
if not feature_classes:
|
|
621
|
+
yield None
|
|
622
|
+
return
|
|
623
|
+
|
|
615
624
|
# Get the source code of the feature classes
|
|
616
625
|
feature_sources = [source.getsource(cls) for _, cls in feature_classes.items()]
|
|
617
626
|
# Set the module name for the feature classes to the generated name
|
|
@@ -621,7 +630,7 @@ class UDF(Step, ABC):
|
|
|
621
630
|
# Add the dynamic module to the sys.modules dictionary
|
|
622
631
|
sys.modules[feature_module_name] = dynamic_module
|
|
623
632
|
# Combine the import lines and feature sources
|
|
624
|
-
feature_file = "".join(import_lines) + "\n".join(feature_sources)
|
|
633
|
+
feature_file = "\n".join(import_lines) + "\n" + "\n".join(feature_sources)
|
|
625
634
|
|
|
626
635
|
# Write the module content to a .py file
|
|
627
636
|
with open(f"{feature_module_name}.py", "w") as module_file:
|
|
@@ -1362,33 +1371,11 @@ class DatasetQuery:
|
|
|
1362
1371
|
cols = result.columns
|
|
1363
1372
|
return [dict(zip(cols, row)) for row in result]
|
|
1364
1373
|
|
|
1365
|
-
@classmethod
|
|
1366
|
-
def create_empty_record(
|
|
1367
|
-
cls, name: Optional[str] = None, session: Optional[Session] = None
|
|
1368
|
-
) -> "DatasetRecord":
|
|
1369
|
-
session = Session.get(session)
|
|
1370
|
-
if name is None:
|
|
1371
|
-
name = session.generate_temp_dataset_name()
|
|
1372
|
-
columns = session.catalog.warehouse.dataset_row_cls.file_columns()
|
|
1373
|
-
return session.catalog.create_dataset(name, columns=columns)
|
|
1374
|
-
|
|
1375
|
-
@classmethod
|
|
1376
|
-
def insert_record(
|
|
1377
|
-
cls,
|
|
1378
|
-
dsr: "DatasetRecord",
|
|
1379
|
-
record: dict[str, Any],
|
|
1380
|
-
session: Optional[Session] = None,
|
|
1381
|
-
) -> None:
|
|
1382
|
-
session = Session.get(session)
|
|
1383
|
-
dr = session.catalog.warehouse.dataset_rows(dsr)
|
|
1384
|
-
insert_q = dr.get_table().insert().values(**record)
|
|
1385
|
-
session.catalog.warehouse.db.execute(insert_q)
|
|
1386
|
-
|
|
1387
1374
|
def to_pandas(self) -> "pd.DataFrame":
|
|
1388
|
-
import pandas as pd
|
|
1389
|
-
|
|
1390
1375
|
records = self.to_records()
|
|
1391
|
-
|
|
1376
|
+
df = pd.DataFrame.from_records(records)
|
|
1377
|
+
df.columns = [c.replace(DEFAULT_DELIMITER, ".") for c in df.columns]
|
|
1378
|
+
return df
|
|
1392
1379
|
|
|
1393
1380
|
def shuffle(self) -> "Self":
|
|
1394
1381
|
# ToDo: implement shaffle based on seed and/or generating random column
|
|
@@ -1410,8 +1397,17 @@ class DatasetQuery:
|
|
|
1410
1397
|
|
|
1411
1398
|
def show(self, limit=20) -> None:
|
|
1412
1399
|
df = self.limit(limit).to_pandas()
|
|
1413
|
-
|
|
1414
|
-
|
|
1400
|
+
|
|
1401
|
+
options = ["display.max_colwidth", 50, "display.show_dimensions", False]
|
|
1402
|
+
with pd.option_context(*options):
|
|
1403
|
+
if inside_notebook():
|
|
1404
|
+
from IPython.display import display
|
|
1405
|
+
|
|
1406
|
+
display(df)
|
|
1407
|
+
|
|
1408
|
+
else:
|
|
1409
|
+
print(df.to_string())
|
|
1410
|
+
|
|
1415
1411
|
if len(df) == limit:
|
|
1416
1412
|
print(f"[limited by {limit} objects]")
|
|
1417
1413
|
|
|
@@ -1692,6 +1688,15 @@ class DatasetQuery:
|
|
|
1692
1688
|
storage.timestamp_str,
|
|
1693
1689
|
)
|
|
1694
1690
|
|
|
1691
|
+
def exec(self) -> "Self":
|
|
1692
|
+
"""Execute the query."""
|
|
1693
|
+
try:
|
|
1694
|
+
query = self.clone()
|
|
1695
|
+
query.apply_steps()
|
|
1696
|
+
finally:
|
|
1697
|
+
self.cleanup()
|
|
1698
|
+
return query
|
|
1699
|
+
|
|
1695
1700
|
def save(
|
|
1696
1701
|
self,
|
|
1697
1702
|
name: Optional[str] = None,
|
|
@@ -1737,22 +1742,16 @@ class DatasetQuery:
|
|
|
1737
1742
|
|
|
1738
1743
|
# Exclude the id column and let the db create it to avoid unique
|
|
1739
1744
|
# constraint violations.
|
|
1740
|
-
cols = [col.name for col in dr.get_table().c if col.name != "id"]
|
|
1741
|
-
assert cols
|
|
1742
1745
|
q = query.exclude(("id",))
|
|
1743
|
-
|
|
1744
1746
|
if q._order_by_clauses:
|
|
1745
1747
|
# ensuring we have id sorted by order by clause if it exists in a query
|
|
1746
1748
|
q = q.add_columns(
|
|
1747
1749
|
f.row_number().over(order_by=q._order_by_clauses).label("id")
|
|
1748
1750
|
)
|
|
1749
|
-
cols.append("id")
|
|
1750
|
-
|
|
1751
|
-
self.catalog.warehouse.db.execute(
|
|
1752
|
-
sqlalchemy.insert(dr.get_table()).from_select(cols, q),
|
|
1753
|
-
**kwargs,
|
|
1754
|
-
)
|
|
1755
1751
|
|
|
1752
|
+
cols = tuple(c.name for c in q.columns)
|
|
1753
|
+
insert_q = sqlalchemy.insert(dr.get_table()).from_select(cols, q)
|
|
1754
|
+
self.catalog.warehouse.db.execute(insert_q, **kwargs)
|
|
1756
1755
|
self.catalog.metastore.update_dataset_status(
|
|
1757
1756
|
dataset, DatasetStatus.COMPLETE, version=version
|
|
1758
1757
|
)
|
|
@@ -1884,3 +1883,24 @@ def _feature_predicate(obj):
|
|
|
1884
1883
|
|
|
1885
1884
|
def _imports(obj):
|
|
1886
1885
|
return not source.isfrommain(obj)
|
|
1886
|
+
|
|
1887
|
+
|
|
1888
|
+
def get_imports(m):
|
|
1889
|
+
root = ast.parse(inspect.getsource(m))
|
|
1890
|
+
|
|
1891
|
+
for node in ast.iter_child_nodes(root):
|
|
1892
|
+
if isinstance(node, ast.Import):
|
|
1893
|
+
module = None
|
|
1894
|
+
elif isinstance(node, ast.ImportFrom):
|
|
1895
|
+
module = node.module
|
|
1896
|
+
else:
|
|
1897
|
+
continue
|
|
1898
|
+
|
|
1899
|
+
for n in node.names:
|
|
1900
|
+
import_script = ""
|
|
1901
|
+
if module:
|
|
1902
|
+
import_script += f"from {module} "
|
|
1903
|
+
import_script += f"import {n.name}"
|
|
1904
|
+
if n.asname:
|
|
1905
|
+
import_script += f" as {n.asname}"
|
|
1906
|
+
yield import_script
|
datachain/query/dispatch.py
CHANGED
|
@@ -370,8 +370,8 @@ class UDFWorker:
|
|
|
370
370
|
return WorkerCallback(self.done_queue)
|
|
371
371
|
|
|
372
372
|
def run(self) -> None:
|
|
373
|
-
if hasattr(self.udf.func, "
|
|
374
|
-
self.udf.func.
|
|
373
|
+
if hasattr(self.udf.func, "setup") and callable(self.udf.func.setup):
|
|
374
|
+
self.udf.func.setup()
|
|
375
375
|
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
376
376
|
n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
|
|
377
377
|
udf_output = self.udf(
|
datachain/query/schema.py
CHANGED
|
@@ -3,14 +3,12 @@ import json
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from datetime import datetime, timezone
|
|
5
5
|
from fnmatch import fnmatch
|
|
6
|
-
from random import getrandbits
|
|
7
6
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
|
|
8
7
|
|
|
9
8
|
import attrs
|
|
10
9
|
import sqlalchemy as sa
|
|
11
10
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
12
11
|
|
|
13
|
-
from datachain.data_storage.warehouse import RANDOM_BITS
|
|
14
12
|
from datachain.sql.types import JSON, Boolean, DateTime, Int, Int64, SQLType, String
|
|
15
13
|
|
|
16
14
|
if TYPE_CHECKING:
|
|
@@ -217,7 +215,7 @@ class DatasetRow:
|
|
|
217
215
|
"source": String,
|
|
218
216
|
"parent": String,
|
|
219
217
|
"name": String,
|
|
220
|
-
"size":
|
|
218
|
+
"size": Int64,
|
|
221
219
|
"location": JSON,
|
|
222
220
|
"vtype": String,
|
|
223
221
|
"dir_type": Int,
|
|
@@ -227,8 +225,6 @@ class DatasetRow:
|
|
|
227
225
|
"last_modified": DateTime,
|
|
228
226
|
"version": String,
|
|
229
227
|
"etag": String,
|
|
230
|
-
# system column
|
|
231
|
-
"random": Int64,
|
|
232
228
|
}
|
|
233
229
|
|
|
234
230
|
@staticmethod
|
|
@@ -267,8 +263,6 @@ class DatasetRow:
|
|
|
267
263
|
|
|
268
264
|
last_modified = last_modified or datetime.now(timezone.utc)
|
|
269
265
|
|
|
270
|
-
random = getrandbits(RANDOM_BITS)
|
|
271
|
-
|
|
272
266
|
return ( # type: ignore [return-value]
|
|
273
267
|
source,
|
|
274
268
|
parent,
|
|
@@ -283,7 +277,6 @@ class DatasetRow:
|
|
|
283
277
|
last_modified,
|
|
284
278
|
version,
|
|
285
279
|
etag,
|
|
286
|
-
random,
|
|
287
280
|
)
|
|
288
281
|
|
|
289
282
|
@staticmethod
|
datachain/query/udf.py
CHANGED
|
@@ -14,6 +14,7 @@ from typing import (
|
|
|
14
14
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
15
15
|
|
|
16
16
|
from datachain.dataset import RowDict
|
|
17
|
+
from datachain.lib.utils import AbstractUDF
|
|
17
18
|
|
|
18
19
|
from .batch import Batch, BatchingStrategy, NoBatching, Partition, RowBatch
|
|
19
20
|
from .schema import (
|
|
@@ -58,14 +59,6 @@ class UDFProperties:
|
|
|
58
59
|
def signal_names(self) -> Iterable[str]:
|
|
59
60
|
return self.output.keys()
|
|
60
61
|
|
|
61
|
-
def parameter_parser(self) -> Callable:
|
|
62
|
-
"""Generate a parameter list from a dataset row."""
|
|
63
|
-
|
|
64
|
-
def plist(catalog: "Catalog", row: "RowDict", **kwargs) -> list:
|
|
65
|
-
return [p.get_value(catalog, row, **kwargs) for p in self.params]
|
|
66
|
-
|
|
67
|
-
return plist
|
|
68
|
-
|
|
69
62
|
|
|
70
63
|
def udf(
|
|
71
64
|
params: Sequence[UDFParamSpec],
|
|
@@ -113,32 +106,37 @@ class UDFBase:
|
|
|
113
106
|
self.func = func
|
|
114
107
|
self.properties = properties
|
|
115
108
|
self.signal_names = properties.signal_names()
|
|
116
|
-
self.parameter_parser = properties.parameter_parser()
|
|
117
109
|
self.output = properties.output
|
|
118
110
|
|
|
119
111
|
def __call__(
|
|
120
112
|
self,
|
|
121
113
|
catalog: "Catalog",
|
|
122
|
-
|
|
114
|
+
arg: "BatchingResult",
|
|
123
115
|
is_generator: bool = False,
|
|
124
116
|
cache: bool = False,
|
|
125
117
|
cb: Callback = DEFAULT_CALLBACK,
|
|
126
118
|
) -> Iterable[UDFResult]:
|
|
127
|
-
if isinstance(
|
|
119
|
+
if isinstance(self.func, AbstractUDF):
|
|
120
|
+
self.func._catalog = catalog # type: ignore[unreachable]
|
|
121
|
+
|
|
122
|
+
if isinstance(arg, RowBatch):
|
|
128
123
|
udf_inputs = [
|
|
129
|
-
self.
|
|
130
|
-
for row in
|
|
124
|
+
self.bind_parameters(catalog, row, cache=cache, cb=cb)
|
|
125
|
+
for row in arg.rows
|
|
131
126
|
]
|
|
132
127
|
udf_outputs = self.func(udf_inputs)
|
|
133
|
-
return self._process_results(
|
|
134
|
-
if isinstance(
|
|
135
|
-
udf_inputs = self.
|
|
128
|
+
return self._process_results(arg.rows, udf_outputs, is_generator)
|
|
129
|
+
if isinstance(arg, RowDict):
|
|
130
|
+
udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
|
|
136
131
|
udf_outputs = self.func(*udf_inputs)
|
|
137
132
|
if not is_generator:
|
|
138
133
|
# udf_outputs is generator already if is_generator=True
|
|
139
134
|
udf_outputs = [udf_outputs]
|
|
140
|
-
return self._process_results([
|
|
141
|
-
raise ValueError(f"
|
|
135
|
+
return self._process_results([arg], udf_outputs, is_generator)
|
|
136
|
+
raise ValueError(f"Unexpected UDF argument: {arg}")
|
|
137
|
+
|
|
138
|
+
def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
|
|
139
|
+
return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
|
|
142
140
|
|
|
143
141
|
def _process_results(
|
|
144
142
|
self,
|
datachain/sql/sqlite/base.py
CHANGED
|
@@ -71,8 +71,6 @@ def setup():
|
|
|
71
71
|
compiles(sql_path.name, "sqlite")(compile_path_name)
|
|
72
72
|
compiles(sql_path.file_stem, "sqlite")(compile_path_file_stem)
|
|
73
73
|
compiles(sql_path.file_ext, "sqlite")(compile_path_file_ext)
|
|
74
|
-
compiles(array.cosine_distance, "sqlite")(compile_cosine_distance)
|
|
75
|
-
compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance)
|
|
76
74
|
compiles(array.length, "sqlite")(compile_array_length)
|
|
77
75
|
compiles(string.length, "sqlite")(compile_string_length)
|
|
78
76
|
compiles(string.split, "sqlite")(compile_string_split)
|
|
@@ -81,6 +79,13 @@ def setup():
|
|
|
81
79
|
compiles(Values, "sqlite")(compile_values)
|
|
82
80
|
compiles(random.rand, "sqlite")(compile_rand)
|
|
83
81
|
|
|
82
|
+
if load_usearch_extension(sqlite3.connect(":memory:")):
|
|
83
|
+
compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
|
|
84
|
+
compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance_ext)
|
|
85
|
+
else:
|
|
86
|
+
compiles(array.cosine_distance, "sqlite")(compile_cosine_distance)
|
|
87
|
+
compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance)
|
|
88
|
+
|
|
84
89
|
register_user_defined_sql_functions()
|
|
85
90
|
setup_is_complete = True
|
|
86
91
|
|
|
@@ -246,11 +251,23 @@ def compile_path_file_ext(element, compiler, **kwargs):
|
|
|
246
251
|
return compiler.process(path_file_ext(*element.clauses.clauses), **kwargs)
|
|
247
252
|
|
|
248
253
|
|
|
254
|
+
def compile_cosine_distance_ext(element, compiler, **kwargs):
|
|
255
|
+
run_compiler_hook("cosine_distance")
|
|
256
|
+
return f"distance_cosine_f32({compiler.process(element.clauses, **kwargs)})"
|
|
257
|
+
|
|
258
|
+
|
|
249
259
|
def compile_cosine_distance(element, compiler, **kwargs):
|
|
250
260
|
run_compiler_hook("cosine_distance")
|
|
251
261
|
return f"cosine_distance({compiler.process(element.clauses, **kwargs)})"
|
|
252
262
|
|
|
253
263
|
|
|
264
|
+
def compile_euclidean_distance_ext(element, compiler, **kwargs):
|
|
265
|
+
run_compiler_hook("euclidean_distance")
|
|
266
|
+
return (
|
|
267
|
+
f"sqrt(distance_sqeuclidean_f32({compiler.process(element.clauses, **kwargs)}))"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
254
271
|
def compile_euclidean_distance(element, compiler, **kwargs):
|
|
255
272
|
run_compiler_hook("euclidean_distance")
|
|
256
273
|
return f"euclidean_distance({compiler.process(element.clauses, **kwargs)})"
|
|
@@ -330,3 +347,18 @@ def compile_values(element, compiler, **kwargs):
|
|
|
330
347
|
|
|
331
348
|
def compile_rand(element, compiler, **kwargs):
|
|
332
349
|
return compiler.process(func.random(), **kwargs)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def load_usearch_extension(conn) -> bool:
|
|
353
|
+
try:
|
|
354
|
+
# usearch is part of the vector optional dependencies
|
|
355
|
+
# we use the extension's cosine and euclidean distance functions
|
|
356
|
+
from usearch import sqlite_path
|
|
357
|
+
|
|
358
|
+
conn.enable_load_extension(True)
|
|
359
|
+
conn.load_extension(sqlite_path())
|
|
360
|
+
conn.enable_load_extension(False)
|
|
361
|
+
return True
|
|
362
|
+
|
|
363
|
+
except Exception: # noqa: BLE001
|
|
364
|
+
return False
|