molcraft 0.1.0a13__tar.gz → 0.1.0a15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/PKG-INFO +1 -1
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/__init__.py +1 -1
- molcraft-0.1.0a15/molcraft/apps/qsrr.py +47 -0
- molcraft-0.1.0a15/molcraft/datasets.py +131 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/featurizers.py +1 -1
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/models.py +7 -8
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/ops.py +11 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft.egg-info/PKG-INFO +1 -1
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft.egg-info/SOURCES.txt +1 -0
- molcraft-0.1.0a13/molcraft/datasets.py +0 -123
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/LICENSE +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/README.md +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/apps/__init__.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/apps/peptides.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/callbacks.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/chem.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/conformers.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/descriptors.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/features.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/layers.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/losses.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/records.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft/tensors.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft.egg-info/dependency_links.txt +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft.egg-info/requires.txt +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/molcraft.egg-info/top_level.txt +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/pyproject.toml +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/setup.cfg +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/tests/test_chem.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/tests/test_featurizers.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/tests/test_layers.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/tests/test_losses.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/tests/test_models.py +0 -0
- {molcraft-0.1.0a13 → molcraft-0.1.0a15}/tests/test_tensors.py +0 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import molcraft
|
|
2
|
+
import keras
|
|
3
|
+
|
|
4
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
5
|
+
class AuxiliaryFeatureInjection(molcraft.layers.GraphLayer):
|
|
6
|
+
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
field: str = 'auxiliary_feature',
|
|
10
|
+
depth: int = 2,
|
|
11
|
+
drop: bool = True,
|
|
12
|
+
activation: str | None = None,
|
|
13
|
+
**kwargs,
|
|
14
|
+
) -> None:
|
|
15
|
+
super().__init__(**kwargs)
|
|
16
|
+
self.field = field
|
|
17
|
+
self.depth = depth
|
|
18
|
+
self.drop = drop
|
|
19
|
+
self.activation = keras.activations.get(activation)
|
|
20
|
+
|
|
21
|
+
def build(self, spec: molcraft.tensors.GraphTensor.Spec) -> None:
|
|
22
|
+
units = spec.node['feature'].shape[1]
|
|
23
|
+
for i in range(self.depth):
|
|
24
|
+
setattr(
|
|
25
|
+
self, f'dense_{i}', self.get_dense(units, activation=self.activation)
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
def propagate(self, tensor: molcraft.tensors.GraphTensor) -> None:
|
|
29
|
+
x = tensor.context[self.field]
|
|
30
|
+
if self.drop:
|
|
31
|
+
tensor = tensor.update({'context': {self.field: None}})
|
|
32
|
+
for i in range(self.depth):
|
|
33
|
+
x = getattr(self, f'dense_{i}')(x)
|
|
34
|
+
node_feature = molcraft.ops.scatter_add(
|
|
35
|
+
tensor.node['feature'], tensor.node['super'], x
|
|
36
|
+
)
|
|
37
|
+
return tensor.update({'node': {'feature': node_feature}})
|
|
38
|
+
|
|
39
|
+
def get_config(self) -> dict:
|
|
40
|
+
config = super().get_config()
|
|
41
|
+
config.update({
|
|
42
|
+
'field': self.field,
|
|
43
|
+
'depth': self.depth,
|
|
44
|
+
'drop': self.drop,
|
|
45
|
+
'activation': keras.activations.serialize(self.activation)
|
|
46
|
+
})
|
|
47
|
+
return config
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import typing
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def split(
|
|
7
|
+
data: pd.DataFrame | np.ndarray,
|
|
8
|
+
*,
|
|
9
|
+
train_size: float | None = None,
|
|
10
|
+
validation_size: float | None = None,
|
|
11
|
+
test_size: float | None = None,
|
|
12
|
+
groups: str | np.ndarray = None,
|
|
13
|
+
shuffle: bool = False,
|
|
14
|
+
random_state: int | None = None,
|
|
15
|
+
) -> tuple[np.ndarray | pd.DataFrame, ...]:
|
|
16
|
+
"""Splits the dataset into subsets.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
data:
|
|
20
|
+
A pd.DataFrame or np.ndarray object.
|
|
21
|
+
train_size:
|
|
22
|
+
The size of the train set.
|
|
23
|
+
validation_size:
|
|
24
|
+
The size of the validation set.
|
|
25
|
+
test_size:
|
|
26
|
+
The size of the test set.
|
|
27
|
+
groups:
|
|
28
|
+
The groups to perform the splitting on.
|
|
29
|
+
shuffle:
|
|
30
|
+
Whether the dataset should be shuffled prior to splitting.
|
|
31
|
+
random_state:
|
|
32
|
+
The random state/seed. Only applicable if shuffling.
|
|
33
|
+
"""
|
|
34
|
+
if not isinstance(data, (pd.DataFrame, np.ndarray)):
|
|
35
|
+
raise ValueError(f'Unsupported `data` type ({type(data)}).')
|
|
36
|
+
|
|
37
|
+
if isinstance(groups, str):
|
|
38
|
+
groups = data[groups].values
|
|
39
|
+
elif groups is None:
|
|
40
|
+
groups = np.arange(len(data))
|
|
41
|
+
|
|
42
|
+
indices = np.unique(groups)
|
|
43
|
+
size = len(indices)
|
|
44
|
+
|
|
45
|
+
if not train_size and not test_size:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f'Found both `train_size` and `test_size` to be `None`, '
|
|
48
|
+
f'specify at least one of them.'
|
|
49
|
+
)
|
|
50
|
+
if isinstance(test_size, float):
|
|
51
|
+
test_size = int(size * test_size)
|
|
52
|
+
if isinstance(train_size, float):
|
|
53
|
+
train_size = int(size * train_size)
|
|
54
|
+
if isinstance(validation_size, float):
|
|
55
|
+
validation_size = int(size * validation_size)
|
|
56
|
+
elif not validation_size:
|
|
57
|
+
validation_size = 0
|
|
58
|
+
|
|
59
|
+
if not train_size:
|
|
60
|
+
train_size = (size - test_size - validation_size)
|
|
61
|
+
if not test_size:
|
|
62
|
+
test_size = (size - train_size - validation_size)
|
|
63
|
+
|
|
64
|
+
remainder = size - (train_size + validation_size + test_size)
|
|
65
|
+
if remainder < 0:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f'subset sizes added up to more than the data size.'
|
|
68
|
+
)
|
|
69
|
+
train_size += remainder
|
|
70
|
+
|
|
71
|
+
if shuffle:
|
|
72
|
+
np.random.seed(random_state)
|
|
73
|
+
np.random.shuffle(indices)
|
|
74
|
+
|
|
75
|
+
train_mask = np.isin(groups, indices[:train_size])
|
|
76
|
+
test_mask = np.isin(groups, indices[-test_size:])
|
|
77
|
+
if not validation_size:
|
|
78
|
+
return data[train_mask], data[test_mask]
|
|
79
|
+
validation_mask = np.isin(groups, indices[train_size:-test_size])
|
|
80
|
+
return data[train_mask], data[validation_mask], data[test_mask]
|
|
81
|
+
|
|
82
|
+
def cv_split(
|
|
83
|
+
data: pd.DataFrame | np.ndarray,
|
|
84
|
+
num_splits: int = 10,
|
|
85
|
+
groups: str | np.ndarray = None,
|
|
86
|
+
shuffle: bool = False,
|
|
87
|
+
random_state: int | None = None,
|
|
88
|
+
) -> typing.Iterator[
|
|
89
|
+
tuple[np.ndarray | pd.DataFrame, np.ndarray | pd.DataFrame]
|
|
90
|
+
]:
|
|
91
|
+
"""Splits the dataset into cross-validation folds.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
data:
|
|
95
|
+
A pd.DataFrame or np.ndarray object.
|
|
96
|
+
num_splits:
|
|
97
|
+
The number of cross-validation folds.
|
|
98
|
+
groups:
|
|
99
|
+
The groups to perform the splitting on.
|
|
100
|
+
shuffle:
|
|
101
|
+
Whether the dataset should be shuffled prior to splitting.
|
|
102
|
+
random_state:
|
|
103
|
+
The random state/seed. Only applicable if shuffling.
|
|
104
|
+
"""
|
|
105
|
+
if not isinstance(data, (pd.DataFrame, np.ndarray)):
|
|
106
|
+
raise ValueError(f'Unsupported `data` type ({type(data)}).')
|
|
107
|
+
|
|
108
|
+
if isinstance(groups, str):
|
|
109
|
+
groups = data[groups].values
|
|
110
|
+
elif groups is None:
|
|
111
|
+
groups = np.arange(len(data))
|
|
112
|
+
|
|
113
|
+
indices = np.unique(groups)
|
|
114
|
+
size = len(indices)
|
|
115
|
+
|
|
116
|
+
if num_splits > size:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f'`num_splits` ({num_splits}) must not be greater than'
|
|
119
|
+
f'the data size or the number of groups ({size}).'
|
|
120
|
+
)
|
|
121
|
+
if shuffle:
|
|
122
|
+
np.random.seed(random_state)
|
|
123
|
+
np.random.shuffle(indices)
|
|
124
|
+
|
|
125
|
+
indices_splits = np.array_split(indices, num_splits)
|
|
126
|
+
|
|
127
|
+
for k in range(num_splits):
|
|
128
|
+
test_indices = indices_splits[k]
|
|
129
|
+
test_mask = np.isin(groups, test_indices)
|
|
130
|
+
train_mask = ~test_mask
|
|
131
|
+
yield data[train_mask], data[test_mask]
|
|
@@ -169,7 +169,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
169
169
|
if default_atom_features:
|
|
170
170
|
atom_features = [features.AtomType()]
|
|
171
171
|
if not self.include_hs:
|
|
172
|
-
atom_features.append(features.
|
|
172
|
+
atom_features.append(features.NumHydrogens())
|
|
173
173
|
atom_features.append(features.Degree())
|
|
174
174
|
if not isinstance(self, MolGraphFeaturizer3D):
|
|
175
175
|
default_bond_features = (
|
|
@@ -250,7 +250,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
|
250
250
|
val_size = int(val_split * x.num_subgraphs)
|
|
251
251
|
x_val = _make_dataset(x[-val_size:], batch_size)
|
|
252
252
|
x = x[:-val_size]
|
|
253
|
-
x = _make_dataset(x, batch_size)
|
|
253
|
+
x = _make_dataset(x, batch_size, shuffle=kwargs.get('shuffle', True))
|
|
254
254
|
return super().fit(x, validation_data=x_val, **kwargs)
|
|
255
255
|
|
|
256
256
|
def evaluate(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
|
|
@@ -397,7 +397,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
|
397
397
|
raise ValueError(
|
|
398
398
|
'Could not extract output. `Readout` layer not found.'
|
|
399
399
|
)
|
|
400
|
-
return self.__class__(inputs, outputs, name=f'{self.name}
|
|
400
|
+
return self.__class__(inputs, outputs, name=f'{self.name}_backbone')
|
|
401
401
|
|
|
402
402
|
def head(self) -> functional.Functional:
|
|
403
403
|
if not isinstance(self, FunctionalGraphModel):
|
|
@@ -561,9 +561,8 @@ def _functional_init_arguments(args, kwargs):
|
|
|
561
561
|
or ("inputs" in kwargs and "outputs" in kwargs)
|
|
562
562
|
)
|
|
563
563
|
|
|
564
|
-
def _make_dataset(x: tensors.GraphTensor, batch_size: int):
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
.
|
|
568
|
-
|
|
569
|
-
)
|
|
564
|
+
def _make_dataset(x: tensors.GraphTensor, batch_size: int, shuffle: bool = False):
|
|
565
|
+
ds = tf.data.Dataset.from_tensor_slices(x)
|
|
566
|
+
if shuffle:
|
|
567
|
+
ds = ds.shuffle(buffer_size=ds.cardinality())
|
|
568
|
+
return ds.batch(batch_size).prefetch(-1)
|
|
@@ -4,6 +4,7 @@ import tensorflow as tf
|
|
|
4
4
|
from keras import backend
|
|
5
5
|
|
|
6
6
|
|
|
7
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
7
8
|
def gather(
|
|
8
9
|
node_feature: tf.Tensor,
|
|
9
10
|
edge: tf.Tensor
|
|
@@ -16,6 +17,7 @@ def gather(
|
|
|
16
17
|
edge = keras.ops.expand_dims(edge, axis=-1)
|
|
17
18
|
return keras.ops.take_along_axis(node_feature, edge, axis=0)
|
|
18
19
|
|
|
20
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
19
21
|
def aggregate(
|
|
20
22
|
node_feature: tf.Tensor,
|
|
21
23
|
edge: tf.Tensor,
|
|
@@ -30,6 +32,7 @@ def aggregate(
|
|
|
30
32
|
node_feature, edge, num_nodes, sorted=False
|
|
31
33
|
)
|
|
32
34
|
|
|
35
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
33
36
|
def propagate(
|
|
34
37
|
node_feature: tf.Tensor,
|
|
35
38
|
edge_source: tf.Tensor,
|
|
@@ -49,6 +52,7 @@ def propagate(
|
|
|
49
52
|
|
|
50
53
|
return aggregate(node_feature, edge_target, num_nodes)
|
|
51
54
|
|
|
55
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
52
56
|
def scatter_update(
|
|
53
57
|
inputs: tf.Tensor,
|
|
54
58
|
indices: tf.Tensor,
|
|
@@ -62,6 +66,7 @@ def scatter_update(
|
|
|
62
66
|
indices = keras.ops.expand_dims(indices, axis=-1)
|
|
63
67
|
return keras.ops.scatter_update(inputs, indices, updates)
|
|
64
68
|
|
|
69
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
65
70
|
def scatter_add(
|
|
66
71
|
inputs: tf.Tensor,
|
|
67
72
|
indices: tf.Tensor,
|
|
@@ -78,6 +83,7 @@ def scatter_add(
|
|
|
78
83
|
updates = scatter_update(keras.ops.zeros_like(inputs), indices, updates)
|
|
79
84
|
return inputs + updates
|
|
80
85
|
|
|
86
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
81
87
|
def edge_softmax(
|
|
82
88
|
score: tf.Tensor,
|
|
83
89
|
edge_target: tf.Tensor
|
|
@@ -98,6 +104,7 @@ def edge_softmax(
|
|
|
98
104
|
denominator = gather(denominator, edge_target)
|
|
99
105
|
return numerator / denominator
|
|
100
106
|
|
|
107
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
101
108
|
def edge_weight(
|
|
102
109
|
edge: tf.Tensor,
|
|
103
110
|
edge_weight: tf.Tensor,
|
|
@@ -108,6 +115,7 @@ def edge_weight(
|
|
|
108
115
|
edge_weight = keras.ops.expand_dims(edge_weight, axis=-1)
|
|
109
116
|
return edge * edge_weight
|
|
110
117
|
|
|
118
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
111
119
|
def segment_mean(
|
|
112
120
|
data: tf.Tensor,
|
|
113
121
|
segment_ids: tf.Tensor,
|
|
@@ -142,6 +150,7 @@ def segment_mean(
|
|
|
142
150
|
)
|
|
143
151
|
return x / sizes[:, None]
|
|
144
152
|
|
|
153
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
145
154
|
def gaussian(
|
|
146
155
|
x: tf.Tensor,
|
|
147
156
|
mean: tf.Tensor,
|
|
@@ -155,6 +164,7 @@ def gaussian(
|
|
|
155
164
|
a = (2 * np.pi) ** 0.5
|
|
156
165
|
return keras.ops.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
|
|
157
166
|
|
|
167
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
158
168
|
def euclidean_distance(
|
|
159
169
|
x1: tf.Tensor,
|
|
160
170
|
x2: tf.Tensor,
|
|
@@ -169,6 +179,7 @@ def euclidean_distance(
|
|
|
169
179
|
)
|
|
170
180
|
)
|
|
171
181
|
|
|
182
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
172
183
|
def displacement(
|
|
173
184
|
x1: tf.Tensor,
|
|
174
185
|
x2: tf.Tensor,
|
|
@@ -1,123 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
def split(
|
|
6
|
-
data: pd.DataFrame | np.ndarray,
|
|
7
|
-
train_size: float | None = None,
|
|
8
|
-
validation_size: float | None = None,
|
|
9
|
-
test_size: float = 0.1,
|
|
10
|
-
shuffle: bool = False,
|
|
11
|
-
random_state: int | None = None,
|
|
12
|
-
) -> pd.DataFrame | np.ndarray:
|
|
13
|
-
"""Splits dataset into subsets.
|
|
14
|
-
|
|
15
|
-
Args:
|
|
16
|
-
data:
|
|
17
|
-
A pd.DataFrame or np.ndarray object.
|
|
18
|
-
train_size:
|
|
19
|
-
Optional train size, as a fraction (`float`) or size (`int`).
|
|
20
|
-
validation_size:
|
|
21
|
-
Optional validation size, as a fraction (`float`) or size (`int`).
|
|
22
|
-
test_size:
|
|
23
|
-
Required test size, as a fraction (`float`) or size (`int`).
|
|
24
|
-
shuffle:
|
|
25
|
-
Whether the dataset should be shuffled prior to splitting.
|
|
26
|
-
random_state:
|
|
27
|
-
The random state (or seed). Only applicable if shuffling.
|
|
28
|
-
"""
|
|
29
|
-
|
|
30
|
-
if not isinstance(data, (pd.DataFrame, np.ndarray, list)):
|
|
31
|
-
raise ValueError(
|
|
32
|
-
'`data` needs to be a pd.DataFrame, np.ndarray or a list. '
|
|
33
|
-
f'Found {type(data)}.'
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
size = len(data)
|
|
37
|
-
|
|
38
|
-
if test_size is None:
|
|
39
|
-
raise ValueError('`test_size` is required.')
|
|
40
|
-
elif test_size <= 0:
|
|
41
|
-
raise ValueError(
|
|
42
|
-
f'Test size needs to be positive. Found: {test_size}. '
|
|
43
|
-
'Either specify a positive `float` (fraction) or '
|
|
44
|
-
'a positive `int` (size).'
|
|
45
|
-
)
|
|
46
|
-
if train_size is not None and train_size <= 0:
|
|
47
|
-
raise ValueError(
|
|
48
|
-
f'Train size needs to be None or positive. Found: {train_size}. '
|
|
49
|
-
'Either specify `None`, a positive `float` (fraction) or '
|
|
50
|
-
'a positive `int` (size).'
|
|
51
|
-
)
|
|
52
|
-
if validation_size is not None and validation_size <= 0:
|
|
53
|
-
raise ValueError(
|
|
54
|
-
f'Validation size needs to be None or positive. Found: {validation_size}. '
|
|
55
|
-
'Either specify `None`, a positive `float` (fraction) or '
|
|
56
|
-
'a positive `int` (size).'
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
if isinstance(test_size, float):
|
|
60
|
-
test_size = int(size * test_size)
|
|
61
|
-
if validation_size and isinstance(validation_size, float):
|
|
62
|
-
validation_size = int(size * validation_size)
|
|
63
|
-
elif not validation_size:
|
|
64
|
-
validation_size = 0
|
|
65
|
-
|
|
66
|
-
if train_size and isinstance(train_size, float):
|
|
67
|
-
train_size = int(size * train_size)
|
|
68
|
-
elif not train_size:
|
|
69
|
-
train_size = 0
|
|
70
|
-
|
|
71
|
-
if not train_size:
|
|
72
|
-
train_size = size - test_size
|
|
73
|
-
if not validation_size:
|
|
74
|
-
train_size -= validation_size
|
|
75
|
-
|
|
76
|
-
remainder = size - (train_size + validation_size + test_size)
|
|
77
|
-
|
|
78
|
-
if remainder < 0:
|
|
79
|
-
raise ValueError(
|
|
80
|
-
'Sizes of data subsets add up to more than the size of the original data set: '
|
|
81
|
-
f'{size} < ({train_size} + {validation_size} + {test_size})'
|
|
82
|
-
)
|
|
83
|
-
if test_size <= 0:
|
|
84
|
-
raise ValueError(
|
|
85
|
-
f'Test size needs to be greater than 0. Found: {test_size}.'
|
|
86
|
-
)
|
|
87
|
-
if train_size <= 0:
|
|
88
|
-
raise ValueError(
|
|
89
|
-
f'Train size needs to be greater than 0. Found: {train_size}.'
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
train_size += remainder
|
|
93
|
-
|
|
94
|
-
if isinstance(data, pd.DataFrame):
|
|
95
|
-
if shuffle:
|
|
96
|
-
data = data.sample(
|
|
97
|
-
frac=1.0, replace=False, random_state=random_state
|
|
98
|
-
)
|
|
99
|
-
train_data = data.iloc[:train_size]
|
|
100
|
-
test_data = data.iloc[-test_size:]
|
|
101
|
-
if not validation_size:
|
|
102
|
-
return train_data, test_data
|
|
103
|
-
validation_data = data.iloc[train_size:-test_size]
|
|
104
|
-
return train_data, validation_data, test_data
|
|
105
|
-
|
|
106
|
-
if not isinstance(data, np.ndarray):
|
|
107
|
-
data = np.asarray(data)
|
|
108
|
-
|
|
109
|
-
np.random.seed(random_state)
|
|
110
|
-
|
|
111
|
-
random_indices = np.arange(size)
|
|
112
|
-
np.random.shuffle(random_indices)
|
|
113
|
-
data = data[random_indices]
|
|
114
|
-
|
|
115
|
-
train_data = data[:train_size]
|
|
116
|
-
test_data = data[-test_size:]
|
|
117
|
-
if not validation_size:
|
|
118
|
-
return train_data, test_data
|
|
119
|
-
validation_data = data[train_size:-test_size]
|
|
120
|
-
return train_data, validation_data, test_data
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|