datachain 0.25.2__py3-none-any.whl → 0.26.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -15,6 +15,7 @@ from typing import (
15
15
  Optional,
16
16
  TypeVar,
17
17
  Union,
18
+ cast,
18
19
  overload,
19
20
  )
20
21
 
@@ -32,21 +33,28 @@ from datachain.func import literal
32
33
  from datachain.func.base import Function
33
34
  from datachain.func.func import Func
34
35
  from datachain.lib.convert.python_to_sql import python_to_sql
35
- from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_data_model
36
+ from datachain.lib.data_model import (
37
+ DataModel,
38
+ DataType,
39
+ DataValue,
40
+ StandardType,
41
+ dict_to_data_model,
42
+ )
36
43
  from datachain.lib.file import (
37
44
  EXPORT_FILES_MAX_THREADS,
38
45
  ArrowRow,
39
46
  FileExporter,
40
47
  )
41
48
  from datachain.lib.file import ExportPlacement as FileExportPlacement
49
+ from datachain.lib.model_store import ModelStore
42
50
  from datachain.lib.settings import Settings
43
- from datachain.lib.signal_schema import SignalSchema
51
+ from datachain.lib.signal_schema import SignalResolvingError, SignalSchema
44
52
  from datachain.lib.udf import Aggregator, BatchMapper, Generator, Mapper, UDFBase
45
53
  from datachain.lib.udf_signature import UdfSignature
46
54
  from datachain.lib.utils import DataChainColumnError, DataChainParamsError
47
55
  from datachain.query import Session
48
56
  from datachain.query.dataset import DatasetQuery, PartitionByType
49
- from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
57
+ from datachain.query.schema import DEFAULT_DELIMITER, Column
50
58
  from datachain.sql.functions import path as pathfunc
51
59
  from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
52
60
 
@@ -358,14 +366,6 @@ class DataChain:
358
366
  self._settings = settings if settings else Settings()
359
367
  return self
360
368
 
361
- def reset_schema(self, signals_schema: SignalSchema) -> "Self":
362
- self.signals_schema = signals_schema
363
- return self
364
-
365
- def add_schema(self, signals_schema: SignalSchema) -> "Self":
366
- self.signals_schema |= signals_schema
367
- return self
368
-
369
369
  @classmethod
370
370
  def from_storage(
371
371
  cls,
@@ -758,11 +758,12 @@ class DataChain:
758
758
  @delta_disabled
759
759
  def agg(
760
760
  self,
761
+ /,
761
762
  func: Optional[Callable] = None,
762
763
  partition_by: Optional[PartitionByType] = None,
763
764
  params: Union[None, str, Sequence[str]] = None,
764
765
  output: OutputType = None,
765
- **signal_map,
766
+ **signal_map: Callable,
766
767
  ) -> "Self":
767
768
  """Aggregate rows using `partition_by` statement and apply a function to the
768
769
  groups of aggregated rows. The function needs to return new objects for each
@@ -772,12 +773,28 @@ class DataChain:
772
773
 
773
774
  This method bears similarity to `gen()` and `map()`, employing a comparable set
774
775
  of parameters, yet differs in two crucial aspects:
776
+
775
777
  1. The `partition_by` parameter: This specifies the column name or a list of
776
778
  column names that determine the grouping criteria for aggregation.
777
779
  2. Group-based UDF function input: Instead of individual rows, the function
778
- receives a list all rows within each group defined by `partition_by`.
780
+ receives a list of all rows within each group defined by `partition_by`.
781
+
782
+ If `partition_by` is not set or is an empty list, all rows will be placed
783
+ into a single group.
784
+
785
+ Parameters:
786
+ func: Function applied to each group of rows.
787
+ partition_by: Column name(s) to group by. If None, all rows go
788
+ into one group.
789
+ params: List of column names used as input for the function. Default is
790
+ taken from function signature.
791
+ output: Dictionary defining new signals and their corresponding types.
792
+ Default type is taken from function signature.
793
+ **signal_map: kwargs can be used to define `func` together with its return
794
+ signal name in format of `agg(result_column=my_func)`.
779
795
 
780
796
  Examples:
797
+ Basic aggregation with lambda function:
781
798
  ```py
782
799
  chain = chain.agg(
783
800
  total=lambda category, amount: [sum(amount)],
@@ -788,7 +805,6 @@ class DataChain:
788
805
  ```
789
806
 
790
807
  An alternative syntax, when you need to specify a more complex function:
791
-
792
808
  ```py
793
809
  # It automatically resolves which columns to pass to the function
794
810
  # by looking at the function signature.
@@ -806,10 +822,43 @@ class DataChain:
806
822
  )
807
823
  chain.save("new_dataset")
808
824
  ```
825
+
826
+ Using complex signals for partitioning (`File` or any Pydantic `BaseModel`):
827
+ ```py
828
+ def my_agg(files: list[File]) -> Iterator[tuple[File, int]]:
829
+ yield files[0], sum(f.size for f in files)
830
+
831
+ chain = chain.agg(
832
+ my_agg,
833
+ params=("file",),
834
+ output={"file": File, "total": int},
835
+ partition_by="file", # Column referring to all sub-columns of File
836
+ )
837
+ chain.save("new_dataset")
838
+ ```
839
+
840
+ Aggregating all rows into a single group (when `partition_by` is not set):
841
+ ```py
842
+ chain = chain.agg(
843
+ total_size=lambda file, size: [sum(size)],
844
+ output=int,
845
+ # No partition_by specified - all rows go into one group
846
+ )
847
+ chain.save("new_dataset")
848
+ ```
849
+
850
+ Multiple partition columns:
851
+ ```py
852
+ chain = chain.agg(
853
+ total=lambda category, subcategory, amount: [sum(amount)],
854
+ output=float,
855
+ partition_by=["category", "subcategory"],
856
+ )
857
+ chain.save("new_dataset")
858
+ ```
809
859
  """
810
- # Convert string partition_by parameters to Column objects
811
- processed_partition_by = partition_by
812
860
  if partition_by is not None:
861
+ # Convert string partition_by parameters to Column objects
813
862
  if isinstance(partition_by, (str, Function, ColumnElement)):
814
863
  list_partition_by = [partition_by]
815
864
  else:
@@ -818,10 +867,10 @@ class DataChain:
818
867
  processed_partition_columns: list[ColumnElement] = []
819
868
  for col in list_partition_by:
820
869
  if isinstance(col, str):
821
- col_db_name = ColumnMeta.to_db_name(col)
822
- col_type = self.signals_schema.get_column_type(col_db_name)
823
- column = Column(col_db_name, python_to_sql(col_type))
824
- processed_partition_columns.append(column)
870
+ columns = self.signals_schema.db_signals(name=col, as_columns=True)
871
+ if not columns:
872
+ raise SignalResolvingError([col], "is not found")
873
+ processed_partition_columns.extend(cast("list[Column]", columns))
825
874
  elif isinstance(col, Function):
826
875
  column = col.get_column(self.signals_schema)
827
876
  processed_partition_columns.append(column)
@@ -830,6 +879,8 @@ class DataChain:
830
879
  processed_partition_columns.append(col)
831
880
 
832
881
  processed_partition_by = processed_partition_columns
882
+ else:
883
+ processed_partition_by = []
833
884
 
834
885
  udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
835
886
  return self._evolve(
@@ -905,7 +956,7 @@ class DataChain:
905
956
  query_func = getattr(self._query, method_name)
906
957
 
907
958
  new_schema = self.signals_schema.resolve(*args)
908
- columns = [C(col) for col in new_schema.db_signals()]
959
+ columns = new_schema.db_signals(as_columns=True)
909
960
  return query_func(*columns, **kwargs)
910
961
 
911
962
  @resolve_columns
@@ -969,7 +1020,7 @@ class DataChain:
969
1020
  )
970
1021
 
971
1022
  @delta_disabled # type: ignore[arg-type]
972
- def group_by(
1023
+ def group_by( # noqa: C901, PLR0912
973
1024
  self,
974
1025
  *,
975
1026
  partition_by: Optional[Union[str, Func, Sequence[Union[str, Func]]]] = None,
@@ -988,6 +1039,15 @@ class DataChain:
988
1039
  partition_by=("file_source", "file_ext"),
989
1040
  )
990
1041
  ```
1042
+
1043
+ Using complex signals:
1044
+ ```py
1045
+ chain = chain.group_by(
1046
+ total_size=func.sum("file.size"),
1047
+ count=func.count(),
1048
+ partition_by="file", # Uses column name, expands to File's unique keys
1049
+ )
1050
+ ```
991
1051
  """
992
1052
  if partition_by is None:
993
1053
  partition_by = []
@@ -998,20 +1058,61 @@ class DataChain:
998
1058
  signal_columns: list[Column] = []
999
1059
  schema_fields: dict[str, DataType] = {}
1000
1060
  keep_columns: list[str] = []
1061
+ partial_fields: list[str] = [] # Track specific fields for partial creation
1062
+ schema_partition_by: list[str] = []
1001
1063
 
1002
- # validate partition_by columns and add them to the schema
1003
1064
  for col in partition_by:
1004
1065
  if isinstance(col, str):
1005
- col_db_name = ColumnMeta.to_db_name(col)
1006
- col_type = self.signals_schema.get_column_type(col_db_name)
1007
- column = Column(col_db_name, python_to_sql(col_type))
1008
- if col not in keep_columns:
1009
- keep_columns.append(col)
1066
+ columns = self.signals_schema.db_signals(name=col, as_columns=True)
1067
+ if not columns:
1068
+ raise SignalResolvingError([col], "is not found")
1069
+ partition_by_columns.extend(cast("list[Column]", columns))
1070
+
1071
+ # For nested field references (e.g., "nested.level1.name"),
1072
+ # we need to distinguish between:
1073
+ # 1. References to fields within a complex signal (create partials)
1074
+ # 2. Deep nested references that should be flattened
1075
+ if "." in col:
1076
+ # Split the column reference to analyze it
1077
+ parts = col.split(".")
1078
+ parent_signal = parts[0]
1079
+ parent_type = self.signals_schema.values.get(parent_signal)
1080
+
1081
+ if ModelStore.is_partial(parent_type):
1082
+ if parent_signal not in keep_columns:
1083
+ keep_columns.append(parent_signal)
1084
+ partial_fields.append(col)
1085
+ schema_partition_by.append(col)
1086
+ else:
1087
+ # BaseModel or other - add flattened columns directly
1088
+ for column in cast("list[Column]", columns):
1089
+ col_type = self.signals_schema.get_column_type(column.name)
1090
+ schema_fields[column.name] = col_type
1091
+ schema_partition_by.append(col)
1092
+ else:
1093
+ # simple signal - but we need to check if it's a complex signal
1094
+ # complex signal - only include the columns used for partitioning
1095
+ col_type = self.signals_schema.get_column_type(
1096
+ col, with_subtree=True
1097
+ )
1098
+ if isinstance(col_type, type) and issubclass(col_type, BaseModel):
1099
+ # Complex signal - add only the partitioning columns
1100
+ for column in cast("list[Column]", columns):
1101
+ col_type = self.signals_schema.get_column_type(column.name)
1102
+ schema_fields[column.name] = col_type
1103
+ schema_partition_by.append(col)
1104
+ # Simple signal - keep the entire signal
1105
+ else:
1106
+ if col not in keep_columns:
1107
+ keep_columns.append(col)
1108
+ schema_partition_by.append(col)
1010
1109
  elif isinstance(col, Function):
1011
1110
  column = col.get_column(self.signals_schema)
1012
1111
  col_db_name = column.name
1013
1112
  col_type = column.type.python_type
1014
1113
  schema_fields[col_db_name] = col_type
1114
+ partition_by_columns.append(column)
1115
+ signal_columns.append(column)
1015
1116
  else:
1016
1117
  raise DataChainColumnError(
1017
1118
  col,
@@ -1020,9 +1121,7 @@ class DataChain:
1020
1121
  " but expected str or Function"
1021
1122
  ),
1022
1123
  )
1023
- partition_by_columns.append(column)
1024
1124
 
1025
- # validate signal columns and add them to the schema
1026
1125
  if not kwargs:
1027
1126
  raise ValueError("At least one column should be provided for group_by")
1028
1127
  for col_name, func in kwargs.items():
@@ -1035,9 +1134,9 @@ class DataChain:
1035
1134
  signal_columns.append(column)
1036
1135
  schema_fields[col_name] = func.get_result_type(self.signals_schema)
1037
1136
 
1038
- signal_schema = SignalSchema(schema_fields)
1039
- if keep_columns:
1040
- signal_schema |= self.signals_schema.to_partial(*keep_columns)
1137
+ signal_schema = self.signals_schema.group_by(
1138
+ schema_partition_by, signal_columns
1139
+ )
1041
1140
 
1042
1141
  return self._evolve(
1043
1142
  query=self._query.group_by(signal_columns, partition_by_columns),
@@ -1166,6 +1265,7 @@ class DataChain:
1166
1265
  db_signals = self._effective_signals_schema.db_signals(
1167
1266
  include_hidden=include_hidden
1168
1267
  )
1268
+
1169
1269
  with self._query.ordered_select(*db_signals).as_iterable() as rows:
1170
1270
  if row_factory:
1171
1271
  rows = (row_factory(db_signals, r) for r in rows) # type: ignore[assignment]
@@ -1343,10 +1443,6 @@ class DataChain:
1343
1443
  remove_prefetched=remove_prefetched,
1344
1444
  )
1345
1445
 
1346
- def remove_file_signals(self) -> "Self":
1347
- schema = self.signals_schema.clone_without_file_signals()
1348
- return self.select(*schema.values.keys())
1349
-
1350
1446
  @delta_disabled
1351
1447
  def merge(
1352
1448
  self,
@@ -1701,12 +1797,19 @@ class DataChain:
1701
1797
  )
1702
1798
  return read_pandas(*args, **kwargs)
1703
1799
 
1704
- def to_pandas(self, flatten=False, include_hidden=True) -> "pd.DataFrame":
1800
+ def to_pandas(
1801
+ self,
1802
+ flatten: bool = False,
1803
+ include_hidden: bool = True,
1804
+ ) -> "pd.DataFrame":
1705
1805
  """Return a pandas DataFrame from the chain.
1706
1806
 
1707
1807
  Parameters:
1708
- flatten : Whether to use a multiindex or flatten column names.
1709
- include_hidden : Whether to include hidden columns.
1808
+ flatten: Whether to use a multiindex or flatten column names.
1809
+ include_hidden: Whether to include hidden columns.
1810
+
1811
+ Returns:
1812
+ pd.DataFrame: A pandas DataFrame representation of the chain.
1710
1813
  """
1711
1814
  import pandas as pd
1712
1815
 
@@ -1724,19 +1827,19 @@ class DataChain:
1724
1827
  def show(
1725
1828
  self,
1726
1829
  limit: int = 20,
1727
- flatten=False,
1728
- transpose=False,
1729
- truncate=True,
1730
- include_hidden=False,
1830
+ flatten: bool = False,
1831
+ transpose: bool = False,
1832
+ truncate: bool = True,
1833
+ include_hidden: bool = False,
1731
1834
  ) -> None:
1732
1835
  """Show a preview of the chain results.
1733
1836
 
1734
1837
  Parameters:
1735
- limit : How many rows to show.
1736
- flatten : Whether to use a multiindex or flatten column names.
1737
- transpose : Whether to transpose rows and columns.
1738
- truncate : Whether or not to truncate the contents of columns.
1739
- include_hidden : Whether to include hidden columns.
1838
+ limit: How many rows to show.
1839
+ flatten: Whether to use a multiindex or flatten column names.
1840
+ transpose: Whether to transpose rows and columns.
1841
+ truncate: Whether or not to truncate the contents of columns.
1842
+ include_hidden: Whether to include hidden columns.
1740
1843
  """
1741
1844
  import pandas as pd
1742
1845
 
@@ -2166,21 +2269,73 @@ class DataChain:
2166
2269
  )
2167
2270
  return read_records(*args, **kwargs)
2168
2271
 
2169
- def sum(self, fr: DataType): # type: ignore[override]
2170
- """Compute the sum of a column."""
2171
- return self._extend_to_data_model("sum", fr)
2272
+ def sum(self, col: str) -> StandardType: # type: ignore[override]
2273
+ """Compute the sum of a column.
2274
+
2275
+ Parameters:
2276
+ col: The column to compute the sum for.
2277
+
2278
+ Returns:
2279
+ The sum of the column values.
2280
+
2281
+ Example:
2282
+ ```py
2283
+ total_size = chain.sum("file.size")
2284
+ print(f"Total size: {total_size}")
2285
+ ```
2286
+ """
2287
+ return self._extend_to_data_model("sum", col)
2172
2288
 
2173
- def avg(self, fr: DataType): # type: ignore[override]
2174
- """Compute the average of a column."""
2175
- return self._extend_to_data_model("avg", fr)
2289
+ def avg(self, col: str) -> StandardType: # type: ignore[override]
2290
+ """Compute the average of a column.
2176
2291
 
2177
- def min(self, fr: DataType): # type: ignore[override]
2178
- """Compute the minimum of a column."""
2179
- return self._extend_to_data_model("min", fr)
2292
+ Parameters:
2293
+ col: The column to compute the average for.
2180
2294
 
2181
- def max(self, fr: DataType): # type: ignore[override]
2182
- """Compute the maximum of a column."""
2183
- return self._extend_to_data_model("max", fr)
2295
+ Returns:
2296
+ The average of the column values.
2297
+
2298
+ Example:
2299
+ ```py
2300
+ average_size = chain.avg("file.size")
2301
+ print(f"Average size: {average_size}")
2302
+ ```
2303
+ """
2304
+ return self._extend_to_data_model("avg", col)
2305
+
2306
+ def min(self, col: str) -> StandardType: # type: ignore[override]
2307
+ """Compute the minimum of a column.
2308
+
2309
+ Parameters:
2310
+ col: The column to compute the minimum for.
2311
+
2312
+ Returns:
2313
+ The minimum value in the column.
2314
+
2315
+ Example:
2316
+ ```py
2317
+ min_size = chain.min("file.size")
2318
+ print(f"Minimum size: {min_size}")
2319
+ ```
2320
+ """
2321
+ return self._extend_to_data_model("min", col)
2322
+
2323
+ def max(self, col: str) -> StandardType: # type: ignore[override]
2324
+ """Compute the maximum of a column.
2325
+
2326
+ Parameters:
2327
+ col: The column to compute the maximum for.
2328
+
2329
+ Returns:
2330
+ The maximum value in the column.
2331
+
2332
+ Example:
2333
+ ```py
2334
+ max_size = chain.max("file.size")
2335
+ print(f"Maximum size: {max_size}")
2336
+ ```
2337
+ """
2338
+ return self._extend_to_data_model("max", col)
2184
2339
 
2185
2340
  def setup(self, **kwargs) -> "Self":
2186
2341
  """Setup variables to pass to UDF functions.
@@ -2291,14 +2446,15 @@ class DataChain:
2291
2446
  """Shuffle the rows of the chain deterministically."""
2292
2447
  return self.order_by("sys.rand")
2293
2448
 
2294
- def sample(self, n) -> "Self":
2449
+ def sample(self, n: int) -> "Self":
2295
2450
  """Return a random sample from the chain.
2296
2451
 
2297
2452
  Parameters:
2298
- n (int): Number of samples to draw.
2453
+ n: Number of samples to draw.
2299
2454
 
2300
- NOTE: Samples are not deterministic, and streamed/paginated queries or
2301
- multiple workers will draw samples with replacement.
2455
+ Note:
2456
+ Samples are not deterministic, and streamed/paginated queries or
2457
+ multiple workers will draw samples with replacement.
2302
2458
  """
2303
2459
  return self._evolve(query=self._query.sample(n))
2304
2460
 
@@ -2405,6 +2561,10 @@ class DataChain:
2405
2561
  def chunk(self, index: int, total: int) -> "Self":
2406
2562
  """Split a chain into smaller chunks for e.g. parallelization.
2407
2563
 
2564
+ Parameters:
2565
+ index: The index of the chunk (0-indexed).
2566
+ total: The total number of chunks.
2567
+
2408
2568
  Example:
2409
2569
  ```py
2410
2570
  import datachain as dc
@@ -2424,7 +2584,7 @@ class DataChain:
2424
2584
  """Returns a list of rows of values, optionally limited to the specified
2425
2585
  columns.
2426
2586
 
2427
- Args:
2587
+ Parameters:
2428
2588
  *cols: Limit to the specified columns. By default, all columns are selected.
2429
2589
 
2430
2590
  Returns:
@@ -2454,7 +2614,7 @@ class DataChain:
2454
2614
  def to_values(self, col: str) -> list[DataValue]:
2455
2615
  """Returns a flat list of values from a single column.
2456
2616
 
2457
- Args:
2617
+ Parameters:
2458
2618
  col: The name of the column to extract values from.
2459
2619
 
2460
2620
  Returns: