datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/delta.py CHANGED
@@ -1,17 +1,22 @@
1
1
  from collections.abc import Sequence
2
2
  from copy import copy
3
3
  from functools import wraps
4
- from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
4
+ from typing import TYPE_CHECKING, TypeVar
5
5
 
6
6
  import datachain
7
- from datachain.dataset import DatasetDependency
8
- from datachain.error import DatasetNotFoundError
7
+ from datachain.dataset import DatasetDependency, DatasetRecord
8
+ from datachain.error import DatasetNotFoundError, SchemaDriftError
9
9
  from datachain.project import Project
10
+ from datachain.query.dataset import UnionSchemaMismatchError
10
11
 
11
12
  if TYPE_CHECKING:
12
- from typing_extensions import Concatenate, ParamSpec
13
+ from collections.abc import Callable
14
+ from typing import Concatenate
15
+
16
+ from typing_extensions import ParamSpec
13
17
 
14
18
  from datachain.lib.dc import DataChain
19
+ from datachain.lib.signal_schema import SignalSchema
15
20
 
16
21
  P = ParamSpec("P")
17
22
 
@@ -30,9 +35,10 @@ def delta_disabled(
30
35
 
31
36
  @wraps(method)
32
37
  def _inner(self: T, *args: "P.args", **kwargs: "P.kwargs") -> T:
33
- if self.delta:
38
+ if self.delta and not self._delta_unsafe:
34
39
  raise NotImplementedError(
35
- f"Delta update cannot be used with {method.__name__}"
40
+ f"Cannot use {method.__name__} with delta datasets - may cause"
41
+ " inconsistency. Use delta_unsafe flag to allow this operation."
36
42
  )
37
43
  return method(self, *args, **kwargs)
38
44
 
@@ -49,13 +55,55 @@ def _append_steps(dc: "DataChain", other: "DataChain"):
49
55
  return dc
50
56
 
51
57
 
58
+ def _format_schema_drift_message(
59
+ context: str,
60
+ existing_schema: "SignalSchema",
61
+ updated_schema: "SignalSchema",
62
+ ) -> tuple[str, bool]:
63
+ missing_cols, new_cols = existing_schema.compare_signals(updated_schema)
64
+
65
+ if not new_cols and not missing_cols:
66
+ return "", False
67
+
68
+ parts: list[str] = []
69
+ if new_cols:
70
+ parts.append("new columns detected: " + ", ".join(sorted(new_cols)))
71
+ if missing_cols:
72
+ parts.append(
73
+ "columns missing in updated data: " + ", ".join(sorted(missing_cols))
74
+ )
75
+
76
+ details = "; ".join(parts)
77
+ message = f"Delta update failed: schema drift detected while {context}: {details}."
78
+
79
+ return message, True
80
+
81
+
82
+ def _safe_union(
83
+ left: "DataChain",
84
+ right: "DataChain",
85
+ context: str,
86
+ ) -> "DataChain":
87
+ try:
88
+ return left.union(right)
89
+ except UnionSchemaMismatchError as exc:
90
+ message, has_drift = _format_schema_drift_message(
91
+ context,
92
+ left.signals_schema,
93
+ right.signals_schema,
94
+ )
95
+ if has_drift:
96
+ raise SchemaDriftError(message) from exc
97
+ raise
98
+
99
+
52
100
  def _get_delta_chain(
53
101
  source_ds_name: str,
54
102
  source_ds_project: Project,
55
103
  source_ds_version: str,
56
104
  source_ds_latest_version: str,
57
- on: Union[str, Sequence[str]],
58
- compare: Optional[Union[str, Sequence[str]]] = None,
105
+ on: str | Sequence[str],
106
+ compare: str | Sequence[str] | None = None,
59
107
  ) -> "DataChain":
60
108
  """Get delta chain for processing changes between versions."""
61
109
  source_dc = datachain.read_dataset(
@@ -83,11 +131,11 @@ def _get_retry_chain(
83
131
  source_ds_name: str,
84
132
  source_ds_project: Project,
85
133
  source_ds_version: str,
86
- on: Union[str, Sequence[str]],
87
- right_on: Optional[Union[str, Sequence[str]]],
88
- delta_retry: Optional[Union[bool, str]],
134
+ on: str | Sequence[str],
135
+ right_on: str | Sequence[str] | None,
136
+ delta_retry: bool | str | None,
89
137
  diff_chain: "DataChain",
90
- ) -> Optional["DataChain"]:
138
+ ) -> "DataChain | None":
91
139
  """Get retry chain for processing error records and missing records."""
92
140
  # Import here to avoid circular import
93
141
  from datachain.lib.dc import C
@@ -113,7 +161,9 @@ def _get_retry_chain(
113
161
  error_records = result_dataset.filter(C(delta_retry) != "")
114
162
  error_source_records = source_dc.merge(
115
163
  error_records, on=on, right_on=right_on, inner=True
116
- ).select(*list(source_dc.signals_schema.values))
164
+ ).select(
165
+ *list(source_dc.signals_schema.clone_without_sys_signals().values.keys())
166
+ )
117
167
  retry_chain = error_source_records
118
168
 
119
169
  # Handle missing records if delta_retry is True
@@ -124,21 +174,30 @@ def _get_retry_chain(
124
174
  # Subtract also diff chain since some items might be picked
125
175
  # up by `delta=True` itself (e.g. records got modified AND are missing in the
126
176
  # result dataset atm)
127
- return retry_chain.subtract(diff_chain, on=on) if retry_chain else None
177
+ on = [on] if isinstance(on, str) else on
178
+
179
+ return (
180
+ retry_chain.diff(
181
+ diff_chain, on=on, added=True, same=True, modified=False, deleted=False
182
+ ).distinct(*on)
183
+ if retry_chain
184
+ else None
185
+ )
128
186
 
129
187
 
130
188
  def _get_source_info(
189
+ source_ds: DatasetRecord,
131
190
  name: str,
132
191
  namespace_name: str,
133
192
  project_name: str,
134
193
  latest_version: str,
135
194
  catalog,
136
195
  ) -> tuple[
137
- Optional[str],
138
- Optional[Project],
139
- Optional[str],
140
- Optional[str],
141
- Optional[list[DatasetDependency]],
196
+ str | None,
197
+ Project | None,
198
+ str | None,
199
+ str | None,
200
+ list[DatasetDependency] | None,
142
201
  ]:
143
202
  """Get source dataset information and dependencies.
144
203
 
@@ -154,25 +213,25 @@ def _get_source_info(
154
213
  indirect=False,
155
214
  )
156
215
 
157
- dep = dependencies[0]
158
- if not dep:
216
+ source_ds_dep = next(
217
+ (d for d in dependencies if d and d.name == source_ds.name), None
218
+ )
219
+ if not source_ds_dep:
159
220
  # Starting dataset was removed, back off to normal dataset creation
160
221
  return None, None, None, None, None
161
222
 
162
- source_ds_project = catalog.metastore.get_project(dep.project, dep.namespace)
163
- source_ds_name = dep.name
164
- source_ds_version = dep.version
165
- source_ds_latest_version = catalog.get_dataset(
166
- source_ds_name,
167
- namespace_name=source_ds_project.namespace.name,
168
- project_name=source_ds_project.name,
169
- ).latest_version
223
+ # Refresh starting dataset to have new versions if they are created
224
+ source_ds = catalog.get_dataset(
225
+ source_ds.name,
226
+ namespace_name=source_ds.project.namespace.name,
227
+ project_name=source_ds.project.name,
228
+ )
170
229
 
171
230
  return (
172
- source_ds_name,
173
- source_ds_project,
174
- source_ds_version,
175
- source_ds_latest_version,
231
+ source_ds.name,
232
+ source_ds.project,
233
+ source_ds_dep.version,
234
+ source_ds.latest_version,
176
235
  dependencies,
177
236
  )
178
237
 
@@ -182,11 +241,11 @@ def delta_retry_update(
182
241
  namespace_name: str,
183
242
  project_name: str,
184
243
  name: str,
185
- on: Union[str, Sequence[str]],
186
- right_on: Optional[Union[str, Sequence[str]]] = None,
187
- compare: Optional[Union[str, Sequence[str]]] = None,
188
- delta_retry: Optional[Union[bool, str]] = None,
189
- ) -> tuple[Optional["DataChain"], Optional[list[DatasetDependency]], bool]:
244
+ on: str | Sequence[str],
245
+ right_on: str | Sequence[str] | None = None,
246
+ compare: str | Sequence[str] | None = None,
247
+ delta_retry: bool | str | None = None,
248
+ ) -> tuple["DataChain | None", list[DatasetDependency] | None, bool]:
190
249
  """
191
250
  Creates new chain that consists of the last version of current delta dataset
192
251
  plus diff from the source with all needed modifications.
@@ -244,7 +303,14 @@ def delta_retry_update(
244
303
  source_ds_version,
245
304
  source_ds_latest_version,
246
305
  dependencies,
247
- ) = _get_source_info(name, namespace_name, project_name, latest_version, catalog)
306
+ ) = _get_source_info(
307
+ dc._query.starting_step.dataset, # type: ignore[union-attr]
308
+ name,
309
+ namespace_name,
310
+ project_name,
311
+ latest_version,
312
+ catalog,
313
+ )
248
314
 
249
315
  # If source_ds_name is None, starting dataset was removed
250
316
  if source_ds_name is None:
@@ -267,8 +333,9 @@ def delta_retry_update(
267
333
  if dependencies:
268
334
  dependencies = copy(dependencies)
269
335
  dependencies = [d for d in dependencies if d is not None]
336
+ source_ds_dep = next(d for d in dependencies if d.name == source_ds_name)
270
337
  # Update to latest version
271
- dependencies[0].version = source_ds_latest_version # type: ignore[union-attr]
338
+ source_ds_dep.version = source_ds_latest_version # type: ignore[union-attr]
272
339
 
273
340
  # Handle retry functionality if enabled
274
341
  if delta_retry:
@@ -288,7 +355,11 @@ def delta_retry_update(
288
355
 
289
356
  # Combine delta and retry chains
290
357
  if retry_chain is not None:
291
- processing_chain = diff_chain.union(retry_chain)
358
+ processing_chain = _safe_union(
359
+ diff_chain,
360
+ retry_chain,
361
+ context="combining retry records with delta changes",
362
+ )
292
363
  else:
293
364
  processing_chain = diff_chain
294
365
 
@@ -312,5 +383,9 @@ def delta_retry_update(
312
383
  modified=False,
313
384
  deleted=False,
314
385
  )
315
- result_chain = compared_chain.union(processing_chain)
386
+ result_chain = _safe_union(
387
+ compared_chain,
388
+ processing_chain,
389
+ context="merging the delta output with the existing dataset version",
390
+ )
316
391
  return result_chain, dependencies, True
@@ -1,8 +1,6 @@
1
- import random
2
- import string
3
1
  from collections.abc import Sequence
4
2
  from enum import Enum
5
- from typing import TYPE_CHECKING, Optional, Union
3
+ from typing import TYPE_CHECKING
6
4
 
7
5
  from datachain.func import case, ifelse, isnone, or_
8
6
  from datachain.lib.signal_schema import SignalSchema
@@ -11,16 +9,12 @@ from datachain.query.schema import Column
11
9
  if TYPE_CHECKING:
12
10
  from datachain.lib.dc import DataChain
13
11
 
14
-
15
12
  C = Column
16
13
 
17
14
 
18
- def get_status_col_name() -> str:
19
- """Returns new unique status col name"""
20
- return "diff_" + "".join(
21
- random.choice(string.ascii_letters) # noqa: S311
22
- for _ in range(10)
23
- )
15
+ STATUS_COL_NAME = "diff_7aeed3aa17ba4d50b8d1c368c76e16a6"
16
+ LEFT_DIFF_COL_NAME = "diff_95f95344064a4b819c8625cd1a5cfc2b"
17
+ RIGHT_DIFF_COL_NAME = "diff_5808838a49b54849aa461d7387376d34"
24
18
 
25
19
 
26
20
  class CompareStatus(str, Enum):
@@ -30,25 +24,25 @@ class CompareStatus(str, Enum):
30
24
  SAME = "S"
31
25
 
32
26
 
33
- def _compare( # noqa: C901, PLR0912
27
+ def _compare( # noqa: C901
34
28
  left: "DataChain",
35
29
  right: "DataChain",
36
- on: Union[str, Sequence[str]],
37
- right_on: Optional[Union[str, Sequence[str]]] = None,
38
- compare: Optional[Union[str, Sequence[str]]] = None,
39
- right_compare: Optional[Union[str, Sequence[str]]] = None,
30
+ on: str | Sequence[str],
31
+ right_on: str | Sequence[str] | None = None,
32
+ compare: str | Sequence[str] | None = None,
33
+ right_compare: str | Sequence[str] | None = None,
40
34
  added: bool = True,
41
35
  deleted: bool = True,
42
36
  modified: bool = True,
43
37
  same: bool = True,
44
- status_col: Optional[str] = None,
38
+ status_col: str | None = None,
45
39
  ) -> "DataChain":
46
40
  """Comparing two chains by identifying rows that are added, deleted, modified
47
41
  or same"""
48
42
  rname = "right_"
49
43
  schema = left.signals_schema # final chain must have schema from left chain
50
44
 
51
- def _to_list(obj: Optional[Union[str, Sequence[str]]]) -> Optional[list[str]]:
45
+ def _to_list(obj: str | Sequence[str] | None) -> list[str] | None:
52
46
  if obj is None:
53
47
  return None
54
48
  return [obj] if isinstance(obj, str) else list(obj)
@@ -101,21 +95,23 @@ def _compare( # noqa: C901, PLR0912
101
95
  compare = right_compare = [c for c in cols if c in right_cols and c not in on] # type: ignore[misc]
102
96
 
103
97
  # get diff column names
104
- diff_col = status_col or get_status_col_name()
105
- ldiff_col = get_status_col_name()
106
- rdiff_col = get_status_col_name()
98
+ diff_col = status_col or STATUS_COL_NAME
99
+ ldiff_col = LEFT_DIFF_COL_NAME
100
+ rdiff_col = RIGHT_DIFF_COL_NAME
107
101
 
108
102
  # adding helper diff columns, which will be removed after
109
103
  left = left.mutate(**{ldiff_col: 1})
110
104
  right = right.mutate(**{rdiff_col: 1})
111
105
 
112
- if not compare:
106
+ if compare is None:
113
107
  modified_cond = True
108
+ elif len(compare) == 0:
109
+ modified_cond = False
114
110
  else:
115
111
  modified_cond = or_( # type: ignore[assignment]
116
112
  *[
117
113
  C(c) != (C(f"{rname}{rc}") if c == rc else C(rc))
118
- for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
114
+ for c, rc in zip(compare, right_compare, strict=False) # type: ignore[arg-type]
119
115
  ]
120
116
  )
121
117
 
@@ -139,7 +135,7 @@ def _compare( # noqa: C901, PLR0912
139
135
  C(f"{rname + l_on if on == right_on else r_on}"),
140
136
  C(l_on),
141
137
  )
142
- for l_on, r_on in zip(on, right_on) # type: ignore[arg-type]
138
+ for l_on, r_on in zip(on, right_on, strict=False) # type: ignore[arg-type]
143
139
  }
144
140
  )
145
141
  .select_except(ldiff_col, rdiff_col)
@@ -157,11 +153,7 @@ def _compare( # noqa: C901, PLR0912
157
153
  if status_col:
158
154
  cols_select.append(diff_col)
159
155
 
160
- if not dc_diff._sys:
161
- # TODO workaround when sys signal is not available in diff
162
- dc_diff = dc_diff.settings(sys=True).select(*cols_select).settings(sys=False)
163
- else:
164
- dc_diff = dc_diff.select(*cols_select)
156
+ dc_diff = dc_diff.select(*cols_select)
165
157
 
166
158
  # final schema is schema from the left chain with status column added if needed
167
159
  dc_diff.signals_schema = (
@@ -174,10 +166,10 @@ def _compare( # noqa: C901, PLR0912
174
166
  def compare_and_split(
175
167
  left: "DataChain",
176
168
  right: "DataChain",
177
- on: Union[str, Sequence[str]],
178
- right_on: Optional[Union[str, Sequence[str]]] = None,
179
- compare: Optional[Union[str, Sequence[str]]] = None,
180
- right_compare: Optional[Union[str, Sequence[str]]] = None,
169
+ on: str | Sequence[str],
170
+ right_on: str | Sequence[str] | None = None,
171
+ compare: str | Sequence[str] | None = None,
172
+ right_compare: str | Sequence[str] | None = None,
181
173
  added: bool = True,
182
174
  deleted: bool = True,
183
175
  modified: bool = True,
@@ -227,7 +219,7 @@ def compare_and_split(
227
219
  )
228
220
  ```
229
221
  """
230
- status_col = get_status_col_name()
222
+ status_col = STATUS_COL_NAME
231
223
 
232
224
  res = _compare(
233
225
  left,
datachain/error.py CHANGED
@@ -2,6 +2,10 @@ class DataChainError(RuntimeError):
2
2
  pass
3
3
 
4
4
 
5
+ class SchemaDriftError(DataChainError):
6
+ pass
7
+
8
+
5
9
  class InvalidDatasetNameError(RuntimeError):
6
10
  pass
7
11
 
@@ -34,6 +38,14 @@ class ProjectCreateNotAllowedError(NotAllowedError):
34
38
  pass
35
39
 
36
40
 
41
+ class ProjectDeleteNotAllowedError(NotAllowedError):
42
+ pass
43
+
44
+
45
+ class NamespaceDeleteNotAllowedError(NotAllowedError):
46
+ pass
47
+
48
+
37
49
  class ProjectNotFoundError(NotFoundError):
38
50
  pass
39
51
 
@@ -89,3 +101,15 @@ class TableMissingError(DataChainError):
89
101
 
90
102
  class OutdatedDatabaseSchemaError(DataChainError):
91
103
  pass
104
+
105
+
106
+ class CheckpointNotFoundError(NotFoundError):
107
+ pass
108
+
109
+
110
+ class JobNotFoundError(NotFoundError):
111
+ pass
112
+
113
+
114
+ class JobAncestryDepthExceededError(DataChainError):
115
+ pass
@@ -1,5 +1,3 @@
1
- from typing import Optional, Union
2
-
3
1
  from sqlalchemy import func as sa_func
4
2
 
5
3
  from datachain.query.schema import Column
@@ -8,7 +6,7 @@ from datachain.sql.functions import aggregate
8
6
  from .func import Func
9
7
 
10
8
 
11
- def count(col: Optional[Union[str, Column]] = None) -> Func:
9
+ def count(col: str | Column | None = None) -> Func:
12
10
  """
13
11
  Returns a COUNT aggregate SQL function for the specified column.
14
12
 
@@ -44,7 +42,7 @@ def count(col: Optional[Union[str, Column]] = None) -> Func:
44
42
  )
45
43
 
46
44
 
47
- def sum(col: Union[str, Column]) -> Func:
45
+ def sum(col: str | Column) -> Func:
48
46
  """
49
47
  Returns the SUM aggregate SQL function for the specified column.
50
48
 
@@ -74,7 +72,7 @@ def sum(col: Union[str, Column]) -> Func:
74
72
  return Func("sum", inner=sa_func.sum, cols=[col])
75
73
 
76
74
 
77
- def avg(col: Union[str, Column]) -> Func:
75
+ def avg(col: str | Column) -> Func:
78
76
  """
79
77
  Returns the AVG aggregate SQL function for the specified column.
80
78
 
@@ -104,7 +102,7 @@ def avg(col: Union[str, Column]) -> Func:
104
102
  return Func("avg", inner=aggregate.avg, cols=[col], result_type=float)
105
103
 
106
104
 
107
- def min(col: Union[str, Column]) -> Func:
105
+ def min(col: str | Column) -> Func:
108
106
  """
109
107
  Returns the MIN aggregate SQL function for the specified column.
110
108
 
@@ -134,7 +132,7 @@ def min(col: Union[str, Column]) -> Func:
134
132
  return Func("min", inner=sa_func.min, cols=[col])
135
133
 
136
134
 
137
- def max(col: Union[str, Column]) -> Func:
135
+ def max(col: str | Column) -> Func:
138
136
  """
139
137
  Returns the MAX aggregate SQL function for the given column name.
140
138
 
@@ -164,7 +162,7 @@ def max(col: Union[str, Column]) -> Func:
164
162
  return Func("max", inner=sa_func.max, cols=[col])
165
163
 
166
164
 
167
- def any_value(col: Union[str, Column]) -> Func:
165
+ def any_value(col: str | Column) -> Func:
168
166
  """
169
167
  Returns the ANY_VALUE aggregate SQL function for the given column name.
170
168
 
@@ -198,7 +196,7 @@ def any_value(col: Union[str, Column]) -> Func:
198
196
  return Func("any_value", inner=aggregate.any_value, cols=[col])
199
197
 
200
198
 
201
- def collect(col: Union[str, Column]) -> Func:
199
+ def collect(col: str | Column) -> Func:
202
200
  """
203
201
  Returns the COLLECT aggregate SQL function for the given column name.
204
202
 
@@ -229,7 +227,7 @@ def collect(col: Union[str, Column]) -> Func:
229
227
  return Func("collect", inner=aggregate.collect, cols=[col], is_array=True)
230
228
 
231
229
 
232
- def concat(col: Union[str, Column], separator="") -> Func:
230
+ def concat(col: str | Column, separator="") -> Func:
233
231
  """
234
232
  Returns the CONCAT aggregate SQL function for the given column name.
235
233
 
@@ -348,7 +346,7 @@ def dense_rank() -> Func:
348
346
  return Func("dense_rank", inner=sa_func.dense_rank, result_type=int, is_window=True)
349
347
 
350
348
 
351
- def first(col: Union[str, Column]) -> Func:
349
+ def first(col: str | Column) -> Func:
352
350
  """
353
351
  Returns the FIRST_VALUE window function for SQL queries.
354
352
 
datachain/func/array.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from collections.abc import Sequence
2
- from typing import Any, Optional, Union
2
+ from typing import Any
3
3
 
4
4
  from datachain.query.schema import Column
5
5
  from datachain.sql.functions import array
@@ -7,7 +7,7 @@ from datachain.sql.functions import array
7
7
  from .func import Func
8
8
 
9
9
 
10
- def cosine_distance(*args: Union[str, Column, Func, Sequence]) -> Func:
10
+ def cosine_distance(*args: str | Column | Func | Sequence) -> Func:
11
11
  """
12
12
  Returns the cosine distance between two vectors.
13
13
 
@@ -62,7 +62,7 @@ def cosine_distance(*args: Union[str, Column, Func, Sequence]) -> Func:
62
62
  )
63
63
 
64
64
 
65
- def euclidean_distance(*args: Union[str, Column, Func, Sequence]) -> Func:
65
+ def euclidean_distance(*args: str | Column | Func | Sequence) -> Func:
66
66
  """
67
67
  Returns the Euclidean distance between two vectors.
68
68
 
@@ -115,7 +115,7 @@ def euclidean_distance(*args: Union[str, Column, Func, Sequence]) -> Func:
115
115
  )
116
116
 
117
117
 
118
- def length(arg: Union[str, Column, Func, Sequence]) -> Func:
118
+ def length(arg: str | Column | Func | Sequence) -> Func:
119
119
  """
120
120
  Returns the length of the array.
121
121
 
@@ -151,7 +151,7 @@ def length(arg: Union[str, Column, Func, Sequence]) -> Func:
151
151
  return Func("length", inner=array.length, cols=cols, args=args, result_type=int)
152
152
 
153
153
 
154
- def contains(arr: Union[str, Column, Func, Sequence], elem: Any) -> Func:
154
+ def contains(arr: str | Column | Func | Sequence, elem: Any) -> Func:
155
155
  """
156
156
  Checks whether the array contains the specified element.
157
157
 
@@ -196,9 +196,9 @@ def contains(arr: Union[str, Column, Func, Sequence], elem: Any) -> Func:
196
196
 
197
197
 
198
198
  def slice(
199
- arr: Union[str, Column, Func, Sequence],
199
+ arr: str | Column | Func | Sequence,
200
200
  offset: int,
201
- length: Optional[int] = None,
201
+ length: int | None = None,
202
202
  ) -> Func:
203
203
  """
204
204
  Returns a slice of the array starting from the specified offset.
@@ -272,7 +272,7 @@ def slice(
272
272
 
273
273
 
274
274
  def join(
275
- arr: Union[str, Column, Func, Sequence],
275
+ arr: str | Column | Func | Sequence,
276
276
  sep: str = "",
277
277
  ) -> Func:
278
278
  """
@@ -322,7 +322,7 @@ def join(
322
322
  )
323
323
 
324
324
 
325
- def get_element(arg: Union[str, Column, Func, Sequence], index: int) -> Func:
325
+ def get_element(arg: str | Column | Func | Sequence, index: int) -> Func:
326
326
  """
327
327
  Returns the element at the given index from the array.
328
328
  If the index is out of bounds, it returns None or columns default value.
@@ -359,8 +359,8 @@ def get_element(arg: Union[str, Column, Func, Sequence], index: int) -> Func:
359
359
  return str # if the array is empty, return str as default type
360
360
  return None
361
361
 
362
- cols: Optional[Union[str, Column, Func, Sequence]]
363
- args: Union[str, Column, Func, Sequence, int]
362
+ cols: str | Column | Func | Sequence | None
363
+ args: str | Column | Func | Sequence | int
364
364
 
365
365
  if isinstance(arg, (str, Column, Func)):
366
366
  cols = [arg]
@@ -379,7 +379,7 @@ def get_element(arg: Union[str, Column, Func, Sequence], index: int) -> Func:
379
379
  )
380
380
 
381
381
 
382
- def sip_hash_64(arg: Union[str, Column, Func, Sequence]) -> Func:
382
+ def sip_hash_64(arg: str | Column | Func | Sequence) -> Func:
383
383
  """
384
384
  Returns the SipHash-64 hash of the array.
385
385
 
datachain/func/base.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from abc import ABCMeta, abstractmethod
2
- from typing import TYPE_CHECKING, Optional
2
+ from collections.abc import Sequence
3
+ from typing import TYPE_CHECKING
3
4
 
4
5
  if TYPE_CHECKING:
5
6
  from sqlalchemy import TableClause
@@ -12,12 +13,14 @@ class Function:
12
13
  __metaclass__ = ABCMeta
13
14
 
14
15
  name: str
16
+ cols: Sequence
17
+ args: Sequence
15
18
 
16
19
  @abstractmethod
17
20
  def get_column(
18
21
  self,
19
- signals_schema: Optional["SignalSchema"] = None,
20
- label: Optional[str] = None,
21
- table: Optional["TableClause"] = None,
22
+ signals_schema: "SignalSchema | None" = None,
23
+ label: str | None = None,
24
+ table: "TableClause | None" = None,
22
25
  ) -> "Column":
23
26
  pass