datachain 0.18.4__py3-none-any.whl → 0.18.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -911,11 +911,7 @@ class Catalog:
911
911
  values["num_objects"] = None
912
912
  values["size"] = None
913
913
  values["preview"] = None
914
- self.metastore.update_dataset_version(
915
- dataset,
916
- version,
917
- **values,
918
- )
914
+ self.metastore.update_dataset_version(dataset, version, **values)
919
915
  return
920
916
 
921
917
  if not dataset_version.num_objects:
@@ -935,11 +931,7 @@ class Catalog:
935
931
  if not values:
936
932
  return
937
933
 
938
- self.metastore.update_dataset_version(
939
- dataset,
940
- version,
941
- **values,
942
- )
934
+ self.metastore.update_dataset_version(dataset, version, **values)
943
935
 
944
936
  def update_dataset(
945
937
  self, dataset: DatasetRecord, conn=None, **kwargs
datachain/client/azure.py CHANGED
@@ -65,7 +65,7 @@ class AzureClient(Client):
65
65
  if entries:
66
66
  await result_queue.put(entries)
67
67
  pbar.update(len(entries))
68
- if not found:
68
+ if not found and prefix:
69
69
  raise FileNotFoundError(
70
70
  f"Unable to resolve remote path: {prefix}"
71
71
  )
datachain/client/gcs.py CHANGED
@@ -74,7 +74,7 @@ class GCSClient(Client):
74
74
  try:
75
75
  await self._get_pages(prefix, page_queue)
76
76
  found = await consumer
77
- if not found:
77
+ if not found and prefix:
78
78
  raise FileNotFoundError(f"Unable to resolve remote path: {prefix}")
79
79
  finally:
80
80
  consumer.cancel() # In case _get_pages() raised
datachain/client/s3.py CHANGED
@@ -80,7 +80,7 @@ class ClientS3(Client):
80
80
  finally:
81
81
  await page_queue.put(None)
82
82
 
83
- async def process_pages(page_queue, result_queue):
83
+ async def process_pages(page_queue, result_queue, prefix):
84
84
  found = False
85
85
  with tqdm(desc=f"Listing {self.uri}", unit=" objects", leave=False) as pbar:
86
86
  while (res := await page_queue.get()) is not None:
@@ -94,7 +94,7 @@ class ClientS3(Client):
94
94
  if entries:
95
95
  await result_queue.put(entries)
96
96
  pbar.update(len(entries))
97
- if not found:
97
+ if not found and prefix:
98
98
  raise FileNotFoundError(f"Unable to resolve remote path: {prefix}")
99
99
 
100
100
  try:
@@ -118,7 +118,9 @@ class ClientS3(Client):
118
118
  Delimiter="",
119
119
  )
120
120
  page_queue: asyncio.Queue[list] = asyncio.Queue(2)
121
- consumer = asyncio.create_task(process_pages(page_queue, result_queue))
121
+ consumer = asyncio.create_task(
122
+ process_pages(page_queue, result_queue, prefix)
123
+ )
122
124
  try:
123
125
  await get_pages(it, page_queue)
124
126
  await consumer
@@ -36,6 +36,7 @@ from datachain.dataset import (
36
36
  )
37
37
  from datachain.error import (
38
38
  DatasetNotFoundError,
39
+ DatasetVersionNotFoundError,
39
40
  TableMissingError,
40
41
  )
41
42
  from datachain.job import Job
@@ -273,7 +274,6 @@ class AbstractMetastore(ABC, Serializable):
273
274
  self,
274
275
  job_id: str,
275
276
  status: Optional[JobStatus] = None,
276
- exit_code: Optional[int] = None,
277
277
  error_message: Optional[str] = None,
278
278
  error_stack: Optional[str] = None,
279
279
  finished_at: Optional[datetime] = None,
@@ -620,22 +620,36 @@ class AbstractDBMetastore(AbstractMetastore):
620
620
  self, dataset: DatasetRecord, conn=None, **kwargs
621
621
  ) -> DatasetRecord:
622
622
  """Updates dataset fields."""
623
- values = {}
624
- dataset_values = {}
623
+ values: dict[str, Any] = {}
624
+ dataset_values: dict[str, Any] = {}
625
625
  for field, value in kwargs.items():
626
- if field in self._dataset_fields[1:]:
627
- if field in ["attrs", "schema"]:
628
- values[field] = json.dumps(value) if value else None
626
+ if field in ("id", "created_at") or field not in self._dataset_fields:
627
+ continue # these fields are read-only or not applicable
628
+
629
+ if value is None and field in ("name", "status", "sources", "query_script"):
630
+ raise ValueError(f"Field {field} cannot be None")
631
+ if field == "name" and not value:
632
+ raise ValueError("name cannot be empty")
633
+
634
+ if field == "attrs":
635
+ if value is None:
636
+ values[field] = None
629
637
  else:
630
- values[field] = value
631
- if field == "schema":
632
- dataset_values[field] = DatasetRecord.parse_schema(value)
638
+ values[field] = json.dumps(value)
639
+ dataset_values[field] = value
640
+ elif field == "schema":
641
+ if value is None:
642
+ values[field] = None
643
+ dataset_values[field] = None
633
644
  else:
634
- dataset_values[field] = value
645
+ values[field] = json.dumps(value)
646
+ dataset_values[field] = DatasetRecord.parse_schema(value)
647
+ else:
648
+ values[field] = value
649
+ dataset_values[field] = value
635
650
 
636
651
  if not values:
637
- # Nothing to update
638
- return dataset
652
+ return dataset # nothing to update
639
653
 
640
654
  d = self._datasets
641
655
  self.db.execute(
@@ -651,36 +665,70 @@ class AbstractDBMetastore(AbstractMetastore):
651
665
  self, dataset: DatasetRecord, version: str, conn=None, **kwargs
652
666
  ) -> DatasetVersion:
653
667
  """Updates dataset fields."""
654
- dataset_version = dataset.get_version(version)
655
-
656
- values = {}
657
- version_values: dict = {}
668
+ values: dict[str, Any] = {}
669
+ version_values: dict[str, Any] = {}
658
670
  for field, value in kwargs.items():
659
- if field in self._dataset_version_fields[1:]:
660
- if field == "schema":
661
- values[field] = json.dumps(value) if value else None
662
- version_values[field] = DatasetRecord.parse_schema(value)
663
- elif field == "feature_schema":
664
- values[field] = json.dumps(value) if value else None
665
- version_values[field] = value
666
- elif field == "preview" and isinstance(value, list):
667
- values[field] = json.dumps(value, cls=JSONSerialize)
668
- version_values[field] = value
671
+ if (
672
+ field in ("id", "created_at")
673
+ or field not in self._dataset_version_fields
674
+ ):
675
+ continue # these fields are read-only or not applicable
676
+
677
+ if value is None and field in (
678
+ "status",
679
+ "sources",
680
+ "query_script",
681
+ "error_message",
682
+ "error_stack",
683
+ "script_output",
684
+ "uuid",
685
+ ):
686
+ raise ValueError(f"Field {field} cannot be None")
687
+
688
+ if field == "schema":
689
+ values[field] = json.dumps(value) if value else None
690
+ version_values[field] = (
691
+ DatasetRecord.parse_schema(value) if value else None
692
+ )
693
+ elif field == "feature_schema":
694
+ if value is None:
695
+ values[field] = None
696
+ else:
697
+ values[field] = json.dumps(value)
698
+ version_values[field] = value
699
+ elif field == "preview":
700
+ if value is None:
701
+ values[field] = None
702
+ elif not isinstance(value, list):
703
+ raise ValueError(
704
+ f"Field '{field}' must be a list, got {type(value).__name__}"
705
+ )
669
706
  else:
670
- values[field] = value
671
- version_values[field] = value
707
+ values[field] = json.dumps(value, cls=JSONSerialize)
708
+ version_values["_preview_data"] = value
709
+ else:
710
+ values[field] = value
711
+ version_values[field] = value
672
712
 
673
- if values:
674
- dv = self._datasets_versions
675
- self.db.execute(
676
- self._datasets_versions_update()
677
- .where(dv.c.dataset_id == dataset.id, dv.c.version == version)
678
- .values(values),
679
- conn=conn,
680
- ) # type: ignore [attr-defined]
681
- dataset_version.update(**version_values)
713
+ if not values:
714
+ return dataset.get_version(version)
715
+
716
+ dv = self._datasets_versions
717
+ self.db.execute(
718
+ self._datasets_versions_update()
719
+ .where(dv.c.dataset_id == dataset.id, dv.c.version == version)
720
+ .values(values),
721
+ conn=conn,
722
+ ) # type: ignore [attr-defined]
723
+
724
+ for v in dataset.versions:
725
+ if v.version == version:
726
+ v.update(**version_values)
727
+ return v
682
728
 
683
- return dataset_version
729
+ raise DatasetVersionNotFoundError(
730
+ f"Dataset {dataset.name} does not have version {version}"
731
+ )
684
732
 
685
733
  def _parse_dataset(self, rows) -> Optional[DatasetRecord]:
686
734
  versions = [self.dataset_class.parse(*r) for r in rows]
@@ -812,7 +860,7 @@ class AbstractDBMetastore(AbstractMetastore):
812
860
  update_data["error_message"] = error_message
813
861
  update_data["error_stack"] = error_stack
814
862
 
815
- self.update_dataset(dataset, conn=conn, **update_data)
863
+ dataset = self.update_dataset(dataset, conn=conn, **update_data)
816
864
 
817
865
  if version:
818
866
  self.update_dataset_version(dataset, version, conn=conn, **update_data)
@@ -1064,7 +1112,6 @@ class AbstractDBMetastore(AbstractMetastore):
1064
1112
  self,
1065
1113
  job_id: str,
1066
1114
  status: Optional[JobStatus] = None,
1067
- exit_code: Optional[int] = None,
1068
1115
  error_message: Optional[str] = None,
1069
1116
  error_stack: Optional[str] = None,
1070
1117
  finished_at: Optional[datetime] = None,
@@ -1075,8 +1122,6 @@ class AbstractDBMetastore(AbstractMetastore):
1075
1122
  values: dict = {}
1076
1123
  if status is not None:
1077
1124
  values["status"] = status
1078
- if exit_code is not None:
1079
- values["exit_code"] = exit_code
1080
1125
  if error_message is not None:
1081
1126
  values["error_message"] = error_message
1082
1127
  if error_stack is not None:
@@ -1,78 +1,89 @@
1
- from typing import Optional
1
+ from typing import Optional, Union
2
2
 
3
3
  from sqlalchemy import func as sa_func
4
4
 
5
+ from datachain.query.schema import Column
5
6
  from datachain.sql.functions import aggregate
6
7
 
7
8
  from .func import Func
8
9
 
9
10
 
10
- def count(col: Optional[str] = None) -> Func:
11
+ def count(col: Optional[Union[str, Column]] = None) -> Func:
11
12
  """
12
- Returns the COUNT aggregate SQL function for the given column name.
13
+ Returns a COUNT aggregate SQL function for the specified column.
13
14
 
14
- The COUNT function returns the number of rows in a table.
15
+ The COUNT function returns the number of rows, optionally filtered
16
+ by a specific column.
15
17
 
16
18
  Args:
17
- col (str, optional): The name of the column for which to count rows.
18
- If not provided, it defaults to counting all rows.
19
+ col (str | Column, optional): The column to count.
20
+ If omitted, counts all rows.
21
+ The column can be specified as a string or a `Column` object.
19
22
 
20
23
  Returns:
21
- Func: A Func object that represents the COUNT aggregate function.
24
+ Func: A `Func` object representing the COUNT aggregate function.
22
25
 
23
26
  Example:
24
27
  ```py
25
28
  dc.group_by(
26
- count=func.count(),
29
+ count1=func.count(),
30
+ count2=func.count("signal.id"),
31
+ count3=func.count(dc.C("signal.category")),
27
32
  partition_by="signal.category",
28
33
  )
29
34
  ```
30
35
 
31
36
  Notes:
32
- - Result column will always be of type int.
37
+ - The result column will always have an integer type.
33
38
  """
34
39
  return Func(
35
- "count", inner=sa_func.count, cols=[col] if col else None, result_type=int
40
+ "count",
41
+ inner=sa_func.count,
42
+ cols=[col] if col is not None else None,
43
+ result_type=int,
36
44
  )
37
45
 
38
46
 
39
- def sum(col: str) -> Func:
47
+ def sum(col: Union[str, Column]) -> Func:
40
48
  """
41
- Returns the SUM aggregate SQL function for the given column name.
49
+ Returns the SUM aggregate SQL function for the specified column.
42
50
 
43
51
  The SUM function returns the total sum of a numeric column in a table.
44
52
  It sums up all the values for the specified column.
45
53
 
46
54
  Args:
47
- col (str): The name of the column for which to calculate the sum.
55
+ col (str | Column): The name of the column for which to calculate the sum.
56
+ The column can be specified as a string or a `Column` object.
48
57
 
49
58
  Returns:
50
- Func: A Func object that represents the SUM aggregate function.
59
+ Func: A `Func` object that represents the SUM aggregate function.
51
60
 
52
61
  Example:
53
62
  ```py
54
63
  dc.group_by(
55
64
  files_size=func.sum("file.size"),
65
+ total_size=func.sum(dc.C("size")),
56
66
  partition_by="signal.category",
57
67
  )
58
68
  ```
59
69
 
60
70
  Notes:
61
71
  - The `sum` function should be used on numeric columns.
62
- - Result column type will be the same as the input column type.
72
+ - The result column type will be the same as the input column type.
63
73
  """
64
74
  return Func("sum", inner=sa_func.sum, cols=[col])
65
75
 
66
76
 
67
- def avg(col: str) -> Func:
77
+ def avg(col: Union[str, Column]) -> Func:
68
78
  """
69
- Returns the AVG aggregate SQL function for the given column name.
79
+ Returns the AVG aggregate SQL function for the specified column.
70
80
 
71
81
  The AVG function returns the average of a numeric column in a table.
72
82
  It calculates the mean of all values in the specified column.
73
83
 
74
84
  Args:
75
- col (str): The name of the column for which to calculate the average.
85
+ col (str | Column): The name of the column for which to calculate the average.
86
+ Column can be specified as a string or a `Column` object.
76
87
 
77
88
  Returns:
78
89
  Func: A Func object that represents the AVG aggregate function.
@@ -81,26 +92,28 @@ def avg(col: str) -> Func:
81
92
  ```py
82
93
  dc.group_by(
83
94
  avg_file_size=func.avg("file.size"),
95
+ avg_signal_value=func.avg(dc.C("signal.value")),
84
96
  partition_by="signal.category",
85
97
  )
86
98
  ```
87
99
 
88
100
  Notes:
89
101
  - The `avg` function should be used on numeric columns.
90
- - Result column will always be of type float.
102
+ - The result column will always be of type float.
91
103
  """
92
104
  return Func("avg", inner=aggregate.avg, cols=[col], result_type=float)
93
105
 
94
106
 
95
- def min(col: str) -> Func:
107
+ def min(col: Union[str, Column]) -> Func:
96
108
  """
97
- Returns the MIN aggregate SQL function for the given column name.
109
+ Returns the MIN aggregate SQL function for the specified column.
98
110
 
99
111
  The MIN function returns the smallest value in the specified column.
100
112
  It can be used on both numeric and non-numeric columns to find the minimum value.
101
113
 
102
114
  Args:
103
- col (str): The name of the column for which to find the minimum value.
115
+ col (str | Column): The name of the column for which to find the minimum value.
116
+ Column can be specified as a string or a `Column` object.
104
117
 
105
118
  Returns:
106
119
  Func: A Func object that represents the MIN aggregate function.
@@ -109,18 +122,19 @@ def min(col: str) -> Func:
109
122
  ```py
110
123
  dc.group_by(
111
124
  smallest_file=func.min("file.size"),
125
+ min_signal=func.min(dc.C("signal")),
112
126
  partition_by="signal.category",
113
127
  )
114
128
  ```
115
129
 
116
130
  Notes:
117
131
  - The `min` function can be used with numeric, date, and string columns.
118
- - Result column will have the same type as the input column.
132
+ - The result column will have the same type as the input column.
119
133
  """
120
134
  return Func("min", inner=sa_func.min, cols=[col])
121
135
 
122
136
 
123
- def max(col: str) -> Func:
137
+ def max(col: Union[str, Column]) -> Func:
124
138
  """
125
139
  Returns the MAX aggregate SQL function for the given column name.
126
140
 
@@ -128,7 +142,8 @@ def max(col: str) -> Func:
128
142
  It can be used on both numeric and non-numeric columns to find the maximum value.
129
143
 
130
144
  Args:
131
- col (str): The name of the column for which to find the maximum value.
145
+ col (str | Column): The name of the column for which to find the maximum value.
146
+ Column can be specified as a string or a `Column` object.
132
147
 
133
148
  Returns:
134
149
  Func: A Func object that represents the MAX aggregate function.
@@ -137,18 +152,19 @@ def max(col: str) -> Func:
137
152
  ```py
138
153
  dc.group_by(
139
154
  largest_file=func.max("file.size"),
155
+ max_signal=func.max(dc.C("signal")),
140
156
  partition_by="signal.category",
141
157
  )
142
158
  ```
143
159
 
144
160
  Notes:
145
161
  - The `max` function can be used with numeric, date, and string columns.
146
- - Result column will have the same type as the input column.
162
+ - The result column will have the same type as the input column.
147
163
  """
148
164
  return Func("max", inner=sa_func.max, cols=[col])
149
165
 
150
166
 
151
- def any_value(col: str) -> Func:
167
+ def any_value(col: Union[str, Column]) -> Func:
152
168
  """
153
169
  Returns the ANY_VALUE aggregate SQL function for the given column name.
154
170
 
@@ -157,7 +173,9 @@ def any_value(col: str) -> Func:
157
173
  as long as it comes from one of the rows in the group.
158
174
 
159
175
  Args:
160
- col (str): The name of the column from which to return an arbitrary value.
176
+ col (str | Column): The name of the column from which to return
177
+ an arbitrary value.
178
+ Column can be specified as a string or a `Column` object.
161
179
 
162
180
  Returns:
163
181
  Func: A Func object that represents the ANY_VALUE aggregate function.
@@ -166,20 +184,21 @@ def any_value(col: str) -> Func:
166
184
  ```py
167
185
  dc.group_by(
168
186
  file_example=func.any_value("file.path"),
187
+ signal_example=func.any_value(dc.C("signal.value")),
169
188
  partition_by="signal.category",
170
189
  )
171
190
  ```
172
191
 
173
192
  Notes:
174
193
  - The `any_value` function can be used with any type of column.
175
- - Result column will have the same type as the input column.
194
+ - The result column will have the same type as the input column.
176
195
  - The result of `any_value` is non-deterministic,
177
196
  meaning it may return different values for different executions.
178
197
  """
179
198
  return Func("any_value", inner=aggregate.any_value, cols=[col])
180
199
 
181
200
 
182
- def collect(col: str) -> Func:
201
+ def collect(col: Union[str, Column]) -> Func:
183
202
  """
184
203
  Returns the COLLECT aggregate SQL function for the given column name.
185
204
 
@@ -188,7 +207,8 @@ def collect(col: str) -> Func:
188
207
  into a collection, often for further processing or aggregation.
189
208
 
190
209
  Args:
191
- col (str): The name of the column from which to collect values.
210
+ col (str | Column): The name of the column from which to collect values.
211
+ Column can be specified as a string or a `Column` object.
192
212
 
193
213
  Returns:
194
214
  Func: A Func object that represents the COLLECT aggregate function.
@@ -197,18 +217,19 @@ def collect(col: str) -> Func:
197
217
  ```py
198
218
  dc.group_by(
199
219
  signals=func.collect("signal"),
220
+ file_paths=func.collect(dc.C("file.path")),
200
221
  partition_by="signal.category",
201
222
  )
202
223
  ```
203
224
 
204
225
  Notes:
205
226
  - The `collect` function can be used with numeric and string columns.
206
- - Result column will have an array type.
227
+ - The result column will have an array type.
207
228
  """
208
229
  return Func("collect", inner=aggregate.collect, cols=[col], is_array=True)
209
230
 
210
231
 
211
- def concat(col: str, separator="") -> Func:
232
+ def concat(col: Union[str, Column], separator="") -> Func:
212
233
  """
213
234
  Returns the CONCAT aggregate SQL function for the given column name.
214
235
 
@@ -217,9 +238,10 @@ def concat(col: str, separator="") -> Func:
217
238
  into a single combined value.
218
239
 
219
240
  Args:
220
- col (str): The name of the column from which to concatenate values.
241
+ col (str | Column): The name of the column from which to concatenate values.
242
+ Column can be specified as a string or a `Column` object.
221
243
  separator (str, optional): The separator to use between concatenated values.
222
- Defaults to an empty string.
244
+ Defaults to an empty string.
223
245
 
224
246
  Returns:
225
247
  Func: A Func object that represents the CONCAT aggregate function.
@@ -228,13 +250,14 @@ def concat(col: str, separator="") -> Func:
228
250
  ```py
229
251
  dc.group_by(
230
252
  files=func.concat("file.path", separator=", "),
253
+ signals=func.concat(dc.C("signal.name"), separator=" | "),
231
254
  partition_by="signal.category",
232
255
  )
233
256
  ```
234
257
 
235
258
  Notes:
236
259
  - The `concat` function can be used with string columns.
237
- - Result column will have a string type.
260
+ - The result column will have a string type.
238
261
  """
239
262
 
240
263
  def inner(arg):
@@ -325,7 +348,7 @@ def dense_rank() -> Func:
325
348
  return Func("dense_rank", inner=sa_func.dense_rank, result_type=int, is_window=True)
326
349
 
327
350
 
328
- def first(col: str) -> Func:
351
+ def first(col: Union[str, Column]) -> Func:
329
352
  """
330
353
  Returns the FIRST_VALUE window function for SQL queries.
331
354
 
@@ -334,7 +357,9 @@ def first(col: str) -> Func:
334
357
  and can be useful for retrieving the leading value in a group of rows.
335
358
 
336
359
  Args:
337
- col (str): The name of the column from which to retrieve the first value.
360
+ col (str | Column): The name of the column from which to retrieve
361
+ the first value.
362
+ Column can be specified as a string or a `Column` object.
338
363
 
339
364
  Returns:
340
365
  Func: A Func object that represents the FIRST_VALUE window function.
@@ -344,6 +369,7 @@ def first(col: str) -> Func:
344
369
  window = func.window(partition_by="signal.category", order_by="created_at")
345
370
  dc.mutate(
346
371
  first_file=func.first("file.path").over(window),
372
+ first_signal=func.first(dc.C("signal.value")).over(window),
347
373
  )
348
374
  ```
349
375