molcraft 0.1.0a10__tar.gz → 0.1.0a12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

Files changed (32) hide show
  1. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/PKG-INFO +15 -10
  2. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/README.md +14 -9
  3. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/__init__.py +1 -1
  4. molcraft-0.1.0a12/molcraft/apps/__init__.py +0 -0
  5. molcraft-0.1.0a12/molcraft/apps/peptides.py +429 -0
  6. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/chem.py +10 -6
  7. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/descriptors.py +1 -1
  8. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/features.py +19 -19
  9. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/featurizers.py +2 -1
  10. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/layers.py +1 -1
  11. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/losses.py +1 -1
  12. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/models.py +27 -68
  13. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft.egg-info/PKG-INFO +15 -10
  14. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft.egg-info/SOURCES.txt +2 -0
  15. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/tests/test_featurizers.py +2 -2
  16. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/LICENSE +0 -0
  17. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/callbacks.py +0 -0
  18. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/conformers.py +0 -0
  19. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/datasets.py +0 -0
  20. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/ops.py +0 -0
  21. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/records.py +0 -0
  22. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft/tensors.py +0 -0
  23. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft.egg-info/dependency_links.txt +0 -0
  24. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft.egg-info/requires.txt +0 -0
  25. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/molcraft.egg-info/top_level.txt +0 -0
  26. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/pyproject.toml +0 -0
  27. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/setup.cfg +0 -0
  28. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/tests/test_chem.py +0 -0
  29. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/tests/test_layers.py +0 -0
  30. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/tests/test_losses.py +0 -0
  31. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/tests/test_models.py +0 -0
  32. {molcraft-0.1.0a10 → molcraft-0.1.0a12}/tests/test_tensors.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a10
3
+ Version: 0.1.0a12
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -47,15 +47,20 @@ Dynamic: license-file
47
47
  **Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
48
48
 
49
49
  > [!NOTE]
50
- > In progress/Unfinished.
50
+ > In progress.
51
51
 
52
- ## Highlights
53
- - Compatible with **Keras 3**
54
- - Customizable and serializable **featurizers**
55
- - Customizable and serializable **layers** and **models**
56
- - Customizable **GraphTensor**
57
- - Fast and efficient featurization of molecular graphs
58
- - Fast and efficient input pipelines using TF **records**
52
+ ## Installation
53
+
54
+ For CPU users:
55
+
56
+ ```bash
57
+ pip install --pre molcraft
58
+ ```
59
+
60
+ For GPU users:
61
+ ```bash
62
+ pip install --pre molcraft[gpu]
63
+ ```
59
64
 
60
65
  ## Examples
61
66
 
@@ -70,7 +75,7 @@ import keras
70
75
  featurizer = featurizers.MolGraphFeaturizer(
71
76
  atom_features=[
72
77
  features.AtomType(),
73
- features.TotalNumHs(),
78
+ features.NumHydrogens(),
74
79
  features.Degree(),
75
80
  ],
76
81
  bond_features=[
@@ -3,15 +3,20 @@
3
3
  **Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
4
4
 
5
5
  > [!NOTE]
6
- > In progress/Unfinished.
6
+ > In progress.
7
7
 
8
- ## Highlights
9
- - Compatible with **Keras 3**
10
- - Customizable and serializable **featurizers**
11
- - Customizable and serializable **layers** and **models**
12
- - Customizable **GraphTensor**
13
- - Fast and efficient featurization of molecular graphs
14
- - Fast and efficient input pipelines using TF **records**
8
+ ## Installation
9
+
10
+ For CPU users:
11
+
12
+ ```bash
13
+ pip install --pre molcraft
14
+ ```
15
+
16
+ For GPU users:
17
+ ```bash
18
+ pip install --pre molcraft[gpu]
19
+ ```
15
20
 
16
21
  ## Examples
17
22
 
@@ -26,7 +31,7 @@ import keras
26
31
  featurizer = featurizers.MolGraphFeaturizer(
27
32
  atom_features=[
28
33
  features.AtomType(),
29
- features.TotalNumHs(),
34
+ features.NumHydrogens(),
30
35
  features.Degree(),
31
36
  ],
32
37
  bond_features=[
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a10'
1
+ __version__ = '0.1.0a12'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
File without changes
@@ -0,0 +1,429 @@
1
+ import re
2
+ import keras
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ import tensorflow_text as tf_text
6
+ from rdkit import Chem
7
+
8
+ from molcraft import ops
9
+ from molcraft import chem
10
+ from molcraft import features
11
+ from molcraft import featurizers
12
+ from molcraft import tensors
13
+ from molcraft import descriptors
14
+ from molcraft import layers
15
+ from molcraft import models
16
+
17
+
18
+
19
+ @keras.saving.register_keras_serializable(package='molcraft')
20
+ class SequenceSplitter(keras.layers.Layer):
21
+
22
+ _pattern = "|".join([
23
+ r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
24
+ r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
25
+ r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
26
+ r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
27
+ r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
28
+ r'([A-Z])', # No mod
29
+ ])
30
+
31
+ def call(self, inputs):
32
+ inputs = tf_text.regex_split(inputs, self._pattern, self._pattern)
33
+ inputs = keras.ops.concatenate([
34
+ tf.strings.join([inputs[:, :-1], '-[X]']),
35
+ inputs[:, -1:]
36
+ ], axis=1)
37
+ return inputs.to_tensor()
38
+
39
+ @keras.saving.register_keras_serializable(package='molcraft')
40
+ class Gather(keras.layers.Layer):
41
+
42
+ def __init__(
43
+ self,
44
+ padding: list[tuple[int]] | tuple[int] | int = 1,
45
+ mask_value: int = 0,
46
+ **kwargs
47
+ ) -> None:
48
+ super().__init__(**kwargs)
49
+ self._splitter = SequenceSplitter()
50
+ self.padding = padding
51
+ self.mask_value = mask_value
52
+ self.supports_masking = True
53
+
54
+ self._tags = list(sorted(residues.keys()))
55
+ self._mapping = tf.lookup.StaticHashTable(
56
+ tf.lookup.KeyValueTensorInitializer(
57
+ keys=self._tags,
58
+ values=range(len(self._tags)),
59
+ ),
60
+ default_value=-1,
61
+ )
62
+
63
+ def get_config(self):
64
+ config = super().get_config()
65
+ config['mask_value'] = self.mask_value
66
+ config['padding'] = self.padding
67
+ return config
68
+
69
+ def call(self, inputs) -> tf.Tensor:
70
+ embedding, sequence = inputs
71
+ sequence = self._splitter(sequence)
72
+ sequence = self._mapping.lookup(sequence)
73
+ readout = ops.gather(embedding, keras.ops.where(sequence == -1, 0, sequence))
74
+ readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
75
+ return readout
76
+
77
+ def compute_mask(
78
+ self,
79
+ inputs: tensors.GraphTensor,
80
+ mask: bool | None = None
81
+ ) -> tf.Tensor | None:
82
+ # if self.mask_value is None:
83
+ # return None
84
+ _, sequence = inputs
85
+ sequence = self._splitter(sequence)
86
+ return keras.ops.not_equal(sequence, '')
87
+
88
+
89
+ @keras.saving.register_keras_serializable(package='molcraft')
90
+ class Embedding(keras.layers.Layer):
91
+
92
+ def __init__(self, **kwargs):
93
+ super().__init__(**kwargs)
94
+ tags = list(sorted(residues.keys()))
95
+ self.mapping = tf.lookup.StaticHashTable(
96
+ tf.lookup.KeyValueTensorInitializer(
97
+ keys=tags,
98
+ values=range(len(tags)),
99
+ ),
100
+ default_value=-1,
101
+ )
102
+ self.splitting = SequenceSplitter()
103
+ featurizer = featurizers.MolGraphFeaturizer(super_atom=True)
104
+ tensor_list = [featurizer(residues[tag]) for tag in tags]
105
+ graph = tf.stack(tensor_list, axis=0)
106
+ self._build_on_init(graph)
107
+ self.embedder = models.GraphModel.from_layers(
108
+ [
109
+ layers.Input(graph.spec),
110
+ layers.NodeEmbedding(128),
111
+ layers.EdgeEmbedding(128),
112
+ layers.GraphTransformer(128),
113
+ layers.Readout()
114
+ ]
115
+ )
116
+ self.embedding = tf.Variable(
117
+ initial_value=tf.zeros((114, 128)), trainable=True
118
+ )
119
+ self.new_state = tf.Variable(True, dtype=tf.bool, trainable=False)
120
+ self.gather = Gather()
121
+ self.update_state()
122
+
123
+ # Keep AA as is (most simple?), add positional embedding to distingusih N-, C- and non-terminal
124
+
125
+ def update_state(self, inputs=None):
126
+ graph = self._graph_tensor
127
+ graph = tensors.to_dict(graph)
128
+ embedding = self.embedder(graph)
129
+ self.embedding.assign(embedding)
130
+ tf.print("STATE UPDATED")
131
+ return embedding
132
+
133
+ def call(self, inputs=None, training=None) -> tensors.GraphTensor:
134
+ if training:
135
+ embedding = self.update_state()
136
+ self.new_state.assign(True)
137
+ return self.gather([embedding, inputs])
138
+ else:
139
+ embedding = tf.cond(
140
+ pred=self.new_state,
141
+ true_fn=lambda: self.update_state(),
142
+ false_fn=lambda: self.embedding
143
+ )
144
+ self.new_state.assign(False)
145
+ return self.gather([embedding, inputs])
146
+
147
+ def build(self, input_shape):
148
+ super().build(input_shape)
149
+
150
+ def _build_on_init(self, x):
151
+
152
+ if isinstance(x, tensors.GraphTensor):
153
+ tensor = tensors.to_dict(x)
154
+ self._spec = tf.nest.map_structure(
155
+ tf.type_spec_from_value, tensor
156
+ )
157
+ else:
158
+ self._spec = x
159
+
160
+ self._graph = tf.nest.map_structure(
161
+ lambda s: self.add_weight(
162
+ shape=s.shape,
163
+ dtype=s.dtype,
164
+ trainable=False,
165
+ initializer='zeros'
166
+ ),
167
+ self._spec
168
+ )
169
+
170
+ if isinstance(x, tensors.GraphTensor):
171
+ tf.nest.map_structure(
172
+ lambda v, x: v.assign(x),
173
+ self._graph, tensor
174
+ )
175
+
176
+ graph = tf.nest.map_structure(
177
+ keras.ops.convert_to_tensor, self._graph
178
+ )
179
+ self._graph_tensor = tensors.from_dict(graph)
180
+
181
+ # def get_config(self) -> dict:
182
+ # config = super().get_config()
183
+ # spec = keras.saving.serialize_keras_object(self._spec)
184
+ # config['spec'] = spec
185
+ # #config['layers'] = keras.saving.serialize_keras_object(self.embedding.layers)
186
+ # return config
187
+
188
+ # @classmethod
189
+ # def from_config(cls, config: dict) -> 'SequenceToGraph':
190
+ # spec = config.pop('spec')
191
+ # spec = keras.saving.deserialize_keras_object(spec)
192
+ # # config['layers'] = keras.saving.deserialize_keras_object(config['layers'])
193
+ # layer = cls(**config)
194
+ # layer._build_on_init(spec)
195
+ # return layer
196
+
197
+
198
+ @keras.saving.register_keras_serializable(package='molcraft')
199
+ class SequenceToGraph(keras.layers.Layer):
200
+
201
+ def __init__(
202
+ self,
203
+ atom_features: list[features.Feature] | str | None = 'auto',
204
+ bond_features: list[features.Feature] | str | None = 'auto',
205
+ molecule_features: list[descriptors.Descriptor] | str | None = 'auto',
206
+ super_atom: bool = True,
207
+ radius: int | float | None = None,
208
+ self_loops: bool = False,
209
+ include_hs: bool = False,
210
+ **kwargs,
211
+ ):
212
+ super().__init__(**kwargs)
213
+ self._splitter = SequenceSplitter()
214
+ featurizer = featurizers.MolGraphFeaturizer(
215
+ atom_features=atom_features,
216
+ bond_features=bond_features,
217
+ molecule_features=molecule_features,
218
+ super_atom=super_atom,
219
+ radius=radius,
220
+ self_loops=self_loops,
221
+ include_hs=include_hs,
222
+ **kwargs,
223
+ )
224
+ tensor_list: list[tensors.GraphTensor] = [
225
+ featurizer(residues[tag]).update({'context': {'tag': tag}}) for tag in residues
226
+ ]
227
+ graph = tf.stack(tensor_list, axis=0)
228
+ self._build_on_init(graph)
229
+
230
+ def call(self, sequence: tf.Tensor) -> tensors.GraphTensor:
231
+ sequence = self._splitter(sequence)
232
+ indices = self._tag_to_index.lookup(sequence)
233
+ indices = tf.sort(tf.unique(tf.reshape(indices, [-1]))[0])[1:]
234
+ graph = self._graph_tensor[indices]
235
+ return tensors.to_dict(graph)
236
+
237
+ def _build_on_init(self, x):
238
+
239
+ if isinstance(x, tensors.GraphTensor):
240
+ tensor = tensors.to_dict(x)
241
+ self._spec = tf.nest.map_structure(
242
+ tf.type_spec_from_value, tensor
243
+ )
244
+ else:
245
+ self._spec = x
246
+
247
+ self._graph = tf.nest.map_structure(
248
+ lambda s: self.add_weight(
249
+ shape=s.shape,
250
+ dtype=s.dtype,
251
+ trainable=False,
252
+ initializer='zeros'
253
+ ),
254
+ self._spec
255
+ )
256
+
257
+ if isinstance(x, tensors.GraphTensor):
258
+ tf.nest.map_structure(
259
+ lambda v, x: v.assign(x),
260
+ self._graph, tensor
261
+ )
262
+
263
+ graph = tf.nest.map_structure(
264
+ keras.ops.convert_to_tensor, self._graph
265
+ )
266
+ self._graph_tensor = tensors.from_dict(graph)
267
+
268
+ tags = self._graph_tensor.context['tag']
269
+
270
+ self._tag_to_index = tf.lookup.StaticHashTable(
271
+ tf.lookup.KeyValueTensorInitializer(
272
+ keys=tags,
273
+ values=range(len(tags)),
274
+ ),
275
+ default_value=-1,
276
+ )
277
+
278
+ def get_config(self) -> dict:
279
+ config = super().get_config()
280
+ spec = keras.saving.serialize_keras_object(self._spec)
281
+ config['spec'] = spec
282
+ return config
283
+
284
+ @classmethod
285
+ def from_config(cls, config: dict) -> 'SequenceToGraph':
286
+ spec = config.pop('spec')
287
+ spec = keras.saving.deserialize_keras_object(spec)
288
+ layer = cls(**config)
289
+ layer._build_on_init(spec)
290
+ return layer
291
+
292
+ # @property
293
+ # def graph(self) -> tensors.GraphTensor:
294
+ # return self._graph_tensor
295
+
296
+
297
+ @keras.saving.register_keras_serializable(package='molcraft')
298
+ class GraphToSequence(keras.layers.Layer):
299
+
300
+ def __init__(
301
+ self,
302
+ padding: list[tuple[int]] | tuple[int] | int = 1,
303
+ mask_value: int = 0,
304
+ **kwargs
305
+ ) -> None:
306
+ super().__init__(**kwargs)
307
+ self._splitter = SequenceSplitter()
308
+ self.padding = padding
309
+ self.mask_value = mask_value
310
+ self._readout_layer = layers.Readout(mode='mean')
311
+ self.supports_masking = True
312
+
313
+ def get_config(self):
314
+ config = super().get_config()
315
+ config['mask_value'] = self.mask_value
316
+ config['padding'] = self.padding
317
+ return config
318
+
319
+ def call(self, inputs) -> tf.Tensor:
320
+
321
+ graph, sequence = inputs
322
+ sequence = self._splitter(sequence)
323
+ tag = graph['context']['tag']
324
+ data = self._readout_layer(graph)
325
+
326
+ table = tf.lookup.experimental.MutableHashTable(
327
+ key_dtype=tf.string,
328
+ value_dtype=tf.int32,
329
+ default_value=-1
330
+ )
331
+
332
+ table.insert(tag, tf.range(tf.shape(tag)[0]))
333
+ sequence = table.lookup(sequence)
334
+
335
+ readout = ops.gather(data, keras.ops.where(sequence == -1, 0, sequence))
336
+ readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
337
+ return readout
338
+
339
+ def compute_mask(
340
+ self,
341
+ inputs: tensors.GraphTensor,
342
+ mask: bool | None = None
343
+ ) -> tf.Tensor | None:
344
+ # if self.mask_value is None:
345
+ # return None
346
+ _, sequence = inputs
347
+ sequence = self._splitter(sequence)
348
+ return keras.ops.not_equal(sequence, '')
349
+
350
+
351
+ residues = {
352
+ "A": "N[C@@H](C)C(=O)O",
353
+ "C": "N[C@@H](CS)C(=O)O",
354
+ "C[Carbamidomethyl]": "N[C@@H](CSCC(=O)N)C(=O)O",
355
+ "D": "N[C@@H](CC(=O)O)C(=O)O",
356
+ "E": "N[C@@H](CCC(=O)O)C(=O)O",
357
+ "F": "N[C@@H](Cc1ccccc1)C(=O)O",
358
+ "G": "NCC(=O)O",
359
+ "H": "N[C@@H](CC1=CN=C-N1)C(=O)O",
360
+ "I": "N[C@@H](C(CC)C)C(=O)O",
361
+ "K": "N[C@@H](CCCCN)C(=O)O",
362
+ "K[Acetyl]": "N[C@@H](CCCCNC(=O)C)C(=O)O",
363
+ "K[Crotonyl]": "N[C@@H](CCCCNC(C=CC)=O)C(=O)O",
364
+ "K[Dimethyl]": "N[C@@H](CCCCN(C)C)C(=O)O",
365
+ "K[Formyl]": "N[C@@H](CCCCNC=O)C(=O)O",
366
+ "K[Malonyl]": "N[C@@H](CCCCNC(=O)CC(O)=O)C(=O)O",
367
+ "K[Methyl]": "N[C@@H](CCCCNC)C(=O)O",
368
+ "K[Propionyl]": "N[C@@H](CCCCNC(=O)CC)C(=O)O",
369
+ "K[Succinyl]": "N[C@@H](CCCCNC(CCC(O)=O)=O)C(=O)O",
370
+ "K[Trimethyl]": "N[C@@H](CCCC[N+](C)(C)C)C(=O)O",
371
+ "L": "N[C@@H](CC(C)C)C(=O)O",
372
+ "M": "N[C@@H](CCSC)C(=O)O",
373
+ "M[Oxidation]": "N[C@@H](CCS(=O)C)C(=O)O",
374
+ "N": "N[C@@H](CC(=O)N)C(=O)O",
375
+ "P": "N1[C@@H](CCC1)C(=O)O",
376
+ "P[Oxidation]": "N1CC(O)C[C@H]1C(=O)O",
377
+ "Q": "N[C@@H](CCC(=O)N)C(=O)O",
378
+ "R": "N[C@@H](CCCNC(=N)N)C(=O)O",
379
+ "R[Deamidated]": "N[C@@H](CCCNC(N)=O)C(=O)O",
380
+ "R[Dimethyl]": "N[C@@H](CCCNC(N(C)C)=N)C(=O)O",
381
+ "R[Methyl]": "N[C@@H](CCCNC(=N)NC)C(=O)O",
382
+ "S": "N[C@@H](CO)C(=O)O",
383
+ "T": "N[C@@H](C(O)C)C(=O)O",
384
+ "V": "N[C@@H](C(C)C)C(=O)O",
385
+ "W": "N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O",
386
+ "Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
387
+ "Y[Nitro]": "N[C@@H](Cc1ccc(O)c(N(=O)=O)c1)C(=O)O",
388
+ "Y[Phospho]": "N[C@@H](Cc1ccc(OP(O)(=O)O)cc1)C(=O)O",
389
+ "[Acetyl]-A": "N(C(C)=O)[C@@H](C)C(=O)O",
390
+ "[Acetyl]-C": "N(C(C)=O)[C@@H](CS)C(=O)O",
391
+ "[Acetyl]-D": "N(C(=O)C)[C@H](C(=O)O)CC(=O)O",
392
+ "[Acetyl]-E": "N(C(=O)C)[C@@H](CCC(O)=O)C(=O)O",
393
+ "[Acetyl]-F": "N(C(C)=O)[C@@H](Cc1ccccc1)C(=O)O",
394
+ "[Acetyl]-G": "N(C(=O)C)CC(=O)O",
395
+ "[Acetyl]-H": "N(C(=O)C)[C@@H](Cc1[nH]cnc1)C(=O)O",
396
+ "[Acetyl]-I": "N(C(=O)C)[C@@H]([C@H](CC)C)C(=O)O",
397
+ "[Acetyl]-K": "N(C(C)=O)[C@@H](CCCCN)C(=O)O",
398
+ "[Acetyl]-L": "N(C(=O)C)[C@@H](CC(C)C)C(=O)O",
399
+ "[Acetyl]-M": "N(C(=O)C)[C@@H](CCSC)C(=O)O",
400
+ "[Acetyl]-N": "N(C(C)=O)[C@@H](CC(=O)N)C(=O)O",
401
+ "[Acetyl]-P": "N1(C(=O)C)CCC[C@H]1C(=O)O",
402
+ "[Acetyl]-Q": "N(C(=O)C)[C@@H](CCC(=O)N)C(=O)O",
403
+ "[Acetyl]-R": "N(C(C)=O)[C@@H](CCCN=C(N)N)C(=O)O",
404
+ "[Acetyl]-S": "N(C(C)=O)[C@@H](CO)C(=O)O",
405
+ "[Acetyl]-T": "N(C(=O)C)[C@@H]([C@H](O)C)C(=O)O",
406
+ "[Acetyl]-V": "N(C(=O)C)[C@@H](C(C)C)C(=O)O",
407
+ "[Acetyl]-W": "N(C(C)=O)[C@@H](Cc1c2ccccc2[nH]c1)C(=O)O",
408
+ "[Acetyl]-Y": "N(C(C)=O)[C@@H](Cc1ccc(O)cc1)C(=O)O"
409
+ }
410
+
411
+ residues_reverse = {}
412
+ def register_peptide_residues(residues_: dict[str, str], canonicalize=True):
413
+ for residue, smiles in residues_.items():
414
+ if canonicalize:
415
+ smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
416
+ residues[residue] = smiles
417
+ residues_reverse[residues[residue]] = residue
418
+
419
+ register_peptide_residues(residues, canonicalize=False)
420
+
421
+ def _extract_residue_type(residue_tag: str) -> str:
422
+ pattern = r"(?<!\[)[A-Z](?![^\[]*\])"
423
+ return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]
424
+
425
+ special_residues = {}
426
+ for key, value in residues.items():
427
+ special_residues[key + '-[X]'] = value.rstrip('O')
428
+
429
+ register_peptide_residues(special_residues, canonicalize=False)
@@ -426,10 +426,12 @@ def embed_conformers(
426
426
  success = rdDistGeom.EmbedMultipleConfs(
427
427
  mol, numConfs=num_conformers, params=embedding_method
428
428
  )
429
- if not len(success):
429
+ num_successes = len(success)
430
+ if num_successes < num_conformers:
430
431
  warnings.warn(
431
- f'Could not embed conformer(s) for {mol.canonical_smiles!r} using the '
432
- 'speified method. Giving it another try with more permissive methods.',
432
+ f'Could only embed {num_successes} out of {num_conformers} conformer(s) '
433
+ f'for {mol.canonical_smiles!r} using {method}. Embedding the remaining '
434
+ f'{num_conformers - num_successes} conformer(s) using different embedding methods.',
433
435
  stacklevel=2
434
436
  )
435
437
  max_attempts = (20 * mol.num_atoms) # increasing it from 10xN to 20xN
@@ -437,14 +439,16 @@ def embed_conformers(
437
439
  fallback_embedding_method = available_embedding_methods[fallback_method]
438
440
  fallback_embedding_method.useRandomCoords = True
439
441
  fallback_embedding_method.maxAttempts = max_attempts
442
+ fallback_embedding_method.clearConfs = False
440
443
  success = rdDistGeom.EmbedMultipleConfs(
441
- mol, numConfs=num_conformers, params=fallback_embedding_method
444
+ mol, numConfs=(num_conformers - num_successes), params=fallback_embedding_method
442
445
  )
443
- if len(success):
446
+ num_successes += len(success)
447
+ if num_successes == num_conformers:
444
448
  break
445
449
  else:
446
450
  raise RuntimeError(
447
- f'Could not embed conformer(s) for {mol.canonical_smiles!r}. '
451
+ f'Could not embed {num_conformers} conformer(s) for {mol.canonical_smiles!r}. '
448
452
  )
449
453
  return mol
450
454
 
@@ -61,7 +61,7 @@ class NumHeavyAtoms(Descriptor):
61
61
 
62
62
 
63
63
  @keras.saving.register_keras_serializable(package='molcraft')
64
- class NumHeteroAtoms(Descriptor):
64
+ class NumHeteroatoms(Descriptor):
65
65
  def call(self, mol: chem.Mol) -> np.ndarray:
66
66
  return rdMolDescriptors.CalcNumHeteroatoms(mol)
67
67
 
@@ -185,13 +185,13 @@ class Degree(Feature):
185
185
 
186
186
 
187
187
  @keras.saving.register_keras_serializable(package='molcraft')
188
- class TotalNumHs(Feature):
188
+ class NumHydrogens(Feature):
189
189
  def call(self, mol: chem.Mol) -> list[int, float, str]:
190
190
  return [atom.GetTotalNumHs() for atom in mol.atoms]
191
191
 
192
192
 
193
193
  @keras.saving.register_keras_serializable(package='molcraft')
194
- class TotalValence(Feature):
194
+ class Valence(Feature):
195
195
  def call(self, mol: chem.Mol) -> list[int, float, str]:
196
196
  return [atom.GetTotalValence() for atom in mol.atoms]
197
197
 
@@ -218,10 +218,17 @@ class CIPCode(Feature):
218
218
 
219
219
 
220
220
  @keras.saving.register_keras_serializable(package='molcraft')
221
- class IsChiralityPossible(Feature):
221
+ class RingSize(Feature):
222
222
  def call(self, mol: chem.Mol) -> list[int, float, str]:
223
- return [atom.HasProp("_ChiralityPossible") for atom in mol.atoms]
224
-
223
+ def ring_size(atom):
224
+ if not atom.IsInRing():
225
+ return -1
226
+ size = 3
227
+ while not atom.IsInRingSize(size):
228
+ size += 1
229
+ return size
230
+ return [ring_size(atom) for atom in mol.atoms]
231
+
225
232
 
226
233
  @keras.saving.register_keras_serializable(package='molcraft')
227
234
  class FormalCharge(Feature):
@@ -229,6 +236,12 @@ class FormalCharge(Feature):
229
236
  return [atom.GetFormalCharge() for atom in mol.atoms]
230
237
 
231
238
 
239
+ @keras.saving.register_keras_serializable(package='molcraft')
240
+ class IsChiralityPossible(Feature):
241
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
242
+ return [atom.HasProp("_ChiralityPossible") for atom in mol.atoms]
243
+
244
+
232
245
  @keras.saving.register_keras_serializable(package='molcraft')
233
246
  class NumRadicalElectrons(Feature):
234
247
  def call(self, mol: chem.Mol) -> list[int, float, str]:
@@ -242,7 +255,7 @@ class IsAromatic(Feature):
242
255
 
243
256
 
244
257
  @keras.saving.register_keras_serializable(package='molcraft')
245
- class IsHetero(Feature):
258
+ class IsHeteroatom(Feature):
246
259
  def call(self, mol: chem.Mol) -> list[int, float, str]:
247
260
  return chem.hetero_atoms(mol)
248
261
 
@@ -259,19 +272,6 @@ class IsHydrogenAcceptor(Feature):
259
272
  return chem.hydrogen_acceptors(mol)
260
273
 
261
274
 
262
- @keras.saving.register_keras_serializable(package='molcraft')
263
- class RingSize(Feature):
264
- def call(self, mol: chem.Mol) -> list[int, float, str]:
265
- def ring_size(atom):
266
- if not atom.IsInRing():
267
- return -1
268
- size = 3
269
- while not atom.IsInRingSize(size):
270
- size += 1
271
- return size
272
- return [ring_size(atom) for atom in mol.atoms]
273
-
274
-
275
275
  @keras.saving.register_keras_serializable(package='molcraft')
276
276
  class IsInRing(Feature):
277
277
  def call(self, mol: chem.Mol) -> list[int, float, str]:
@@ -196,7 +196,7 @@ class MolGraphFeaturizer(Featurizer):
196
196
  descriptors.CrippenLogP(),
197
197
  descriptors.CrippenMolarRefractivity(),
198
198
  descriptors.NumHeavyAtoms(),
199
- descriptors.NumHeteroAtoms(),
199
+ descriptors.NumHeteroatoms(),
200
200
  descriptors.NumHydrogenDonors(),
201
201
  descriptors.NumHydrogenAcceptors(),
202
202
  descriptors.NumRotatableBonds(),
@@ -556,6 +556,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
556
556
 
557
557
  molecule_feature = self.molecule_feature(mol)
558
558
  molecule_size = self.num_atoms(mol) + int(self.super_atom)
559
+ molecule_size = molecule_size.astype(self.index_dtype)
559
560
 
560
561
  if isinstance(context, dict):
561
562
  if 'x' in context:
@@ -661,7 +661,7 @@ class GIConv(GraphConv):
661
661
  return config
662
662
 
663
663
 
664
- @keras.saving.register_keras_serializable(package='molgraphx')
664
+ @keras.saving.register_keras_serializable(package='molcraft')
665
665
  class GAConv(GraphConv):
666
666
 
667
667
  """Graph attention network layer.
@@ -2,7 +2,7 @@ import keras
2
2
  import numpy as np
3
3
 
4
4
 
5
- @keras.saving.register_keras_serializable(package='molgraph')
5
+ @keras.saving.register_keras_serializable(package='molcraft')
6
6
  class GaussianNegativeLogLikelihood(keras.losses.Loss):
7
7
 
8
8
  def __init__(
@@ -114,6 +114,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
114
114
  return typing.cast(GraphModel, super().__new__(cls))
115
115
 
116
116
  def __init__(self, *args, **kwargs):
117
+ self._model_layers = kwargs.pop('model_layers', None)
117
118
  super().__init__(*args, **kwargs)
118
119
  self.jit_compile = False
119
120
 
@@ -135,10 +136,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
135
136
  `molcraft.layers.Input(spec)`.
136
137
  """
137
138
  if not tensors.is_graph(graph_layers[0]):
138
- # TODO: Allow this. E.g.: return cls(layers=graph_layers)
139
- raise ValueError(
140
- 'Graph input not found. Make sure to add `Input`.'
141
- )
139
+ return cls(model_layers=graph_layers)
142
140
  inputs: dict = graph_layers.pop(0)
143
141
  x = inputs
144
142
  for layer in graph_layers:
@@ -148,6 +146,31 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
148
146
  outputs = x
149
147
  return cls(inputs=inputs, outputs=outputs, **kwargs)
150
148
 
149
+ def propagate(self, graph: tensors.GraphTensor) -> tensors.GraphTensor:
150
+ if self._model_layers is None:
151
+ return super().propagate(graph)
152
+ for layer in self._model_layers:
153
+ graph = layer(graph)
154
+ return graph
155
+
156
+ def get_config(self):
157
+ config = super().get_config()
158
+ if hasattr(self, '_model_layers') and self._model_layers is not None:
159
+ config['model_layers'] = [
160
+ keras.saving.serialize_keras_object(l)
161
+ for l in self._model_layers
162
+ ]
163
+ return config
164
+
165
+ @classmethod
166
+ def from_config(cls, config: dict):
167
+ if 'model_layers' in config:
168
+ config['model_layers'] = [
169
+ keras.saving.deserialize_keras_object(l)
170
+ for l in config['model_layers']
171
+ ]
172
+ return super().from_config(config)
173
+
151
174
  def compile(
152
175
  self,
153
176
  optimizer: keras.optimizers.Optimizer | str | None = 'rmsprop',
@@ -416,7 +439,6 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
416
439
  return self(tensor, training=False)
417
440
 
418
441
  def compute_loss(self, x, y, y_pred, sample_weight=None):
419
- y, y_pred, sample_weight = _maybe_reshape(y, y_pred, sample_weight)
420
442
  return super().compute_loss(x, y, y_pred, sample_weight)
421
443
 
422
444
  def compute_metrics(self, x, y, y_pred, sample_weight=None) -> dict[str, float]:
@@ -531,58 +553,6 @@ def saliency(
531
553
  }
532
554
  }
533
555
  )
534
-
535
- def predict(
536
- model: GraphModel,
537
- x: tensors.GraphTensor | tf.data.Dataset,
538
- repeats: int | None = 16,
539
- batch_size: int = 256,
540
- verbose: int = 0,
541
- **kwargs,
542
- ) -> tuple[tf.Tensor | np.ndarray, tf.Tensor | np.ndarray]:
543
- """Predict with model.
544
-
545
- By default performs monte-carlo predictions. Namely, it performs
546
- `repeats` number of predictions for each example with `training = True`,
547
- and subsequently computes mean and standard deviations of the predictions.
548
-
549
- Args:
550
- x:
551
- A `GraphTensor` instance.
552
- repeats:
553
- Number of predictions per example.
554
- batch_size:
555
- Number of samples per batch of computation.
556
- kwargs:
557
- See `Model.predict` in Keras documentation.
558
- May or may not apply here.
559
- """
560
- if not repeats:
561
- return model.predict(
562
- x, batch_size=batch_size, verbose=verbose, **kwargs
563
- )
564
- if isinstance(x, tensors.GraphTensor):
565
- ds = tf.data.Dataset.from_tensor_slices(x)
566
- ds = ds.repeat(repeats)
567
- ds = ds.batch(batch_size)
568
- elif isinstance(x, tf.data.Dataset):
569
- ds = x.repeat(repeats)
570
- else:
571
- raise ValueError(
572
- 'Input `x` needs to be a `tensors.GraphTensor` instance '
573
- 'or a `tf.data.Dataset` instance constructed from `tensors.GraphTensor`.'
574
- )
575
- ds = ds.prefetch(-1)
576
- y_pred = keras.ops.concatenate([
577
- model(x, training=True) for x in ds])
578
- global_batch_size = len(y_pred) // repeats
579
- y_pred = np.reshape(y_pred, (repeats, global_batch_size, -1))
580
- y_pred_loc = keras.ops.mean(y_pred, axis=0)
581
- y_pred_scale = keras.ops.std(y_pred, axis=0)
582
- if tf.executing_eagerly():
583
- y_pred_loc = y_pred_loc.numpy()
584
- y_pred_scale = y_pred_scale.numpy()
585
- return (y_pred_loc, y_pred_scale)
586
556
 
587
557
  def _functional_init_arguments(args, kwargs):
588
558
  return (
@@ -597,14 +567,3 @@ def _make_dataset(x: tensors.GraphTensor, batch_size: int):
597
567
  .batch(batch_size)
598
568
  .prefetch(-1)
599
569
  )
600
-
601
- def _maybe_reshape(y, y_pred, sample_weight):
602
- if (
603
- sample_weight is not None and
604
- len(keras.ops.shape(sample_weight)) == 2 and
605
- sample_weight.shape == y_pred.shape
606
- ):
607
- y = keras.ops.reshape(y, [-1])
608
- y_pred = keras.ops.reshape(y_pred, [-1])
609
- sample_weight = keras.ops.reshape(sample_weight, [-1])
610
- return y, y_pred, sample_weight
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a10
3
+ Version: 0.1.0a12
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -47,15 +47,20 @@ Dynamic: license-file
47
47
  **Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
48
48
 
49
49
  > [!NOTE]
50
- > In progress/Unfinished.
50
+ > In progress.
51
51
 
52
- ## Highlights
53
- - Compatible with **Keras 3**
54
- - Customizable and serializable **featurizers**
55
- - Customizable and serializable **layers** and **models**
56
- - Customizable **GraphTensor**
57
- - Fast and efficient featurization of molecular graphs
58
- - Fast and efficient input pipelines using TF **records**
52
+ ## Installation
53
+
54
+ For CPU users:
55
+
56
+ ```bash
57
+ pip install --pre molcraft
58
+ ```
59
+
60
+ For GPU users:
61
+ ```bash
62
+ pip install --pre molcraft[gpu]
63
+ ```
59
64
 
60
65
  ## Examples
61
66
 
@@ -70,7 +75,7 @@ import keras
70
75
  featurizer = featurizers.MolGraphFeaturizer(
71
76
  atom_features=[
72
77
  features.AtomType(),
73
- features.TotalNumHs(),
78
+ features.NumHydrogens(),
74
79
  features.Degree(),
75
80
  ],
76
81
  bond_features=[
@@ -20,6 +20,8 @@ molcraft.egg-info/SOURCES.txt
20
20
  molcraft.egg-info/dependency_links.txt
21
21
  molcraft.egg-info/requires.txt
22
22
  molcraft.egg-info/top_level.txt
23
+ molcraft/apps/__init__.py
24
+ molcraft/apps/peptides.py
23
25
  tests/test_chem.py
24
26
  tests/test_featurizers.py
25
27
  tests/test_layers.py
@@ -31,7 +31,7 @@ class TestFeaturizer(unittest.TestCase):
31
31
  featurizer = featurizers.MolFeaturizer(
32
32
  atom_features=[
33
33
  features.AtomType({'C', 'N', 'O', 'H'}),
34
- features.TotalNumHs({0, 1, 2, 3, 4})
34
+ features.NumHydrogens({0, 1, 2, 3, 4})
35
35
  ],
36
36
  bond_features=[
37
37
  features.BondType({'single', 'double', 'aromatic'}),
@@ -119,7 +119,7 @@ class TestFeaturizer(unittest.TestCase):
119
119
  featurizer = featurizers.MolFeaturizer3D(
120
120
  atom_features=[
121
121
  features.AtomType({'C', 'N', 'O', 'H'}, encode_oov=True),
122
- features.TotalNumHs({0, 1, 2, 3, 4})
122
+ features.NumHydrogens({0, 1, 2, 3, 4})
123
123
  ],
124
124
  bond_features=[
125
125
  features.Distance(max_distance=20)
File without changes
File without changes
File without changes
File without changes