datachain 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +0 -2
- datachain/catalog/catalog.py +12 -9
- datachain/cli.py +109 -9
- datachain/client/fsspec.py +9 -9
- datachain/data_storage/metastore.py +63 -11
- datachain/data_storage/schema.py +2 -2
- datachain/data_storage/sqlite.py +5 -4
- datachain/data_storage/warehouse.py +18 -18
- datachain/dataset.py +142 -14
- datachain/func/__init__.py +49 -0
- datachain/{lib/func → func}/aggregate.py +13 -11
- datachain/func/array.py +176 -0
- datachain/func/base.py +23 -0
- datachain/func/conditional.py +81 -0
- datachain/func/func.py +384 -0
- datachain/func/path.py +110 -0
- datachain/func/random.py +23 -0
- datachain/func/string.py +154 -0
- datachain/func/window.py +49 -0
- datachain/lib/arrow.py +24 -12
- datachain/lib/data_model.py +25 -9
- datachain/lib/dataset_info.py +9 -5
- datachain/lib/dc.py +94 -56
- datachain/lib/hf.py +1 -1
- datachain/lib/signal_schema.py +1 -1
- datachain/lib/utils.py +1 -0
- datachain/lib/webdataset_laion.py +5 -5
- datachain/model/bbox.py +2 -2
- datachain/model/pose.py +5 -5
- datachain/model/segment.py +2 -2
- datachain/nodes_fetcher.py +2 -2
- datachain/query/dataset.py +57 -34
- datachain/remote/studio.py +40 -8
- datachain/sql/__init__.py +0 -2
- datachain/sql/functions/__init__.py +0 -26
- datachain/sql/selectable.py +11 -5
- datachain/sql/sqlite/base.py +11 -2
- datachain/studio.py +29 -0
- {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/METADATA +2 -2
- {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/RECORD +44 -37
- datachain/lib/func/__init__.py +0 -32
- datachain/lib/func/func.py +0 -152
- {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/LICENSE +0 -0
- {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/WHEEL +0 -0
- {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/entry_points.txt +0 -0
- {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/top_level.txt +0 -0
datachain/dataset.py
CHANGED
|
@@ -15,7 +15,9 @@ from datachain.error import DatasetVersionNotFoundError
|
|
|
15
15
|
from datachain.sql.types import NAME_TYPES_MAPPING, SQLType
|
|
16
16
|
|
|
17
17
|
T = TypeVar("T", bound="DatasetRecord")
|
|
18
|
+
LT = TypeVar("LT", bound="DatasetListRecord")
|
|
18
19
|
V = TypeVar("V", bound="DatasetVersion")
|
|
20
|
+
LV = TypeVar("LV", bound="DatasetListVersion")
|
|
19
21
|
DD = TypeVar("DD", bound="DatasetDependency")
|
|
20
22
|
|
|
21
23
|
DATASET_PREFIX = "ds://"
|
|
@@ -264,6 +266,59 @@ class DatasetVersion:
|
|
|
264
266
|
return cls(**kwargs)
|
|
265
267
|
|
|
266
268
|
|
|
269
|
+
@dataclass
|
|
270
|
+
class DatasetListVersion:
|
|
271
|
+
id: int
|
|
272
|
+
uuid: str
|
|
273
|
+
dataset_id: int
|
|
274
|
+
version: int
|
|
275
|
+
status: int
|
|
276
|
+
created_at: datetime
|
|
277
|
+
finished_at: Optional[datetime]
|
|
278
|
+
error_message: str
|
|
279
|
+
error_stack: str
|
|
280
|
+
num_objects: Optional[int]
|
|
281
|
+
size: Optional[int]
|
|
282
|
+
query_script: str = ""
|
|
283
|
+
job_id: Optional[str] = None
|
|
284
|
+
|
|
285
|
+
@classmethod
|
|
286
|
+
def parse(
|
|
287
|
+
cls: type[LV],
|
|
288
|
+
id: int,
|
|
289
|
+
uuid: str,
|
|
290
|
+
dataset_id: int,
|
|
291
|
+
version: int,
|
|
292
|
+
status: int,
|
|
293
|
+
created_at: datetime,
|
|
294
|
+
finished_at: Optional[datetime],
|
|
295
|
+
error_message: str,
|
|
296
|
+
error_stack: str,
|
|
297
|
+
num_objects: Optional[int],
|
|
298
|
+
size: Optional[int],
|
|
299
|
+
query_script: str = "",
|
|
300
|
+
job_id: Optional[str] = None,
|
|
301
|
+
):
|
|
302
|
+
return cls(
|
|
303
|
+
id,
|
|
304
|
+
uuid,
|
|
305
|
+
dataset_id,
|
|
306
|
+
version,
|
|
307
|
+
status,
|
|
308
|
+
created_at,
|
|
309
|
+
finished_at,
|
|
310
|
+
error_message,
|
|
311
|
+
error_stack,
|
|
312
|
+
num_objects,
|
|
313
|
+
size,
|
|
314
|
+
query_script,
|
|
315
|
+
job_id,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
def __hash__(self):
|
|
319
|
+
return hash(f"{self.dataset_id}_{self.version}")
|
|
320
|
+
|
|
321
|
+
|
|
267
322
|
@dataclass
|
|
268
323
|
class DatasetRecord:
|
|
269
324
|
id: int
|
|
@@ -447,20 +502,6 @@ class DatasetRecord:
|
|
|
447
502
|
identifier = self.identifier(version)
|
|
448
503
|
return f"{DATASET_PREFIX}{identifier}"
|
|
449
504
|
|
|
450
|
-
@property
|
|
451
|
-
def is_bucket_listing(self) -> bool:
|
|
452
|
-
"""
|
|
453
|
-
For bucket listing we implicitly create underlying dataset to hold data. This
|
|
454
|
-
method is checking if this is one of those datasets.
|
|
455
|
-
"""
|
|
456
|
-
from datachain.client import Client
|
|
457
|
-
|
|
458
|
-
# TODO refactor and maybe remove method in
|
|
459
|
-
# https://github.com/iterative/datachain/issues/318
|
|
460
|
-
return Client.is_data_source_uri(self.name) or self.name.startswith(
|
|
461
|
-
LISTING_PREFIX
|
|
462
|
-
)
|
|
463
|
-
|
|
464
505
|
@property
|
|
465
506
|
def versions_values(self) -> list[int]:
|
|
466
507
|
"""
|
|
@@ -499,5 +540,92 @@ class DatasetRecord:
|
|
|
499
540
|
return cls(**kwargs, versions=versions)
|
|
500
541
|
|
|
501
542
|
|
|
543
|
+
@dataclass
|
|
544
|
+
class DatasetListRecord:
|
|
545
|
+
id: int
|
|
546
|
+
name: str
|
|
547
|
+
description: Optional[str]
|
|
548
|
+
labels: list[str]
|
|
549
|
+
versions: list[DatasetListVersion]
|
|
550
|
+
created_at: Optional[datetime] = None
|
|
551
|
+
|
|
552
|
+
@classmethod
|
|
553
|
+
def parse( # noqa: PLR0913
|
|
554
|
+
cls: type[LT],
|
|
555
|
+
id: int,
|
|
556
|
+
name: str,
|
|
557
|
+
description: Optional[str],
|
|
558
|
+
labels: str,
|
|
559
|
+
created_at: datetime,
|
|
560
|
+
version_id: int,
|
|
561
|
+
version_uuid: str,
|
|
562
|
+
version_dataset_id: int,
|
|
563
|
+
version: int,
|
|
564
|
+
version_status: int,
|
|
565
|
+
version_created_at: datetime,
|
|
566
|
+
version_finished_at: Optional[datetime],
|
|
567
|
+
version_error_message: str,
|
|
568
|
+
version_error_stack: str,
|
|
569
|
+
version_num_objects: Optional[int],
|
|
570
|
+
version_size: Optional[int],
|
|
571
|
+
version_query_script: Optional[str],
|
|
572
|
+
version_job_id: Optional[str] = None,
|
|
573
|
+
) -> "DatasetListRecord":
|
|
574
|
+
labels_lst: list[str] = json.loads(labels) if labels else []
|
|
575
|
+
|
|
576
|
+
dataset_version = DatasetListVersion.parse(
|
|
577
|
+
version_id,
|
|
578
|
+
version_uuid,
|
|
579
|
+
version_dataset_id,
|
|
580
|
+
version,
|
|
581
|
+
version_status,
|
|
582
|
+
version_created_at,
|
|
583
|
+
version_finished_at,
|
|
584
|
+
version_error_message,
|
|
585
|
+
version_error_stack,
|
|
586
|
+
version_num_objects,
|
|
587
|
+
version_size,
|
|
588
|
+
version_query_script, # type: ignore[arg-type]
|
|
589
|
+
version_job_id,
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
return cls(
|
|
593
|
+
id,
|
|
594
|
+
name,
|
|
595
|
+
description,
|
|
596
|
+
labels_lst,
|
|
597
|
+
[dataset_version],
|
|
598
|
+
created_at,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
def merge_versions(self, other: "DatasetListRecord") -> "DatasetListRecord":
|
|
602
|
+
"""Merge versions from another dataset"""
|
|
603
|
+
if other.id != self.id:
|
|
604
|
+
raise RuntimeError("Cannot merge versions of datasets with different ids")
|
|
605
|
+
if not other.versions:
|
|
606
|
+
# nothing to merge
|
|
607
|
+
return self
|
|
608
|
+
if not self.versions:
|
|
609
|
+
self.versions = []
|
|
610
|
+
|
|
611
|
+
self.versions = list(set(self.versions + other.versions))
|
|
612
|
+
self.versions.sort(key=lambda v: v.version)
|
|
613
|
+
return self
|
|
614
|
+
|
|
615
|
+
@property
|
|
616
|
+
def is_bucket_listing(self) -> bool:
|
|
617
|
+
"""
|
|
618
|
+
For bucket listing we implicitly create underlying dataset to hold data. This
|
|
619
|
+
method is checking if this is one of those datasets.
|
|
620
|
+
"""
|
|
621
|
+
from datachain.client import Client
|
|
622
|
+
|
|
623
|
+
# TODO refactor and maybe remove method in
|
|
624
|
+
# https://github.com/iterative/datachain/issues/318
|
|
625
|
+
return Client.is_data_source_uri(self.name) or self.name.startswith(
|
|
626
|
+
LISTING_PREFIX
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
|
|
502
630
|
class RowDict(dict):
|
|
503
631
|
pass
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from sqlalchemy import literal
|
|
2
|
+
|
|
3
|
+
from . import array, path, random, string
|
|
4
|
+
from .aggregate import (
|
|
5
|
+
any_value,
|
|
6
|
+
avg,
|
|
7
|
+
collect,
|
|
8
|
+
concat,
|
|
9
|
+
count,
|
|
10
|
+
dense_rank,
|
|
11
|
+
first,
|
|
12
|
+
max,
|
|
13
|
+
min,
|
|
14
|
+
rank,
|
|
15
|
+
row_number,
|
|
16
|
+
sum,
|
|
17
|
+
)
|
|
18
|
+
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
|
|
19
|
+
from .conditional import greatest, least
|
|
20
|
+
from .random import rand
|
|
21
|
+
from .window import window
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"any_value",
|
|
25
|
+
"array",
|
|
26
|
+
"avg",
|
|
27
|
+
"collect",
|
|
28
|
+
"concat",
|
|
29
|
+
"cosine_distance",
|
|
30
|
+
"count",
|
|
31
|
+
"dense_rank",
|
|
32
|
+
"euclidean_distance",
|
|
33
|
+
"first",
|
|
34
|
+
"greatest",
|
|
35
|
+
"least",
|
|
36
|
+
"length",
|
|
37
|
+
"literal",
|
|
38
|
+
"max",
|
|
39
|
+
"min",
|
|
40
|
+
"path",
|
|
41
|
+
"rand",
|
|
42
|
+
"random",
|
|
43
|
+
"rank",
|
|
44
|
+
"row_number",
|
|
45
|
+
"sip_hash_64",
|
|
46
|
+
"string",
|
|
47
|
+
"sum",
|
|
48
|
+
"window",
|
|
49
|
+
]
|
|
@@ -2,7 +2,7 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
from sqlalchemy import func as sa_func
|
|
4
4
|
|
|
5
|
-
from datachain.sql import
|
|
5
|
+
from datachain.sql.functions import aggregate
|
|
6
6
|
|
|
7
7
|
from .func import Func
|
|
8
8
|
|
|
@@ -31,7 +31,9 @@ def count(col: Optional[str] = None) -> Func:
|
|
|
31
31
|
Notes:
|
|
32
32
|
- Result column will always be of type int.
|
|
33
33
|
"""
|
|
34
|
-
return Func(
|
|
34
|
+
return Func(
|
|
35
|
+
"count", inner=sa_func.count, cols=[col] if col else None, result_type=int
|
|
36
|
+
)
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
def sum(col: str) -> Func:
|
|
@@ -59,7 +61,7 @@ def sum(col: str) -> Func:
|
|
|
59
61
|
- The `sum` function should be used on numeric columns.
|
|
60
62
|
- Result column type will be the same as the input column type.
|
|
61
63
|
"""
|
|
62
|
-
return Func("sum", inner=sa_func.sum,
|
|
64
|
+
return Func("sum", inner=sa_func.sum, cols=[col])
|
|
63
65
|
|
|
64
66
|
|
|
65
67
|
def avg(col: str) -> Func:
|
|
@@ -87,7 +89,7 @@ def avg(col: str) -> Func:
|
|
|
87
89
|
- The `avg` function should be used on numeric columns.
|
|
88
90
|
- Result column will always be of type float.
|
|
89
91
|
"""
|
|
90
|
-
return Func("avg", inner=
|
|
92
|
+
return Func("avg", inner=aggregate.avg, cols=[col], result_type=float)
|
|
91
93
|
|
|
92
94
|
|
|
93
95
|
def min(col: str) -> Func:
|
|
@@ -115,7 +117,7 @@ def min(col: str) -> Func:
|
|
|
115
117
|
- The `min` function can be used with numeric, date, and string columns.
|
|
116
118
|
- Result column will have the same type as the input column.
|
|
117
119
|
"""
|
|
118
|
-
return Func("min", inner=sa_func.min,
|
|
120
|
+
return Func("min", inner=sa_func.min, cols=[col])
|
|
119
121
|
|
|
120
122
|
|
|
121
123
|
def max(col: str) -> Func:
|
|
@@ -143,7 +145,7 @@ def max(col: str) -> Func:
|
|
|
143
145
|
- The `max` function can be used with numeric, date, and string columns.
|
|
144
146
|
- Result column will have the same type as the input column.
|
|
145
147
|
"""
|
|
146
|
-
return Func("max", inner=sa_func.max,
|
|
148
|
+
return Func("max", inner=sa_func.max, cols=[col])
|
|
147
149
|
|
|
148
150
|
|
|
149
151
|
def any_value(col: str) -> Func:
|
|
@@ -174,7 +176,7 @@ def any_value(col: str) -> Func:
|
|
|
174
176
|
- The result of `any_value` is non-deterministic,
|
|
175
177
|
meaning it may return different values for different executions.
|
|
176
178
|
"""
|
|
177
|
-
return Func("any_value", inner=
|
|
179
|
+
return Func("any_value", inner=aggregate.any_value, cols=[col])
|
|
178
180
|
|
|
179
181
|
|
|
180
182
|
def collect(col: str) -> Func:
|
|
@@ -203,7 +205,7 @@ def collect(col: str) -> Func:
|
|
|
203
205
|
- The `collect` function can be used with numeric and string columns.
|
|
204
206
|
- Result column will have an array type.
|
|
205
207
|
"""
|
|
206
|
-
return Func("collect", inner=
|
|
208
|
+
return Func("collect", inner=aggregate.collect, cols=[col], is_array=True)
|
|
207
209
|
|
|
208
210
|
|
|
209
211
|
def concat(col: str, separator="") -> Func:
|
|
@@ -236,9 +238,9 @@ def concat(col: str, separator="") -> Func:
|
|
|
236
238
|
"""
|
|
237
239
|
|
|
238
240
|
def inner(arg):
|
|
239
|
-
return
|
|
241
|
+
return aggregate.group_concat(arg, separator)
|
|
240
242
|
|
|
241
|
-
return Func("concat", inner=inner,
|
|
243
|
+
return Func("concat", inner=inner, cols=[col], result_type=str)
|
|
242
244
|
|
|
243
245
|
|
|
244
246
|
def row_number() -> Func:
|
|
@@ -350,4 +352,4 @@ def first(col: str) -> Func:
|
|
|
350
352
|
in the specified order.
|
|
351
353
|
- The result column will have the same type as the input column.
|
|
352
354
|
"""
|
|
353
|
-
return Func("first", inner=sa_func.first_value,
|
|
355
|
+
return Func("first", inner=sa_func.first_value, cols=[col], is_window=True)
|
datachain/func/array.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from datachain.sql.functions import array
|
|
5
|
+
|
|
6
|
+
from .func import Func
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def cosine_distance(*args: Union[str, Sequence]) -> Func:
|
|
10
|
+
"""
|
|
11
|
+
Computes the cosine distance between two vectors.
|
|
12
|
+
|
|
13
|
+
The cosine distance is derived from the cosine similarity, which measures the angle
|
|
14
|
+
between two vectors. This function returns the dissimilarity between the vectors,
|
|
15
|
+
where 0 indicates identical vectors and values closer to 1
|
|
16
|
+
indicate higher dissimilarity.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
args (str | Sequence): Two vectors to compute the cosine distance between.
|
|
20
|
+
If a string is provided, it is assumed to be the name of the column vector.
|
|
21
|
+
If a sequence is provided, it is assumed to be a vector of values.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Func: A Func object that represents the cosine_distance function.
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
```py
|
|
28
|
+
target_embedding = [0.1, 0.2, 0.3]
|
|
29
|
+
dc.mutate(
|
|
30
|
+
cos_dist1=func.cosine_distance("embedding", target_embedding),
|
|
31
|
+
cos_dist2=func.cosine_distance(target_embedding, [0.4, 0.5, 0.6]),
|
|
32
|
+
)
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
Notes:
|
|
36
|
+
- Ensure both vectors have the same number of elements.
|
|
37
|
+
- Result column will always be of type float.
|
|
38
|
+
"""
|
|
39
|
+
cols, func_args = [], []
|
|
40
|
+
for arg in args:
|
|
41
|
+
if isinstance(arg, str):
|
|
42
|
+
cols.append(arg)
|
|
43
|
+
else:
|
|
44
|
+
func_args.append(list(arg))
|
|
45
|
+
|
|
46
|
+
if len(cols) + len(func_args) != 2:
|
|
47
|
+
raise ValueError("cosine_distance() requires exactly two arguments")
|
|
48
|
+
if not cols and len(func_args[0]) != len(func_args[1]):
|
|
49
|
+
raise ValueError("cosine_distance() requires vectors of the same length")
|
|
50
|
+
|
|
51
|
+
return Func(
|
|
52
|
+
"cosine_distance",
|
|
53
|
+
inner=array.cosine_distance,
|
|
54
|
+
cols=cols,
|
|
55
|
+
args=func_args,
|
|
56
|
+
result_type=float,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def euclidean_distance(*args: Union[str, Sequence]) -> Func:
|
|
61
|
+
"""
|
|
62
|
+
Computes the Euclidean distance between two vectors.
|
|
63
|
+
|
|
64
|
+
The Euclidean distance is the straight-line distance between two points
|
|
65
|
+
in Euclidean space. This function returns the distance between the two vectors.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
args (str | Sequence): Two vectors to compute the Euclidean distance between.
|
|
69
|
+
If a string is provided, it is assumed to be the name of the column vector.
|
|
70
|
+
If a sequence is provided, it is assumed to be a vector of values.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Func: A Func object that represents the euclidean_distance function.
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
```py
|
|
77
|
+
target_embedding = [0.1, 0.2, 0.3]
|
|
78
|
+
dc.mutate(
|
|
79
|
+
eu_dist1=func.euclidean_distance("embedding", target_embedding),
|
|
80
|
+
eu_dist2=func.euclidean_distance(target_embedding, [0.4, 0.5, 0.6]),
|
|
81
|
+
)
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
Notes:
|
|
85
|
+
- Ensure both vectors have the same number of elements.
|
|
86
|
+
- Result column will always be of type float.
|
|
87
|
+
"""
|
|
88
|
+
cols, func_args = [], []
|
|
89
|
+
for arg in args:
|
|
90
|
+
if isinstance(arg, str):
|
|
91
|
+
cols.append(arg)
|
|
92
|
+
else:
|
|
93
|
+
func_args.append(list(arg))
|
|
94
|
+
|
|
95
|
+
if len(cols) + len(func_args) != 2:
|
|
96
|
+
raise ValueError("euclidean_distance() requires exactly two arguments")
|
|
97
|
+
if not cols and len(func_args[0]) != len(func_args[1]):
|
|
98
|
+
raise ValueError("euclidean_distance() requires vectors of the same length")
|
|
99
|
+
|
|
100
|
+
return Func(
|
|
101
|
+
"euclidean_distance",
|
|
102
|
+
inner=array.euclidean_distance,
|
|
103
|
+
cols=cols,
|
|
104
|
+
args=func_args,
|
|
105
|
+
result_type=float,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def length(arg: Union[str, Sequence, Func]) -> Func:
|
|
110
|
+
"""
|
|
111
|
+
Returns the length of the array.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
arg (str | Sequence | Func): Array to compute the length of.
|
|
115
|
+
If a string is provided, it is assumed to be the name of the array column.
|
|
116
|
+
If a sequence is provided, it is assumed to be an array of values.
|
|
117
|
+
If a Func is provided, it is assumed to be a function returning an array.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Func: A Func object that represents the array length function.
|
|
121
|
+
|
|
122
|
+
Example:
|
|
123
|
+
```py
|
|
124
|
+
dc.mutate(
|
|
125
|
+
len1=func.array.length("signal.values"),
|
|
126
|
+
len2=func.array.length([1, 2, 3, 4, 5]),
|
|
127
|
+
)
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
Note:
|
|
131
|
+
- Result column will always be of type int.
|
|
132
|
+
"""
|
|
133
|
+
if isinstance(arg, (str, Func)):
|
|
134
|
+
cols = [arg]
|
|
135
|
+
args = None
|
|
136
|
+
else:
|
|
137
|
+
cols = None
|
|
138
|
+
args = [arg]
|
|
139
|
+
|
|
140
|
+
return Func("length", inner=array.length, cols=cols, args=args, result_type=int)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def sip_hash_64(arg: Union[str, Sequence]) -> Func:
|
|
144
|
+
"""
|
|
145
|
+
Computes the SipHash-64 hash of the array.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
arg (str | Sequence): Array to compute the SipHash-64 hash of.
|
|
149
|
+
If a string is provided, it is assumed to be the name of the array column.
|
|
150
|
+
If a sequence is provided, it is assumed to be an array of values.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Func: A Func object that represents the sip_hash_64 function.
|
|
154
|
+
|
|
155
|
+
Example:
|
|
156
|
+
```py
|
|
157
|
+
dc.mutate(
|
|
158
|
+
hash1=func.sip_hash_64("signal.values"),
|
|
159
|
+
hash2=func.sip_hash_64([1, 2, 3, 4, 5]),
|
|
160
|
+
)
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
Note:
|
|
164
|
+
- This function is only available for the ClickHouse warehouse.
|
|
165
|
+
- Result column will always be of type int.
|
|
166
|
+
"""
|
|
167
|
+
if isinstance(arg, str):
|
|
168
|
+
cols = [arg]
|
|
169
|
+
args = None
|
|
170
|
+
else:
|
|
171
|
+
cols = None
|
|
172
|
+
args = [arg]
|
|
173
|
+
|
|
174
|
+
return Func(
|
|
175
|
+
"sip_hash_64", inner=array.sip_hash_64, cols=cols, args=args, result_type=int
|
|
176
|
+
)
|
datachain/func/base.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from abc import ABCMeta, abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from sqlalchemy import TableClause
|
|
6
|
+
|
|
7
|
+
from datachain.lib.signal_schema import SignalSchema
|
|
8
|
+
from datachain.query.schema import Column
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Function:
|
|
12
|
+
__metaclass__ = ABCMeta
|
|
13
|
+
|
|
14
|
+
name: str
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def get_column(
|
|
18
|
+
self,
|
|
19
|
+
signals_schema: Optional["SignalSchema"] = None,
|
|
20
|
+
label: Optional[str] = None,
|
|
21
|
+
table: Optional["TableClause"] = None,
|
|
22
|
+
) -> "Column":
|
|
23
|
+
pass
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from datachain.sql.functions import conditional
|
|
4
|
+
|
|
5
|
+
from .func import ColT, Func
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def greatest(*args: Union[ColT, float]) -> Func:
|
|
9
|
+
"""
|
|
10
|
+
Returns the greatest (largest) value from the given input values.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
args (ColT | str | int | float | Sequence): The values to compare.
|
|
14
|
+
If a string is provided, it is assumed to be the name of the column.
|
|
15
|
+
If a Func is provided, it is assumed to be a function returning a value.
|
|
16
|
+
If an int, float, or Sequence is provided, it is assumed to be a literal.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Func: A Func object that represents the greatest function.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
```py
|
|
23
|
+
dc.mutate(
|
|
24
|
+
greatest=func.greatest("signal.value", 0),
|
|
25
|
+
)
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
Note:
|
|
29
|
+
- Result column will always be of the same type as the input columns.
|
|
30
|
+
"""
|
|
31
|
+
cols, func_args = [], []
|
|
32
|
+
|
|
33
|
+
for arg in args:
|
|
34
|
+
if isinstance(arg, (str, Func)):
|
|
35
|
+
cols.append(arg)
|
|
36
|
+
else:
|
|
37
|
+
func_args.append(arg)
|
|
38
|
+
|
|
39
|
+
return Func(
|
|
40
|
+
"greatest",
|
|
41
|
+
inner=conditional.greatest,
|
|
42
|
+
cols=cols,
|
|
43
|
+
args=func_args,
|
|
44
|
+
result_type=int,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def least(*args: Union[ColT, float]) -> Func:
|
|
49
|
+
"""
|
|
50
|
+
Returns the least (smallest) value from the given input values.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
args (ColT | str | int | float | Sequence): The values to compare.
|
|
54
|
+
If a string is provided, it is assumed to be the name of the column.
|
|
55
|
+
If a Func is provided, it is assumed to be a function returning a value.
|
|
56
|
+
If an int, float, or Sequence is provided, it is assumed to be a literal.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Func: A Func object that represents the least function.
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
```py
|
|
63
|
+
dc.mutate(
|
|
64
|
+
least=func.least("signal.value", 0),
|
|
65
|
+
)
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
Note:
|
|
69
|
+
- Result column will always be of the same type as the input columns.
|
|
70
|
+
"""
|
|
71
|
+
cols, func_args = [], []
|
|
72
|
+
|
|
73
|
+
for arg in args:
|
|
74
|
+
if isinstance(arg, (str, Func)):
|
|
75
|
+
cols.append(arg)
|
|
76
|
+
else:
|
|
77
|
+
func_args.append(arg)
|
|
78
|
+
|
|
79
|
+
return Func(
|
|
80
|
+
"least", inner=conditional.least, cols=cols, args=func_args, result_type=int
|
|
81
|
+
)
|