molcraft 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/conformers.py ADDED
@@ -0,0 +1,155 @@
1
+ import keras
2
+
3
+ from molcraft import chem
4
+
5
+
6
+ @keras.saving.register_keras_serializable(package="molcraft")
7
+ class ConformerProcessor:
8
+
9
+ def get_config(self) -> dict:
10
+ return {}
11
+
12
+ @classmethod
13
+ def from_config(cls, config: dict):
14
+ return cls(**config)
15
+
16
+ def __call__(self, mol: chem.Mol) -> chem.Mol:
17
+ raise NotImplementedError
18
+
19
+
20
+ @keras.saving.register_keras_serializable(package="molcraft")
21
+ class ConformerEmbedder(ConformerProcessor):
22
+
23
+ def __init__(
24
+ self,
25
+ method: str = 'ETKDGv3',
26
+ num_conformers: int = 10,
27
+ force: bool = True,
28
+ **kwargs,
29
+ ) -> None:
30
+ self.method = method
31
+ self.num_conformers = num_conformers
32
+ self.force = force
33
+ self.kwargs = kwargs
34
+
35
+ def get_config(self) -> dict:
36
+ config = {
37
+ 'method': self.method,
38
+ 'num_conformers': self.num_conformers,
39
+ 'force': self.force,
40
+ }
41
+ config.update({
42
+ k: v for (k, v) in self.kwargs.items()
43
+ })
44
+ return config
45
+
46
+ def __call__(self, mol: chem.Mol) -> chem.Mol:
47
+ return chem.embed_conformers(
48
+ mol,
49
+ method=self.method,
50
+ num_conformers=self.num_conformers,
51
+ force=self.force,
52
+ **self.kwargs,
53
+ )
54
+
55
+
56
+ @keras.saving.register_keras_serializable(package="molcraft")
57
+ class ConformerOptimizer(ConformerProcessor):
58
+
59
+ def __init__(
60
+ self,
61
+ method: str = 'UFF',
62
+ max_iter: int = 200,
63
+ ignore_interfragment_interactions: bool = True,
64
+ vdw_threshold: float = 10.0,
65
+ **kwargs,
66
+ ) -> None:
67
+ self.method = method
68
+ self.max_iter = max_iter
69
+ self.ignore_interfragment_interactions = ignore_interfragment_interactions
70
+ self.vdw_threshold = vdw_threshold
71
+ self.kwargs = kwargs
72
+
73
+ def get_config(self) -> dict:
74
+ config = {
75
+ 'method': self.method,
76
+ 'max_iter': self.max_iter,
77
+ 'ignore_interfragment_interactions': self.ignore_interfragment_interactions,
78
+ 'vdw_threshold': self.vdw_threshold,
79
+ }
80
+ config.update({
81
+ k: v for (k, v) in self.kwargs.items()
82
+ })
83
+ return config
84
+
85
+ def __call__(self, mol: chem.Mol) -> chem.Mol:
86
+ return chem.optimize_conformers(
87
+ mol,
88
+ method=self.method,
89
+ max_iter=self.max_iter,
90
+ ignore_interfragment_interactions=self.ignore_interfragment_interactions,
91
+ vdw_threshold=self.vdw_threshold,
92
+ **self.kwargs,
93
+ )
94
+
95
+
96
+ @keras.saving.register_keras_serializable(package="molcraft")
97
+ class ConformerPruner(ConformerProcessor):
98
+ def __init__(
99
+ self,
100
+ keep: int = 1,
101
+ threshold: float = 0.0,
102
+ energy_force_field: str = 'UFF',
103
+ **kwargs,
104
+ ) -> None:
105
+ self.keep = keep
106
+ self.threshold = threshold
107
+ self.energy_force_field = energy_force_field
108
+ self.kwargs = kwargs
109
+
110
+ def get_config(self) -> dict:
111
+ config = {
112
+ 'keep': self.keep,
113
+ 'threshold': self.threshold,
114
+ 'energy_force_field': self.energy_force_field,
115
+ }
116
+ config.update({
117
+ k: v for (k, v) in self.kwargs.items()
118
+ })
119
+ return config
120
+
121
+ def __call__(self, mol: chem.Mol) -> chem.Mol:
122
+ return chem.prune_conformers(
123
+ mol,
124
+ keep=self.keep,
125
+ threshold=self.threshold,
126
+ energy_force_field=self.energy_force_field,
127
+ **self.kwargs,
128
+ )
129
+
130
+
131
+ @keras.saving.register_keras_serializable(package='molcraft')
132
+ class ConformerGenerator(ConformerProcessor):
133
+
134
+ def __init__(self, steps: list[ConformerProcessor]) -> None:
135
+ self.steps = steps
136
+
137
+ def get_config(self) -> dict:
138
+ return {
139
+ "steps": [
140
+ keras.saving.serialize_keras_object(step) for step in self.steps
141
+ ]
142
+ }
143
+
144
+ @classmethod
145
+ def from_config(cls, config: dict) -> 'ConformerGenerator':
146
+ steps = [
147
+ keras.saving.deserialize_keras_object(obj)
148
+ for obj in config["steps"]
149
+ ]
150
+ return cls(steps)
151
+
152
+ def __call__(self, mol: chem.Mol) -> chem.Mol:
153
+ for step in self.steps:
154
+ mol = step(mol)
155
+ return mol
@@ -0,0 +1,90 @@
1
+ import keras
2
+ import numpy as np
3
+ from rdkit.Chem import Descriptors
4
+
5
+ from molcraft import chem
6
+ from molcraft import features
7
+
8
+
9
+ @keras.saving.register_keras_serializable(package='molcraft')
10
+ class Descriptor(features.Feature):
11
+ def __init__(self, scale: float | None = None, **kwargs):
12
+ super().__init__(**kwargs)
13
+ self.scale = scale
14
+
15
+ def __call__(self, mol: chem.Mol) -> np.ndarray:
16
+ if not isinstance(mol, chem.Mol):
17
+ raise ValueError(
18
+ f'Input to {self.name} needs to be a `chem.Mol`, which '
19
+ 'implements two properties that should be iterated over '
20
+ 'to compute features: `atoms` and `bonds`.'
21
+ )
22
+ descriptor = self.call(mol)
23
+ func = (
24
+ self._featurize_categorical if self.vocab else
25
+ self._featurize_floating
26
+ )
27
+ scale_value = self.scale and not self.vocab
28
+ if not isinstance(descriptor, (tuple, list, np.ndarray)):
29
+ descriptor = [descriptor]
30
+
31
+ descriptors = []
32
+ for value in descriptor:
33
+ if scale_value:
34
+ value /= self.scale
35
+ descriptors.append(func(value))
36
+ return np.concatenate(descriptors)
37
+
38
+ def get_config(self):
39
+ config = super().get_config()
40
+ config['scale'] = self.scale
41
+ return config
42
+
43
+
44
+ @keras.saving.register_keras_serializable(package='molcraft')
45
+ class MolWeight(Descriptor):
46
+ def call(self, mol: chem.Mol) -> np.ndarray:
47
+ return Descriptors.MolWt(mol)
48
+
49
+
50
+ @keras.saving.register_keras_serializable(package='molcraft')
51
+ class MolTPSA(Descriptor):
52
+ def call(self, mol: chem.Mol) -> np.ndarray:
53
+ return Descriptors.TPSA(mol)
54
+
55
+
56
+ @keras.saving.register_keras_serializable(package='molcraft')
57
+ class MolLogP(Descriptor):
58
+ def call(self, mol: chem.Mol) -> np.ndarray:
59
+ return Descriptors.MolLogP(mol)
60
+
61
+
62
+ @keras.saving.register_keras_serializable(package='molcraft')
63
+ class NumHeavyAtoms(Descriptor):
64
+ def call(self, mol: chem.Mol) -> np.ndarray:
65
+ return Descriptors.HeavyAtomCount(mol)
66
+
67
+
68
+ @keras.saving.register_keras_serializable(package='molcraft')
69
+ class NumHydrogenDonors(Descriptor):
70
+ def call(self, mol: chem.Mol) -> np.ndarray:
71
+ return Descriptors.NumHDonors(mol)
72
+
73
+
74
+ @keras.saving.register_keras_serializable(package='molcraft')
75
+ class NumHydrogenAcceptors(Descriptor):
76
+ def call(self, mol: chem.Mol) -> np.ndarray:
77
+ return Descriptors.NumHAcceptors(mol)
78
+
79
+
80
+ @keras.saving.register_keras_serializable(package='molcraft')
81
+ class NumRotatableBonds(Descriptor):
82
+ def call(self, mol: chem.Mol) -> np.ndarray:
83
+ return Descriptors.NumRotatableBonds(mol)
84
+
85
+
86
+ @keras.saving.register_keras_serializable(package='molcraft')
87
+ class NumRings(Descriptor):
88
+ def call(self, mol: chem.Mol) -> np.ndarray:
89
+ return Descriptors.RingCount(mol)
90
+
@@ -0,0 +1 @@
1
+ from molcraft.experimental import peptides
@@ -0,0 +1,303 @@
1
+ import re
2
+ import keras
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from rdkit import Chem
6
+
7
+ from molcraft import ops
8
+ from molcraft import chem
9
+ from molcraft import features
10
+ from molcraft import featurizers
11
+ from molcraft import tensors
12
+
13
+
14
+ class PeptideGraphFeaturizer(featurizers.MolGraphFeaturizer):
15
+
16
+ def __init__(
17
+ self,
18
+ atom_features: list[features.Feature] | str | None = None,
19
+ bond_features: list[features.Feature] | str | None = None,
20
+ super_atom_feature: features.Feature | bool = None,
21
+ radius: int | float | None = None,
22
+ self_loops: bool = False,
23
+ include_hs: bool = False,
24
+ feature_dtype: str = 'float32',
25
+ index_dtype: str = 'int32',
26
+ ) -> None:
27
+ if super_atom_feature is None:
28
+ super_atom_feature = AminoAcidType()
29
+ super().__init__(
30
+ atom_features=atom_features,
31
+ bond_features=bond_features,
32
+ super_atom_feature=super_atom_feature,
33
+ radius=radius,
34
+ self_loops=self_loops,
35
+ include_hs=include_hs,
36
+ feature_dtype=feature_dtype,
37
+ index_dtype=index_dtype
38
+ )
39
+
40
+ def to_index(self, sequence: str):
41
+ pass
42
+
43
+ def static(self, inputs):
44
+ # TODO: Make sure it is an ordered sequence
45
+ inputs = [
46
+ features.residues[x] for x in ['G'] + inputs
47
+ ]
48
+ mols = [
49
+ chem.Mol.from_encoding(x, explicit_hs=self.include_hs) for x in inputs
50
+ ]
51
+ mols = [
52
+ mol for mol in mols if mol is not None
53
+ ]
54
+ if not mols:
55
+ return None
56
+ tensor_list: list[tensors.GraphTensor] = [super().call(mol) for mol in mols]
57
+ tensor: tensors.GraphTensor = tf.stack(tensor_list, axis=0)
58
+ return tensor
59
+
60
+ def call(self, inputs: str | tuple) -> tensors.GraphTensor:
61
+ args = []
62
+ if isinstance(inputs, (list, tuple, np.ndarray)):
63
+ inputs, *args = inputs
64
+ inputs = [
65
+ features.residues[x] for x in chem.sequence_split(inputs)
66
+ ]
67
+ tensor_list: list[tensors.GraphTensor] = [super().call(x) for x in inputs]
68
+ tensor: tensors.GraphTensor = tf.stack(tensor_list, axis=0)
69
+ tensor = tensor._merge()
70
+ context = {
71
+ k: v for (k, v) in zip(['label', 'weight'], args)
72
+ }
73
+ tensor = tensor.update(
74
+ {
75
+ 'context': context
76
+ }
77
+ )
78
+
79
+ return tensor
80
+
81
+
82
+ def GraphLookup(graph: tensors.GraphTensor) -> 'GraphLookupLayer':
83
+ lookup = GraphLookupLayer()
84
+ lookup._build(graph)
85
+ return lookup
86
+
87
+
88
+ @keras.saving.register_keras_serializable(package='molcraft')
89
+ class GraphLookupLayer(keras.layers.Layer):
90
+
91
+ def call(self, indices: tf.Tensor) -> tensors.GraphTensor:
92
+ indices = tf.sort(tf.unique(tf.reshape(indices, [-1]))[0])
93
+ graph = self.graph[indices]
94
+ sizes = graph.context['size']
95
+ max_index = keras.ops.max(indices)
96
+ sizes = tf.tensor_scatter_nd_update(
97
+ tensor=tf.zeros([max_index + 1], dtype=indices.dtype),
98
+ indices=indices[:, None],
99
+ updates=sizes
100
+ )
101
+ graph = graph.update(
102
+ {
103
+ 'context': {
104
+ 'size': sizes
105
+ }
106
+ },
107
+ )
108
+ return tensors.to_dict(graph)
109
+
110
+ def _build(self, x):
111
+
112
+ if isinstance(x, tensors.GraphTensor):
113
+ tensor = tensors.to_dict(x)
114
+ self._spec = tf.nest.map_structure(
115
+ tf.type_spec_from_value, tensor
116
+ )
117
+ else:
118
+ self._spec = x
119
+
120
+ self._graph = tf.nest.map_structure(
121
+ lambda s: self.add_weight(
122
+ shape=s.shape,
123
+ dtype=s.dtype,
124
+ trainable=False,
125
+ initializer='zeros'
126
+ ),
127
+ self._spec
128
+ )
129
+
130
+ if isinstance(x, tensors.GraphTensor):
131
+ tf.nest.map_structure(
132
+ lambda v, x: v.assign(x),
133
+ self._graph, tensor
134
+ )
135
+
136
+ graph = tf.nest.map_structure(
137
+ keras.ops.convert_to_tensor, self._graph
138
+ )
139
+ self._graph_tensor = tensors.from_dict(graph)
140
+
141
+ def get_config(self):
142
+ config = super().get_config()
143
+ spec = keras.saving.serialize_keras_object(self._spec)
144
+ config['spec'] = spec
145
+ return config
146
+
147
+ @classmethod
148
+ def from_config(cls, config: dict) -> 'GraphLookupLayer':
149
+ spec = config.pop('spec')
150
+ spec = keras.saving.deserialize_keras_object(spec)
151
+ layer = cls(**config)
152
+ layer._build(spec)
153
+ return layer
154
+
155
+ @property
156
+ def graph(self) -> tensors.GraphTensor:
157
+ return self._graph_tensor
158
+
159
+
160
+ @keras.saving.register_keras_serializable(package='molcraft')
161
+ class Gather(keras.layers.Layer):
162
+
163
+ def __init__(
164
+ self,
165
+ padding: list[tuple[int]] | tuple[int] | int = 1,
166
+ mask_value: int = 0,
167
+ **kwargs
168
+ ) -> None:
169
+ super().__init__(**kwargs)
170
+ self.padding = padding
171
+ self.mask_value = mask_value
172
+ self.supports_masking = True
173
+
174
+ def get_config(self):
175
+ config = super().get_config()
176
+ config['mask_value'] = self.mask_value
177
+ config['padding'] = self.padding
178
+ return config
179
+
180
+ def call(self, inputs: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
181
+ data, indices = inputs
182
+ # if self.padding:
183
+ # padding = self.padding
184
+ # if isinstance(self.padding, int):
185
+ # padding = [(self.padding, 0)]
186
+ # if isinstance(self.padding, tuple):
187
+ # padding = [self.padding]
188
+ # data_rank = len(keras.ops.shape(data))
189
+ # for _ in range(data_rank - len(padding)):
190
+ # padding.append((0, 0))
191
+ # data = keras.ops.pad(data, padding)
192
+ return ops.gather(data, indices)
193
+
194
+ def compute_mask(
195
+ self,
196
+ inputs: tuple[tf.Tensor, tf.Tensor],
197
+ mask: bool | None = None
198
+ ) -> tf.Tensor | None:
199
+ # if self.mask_value is None:
200
+ # return None
201
+ _, indices = inputs
202
+ return keras.ops.not_equal(indices, self.mask_value)
203
+
204
+
205
+ @keras.saving.register_keras_serializable(package='molcraft')
206
+ class AminoAcidType(features.Feature):
207
+
208
+ def __init__(self, vocab=None, **kwargs):
209
+ vocab = [
210
+ "A", "C", "D", "E", "F", "G", "H", "I", "K", "L",
211
+ "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y",
212
+ ]
213
+ super().__init__(vocab=vocab, **kwargs)
214
+
215
+ def call(self, mol: chem.Mol) -> list[str]:
216
+ residue = residues_reverse.get(mol.canonical_smiles)
217
+ if not residue:
218
+ raise KeyError(f'Could not find {mol.canonical_smiles} in `residues_reverse`.')
219
+ mol = chem.remove_hs(mol)
220
+ return [_extract_residue_type(residues_reverse[mol.canonical_smiles])]
221
+
222
+ def sequence_split(sequence: str):
223
+ patterns = [
224
+ r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
225
+ r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
226
+ r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
227
+ r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
228
+ r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
229
+ r'([A-Z])', # No mod
230
+ ]
231
+ return [match.group(0) for match in re.finditer("|".join(patterns), sequence)]
232
+
233
+ residues = {
234
+ "A": "N[C@@H](C)C(=O)O",
235
+ "C": "N[C@@H](CS)C(=O)O",
236
+ "C[Carbamidomethyl]": "N[C@@H](CSCC(=O)N)C(=O)O",
237
+ "D": "N[C@@H](CC(=O)O)C(=O)O",
238
+ "E": "N[C@@H](CCC(=O)O)C(=O)O",
239
+ "F": "N[C@@H](Cc1ccccc1)C(=O)O",
240
+ "G": "NCC(=O)O",
241
+ "H": "N[C@@H](CC1=CN=C-N1)C(=O)O",
242
+ "I": "N[C@@H](C(CC)C)C(=O)O",
243
+ "K": "N[C@@H](CCCCN)C(=O)O",
244
+ "K[Acetyl]": "N[C@@H](CCCCNC(=O)C)C(=O)O",
245
+ "K[Crotonyl]": "N[C@@H](CCCCNC(C=CC)=O)C(=O)O",
246
+ "K[Dimethyl]": "N[C@@H](CCCCN(C)C)C(=O)O",
247
+ "K[Formyl]": "N[C@@H](CCCCNC=O)C(=O)O",
248
+ "K[Malonyl]": "N[C@@H](CCCCNC(=O)CC(O)=O)C(=O)O",
249
+ "K[Methyl]": "N[C@@H](CCCCNC)C(=O)O",
250
+ "K[Propionyl]": "N[C@@H](CCCCNC(=O)CC)C(=O)O",
251
+ "K[Succinyl]": "N[C@@H](CCCCNC(CCC(O)=O)=O)C(=O)O",
252
+ "K[Trimethyl]": "N[C@@H](CCCC[N+](C)(C)C)C(=O)O",
253
+ "L": "N[C@@H](CC(C)C)C(=O)O",
254
+ "M": "N[C@@H](CCSC)C(=O)O",
255
+ "M[Oxidation]": "N[C@@H](CCS(=O)C)C(=O)O",
256
+ "N": "N[C@@H](CC(=O)N)C(=O)O",
257
+ "P": "N1[C@@H](CCC1)C(=O)O",
258
+ "P[Oxidation]": "N1CC(O)C[C@H]1C(=O)O",
259
+ "Q": "N[C@@H](CCC(=O)N)C(=O)O",
260
+ "R": "N[C@@H](CCCNC(=N)N)C(=O)O",
261
+ "R[Deamidated]": "N[C@@H](CCCNC(N)=O)C(=O)O",
262
+ "R[Dimethyl]": "N[C@@H](CCCNC(N(C)C)=N)C(=O)O",
263
+ "R[Methyl]": "N[C@@H](CCCNC(=N)NC)C(=O)O",
264
+ "S": "N[C@@H](CO)C(=O)O",
265
+ "T": "N[C@@H](C(O)C)C(=O)O",
266
+ "V": "N[C@@H](C(C)C)C(=O)O",
267
+ "W": "N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O",
268
+ "Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
269
+ "Y[Nitro]": "N[C@@H](Cc1ccc(O)c(N(=O)=O)c1)C(=O)O",
270
+ "Y[Phospho]": "N[C@@H](Cc1ccc(OP(O)(=O)O)cc1)C(=O)O",
271
+ "[Acetyl]-A": "N(C(C)=O)[C@@H](C)C(=O)O",
272
+ "[Acetyl]-C": "N(C(C)=O)[C@@H](CS)C(=O)O",
273
+ "[Acetyl]-D": "N(C(=O)C)[C@H](C(=O)O)CC(=O)O",
274
+ "[Acetyl]-E": "N(C(=O)C)[C@@H](CCC(O)=O)C(=O)O",
275
+ "[Acetyl]-F": "N(C(C)=O)[C@@H](Cc1ccccc1)C(=O)O",
276
+ "[Acetyl]-G": "N(C(=O)C)CC(=O)O",
277
+ "[Acetyl]-H": "N(C(=O)C)[C@@H](Cc1[nH]cnc1)C(=O)O",
278
+ "[Acetyl]-I": "N(C(=O)C)[C@@H]([C@H](CC)C)C(=O)O",
279
+ "[Acetyl]-K": "N(C(C)=O)[C@@H](CCCCN)C(=O)O",
280
+ "[Acetyl]-L": "N(C(=O)C)[C@@H](CC(C)C)C(=O)O",
281
+ "[Acetyl]-M": "N(C(=O)C)[C@@H](CCSC)C(=O)O",
282
+ "[Acetyl]-N": "N(C(C)=O)[C@@H](CC(=O)N)C(=O)O",
283
+ "[Acetyl]-P": "N1(C(=O)C)CCC[C@H]1C(=O)O",
284
+ "[Acetyl]-Q": "N(C(=O)C)[C@@H](CCC(=O)N)C(=O)O",
285
+ "[Acetyl]-R": "N(C(C)=O)[C@@H](CCCN=C(N)N)C(=O)O",
286
+ "[Acetyl]-S": "N(C(C)=O)[C@@H](CO)C(=O)O",
287
+ "[Acetyl]-T": "N(C(=O)C)[C@@H]([C@H](O)C)C(=O)O",
288
+ "[Acetyl]-V": "N(C(=O)C)[C@@H](C(C)C)C(=O)O",
289
+ "[Acetyl]-W": "N(C(C)=O)[C@@H](Cc1c2ccccc2[nH]c1)C(=O)O",
290
+ "[Acetyl]-Y": "N(C(C)=O)[C@@H](Cc1ccc(O)cc1)C(=O)O"
291
+ }
292
+
293
+ residues_reverse = {}
294
+ def register_peptide_residues(residues: dict[str, str]):
295
+ for residue, smiles in residues.items():
296
+ residues[residue] = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
297
+ residues_reverse[residues[residue]] = residue
298
+
299
+ register_peptide_residues(residues)
300
+
301
+ def _extract_residue_type(residue_tag: str) -> str:
302
+ pattern = r"(?<!\[)[A-Z](?![\w-])"
303
+ return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]