datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/func/func.py CHANGED
@@ -1,12 +1,13 @@
1
1
  import inspect
2
- from collections.abc import Sequence
3
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
2
+ from collections.abc import Callable, Sequence
3
+ from typing import TYPE_CHECKING, Any, Union, get_args, get_origin
4
4
 
5
5
  from sqlalchemy import BindParameter, Case, ColumnElement, Integer, cast, desc
6
6
  from sqlalchemy.sql import func as sa_func
7
7
 
8
8
  from datachain.lib.convert.python_to_sql import python_to_sql
9
9
  from datachain.lib.convert.sql_to_python import sql_to_python
10
+ from datachain.lib.model_store import ModelStore
10
11
  from datachain.lib.utils import DataChainColumnError, DataChainParamsError
11
12
  from datachain.query.schema import Column, ColumnMeta
12
13
  from datachain.sql.functions import numeric
@@ -22,24 +23,29 @@ if TYPE_CHECKING:
22
23
  from .window import Window
23
24
 
24
25
 
25
- ColT = Union[str, ColumnElement, "Func", tuple]
26
+ ColT = Union[str, tuple, Column, ColumnElement, "Func"]
26
27
 
27
28
 
28
- class Func(Function):
29
+ class Func(Function): # noqa: PLW1641
29
30
  """Represents a function to be applied to a column in a SQL query."""
30
31
 
32
+ cols: Sequence[ColT]
33
+ args: Sequence[Any]
34
+
31
35
  def __init__(
32
36
  self,
33
37
  name: str,
34
38
  inner: Callable,
35
- cols: Optional[Sequence[ColT]] = None,
36
- args: Optional[Sequence[Any]] = None,
37
- kwargs: Optional[dict[str, Any]] = None,
38
- result_type: Optional["DataType"] = None,
39
+ cols: Sequence[ColT] | None = None,
40
+ args: Sequence[Any] | None = None,
41
+ kwargs: dict[str, Any] | None = None,
42
+ result_type: "DataType | None" = None,
43
+ type_from_args: Callable[..., "DataType"] | None = None,
39
44
  is_array: bool = False,
45
+ from_array: bool = False,
40
46
  is_window: bool = False,
41
- window: Optional["Window"] = None,
42
- label: Optional[str] = None,
47
+ window: "Window | None" = None,
48
+ label: str | None = None,
43
49
  ) -> None:
44
50
  self.name = name
45
51
  self.inner = inner
@@ -47,7 +53,9 @@ class Func(Function):
47
53
  self.args = args or []
48
54
  self.kwargs = kwargs or {}
49
55
  self.result_type = result_type
56
+ self.type_from_args = type_from_args
50
57
  self.is_array = is_array
58
+ self.from_array = from_array
51
59
  self.is_window = is_window
52
60
  self.window = window
53
61
  self.col_label = label
@@ -66,7 +74,9 @@ class Func(Function):
66
74
  self.args,
67
75
  self.kwargs,
68
76
  self.result_type,
77
+ self.type_from_args,
69
78
  self.is_array,
79
+ self.from_array,
70
80
  self.is_window,
71
81
  window,
72
82
  self.col_label,
@@ -89,7 +99,7 @@ class Func(Function):
89
99
  else []
90
100
  )
91
101
 
92
- def _db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]:
102
+ def _db_col_type(self, signals_schema: "SignalSchema") -> "DataType | None":
93
103
  if not self._db_cols:
94
104
  return None
95
105
 
@@ -101,53 +111,69 @@ class Func(Function):
101
111
  "Columns must have the same type to infer result type",
102
112
  )
103
113
 
114
+ if self.from_array:
115
+ if get_origin(col_type) is not list:
116
+ raise DataChainColumnError(
117
+ str(self),
118
+ "Array column must be of type list",
119
+ )
120
+ if self.is_array:
121
+ return col_type
122
+ col_args = get_args(col_type)
123
+ if len(col_args) != 1:
124
+ raise DataChainColumnError(
125
+ str(self),
126
+ "Array column must have a single type argument",
127
+ )
128
+ return col_args[0]
129
+
104
130
  return list[col_type] if self.is_array else col_type # type: ignore[valid-type]
105
131
 
106
- def __add__(self, other: Union[ColT, float]) -> "Func":
132
+ def __add__(self, other: ColT | float) -> "Func":
107
133
  if isinstance(other, (int, float)):
108
134
  return Func("add", lambda a: a + other, [self])
109
135
  return Func("add", lambda a1, a2: a1 + a2, [self, other])
110
136
 
111
- def __radd__(self, other: Union[ColT, float]) -> "Func":
137
+ def __radd__(self, other: ColT | float) -> "Func":
112
138
  if isinstance(other, (int, float)):
113
139
  return Func("add", lambda a: other + a, [self])
114
140
  return Func("add", lambda a1, a2: a1 + a2, [other, self])
115
141
 
116
- def __sub__(self, other: Union[ColT, float]) -> "Func":
142
+ def __sub__(self, other: ColT | float) -> "Func":
117
143
  if isinstance(other, (int, float)):
118
144
  return Func("sub", lambda a: a - other, [self])
119
145
  return Func("sub", lambda a1, a2: a1 - a2, [self, other])
120
146
 
121
- def __rsub__(self, other: Union[ColT, float]) -> "Func":
147
+ def __rsub__(self, other: ColT | float) -> "Func":
122
148
  if isinstance(other, (int, float)):
123
149
  return Func("sub", lambda a: other - a, [self])
124
150
  return Func("sub", lambda a1, a2: a1 - a2, [other, self])
125
151
 
126
- def __mul__(self, other: Union[ColT, float]) -> "Func":
152
+ def __mul__(self, other: ColT | float) -> "Func":
127
153
  if isinstance(other, (int, float)):
128
154
  return Func("mul", lambda a: a * other, [self])
129
155
  return Func("mul", lambda a1, a2: a1 * a2, [self, other])
130
156
 
131
- def __rmul__(self, other: Union[ColT, float]) -> "Func":
157
+ def __rmul__(self, other: ColT | float) -> "Func":
132
158
  if isinstance(other, (int, float)):
133
159
  return Func("mul", lambda a: other * a, [self])
134
160
  return Func("mul", lambda a1, a2: a1 * a2, [other, self])
135
161
 
136
- def __truediv__(self, other: Union[ColT, float]) -> "Func":
162
+ def __truediv__(self, other: ColT | float) -> "Func":
137
163
  if isinstance(other, (int, float)):
138
164
  return Func("div", lambda a: _truediv(a, other), [self], result_type=float)
139
165
  return Func(
140
166
  "div", lambda a1, a2: _truediv(a1, a2), [self, other], result_type=float
141
167
  )
142
168
 
143
- def __rtruediv__(self, other: Union[ColT, float]) -> "Func":
169
+ def __rtruediv__(self, other: ColT | float) -> "Func":
144
170
  if isinstance(other, (int, float)):
145
171
  return Func("div", lambda a: _truediv(other, a), [self], result_type=float)
146
172
  return Func(
147
173
  "div", lambda a1, a2: _truediv(a1, a2), [other, self], result_type=float
148
174
  )
149
175
 
150
- def __floordiv__(self, other: Union[ColT, float]) -> "Func":
176
+ def __floordiv__(self, other: ColT | float) -> "Func":
151
177
  if isinstance(other, (int, float)):
152
178
  return Func(
153
179
  "floordiv", lambda a: _floordiv(a, other), [self], result_type=int
@@ -156,7 +182,7 @@ class Func(Function):
156
182
  "floordiv", lambda a1, a2: _floordiv(a1, a2), [self, other], result_type=int
157
183
  )
158
184
 
159
- def __rfloordiv__(self, other: Union[ColT, float]) -> "Func":
185
+ def __rfloordiv__(self, other: ColT | float) -> "Func":
160
186
  if isinstance(other, (int, float)):
161
187
  return Func(
162
188
  "floordiv", lambda a: _floordiv(other, a), [self], result_type=int
@@ -165,17 +191,17 @@ class Func(Function):
165
191
  "floordiv", lambda a1, a2: _floordiv(a1, a2), [other, self], result_type=int
166
192
  )
167
193
 
168
- def __mod__(self, other: Union[ColT, float]) -> "Func":
194
+ def __mod__(self, other: ColT | float) -> "Func":
169
195
  if isinstance(other, (int, float)):
170
196
  return Func("mod", lambda a: a % other, [self], result_type=int)
171
197
  return Func("mod", lambda a1, a2: a1 % a2, [self, other], result_type=int)
172
198
 
173
- def __rmod__(self, other: Union[ColT, float]) -> "Func":
199
+ def __rmod__(self, other: ColT | float) -> "Func":
174
200
  if isinstance(other, (int, float)):
175
201
  return Func("mod", lambda a: other % a, [self], result_type=int)
176
202
  return Func("mod", lambda a1, a2: a1 % a2, [other, self], result_type=int)
177
203
 
178
- def __and__(self, other: Union[ColT, float]) -> "Func":
204
+ def __and__(self, other: ColT | float) -> "Func":
179
205
  if isinstance(other, (int, float)):
180
206
  return Func(
181
207
  "and", lambda a: numeric.bit_and(a, other), [self], result_type=int
@@ -187,7 +213,7 @@ class Func(Function):
187
213
  result_type=int,
188
214
  )
189
215
 
190
- def __rand__(self, other: Union[ColT, float]) -> "Func":
216
+ def __rand__(self, other: ColT | float) -> "Func":
191
217
  if isinstance(other, (int, float)):
192
218
  return Func(
193
219
  "and", lambda a: numeric.bit_and(other, a), [self], result_type=int
@@ -199,7 +225,7 @@ class Func(Function):
199
225
  result_type=int,
200
226
  )
201
227
 
202
- def __or__(self, other: Union[ColT, float]) -> "Func":
228
+ def __or__(self, other: ColT | float) -> "Func":
203
229
  if isinstance(other, (int, float)):
204
230
  return Func(
205
231
  "or", lambda a: numeric.bit_or(a, other), [self], result_type=int
@@ -208,7 +234,7 @@ class Func(Function):
208
234
  "or", lambda a1, a2: numeric.bit_or(a1, a2), [self, other], result_type=int
209
235
  )
210
236
 
211
- def __ror__(self, other: Union[ColT, float]) -> "Func":
237
+ def __ror__(self, other: ColT | float) -> "Func":
212
238
  if isinstance(other, (int, float)):
213
239
  return Func(
214
240
  "or", lambda a: numeric.bit_or(other, a), [self], result_type=int
@@ -217,7 +243,7 @@ class Func(Function):
217
243
  "or", lambda a1, a2: numeric.bit_or(a1, a2), [other, self], result_type=int
218
244
  )
219
245
 
220
- def __xor__(self, other: Union[ColT, float]) -> "Func":
246
+ def __xor__(self, other: ColT | float) -> "Func":
221
247
  if isinstance(other, (int, float)):
222
248
  return Func(
223
249
  "xor", lambda a: numeric.bit_xor(a, other), [self], result_type=int
@@ -229,7 +255,7 @@ class Func(Function):
229
255
  result_type=int,
230
256
  )
231
257
 
232
- def __rxor__(self, other: Union[ColT, float]) -> "Func":
258
+ def __rxor__(self, other: ColT | float) -> "Func":
233
259
  if isinstance(other, (int, float)):
234
260
  return Func(
235
261
  "xor", lambda a: numeric.bit_xor(other, a), [self], result_type=int
@@ -241,7 +267,7 @@ class Func(Function):
241
267
  result_type=int,
242
268
  )
243
269
 
244
- def __rshift__(self, other: Union[ColT, float]) -> "Func":
270
+ def __rshift__(self, other: ColT | float) -> "Func":
245
271
  if isinstance(other, (int, float)):
246
272
  return Func(
247
273
  "rshift",
@@ -256,7 +282,7 @@ class Func(Function):
256
282
  result_type=int,
257
283
  )
258
284
 
259
- def __rrshift__(self, other: Union[ColT, float]) -> "Func":
285
+ def __rrshift__(self, other: ColT | float) -> "Func":
260
286
  if isinstance(other, (int, float)):
261
287
  return Func(
262
288
  "rshift",
@@ -271,7 +297,7 @@ class Func(Function):
271
297
  result_type=int,
272
298
  )
273
299
 
274
- def __lshift__(self, other: Union[ColT, float]) -> "Func":
300
+ def __lshift__(self, other: ColT | float) -> "Func":
275
301
  if isinstance(other, (int, float)):
276
302
  return Func(
277
303
  "lshift",
@@ -286,7 +312,7 @@ class Func(Function):
286
312
  result_type=int,
287
313
  )
288
314
 
289
- def __rlshift__(self, other: Union[ColT, float]) -> "Func":
315
+ def __rlshift__(self, other: ColT | float) -> "Func":
290
316
  if isinstance(other, (int, float)):
291
317
  return Func(
292
318
  "lshift",
@@ -301,12 +327,12 @@ class Func(Function):
301
327
  result_type=int,
302
328
  )
303
329
 
304
- def __lt__(self, other: Union[ColT, float]) -> "Func":
330
+ def __lt__(self, other: ColT | float) -> "Func":
305
331
  if isinstance(other, (int, float)):
306
332
  return Func("lt", lambda a: a < other, [self], result_type=bool)
307
333
  return Func("lt", lambda a1, a2: a1 < a2, [self, other], result_type=bool)
308
334
 
309
- def __le__(self, other: Union[ColT, float]) -> "Func":
335
+ def __le__(self, other: ColT | float) -> "Func":
310
336
  if isinstance(other, (int, float)):
311
337
  return Func("le", lambda a: a <= other, [self], result_type=bool)
312
338
  return Func("le", lambda a1, a2: a1 <= a2, [self, other], result_type=bool)
@@ -321,12 +347,12 @@ class Func(Function):
321
347
  return Func("ne", lambda a: a != other, [self], result_type=bool)
322
348
  return Func("ne", lambda a1, a2: a1 != a2, [self, other], result_type=bool)
323
349
 
324
- def __gt__(self, other: Union[ColT, float]) -> "Func":
350
+ def __gt__(self, other: ColT | float) -> "Func":
325
351
  if isinstance(other, (int, float)):
326
352
  return Func("gt", lambda a: a > other, [self], result_type=bool)
327
353
  return Func("gt", lambda a1, a2: a1 > a2, [self, other], result_type=bool)
328
354
 
329
- def __ge__(self, other: Union[ColT, float]) -> "Func":
355
+ def __ge__(self, other: ColT | float) -> "Func":
330
356
  if isinstance(other, (int, float)):
331
357
  return Func("ge", lambda a: a >= other, [self], result_type=bool)
332
358
  return Func("ge", lambda a1, a2: a1 >= a2, [self, other], result_type=bool)
@@ -339,13 +365,15 @@ class Func(Function):
339
365
  self.args,
340
366
  self.kwargs,
341
367
  self.result_type,
368
+ self.type_from_args,
342
369
  self.is_array,
370
+ self.from_array,
343
371
  self.is_window,
344
372
  self.window,
345
373
  label,
346
374
  )
347
375
 
348
- def get_col_name(self, label: Optional[str] = None) -> str:
376
+ def get_col_name(self, label: str | None = None) -> str:
349
377
  if label:
350
378
  return label
351
379
  if self.col_label:
@@ -360,7 +388,7 @@ class Func(Function):
360
388
  return self.name
361
389
 
362
390
  def get_result_type(
363
- self, signals_schema: Optional["SignalSchema"] = None
391
+ self, signals_schema: "SignalSchema | None" = None
364
392
  ) -> "DataType":
365
393
  if self.result_type:
366
394
  return self.result_type
@@ -368,6 +396,15 @@ class Func(Function):
368
396
  if signals_schema and (col_type := self._db_col_type(signals_schema)):
369
397
  return col_type
370
398
 
399
+ if (
400
+ self.type_from_args
401
+ and (self.cols is None or self.cols == [])
402
+ and self.args is not None
403
+ and len(self.args) > 0
404
+ and (result_type := self.type_from_args(*self.args)) is not None
405
+ ):
406
+ return result_type
407
+
371
408
  raise DataChainColumnError(
372
409
  str(self),
373
410
  "Column name is required to infer result type",
@@ -375,10 +412,24 @@ class Func(Function):
375
412
 
376
413
  def get_column(
377
414
  self,
378
- signals_schema: Optional["SignalSchema"] = None,
379
- label: Optional[str] = None,
380
- table: Optional["TableClause"] = None,
415
+ signals_schema: "SignalSchema | None" = None,
416
+ label: str | None = None,
417
+ table: "TableClause | None" = None,
381
418
  ) -> Column:
419
+ # Guard against using complex (pydantic) object columns in SQL funcs
420
+ if signals_schema and self._db_cols:
421
+ for arg in self._db_cols:
422
+ # _db_cols normalizes known columns to strings; skip non-string args
423
+ if not isinstance(arg, str):
424
+ continue
425
+ t_with_sub = signals_schema.get_column_type(arg, with_subtree=True)
426
+ if ModelStore.is_pydantic(t_with_sub):
427
+ raise DataChainParamsError(
428
+ f"Function {self.name} doesn't support complex object "
429
+ f"columns like '{arg}'. Use a leaf field (e.g., "
430
+ f"'{arg}.path') or use UDFs to operate on complex objects."
431
+ )
432
+
382
433
  col_type = self.get_result_type(signals_schema)
383
434
  sql_type = python_to_sql(col_type)
384
435
 
@@ -398,6 +449,7 @@ class Func(Function):
398
449
  return col
399
450
 
400
451
  cols = [get_col(col) for col in self._db_cols]
452
+
401
453
  kwargs = {k: get_col(v, string_as_literal=True) for k, v in self.kwargs.items()}
402
454
  func_col = self.inner(*cols, *self.args, **kwargs)
403
455
 
@@ -434,9 +486,8 @@ def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType":
434
486
  if isinstance(col, ColumnElement) and not hasattr(col, "name"):
435
487
  return sql_to_python(col)
436
488
 
437
- return signals_schema.get_column_type(
438
- col.name if isinstance(col, ColumnElement) else col # type: ignore[arg-type]
439
- )
489
+ name = col.name if isinstance(col, ColumnElement) else col # type: ignore[assignment]
490
+ return signals_schema.get_column_type(name) # type: ignore[arg-type]
440
491
 
441
492
 
442
493
  def _truediv(a, b):
datachain/func/numeric.py CHANGED
@@ -1,31 +1,34 @@
1
- from typing import Union
2
-
1
+ from datachain.query.schema import Column
3
2
  from datachain.sql.functions import numeric
4
3
 
5
- from .func import ColT, Func
4
+ from .func import Func
6
5
 
7
6
 
8
- def bit_and(*args: Union[ColT, int]) -> Func:
7
+ def bit_and(*args: str | Column | Func | int) -> Func:
9
8
  """
10
- Computes the bitwise AND operation between two values.
9
+ Returns a function that computes the bitwise AND operation between two values.
11
10
 
12
11
  Args:
13
- args (str | int): Two values to compute the bitwise AND operation between.
14
- If a string is provided, it is assumed to be the name of the column vector.
12
+ args (str | Column | Func | int): Two values to compute
13
+ the bitwise AND operation between.
14
+ If a string is provided, it is assumed to be the name of the column.
15
+ If a Column is provided, it is assumed to be a column.
16
+ If a Func is provided, it is assumed to be a function returning an int.
15
17
  If an integer is provided, it is assumed to be a constant value.
16
18
 
17
19
  Returns:
18
- Func: A Func object that represents the bitwise AND function.
20
+ Func: A `Func` object that represents the bitwise AND function.
19
21
 
20
22
  Example:
21
23
  ```py
22
24
  dc.mutate(
23
- xor1=func.bit_and("signal.values", 0x0F),
25
+ and1=func.bit_and("signal.value", 0x0F),
26
+ and2=func.bit_and(dc.C("signal.value1"), "signal.value2"),
24
27
  )
25
28
  ```
26
29
 
27
30
  Notes:
28
- - Result column will always be of type int.
31
+ - The result column will always be of type int.
29
32
  """
30
33
  cols, func_args = [], []
31
34
  for arg in args:
@@ -46,27 +49,31 @@ def bit_and(*args: Union[ColT, int]) -> Func:
46
49
  )
47
50
 
48
51
 
49
- def bit_or(*args: Union[ColT, int]) -> Func:
52
+ def bit_or(*args: str | Column | Func | int) -> Func:
50
53
  """
51
- Computes the bitwise AND operation between two values.
54
+ Returns a function that computes the bitwise OR operation between two values.
52
55
 
53
56
  Args:
54
- args (str | int): Two values to compute the bitwise OR operation between.
55
- If a string is provided, it is assumed to be the name of the column vector.
57
+ args (str | Column | Func | int): Two values to compute
58
+ the bitwise OR operation between.
59
+ If a string is provided, it is assumed to be the name of the column.
60
+ If a Column is provided, it is assumed to be a column.
61
+ If a Func is provided, it is assumed to be a function returning an int.
56
62
  If an integer is provided, it is assumed to be a constant value.
57
63
 
58
64
  Returns:
59
- Func: A Func object that represents the bitwise OR function.
65
+ Func: A `Func` object that represents the bitwise OR function.
60
66
 
61
67
  Example:
62
68
  ```py
63
69
  dc.mutate(
64
- xor1=func.bit_or("signal.values", 0x0F),
70
+ or1=func.bit_or("signal.value", 0x0F),
71
+ or2=func.bit_or(dc.C("signal.value1"), "signal.value2"),
65
72
  )
66
73
  ```
67
74
 
68
75
  Notes:
69
- - Result column will always be of type int.
76
+ - The result column will always be of type int.
70
77
  """
71
78
  cols, func_args = [], []
72
79
  for arg in args:
@@ -87,27 +94,31 @@ def bit_or(*args: Union[ColT, int]) -> Func:
87
94
  )
88
95
 
89
96
 
90
- def bit_xor(*args: Union[ColT, int]) -> Func:
97
+ def bit_xor(*args: str | Column | Func | int) -> Func:
91
98
  """
92
- Computes the bitwise XOR operation between two values.
99
+ Returns a function that computes the bitwise XOR operation between two values.
93
100
 
94
101
  Args:
95
- args (str | int): Two values to compute the bitwise XOR operation between.
96
- If a string is provided, it is assumed to be the name of the column vector.
102
+ args (str | Column | Func | int): Two values to compute
103
+ the bitwise XOR operation between.
104
+ If a string is provided, it is assumed to be the name of the column.
105
+ If a Column is provided, it is assumed to be a column.
106
+ If a Func is provided, it is assumed to be a function returning an int.
97
107
  If an integer is provided, it is assumed to be a constant value.
98
108
 
99
109
  Returns:
100
- Func: A Func object that represents the bitwise XOR function.
110
+ Func: A `Func` object that represents the bitwise XOR function.
101
111
 
102
112
  Example:
103
113
  ```py
104
114
  dc.mutate(
105
- xor1=func.bit_xor("signal.values", 0x0F),
115
+ xor1=func.bit_xor("signal.value", 0x0F),
116
+ xor2=func.bit_xor(dc.C("signal.value1"), "signal.value2"),
106
117
  )
107
118
  ```
108
119
 
109
120
  Notes:
110
- - Result column will always be of type int.
121
+ - The result column will always be of type int.
111
122
  """
112
123
  cols, func_args = [], []
113
124
  for arg in args:
@@ -128,28 +139,30 @@ def bit_xor(*args: Union[ColT, int]) -> Func:
128
139
  )
129
140
 
130
141
 
131
- def int_hash_64(col: Union[ColT, int]) -> Func:
142
+ def int_hash_64(col: str | Column | Func | int) -> Func:
132
143
  """
133
- Returns the 64-bit hash of an integer.
144
+ Returns a function that computes the 64-bit hash of an integer.
134
145
 
135
146
  Args:
136
- col (str | int): String to compute the hash of.
147
+ col (str | Column | Func | int): Integer to compute the hash of.
137
148
  If a string is provided, it is assumed to be the name of the column.
138
- If a int is provided, it is assumed to be an int literal.
149
+ If a Column is provided, it is assumed to be a column.
139
150
  If a Func is provided, it is assumed to be a function returning an int.
151
+ If an int is provided, it is assumed to be an int literal.
140
152
 
141
153
  Returns:
142
- Func: A Func object that represents the 64-bit hash function.
154
+ Func: A `Func` object that represents the 64-bit hash function.
143
155
 
144
156
  Example:
145
157
  ```py
146
158
  dc.mutate(
147
159
  val_hash=func.int_hash_64("val"),
160
+ val_hash2=func.int_hash_64(dc.C("val2")),
148
161
  )
149
162
  ```
150
163
 
151
- Note:
152
- - Result column will always be of type int.
164
+ Notes:
165
+ - The result column will always be of type int.
153
166
  """
154
167
  cols, args = [], []
155
168
  if isinstance(col, int):
@@ -162,9 +175,9 @@ def int_hash_64(col: Union[ColT, int]) -> Func:
162
175
  )
163
176
 
164
177
 
165
- def bit_hamming_distance(*args: Union[ColT, int]) -> Func:
178
+ def bit_hamming_distance(*args: str | Column | Func | int) -> Func:
166
179
  """
167
- Computes the Hamming distance between the bit representations of two integer values.
180
+ Returns a function that computes the Hamming distance between two integers.
168
181
 
169
182
  The Hamming distance is the number of positions at which the corresponding bits
170
183
  are different. This function returns the dissimilarity between the integers,
@@ -172,22 +185,26 @@ def bit_hamming_distance(*args: Union[ColT, int]) -> Func:
172
185
  in the integer indicate higher dissimilarity.
173
186
 
174
187
  Args:
175
- args (str | int): Two integers to compute the Hamming distance between.
176
- If a str is provided, it is assumed to be the name of the column.
188
+ args (str | Column | Func | int): Two integers to compute
189
+ the Hamming distance between.
190
+ If a string is provided, it is assumed to be the name of the column.
191
+ If a Column is provided, it is assumed to be a column.
192
+ If a Func is provided, it is assumed to be a function returning an int.
177
193
  If an int is provided, it is assumed to be an integer literal.
178
194
 
179
195
  Returns:
180
- Func: A Func object that represents the Hamming distance function.
196
+ Func: A `Func` object that represents the Hamming distance function.
181
197
 
182
198
  Example:
183
199
  ```py
184
200
  dc.mutate(
185
- ham_dist=func.bit_hamming_distance("embed1", 123456),
201
+ hd1=func.bit_hamming_distance("signal.value1", "signal.value2"),
202
+ hd2=func.bit_hamming_distance(dc.C("signal.value1"), 0x0F),
186
203
  )
187
204
  ```
188
205
 
189
206
  Notes:
190
- - Result column will always be of type int.
207
+ - The result column will always be of type int.
191
208
  """
192
209
  cols, func_args = [], []
193
210
  for arg in args: