datachain 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +0 -3
- datachain/catalog/catalog.py +8 -6
- datachain/cli.py +1 -1
- datachain/client/fsspec.py +9 -9
- datachain/data_storage/schema.py +2 -2
- datachain/data_storage/sqlite.py +5 -4
- datachain/data_storage/warehouse.py +18 -18
- datachain/func/__init__.py +49 -0
- datachain/{lib/func → func}/aggregate.py +13 -11
- datachain/func/array.py +176 -0
- datachain/func/base.py +23 -0
- datachain/func/conditional.py +81 -0
- datachain/func/func.py +384 -0
- datachain/func/path.py +110 -0
- datachain/func/random.py +23 -0
- datachain/func/string.py +154 -0
- datachain/func/window.py +49 -0
- datachain/lib/arrow.py +24 -12
- datachain/lib/data_model.py +25 -9
- datachain/lib/dataset_info.py +2 -2
- datachain/lib/dc.py +94 -56
- datachain/lib/hf.py +1 -1
- datachain/lib/signal_schema.py +1 -1
- datachain/lib/utils.py +1 -0
- datachain/lib/webdataset_laion.py +5 -5
- datachain/model/__init__.py +6 -0
- datachain/model/bbox.py +102 -0
- datachain/model/pose.py +88 -0
- datachain/model/segment.py +47 -0
- datachain/model/ultralytics/__init__.py +27 -0
- datachain/model/ultralytics/bbox.py +147 -0
- datachain/model/ultralytics/pose.py +113 -0
- datachain/model/ultralytics/segment.py +91 -0
- datachain/nodes_fetcher.py +2 -2
- datachain/query/dataset.py +57 -34
- datachain/sql/__init__.py +0 -2
- datachain/sql/functions/__init__.py +0 -26
- datachain/sql/selectable.py +11 -5
- datachain/sql/sqlite/base.py +11 -2
- datachain/toolkit/split.py +6 -2
- {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/METADATA +72 -71
- {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/RECORD +46 -35
- {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/WHEEL +1 -1
- datachain/lib/func/__init__.py +0 -32
- datachain/lib/func/func.py +0 -152
- datachain/lib/models/__init__.py +0 -5
- datachain/lib/models/bbox.py +0 -45
- datachain/lib/models/pose.py +0 -37
- datachain/lib/models/yolo.py +0 -39
- {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/LICENSE +0 -0
- {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/entry_points.txt +0 -0
- {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/top_level.txt +0 -0
datachain/__init__.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from datachain.lib import func, models
|
|
2
1
|
from datachain.lib.data_model import DataModel, DataType, is_chain_type
|
|
3
2
|
from datachain.lib.dc import C, Column, DataChain, Sys
|
|
4
3
|
from datachain.lib.file import (
|
|
@@ -35,9 +34,7 @@ __all__ = [
|
|
|
35
34
|
"Sys",
|
|
36
35
|
"TarVFile",
|
|
37
36
|
"TextFile",
|
|
38
|
-
"func",
|
|
39
37
|
"is_chain_type",
|
|
40
38
|
"metrics",
|
|
41
|
-
"models",
|
|
42
39
|
"param",
|
|
43
40
|
]
|
datachain/catalog/catalog.py
CHANGED
|
@@ -54,7 +54,6 @@ from datachain.error import (
|
|
|
54
54
|
QueryScriptCancelError,
|
|
55
55
|
QueryScriptRunError,
|
|
56
56
|
)
|
|
57
|
-
from datachain.listing import Listing
|
|
58
57
|
from datachain.node import DirType, Node, NodeWithPath
|
|
59
58
|
from datachain.nodes_thread_pool import NodesThreadPool
|
|
60
59
|
from datachain.remote.studio import StudioClient
|
|
@@ -76,6 +75,7 @@ if TYPE_CHECKING:
|
|
|
76
75
|
from datachain.dataset import DatasetVersion
|
|
77
76
|
from datachain.job import Job
|
|
78
77
|
from datachain.lib.file import File
|
|
78
|
+
from datachain.listing import Listing
|
|
79
79
|
|
|
80
80
|
logger = logging.getLogger("datachain")
|
|
81
81
|
|
|
@@ -236,7 +236,7 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
236
236
|
class NodeGroup:
|
|
237
237
|
"""Class for a group of nodes from the same source"""
|
|
238
238
|
|
|
239
|
-
listing: Listing
|
|
239
|
+
listing: "Listing"
|
|
240
240
|
sources: list[DataSource]
|
|
241
241
|
|
|
242
242
|
# The source path within the bucket
|
|
@@ -591,8 +591,9 @@ class Catalog:
|
|
|
591
591
|
client_config=None,
|
|
592
592
|
object_name="file",
|
|
593
593
|
skip_indexing=False,
|
|
594
|
-
) -> tuple[Listing, str]:
|
|
594
|
+
) -> tuple["Listing", str]:
|
|
595
595
|
from datachain.lib.dc import DataChain
|
|
596
|
+
from datachain.listing import Listing
|
|
596
597
|
|
|
597
598
|
DataChain.from_storage(
|
|
598
599
|
source, session=self.session, update=update, object_name=object_name
|
|
@@ -660,7 +661,8 @@ class Catalog:
|
|
|
660
661
|
no_glob: bool = False,
|
|
661
662
|
client_config=None,
|
|
662
663
|
) -> list[NodeGroup]:
|
|
663
|
-
from datachain.
|
|
664
|
+
from datachain.listing import Listing
|
|
665
|
+
from datachain.query.dataset import DatasetQuery
|
|
664
666
|
|
|
665
667
|
def _row_to_node(d: dict[str, Any]) -> Node:
|
|
666
668
|
del d["file__source"]
|
|
@@ -876,7 +878,7 @@ class Catalog:
|
|
|
876
878
|
def update_dataset_version_with_warehouse_info(
|
|
877
879
|
self, dataset: DatasetRecord, version: int, rows_dropped=False, **kwargs
|
|
878
880
|
) -> None:
|
|
879
|
-
from datachain.query import DatasetQuery
|
|
881
|
+
from datachain.query.dataset import DatasetQuery
|
|
880
882
|
|
|
881
883
|
dataset_version = dataset.get_version(version)
|
|
882
884
|
|
|
@@ -1177,7 +1179,7 @@ class Catalog:
|
|
|
1177
1179
|
def ls_dataset_rows(
|
|
1178
1180
|
self, name: str, version: int, offset=None, limit=None
|
|
1179
1181
|
) -> list[dict]:
|
|
1180
|
-
from datachain.query import DatasetQuery
|
|
1182
|
+
from datachain.query.dataset import DatasetQuery
|
|
1181
1183
|
|
|
1182
1184
|
dataset = self.get_dataset(name)
|
|
1183
1185
|
|
datachain/cli.py
CHANGED
|
@@ -957,7 +957,7 @@ def show(
|
|
|
957
957
|
schema: bool = False,
|
|
958
958
|
) -> None:
|
|
959
959
|
from datachain.lib.dc import DataChain
|
|
960
|
-
from datachain.query import DatasetQuery
|
|
960
|
+
from datachain.query.dataset import DatasetQuery
|
|
961
961
|
from datachain.utils import show_records
|
|
962
962
|
|
|
963
963
|
dataset = catalog.get_dataset(name)
|
datachain/client/fsspec.py
CHANGED
|
@@ -28,7 +28,6 @@ from tqdm import tqdm
|
|
|
28
28
|
from datachain.cache import DataChainCache
|
|
29
29
|
from datachain.client.fileslice import FileWrapper
|
|
30
30
|
from datachain.error import ClientError as DataChainClientError
|
|
31
|
-
from datachain.lib.file import File
|
|
32
31
|
from datachain.nodes_fetcher import NodesFetcher
|
|
33
32
|
from datachain.nodes_thread_pool import NodeChunk
|
|
34
33
|
|
|
@@ -36,6 +35,7 @@ if TYPE_CHECKING:
|
|
|
36
35
|
from fsspec.spec import AbstractFileSystem
|
|
37
36
|
|
|
38
37
|
from datachain.dataset import StorageURI
|
|
38
|
+
from datachain.lib.file import File
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
logger = logging.getLogger("datachain")
|
|
@@ -45,7 +45,7 @@ DELIMITER = "/" # Path delimiter.
|
|
|
45
45
|
|
|
46
46
|
DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$")
|
|
47
47
|
|
|
48
|
-
ResultQueue = asyncio.Queue[Optional[Sequence[File]]]
|
|
48
|
+
ResultQueue = asyncio.Queue[Optional[Sequence["File"]]]
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
def _is_win_local_path(uri: str) -> bool:
|
|
@@ -212,7 +212,7 @@ class Client(ABC):
|
|
|
212
212
|
|
|
213
213
|
async def scandir(
|
|
214
214
|
self, start_prefix: str, method: str = "default"
|
|
215
|
-
) -> AsyncIterator[Sequence[File]]:
|
|
215
|
+
) -> AsyncIterator[Sequence["File"]]:
|
|
216
216
|
try:
|
|
217
217
|
impl = getattr(self, f"_fetch_{method}")
|
|
218
218
|
except AttributeError:
|
|
@@ -317,7 +317,7 @@ class Client(ABC):
|
|
|
317
317
|
return f"{self.PREFIX}{self.name}/{rel_path}"
|
|
318
318
|
|
|
319
319
|
@abstractmethod
|
|
320
|
-
def info_to_file(self, v: dict[str, Any], parent: str) -> File: ...
|
|
320
|
+
def info_to_file(self, v: dict[str, Any], parent: str) -> "File": ...
|
|
321
321
|
|
|
322
322
|
def fetch_nodes(
|
|
323
323
|
self,
|
|
@@ -354,7 +354,7 @@ class Client(ABC):
|
|
|
354
354
|
copy2(src, dst)
|
|
355
355
|
|
|
356
356
|
def open_object(
|
|
357
|
-
self, file: File, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
|
|
357
|
+
self, file: "File", use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
|
|
358
358
|
) -> BinaryIO:
|
|
359
359
|
"""Open a file, including files in tar archives."""
|
|
360
360
|
if use_cache and (cache_path := self.cache.get_path(file)):
|
|
@@ -362,19 +362,19 @@ class Client(ABC):
|
|
|
362
362
|
assert not file.location
|
|
363
363
|
return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value]
|
|
364
364
|
|
|
365
|
-
def download(self, file: File, *, callback: Callback = DEFAULT_CALLBACK) -> None:
|
|
365
|
+
def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None:
|
|
366
366
|
sync(get_loop(), functools.partial(self._download, file, callback=callback))
|
|
367
367
|
|
|
368
|
-
async def _download(self, file: File, *, callback: "Callback" = None) -> None:
|
|
368
|
+
async def _download(self, file: "File", *, callback: "Callback" = None) -> None:
|
|
369
369
|
if self.cache.contains(file):
|
|
370
370
|
# Already in cache, so there's nothing to do.
|
|
371
371
|
return
|
|
372
372
|
await self._put_in_cache(file, callback=callback)
|
|
373
373
|
|
|
374
|
-
def put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
|
|
374
|
+
def put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
|
|
375
375
|
sync(get_loop(), functools.partial(self._put_in_cache, file, callback=callback))
|
|
376
376
|
|
|
377
|
-
async def _put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
|
|
377
|
+
async def _put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
|
|
378
378
|
assert not file.location
|
|
379
379
|
if file.etag:
|
|
380
380
|
etag = await self.get_current_etag(file)
|
datachain/data_storage/schema.py
CHANGED
|
@@ -12,7 +12,7 @@ import sqlalchemy as sa
|
|
|
12
12
|
from sqlalchemy.sql import func as f
|
|
13
13
|
from sqlalchemy.sql.expression import false, null, true
|
|
14
14
|
|
|
15
|
-
from datachain.sql.functions import path
|
|
15
|
+
from datachain.sql.functions import path as pathfunc
|
|
16
16
|
from datachain.sql.types import Int, SQLType, UInt64
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
@@ -130,7 +130,7 @@ class DirExpansion:
|
|
|
130
130
|
|
|
131
131
|
def query(self, q):
|
|
132
132
|
q = self.base_select(q).cte(recursive=True)
|
|
133
|
-
parent =
|
|
133
|
+
parent = pathfunc.parent(self.c(q, "path"))
|
|
134
134
|
q = q.union_all(
|
|
135
135
|
sa.select(
|
|
136
136
|
sa.literal(-1).label("sys__id"),
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -122,7 +122,9 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
122
122
|
return cls(*cls._connect(db_file=db_file))
|
|
123
123
|
|
|
124
124
|
@staticmethod
|
|
125
|
-
def _connect(
|
|
125
|
+
def _connect(
|
|
126
|
+
db_file: Optional[str] = None,
|
|
127
|
+
) -> tuple["Engine", "MetaData", sqlite3.Connection, str]:
|
|
126
128
|
try:
|
|
127
129
|
if db_file == ":memory:":
|
|
128
130
|
# Enable multithreaded usage of the same in-memory db
|
|
@@ -130,9 +132,8 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
130
132
|
_get_in_memory_uri(), uri=True, detect_types=DETECT_TYPES
|
|
131
133
|
)
|
|
132
134
|
else:
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
)
|
|
135
|
+
db_file = db_file or DataChainDir.find().db
|
|
136
|
+
db = sqlite3.connect(db_file, detect_types=DETECT_TYPES)
|
|
136
137
|
create_user_defined_sql_functions(db)
|
|
137
138
|
engine = sqlalchemy.create_engine(
|
|
138
139
|
"sqlite+pysqlite:///", creator=lambda: db, future=True
|
|
@@ -224,28 +224,28 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
224
224
|
offset = 0
|
|
225
225
|
num_yielded = 0
|
|
226
226
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
if limit
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
227
|
+
# Ensure we're using a thread-local connection
|
|
228
|
+
with self.clone() as wh:
|
|
229
|
+
while True:
|
|
230
|
+
if limit is not None:
|
|
231
|
+
limit -= num_yielded
|
|
232
|
+
if limit == 0:
|
|
233
|
+
break
|
|
234
|
+
if limit < page_size:
|
|
235
|
+
paginated_query = paginated_query.limit(None).limit(limit)
|
|
236
|
+
|
|
237
237
|
# Cursor results are not thread-safe, so we convert them to a list
|
|
238
238
|
results = list(wh.dataset_rows_select(paginated_query.offset(offset)))
|
|
239
239
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
240
|
+
processed = False
|
|
241
|
+
for row in results:
|
|
242
|
+
processed = True
|
|
243
|
+
yield row
|
|
244
|
+
num_yielded += 1
|
|
245
245
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
246
|
+
if not processed:
|
|
247
|
+
break # no more results
|
|
248
|
+
offset += page_size
|
|
249
249
|
|
|
250
250
|
#
|
|
251
251
|
# Table Name Internal Functions
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from sqlalchemy import literal
|
|
2
|
+
|
|
3
|
+
from . import array, path, random, string
|
|
4
|
+
from .aggregate import (
|
|
5
|
+
any_value,
|
|
6
|
+
avg,
|
|
7
|
+
collect,
|
|
8
|
+
concat,
|
|
9
|
+
count,
|
|
10
|
+
dense_rank,
|
|
11
|
+
first,
|
|
12
|
+
max,
|
|
13
|
+
min,
|
|
14
|
+
rank,
|
|
15
|
+
row_number,
|
|
16
|
+
sum,
|
|
17
|
+
)
|
|
18
|
+
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
|
|
19
|
+
from .conditional import greatest, least
|
|
20
|
+
from .random import rand
|
|
21
|
+
from .window import window
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"any_value",
|
|
25
|
+
"array",
|
|
26
|
+
"avg",
|
|
27
|
+
"collect",
|
|
28
|
+
"concat",
|
|
29
|
+
"cosine_distance",
|
|
30
|
+
"count",
|
|
31
|
+
"dense_rank",
|
|
32
|
+
"euclidean_distance",
|
|
33
|
+
"first",
|
|
34
|
+
"greatest",
|
|
35
|
+
"least",
|
|
36
|
+
"length",
|
|
37
|
+
"literal",
|
|
38
|
+
"max",
|
|
39
|
+
"min",
|
|
40
|
+
"path",
|
|
41
|
+
"rand",
|
|
42
|
+
"random",
|
|
43
|
+
"rank",
|
|
44
|
+
"row_number",
|
|
45
|
+
"sip_hash_64",
|
|
46
|
+
"string",
|
|
47
|
+
"sum",
|
|
48
|
+
"window",
|
|
49
|
+
]
|
|
@@ -2,7 +2,7 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
from sqlalchemy import func as sa_func
|
|
4
4
|
|
|
5
|
-
from datachain.sql import
|
|
5
|
+
from datachain.sql.functions import aggregate
|
|
6
6
|
|
|
7
7
|
from .func import Func
|
|
8
8
|
|
|
@@ -31,7 +31,9 @@ def count(col: Optional[str] = None) -> Func:
|
|
|
31
31
|
Notes:
|
|
32
32
|
- Result column will always be of type int.
|
|
33
33
|
"""
|
|
34
|
-
return Func(
|
|
34
|
+
return Func(
|
|
35
|
+
"count", inner=sa_func.count, cols=[col] if col else None, result_type=int
|
|
36
|
+
)
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
def sum(col: str) -> Func:
|
|
@@ -59,7 +61,7 @@ def sum(col: str) -> Func:
|
|
|
59
61
|
- The `sum` function should be used on numeric columns.
|
|
60
62
|
- Result column type will be the same as the input column type.
|
|
61
63
|
"""
|
|
62
|
-
return Func("sum", inner=sa_func.sum,
|
|
64
|
+
return Func("sum", inner=sa_func.sum, cols=[col])
|
|
63
65
|
|
|
64
66
|
|
|
65
67
|
def avg(col: str) -> Func:
|
|
@@ -87,7 +89,7 @@ def avg(col: str) -> Func:
|
|
|
87
89
|
- The `avg` function should be used on numeric columns.
|
|
88
90
|
- Result column will always be of type float.
|
|
89
91
|
"""
|
|
90
|
-
return Func("avg", inner=
|
|
92
|
+
return Func("avg", inner=aggregate.avg, cols=[col], result_type=float)
|
|
91
93
|
|
|
92
94
|
|
|
93
95
|
def min(col: str) -> Func:
|
|
@@ -115,7 +117,7 @@ def min(col: str) -> Func:
|
|
|
115
117
|
- The `min` function can be used with numeric, date, and string columns.
|
|
116
118
|
- Result column will have the same type as the input column.
|
|
117
119
|
"""
|
|
118
|
-
return Func("min", inner=sa_func.min,
|
|
120
|
+
return Func("min", inner=sa_func.min, cols=[col])
|
|
119
121
|
|
|
120
122
|
|
|
121
123
|
def max(col: str) -> Func:
|
|
@@ -143,7 +145,7 @@ def max(col: str) -> Func:
|
|
|
143
145
|
- The `max` function can be used with numeric, date, and string columns.
|
|
144
146
|
- Result column will have the same type as the input column.
|
|
145
147
|
"""
|
|
146
|
-
return Func("max", inner=sa_func.max,
|
|
148
|
+
return Func("max", inner=sa_func.max, cols=[col])
|
|
147
149
|
|
|
148
150
|
|
|
149
151
|
def any_value(col: str) -> Func:
|
|
@@ -174,7 +176,7 @@ def any_value(col: str) -> Func:
|
|
|
174
176
|
- The result of `any_value` is non-deterministic,
|
|
175
177
|
meaning it may return different values for different executions.
|
|
176
178
|
"""
|
|
177
|
-
return Func("any_value", inner=
|
|
179
|
+
return Func("any_value", inner=aggregate.any_value, cols=[col])
|
|
178
180
|
|
|
179
181
|
|
|
180
182
|
def collect(col: str) -> Func:
|
|
@@ -203,7 +205,7 @@ def collect(col: str) -> Func:
|
|
|
203
205
|
- The `collect` function can be used with numeric and string columns.
|
|
204
206
|
- Result column will have an array type.
|
|
205
207
|
"""
|
|
206
|
-
return Func("collect", inner=
|
|
208
|
+
return Func("collect", inner=aggregate.collect, cols=[col], is_array=True)
|
|
207
209
|
|
|
208
210
|
|
|
209
211
|
def concat(col: str, separator="") -> Func:
|
|
@@ -236,9 +238,9 @@ def concat(col: str, separator="") -> Func:
|
|
|
236
238
|
"""
|
|
237
239
|
|
|
238
240
|
def inner(arg):
|
|
239
|
-
return
|
|
241
|
+
return aggregate.group_concat(arg, separator)
|
|
240
242
|
|
|
241
|
-
return Func("concat", inner=inner,
|
|
243
|
+
return Func("concat", inner=inner, cols=[col], result_type=str)
|
|
242
244
|
|
|
243
245
|
|
|
244
246
|
def row_number() -> Func:
|
|
@@ -350,4 +352,4 @@ def first(col: str) -> Func:
|
|
|
350
352
|
in the specified order.
|
|
351
353
|
- The result column will have the same type as the input column.
|
|
352
354
|
"""
|
|
353
|
-
return Func("first", inner=sa_func.first_value,
|
|
355
|
+
return Func("first", inner=sa_func.first_value, cols=[col], is_window=True)
|
datachain/func/array.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from datachain.sql.functions import array
|
|
5
|
+
|
|
6
|
+
from .func import Func
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def cosine_distance(*args: Union[str, Sequence]) -> Func:
|
|
10
|
+
"""
|
|
11
|
+
Computes the cosine distance between two vectors.
|
|
12
|
+
|
|
13
|
+
The cosine distance is derived from the cosine similarity, which measures the angle
|
|
14
|
+
between two vectors. This function returns the dissimilarity between the vectors,
|
|
15
|
+
where 0 indicates identical vectors and values closer to 1
|
|
16
|
+
indicate higher dissimilarity.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
args (str | Sequence): Two vectors to compute the cosine distance between.
|
|
20
|
+
If a string is provided, it is assumed to be the name of the column vector.
|
|
21
|
+
If a sequence is provided, it is assumed to be a vector of values.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Func: A Func object that represents the cosine_distance function.
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
```py
|
|
28
|
+
target_embedding = [0.1, 0.2, 0.3]
|
|
29
|
+
dc.mutate(
|
|
30
|
+
cos_dist1=func.cosine_distance("embedding", target_embedding),
|
|
31
|
+
cos_dist2=func.cosine_distance(target_embedding, [0.4, 0.5, 0.6]),
|
|
32
|
+
)
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
Notes:
|
|
36
|
+
- Ensure both vectors have the same number of elements.
|
|
37
|
+
- Result column will always be of type float.
|
|
38
|
+
"""
|
|
39
|
+
cols, func_args = [], []
|
|
40
|
+
for arg in args:
|
|
41
|
+
if isinstance(arg, str):
|
|
42
|
+
cols.append(arg)
|
|
43
|
+
else:
|
|
44
|
+
func_args.append(list(arg))
|
|
45
|
+
|
|
46
|
+
if len(cols) + len(func_args) != 2:
|
|
47
|
+
raise ValueError("cosine_distance() requires exactly two arguments")
|
|
48
|
+
if not cols and len(func_args[0]) != len(func_args[1]):
|
|
49
|
+
raise ValueError("cosine_distance() requires vectors of the same length")
|
|
50
|
+
|
|
51
|
+
return Func(
|
|
52
|
+
"cosine_distance",
|
|
53
|
+
inner=array.cosine_distance,
|
|
54
|
+
cols=cols,
|
|
55
|
+
args=func_args,
|
|
56
|
+
result_type=float,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def euclidean_distance(*args: Union[str, Sequence]) -> Func:
|
|
61
|
+
"""
|
|
62
|
+
Computes the Euclidean distance between two vectors.
|
|
63
|
+
|
|
64
|
+
The Euclidean distance is the straight-line distance between two points
|
|
65
|
+
in Euclidean space. This function returns the distance between the two vectors.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
args (str | Sequence): Two vectors to compute the Euclidean distance between.
|
|
69
|
+
If a string is provided, it is assumed to be the name of the column vector.
|
|
70
|
+
If a sequence is provided, it is assumed to be a vector of values.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Func: A Func object that represents the euclidean_distance function.
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
```py
|
|
77
|
+
target_embedding = [0.1, 0.2, 0.3]
|
|
78
|
+
dc.mutate(
|
|
79
|
+
eu_dist1=func.euclidean_distance("embedding", target_embedding),
|
|
80
|
+
eu_dist2=func.euclidean_distance(target_embedding, [0.4, 0.5, 0.6]),
|
|
81
|
+
)
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
Notes:
|
|
85
|
+
- Ensure both vectors have the same number of elements.
|
|
86
|
+
- Result column will always be of type float.
|
|
87
|
+
"""
|
|
88
|
+
cols, func_args = [], []
|
|
89
|
+
for arg in args:
|
|
90
|
+
if isinstance(arg, str):
|
|
91
|
+
cols.append(arg)
|
|
92
|
+
else:
|
|
93
|
+
func_args.append(list(arg))
|
|
94
|
+
|
|
95
|
+
if len(cols) + len(func_args) != 2:
|
|
96
|
+
raise ValueError("euclidean_distance() requires exactly two arguments")
|
|
97
|
+
if not cols and len(func_args[0]) != len(func_args[1]):
|
|
98
|
+
raise ValueError("euclidean_distance() requires vectors of the same length")
|
|
99
|
+
|
|
100
|
+
return Func(
|
|
101
|
+
"euclidean_distance",
|
|
102
|
+
inner=array.euclidean_distance,
|
|
103
|
+
cols=cols,
|
|
104
|
+
args=func_args,
|
|
105
|
+
result_type=float,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def length(arg: Union[str, Sequence, Func]) -> Func:
|
|
110
|
+
"""
|
|
111
|
+
Returns the length of the array.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
arg (str | Sequence | Func): Array to compute the length of.
|
|
115
|
+
If a string is provided, it is assumed to be the name of the array column.
|
|
116
|
+
If a sequence is provided, it is assumed to be an array of values.
|
|
117
|
+
If a Func is provided, it is assumed to be a function returning an array.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Func: A Func object that represents the array length function.
|
|
121
|
+
|
|
122
|
+
Example:
|
|
123
|
+
```py
|
|
124
|
+
dc.mutate(
|
|
125
|
+
len1=func.array.length("signal.values"),
|
|
126
|
+
len2=func.array.length([1, 2, 3, 4, 5]),
|
|
127
|
+
)
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
Note:
|
|
131
|
+
- Result column will always be of type int.
|
|
132
|
+
"""
|
|
133
|
+
if isinstance(arg, (str, Func)):
|
|
134
|
+
cols = [arg]
|
|
135
|
+
args = None
|
|
136
|
+
else:
|
|
137
|
+
cols = None
|
|
138
|
+
args = [arg]
|
|
139
|
+
|
|
140
|
+
return Func("length", inner=array.length, cols=cols, args=args, result_type=int)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def sip_hash_64(arg: Union[str, Sequence]) -> Func:
|
|
144
|
+
"""
|
|
145
|
+
Computes the SipHash-64 hash of the array.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
arg (str | Sequence): Array to compute the SipHash-64 hash of.
|
|
149
|
+
If a string is provided, it is assumed to be the name of the array column.
|
|
150
|
+
If a sequence is provided, it is assumed to be an array of values.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Func: A Func object that represents the sip_hash_64 function.
|
|
154
|
+
|
|
155
|
+
Example:
|
|
156
|
+
```py
|
|
157
|
+
dc.mutate(
|
|
158
|
+
hash1=func.sip_hash_64("signal.values"),
|
|
159
|
+
hash2=func.sip_hash_64([1, 2, 3, 4, 5]),
|
|
160
|
+
)
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
Note:
|
|
164
|
+
- This function is only available for the ClickHouse warehouse.
|
|
165
|
+
- Result column will always be of type int.
|
|
166
|
+
"""
|
|
167
|
+
if isinstance(arg, str):
|
|
168
|
+
cols = [arg]
|
|
169
|
+
args = None
|
|
170
|
+
else:
|
|
171
|
+
cols = None
|
|
172
|
+
args = [arg]
|
|
173
|
+
|
|
174
|
+
return Func(
|
|
175
|
+
"sip_hash_64", inner=array.sip_hash_64, cols=cols, args=args, result_type=int
|
|
176
|
+
)
|
datachain/func/base.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from abc import ABCMeta, abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from sqlalchemy import TableClause
|
|
6
|
+
|
|
7
|
+
from datachain.lib.signal_schema import SignalSchema
|
|
8
|
+
from datachain.query.schema import Column
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Function:
|
|
12
|
+
__metaclass__ = ABCMeta
|
|
13
|
+
|
|
14
|
+
name: str
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def get_column(
|
|
18
|
+
self,
|
|
19
|
+
signals_schema: Optional["SignalSchema"] = None,
|
|
20
|
+
label: Optional[str] = None,
|
|
21
|
+
table: Optional["TableClause"] = None,
|
|
22
|
+
) -> "Column":
|
|
23
|
+
pass
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from datachain.sql.functions import conditional
|
|
4
|
+
|
|
5
|
+
from .func import ColT, Func
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def greatest(*args: Union[ColT, float]) -> Func:
|
|
9
|
+
"""
|
|
10
|
+
Returns the greatest (largest) value from the given input values.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
args (ColT | str | int | float | Sequence): The values to compare.
|
|
14
|
+
If a string is provided, it is assumed to be the name of the column.
|
|
15
|
+
If a Func is provided, it is assumed to be a function returning a value.
|
|
16
|
+
If an int, float, or Sequence is provided, it is assumed to be a literal.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Func: A Func object that represents the greatest function.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
```py
|
|
23
|
+
dc.mutate(
|
|
24
|
+
greatest=func.greatest("signal.value", 0),
|
|
25
|
+
)
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
Note:
|
|
29
|
+
- Result column will always be of the same type as the input columns.
|
|
30
|
+
"""
|
|
31
|
+
cols, func_args = [], []
|
|
32
|
+
|
|
33
|
+
for arg in args:
|
|
34
|
+
if isinstance(arg, (str, Func)):
|
|
35
|
+
cols.append(arg)
|
|
36
|
+
else:
|
|
37
|
+
func_args.append(arg)
|
|
38
|
+
|
|
39
|
+
return Func(
|
|
40
|
+
"greatest",
|
|
41
|
+
inner=conditional.greatest,
|
|
42
|
+
cols=cols,
|
|
43
|
+
args=func_args,
|
|
44
|
+
result_type=int,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def least(*args: Union[ColT, float]) -> Func:
|
|
49
|
+
"""
|
|
50
|
+
Returns the least (smallest) value from the given input values.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
args (ColT | str | int | float | Sequence): The values to compare.
|
|
54
|
+
If a string is provided, it is assumed to be the name of the column.
|
|
55
|
+
If a Func is provided, it is assumed to be a function returning a value.
|
|
56
|
+
If an int, float, or Sequence is provided, it is assumed to be a literal.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Func: A Func object that represents the least function.
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
```py
|
|
63
|
+
dc.mutate(
|
|
64
|
+
least=func.least("signal.value", 0),
|
|
65
|
+
)
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
Note:
|
|
69
|
+
- Result column will always be of the same type as the input columns.
|
|
70
|
+
"""
|
|
71
|
+
cols, func_args = [], []
|
|
72
|
+
|
|
73
|
+
for arg in args:
|
|
74
|
+
if isinstance(arg, (str, Func)):
|
|
75
|
+
cols.append(arg)
|
|
76
|
+
else:
|
|
77
|
+
func_args.append(arg)
|
|
78
|
+
|
|
79
|
+
return Func(
|
|
80
|
+
"least", inner=conditional.least, cols=cols, args=func_args, result_type=int
|
|
81
|
+
)
|