datachain 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (52) hide show
  1. datachain/__init__.py +0 -3
  2. datachain/catalog/catalog.py +8 -6
  3. datachain/cli.py +1 -1
  4. datachain/client/fsspec.py +9 -9
  5. datachain/data_storage/schema.py +2 -2
  6. datachain/data_storage/sqlite.py +5 -4
  7. datachain/data_storage/warehouse.py +18 -18
  8. datachain/func/__init__.py +49 -0
  9. datachain/{lib/func → func}/aggregate.py +13 -11
  10. datachain/func/array.py +176 -0
  11. datachain/func/base.py +23 -0
  12. datachain/func/conditional.py +81 -0
  13. datachain/func/func.py +384 -0
  14. datachain/func/path.py +110 -0
  15. datachain/func/random.py +23 -0
  16. datachain/func/string.py +154 -0
  17. datachain/func/window.py +49 -0
  18. datachain/lib/arrow.py +24 -12
  19. datachain/lib/data_model.py +25 -9
  20. datachain/lib/dataset_info.py +2 -2
  21. datachain/lib/dc.py +94 -56
  22. datachain/lib/hf.py +1 -1
  23. datachain/lib/signal_schema.py +1 -1
  24. datachain/lib/utils.py +1 -0
  25. datachain/lib/webdataset_laion.py +5 -5
  26. datachain/model/__init__.py +6 -0
  27. datachain/model/bbox.py +102 -0
  28. datachain/model/pose.py +88 -0
  29. datachain/model/segment.py +47 -0
  30. datachain/model/ultralytics/__init__.py +27 -0
  31. datachain/model/ultralytics/bbox.py +147 -0
  32. datachain/model/ultralytics/pose.py +113 -0
  33. datachain/model/ultralytics/segment.py +91 -0
  34. datachain/nodes_fetcher.py +2 -2
  35. datachain/query/dataset.py +57 -34
  36. datachain/sql/__init__.py +0 -2
  37. datachain/sql/functions/__init__.py +0 -26
  38. datachain/sql/selectable.py +11 -5
  39. datachain/sql/sqlite/base.py +11 -2
  40. datachain/toolkit/split.py +6 -2
  41. {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/METADATA +72 -71
  42. {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/RECORD +46 -35
  43. {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/WHEEL +1 -1
  44. datachain/lib/func/__init__.py +0 -32
  45. datachain/lib/func/func.py +0 -152
  46. datachain/lib/models/__init__.py +0 -5
  47. datachain/lib/models/bbox.py +0 -45
  48. datachain/lib/models/pose.py +0 -37
  49. datachain/lib/models/yolo.py +0 -39
  50. {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/LICENSE +0 -0
  51. {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/entry_points.txt +0 -0
  52. {datachain-0.7.0.dist-info → datachain-0.7.2.dist-info}/top_level.txt +0 -0
datachain/__init__.py CHANGED
@@ -1,4 +1,3 @@
1
- from datachain.lib import func, models
2
1
  from datachain.lib.data_model import DataModel, DataType, is_chain_type
3
2
  from datachain.lib.dc import C, Column, DataChain, Sys
4
3
  from datachain.lib.file import (
@@ -35,9 +34,7 @@ __all__ = [
35
34
  "Sys",
36
35
  "TarVFile",
37
36
  "TextFile",
38
- "func",
39
37
  "is_chain_type",
40
38
  "metrics",
41
- "models",
42
39
  "param",
43
40
  ]
@@ -54,7 +54,6 @@ from datachain.error import (
54
54
  QueryScriptCancelError,
55
55
  QueryScriptRunError,
56
56
  )
57
- from datachain.listing import Listing
58
57
  from datachain.node import DirType, Node, NodeWithPath
59
58
  from datachain.nodes_thread_pool import NodesThreadPool
60
59
  from datachain.remote.studio import StudioClient
@@ -76,6 +75,7 @@ if TYPE_CHECKING:
76
75
  from datachain.dataset import DatasetVersion
77
76
  from datachain.job import Job
78
77
  from datachain.lib.file import File
78
+ from datachain.listing import Listing
79
79
 
80
80
  logger = logging.getLogger("datachain")
81
81
 
@@ -236,7 +236,7 @@ class DatasetRowsFetcher(NodesThreadPool):
236
236
  class NodeGroup:
237
237
  """Class for a group of nodes from the same source"""
238
238
 
239
- listing: Listing
239
+ listing: "Listing"
240
240
  sources: list[DataSource]
241
241
 
242
242
  # The source path within the bucket
@@ -591,8 +591,9 @@ class Catalog:
591
591
  client_config=None,
592
592
  object_name="file",
593
593
  skip_indexing=False,
594
- ) -> tuple[Listing, str]:
594
+ ) -> tuple["Listing", str]:
595
595
  from datachain.lib.dc import DataChain
596
+ from datachain.listing import Listing
596
597
 
597
598
  DataChain.from_storage(
598
599
  source, session=self.session, update=update, object_name=object_name
@@ -660,7 +661,8 @@ class Catalog:
660
661
  no_glob: bool = False,
661
662
  client_config=None,
662
663
  ) -> list[NodeGroup]:
663
- from datachain.query import DatasetQuery
664
+ from datachain.listing import Listing
665
+ from datachain.query.dataset import DatasetQuery
664
666
 
665
667
  def _row_to_node(d: dict[str, Any]) -> Node:
666
668
  del d["file__source"]
@@ -876,7 +878,7 @@ class Catalog:
876
878
  def update_dataset_version_with_warehouse_info(
877
879
  self, dataset: DatasetRecord, version: int, rows_dropped=False, **kwargs
878
880
  ) -> None:
879
- from datachain.query import DatasetQuery
881
+ from datachain.query.dataset import DatasetQuery
880
882
 
881
883
  dataset_version = dataset.get_version(version)
882
884
 
@@ -1177,7 +1179,7 @@ class Catalog:
1177
1179
  def ls_dataset_rows(
1178
1180
  self, name: str, version: int, offset=None, limit=None
1179
1181
  ) -> list[dict]:
1180
- from datachain.query import DatasetQuery
1182
+ from datachain.query.dataset import DatasetQuery
1181
1183
 
1182
1184
  dataset = self.get_dataset(name)
1183
1185
 
datachain/cli.py CHANGED
@@ -957,7 +957,7 @@ def show(
957
957
  schema: bool = False,
958
958
  ) -> None:
959
959
  from datachain.lib.dc import DataChain
960
- from datachain.query import DatasetQuery
960
+ from datachain.query.dataset import DatasetQuery
961
961
  from datachain.utils import show_records
962
962
 
963
963
  dataset = catalog.get_dataset(name)
@@ -28,7 +28,6 @@ from tqdm import tqdm
28
28
  from datachain.cache import DataChainCache
29
29
  from datachain.client.fileslice import FileWrapper
30
30
  from datachain.error import ClientError as DataChainClientError
31
- from datachain.lib.file import File
32
31
  from datachain.nodes_fetcher import NodesFetcher
33
32
  from datachain.nodes_thread_pool import NodeChunk
34
33
 
@@ -36,6 +35,7 @@ if TYPE_CHECKING:
36
35
  from fsspec.spec import AbstractFileSystem
37
36
 
38
37
  from datachain.dataset import StorageURI
38
+ from datachain.lib.file import File
39
39
 
40
40
 
41
41
  logger = logging.getLogger("datachain")
@@ -45,7 +45,7 @@ DELIMITER = "/" # Path delimiter.
45
45
 
46
46
  DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$")
47
47
 
48
- ResultQueue = asyncio.Queue[Optional[Sequence[File]]]
48
+ ResultQueue = asyncio.Queue[Optional[Sequence["File"]]]
49
49
 
50
50
 
51
51
  def _is_win_local_path(uri: str) -> bool:
@@ -212,7 +212,7 @@ class Client(ABC):
212
212
 
213
213
  async def scandir(
214
214
  self, start_prefix: str, method: str = "default"
215
- ) -> AsyncIterator[Sequence[File]]:
215
+ ) -> AsyncIterator[Sequence["File"]]:
216
216
  try:
217
217
  impl = getattr(self, f"_fetch_{method}")
218
218
  except AttributeError:
@@ -317,7 +317,7 @@ class Client(ABC):
317
317
  return f"{self.PREFIX}{self.name}/{rel_path}"
318
318
 
319
319
  @abstractmethod
320
- def info_to_file(self, v: dict[str, Any], parent: str) -> File: ...
320
+ def info_to_file(self, v: dict[str, Any], parent: str) -> "File": ...
321
321
 
322
322
  def fetch_nodes(
323
323
  self,
@@ -354,7 +354,7 @@ class Client(ABC):
354
354
  copy2(src, dst)
355
355
 
356
356
  def open_object(
357
- self, file: File, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
357
+ self, file: "File", use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
358
358
  ) -> BinaryIO:
359
359
  """Open a file, including files in tar archives."""
360
360
  if use_cache and (cache_path := self.cache.get_path(file)):
@@ -362,19 +362,19 @@ class Client(ABC):
362
362
  assert not file.location
363
363
  return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value]
364
364
 
365
- def download(self, file: File, *, callback: Callback = DEFAULT_CALLBACK) -> None:
365
+ def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None:
366
366
  sync(get_loop(), functools.partial(self._download, file, callback=callback))
367
367
 
368
- async def _download(self, file: File, *, callback: "Callback" = None) -> None:
368
+ async def _download(self, file: "File", *, callback: "Callback" = None) -> None:
369
369
  if self.cache.contains(file):
370
370
  # Already in cache, so there's nothing to do.
371
371
  return
372
372
  await self._put_in_cache(file, callback=callback)
373
373
 
374
- def put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
374
+ def put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
375
375
  sync(get_loop(), functools.partial(self._put_in_cache, file, callback=callback))
376
376
 
377
- async def _put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
377
+ async def _put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
378
378
  assert not file.location
379
379
  if file.etag:
380
380
  etag = await self.get_current_etag(file)
@@ -12,7 +12,7 @@ import sqlalchemy as sa
12
12
  from sqlalchemy.sql import func as f
13
13
  from sqlalchemy.sql.expression import false, null, true
14
14
 
15
- from datachain.sql.functions import path
15
+ from datachain.sql.functions import path as pathfunc
16
16
  from datachain.sql.types import Int, SQLType, UInt64
17
17
 
18
18
  if TYPE_CHECKING:
@@ -130,7 +130,7 @@ class DirExpansion:
130
130
 
131
131
  def query(self, q):
132
132
  q = self.base_select(q).cte(recursive=True)
133
- parent = path.parent(self.c(q, "path"))
133
+ parent = pathfunc.parent(self.c(q, "path"))
134
134
  q = q.union_all(
135
135
  sa.select(
136
136
  sa.literal(-1).label("sys__id"),
@@ -122,7 +122,9 @@ class SQLiteDatabaseEngine(DatabaseEngine):
122
122
  return cls(*cls._connect(db_file=db_file))
123
123
 
124
124
  @staticmethod
125
- def _connect(db_file: Optional[str] = None):
125
+ def _connect(
126
+ db_file: Optional[str] = None,
127
+ ) -> tuple["Engine", "MetaData", sqlite3.Connection, str]:
126
128
  try:
127
129
  if db_file == ":memory:":
128
130
  # Enable multithreaded usage of the same in-memory db
@@ -130,9 +132,8 @@ class SQLiteDatabaseEngine(DatabaseEngine):
130
132
  _get_in_memory_uri(), uri=True, detect_types=DETECT_TYPES
131
133
  )
132
134
  else:
133
- db = sqlite3.connect(
134
- db_file or DataChainDir.find().db, detect_types=DETECT_TYPES
135
- )
135
+ db_file = db_file or DataChainDir.find().db
136
+ db = sqlite3.connect(db_file, detect_types=DETECT_TYPES)
136
137
  create_user_defined_sql_functions(db)
137
138
  engine = sqlalchemy.create_engine(
138
139
  "sqlite+pysqlite:///", creator=lambda: db, future=True
@@ -224,28 +224,28 @@ class AbstractWarehouse(ABC, Serializable):
224
224
  offset = 0
225
225
  num_yielded = 0
226
226
 
227
- while True:
228
- if limit is not None:
229
- limit -= num_yielded
230
- if limit == 0:
231
- break
232
- if limit < page_size:
233
- paginated_query = paginated_query.limit(None).limit(limit)
234
-
235
- # Ensure we're using a thread-local connection
236
- with self.clone() as wh:
227
+ # Ensure we're using a thread-local connection
228
+ with self.clone() as wh:
229
+ while True:
230
+ if limit is not None:
231
+ limit -= num_yielded
232
+ if limit == 0:
233
+ break
234
+ if limit < page_size:
235
+ paginated_query = paginated_query.limit(None).limit(limit)
236
+
237
237
  # Cursor results are not thread-safe, so we convert them to a list
238
238
  results = list(wh.dataset_rows_select(paginated_query.offset(offset)))
239
239
 
240
- processed = False
241
- for row in results:
242
- processed = True
243
- yield row
244
- num_yielded += 1
240
+ processed = False
241
+ for row in results:
242
+ processed = True
243
+ yield row
244
+ num_yielded += 1
245
245
 
246
- if not processed:
247
- break # no more results
248
- offset += page_size
246
+ if not processed:
247
+ break # no more results
248
+ offset += page_size
249
249
 
250
250
  #
251
251
  # Table Name Internal Functions
@@ -0,0 +1,49 @@
1
+ from sqlalchemy import literal
2
+
3
+ from . import array, path, random, string
4
+ from .aggregate import (
5
+ any_value,
6
+ avg,
7
+ collect,
8
+ concat,
9
+ count,
10
+ dense_rank,
11
+ first,
12
+ max,
13
+ min,
14
+ rank,
15
+ row_number,
16
+ sum,
17
+ )
18
+ from .array import cosine_distance, euclidean_distance, length, sip_hash_64
19
+ from .conditional import greatest, least
20
+ from .random import rand
21
+ from .window import window
22
+
23
+ __all__ = [
24
+ "any_value",
25
+ "array",
26
+ "avg",
27
+ "collect",
28
+ "concat",
29
+ "cosine_distance",
30
+ "count",
31
+ "dense_rank",
32
+ "euclidean_distance",
33
+ "first",
34
+ "greatest",
35
+ "least",
36
+ "length",
37
+ "literal",
38
+ "max",
39
+ "min",
40
+ "path",
41
+ "rand",
42
+ "random",
43
+ "rank",
44
+ "row_number",
45
+ "sip_hash_64",
46
+ "string",
47
+ "sum",
48
+ "window",
49
+ ]
@@ -2,7 +2,7 @@ from typing import Optional
2
2
 
3
3
  from sqlalchemy import func as sa_func
4
4
 
5
- from datachain.sql import functions as dc_func
5
+ from datachain.sql.functions import aggregate
6
6
 
7
7
  from .func import Func
8
8
 
@@ -31,7 +31,9 @@ def count(col: Optional[str] = None) -> Func:
31
31
  Notes:
32
32
  - Result column will always be of type int.
33
33
  """
34
- return Func("count", inner=sa_func.count, col=col, result_type=int)
34
+ return Func(
35
+ "count", inner=sa_func.count, cols=[col] if col else None, result_type=int
36
+ )
35
37
 
36
38
 
37
39
  def sum(col: str) -> Func:
@@ -59,7 +61,7 @@ def sum(col: str) -> Func:
59
61
  - The `sum` function should be used on numeric columns.
60
62
  - Result column type will be the same as the input column type.
61
63
  """
62
- return Func("sum", inner=sa_func.sum, col=col)
64
+ return Func("sum", inner=sa_func.sum, cols=[col])
63
65
 
64
66
 
65
67
  def avg(col: str) -> Func:
@@ -87,7 +89,7 @@ def avg(col: str) -> Func:
87
89
  - The `avg` function should be used on numeric columns.
88
90
  - Result column will always be of type float.
89
91
  """
90
- return Func("avg", inner=dc_func.aggregate.avg, col=col, result_type=float)
92
+ return Func("avg", inner=aggregate.avg, cols=[col], result_type=float)
91
93
 
92
94
 
93
95
  def min(col: str) -> Func:
@@ -115,7 +117,7 @@ def min(col: str) -> Func:
115
117
  - The `min` function can be used with numeric, date, and string columns.
116
118
  - Result column will have the same type as the input column.
117
119
  """
118
- return Func("min", inner=sa_func.min, col=col)
120
+ return Func("min", inner=sa_func.min, cols=[col])
119
121
 
120
122
 
121
123
  def max(col: str) -> Func:
@@ -143,7 +145,7 @@ def max(col: str) -> Func:
143
145
  - The `max` function can be used with numeric, date, and string columns.
144
146
  - Result column will have the same type as the input column.
145
147
  """
146
- return Func("max", inner=sa_func.max, col=col)
148
+ return Func("max", inner=sa_func.max, cols=[col])
147
149
 
148
150
 
149
151
  def any_value(col: str) -> Func:
@@ -174,7 +176,7 @@ def any_value(col: str) -> Func:
174
176
  - The result of `any_value` is non-deterministic,
175
177
  meaning it may return different values for different executions.
176
178
  """
177
- return Func("any_value", inner=dc_func.aggregate.any_value, col=col)
179
+ return Func("any_value", inner=aggregate.any_value, cols=[col])
178
180
 
179
181
 
180
182
  def collect(col: str) -> Func:
@@ -203,7 +205,7 @@ def collect(col: str) -> Func:
203
205
  - The `collect` function can be used with numeric and string columns.
204
206
  - Result column will have an array type.
205
207
  """
206
- return Func("collect", inner=dc_func.aggregate.collect, col=col, is_array=True)
208
+ return Func("collect", inner=aggregate.collect, cols=[col], is_array=True)
207
209
 
208
210
 
209
211
  def concat(col: str, separator="") -> Func:
@@ -236,9 +238,9 @@ def concat(col: str, separator="") -> Func:
236
238
  """
237
239
 
238
240
  def inner(arg):
239
- return dc_func.aggregate.group_concat(arg, separator)
241
+ return aggregate.group_concat(arg, separator)
240
242
 
241
- return Func("concat", inner=inner, col=col, result_type=str)
243
+ return Func("concat", inner=inner, cols=[col], result_type=str)
242
244
 
243
245
 
244
246
  def row_number() -> Func:
@@ -350,4 +352,4 @@ def first(col: str) -> Func:
350
352
  in the specified order.
351
353
  - The result column will have the same type as the input column.
352
354
  """
353
- return Func("first", inner=sa_func.first_value, col=col, is_window=True)
355
+ return Func("first", inner=sa_func.first_value, cols=[col], is_window=True)
@@ -0,0 +1,176 @@
1
+ from collections.abc import Sequence
2
+ from typing import Union
3
+
4
+ from datachain.sql.functions import array
5
+
6
+ from .func import Func
7
+
8
+
9
+ def cosine_distance(*args: Union[str, Sequence]) -> Func:
10
+ """
11
+ Computes the cosine distance between two vectors.
12
+
13
+ The cosine distance is derived from the cosine similarity, which measures the angle
14
+ between two vectors. This function returns the dissimilarity between the vectors,
15
+ where 0 indicates identical vectors and values closer to 1
16
+ indicate higher dissimilarity.
17
+
18
+ Args:
19
+ args (str | Sequence): Two vectors to compute the cosine distance between.
20
+ If a string is provided, it is assumed to be the name of the column vector.
21
+ If a sequence is provided, it is assumed to be a vector of values.
22
+
23
+ Returns:
24
+ Func: A Func object that represents the cosine_distance function.
25
+
26
+ Example:
27
+ ```py
28
+ target_embedding = [0.1, 0.2, 0.3]
29
+ dc.mutate(
30
+ cos_dist1=func.cosine_distance("embedding", target_embedding),
31
+ cos_dist2=func.cosine_distance(target_embedding, [0.4, 0.5, 0.6]),
32
+ )
33
+ ```
34
+
35
+ Notes:
36
+ - Ensure both vectors have the same number of elements.
37
+ - Result column will always be of type float.
38
+ """
39
+ cols, func_args = [], []
40
+ for arg in args:
41
+ if isinstance(arg, str):
42
+ cols.append(arg)
43
+ else:
44
+ func_args.append(list(arg))
45
+
46
+ if len(cols) + len(func_args) != 2:
47
+ raise ValueError("cosine_distance() requires exactly two arguments")
48
+ if not cols and len(func_args[0]) != len(func_args[1]):
49
+ raise ValueError("cosine_distance() requires vectors of the same length")
50
+
51
+ return Func(
52
+ "cosine_distance",
53
+ inner=array.cosine_distance,
54
+ cols=cols,
55
+ args=func_args,
56
+ result_type=float,
57
+ )
58
+
59
+
60
+ def euclidean_distance(*args: Union[str, Sequence]) -> Func:
61
+ """
62
+ Computes the Euclidean distance between two vectors.
63
+
64
+ The Euclidean distance is the straight-line distance between two points
65
+ in Euclidean space. This function returns the distance between the two vectors.
66
+
67
+ Args:
68
+ args (str | Sequence): Two vectors to compute the Euclidean distance between.
69
+ If a string is provided, it is assumed to be the name of the column vector.
70
+ If a sequence is provided, it is assumed to be a vector of values.
71
+
72
+ Returns:
73
+ Func: A Func object that represents the euclidean_distance function.
74
+
75
+ Example:
76
+ ```py
77
+ target_embedding = [0.1, 0.2, 0.3]
78
+ dc.mutate(
79
+ eu_dist1=func.euclidean_distance("embedding", target_embedding),
80
+ eu_dist2=func.euclidean_distance(target_embedding, [0.4, 0.5, 0.6]),
81
+ )
82
+ ```
83
+
84
+ Notes:
85
+ - Ensure both vectors have the same number of elements.
86
+ - Result column will always be of type float.
87
+ """
88
+ cols, func_args = [], []
89
+ for arg in args:
90
+ if isinstance(arg, str):
91
+ cols.append(arg)
92
+ else:
93
+ func_args.append(list(arg))
94
+
95
+ if len(cols) + len(func_args) != 2:
96
+ raise ValueError("euclidean_distance() requires exactly two arguments")
97
+ if not cols and len(func_args[0]) != len(func_args[1]):
98
+ raise ValueError("euclidean_distance() requires vectors of the same length")
99
+
100
+ return Func(
101
+ "euclidean_distance",
102
+ inner=array.euclidean_distance,
103
+ cols=cols,
104
+ args=func_args,
105
+ result_type=float,
106
+ )
107
+
108
+
109
+ def length(arg: Union[str, Sequence, Func]) -> Func:
110
+ """
111
+ Returns the length of the array.
112
+
113
+ Args:
114
+ arg (str | Sequence | Func): Array to compute the length of.
115
+ If a string is provided, it is assumed to be the name of the array column.
116
+ If a sequence is provided, it is assumed to be an array of values.
117
+ If a Func is provided, it is assumed to be a function returning an array.
118
+
119
+ Returns:
120
+ Func: A Func object that represents the array length function.
121
+
122
+ Example:
123
+ ```py
124
+ dc.mutate(
125
+ len1=func.array.length("signal.values"),
126
+ len2=func.array.length([1, 2, 3, 4, 5]),
127
+ )
128
+ ```
129
+
130
+ Note:
131
+ - Result column will always be of type int.
132
+ """
133
+ if isinstance(arg, (str, Func)):
134
+ cols = [arg]
135
+ args = None
136
+ else:
137
+ cols = None
138
+ args = [arg]
139
+
140
+ return Func("length", inner=array.length, cols=cols, args=args, result_type=int)
141
+
142
+
143
+ def sip_hash_64(arg: Union[str, Sequence]) -> Func:
144
+ """
145
+ Computes the SipHash-64 hash of the array.
146
+
147
+ Args:
148
+ arg (str | Sequence): Array to compute the SipHash-64 hash of.
149
+ If a string is provided, it is assumed to be the name of the array column.
150
+ If a sequence is provided, it is assumed to be an array of values.
151
+
152
+ Returns:
153
+ Func: A Func object that represents the sip_hash_64 function.
154
+
155
+ Example:
156
+ ```py
157
+ dc.mutate(
158
+ hash1=func.sip_hash_64("signal.values"),
159
+ hash2=func.sip_hash_64([1, 2, 3, 4, 5]),
160
+ )
161
+ ```
162
+
163
+ Note:
164
+ - This function is only available for the ClickHouse warehouse.
165
+ - Result column will always be of type int.
166
+ """
167
+ if isinstance(arg, str):
168
+ cols = [arg]
169
+ args = None
170
+ else:
171
+ cols = None
172
+ args = [arg]
173
+
174
+ return Func(
175
+ "sip_hash_64", inner=array.sip_hash_64, cols=cols, args=args, result_type=int
176
+ )
datachain/func/base.py ADDED
@@ -0,0 +1,23 @@
1
+ from abc import ABCMeta, abstractmethod
2
+ from typing import TYPE_CHECKING, Optional
3
+
4
+ if TYPE_CHECKING:
5
+ from sqlalchemy import TableClause
6
+
7
+ from datachain.lib.signal_schema import SignalSchema
8
+ from datachain.query.schema import Column
9
+
10
+
11
+ class Function:
12
+ __metaclass__ = ABCMeta
13
+
14
+ name: str
15
+
16
+ @abstractmethod
17
+ def get_column(
18
+ self,
19
+ signals_schema: Optional["SignalSchema"] = None,
20
+ label: Optional[str] = None,
21
+ table: Optional["TableClause"] = None,
22
+ ) -> "Column":
23
+ pass
@@ -0,0 +1,81 @@
1
+ from typing import Union
2
+
3
+ from datachain.sql.functions import conditional
4
+
5
+ from .func import ColT, Func
6
+
7
+
8
+ def greatest(*args: Union[ColT, float]) -> Func:
9
+ """
10
+ Returns the greatest (largest) value from the given input values.
11
+
12
+ Args:
13
+ args (ColT | str | int | float | Sequence): The values to compare.
14
+ If a string is provided, it is assumed to be the name of the column.
15
+ If a Func is provided, it is assumed to be a function returning a value.
16
+ If an int, float, or Sequence is provided, it is assumed to be a literal.
17
+
18
+ Returns:
19
+ Func: A Func object that represents the greatest function.
20
+
21
+ Example:
22
+ ```py
23
+ dc.mutate(
24
+ greatest=func.greatest("signal.value", 0),
25
+ )
26
+ ```
27
+
28
+ Note:
29
+ - Result column will always be of the same type as the input columns.
30
+ """
31
+ cols, func_args = [], []
32
+
33
+ for arg in args:
34
+ if isinstance(arg, (str, Func)):
35
+ cols.append(arg)
36
+ else:
37
+ func_args.append(arg)
38
+
39
+ return Func(
40
+ "greatest",
41
+ inner=conditional.greatest,
42
+ cols=cols,
43
+ args=func_args,
44
+ result_type=int,
45
+ )
46
+
47
+
48
+ def least(*args: Union[ColT, float]) -> Func:
49
+ """
50
+ Returns the least (smallest) value from the given input values.
51
+
52
+ Args:
53
+ args (ColT | str | int | float | Sequence): The values to compare.
54
+ If a string is provided, it is assumed to be the name of the column.
55
+ If a Func is provided, it is assumed to be a function returning a value.
56
+ If an int, float, or Sequence is provided, it is assumed to be a literal.
57
+
58
+ Returns:
59
+ Func: A Func object that represents the least function.
60
+
61
+ Example:
62
+ ```py
63
+ dc.mutate(
64
+ least=func.least("signal.value", 0),
65
+ )
66
+ ```
67
+
68
+ Note:
69
+ - Result column will always be of the same type as the input columns.
70
+ """
71
+ cols, func_args = [], []
72
+
73
+ for arg in args:
74
+ if isinstance(arg, (str, Func)):
75
+ cols.append(arg)
76
+ else:
77
+ func_args.append(arg)
78
+
79
+ return Func(
80
+ "least", inner=conditional.least, cols=cols, args=func_args, result_type=int
81
+ )