datachain 0.37.7__py3-none-any.whl → 0.37.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/data_storage/warehouse.py +31 -5
- datachain/lib/convert/values_to_tuples.py +139 -43
- datachain/lib/data_model.py +3 -0
- datachain/lib/dc/datachain.py +19 -3
- datachain/lib/signal_schema.py +72 -6
- datachain/query/dataset.py +22 -5
- datachain/toolkit/split.py +30 -8
- {datachain-0.37.7.dist-info → datachain-0.37.9.dist-info}/METADATA +1 -1
- {datachain-0.37.7.dist-info → datachain-0.37.9.dist-info}/RECORD +13 -13
- {datachain-0.37.7.dist-info → datachain-0.37.9.dist-info}/WHEEL +0 -0
- {datachain-0.37.7.dist-info → datachain-0.37.9.dist-info}/entry_points.txt +0 -0
- {datachain-0.37.7.dist-info → datachain-0.37.9.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.37.7.dist-info → datachain-0.37.9.dist-info}/top_level.txt +0 -0
|
@@ -18,6 +18,7 @@ from datachain.data_storage.schema import convert_rows_custom_column_types
|
|
|
18
18
|
from datachain.data_storage.serializer import Serializable
|
|
19
19
|
from datachain.dataset import DatasetRecord, StorageURI
|
|
20
20
|
from datachain.lib.file import File
|
|
21
|
+
from datachain.lib.model_store import ModelStore
|
|
21
22
|
from datachain.lib.signal_schema import SignalSchema
|
|
22
23
|
from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
|
|
23
24
|
from datachain.query.batch import RowsOutput
|
|
@@ -76,6 +77,29 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
76
77
|
def cleanup_for_tests(self):
|
|
77
78
|
"""Cleanup for tests."""
|
|
78
79
|
|
|
80
|
+
def _to_jsonable(self, obj: Any) -> Any:
|
|
81
|
+
"""Recursively convert Python/Pydantic structures into JSON-serializable
|
|
82
|
+
objects.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
if ModelStore.is_pydantic(type(obj)):
|
|
86
|
+
return obj.model_dump()
|
|
87
|
+
|
|
88
|
+
if isinstance(obj, dict):
|
|
89
|
+
out: dict[str, Any] = {}
|
|
90
|
+
for k, v in obj.items():
|
|
91
|
+
if not isinstance(k, str):
|
|
92
|
+
key_str = json.dumps(self._to_jsonable(k), ensure_ascii=False)
|
|
93
|
+
else:
|
|
94
|
+
key_str = k
|
|
95
|
+
out[key_str] = self._to_jsonable(v)
|
|
96
|
+
return out
|
|
97
|
+
|
|
98
|
+
if isinstance(obj, (list, tuple, set)):
|
|
99
|
+
return [self._to_jsonable(i) for i in obj]
|
|
100
|
+
|
|
101
|
+
return obj
|
|
102
|
+
|
|
79
103
|
def convert_type( # noqa: PLR0911
|
|
80
104
|
self,
|
|
81
105
|
val: Any,
|
|
@@ -122,11 +146,13 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
122
146
|
if col_python_type is dict or col_type_name == "JSON":
|
|
123
147
|
if value_type is str:
|
|
124
148
|
return val
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
149
|
+
try:
|
|
150
|
+
json_ready = self._to_jsonable(val)
|
|
151
|
+
return json.dumps(json_ready, ensure_ascii=False)
|
|
152
|
+
except Exception as e:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Cannot convert value {val!r} with type {value_type} to JSON"
|
|
155
|
+
) from e
|
|
130
156
|
|
|
131
157
|
if isinstance(val, col_python_type):
|
|
132
158
|
return val
|
|
@@ -13,41 +13,153 @@ class ValuesToTupleError(DataChainParamsError):
|
|
|
13
13
|
super().__init__(f"Cannot convert signals for dataset{ds_name}: {msg}")
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
def
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
16
|
+
def _find_first_non_none(sequence: Sequence[Any]) -> Any | None:
|
|
17
|
+
"""Find the first non-None element in a sequence."""
|
|
18
|
+
try:
|
|
19
|
+
return next(itertools.dropwhile(lambda i: i is None, sequence))
|
|
20
|
+
except StopIteration:
|
|
21
|
+
return None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _infer_list_item_type(lst: list) -> type:
|
|
25
|
+
"""Infer the item type of a list, handling None values and nested lists."""
|
|
26
|
+
if len(lst) == 0:
|
|
27
|
+
# Default to str when list is empty to avoid generic list
|
|
28
|
+
return str
|
|
29
|
+
|
|
30
|
+
first_item = _find_first_non_none(lst)
|
|
31
|
+
if first_item is None:
|
|
32
|
+
# Default to str when all items are None
|
|
33
|
+
return str
|
|
34
|
+
|
|
35
|
+
item_type = type(first_item)
|
|
36
|
+
|
|
37
|
+
# Handle nested lists one level deep
|
|
38
|
+
if isinstance(first_item, list) and len(first_item) > 0:
|
|
39
|
+
nested_item = _find_first_non_none(first_item)
|
|
40
|
+
if nested_item is not None:
|
|
41
|
+
return list[type(nested_item)] # type: ignore[misc, return-value]
|
|
42
|
+
# Default to str for nested lists with all None
|
|
43
|
+
return list[str] # type: ignore[return-value]
|
|
44
|
+
|
|
45
|
+
return item_type
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _infer_dict_value_type(dct: dict) -> type:
|
|
49
|
+
"""Infer the value type of a dict, handling None values and list values."""
|
|
50
|
+
if len(dct) == 0:
|
|
51
|
+
# Default to str when dict is empty to avoid generic dict values
|
|
52
|
+
return str
|
|
53
|
+
|
|
54
|
+
# Find first non-None value
|
|
55
|
+
first_value = None
|
|
56
|
+
for val in dct.values():
|
|
57
|
+
if val is not None:
|
|
58
|
+
first_value = val
|
|
59
|
+
break
|
|
60
|
+
|
|
61
|
+
if first_value is None:
|
|
62
|
+
# Default to str when all values are None
|
|
63
|
+
return str
|
|
64
|
+
|
|
65
|
+
# Handle list values
|
|
66
|
+
if isinstance(first_value, list) and len(first_value) > 0:
|
|
67
|
+
list_item = _find_first_non_none(first_value)
|
|
68
|
+
if list_item is not None:
|
|
69
|
+
return list[type(list_item)] # type: ignore[misc, return-value]
|
|
70
|
+
# Default to str for lists with all None
|
|
71
|
+
return list[str] # type: ignore[return-value]
|
|
72
|
+
|
|
73
|
+
return type(first_value)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _infer_type_from_sequence(
|
|
77
|
+
sequence: Sequence[DataValue], signal_name: str, ds_name: str
|
|
78
|
+
) -> type:
|
|
79
|
+
"""
|
|
80
|
+
Infer the type from a sequence of values.
|
|
81
|
+
|
|
82
|
+
Returns str if all values are None, otherwise infers from the first non-None value.
|
|
83
|
+
Handles lists and dicts with proper type inference for nested structures.
|
|
84
|
+
"""
|
|
85
|
+
first_element = _find_first_non_none(sequence)
|
|
86
|
+
|
|
87
|
+
if first_element is None:
|
|
88
|
+
# Default to str if column is empty or all values are None
|
|
89
|
+
return str
|
|
90
|
+
|
|
91
|
+
typ = type(first_element)
|
|
92
|
+
|
|
93
|
+
if not is_chain_type(typ):
|
|
94
|
+
raise ValuesToTupleError(
|
|
95
|
+
ds_name,
|
|
96
|
+
f"signal '{signal_name}' has unsupported type '{typ.__name__}'."
|
|
97
|
+
f" Please use DataModel types: {DataTypeNames}",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if isinstance(first_element, list):
|
|
101
|
+
item_type = _infer_list_item_type(first_element)
|
|
102
|
+
return list[item_type] # type: ignore[valid-type, return-value]
|
|
103
|
+
|
|
104
|
+
if isinstance(first_element, dict):
|
|
105
|
+
# If the first dict is empty, use str as default key/value types
|
|
106
|
+
if len(first_element) == 0:
|
|
107
|
+
return dict[str, str] # type: ignore[return-value]
|
|
108
|
+
first_key = next(iter(first_element.keys()))
|
|
109
|
+
value_type = _infer_dict_value_type(first_element)
|
|
110
|
+
return dict[type(first_key), value_type] # type: ignore[misc, return-value]
|
|
111
|
+
|
|
112
|
+
return typ
|
|
33
113
|
|
|
34
|
-
key: str = next(iter(fr_map.keys()))
|
|
35
|
-
output = {key: output} # type: ignore[dict-item]
|
|
36
114
|
|
|
37
|
-
|
|
115
|
+
def _validate_and_normalize_output(
|
|
116
|
+
output: DataType | Sequence[str] | dict[str, DataType] | None,
|
|
117
|
+
fr_map: dict[str, Sequence[DataValue]],
|
|
118
|
+
ds_name: str,
|
|
119
|
+
) -> dict[str, DataType] | None:
|
|
120
|
+
"""Validate and normalize the output parameter to a dict format."""
|
|
121
|
+
if not output:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
if not isinstance(output, (Sequence, str, dict)):
|
|
125
|
+
if len(fr_map) != 1:
|
|
38
126
|
raise ValuesToTupleError(
|
|
39
127
|
ds_name,
|
|
40
|
-
"output type
|
|
41
|
-
f"'{type(output).__name__}' is given",
|
|
128
|
+
f"only one output type was specified, {len(fr_map)} expected",
|
|
42
129
|
)
|
|
43
|
-
|
|
44
|
-
if len(output) != len(fr_map):
|
|
130
|
+
if not isinstance(output, type):
|
|
45
131
|
raise ValuesToTupleError(
|
|
46
132
|
ds_name,
|
|
47
|
-
f"
|
|
48
|
-
f" number of signals '{len(fr_map)}'",
|
|
133
|
+
f"output must specify a type while '{output}' was given",
|
|
49
134
|
)
|
|
50
135
|
|
|
136
|
+
key: str = next(iter(fr_map.keys()))
|
|
137
|
+
return {key: output} # type: ignore[dict-item]
|
|
138
|
+
|
|
139
|
+
if not isinstance(output, dict):
|
|
140
|
+
raise ValuesToTupleError(
|
|
141
|
+
ds_name,
|
|
142
|
+
"output type must be dict[str, DataType] while "
|
|
143
|
+
f"'{type(output).__name__}' is given",
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
if len(output) != len(fr_map):
|
|
147
|
+
raise ValuesToTupleError(
|
|
148
|
+
ds_name,
|
|
149
|
+
f"number of outputs '{len(output)}' should match"
|
|
150
|
+
f" number of signals '{len(fr_map)}'",
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return output # type: ignore[return-value]
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def values_to_tuples(
|
|
157
|
+
ds_name: str = "",
|
|
158
|
+
output: DataType | Sequence[str] | dict[str, DataType] | None = None,
|
|
159
|
+
**fr_map: Sequence[DataValue],
|
|
160
|
+
) -> tuple[Any, Any, Any]:
|
|
161
|
+
output = _validate_and_normalize_output(output, fr_map, ds_name)
|
|
162
|
+
|
|
51
163
|
types_map: dict[str, type] = {}
|
|
52
164
|
length = -1
|
|
53
165
|
for k, v in fr_map.items():
|
|
@@ -65,23 +177,7 @@ def values_to_tuples( # noqa: C901, PLR0912
|
|
|
65
177
|
# FIXME: Stops as soon as it finds the first non-None value.
|
|
66
178
|
# If a non-None value appears early, it won't check the remaining items for
|
|
67
179
|
# `None` values.
|
|
68
|
-
|
|
69
|
-
first_not_none_element = next(
|
|
70
|
-
itertools.dropwhile(lambda i: i is None, v)
|
|
71
|
-
)
|
|
72
|
-
except StopIteration:
|
|
73
|
-
# set default type to `str` if column is empty or all values are `None`
|
|
74
|
-
typ = str
|
|
75
|
-
else:
|
|
76
|
-
typ = type(first_not_none_element) # type: ignore[assignment]
|
|
77
|
-
if not is_chain_type(typ):
|
|
78
|
-
raise ValuesToTupleError(
|
|
79
|
-
ds_name,
|
|
80
|
-
f"signal '{k}' has unsupported type '{typ.__name__}'."
|
|
81
|
-
f" Please use DataModel types: {DataTypeNames}",
|
|
82
|
-
)
|
|
83
|
-
if isinstance(first_not_none_element, list):
|
|
84
|
-
typ = list[type(first_not_none_element[0])] # type: ignore[assignment, misc]
|
|
180
|
+
typ = _infer_type_from_sequence(v, k, ds_name)
|
|
85
181
|
types_map[k] = typ
|
|
86
182
|
|
|
87
183
|
if length < 0:
|
datachain/lib/data_model.py
CHANGED
|
@@ -64,6 +64,9 @@ def is_chain_type(t: type) -> bool:
|
|
|
64
64
|
if orig is list and len(args) == 1:
|
|
65
65
|
return is_chain_type(get_args(t)[0])
|
|
66
66
|
|
|
67
|
+
if orig is dict and len(args) == 2:
|
|
68
|
+
return is_chain_type(args[0]) and is_chain_type(args[1])
|
|
69
|
+
|
|
67
70
|
if orig in (Union, types.UnionType) and len(args) == 2 and (type(None) in args):
|
|
68
71
|
return is_chain_type(args[0] if args[1] is type(None) else args[1])
|
|
69
72
|
|
datachain/lib/dc/datachain.py
CHANGED
|
@@ -52,7 +52,11 @@ from datachain.lib.udf_signature import UdfSignature
|
|
|
52
52
|
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
|
|
53
53
|
from datachain.project import Project
|
|
54
54
|
from datachain.query import Session
|
|
55
|
-
from datachain.query.dataset import
|
|
55
|
+
from datachain.query.dataset import (
|
|
56
|
+
DatasetQuery,
|
|
57
|
+
PartitionByType,
|
|
58
|
+
RegenerateSystemColumns,
|
|
59
|
+
)
|
|
56
60
|
from datachain.query.schema import DEFAULT_DELIMITER, Column
|
|
57
61
|
from datachain.sql.functions import path as pathfunc
|
|
58
62
|
from datachain.utils import batched_it, env2bool, inside_notebook, row_to_nested_dict
|
|
@@ -2740,8 +2744,20 @@ class DataChain:
|
|
|
2740
2744
|
)
|
|
2741
2745
|
|
|
2742
2746
|
def shuffle(self) -> "Self":
|
|
2743
|
-
"""Shuffle
|
|
2744
|
-
|
|
2747
|
+
"""Shuffle rows with a best-effort deterministic ordering.
|
|
2748
|
+
|
|
2749
|
+
This produces repeatable shuffles. Merge and union operations can
|
|
2750
|
+
lead to non-deterministic results. Use order by or save a dataset
|
|
2751
|
+
afterward to guarantee the same result.
|
|
2752
|
+
"""
|
|
2753
|
+
query = self._query.clone(new_table=False)
|
|
2754
|
+
query.steps.append(RegenerateSystemColumns(self._query.catalog))
|
|
2755
|
+
|
|
2756
|
+
chain = self._evolve(
|
|
2757
|
+
query=query,
|
|
2758
|
+
signal_schema=SignalSchema({"sys": Sys}) | self.signals_schema,
|
|
2759
|
+
)
|
|
2760
|
+
return chain.order_by("sys.rand")
|
|
2745
2761
|
|
|
2746
2762
|
def sample(self, n: int) -> "Self":
|
|
2747
2763
|
"""Return a random sample from the chain.
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import hashlib
|
|
3
|
-
import json
|
|
4
3
|
import logging
|
|
5
4
|
import math
|
|
6
5
|
import types
|
|
@@ -14,9 +13,7 @@ from typing import (
|
|
|
14
13
|
TYPE_CHECKING,
|
|
15
14
|
Annotated,
|
|
16
15
|
Any,
|
|
17
|
-
Dict, # type: ignore[UP035]
|
|
18
16
|
Final,
|
|
19
|
-
List, # type: ignore[UP035]
|
|
20
17
|
Literal,
|
|
21
18
|
Optional,
|
|
22
19
|
Union,
|
|
@@ -24,6 +21,7 @@ from typing import (
|
|
|
24
21
|
get_origin,
|
|
25
22
|
)
|
|
26
23
|
|
|
24
|
+
import ujson as json
|
|
27
25
|
from pydantic import BaseModel, Field, ValidationError, create_model
|
|
28
26
|
from sqlalchemy import ColumnElement
|
|
29
27
|
from typing_extensions import Literal as LiteralEx
|
|
@@ -569,8 +567,10 @@ class SignalSchema:
|
|
|
569
567
|
pos = 0
|
|
570
568
|
for fr_cls in self.values.values():
|
|
571
569
|
if (fr := ModelStore.to_pydantic(fr_cls)) is None:
|
|
572
|
-
|
|
570
|
+
value = row[pos]
|
|
573
571
|
pos += 1
|
|
572
|
+
converted = self._convert_feature_value(fr_cls, value, catalog, cache)
|
|
573
|
+
res.append(converted)
|
|
574
574
|
else:
|
|
575
575
|
json, pos = unflatten_to_json_pos(fr, row, pos) # type: ignore[union-attr]
|
|
576
576
|
try:
|
|
@@ -585,6 +585,72 @@ class SignalSchema:
|
|
|
585
585
|
res.append(obj)
|
|
586
586
|
return res
|
|
587
587
|
|
|
588
|
+
def _convert_feature_value(
|
|
589
|
+
self,
|
|
590
|
+
annotation: DataType,
|
|
591
|
+
value: Any,
|
|
592
|
+
catalog: "Catalog",
|
|
593
|
+
cache: bool,
|
|
594
|
+
) -> Any:
|
|
595
|
+
"""Convert raw DB value into declared annotation if needed."""
|
|
596
|
+
if value is None:
|
|
597
|
+
return None
|
|
598
|
+
|
|
599
|
+
result = value
|
|
600
|
+
origin = get_origin(annotation)
|
|
601
|
+
|
|
602
|
+
if origin in (Union, types.UnionType):
|
|
603
|
+
non_none_args = [
|
|
604
|
+
arg for arg in get_args(annotation) if arg is not type(None)
|
|
605
|
+
]
|
|
606
|
+
if len(non_none_args) == 1:
|
|
607
|
+
annotation = non_none_args[0]
|
|
608
|
+
origin = get_origin(annotation)
|
|
609
|
+
else:
|
|
610
|
+
return result
|
|
611
|
+
|
|
612
|
+
if ModelStore.is_pydantic(annotation):
|
|
613
|
+
if isinstance(value, annotation):
|
|
614
|
+
obj = value
|
|
615
|
+
elif isinstance(value, Mapping):
|
|
616
|
+
obj = annotation(**value)
|
|
617
|
+
else:
|
|
618
|
+
return result
|
|
619
|
+
assert isinstance(obj, BaseModel)
|
|
620
|
+
SignalSchema._set_file_stream(obj, catalog, cache)
|
|
621
|
+
result = obj
|
|
622
|
+
elif origin is list:
|
|
623
|
+
args = get_args(annotation)
|
|
624
|
+
if args and isinstance(value, (list, tuple)):
|
|
625
|
+
item_type = args[0]
|
|
626
|
+
result = [
|
|
627
|
+
self._convert_feature_value(item_type, item, catalog, cache)
|
|
628
|
+
if item is not None
|
|
629
|
+
else None
|
|
630
|
+
for item in value
|
|
631
|
+
]
|
|
632
|
+
elif origin is dict:
|
|
633
|
+
args = get_args(annotation)
|
|
634
|
+
if len(args) == 2 and isinstance(value, dict):
|
|
635
|
+
key_type, val_type = args
|
|
636
|
+
result = {}
|
|
637
|
+
for key, val in value.items():
|
|
638
|
+
if key_type is str:
|
|
639
|
+
converted_key = key
|
|
640
|
+
else:
|
|
641
|
+
loaded_key = json.loads(key)
|
|
642
|
+
converted_key = self._convert_feature_value(
|
|
643
|
+
key_type, loaded_key, catalog, cache
|
|
644
|
+
)
|
|
645
|
+
converted_val = (
|
|
646
|
+
self._convert_feature_value(val_type, val, catalog, cache)
|
|
647
|
+
if val_type is not Any
|
|
648
|
+
else val
|
|
649
|
+
)
|
|
650
|
+
result[converted_key] = converted_val
|
|
651
|
+
|
|
652
|
+
return result
|
|
653
|
+
|
|
588
654
|
@staticmethod
|
|
589
655
|
def _set_file_stream(
|
|
590
656
|
obj: BaseModel, catalog: "Catalog", cache: bool = False
|
|
@@ -898,13 +964,13 @@ class SignalSchema:
|
|
|
898
964
|
args = get_args(type_)
|
|
899
965
|
type_str = SignalSchema._type_to_str(args[0], subtypes)
|
|
900
966
|
return f"Optional[{type_str}]"
|
|
901
|
-
if origin
|
|
967
|
+
if origin is list:
|
|
902
968
|
args = get_args(type_)
|
|
903
969
|
if len(args) == 0:
|
|
904
970
|
return "list"
|
|
905
971
|
type_str = SignalSchema._type_to_str(args[0], subtypes)
|
|
906
972
|
return f"list[{type_str}]"
|
|
907
|
-
if origin
|
|
973
|
+
if origin is dict:
|
|
908
974
|
args = get_args(type_)
|
|
909
975
|
if len(args) == 0:
|
|
910
976
|
return "dict"
|
datachain/query/dataset.py
CHANGED
|
@@ -786,10 +786,31 @@ class SQLClause(Step, ABC):
|
|
|
786
786
|
return tuple(c.get_column() if isinstance(c, Function) else c for c in cols)
|
|
787
787
|
|
|
788
788
|
@abstractmethod
|
|
789
|
-
def apply_sql_clause(self, query):
|
|
789
|
+
def apply_sql_clause(self, query: Any) -> Any:
|
|
790
790
|
pass
|
|
791
791
|
|
|
792
792
|
|
|
793
|
+
@frozen
|
|
794
|
+
class RegenerateSystemColumns(Step):
|
|
795
|
+
catalog: "Catalog"
|
|
796
|
+
|
|
797
|
+
def hash_inputs(self) -> str:
|
|
798
|
+
return hashlib.sha256(b"regenerate_system_columns").hexdigest()
|
|
799
|
+
|
|
800
|
+
def apply(
|
|
801
|
+
self, query_generator: QueryGenerator, temp_tables: list[str]
|
|
802
|
+
) -> StepResult:
|
|
803
|
+
query = query_generator.select()
|
|
804
|
+
new_query = self.catalog.warehouse._regenerate_system_columns(
|
|
805
|
+
query, keep_existing_columns=True
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
def q(*columns):
|
|
809
|
+
return new_query.with_only_columns(*columns)
|
|
810
|
+
|
|
811
|
+
return step_result(q, new_query.selected_columns)
|
|
812
|
+
|
|
813
|
+
|
|
793
814
|
@frozen
|
|
794
815
|
class SQLSelect(SQLClause):
|
|
795
816
|
args: tuple[Function | ColumnElement, ...]
|
|
@@ -1488,10 +1509,6 @@ class DatasetQuery:
|
|
|
1488
1509
|
finally:
|
|
1489
1510
|
self.cleanup()
|
|
1490
1511
|
|
|
1491
|
-
def shuffle(self) -> "Self":
|
|
1492
|
-
# ToDo: implement shaffle based on seed and/or generating random column
|
|
1493
|
-
return self.order_by(C.sys__rand)
|
|
1494
|
-
|
|
1495
1512
|
def sample(self, n) -> "Self":
|
|
1496
1513
|
"""
|
|
1497
1514
|
Return a random sample from the dataset.
|
datachain/toolkit/split.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import random
|
|
2
2
|
|
|
3
3
|
from datachain import C, DataChain
|
|
4
|
+
from datachain.lib.signal_schema import SignalResolvingError
|
|
4
5
|
|
|
5
6
|
RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
|
|
6
7
|
|
|
@@ -59,7 +60,10 @@ def train_test_split(
|
|
|
59
60
|
```
|
|
60
61
|
|
|
61
62
|
Note:
|
|
62
|
-
|
|
63
|
+
Splits reuse the same best-effort shuffle used by `DataChain.shuffle`. Results
|
|
64
|
+
are typically repeatable, but earlier operations such as `merge`, `union`, or
|
|
65
|
+
custom SQL that reshuffle rows can change the outcome between runs. Add order by
|
|
66
|
+
stable keys first when you need strict reproducibility.
|
|
63
67
|
"""
|
|
64
68
|
if len(weights) < 2:
|
|
65
69
|
raise ValueError("Weights should have at least two elements")
|
|
@@ -68,16 +72,34 @@ def train_test_split(
|
|
|
68
72
|
|
|
69
73
|
weights_normalized = [weight / sum(weights) for weight in weights]
|
|
70
74
|
|
|
75
|
+
try:
|
|
76
|
+
dc.signals_schema.resolve("sys.rand")
|
|
77
|
+
except SignalResolvingError:
|
|
78
|
+
dc = dc.persist()
|
|
79
|
+
|
|
71
80
|
rand_col = C("sys.rand")
|
|
72
81
|
if seed is not None:
|
|
73
82
|
uniform_seed = random.Random(seed).randrange(1, RESOLUTION) # noqa: S311
|
|
74
83
|
rand_col = (rand_col % RESOLUTION) * uniform_seed # type: ignore[assignment]
|
|
75
84
|
rand_col = rand_col % RESOLUTION # type: ignore[assignment]
|
|
76
85
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
|
|
86
|
+
boundaries: list[int] = [0]
|
|
87
|
+
cumulative = 0.0
|
|
88
|
+
for weight in weights_normalized[:-1]:
|
|
89
|
+
cumulative += weight
|
|
90
|
+
boundary = round(cumulative * RESOLUTION)
|
|
91
|
+
boundaries.append(min(boundary, RESOLUTION))
|
|
92
|
+
boundaries.append(RESOLUTION)
|
|
93
|
+
|
|
94
|
+
splits: list[DataChain] = []
|
|
95
|
+
last_index = len(weights_normalized) - 1
|
|
96
|
+
for index in range(len(weights_normalized)):
|
|
97
|
+
lower = boundaries[index]
|
|
98
|
+
if index == last_index:
|
|
99
|
+
condition = rand_col >= lower
|
|
100
|
+
else:
|
|
101
|
+
upper = boundaries[index + 1]
|
|
102
|
+
condition = (rand_col >= lower) & (rand_col < upper)
|
|
103
|
+
splits.append(dc.filter(condition))
|
|
104
|
+
|
|
105
|
+
return splits
|
|
@@ -58,7 +58,7 @@ datachain/data_storage/metastore.py,sha256=DFyTkKLJN5-nFXXc7ln_rGj-FLctj0nrhXJxu
|
|
|
58
58
|
datachain/data_storage/schema.py,sha256=3fAgiE11TIDYCW7EbTdiOm61SErRitvsLr7YPnUlVm0,9801
|
|
59
59
|
datachain/data_storage/serializer.py,sha256=oL8i8smyAeVUyDepk8Xhf3lFOGOEHMoZjA5GdFzvfGI,3862
|
|
60
60
|
datachain/data_storage/sqlite.py,sha256=o9TR6N27JB52M9rRXdM9uwdBektGucWtJi9UnmLGh0A,29669
|
|
61
|
-
datachain/data_storage/warehouse.py,sha256=
|
|
61
|
+
datachain/data_storage/warehouse.py,sha256=_TGfMOtpltHA-G1KgoeIc_FFUomSmpAr94p-9AWNYIE,35642
|
|
62
62
|
datachain/diff/__init__.py,sha256=lGrygGzdWSSYJ1DgX4h2q_ko5QINEW8PKfxOwE9ZFnI,9394
|
|
63
63
|
datachain/fs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
64
64
|
datachain/fs/reference.py,sha256=A8McpXF0CqbXPqanXuvpKu50YLB3a2ZXA3YAPxtBXSM,914
|
|
@@ -78,7 +78,7 @@ datachain/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
78
78
|
datachain/lib/arrow.py,sha256=eCZtqbjAzkL4aemY74f_XkIJ_FWwXugJNjIFOwDa9w0,10815
|
|
79
79
|
datachain/lib/audio.py,sha256=hHG29vqrV389im152wCjh80d0xqXGGvFnUpUwkzZejQ,7385
|
|
80
80
|
datachain/lib/clip.py,sha256=nF8-N6Uz0MbAsPJBY2iXEYa3DPLo80OOer5SRNAtcGM,6149
|
|
81
|
-
datachain/lib/data_model.py,sha256=
|
|
81
|
+
datachain/lib/data_model.py,sha256=srz0pfFohSXwFnt5OMi1fNjSbKkFq8vzkcO0n4PHxlQ,3904
|
|
82
82
|
datachain/lib/dataset_info.py,sha256=Ym7yYcGpfUmPLrfdxueijCVRP2Go6KbyuLk_fmzYgDU,3273
|
|
83
83
|
datachain/lib/file.py,sha256=YO4QUaZVZ0TVW9fahERZ3HJXPNXjB4oYzvLQntQYT9s,47501
|
|
84
84
|
datachain/lib/hf.py,sha256=jmyqRDXdksojUJCiU_2XFSIoMzzDJAZQs9xr-sEwEJc,7281
|
|
@@ -91,7 +91,7 @@ datachain/lib/namespaces.py,sha256=d4Zt2mYdGFctkA20SkB1woUxrNI4JwSxruxUGKwfauc,3
|
|
|
91
91
|
datachain/lib/projects.py,sha256=FfBfGoWvy1SccCQW2ITKdDA6V03FbnRCusOeHdPHr6Y,4059
|
|
92
92
|
datachain/lib/pytorch.py,sha256=gDJiUGoSaraW3JDPr5JW2a3SqT7KwgIMMpDTAC0L1_Y,7792
|
|
93
93
|
datachain/lib/settings.py,sha256=maMtywOUetJvEApDiMVfTTq-oaRNvUIfDCrqZwFL2GE,7559
|
|
94
|
-
datachain/lib/signal_schema.py,sha256=
|
|
94
|
+
datachain/lib/signal_schema.py,sha256=k43MncD1eew3zS6h_OYujg3jbvR6WH4Sj2mbrGvvvhc,43554
|
|
95
95
|
datachain/lib/tar.py,sha256=MLcVjzIgBqRuJacCNpZ6kwSZNq1i2tLyROc8PVprHsA,999
|
|
96
96
|
datachain/lib/text.py,sha256=uZom8qXfrv9QYvuDrvd0PuvPmj6qCsjVUwZSNr60BI4,1242
|
|
97
97
|
datachain/lib/udf.py,sha256=51qgPO5s5MA5ccwl7IIPxbkEZ4IKZe4tzihcpZ8ufX0,18618
|
|
@@ -105,11 +105,11 @@ datachain/lib/convert/flatten.py,sha256=_5rjGFnN6t1KCX5ftL5rG7tiiNat7j0SdNqajO15
|
|
|
105
105
|
datachain/lib/convert/python_to_sql.py,sha256=wfnqJ2vRL5UydNPQHshd82hUONsDBa4XyobCSTGqcEo,3187
|
|
106
106
|
datachain/lib/convert/sql_to_python.py,sha256=Gxc4FylWC_Pvvuawuc2MKZIiuAWI7wje8pyeN1MxRrU,670
|
|
107
107
|
datachain/lib/convert/unflatten.py,sha256=ysMkstwJzPMWUlnxn-Z-tXJR3wmhjHeSN_P-sDcLS6s,2010
|
|
108
|
-
datachain/lib/convert/values_to_tuples.py,sha256=
|
|
108
|
+
datachain/lib/convert/values_to_tuples.py,sha256=nOn7dkzScYERZH-2vgUxkQawRQ1KgdIuSDIicvqZkc0,7171
|
|
109
109
|
datachain/lib/dc/__init__.py,sha256=UrUzmDH6YyVl8fxM5iXTSFtl5DZTUzEYm1MaazK4vdQ,900
|
|
110
110
|
datachain/lib/dc/csv.py,sha256=fIfj5-2Ix4z5D5yZueagd5WUWw86pusJ9JJKD-U3KGg,4407
|
|
111
111
|
datachain/lib/dc/database.py,sha256=Wqob3dQc9Mol_0vagzVEXzteCKS9M0E3U5130KVmQKg,14629
|
|
112
|
-
datachain/lib/dc/datachain.py,sha256=
|
|
112
|
+
datachain/lib/dc/datachain.py,sha256=XHr3gbdpLwzHhhIzPQXL5uZJQMFZ1AypCENdRlWWxoM,104671
|
|
113
113
|
datachain/lib/dc/datasets.py,sha256=oY1t8QBAaZdhjwR439zZT74hMOspewVCrgdwy6juXng,15321
|
|
114
114
|
datachain/lib/dc/hf.py,sha256=FeruEO176L2qQ1Mnx0QmK4kV0GuQ4xtj717N8fGJrBI,2849
|
|
115
115
|
datachain/lib/dc/json.py,sha256=iJ6G0jwTKz8xtfh1eICShnWk_bAMWjF5bFnOXLHaTlw,2683
|
|
@@ -132,7 +132,7 @@ datachain/model/ultralytics/pose.py,sha256=pvoXrWWUSWT_UBaMwUb5MBHAY57Co2HFDPigF
|
|
|
132
132
|
datachain/model/ultralytics/segment.py,sha256=v9_xDxd5zw_I8rXsbl7yQXgEdTs2T38zyY_Y4XGN8ok,3194
|
|
133
133
|
datachain/query/__init__.py,sha256=7DhEIjAA8uZJfejruAVMZVcGFmvUpffuZJwgRqNwe-c,263
|
|
134
134
|
datachain/query/batch.py,sha256=ugTlSFqh_kxMcG6vJ5XrEzG9jBXRdb7KRAEEsFWiPew,4190
|
|
135
|
-
datachain/query/dataset.py,sha256=
|
|
135
|
+
datachain/query/dataset.py,sha256=9Ky0LZ7wMpfJbIZyXjnensrDQJvGg1pysZs96AYZqIY,67576
|
|
136
136
|
datachain/query/dispatch.py,sha256=Tg73zB6vDnYYYAvtlS9l7BI3sI1EfRCbDjiasvNxz2s,16385
|
|
137
137
|
datachain/query/metrics.py,sha256=qOMHiYPTMtVs2zI-mUSy8OPAVwrg4oJtVF85B9tdQyM,810
|
|
138
138
|
datachain/query/params.py,sha256=JkVz6IKUIpF58JZRkUXFT8DAHX2yfaULbhVaGmHKFLc,826
|
|
@@ -163,11 +163,11 @@ datachain/sql/sqlite/base.py,sha256=T4G46GggBRMZaDCRnfBWDv_-P2aLisqJ947xMnkB3Pk,
|
|
|
163
163
|
datachain/sql/sqlite/types.py,sha256=DCK7q-Zdc_m1o1T33xrKjYX1zRg1231gw3o3ACO_qho,1815
|
|
164
164
|
datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
|
|
165
165
|
datachain/toolkit/__init__.py,sha256=eQ58Q5Yf_Fgv1ZG0IO5dpB4jmP90rk8YxUWmPc1M2Bo,68
|
|
166
|
-
datachain/toolkit/split.py,sha256=
|
|
166
|
+
datachain/toolkit/split.py,sha256=9HHZl0fGs5Zj8b9l2L3IKf0AiiVNL9SnWbc2rfDiXRA,3710
|
|
167
167
|
datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
|
|
168
|
-
datachain-0.37.
|
|
169
|
-
datachain-0.37.
|
|
170
|
-
datachain-0.37.
|
|
171
|
-
datachain-0.37.
|
|
172
|
-
datachain-0.37.
|
|
173
|
-
datachain-0.37.
|
|
168
|
+
datachain-0.37.9.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
|
|
169
|
+
datachain-0.37.9.dist-info/METADATA,sha256=iZmFzvJMHOE2j4t9zGX2eliujOaRIcD0E39Cx1IXSXg,13763
|
|
170
|
+
datachain-0.37.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
171
|
+
datachain-0.37.9.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
|
|
172
|
+
datachain-0.37.9.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
|
|
173
|
+
datachain-0.37.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|