datachain 0.6.1__py3-none-any.whl → 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +61 -219
- datachain/cli.py +136 -22
- datachain/client/fsspec.py +9 -0
- datachain/client/local.py +11 -32
- datachain/config.py +126 -51
- datachain/data_storage/schema.py +66 -33
- datachain/data_storage/sqlite.py +4 -4
- datachain/data_storage/warehouse.py +101 -125
- datachain/lib/arrow.py +2 -15
- datachain/lib/data_model.py +10 -2
- datachain/lib/dc.py +211 -52
- datachain/lib/func/__init__.py +20 -2
- datachain/lib/func/aggregate.py +319 -8
- datachain/lib/func/func.py +97 -9
- datachain/lib/listing.py +6 -21
- datachain/lib/listing_info.py +4 -0
- datachain/lib/signal_schema.py +8 -5
- datachain/lib/udf.py +3 -3
- datachain/lib/utils.py +30 -0
- datachain/listing.py +22 -48
- datachain/query/dataset.py +11 -3
- datachain/remote/studio.py +63 -14
- datachain/studio.py +129 -0
- datachain/utils.py +58 -0
- {datachain-0.6.1.dist-info → datachain-0.6.3.dist-info}/METADATA +7 -6
- {datachain-0.6.1.dist-info → datachain-0.6.3.dist-info}/RECORD +30 -29
- {datachain-0.6.1.dist-info → datachain-0.6.3.dist-info}/WHEEL +1 -1
- {datachain-0.6.1.dist-info → datachain-0.6.3.dist-info}/LICENSE +0 -0
- {datachain-0.6.1.dist-info → datachain-0.6.3.dist-info}/entry_points.txt +0 -0
- {datachain-0.6.1.dist-info → datachain-0.6.3.dist-info}/top_level.txt +0 -0
datachain/lib/func/aggregate.py
CHANGED
|
@@ -8,35 +8,346 @@ from .func import Func
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def count(col: Optional[str] = None) -> Func:
|
|
11
|
-
|
|
11
|
+
"""
|
|
12
|
+
Returns the COUNT aggregate SQL function for the given column name.
|
|
13
|
+
|
|
14
|
+
The COUNT function returns the number of rows in a table.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
col (str, optional): The name of the column for which to count rows.
|
|
18
|
+
If not provided, it defaults to counting all rows.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Func: A Func object that represents the COUNT aggregate function.
|
|
22
|
+
|
|
23
|
+
Example:
|
|
24
|
+
```py
|
|
25
|
+
dc.group_by(
|
|
26
|
+
count=func.count(),
|
|
27
|
+
partition_by="signal.category",
|
|
28
|
+
)
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
Notes:
|
|
32
|
+
- Result column will always be of type int.
|
|
33
|
+
"""
|
|
34
|
+
return Func("count", inner=sa_func.count, col=col, result_type=int)
|
|
12
35
|
|
|
13
36
|
|
|
14
37
|
def sum(col: str) -> Func:
|
|
15
|
-
|
|
38
|
+
"""
|
|
39
|
+
Returns the SUM aggregate SQL function for the given column name.
|
|
40
|
+
|
|
41
|
+
The SUM function returns the total sum of a numeric column in a table.
|
|
42
|
+
It sums up all the values for the specified column.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
col (str): The name of the column for which to calculate the sum.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Func: A Func object that represents the SUM aggregate function.
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
```py
|
|
52
|
+
dc.group_by(
|
|
53
|
+
files_size=func.sum("file.size"),
|
|
54
|
+
partition_by="signal.category",
|
|
55
|
+
)
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Notes:
|
|
59
|
+
- The `sum` function should be used on numeric columns.
|
|
60
|
+
- Result column type will be the same as the input column type.
|
|
61
|
+
"""
|
|
62
|
+
return Func("sum", inner=sa_func.sum, col=col)
|
|
16
63
|
|
|
17
64
|
|
|
18
65
|
def avg(col: str) -> Func:
|
|
19
|
-
|
|
66
|
+
"""
|
|
67
|
+
Returns the AVG aggregate SQL function for the given column name.
|
|
68
|
+
|
|
69
|
+
The AVG function returns the average of a numeric column in a table.
|
|
70
|
+
It calculates the mean of all values in the specified column.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
col (str): The name of the column for which to calculate the average.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Func: A Func object that represents the AVG aggregate function.
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
```py
|
|
80
|
+
dc.group_by(
|
|
81
|
+
avg_file_size=func.avg("file.size"),
|
|
82
|
+
partition_by="signal.category",
|
|
83
|
+
)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
Notes:
|
|
87
|
+
- The `avg` function should be used on numeric columns.
|
|
88
|
+
- Result column will always be of type float.
|
|
89
|
+
"""
|
|
90
|
+
return Func("avg", inner=dc_func.aggregate.avg, col=col, result_type=float)
|
|
20
91
|
|
|
21
92
|
|
|
22
93
|
def min(col: str) -> Func:
|
|
23
|
-
|
|
94
|
+
"""
|
|
95
|
+
Returns the MIN aggregate SQL function for the given column name.
|
|
96
|
+
|
|
97
|
+
The MIN function returns the smallest value in the specified column.
|
|
98
|
+
It can be used on both numeric and non-numeric columns to find the minimum value.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
col (str): The name of the column for which to find the minimum value.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Func: A Func object that represents the MIN aggregate function.
|
|
105
|
+
|
|
106
|
+
Example:
|
|
107
|
+
```py
|
|
108
|
+
dc.group_by(
|
|
109
|
+
smallest_file=func.min("file.size"),
|
|
110
|
+
partition_by="signal.category",
|
|
111
|
+
)
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
Notes:
|
|
115
|
+
- The `min` function can be used with numeric, date, and string columns.
|
|
116
|
+
- Result column will have the same type as the input column.
|
|
117
|
+
"""
|
|
118
|
+
return Func("min", inner=sa_func.min, col=col)
|
|
24
119
|
|
|
25
120
|
|
|
26
121
|
def max(col: str) -> Func:
|
|
27
|
-
|
|
122
|
+
"""
|
|
123
|
+
Returns the MAX aggregate SQL function for the given column name.
|
|
124
|
+
|
|
125
|
+
The MAX function returns the smallest value in the specified column.
|
|
126
|
+
It can be used on both numeric and non-numeric columns to find the maximum value.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
col (str): The name of the column for which to find the maximum value.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Func: A Func object that represents the MAX aggregate function.
|
|
133
|
+
|
|
134
|
+
Example:
|
|
135
|
+
```py
|
|
136
|
+
dc.group_by(
|
|
137
|
+
largest_file=func.max("file.size"),
|
|
138
|
+
partition_by="signal.category",
|
|
139
|
+
)
|
|
140
|
+
```
|
|
141
|
+
|
|
142
|
+
Notes:
|
|
143
|
+
- The `max` function can be used with numeric, date, and string columns.
|
|
144
|
+
- Result column will have the same type as the input column.
|
|
145
|
+
"""
|
|
146
|
+
return Func("max", inner=sa_func.max, col=col)
|
|
28
147
|
|
|
29
148
|
|
|
30
149
|
def any_value(col: str) -> Func:
|
|
31
|
-
|
|
150
|
+
"""
|
|
151
|
+
Returns the ANY_VALUE aggregate SQL function for the given column name.
|
|
152
|
+
|
|
153
|
+
The ANY_VALUE function returns an arbitrary value from the specified column.
|
|
154
|
+
It is useful when you do not care which particular value is returned,
|
|
155
|
+
as long as it comes from one of the rows in the group.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
col (str): The name of the column from which to return an arbitrary value.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Func: A Func object that represents the ANY_VALUE aggregate function.
|
|
162
|
+
|
|
163
|
+
Example:
|
|
164
|
+
```py
|
|
165
|
+
dc.group_by(
|
|
166
|
+
file_example=func.any_value("file.name"),
|
|
167
|
+
partition_by="signal.category",
|
|
168
|
+
)
|
|
169
|
+
```
|
|
170
|
+
|
|
171
|
+
Notes:
|
|
172
|
+
- The `any_value` function can be used with any type of column.
|
|
173
|
+
- Result column will have the same type as the input column.
|
|
174
|
+
- The result of `any_value` is non-deterministic,
|
|
175
|
+
meaning it may return different values for different executions.
|
|
176
|
+
"""
|
|
177
|
+
return Func("any_value", inner=dc_func.aggregate.any_value, col=col)
|
|
32
178
|
|
|
33
179
|
|
|
34
180
|
def collect(col: str) -> Func:
|
|
35
|
-
|
|
181
|
+
"""
|
|
182
|
+
Returns the COLLECT aggregate SQL function for the given column name.
|
|
183
|
+
|
|
184
|
+
The COLLECT function gathers all values from the specified column
|
|
185
|
+
into an array or similar structure. It is useful for combining values from a column
|
|
186
|
+
into a collection, often for further processing or aggregation.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
col (str): The name of the column from which to collect values.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
Func: A Func object that represents the COLLECT aggregate function.
|
|
193
|
+
|
|
194
|
+
Example:
|
|
195
|
+
```py
|
|
196
|
+
dc.group_by(
|
|
197
|
+
signals=func.collect("signal"),
|
|
198
|
+
partition_by="signal.category",
|
|
199
|
+
)
|
|
200
|
+
```
|
|
201
|
+
|
|
202
|
+
Notes:
|
|
203
|
+
- The `collect` function can be used with numeric and string columns.
|
|
204
|
+
- Result column will have an array type.
|
|
205
|
+
"""
|
|
206
|
+
return Func("collect", inner=dc_func.aggregate.collect, col=col, is_array=True)
|
|
36
207
|
|
|
37
208
|
|
|
38
209
|
def concat(col: str, separator="") -> Func:
|
|
210
|
+
"""
|
|
211
|
+
Returns the CONCAT aggregate SQL function for the given column name.
|
|
212
|
+
|
|
213
|
+
The CONCAT function concatenates values from the specified column
|
|
214
|
+
into a single string. It is useful for merging text values from multiple rows
|
|
215
|
+
into a single combined value.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
col (str): The name of the column from which to concatenate values.
|
|
219
|
+
separator (str, optional): The separator to use between concatenated values.
|
|
220
|
+
Defaults to an empty string.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Func: A Func object that represents the CONCAT aggregate function.
|
|
224
|
+
|
|
225
|
+
Example:
|
|
226
|
+
```py
|
|
227
|
+
dc.group_by(
|
|
228
|
+
files=func.concat("file.name", separator=", "),
|
|
229
|
+
partition_by="signal.category",
|
|
230
|
+
)
|
|
231
|
+
```
|
|
232
|
+
|
|
233
|
+
Notes:
|
|
234
|
+
- The `concat` function can be used with string columns.
|
|
235
|
+
- Result column will have a string type.
|
|
236
|
+
"""
|
|
237
|
+
|
|
39
238
|
def inner(arg):
|
|
40
239
|
return dc_func.aggregate.group_concat(arg, separator)
|
|
41
240
|
|
|
42
|
-
return Func(inner=inner, col=col, result_type=str)
|
|
241
|
+
return Func("concat", inner=inner, col=col, result_type=str)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def row_number() -> Func:
|
|
245
|
+
"""
|
|
246
|
+
Returns the ROW_NUMBER window function for SQL queries.
|
|
247
|
+
|
|
248
|
+
The ROW_NUMBER function assigns a unique sequential integer to rows
|
|
249
|
+
within a partition of a result set, starting from 1 for the first row
|
|
250
|
+
in each partition. It is commonly used to generate row numbers within
|
|
251
|
+
partitions or ordered results.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Func: A Func object that represents the ROW_NUMBER window function.
|
|
255
|
+
|
|
256
|
+
Example:
|
|
257
|
+
```py
|
|
258
|
+
window = func.window(partition_by="signal.category", order_by="created_at")
|
|
259
|
+
dc.mutate(
|
|
260
|
+
row_number=func.row_number().over(window),
|
|
261
|
+
)
|
|
262
|
+
```
|
|
263
|
+
|
|
264
|
+
Note:
|
|
265
|
+
- The result column will always be of type int.
|
|
266
|
+
"""
|
|
267
|
+
return Func("row_number", inner=sa_func.row_number, result_type=int, is_window=True)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def rank() -> Func:
|
|
271
|
+
"""
|
|
272
|
+
Returns the RANK window function for SQL queries.
|
|
273
|
+
|
|
274
|
+
The RANK function assigns a rank to each row within a partition of a result set,
|
|
275
|
+
with gaps in the ranking for ties. Rows with equal values receive the same rank,
|
|
276
|
+
and the next rank is skipped (i.e., if two rows are ranked 1,
|
|
277
|
+
the next row is ranked 3).
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
Func: A Func object that represents the RANK window function.
|
|
281
|
+
|
|
282
|
+
Example:
|
|
283
|
+
```py
|
|
284
|
+
window = func.window(partition_by="signal.category", order_by="created_at")
|
|
285
|
+
dc.mutate(
|
|
286
|
+
rank=func.rank().over(window),
|
|
287
|
+
)
|
|
288
|
+
```
|
|
289
|
+
|
|
290
|
+
Notes:
|
|
291
|
+
- The result column will always be of type int.
|
|
292
|
+
- The RANK function differs from ROW_NUMBER in that rows with the same value
|
|
293
|
+
in the ordering column(s) receive the same rank.
|
|
294
|
+
"""
|
|
295
|
+
return Func("rank", inner=sa_func.rank, result_type=int, is_window=True)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def dense_rank() -> Func:
|
|
299
|
+
"""
|
|
300
|
+
Returns the DENSE_RANK window function for SQL queries.
|
|
301
|
+
|
|
302
|
+
The DENSE_RANK function assigns a rank to each row within a partition
|
|
303
|
+
of a result set, without gaps in the ranking for ties. Rows with equal values
|
|
304
|
+
receive the same rank, but the next rank is assigned consecutively
|
|
305
|
+
(i.e., if two rows are ranked 1, the next row will be ranked 2).
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Func: A Func object that represents the DENSE_RANK window function.
|
|
309
|
+
|
|
310
|
+
Example:
|
|
311
|
+
```py
|
|
312
|
+
window = func.window(partition_by="signal.category", order_by="created_at")
|
|
313
|
+
dc.mutate(
|
|
314
|
+
dense_rank=func.dense_rank().over(window),
|
|
315
|
+
)
|
|
316
|
+
```
|
|
317
|
+
|
|
318
|
+
Notes:
|
|
319
|
+
- The result column will always be of type int.
|
|
320
|
+
- The DENSE_RANK function differs from RANK in that it does not leave gaps
|
|
321
|
+
in the ranking for tied values.
|
|
322
|
+
"""
|
|
323
|
+
return Func("dense_rank", inner=sa_func.dense_rank, result_type=int, is_window=True)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def first(col: str) -> Func:
|
|
327
|
+
"""
|
|
328
|
+
Returns the FIRST_VALUE window function for SQL queries.
|
|
329
|
+
|
|
330
|
+
The FIRST_VALUE function returns the first value in an ordered set of values
|
|
331
|
+
within a partition. The first value is determined by the specified order
|
|
332
|
+
and can be useful for retrieving the leading value in a group of rows.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
col (str): The name of the column from which to retrieve the first value.
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
Func: A Func object that represents the FIRST_VALUE window function.
|
|
339
|
+
|
|
340
|
+
Example:
|
|
341
|
+
```py
|
|
342
|
+
window = func.window(partition_by="signal.category", order_by="created_at")
|
|
343
|
+
dc.mutate(
|
|
344
|
+
first_file=func.first("file.name").over(window),
|
|
345
|
+
)
|
|
346
|
+
```
|
|
347
|
+
|
|
348
|
+
Note:
|
|
349
|
+
- The result of `first_value` will always reflect the value of the first row
|
|
350
|
+
in the specified order.
|
|
351
|
+
- The result column will have the same type as the input column.
|
|
352
|
+
"""
|
|
353
|
+
return Func("first", inner=sa_func.first_value, col=col, is_window=True)
|
datachain/lib/func/func.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
1
2
|
from typing import TYPE_CHECKING, Callable, Optional
|
|
2
3
|
|
|
4
|
+
from sqlalchemy import desc
|
|
5
|
+
|
|
3
6
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
4
|
-
from datachain.lib.utils import DataChainColumnError
|
|
7
|
+
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
|
|
5
8
|
from datachain.query.schema import Column, ColumnMeta
|
|
6
9
|
|
|
7
10
|
if TYPE_CHECKING:
|
|
@@ -9,18 +12,89 @@ if TYPE_CHECKING:
|
|
|
9
12
|
from datachain.lib.signal_schema import SignalSchema
|
|
10
13
|
|
|
11
14
|
|
|
15
|
+
@dataclass
|
|
16
|
+
class Window:
|
|
17
|
+
"""Represents a window specification for SQL window functions."""
|
|
18
|
+
|
|
19
|
+
partition_by: str
|
|
20
|
+
order_by: str
|
|
21
|
+
desc: bool = False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def window(partition_by: str, order_by: str, desc: bool = False) -> Window:
|
|
25
|
+
"""
|
|
26
|
+
Defines a window specification for SQL window functions.
|
|
27
|
+
|
|
28
|
+
The `window` function specifies how to partition and order the result set
|
|
29
|
+
for the associated window function. It is used to define the scope of the rows
|
|
30
|
+
that the window function will operate on.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
partition_by (str): The column name by which to partition the result set.
|
|
34
|
+
Rows with the same value in the partition column
|
|
35
|
+
will be grouped together for the window function.
|
|
36
|
+
order_by (str): The column name by which to order the rows
|
|
37
|
+
within each partition. This determines the sequence in which
|
|
38
|
+
the window function is applied.
|
|
39
|
+
desc (bool, optional): If True, the rows will be ordered in descending order.
|
|
40
|
+
Defaults to False, which orders the rows
|
|
41
|
+
in ascending order.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Window: A Window object representing the window specification.
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
```py
|
|
48
|
+
window = func.window(partition_by="signal.category", order_by="created_at")
|
|
49
|
+
dc.mutate(
|
|
50
|
+
row_number=func.row_number().over(window),
|
|
51
|
+
)
|
|
52
|
+
```
|
|
53
|
+
"""
|
|
54
|
+
return Window(
|
|
55
|
+
ColumnMeta.to_db_name(partition_by),
|
|
56
|
+
ColumnMeta.to_db_name(order_by),
|
|
57
|
+
desc,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
12
61
|
class Func:
|
|
62
|
+
"""Represents a function to be applied to a column in a SQL query."""
|
|
63
|
+
|
|
13
64
|
def __init__(
|
|
14
65
|
self,
|
|
66
|
+
name: str,
|
|
15
67
|
inner: Callable,
|
|
16
68
|
col: Optional[str] = None,
|
|
17
69
|
result_type: Optional["DataType"] = None,
|
|
18
70
|
is_array: bool = False,
|
|
71
|
+
is_window: bool = False,
|
|
72
|
+
window: Optional[Window] = None,
|
|
19
73
|
) -> None:
|
|
74
|
+
self.name = name
|
|
20
75
|
self.inner = inner
|
|
21
76
|
self.col = col
|
|
22
77
|
self.result_type = result_type
|
|
23
78
|
self.is_array = is_array
|
|
79
|
+
self.is_window = is_window
|
|
80
|
+
self.window = window
|
|
81
|
+
|
|
82
|
+
def __str__(self) -> str:
|
|
83
|
+
return self.name + "()"
|
|
84
|
+
|
|
85
|
+
def over(self, window: Window) -> "Func":
|
|
86
|
+
if not self.is_window:
|
|
87
|
+
raise DataChainParamsError(f"{self} doesn't support window (over())")
|
|
88
|
+
|
|
89
|
+
return Func(
|
|
90
|
+
"over",
|
|
91
|
+
self.inner,
|
|
92
|
+
self.col,
|
|
93
|
+
self.result_type,
|
|
94
|
+
self.is_array,
|
|
95
|
+
self.is_window,
|
|
96
|
+
window,
|
|
97
|
+
)
|
|
24
98
|
|
|
25
99
|
@property
|
|
26
100
|
def db_col(self) -> Optional[str]:
|
|
@@ -33,31 +107,45 @@ class Func:
|
|
|
33
107
|
return list[col_type] if self.is_array else col_type # type: ignore[valid-type]
|
|
34
108
|
|
|
35
109
|
def get_result_type(self, signals_schema: "SignalSchema") -> "DataType":
|
|
36
|
-
col_type = self.db_col_type(signals_schema)
|
|
37
|
-
|
|
38
110
|
if self.result_type:
|
|
39
111
|
return self.result_type
|
|
40
112
|
|
|
41
|
-
if col_type:
|
|
113
|
+
if col_type := self.db_col_type(signals_schema):
|
|
42
114
|
return col_type
|
|
43
115
|
|
|
44
116
|
raise DataChainColumnError(
|
|
45
|
-
str(self
|
|
117
|
+
str(self),
|
|
46
118
|
"Column name is required to infer result type",
|
|
47
119
|
)
|
|
48
120
|
|
|
49
121
|
def get_column(
|
|
50
122
|
self, signals_schema: "SignalSchema", label: Optional[str] = None
|
|
51
123
|
) -> Column:
|
|
124
|
+
col_type = self.get_result_type(signals_schema)
|
|
125
|
+
sql_type = python_to_sql(col_type)
|
|
126
|
+
|
|
52
127
|
if self.col:
|
|
53
|
-
|
|
54
|
-
print(label)
|
|
55
|
-
col_type = self.get_result_type(signals_schema)
|
|
56
|
-
col = Column(self.db_col, python_to_sql(col_type))
|
|
128
|
+
col = Column(self.db_col, sql_type)
|
|
57
129
|
func_col = self.inner(col)
|
|
58
130
|
else:
|
|
59
131
|
func_col = self.inner()
|
|
60
132
|
|
|
133
|
+
if self.is_window:
|
|
134
|
+
if not self.window:
|
|
135
|
+
raise DataChainParamsError(
|
|
136
|
+
f"Window function {self} requires over() clause with a window spec",
|
|
137
|
+
)
|
|
138
|
+
func_col = func_col.over(
|
|
139
|
+
partition_by=self.window.partition_by,
|
|
140
|
+
order_by=(
|
|
141
|
+
desc(self.window.order_by)
|
|
142
|
+
if self.window.desc
|
|
143
|
+
else self.window.order_by
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
func_col.type = sql_type
|
|
148
|
+
|
|
61
149
|
if label:
|
|
62
150
|
func_col = func_col.label(label)
|
|
63
151
|
|
datachain/lib/listing.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import posixpath
|
|
2
2
|
from collections.abc import Iterator
|
|
3
|
-
from datetime import datetime, timedelta, timezone
|
|
4
3
|
from typing import TYPE_CHECKING, Callable, Optional, TypeVar
|
|
5
4
|
|
|
6
5
|
from fsspec.asyn import get_loop
|
|
@@ -85,12 +84,13 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]:
|
|
|
85
84
|
storage_uri, path = Client.parse_url(uri)
|
|
86
85
|
telemetry.log_param("client", client.PREFIX)
|
|
87
86
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
87
|
+
if uses_glob(path) or client.fs.isfile(uri):
|
|
88
|
+
lst_uri_path = posixpath.dirname(path)
|
|
89
|
+
else:
|
|
90
|
+
storage_uri, path = Client.parse_url(f'{uri.rstrip("/")}/')
|
|
91
|
+
lst_uri_path = path
|
|
92
92
|
|
|
93
|
-
lst_uri = f
|
|
93
|
+
lst_uri = f'{storage_uri}/{lst_uri_path.lstrip("/")}'
|
|
94
94
|
ds_name = (
|
|
95
95
|
f"{LISTING_PREFIX}{storage_uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
|
|
96
96
|
)
|
|
@@ -108,18 +108,3 @@ def listing_uri_from_name(dataset_name: str) -> str:
|
|
|
108
108
|
if not is_listing_dataset(dataset_name):
|
|
109
109
|
raise ValueError(f"Dataset {dataset_name} is not a listing")
|
|
110
110
|
return dataset_name.removeprefix(LISTING_PREFIX)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
def is_listing_expired(created_at: datetime) -> bool:
|
|
114
|
-
"""Checks if listing has expired based on it's creation date"""
|
|
115
|
-
return datetime.now(timezone.utc) > created_at + timedelta(seconds=LISTING_TTL)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
def is_listing_subset(ds1_name: str, ds2_name: str) -> bool:
|
|
119
|
-
"""
|
|
120
|
-
Checks if one listing contains another one by comparing corresponding dataset names
|
|
121
|
-
"""
|
|
122
|
-
assert ds1_name.endswith("/")
|
|
123
|
-
assert ds2_name.endswith("/")
|
|
124
|
-
|
|
125
|
-
return ds2_name.startswith(ds1_name)
|
datachain/lib/listing_info.py
CHANGED
|
@@ -30,3 +30,7 @@ class ListingInfo(DatasetInfo):
|
|
|
30
30
|
def last_inserted_at(self):
|
|
31
31
|
# TODO we need to add updated_at to dataset version or explicit last_inserted_at
|
|
32
32
|
raise NotImplementedError
|
|
33
|
+
|
|
34
|
+
def contains(self, other_name: str) -> bool:
|
|
35
|
+
"""Checks if this listing contains another one"""
|
|
36
|
+
return other_name.startswith(self.name)
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -20,6 +20,7 @@ from typing import ( # noqa: UP035
|
|
|
20
20
|
)
|
|
21
21
|
|
|
22
22
|
from pydantic import BaseModel, create_model
|
|
23
|
+
from sqlalchemy import ColumnElement
|
|
23
24
|
from typing_extensions import Literal as LiteralEx
|
|
24
25
|
|
|
25
26
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
@@ -27,6 +28,7 @@ from datachain.lib.convert.sql_to_python import sql_to_python
|
|
|
27
28
|
from datachain.lib.convert.unflatten import unflatten_to_json_pos
|
|
28
29
|
from datachain.lib.data_model import DataModel, DataType, DataValue
|
|
29
30
|
from datachain.lib.file import File
|
|
31
|
+
from datachain.lib.func import Func
|
|
30
32
|
from datachain.lib.model_store import ModelStore
|
|
31
33
|
from datachain.lib.utils import DataChainParamsError
|
|
32
34
|
from datachain.query.schema import DEFAULT_DELIMITER, Column
|
|
@@ -490,13 +492,14 @@ class SignalSchema:
|
|
|
490
492
|
# renaming existing signal
|
|
491
493
|
del new_values[value.name]
|
|
492
494
|
new_values[name] = self.values[value.name]
|
|
493
|
-
elif
|
|
494
|
-
#
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
else:
|
|
495
|
+
elif isinstance(value, Func):
|
|
496
|
+
# adding new signal with function
|
|
497
|
+
new_values[name] = value.get_result_type(self)
|
|
498
|
+
elif isinstance(value, ColumnElement):
|
|
498
499
|
# adding new signal
|
|
499
500
|
new_values[name] = sql_to_python(value)
|
|
501
|
+
else:
|
|
502
|
+
new_values[name] = value
|
|
500
503
|
|
|
501
504
|
return SignalSchema(new_values)
|
|
502
505
|
|
datachain/lib/udf.py
CHANGED
|
@@ -11,7 +11,6 @@ from datachain.dataset import RowDict
|
|
|
11
11
|
from datachain.lib.convert.flatten import flatten
|
|
12
12
|
from datachain.lib.data_model import DataValue
|
|
13
13
|
from datachain.lib.file import File
|
|
14
|
-
from datachain.lib.signal_schema import SignalSchema
|
|
15
14
|
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
16
15
|
from datachain.query.batch import (
|
|
17
16
|
Batch,
|
|
@@ -25,6 +24,7 @@ if TYPE_CHECKING:
|
|
|
25
24
|
from typing_extensions import Self
|
|
26
25
|
|
|
27
26
|
from datachain.catalog import Catalog
|
|
27
|
+
from datachain.lib.signal_schema import SignalSchema
|
|
28
28
|
from datachain.lib.udf_signature import UdfSignature
|
|
29
29
|
from datachain.query.batch import RowsOutput
|
|
30
30
|
|
|
@@ -172,7 +172,7 @@ class UDFBase(AbstractUDF):
|
|
|
172
172
|
def _init(
|
|
173
173
|
self,
|
|
174
174
|
sign: "UdfSignature",
|
|
175
|
-
params: SignalSchema,
|
|
175
|
+
params: "SignalSchema",
|
|
176
176
|
func: Optional[Callable],
|
|
177
177
|
):
|
|
178
178
|
self.params = params
|
|
@@ -183,7 +183,7 @@ class UDFBase(AbstractUDF):
|
|
|
183
183
|
def _create(
|
|
184
184
|
cls,
|
|
185
185
|
sign: "UdfSignature",
|
|
186
|
-
params: SignalSchema,
|
|
186
|
+
params: "SignalSchema",
|
|
187
187
|
) -> "Self":
|
|
188
188
|
if isinstance(sign.func, AbstractUDF):
|
|
189
189
|
if not isinstance(sign.func, cls): # type: ignore[unreachable]
|
datachain/lib/utils.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Sequence
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
class AbstractUDF(ABC):
|
|
@@ -28,3 +30,31 @@ class DataChainParamsError(DataChainError):
|
|
|
28
30
|
class DataChainColumnError(DataChainParamsError):
|
|
29
31
|
def __init__(self, col_name, msg):
|
|
30
32
|
super().__init__(f"Error for column {col_name}: {msg}")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
|
|
36
|
+
gen_col_counter = 0
|
|
37
|
+
new_col_names = {}
|
|
38
|
+
org_col_names = set(col_names)
|
|
39
|
+
|
|
40
|
+
for org_column in col_names:
|
|
41
|
+
new_column = org_column.lower()
|
|
42
|
+
new_column = re.sub("[^0-9a-z]+", "_", new_column)
|
|
43
|
+
new_column = new_column.strip("_")
|
|
44
|
+
|
|
45
|
+
generated_column = new_column
|
|
46
|
+
|
|
47
|
+
while (
|
|
48
|
+
not generated_column.isidentifier()
|
|
49
|
+
or generated_column in new_col_names
|
|
50
|
+
or (generated_column != org_column and generated_column in org_col_names)
|
|
51
|
+
):
|
|
52
|
+
if new_column:
|
|
53
|
+
generated_column = f"c{gen_col_counter}_{new_column}"
|
|
54
|
+
else:
|
|
55
|
+
generated_column = f"c{gen_col_counter}"
|
|
56
|
+
gen_col_counter += 1
|
|
57
|
+
|
|
58
|
+
new_col_names[generated_column] = org_column
|
|
59
|
+
|
|
60
|
+
return new_col_names
|