datachain 0.37.6__py3-none-any.whl → 0.37.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -52,7 +52,11 @@ from datachain.lib.udf_signature import UdfSignature
52
52
  from datachain.lib.utils import DataChainColumnError, DataChainParamsError
53
53
  from datachain.project import Project
54
54
  from datachain.query import Session
55
- from datachain.query.dataset import DatasetQuery, PartitionByType
55
+ from datachain.query.dataset import (
56
+ DatasetQuery,
57
+ PartitionByType,
58
+ RegenerateSystemColumns,
59
+ )
56
60
  from datachain.query.schema import DEFAULT_DELIMITER, Column
57
61
  from datachain.sql.functions import path as pathfunc
58
62
  from datachain.utils import batched_it, env2bool, inside_notebook, row_to_nested_dict
@@ -577,7 +581,8 @@ class DataChain:
577
581
  create=True,
578
582
  )
579
583
  return self._evolve(
580
- query=self._query.save(project=project, feature_schema=schema)
584
+ query=self._query.save(project=project, feature_schema=schema),
585
+ signal_schema=self.signals_schema | SignalSchema({"sys": Sys}),
581
586
  )
582
587
 
583
588
  def _calculate_job_hash(self, job_id: str) -> str:
@@ -2739,8 +2744,20 @@ class DataChain:
2739
2744
  )
2740
2745
 
2741
2746
  def shuffle(self) -> "Self":
2742
- """Shuffle the rows of the chain deterministically."""
2743
- return self.order_by("sys.rand")
2747
+ """Shuffle rows with a best-effort deterministic ordering.
2748
+
2749
+ This produces repeatable shuffles. Merge and union operations can
2750
+ lead to non-deterministic results. Use order by or save a dataset
2751
+ afterward to guarantee the same result.
2752
+ """
2753
+ query = self._query.clone(new_table=False)
2754
+ query.steps.append(RegenerateSystemColumns(self._query.catalog))
2755
+
2756
+ chain = self._evolve(
2757
+ query=query,
2758
+ signal_schema=SignalSchema({"sys": Sys}) | self.signals_schema,
2759
+ )
2760
+ return chain.order_by("sys.rand")
2744
2761
 
2745
2762
  def sample(self, n: int) -> "Self":
2746
2763
  """Return a random sample from the chain.
@@ -786,10 +786,31 @@ class SQLClause(Step, ABC):
786
786
  return tuple(c.get_column() if isinstance(c, Function) else c for c in cols)
787
787
 
788
788
  @abstractmethod
789
- def apply_sql_clause(self, query):
789
+ def apply_sql_clause(self, query: Any) -> Any:
790
790
  pass
791
791
 
792
792
 
793
+ @frozen
794
+ class RegenerateSystemColumns(Step):
795
+ catalog: "Catalog"
796
+
797
+ def hash_inputs(self) -> str:
798
+ return hashlib.sha256(b"regenerate_system_columns").hexdigest()
799
+
800
+ def apply(
801
+ self, query_generator: QueryGenerator, temp_tables: list[str]
802
+ ) -> StepResult:
803
+ query = query_generator.select()
804
+ new_query = self.catalog.warehouse._regenerate_system_columns(
805
+ query, keep_existing_columns=True
806
+ )
807
+
808
+ def q(*columns):
809
+ return new_query.with_only_columns(*columns)
810
+
811
+ return step_result(q, new_query.selected_columns)
812
+
813
+
793
814
  @frozen
794
815
  class SQLSelect(SQLClause):
795
816
  args: tuple[Function | ColumnElement, ...]
@@ -1488,10 +1509,6 @@ class DatasetQuery:
1488
1509
  finally:
1489
1510
  self.cleanup()
1490
1511
 
1491
- def shuffle(self) -> "Self":
1492
- # ToDo: implement shaffle based on seed and/or generating random column
1493
- return self.order_by(C.sys__rand)
1494
-
1495
1512
  def sample(self, n) -> "Self":
1496
1513
  """
1497
1514
  Return a random sample from the dataset.
@@ -1,6 +1,7 @@
1
1
  import random
2
2
 
3
3
  from datachain import C, DataChain
4
+ from datachain.lib.signal_schema import SignalResolvingError
4
5
 
5
6
  RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
6
7
 
@@ -59,7 +60,10 @@ def train_test_split(
59
60
  ```
60
61
 
61
62
  Note:
62
- The splits are random but deterministic, based on Dataset `sys__rand` field.
63
+ Splits reuse the same best-effort shuffle used by `DataChain.shuffle`. Results
64
+ are typically repeatable, but earlier operations such as `merge`, `union`, or
65
+ custom SQL that reshuffle rows can change the outcome between runs. Add order by
66
+ stable keys first when you need strict reproducibility.
63
67
  """
64
68
  if len(weights) < 2:
65
69
  raise ValueError("Weights should have at least two elements")
@@ -68,16 +72,34 @@ def train_test_split(
68
72
 
69
73
  weights_normalized = [weight / sum(weights) for weight in weights]
70
74
 
75
+ try:
76
+ dc.signals_schema.resolve("sys.rand")
77
+ except SignalResolvingError:
78
+ dc = dc.persist()
79
+
71
80
  rand_col = C("sys.rand")
72
81
  if seed is not None:
73
82
  uniform_seed = random.Random(seed).randrange(1, RESOLUTION) # noqa: S311
74
83
  rand_col = (rand_col % RESOLUTION) * uniform_seed # type: ignore[assignment]
75
84
  rand_col = rand_col % RESOLUTION # type: ignore[assignment]
76
85
 
77
- return [
78
- dc.filter(
79
- rand_col >= round(sum(weights_normalized[:index]) * (RESOLUTION - 1)),
80
- rand_col < round(sum(weights_normalized[: index + 1]) * (RESOLUTION - 1)),
81
- )
82
- for index, _ in enumerate(weights_normalized)
83
- ]
86
+ boundaries: list[int] = [0]
87
+ cumulative = 0.0
88
+ for weight in weights_normalized[:-1]:
89
+ cumulative += weight
90
+ boundary = round(cumulative * RESOLUTION)
91
+ boundaries.append(min(boundary, RESOLUTION))
92
+ boundaries.append(RESOLUTION)
93
+
94
+ splits: list[DataChain] = []
95
+ last_index = len(weights_normalized) - 1
96
+ for index in range(len(weights_normalized)):
97
+ lower = boundaries[index]
98
+ if index == last_index:
99
+ condition = rand_col >= lower
100
+ else:
101
+ upper = boundaries[index + 1]
102
+ condition = (rand_col >= lower) & (rand_col < upper)
103
+ splits.append(dc.filter(condition))
104
+
105
+ return splits
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datachain
3
- Version: 0.37.6
3
+ Version: 0.37.8
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License-Expression: Apache-2.0
@@ -55,9 +55,9 @@ Provides-Extra: docs
55
55
  Requires-Dist: mkdocs>=1.5.2; extra == "docs"
56
56
  Requires-Dist: mkdocs-gen-files>=0.5.0; extra == "docs"
57
57
  Requires-Dist: mkdocs-material==9.5.22; extra == "docs"
58
- Requires-Dist: mkdocs-section-index>=0.3.6; extra == "docs"
59
58
  Requires-Dist: mkdocstrings-python>=1.6.3; extra == "docs"
60
59
  Requires-Dist: mkdocs-literate-nav>=0.6.1; extra == "docs"
60
+ Requires-Dist: mkdocs-section-index>=0.3.10; extra == "docs"
61
61
  Requires-Dist: eval-type-backport; extra == "docs"
62
62
  Provides-Extra: torch
63
63
  Requires-Dist: torch>=2.1.0; extra == "torch"
@@ -109,7 +109,7 @@ datachain/lib/convert/values_to_tuples.py,sha256=Sxj0ojeMSpAwM_NNoXa1dMR_2L_cQ6X
109
109
  datachain/lib/dc/__init__.py,sha256=UrUzmDH6YyVl8fxM5iXTSFtl5DZTUzEYm1MaazK4vdQ,900
110
110
  datachain/lib/dc/csv.py,sha256=fIfj5-2Ix4z5D5yZueagd5WUWw86pusJ9JJKD-U3KGg,4407
111
111
  datachain/lib/dc/database.py,sha256=Wqob3dQc9Mol_0vagzVEXzteCKS9M0E3U5130KVmQKg,14629
112
- datachain/lib/dc/datachain.py,sha256=RYhinLQ6CMU3tudLpiJGh-vfCL24KDKbKM3Q1EsWoAE,104072
112
+ datachain/lib/dc/datachain.py,sha256=XHr3gbdpLwzHhhIzPQXL5uZJQMFZ1AypCENdRlWWxoM,104671
113
113
  datachain/lib/dc/datasets.py,sha256=oY1t8QBAaZdhjwR439zZT74hMOspewVCrgdwy6juXng,15321
114
114
  datachain/lib/dc/hf.py,sha256=FeruEO176L2qQ1Mnx0QmK4kV0GuQ4xtj717N8fGJrBI,2849
115
115
  datachain/lib/dc/json.py,sha256=iJ6G0jwTKz8xtfh1eICShnWk_bAMWjF5bFnOXLHaTlw,2683
@@ -132,7 +132,7 @@ datachain/model/ultralytics/pose.py,sha256=pvoXrWWUSWT_UBaMwUb5MBHAY57Co2HFDPigF
132
132
  datachain/model/ultralytics/segment.py,sha256=v9_xDxd5zw_I8rXsbl7yQXgEdTs2T38zyY_Y4XGN8ok,3194
133
133
  datachain/query/__init__.py,sha256=7DhEIjAA8uZJfejruAVMZVcGFmvUpffuZJwgRqNwe-c,263
134
134
  datachain/query/batch.py,sha256=ugTlSFqh_kxMcG6vJ5XrEzG9jBXRdb7KRAEEsFWiPew,4190
135
- datachain/query/dataset.py,sha256=kfNh6B6pYSz3batUpwW_6vJ7XRLwLfC08hKOZUMjf3o,67126
135
+ datachain/query/dataset.py,sha256=9Ky0LZ7wMpfJbIZyXjnensrDQJvGg1pysZs96AYZqIY,67576
136
136
  datachain/query/dispatch.py,sha256=Tg73zB6vDnYYYAvtlS9l7BI3sI1EfRCbDjiasvNxz2s,16385
137
137
  datachain/query/metrics.py,sha256=qOMHiYPTMtVs2zI-mUSy8OPAVwrg4oJtVF85B9tdQyM,810
138
138
  datachain/query/params.py,sha256=JkVz6IKUIpF58JZRkUXFT8DAHX2yfaULbhVaGmHKFLc,826
@@ -163,11 +163,11 @@ datachain/sql/sqlite/base.py,sha256=T4G46GggBRMZaDCRnfBWDv_-P2aLisqJ947xMnkB3Pk,
163
163
  datachain/sql/sqlite/types.py,sha256=DCK7q-Zdc_m1o1T33xrKjYX1zRg1231gw3o3ACO_qho,1815
164
164
  datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
165
165
  datachain/toolkit/__init__.py,sha256=eQ58Q5Yf_Fgv1ZG0IO5dpB4jmP90rk8YxUWmPc1M2Bo,68
166
- datachain/toolkit/split.py,sha256=xQzzmvQRKsPteDKbpgOxd4r971BnFaK33mcOl0FuGeI,2883
166
+ datachain/toolkit/split.py,sha256=9HHZl0fGs5Zj8b9l2L3IKf0AiiVNL9SnWbc2rfDiXRA,3710
167
167
  datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
168
- datachain-0.37.6.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
169
- datachain-0.37.6.dist-info/METADATA,sha256=zBPCt_CUJzcP3rNzpykwH9v9A388r273Huo6Hp_f0Jk,13762
170
- datachain-0.37.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
171
- datachain-0.37.6.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
172
- datachain-0.37.6.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
173
- datachain-0.37.6.dist-info/RECORD,,
168
+ datachain-0.37.8.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
169
+ datachain-0.37.8.dist-info/METADATA,sha256=6MLsgOSmSsxKXzbiOqTs9yQXaPhFu1QwgSqN_OmuQQM,13763
170
+ datachain-0.37.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
171
+ datachain-0.37.8.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
172
+ datachain-0.37.8.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
173
+ datachain-0.37.8.dist-info/RECORD,,