datachain 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (38) hide show
  1. datachain/__init__.py +2 -0
  2. datachain/catalog/catalog.py +62 -228
  3. datachain/cli.py +136 -22
  4. datachain/client/fsspec.py +9 -0
  5. datachain/client/local.py +11 -32
  6. datachain/config.py +126 -51
  7. datachain/data_storage/schema.py +66 -33
  8. datachain/data_storage/sqlite.py +12 -4
  9. datachain/data_storage/warehouse.py +101 -129
  10. datachain/lib/convert/sql_to_python.py +8 -12
  11. datachain/lib/dc.py +275 -80
  12. datachain/lib/func/__init__.py +32 -0
  13. datachain/lib/func/aggregate.py +353 -0
  14. datachain/lib/func/func.py +152 -0
  15. datachain/lib/listing.py +6 -21
  16. datachain/lib/listing_info.py +4 -0
  17. datachain/lib/signal_schema.py +17 -8
  18. datachain/lib/udf.py +3 -3
  19. datachain/lib/utils.py +5 -0
  20. datachain/listing.py +22 -48
  21. datachain/query/__init__.py +1 -2
  22. datachain/query/batch.py +0 -1
  23. datachain/query/dataset.py +33 -46
  24. datachain/query/schema.py +1 -61
  25. datachain/query/session.py +33 -25
  26. datachain/remote/studio.py +63 -14
  27. datachain/sql/functions/__init__.py +1 -1
  28. datachain/sql/functions/aggregate.py +47 -0
  29. datachain/sql/functions/array.py +0 -8
  30. datachain/sql/sqlite/base.py +20 -2
  31. datachain/studio.py +129 -0
  32. datachain/utils.py +58 -0
  33. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/METADATA +7 -6
  34. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/RECORD +38 -33
  35. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/WHEEL +1 -1
  36. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/LICENSE +0 -0
  37. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/entry_points.txt +0 -0
  38. {datachain-0.6.0.dist-info → datachain-0.6.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,353 @@
1
+ from typing import Optional
2
+
3
+ from sqlalchemy import func as sa_func
4
+
5
+ from datachain.sql import functions as dc_func
6
+
7
+ from .func import Func
8
+
9
+
10
+ def count(col: Optional[str] = None) -> Func:
11
+ """
12
+ Returns the COUNT aggregate SQL function for the given column name.
13
+
14
+ The COUNT function returns the number of rows in a table.
15
+
16
+ Args:
17
+ col (str, optional): The name of the column for which to count rows.
18
+ If not provided, it defaults to counting all rows.
19
+
20
+ Returns:
21
+ Func: A Func object that represents the COUNT aggregate function.
22
+
23
+ Example:
24
+ ```py
25
+ dc.group_by(
26
+ count=func.count(),
27
+ partition_by="signal.category",
28
+ )
29
+ ```
30
+
31
+ Notes:
32
+ - Result column will always be of type int.
33
+ """
34
+ return Func("count", inner=sa_func.count, col=col, result_type=int)
35
+
36
+
37
+ def sum(col: str) -> Func:
38
+ """
39
+ Returns the SUM aggregate SQL function for the given column name.
40
+
41
+ The SUM function returns the total sum of a numeric column in a table.
42
+ It sums up all the values for the specified column.
43
+
44
+ Args:
45
+ col (str): The name of the column for which to calculate the sum.
46
+
47
+ Returns:
48
+ Func: A Func object that represents the SUM aggregate function.
49
+
50
+ Example:
51
+ ```py
52
+ dc.group_by(
53
+ files_size=func.sum("file.size"),
54
+ partition_by="signal.category",
55
+ )
56
+ ```
57
+
58
+ Notes:
59
+ - The `sum` function should be used on numeric columns.
60
+ - Result column type will be the same as the input column type.
61
+ """
62
+ return Func("sum", inner=sa_func.sum, col=col)
63
+
64
+
65
+ def avg(col: str) -> Func:
66
+ """
67
+ Returns the AVG aggregate SQL function for the given column name.
68
+
69
+ The AVG function returns the average of a numeric column in a table.
70
+ It calculates the mean of all values in the specified column.
71
+
72
+ Args:
73
+ col (str): The name of the column for which to calculate the average.
74
+
75
+ Returns:
76
+ Func: A Func object that represents the AVG aggregate function.
77
+
78
+ Example:
79
+ ```py
80
+ dc.group_by(
81
+ avg_file_size=func.avg("file.size"),
82
+ partition_by="signal.category",
83
+ )
84
+ ```
85
+
86
+ Notes:
87
+ - The `avg` function should be used on numeric columns.
88
+ - Result column will always be of type float.
89
+ """
90
+ return Func("avg", inner=dc_func.aggregate.avg, col=col, result_type=float)
91
+
92
+
93
+ def min(col: str) -> Func:
94
+ """
95
+ Returns the MIN aggregate SQL function for the given column name.
96
+
97
+ The MIN function returns the smallest value in the specified column.
98
+ It can be used on both numeric and non-numeric columns to find the minimum value.
99
+
100
+ Args:
101
+ col (str): The name of the column for which to find the minimum value.
102
+
103
+ Returns:
104
+ Func: A Func object that represents the MIN aggregate function.
105
+
106
+ Example:
107
+ ```py
108
+ dc.group_by(
109
+ smallest_file=func.min("file.size"),
110
+ partition_by="signal.category",
111
+ )
112
+ ```
113
+
114
+ Notes:
115
+ - The `min` function can be used with numeric, date, and string columns.
116
+ - Result column will have the same type as the input column.
117
+ """
118
+ return Func("min", inner=sa_func.min, col=col)
119
+
120
+
121
+ def max(col: str) -> Func:
122
+ """
123
+ Returns the MAX aggregate SQL function for the given column name.
124
+
125
+ The MAX function returns the smallest value in the specified column.
126
+ It can be used on both numeric and non-numeric columns to find the maximum value.
127
+
128
+ Args:
129
+ col (str): The name of the column for which to find the maximum value.
130
+
131
+ Returns:
132
+ Func: A Func object that represents the MAX aggregate function.
133
+
134
+ Example:
135
+ ```py
136
+ dc.group_by(
137
+ largest_file=func.max("file.size"),
138
+ partition_by="signal.category",
139
+ )
140
+ ```
141
+
142
+ Notes:
143
+ - The `max` function can be used with numeric, date, and string columns.
144
+ - Result column will have the same type as the input column.
145
+ """
146
+ return Func("max", inner=sa_func.max, col=col)
147
+
148
+
149
+ def any_value(col: str) -> Func:
150
+ """
151
+ Returns the ANY_VALUE aggregate SQL function for the given column name.
152
+
153
+ The ANY_VALUE function returns an arbitrary value from the specified column.
154
+ It is useful when you do not care which particular value is returned,
155
+ as long as it comes from one of the rows in the group.
156
+
157
+ Args:
158
+ col (str): The name of the column from which to return an arbitrary value.
159
+
160
+ Returns:
161
+ Func: A Func object that represents the ANY_VALUE aggregate function.
162
+
163
+ Example:
164
+ ```py
165
+ dc.group_by(
166
+ file_example=func.any_value("file.name"),
167
+ partition_by="signal.category",
168
+ )
169
+ ```
170
+
171
+ Notes:
172
+ - The `any_value` function can be used with any type of column.
173
+ - Result column will have the same type as the input column.
174
+ - The result of `any_value` is non-deterministic,
175
+ meaning it may return different values for different executions.
176
+ """
177
+ return Func("any_value", inner=dc_func.aggregate.any_value, col=col)
178
+
179
+
180
+ def collect(col: str) -> Func:
181
+ """
182
+ Returns the COLLECT aggregate SQL function for the given column name.
183
+
184
+ The COLLECT function gathers all values from the specified column
185
+ into an array or similar structure. It is useful for combining values from a column
186
+ into a collection, often for further processing or aggregation.
187
+
188
+ Args:
189
+ col (str): The name of the column from which to collect values.
190
+
191
+ Returns:
192
+ Func: A Func object that represents the COLLECT aggregate function.
193
+
194
+ Example:
195
+ ```py
196
+ dc.group_by(
197
+ signals=func.collect("signal"),
198
+ partition_by="signal.category",
199
+ )
200
+ ```
201
+
202
+ Notes:
203
+ - The `collect` function can be used with numeric and string columns.
204
+ - Result column will have an array type.
205
+ """
206
+ return Func("collect", inner=dc_func.aggregate.collect, col=col, is_array=True)
207
+
208
+
209
+ def concat(col: str, separator="") -> Func:
210
+ """
211
+ Returns the CONCAT aggregate SQL function for the given column name.
212
+
213
+ The CONCAT function concatenates values from the specified column
214
+ into a single string. It is useful for merging text values from multiple rows
215
+ into a single combined value.
216
+
217
+ Args:
218
+ col (str): The name of the column from which to concatenate values.
219
+ separator (str, optional): The separator to use between concatenated values.
220
+ Defaults to an empty string.
221
+
222
+ Returns:
223
+ Func: A Func object that represents the CONCAT aggregate function.
224
+
225
+ Example:
226
+ ```py
227
+ dc.group_by(
228
+ files=func.concat("file.name", separator=", "),
229
+ partition_by="signal.category",
230
+ )
231
+ ```
232
+
233
+ Notes:
234
+ - The `concat` function can be used with string columns.
235
+ - Result column will have a string type.
236
+ """
237
+
238
+ def inner(arg):
239
+ return dc_func.aggregate.group_concat(arg, separator)
240
+
241
+ return Func("concat", inner=inner, col=col, result_type=str)
242
+
243
+
244
+ def row_number() -> Func:
245
+ """
246
+ Returns the ROW_NUMBER window function for SQL queries.
247
+
248
+ The ROW_NUMBER function assigns a unique sequential integer to rows
249
+ within a partition of a result set, starting from 1 for the first row
250
+ in each partition. It is commonly used to generate row numbers within
251
+ partitions or ordered results.
252
+
253
+ Returns:
254
+ Func: A Func object that represents the ROW_NUMBER window function.
255
+
256
+ Example:
257
+ ```py
258
+ window = func.window(partition_by="signal.category", order_by="created_at")
259
+ dc.mutate(
260
+ row_number=func.row_number().over(window),
261
+ )
262
+ ```
263
+
264
+ Note:
265
+ - The result column will always be of type int.
266
+ """
267
+ return Func("row_number", inner=sa_func.row_number, result_type=int, is_window=True)
268
+
269
+
270
+ def rank() -> Func:
271
+ """
272
+ Returns the RANK window function for SQL queries.
273
+
274
+ The RANK function assigns a rank to each row within a partition of a result set,
275
+ with gaps in the ranking for ties. Rows with equal values receive the same rank,
276
+ and the next rank is skipped (i.e., if two rows are ranked 1,
277
+ the next row is ranked 3).
278
+
279
+ Returns:
280
+ Func: A Func object that represents the RANK window function.
281
+
282
+ Example:
283
+ ```py
284
+ window = func.window(partition_by="signal.category", order_by="created_at")
285
+ dc.mutate(
286
+ rank=func.rank().over(window),
287
+ )
288
+ ```
289
+
290
+ Notes:
291
+ - The result column will always be of type int.
292
+ - The RANK function differs from ROW_NUMBER in that rows with the same value
293
+ in the ordering column(s) receive the same rank.
294
+ """
295
+ return Func("rank", inner=sa_func.rank, result_type=int, is_window=True)
296
+
297
+
298
+ def dense_rank() -> Func:
299
+ """
300
+ Returns the DENSE_RANK window function for SQL queries.
301
+
302
+ The DENSE_RANK function assigns a rank to each row within a partition
303
+ of a result set, without gaps in the ranking for ties. Rows with equal values
304
+ receive the same rank, but the next rank is assigned consecutively
305
+ (i.e., if two rows are ranked 1, the next row will be ranked 2).
306
+
307
+ Returns:
308
+ Func: A Func object that represents the DENSE_RANK window function.
309
+
310
+ Example:
311
+ ```py
312
+ window = func.window(partition_by="signal.category", order_by="created_at")
313
+ dc.mutate(
314
+ dense_rank=func.dense_rank().over(window),
315
+ )
316
+ ```
317
+
318
+ Notes:
319
+ - The result column will always be of type int.
320
+ - The DENSE_RANK function differs from RANK in that it does not leave gaps
321
+ in the ranking for tied values.
322
+ """
323
+ return Func("dense_rank", inner=sa_func.dense_rank, result_type=int, is_window=True)
324
+
325
+
326
+ def first(col: str) -> Func:
327
+ """
328
+ Returns the FIRST_VALUE window function for SQL queries.
329
+
330
+ The FIRST_VALUE function returns the first value in an ordered set of values
331
+ within a partition. The first value is determined by the specified order
332
+ and can be useful for retrieving the leading value in a group of rows.
333
+
334
+ Args:
335
+ col (str): The name of the column from which to retrieve the first value.
336
+
337
+ Returns:
338
+ Func: A Func object that represents the FIRST_VALUE window function.
339
+
340
+ Example:
341
+ ```py
342
+ window = func.window(partition_by="signal.category", order_by="created_at")
343
+ dc.mutate(
344
+ first_file=func.first("file.name").over(window),
345
+ )
346
+ ```
347
+
348
+ Note:
349
+ - The result of `first_value` will always reflect the value of the first row
350
+ in the specified order.
351
+ - The result column will have the same type as the input column.
352
+ """
353
+ return Func("first", inner=sa_func.first_value, col=col, is_window=True)
@@ -0,0 +1,152 @@
1
+ from dataclasses import dataclass
2
+ from typing import TYPE_CHECKING, Callable, Optional
3
+
4
+ from sqlalchemy import desc
5
+
6
+ from datachain.lib.convert.python_to_sql import python_to_sql
7
+ from datachain.lib.utils import DataChainColumnError, DataChainParamsError
8
+ from datachain.query.schema import Column, ColumnMeta
9
+
10
+ if TYPE_CHECKING:
11
+ from datachain import DataType
12
+ from datachain.lib.signal_schema import SignalSchema
13
+
14
+
15
+ @dataclass
16
+ class Window:
17
+ """Represents a window specification for SQL window functions."""
18
+
19
+ partition_by: str
20
+ order_by: str
21
+ desc: bool = False
22
+
23
+
24
+ def window(partition_by: str, order_by: str, desc: bool = False) -> Window:
25
+ """
26
+ Defines a window specification for SQL window functions.
27
+
28
+ The `window` function specifies how to partition and order the result set
29
+ for the associated window function. It is used to define the scope of the rows
30
+ that the window function will operate on.
31
+
32
+ Args:
33
+ partition_by (str): The column name by which to partition the result set.
34
+ Rows with the same value in the partition column
35
+ will be grouped together for the window function.
36
+ order_by (str): The column name by which to order the rows
37
+ within each partition. This determines the sequence in which
38
+ the window function is applied.
39
+ desc (bool, optional): If True, the rows will be ordered in descending order.
40
+ Defaults to False, which orders the rows
41
+ in ascending order.
42
+
43
+ Returns:
44
+ Window: A Window object representing the window specification.
45
+
46
+ Example:
47
+ ```py
48
+ window = func.window(partition_by="signal.category", order_by="created_at")
49
+ dc.mutate(
50
+ row_number=func.row_number().over(window),
51
+ )
52
+ ```
53
+ """
54
+ return Window(
55
+ ColumnMeta.to_db_name(partition_by),
56
+ ColumnMeta.to_db_name(order_by),
57
+ desc,
58
+ )
59
+
60
+
61
+ class Func:
62
+ """Represents a function to be applied to a column in a SQL query."""
63
+
64
+ def __init__(
65
+ self,
66
+ name: str,
67
+ inner: Callable,
68
+ col: Optional[str] = None,
69
+ result_type: Optional["DataType"] = None,
70
+ is_array: bool = False,
71
+ is_window: bool = False,
72
+ window: Optional[Window] = None,
73
+ ) -> None:
74
+ self.name = name
75
+ self.inner = inner
76
+ self.col = col
77
+ self.result_type = result_type
78
+ self.is_array = is_array
79
+ self.is_window = is_window
80
+ self.window = window
81
+
82
+ def __str__(self) -> str:
83
+ return self.name + "()"
84
+
85
+ def over(self, window: Window) -> "Func":
86
+ if not self.is_window:
87
+ raise DataChainParamsError(f"{self} doesn't support window (over())")
88
+
89
+ return Func(
90
+ "over",
91
+ self.inner,
92
+ self.col,
93
+ self.result_type,
94
+ self.is_array,
95
+ self.is_window,
96
+ window,
97
+ )
98
+
99
+ @property
100
+ def db_col(self) -> Optional[str]:
101
+ return ColumnMeta.to_db_name(self.col) if self.col else None
102
+
103
+ def db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]:
104
+ if not self.db_col:
105
+ return None
106
+ col_type: type = signals_schema.get_column_type(self.db_col)
107
+ return list[col_type] if self.is_array else col_type # type: ignore[valid-type]
108
+
109
+ def get_result_type(self, signals_schema: "SignalSchema") -> "DataType":
110
+ if self.result_type:
111
+ return self.result_type
112
+
113
+ if col_type := self.db_col_type(signals_schema):
114
+ return col_type
115
+
116
+ raise DataChainColumnError(
117
+ str(self),
118
+ "Column name is required to infer result type",
119
+ )
120
+
121
+ def get_column(
122
+ self, signals_schema: "SignalSchema", label: Optional[str] = None
123
+ ) -> Column:
124
+ col_type = self.get_result_type(signals_schema)
125
+ sql_type = python_to_sql(col_type)
126
+
127
+ if self.col:
128
+ col = Column(self.db_col, sql_type)
129
+ func_col = self.inner(col)
130
+ else:
131
+ func_col = self.inner()
132
+
133
+ if self.is_window:
134
+ if not self.window:
135
+ raise DataChainParamsError(
136
+ f"Window function {self} requires over() clause with a window spec",
137
+ )
138
+ func_col = func_col.over(
139
+ partition_by=self.window.partition_by,
140
+ order_by=(
141
+ desc(self.window.order_by)
142
+ if self.window.desc
143
+ else self.window.order_by
144
+ ),
145
+ )
146
+
147
+ func_col.type = sql_type
148
+
149
+ if label:
150
+ func_col = func_col.label(label)
151
+
152
+ return func_col
datachain/lib/listing.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import posixpath
2
2
  from collections.abc import Iterator
3
- from datetime import datetime, timedelta, timezone
4
3
  from typing import TYPE_CHECKING, Callable, Optional, TypeVar
5
4
 
6
5
  from fsspec.asyn import get_loop
@@ -85,12 +84,13 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]:
85
84
  storage_uri, path = Client.parse_url(uri)
86
85
  telemetry.log_param("client", client.PREFIX)
87
86
 
88
- # clean path without globs
89
- lst_uri_path = (
90
- posixpath.dirname(path) if uses_glob(path) or client.fs.isfile(uri) else path
91
- )
87
+ if uses_glob(path) or client.fs.isfile(uri):
88
+ lst_uri_path = posixpath.dirname(path)
89
+ else:
90
+ storage_uri, path = Client.parse_url(f'{uri.rstrip("/")}/')
91
+ lst_uri_path = path
92
92
 
93
- lst_uri = f"{storage_uri}/{lst_uri_path.lstrip('/')}"
93
+ lst_uri = f'{storage_uri}/{lst_uri_path.lstrip("/")}'
94
94
  ds_name = (
95
95
  f"{LISTING_PREFIX}{storage_uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
96
96
  )
@@ -108,18 +108,3 @@ def listing_uri_from_name(dataset_name: str) -> str:
108
108
  if not is_listing_dataset(dataset_name):
109
109
  raise ValueError(f"Dataset {dataset_name} is not a listing")
110
110
  return dataset_name.removeprefix(LISTING_PREFIX)
111
-
112
-
113
- def is_listing_expired(created_at: datetime) -> bool:
114
- """Checks if listing has expired based on it's creation date"""
115
- return datetime.now(timezone.utc) > created_at + timedelta(seconds=LISTING_TTL)
116
-
117
-
118
- def is_listing_subset(ds1_name: str, ds2_name: str) -> bool:
119
- """
120
- Checks if one listing contains another one by comparing corresponding dataset names
121
- """
122
- assert ds1_name.endswith("/")
123
- assert ds2_name.endswith("/")
124
-
125
- return ds2_name.startswith(ds1_name)
@@ -30,3 +30,7 @@ class ListingInfo(DatasetInfo):
30
30
  def last_inserted_at(self):
31
31
  # TODO we need to add updated_at to dataset version or explicit last_inserted_at
32
32
  raise NotImplementedError
33
+
34
+ def contains(self, other_name: str) -> bool:
35
+ """Checks if this listing contains another one"""
36
+ return other_name.startswith(self.name)
@@ -20,6 +20,7 @@ from typing import ( # noqa: UP035
20
20
  )
21
21
 
22
22
  from pydantic import BaseModel, create_model
23
+ from sqlalchemy import ColumnElement
23
24
  from typing_extensions import Literal as LiteralEx
24
25
 
25
26
  from datachain.lib.convert.python_to_sql import python_to_sql
@@ -27,6 +28,7 @@ from datachain.lib.convert.sql_to_python import sql_to_python
27
28
  from datachain.lib.convert.unflatten import unflatten_to_json_pos
28
29
  from datachain.lib.data_model import DataModel, DataType, DataValue
29
30
  from datachain.lib.file import File
31
+ from datachain.lib.func import Func
30
32
  from datachain.lib.model_store import ModelStore
31
33
  from datachain.lib.utils import DataChainParamsError
32
34
  from datachain.query.schema import DEFAULT_DELIMITER, Column
@@ -400,6 +402,12 @@ class SignalSchema:
400
402
  if ModelStore.is_pydantic(finfo.annotation):
401
403
  SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)
402
404
 
405
+ def get_column_type(self, col_name: str) -> DataType:
406
+ for path, _type, has_subtree, _ in self.get_flat_tree():
407
+ if not has_subtree and DEFAULT_DELIMITER.join(path) == col_name:
408
+ return _type
409
+ raise SignalResolvingError([col_name], "is not found")
410
+
403
411
  def db_signals(
404
412
  self, name: Optional[str] = None, as_columns=False
405
413
  ) -> Union[list[str], list[Column]]:
@@ -484,13 +492,14 @@ class SignalSchema:
484
492
  # renaming existing signal
485
493
  del new_values[value.name]
486
494
  new_values[name] = self.values[value.name]
487
- elif name in self.values:
488
- # changing the type of existing signal, e.g File -> ImageFile
489
- del new_values[name]
490
- new_values[name] = args_map[name]
491
- else:
495
+ elif isinstance(value, Func):
496
+ # adding new signal with function
497
+ new_values[name] = value.get_result_type(self)
498
+ elif isinstance(value, ColumnElement):
492
499
  # adding new signal
493
- new_values.update(sql_to_python({name: value}))
500
+ new_values[name] = sql_to_python(value)
501
+ else:
502
+ new_values[name] = value
494
503
 
495
504
  return SignalSchema(new_values)
496
505
 
@@ -534,12 +543,12 @@ class SignalSchema:
534
543
  for name, val in values.items()
535
544
  }
536
545
 
537
- def get_flat_tree(self) -> Iterator[tuple[list[str], type, bool, int]]:
546
+ def get_flat_tree(self) -> Iterator[tuple[list[str], DataType, bool, int]]:
538
547
  yield from self._get_flat_tree(self.tree, [], 0)
539
548
 
540
549
  def _get_flat_tree(
541
550
  self, tree: dict, prefix: list[str], depth: int
542
- ) -> Iterator[tuple[list[str], type, bool, int]]:
551
+ ) -> Iterator[tuple[list[str], DataType, bool, int]]:
543
552
  for name, (type_, substree) in tree.items():
544
553
  suffix = name.split(".")
545
554
  new_prefix = prefix + suffix
datachain/lib/udf.py CHANGED
@@ -11,7 +11,6 @@ from datachain.dataset import RowDict
11
11
  from datachain.lib.convert.flatten import flatten
12
12
  from datachain.lib.data_model import DataValue
13
13
  from datachain.lib.file import File
14
- from datachain.lib.signal_schema import SignalSchema
15
14
  from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
16
15
  from datachain.query.batch import (
17
16
  Batch,
@@ -25,6 +24,7 @@ if TYPE_CHECKING:
25
24
  from typing_extensions import Self
26
25
 
27
26
  from datachain.catalog import Catalog
27
+ from datachain.lib.signal_schema import SignalSchema
28
28
  from datachain.lib.udf_signature import UdfSignature
29
29
  from datachain.query.batch import RowsOutput
30
30
 
@@ -172,7 +172,7 @@ class UDFBase(AbstractUDF):
172
172
  def _init(
173
173
  self,
174
174
  sign: "UdfSignature",
175
- params: SignalSchema,
175
+ params: "SignalSchema",
176
176
  func: Optional[Callable],
177
177
  ):
178
178
  self.params = params
@@ -183,7 +183,7 @@ class UDFBase(AbstractUDF):
183
183
  def _create(
184
184
  cls,
185
185
  sign: "UdfSignature",
186
- params: SignalSchema,
186
+ params: "SignalSchema",
187
187
  ) -> "Self":
188
188
  if isinstance(sign.func, AbstractUDF):
189
189
  if not isinstance(sign.func, cls): # type: ignore[unreachable]
datachain/lib/utils.py CHANGED
@@ -23,3 +23,8 @@ class DataChainError(Exception):
23
23
  class DataChainParamsError(DataChainError):
24
24
  def __init__(self, message):
25
25
  super().__init__(message)
26
+
27
+
28
+ class DataChainColumnError(DataChainParamsError):
29
+ def __init__(self, col_name, msg):
30
+ super().__init__(f"Error for column {col_name}: {msg}")