molcraft 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a1'
1
+ __version__ = '0.1.0a3'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -14,3 +14,4 @@ from molcraft import ops
14
14
  from molcraft import records
15
15
  from molcraft import tensors
16
16
  from molcraft import callbacks
17
+ from molcraft import datasets
molcraft/datasets.py ADDED
@@ -0,0 +1,123 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+
5
+ def split(
6
+ data: pd.DataFrame | np.ndarray,
7
+ train_size: float | None = None,
8
+ validation_size: float | None = None,
9
+ test_size: float = 0.1,
10
+ shuffle: bool = False,
11
+ random_state: int | None = None,
12
+ ) -> pd.DataFrame | np.ndarray:
13
+ """Splits dataset into subsets.
14
+
15
+ Args:
16
+ data:
17
+ A pd.DataFrame or np.ndarray object.
18
+ train_size:
19
+ Optional train size, as a fraction (`float`) or size (`int`).
20
+ validation_size:
21
+ Optional validation size, as a fraction (`float`) or size (`int`).
22
+ test_size:
23
+ Required test size, as a fraction (`float`) or size (`int`).
24
+ shuffle:
25
+ Whether the dataset should be shuffled prior to splitting.
26
+ random_state:
27
+ The random state (or seed). Only applicable if shuffling.
28
+ """
29
+
30
+ if not isinstance(data, (pd.DataFrame, np.ndarray, list)):
31
+ raise ValueError(
32
+ '`data` needs to be a pd.DataFrame, np.ndarray or a list. '
33
+ f'Found {type(data)}.'
34
+ )
35
+
36
+ size = len(data)
37
+
38
+ if test_size is None:
39
+ raise ValueError('`test_size` is required.')
40
+ elif test_size <= 0:
41
+ raise ValueError(
42
+ f'Test size needs to be positive. Found: {test_size}. '
43
+ 'Either specify a positive `float` (fraction) or '
44
+ 'a positive `int` (size).'
45
+ )
46
+ if train_size is not None and train_size <= 0:
47
+ raise ValueError(
48
+ f'Train size needs to be None or positive. Found: {train_size}. '
49
+ 'Either specify `None`, a positive `float` (fraction) or '
50
+ 'a positive `int` (size).'
51
+ )
52
+ if validation_size is not None and validation_size <= 0:
53
+ raise ValueError(
54
+ f'Validation size needs to be None or positive. Found: {validation_size}. '
55
+ 'Either specify `None`, a positive `float` (fraction) or '
56
+ 'a positive `int` (size).'
57
+ )
58
+
59
+ if isinstance(test_size, float):
60
+ test_size = int(size * test_size)
61
+ if validation_size and isinstance(validation_size, float):
62
+ validation_size = int(size * validation_size)
63
+ elif not validation_size:
64
+ validation_size = 0
65
+
66
+ if train_size and isinstance(train_size, float):
67
+ train_size = int(size * train_size)
68
+ elif not train_size:
69
+ train_size = 0
70
+
71
+ if not train_size:
72
+ train_size = size - test_size
73
+ if not validation_size:
74
+ train_size -= validation_size
75
+
76
+ remainder = size - (train_size + validation_size + test_size)
77
+
78
+ if remainder < 0:
79
+ raise ValueError(
80
+ 'Sizes of data subsets add up to more than the size of the original data set: '
81
+ f'{size} < ({train_size} + {validation_size} + {test_size})'
82
+ )
83
+ if test_size <= 0:
84
+ raise ValueError(
85
+ f'Test size needs to be greater than 0. Found: {test_size}.'
86
+ )
87
+ if train_size <= 0:
88
+ raise ValueError(
89
+ f'Train size needs to be greater than 0. Found: {train_size}.'
90
+ )
91
+
92
+ train_size += remainder
93
+
94
+ if isinstance(data, pd.DataFrame):
95
+ if shuffle:
96
+ data = data.sample(
97
+ frac=1.0, replace=False, random_state=random_state
98
+ )
99
+ train_data = data.iloc[:train_size]
100
+ test_data = data.iloc[-test_size:]
101
+ if not validation_size:
102
+ return train_data, test_data
103
+ validation_data = data.iloc[train_size:-test_size]
104
+ return train_data, validation_data, test_data
105
+
106
+ if not isinstance(data, np.ndarray):
107
+ data = np.asarray(data)
108
+
109
+ np.random.seed(random_state)
110
+
111
+ random_indices = np.arange(size)
112
+ np.random.shuffle(random_indices)
113
+ data = data[random_indices]
114
+
115
+ train_data = data[:train_size]
116
+ test_data = data[-test_size:]
117
+ if not validation_size:
118
+ return train_data, test_data
119
+ validation_data = data[train_size:-test_size]
120
+ return train_data, validation_data, test_data
121
+
122
+
123
+
@@ -9,75 +9,36 @@ from molcraft import chem
9
9
  from molcraft import features
10
10
  from molcraft import featurizers
11
11
  from molcraft import tensors
12
+ from molcraft import descriptors
12
13
 
13
14
 
14
- class PeptideGraphFeaturizer(featurizers.MolGraphFeaturizer):
15
-
16
- def __init__(
17
- self,
18
- atom_features: list[features.Feature] | str | None = None,
19
- bond_features: list[features.Feature] | str | None = None,
20
- super_atom_feature: features.Feature | bool = None,
21
- radius: int | float | None = None,
22
- self_loops: bool = False,
23
- include_hs: bool = False,
24
- feature_dtype: str = 'float32',
25
- index_dtype: str = 'int32',
26
- ) -> None:
27
- if super_atom_feature is None:
28
- super_atom_feature = AminoAcidType()
29
- super().__init__(
30
- atom_features=atom_features,
31
- bond_features=bond_features,
32
- super_atom_feature=super_atom_feature,
33
- radius=radius,
34
- self_loops=self_loops,
35
- include_hs=include_hs,
36
- feature_dtype=feature_dtype,
37
- index_dtype=index_dtype
38
- )
15
+ def Graph(
16
+ inputs,
17
+ atom_features: list[features.Feature] | str | None = 'auto',
18
+ bond_features: list[features.Feature] | str | None = 'auto',
19
+ super_atom: bool = True,
20
+ radius: int | float | None = None,
21
+ self_loops: bool = False,
22
+ include_hs: bool = False,
23
+ **kwargs,
24
+ ):
25
+ featurizer = featurizers.MolGraphFeaturizer(
26
+ atom_features=atom_features,
27
+ bond_features=bond_features,
28
+ molecule_features=[AminoAcidType()],
29
+ super_atom=super_atom,
30
+ radius=radius,
31
+ self_loops=self_loops,
32
+ include_hs=include_hs,
33
+ **kwargs,
34
+ )
39
35
 
40
- def to_index(self, sequence: str):
41
- pass
42
-
43
- def static(self, inputs):
44
- # TODO: Make sure it is an ordered sequence
45
- inputs = [
46
- features.residues[x] for x in ['G'] + inputs
47
- ]
48
- mols = [
49
- chem.Mol.from_encoding(x, explicit_hs=self.include_hs) for x in inputs
50
- ]
51
- mols = [
52
- mol for mol in mols if mol is not None
53
- ]
54
- if not mols:
55
- return None
56
- tensor_list: list[tensors.GraphTensor] = [super().call(mol) for mol in mols]
57
- tensor: tensors.GraphTensor = tf.stack(tensor_list, axis=0)
58
- return tensor
59
-
60
- def call(self, inputs: str | tuple) -> tensors.GraphTensor:
61
- args = []
62
- if isinstance(inputs, (list, tuple, np.ndarray)):
63
- inputs, *args = inputs
64
- inputs = [
65
- features.residues[x] for x in chem.sequence_split(inputs)
66
- ]
67
- tensor_list: list[tensors.GraphTensor] = [super().call(x) for x in inputs]
68
- tensor: tensors.GraphTensor = tf.stack(tensor_list, axis=0)
69
- tensor = tensor._merge()
70
- context = {
71
- k: v for (k, v) in zip(['label', 'weight'], args)
72
- }
73
- tensor = tensor.update(
74
- {
75
- 'context': context
76
- }
77
- )
36
+ inputs = [
37
+ residues[x] for x in ['G'] + inputs
38
+ ]
39
+ tensor_list = [featurizer(x) for x in inputs]
40
+ return tf.stack(tensor_list, axis=0)
78
41
 
79
- return tensor
80
-
81
42
 
82
43
  def GraphLookup(graph: tensors.GraphTensor) -> 'GraphLookupLayer':
83
44
  lookup = GraphLookupLayer()
@@ -203,7 +164,7 @@ class Gather(keras.layers.Layer):
203
164
 
204
165
 
205
166
  @keras.saving.register_keras_serializable(package='molcraft')
206
- class AminoAcidType(features.Feature):
167
+ class AminoAcidType(descriptors.Descriptor):
207
168
 
208
169
  def __init__(self, vocab=None, **kwargs):
209
170
  vocab = [
@@ -217,7 +178,7 @@ class AminoAcidType(features.Feature):
217
178
  if not residue:
218
179
  raise KeyError(f'Could not find {mol.canonical_smiles} in `residues_reverse`.')
219
180
  mol = chem.remove_hs(mol)
220
- return [_extract_residue_type(residues_reverse[mol.canonical_smiles])]
181
+ return _extract_residue_type(residues_reverse[mol.canonical_smiles])
221
182
 
222
183
  def sequence_split(sequence: str):
223
184
  patterns = [
molcraft/features.py CHANGED
@@ -155,9 +155,11 @@ class Distance(EdgeFeature):
155
155
  encode_oov: bool = True,
156
156
  **kwargs,
157
157
  ) -> None:
158
- if max_distance is None:
159
- max_distance = 20
160
- vocab = list(range(max_distance + 1))
158
+ vocab = kwargs.pop('vocab', None)
159
+ if not vocab:
160
+ if max_distance is None:
161
+ max_distance = 20
162
+ vocab = list(range(max_distance + 1))
161
163
  super().__init__(
162
164
  vocab=vocab,
163
165
  allow_oov=allow_oov,
molcraft/featurizers.py CHANGED
@@ -200,12 +200,13 @@ class MolGraphFeaturizer(Featurizer):
200
200
  self.feature_dtype = 'float32'
201
201
  self.index_dtype = 'int32'
202
202
 
203
- def call(self, x: str | typing.Tuple) -> tensors.GraphTensor:
204
-
205
- if isinstance(x, (tuple, list, np.ndarray)):
206
- x, *args = x
203
+ def call(self, inputs: str | tuple) -> tensors.GraphTensor:
204
+ if isinstance(inputs, (tuple, list, np.ndarray)):
205
+ x, *context = inputs
206
+ if len(context) and isinstance(context[0], dict):
207
+ context = copy.deepcopy(context[0])
207
208
  else:
208
- args = []
209
+ x, context = inputs, None
209
210
 
210
211
  mol = chem.Mol.from_encoding(x, explicit_hs=self.include_hs)
211
212
 
@@ -220,14 +221,30 @@ class MolGraphFeaturizer(Featurizer):
220
221
  bond_feature = self.bond_features(mol)
221
222
  context_feature = self.context_feature(mol)
222
223
  molecule_size = self.num_atoms(mol)
223
-
224
- context, node, edge = {}, {}, {}
225
- for field, value in zip(['size', 'label', 'weight'], [molecule_size] + args):
226
- context[field] = value
224
+
225
+ if isinstance(context, dict):
226
+ if 'x' in context:
227
+ context['feature'] = context.pop('x')
228
+ if 'y' in context:
229
+ context['label'] = context.pop('y')
230
+ if 'sample_weight' in context:
231
+ context['weight'] = context.pop('sample_weight')
232
+ context = {
233
+ **{'size': molecule_size},
234
+ **context
235
+ }
236
+ elif isinstance(context, list):
237
+ context = {
238
+ **{'size': molecule_size},
239
+ **{key: value for (key, value) in zip(['label', 'weight'], context)}
240
+ }
241
+ else:
242
+ context = {'size': molecule_size}
227
243
 
228
244
  if context_feature is not None:
229
245
  context['feature'] = context_feature
230
246
 
247
+ node = {}
231
248
  node['feature'] = atom_feature
232
249
 
233
250
  if bond_feature is not None and (self.radius > 1 or self.self_loops):
@@ -239,6 +256,7 @@ class MolGraphFeaturizer(Featurizer):
239
256
  [bond_feature, zero_bond_feature], axis=0
240
257
  )
241
258
 
259
+ edge = {}
242
260
  if self.radius == 1:
243
261
  edge['source'], edge['target'] = mol.adjacency(
244
262
  fill='full', sparse=True, self_loops=self.self_loops, dtype=self.index_dtype
@@ -384,6 +402,7 @@ class MolGraphFeaturizer(Featurizer):
384
402
  return cls(**config)
385
403
 
386
404
 
405
+ @keras.saving.register_keras_serializable(package='molcraft')
387
406
  class MolGraphFeaturizer3D(MolGraphFeaturizer):
388
407
 
389
408
  """Molecular 3d-graph featurizer.
@@ -494,19 +513,25 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
494
513
  self.embed_conformer = self.conformer_generator is not None
495
514
  self.radius = float(radius) if radius else None
496
515
 
497
- def call(self, x: str | typing.Tuple) -> tensors.GraphTensor:
516
+ def call(self, inputs: str | tuple) -> tensors.GraphTensor:
498
517
 
499
- if isinstance(x, (tuple, list, np.ndarray)):
500
- x, *args = x
518
+ if isinstance(inputs, (tuple, list, np.ndarray)):
519
+ x, *context = inputs
520
+ if len(context) and isinstance(context[0], dict):
521
+ context = copy.deepcopy(context[0])
501
522
  else:
502
- args = []
523
+ x, context = inputs, None
503
524
 
504
525
  explicit_hs = (self.include_hs or self.embed_conformer)
505
526
  mol = chem.Mol.from_encoding(x, explicit_hs=explicit_hs)
506
-
527
+
507
528
  if mol is None:
529
+ warn(
530
+ f'Could not obtain `chem.Mol` from {x}. '
531
+ 'Proceeding without it.'
532
+ )
508
533
  return None
509
-
534
+
510
535
  if self.embed_conformer:
511
536
  mol = self.conformer_generator(mol)
512
537
  if not self.include_hs:
@@ -519,21 +544,38 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
519
544
  'of the `Featurizer` or input a 3D representation of the molecule. '
520
545
  )
521
546
 
522
- context, node, edge = {}, {}, {}
523
-
524
- context['size'] = self.num_atoms(mol) + int(self.super_atom)
525
- for field, value in zip(['label', 'weight'], args):
526
- context[field] = value
547
+ context_feature = self.context_feature(mol)
548
+ molecule_size = self.num_atoms(mol) + int(self.super_atom)
549
+
550
+ if isinstance(context, dict):
551
+ if 'x' in context:
552
+ context['feature'] = context.pop('x')
553
+ if 'y' in context:
554
+ context['label'] = context.pop('y')
555
+ if 'sample_weight' in context:
556
+ context['weight'] = context.pop('sample_weight')
557
+ context = {
558
+ **{'size': molecule_size},
559
+ **context
560
+ }
561
+ elif isinstance(context, list):
562
+ context = {
563
+ **{'size': molecule_size},
564
+ **{key: value for (key, value) in zip(['label', 'weight'], context)}
565
+ }
566
+ else:
567
+ context = {'size': molecule_size}
527
568
 
569
+ if context_feature is not None:
570
+ context['feature'] = context_feature
571
+
572
+ node = {}
528
573
  node['feature'] = self.atom_features(mol)
529
574
 
530
575
  if self._bond_features:
531
576
  edge_feature = self.bond_features(mol)
532
577
 
533
- context_feature = self.context_feature(mol)
534
- if context_feature is not None:
535
- context['feature'] = context_feature
536
-
578
+ edge = {}
537
579
  mols = chem._split_mol_by_confs(mol)
538
580
  tensor_list = []
539
581
  for i, mol in enumerate(mols):
@@ -563,11 +605,10 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
563
605
  node_conformer['coordinate'] = np.concatenate(
564
606
  [node_conformer['coordinate'], conformer.centroid[None]], axis=0
565
607
  )
566
-
567
608
  tensor_list.append(
568
609
  tensors.GraphTensor(context, node_conformer, edge_conformer)
569
610
  )
570
-
611
+
571
612
  return tensor_list
572
613
 
573
614
  def stack(self, outputs):
@@ -587,7 +628,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
587
628
  config['conformer_generator'] = keras.saving.deserialize_keras_object(
588
629
  config['conformer_generator']
589
630
  )
590
- return super().from_config(**config)
631
+ return super().from_config(config)
591
632
 
592
633
 
593
634
  def save_featurizer(