molcraft 0.1.0a9__py3-none-any.whl → 0.1.0a11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a9'
1
+ __version__ = '0.1.0a11'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
File without changes
@@ -0,0 +1,429 @@
1
+ import re
2
+ import keras
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ import tensorflow_text as tf_text
6
+ from rdkit import Chem
7
+
8
+ from molcraft import ops
9
+ from molcraft import chem
10
+ from molcraft import features
11
+ from molcraft import featurizers
12
+ from molcraft import tensors
13
+ from molcraft import descriptors
14
+ from molcraft import layers
15
+ from molcraft import models
16
+
17
+
18
+
19
+ @keras.saving.register_keras_serializable(package='molcraft')
20
+ class SequenceSplitter(keras.layers.Layer):
21
+
22
+ _pattern = "|".join([
23
+ r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
24
+ r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
25
+ r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
26
+ r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
27
+ r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
28
+ r'([A-Z])', # No mod
29
+ ])
30
+
31
+ def call(self, inputs):
32
+ inputs = tf_text.regex_split(inputs, self._pattern, self._pattern)
33
+ inputs = keras.ops.concatenate([
34
+ tf.strings.join([inputs[:, :-1], '-[X]']),
35
+ inputs[:, -1:]
36
+ ], axis=1)
37
+ return inputs.to_tensor()
38
+
39
+ @keras.saving.register_keras_serializable(package='molcraft')
40
+ class Gather(keras.layers.Layer):
41
+
42
+ def __init__(
43
+ self,
44
+ padding: list[tuple[int]] | tuple[int] | int = 1,
45
+ mask_value: int = 0,
46
+ **kwargs
47
+ ) -> None:
48
+ super().__init__(**kwargs)
49
+ self._splitter = SequenceSplitter()
50
+ self.padding = padding
51
+ self.mask_value = mask_value
52
+ self.supports_masking = True
53
+
54
+ self._tags = list(sorted(residues.keys()))
55
+ self._mapping = tf.lookup.StaticHashTable(
56
+ tf.lookup.KeyValueTensorInitializer(
57
+ keys=self._tags,
58
+ values=range(len(self._tags)),
59
+ ),
60
+ default_value=-1,
61
+ )
62
+
63
+ def get_config(self):
64
+ config = super().get_config()
65
+ config['mask_value'] = self.mask_value
66
+ config['padding'] = self.padding
67
+ return config
68
+
69
+ def call(self, inputs) -> tf.Tensor:
70
+ embedding, sequence = inputs
71
+ sequence = self._splitter(sequence)
72
+ sequence = self._mapping.lookup(sequence)
73
+ readout = ops.gather(embedding, keras.ops.where(sequence == -1, 0, sequence))
74
+ readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
75
+ return readout
76
+
77
+ def compute_mask(
78
+ self,
79
+ inputs: tensors.GraphTensor,
80
+ mask: bool | None = None
81
+ ) -> tf.Tensor | None:
82
+ # if self.mask_value is None:
83
+ # return None
84
+ _, sequence = inputs
85
+ sequence = self._splitter(sequence)
86
+ return keras.ops.not_equal(sequence, '')
87
+
88
+
89
+ @keras.saving.register_keras_serializable(package='molcraft')
90
+ class Embedding(keras.layers.Layer):
91
+
92
+ def __init__(self, **kwargs):
93
+ super().__init__(**kwargs)
94
+ tags = list(sorted(residues.keys()))
95
+ self.mapping = tf.lookup.StaticHashTable(
96
+ tf.lookup.KeyValueTensorInitializer(
97
+ keys=tags,
98
+ values=range(len(tags)),
99
+ ),
100
+ default_value=-1,
101
+ )
102
+ self.splitting = SequenceSplitter()
103
+ featurizer = featurizers.MolGraphFeaturizer(super_atom=True)
104
+ tensor_list = [featurizer(residues[tag]) for tag in tags]
105
+ graph = tf.stack(tensor_list, axis=0)
106
+ self._build_on_init(graph)
107
+ self.embedder = models.GraphModel.from_layers(
108
+ [
109
+ layers.Input(graph.spec),
110
+ layers.NodeEmbedding(128),
111
+ layers.EdgeEmbedding(128),
112
+ layers.GraphTransformer(128),
113
+ layers.Readout()
114
+ ]
115
+ )
116
+ self.embedding = tf.Variable(
117
+ initial_value=tf.zeros((114, 128)), trainable=True
118
+ )
119
+ self.new_state = tf.Variable(True, dtype=tf.bool, trainable=False)
120
+ self.gather = Gather()
121
+ self.update_state()
122
+
123
+ # Keep AA as is (most simple?), add positional embedding to distingusih N-, C- and non-terminal
124
+
125
+ def update_state(self, inputs=None):
126
+ graph = self._graph_tensor
127
+ graph = tensors.to_dict(graph)
128
+ embedding = self.embedder(graph)
129
+ self.embedding.assign(embedding)
130
+ tf.print("STATE UPDATED")
131
+ return embedding
132
+
133
+ def call(self, inputs=None, training=None) -> tensors.GraphTensor:
134
+ if training:
135
+ embedding = self.update_state()
136
+ self.new_state.assign(True)
137
+ return self.gather([embedding, inputs])
138
+ else:
139
+ embedding = tf.cond(
140
+ pred=self.new_state,
141
+ true_fn=lambda: self.update_state(),
142
+ false_fn=lambda: self.embedding
143
+ )
144
+ self.new_state.assign(False)
145
+ return self.gather([embedding, inputs])
146
+
147
+ def build(self, input_shape):
148
+ super().build(input_shape)
149
+
150
+ def _build_on_init(self, x):
151
+
152
+ if isinstance(x, tensors.GraphTensor):
153
+ tensor = tensors.to_dict(x)
154
+ self._spec = tf.nest.map_structure(
155
+ tf.type_spec_from_value, tensor
156
+ )
157
+ else:
158
+ self._spec = x
159
+
160
+ self._graph = tf.nest.map_structure(
161
+ lambda s: self.add_weight(
162
+ shape=s.shape,
163
+ dtype=s.dtype,
164
+ trainable=False,
165
+ initializer='zeros'
166
+ ),
167
+ self._spec
168
+ )
169
+
170
+ if isinstance(x, tensors.GraphTensor):
171
+ tf.nest.map_structure(
172
+ lambda v, x: v.assign(x),
173
+ self._graph, tensor
174
+ )
175
+
176
+ graph = tf.nest.map_structure(
177
+ keras.ops.convert_to_tensor, self._graph
178
+ )
179
+ self._graph_tensor = tensors.from_dict(graph)
180
+
181
+ # def get_config(self) -> dict:
182
+ # config = super().get_config()
183
+ # spec = keras.saving.serialize_keras_object(self._spec)
184
+ # config['spec'] = spec
185
+ # #config['layers'] = keras.saving.serialize_keras_object(self.embedding.layers)
186
+ # return config
187
+
188
+ # @classmethod
189
+ # def from_config(cls, config: dict) -> 'SequenceToGraph':
190
+ # spec = config.pop('spec')
191
+ # spec = keras.saving.deserialize_keras_object(spec)
192
+ # # config['layers'] = keras.saving.deserialize_keras_object(config['layers'])
193
+ # layer = cls(**config)
194
+ # layer._build_on_init(spec)
195
+ # return layer
196
+
197
+
198
+ @keras.saving.register_keras_serializable(package='molcraft')
199
+ class SequenceToGraph(keras.layers.Layer):
200
+
201
+ def __init__(
202
+ self,
203
+ atom_features: list[features.Feature] | str | None = 'auto',
204
+ bond_features: list[features.Feature] | str | None = 'auto',
205
+ molecule_features: list[descriptors.Descriptor] | str | None = 'auto',
206
+ super_atom: bool = True,
207
+ radius: int | float | None = None,
208
+ self_loops: bool = False,
209
+ include_hs: bool = False,
210
+ **kwargs,
211
+ ):
212
+ super().__init__(**kwargs)
213
+ self._splitter = SequenceSplitter()
214
+ featurizer = featurizers.MolGraphFeaturizer(
215
+ atom_features=atom_features,
216
+ bond_features=bond_features,
217
+ molecule_features=molecule_features,
218
+ super_atom=super_atom,
219
+ radius=radius,
220
+ self_loops=self_loops,
221
+ include_hs=include_hs,
222
+ **kwargs,
223
+ )
224
+ tensor_list: list[tensors.GraphTensor] = [
225
+ featurizer(residues[tag]).update({'context': {'tag': tag}}) for tag in residues
226
+ ]
227
+ graph = tf.stack(tensor_list, axis=0)
228
+ self._build_on_init(graph)
229
+
230
+ def call(self, sequence: tf.Tensor) -> tensors.GraphTensor:
231
+ sequence = self._splitter(sequence)
232
+ indices = self._tag_to_index.lookup(sequence)
233
+ indices = tf.sort(tf.unique(tf.reshape(indices, [-1]))[0])[1:]
234
+ graph = self._graph_tensor[indices]
235
+ return tensors.to_dict(graph)
236
+
237
+ def _build_on_init(self, x):
238
+
239
+ if isinstance(x, tensors.GraphTensor):
240
+ tensor = tensors.to_dict(x)
241
+ self._spec = tf.nest.map_structure(
242
+ tf.type_spec_from_value, tensor
243
+ )
244
+ else:
245
+ self._spec = x
246
+
247
+ self._graph = tf.nest.map_structure(
248
+ lambda s: self.add_weight(
249
+ shape=s.shape,
250
+ dtype=s.dtype,
251
+ trainable=False,
252
+ initializer='zeros'
253
+ ),
254
+ self._spec
255
+ )
256
+
257
+ if isinstance(x, tensors.GraphTensor):
258
+ tf.nest.map_structure(
259
+ lambda v, x: v.assign(x),
260
+ self._graph, tensor
261
+ )
262
+
263
+ graph = tf.nest.map_structure(
264
+ keras.ops.convert_to_tensor, self._graph
265
+ )
266
+ self._graph_tensor = tensors.from_dict(graph)
267
+
268
+ tags = self._graph_tensor.context['tag']
269
+
270
+ self._tag_to_index = tf.lookup.StaticHashTable(
271
+ tf.lookup.KeyValueTensorInitializer(
272
+ keys=tags,
273
+ values=range(len(tags)),
274
+ ),
275
+ default_value=-1,
276
+ )
277
+
278
+ def get_config(self) -> dict:
279
+ config = super().get_config()
280
+ spec = keras.saving.serialize_keras_object(self._spec)
281
+ config['spec'] = spec
282
+ return config
283
+
284
+ @classmethod
285
+ def from_config(cls, config: dict) -> 'SequenceToGraph':
286
+ spec = config.pop('spec')
287
+ spec = keras.saving.deserialize_keras_object(spec)
288
+ layer = cls(**config)
289
+ layer._build_on_init(spec)
290
+ return layer
291
+
292
+ # @property
293
+ # def graph(self) -> tensors.GraphTensor:
294
+ # return self._graph_tensor
295
+
296
+
297
+ @keras.saving.register_keras_serializable(package='molcraft')
298
+ class GraphToSequence(keras.layers.Layer):
299
+
300
+ def __init__(
301
+ self,
302
+ padding: list[tuple[int]] | tuple[int] | int = 1,
303
+ mask_value: int = 0,
304
+ **kwargs
305
+ ) -> None:
306
+ super().__init__(**kwargs)
307
+ self._splitter = SequenceSplitter()
308
+ self.padding = padding
309
+ self.mask_value = mask_value
310
+ self._readout_layer = layers.Readout(mode='mean')
311
+ self.supports_masking = True
312
+
313
+ def get_config(self):
314
+ config = super().get_config()
315
+ config['mask_value'] = self.mask_value
316
+ config['padding'] = self.padding
317
+ return config
318
+
319
+ def call(self, inputs) -> tf.Tensor:
320
+
321
+ graph, sequence = inputs
322
+ sequence = self._splitter(sequence)
323
+ tag = graph['context']['tag']
324
+ data = self._readout_layer(graph)
325
+
326
+ table = tf.lookup.experimental.MutableHashTable(
327
+ key_dtype=tf.string,
328
+ value_dtype=tf.int32,
329
+ default_value=-1
330
+ )
331
+
332
+ table.insert(tag, tf.range(tf.shape(tag)[0]))
333
+ sequence = table.lookup(sequence)
334
+
335
+ readout = ops.gather(data, keras.ops.where(sequence == -1, 0, sequence))
336
+ readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
337
+ return readout
338
+
339
+ def compute_mask(
340
+ self,
341
+ inputs: tensors.GraphTensor,
342
+ mask: bool | None = None
343
+ ) -> tf.Tensor | None:
344
+ # if self.mask_value is None:
345
+ # return None
346
+ _, sequence = inputs
347
+ sequence = self._splitter(sequence)
348
+ return keras.ops.not_equal(sequence, '')
349
+
350
+
351
+ residues = {
352
+ "A": "N[C@@H](C)C(=O)O",
353
+ "C": "N[C@@H](CS)C(=O)O",
354
+ "C[Carbamidomethyl]": "N[C@@H](CSCC(=O)N)C(=O)O",
355
+ "D": "N[C@@H](CC(=O)O)C(=O)O",
356
+ "E": "N[C@@H](CCC(=O)O)C(=O)O",
357
+ "F": "N[C@@H](Cc1ccccc1)C(=O)O",
358
+ "G": "NCC(=O)O",
359
+ "H": "N[C@@H](CC1=CN=C-N1)C(=O)O",
360
+ "I": "N[C@@H](C(CC)C)C(=O)O",
361
+ "K": "N[C@@H](CCCCN)C(=O)O",
362
+ "K[Acetyl]": "N[C@@H](CCCCNC(=O)C)C(=O)O",
363
+ "K[Crotonyl]": "N[C@@H](CCCCNC(C=CC)=O)C(=O)O",
364
+ "K[Dimethyl]": "N[C@@H](CCCCN(C)C)C(=O)O",
365
+ "K[Formyl]": "N[C@@H](CCCCNC=O)C(=O)O",
366
+ "K[Malonyl]": "N[C@@H](CCCCNC(=O)CC(O)=O)C(=O)O",
367
+ "K[Methyl]": "N[C@@H](CCCCNC)C(=O)O",
368
+ "K[Propionyl]": "N[C@@H](CCCCNC(=O)CC)C(=O)O",
369
+ "K[Succinyl]": "N[C@@H](CCCCNC(CCC(O)=O)=O)C(=O)O",
370
+ "K[Trimethyl]": "N[C@@H](CCCC[N+](C)(C)C)C(=O)O",
371
+ "L": "N[C@@H](CC(C)C)C(=O)O",
372
+ "M": "N[C@@H](CCSC)C(=O)O",
373
+ "M[Oxidation]": "N[C@@H](CCS(=O)C)C(=O)O",
374
+ "N": "N[C@@H](CC(=O)N)C(=O)O",
375
+ "P": "N1[C@@H](CCC1)C(=O)O",
376
+ "P[Oxidation]": "N1CC(O)C[C@H]1C(=O)O",
377
+ "Q": "N[C@@H](CCC(=O)N)C(=O)O",
378
+ "R": "N[C@@H](CCCNC(=N)N)C(=O)O",
379
+ "R[Deamidated]": "N[C@@H](CCCNC(N)=O)C(=O)O",
380
+ "R[Dimethyl]": "N[C@@H](CCCNC(N(C)C)=N)C(=O)O",
381
+ "R[Methyl]": "N[C@@H](CCCNC(=N)NC)C(=O)O",
382
+ "S": "N[C@@H](CO)C(=O)O",
383
+ "T": "N[C@@H](C(O)C)C(=O)O",
384
+ "V": "N[C@@H](C(C)C)C(=O)O",
385
+ "W": "N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O",
386
+ "Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
387
+ "Y[Nitro]": "N[C@@H](Cc1ccc(O)c(N(=O)=O)c1)C(=O)O",
388
+ "Y[Phospho]": "N[C@@H](Cc1ccc(OP(O)(=O)O)cc1)C(=O)O",
389
+ "[Acetyl]-A": "N(C(C)=O)[C@@H](C)C(=O)O",
390
+ "[Acetyl]-C": "N(C(C)=O)[C@@H](CS)C(=O)O",
391
+ "[Acetyl]-D": "N(C(=O)C)[C@H](C(=O)O)CC(=O)O",
392
+ "[Acetyl]-E": "N(C(=O)C)[C@@H](CCC(O)=O)C(=O)O",
393
+ "[Acetyl]-F": "N(C(C)=O)[C@@H](Cc1ccccc1)C(=O)O",
394
+ "[Acetyl]-G": "N(C(=O)C)CC(=O)O",
395
+ "[Acetyl]-H": "N(C(=O)C)[C@@H](Cc1[nH]cnc1)C(=O)O",
396
+ "[Acetyl]-I": "N(C(=O)C)[C@@H]([C@H](CC)C)C(=O)O",
397
+ "[Acetyl]-K": "N(C(C)=O)[C@@H](CCCCN)C(=O)O",
398
+ "[Acetyl]-L": "N(C(=O)C)[C@@H](CC(C)C)C(=O)O",
399
+ "[Acetyl]-M": "N(C(=O)C)[C@@H](CCSC)C(=O)O",
400
+ "[Acetyl]-N": "N(C(C)=O)[C@@H](CC(=O)N)C(=O)O",
401
+ "[Acetyl]-P": "N1(C(=O)C)CCC[C@H]1C(=O)O",
402
+ "[Acetyl]-Q": "N(C(=O)C)[C@@H](CCC(=O)N)C(=O)O",
403
+ "[Acetyl]-R": "N(C(C)=O)[C@@H](CCCN=C(N)N)C(=O)O",
404
+ "[Acetyl]-S": "N(C(C)=O)[C@@H](CO)C(=O)O",
405
+ "[Acetyl]-T": "N(C(=O)C)[C@@H]([C@H](O)C)C(=O)O",
406
+ "[Acetyl]-V": "N(C(=O)C)[C@@H](C(C)C)C(=O)O",
407
+ "[Acetyl]-W": "N(C(C)=O)[C@@H](Cc1c2ccccc2[nH]c1)C(=O)O",
408
+ "[Acetyl]-Y": "N(C(C)=O)[C@@H](Cc1ccc(O)cc1)C(=O)O"
409
+ }
410
+
411
+ residues_reverse = {}
412
+ def register_peptide_residues(residues_: dict[str, str], canonicalize=True):
413
+ for residue, smiles in residues_.items():
414
+ if canonicalize:
415
+ smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
416
+ residues[residue] = smiles
417
+ residues_reverse[residues[residue]] = residue
418
+
419
+ register_peptide_residues(residues, canonicalize=False)
420
+
421
+ def _extract_residue_type(residue_tag: str) -> str:
422
+ pattern = r"(?<!\[)[A-Z](?![^\[]*\])"
423
+ return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]
424
+
425
+ special_residues = {}
426
+ for key, value in residues.items():
427
+ special_residues[key + '-[X]'] = value.rstrip('O')
428
+
429
+ register_peptide_residues(special_residues, canonicalize=False)
molcraft/chem.py CHANGED
@@ -426,10 +426,12 @@ def embed_conformers(
426
426
  success = rdDistGeom.EmbedMultipleConfs(
427
427
  mol, numConfs=num_conformers, params=embedding_method
428
428
  )
429
- if not len(success):
429
+ num_successes = len(success)
430
+ if num_successes < num_conformers:
430
431
  warnings.warn(
431
- f'Could not embed conformer(s) for {mol.canonical_smiles!r} using the '
432
- 'speified method. Giving it another try with more permissive methods.',
432
+ f'Could only embed {num_successes} out of {num_conformers} conformer(s) '
433
+ f'for {mol.canonical_smiles!r} using {method}. Embedding the remaining '
434
+ f'{num_conformers - num_successes} conformer(s) using different embedding methods.',
433
435
  stacklevel=2
434
436
  )
435
437
  max_attempts = (20 * mol.num_atoms) # increasing it from 10xN to 20xN
@@ -437,14 +439,16 @@ def embed_conformers(
437
439
  fallback_embedding_method = available_embedding_methods[fallback_method]
438
440
  fallback_embedding_method.useRandomCoords = True
439
441
  fallback_embedding_method.maxAttempts = max_attempts
442
+ fallback_embedding_method.clearConfs = False
440
443
  success = rdDistGeom.EmbedMultipleConfs(
441
- mol, numConfs=num_conformers, params=fallback_embedding_method
444
+ mol, numConfs=(num_conformers - num_successes), params=fallback_embedding_method
442
445
  )
443
- if len(success):
446
+ num_successes += len(success)
447
+ if num_successes == num_conformers:
444
448
  break
445
449
  else:
446
450
  raise RuntimeError(
447
- f'Could not embed conformer(s) for {mol.canonical_smiles!r}. '
451
+ f'Could not embed {num_conformers} conformer(s) for {mol.canonical_smiles!r}. '
448
452
  )
449
453
  return mol
450
454
 
molcraft/descriptors.py CHANGED
@@ -61,7 +61,7 @@ class NumHeavyAtoms(Descriptor):
61
61
 
62
62
 
63
63
  @keras.saving.register_keras_serializable(package='molcraft')
64
- class NumHeteroAtoms(Descriptor):
64
+ class NumHeteroatoms(Descriptor):
65
65
  def call(self, mol: chem.Mol) -> np.ndarray:
66
66
  return rdMolDescriptors.CalcNumHeteroatoms(mol)
67
67
 
molcraft/features.py CHANGED
@@ -185,13 +185,13 @@ class Degree(Feature):
185
185
 
186
186
 
187
187
  @keras.saving.register_keras_serializable(package='molcraft')
188
- class TotalNumHs(Feature):
188
+ class NumHydrogens(Feature):
189
189
  def call(self, mol: chem.Mol) -> list[int, float, str]:
190
190
  return [atom.GetTotalNumHs() for atom in mol.atoms]
191
191
 
192
192
 
193
193
  @keras.saving.register_keras_serializable(package='molcraft')
194
- class TotalValence(Feature):
194
+ class Valence(Feature):
195
195
  def call(self, mol: chem.Mol) -> list[int, float, str]:
196
196
  return [atom.GetTotalValence() for atom in mol.atoms]
197
197
 
@@ -218,10 +218,17 @@ class CIPCode(Feature):
218
218
 
219
219
 
220
220
  @keras.saving.register_keras_serializable(package='molcraft')
221
- class IsChiralityPossible(Feature):
221
+ class RingSize(Feature):
222
222
  def call(self, mol: chem.Mol) -> list[int, float, str]:
223
- return [atom.HasProp("_ChiralityPossible") for atom in mol.atoms]
224
-
223
+ def ring_size(atom):
224
+ if not atom.IsInRing():
225
+ return -1
226
+ size = 3
227
+ while not atom.IsInRingSize(size):
228
+ size += 1
229
+ return size
230
+ return [ring_size(atom) for atom in mol.atoms]
231
+
225
232
 
226
233
  @keras.saving.register_keras_serializable(package='molcraft')
227
234
  class FormalCharge(Feature):
@@ -229,6 +236,12 @@ class FormalCharge(Feature):
229
236
  return [atom.GetFormalCharge() for atom in mol.atoms]
230
237
 
231
238
 
239
+ @keras.saving.register_keras_serializable(package='molcraft')
240
+ class IsChiralityPossible(Feature):
241
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
242
+ return [atom.HasProp("_ChiralityPossible") for atom in mol.atoms]
243
+
244
+
232
245
  @keras.saving.register_keras_serializable(package='molcraft')
233
246
  class NumRadicalElectrons(Feature):
234
247
  def call(self, mol: chem.Mol) -> list[int, float, str]:
@@ -242,7 +255,7 @@ class IsAromatic(Feature):
242
255
 
243
256
 
244
257
  @keras.saving.register_keras_serializable(package='molcraft')
245
- class IsHetero(Feature):
258
+ class IsHeteroatom(Feature):
246
259
  def call(self, mol: chem.Mol) -> list[int, float, str]:
247
260
  return chem.hetero_atoms(mol)
248
261
 
@@ -259,19 +272,6 @@ class IsHydrogenAcceptor(Feature):
259
272
  return chem.hydrogen_acceptors(mol)
260
273
 
261
274
 
262
- @keras.saving.register_keras_serializable(package='molcraft')
263
- class RingSize(Feature):
264
- def call(self, mol: chem.Mol) -> list[int, float, str]:
265
- def ring_size(atom):
266
- if not atom.IsInRing():
267
- return -1
268
- size = 3
269
- while not atom.IsInRingSize(size):
270
- size += 1
271
- return size
272
- return [ring_size(atom) for atom in mol.atoms]
273
-
274
-
275
275
  @keras.saving.register_keras_serializable(package='molcraft')
276
276
  class IsInRing(Feature):
277
277
  def call(self, mol: chem.Mol) -> list[int, float, str]:
molcraft/featurizers.py CHANGED
@@ -196,7 +196,7 @@ class MolGraphFeaturizer(Featurizer):
196
196
  descriptors.CrippenLogP(),
197
197
  descriptors.CrippenMolarRefractivity(),
198
198
  descriptors.NumHeavyAtoms(),
199
- descriptors.NumHeteroAtoms(),
199
+ descriptors.NumHeteroatoms(),
200
200
  descriptors.NumHydrogenDonors(),
201
201
  descriptors.NumHydrogenAcceptors(),
202
202
  descriptors.NumRotatableBonds(),
molcraft/layers.py CHANGED
@@ -350,7 +350,7 @@ class GraphConv(GraphLayer):
350
350
  )
351
351
  if self._project_residual:
352
352
  warnings.warn(
353
- '`skip_connect` is set to `True`, but found incompatible dim '
353
+ '`skip_connect` is set to `True`, but found incompatible dim '
354
354
  'between input (node feature dim) and output (`self.units`). '
355
355
  'Automatically applying a projection layer to residual to '
356
356
  'match input and output. ',
@@ -369,7 +369,7 @@ class GraphConv(GraphLayer):
369
369
  self._message_intermediate_activation = self.activation
370
370
  self._message_final_dense = self.get_dense(self.units)
371
371
 
372
- has_overridden_aggregate = self.__class__.message != GraphConv.aggregate
372
+ has_overridden_aggregate = self.__class__.message != GraphConv.aggregate
373
373
  if not has_overridden_aggregate:
374
374
  pass
375
375
 
@@ -401,13 +401,15 @@ class GraphConv(GraphLayer):
401
401
  residual = self._residual_dense(residual)
402
402
 
403
403
  message = self.message(tensor)
404
- if not isinstance(message, tensors.GraphTensor):
404
+ add_message = not isinstance(message, tensors.GraphTensor)
405
+ if add_message:
405
406
  message = tensor.update({'edge': {'message': message}})
406
407
  elif not 'message' in message.edge:
407
408
  raise ValueError('Could not find `message` in `edge` output.')
408
-
409
+
409
410
  aggregate = self.aggregate(message)
410
- if not isinstance(aggregate, tensors.GraphTensor):
411
+ add_aggregate = not isinstance(aggregate, tensors.GraphTensor)
412
+ if add_aggregate:
411
413
  aggregate = tensor.update({'node': {'aggregate': aggregate}})
412
414
  elif not 'aggregate' in aggregate.node:
413
415
  raise ValueError('Could not find `aggregate` in `node` output.')
@@ -421,6 +423,16 @@ class GraphConv(GraphLayer):
421
423
  if update.node['feature'].shape[-1] != self.units:
422
424
  raise ValueError('Updated node `feature` is not equal to `self.units`.')
423
425
 
426
+ if add_message and add_aggregate:
427
+ update = update.update({'node': {'aggregate': None}, 'edge': {'message': None}})
428
+ elif add_message:
429
+ update = update.update({'edge': {'message': None}})
430
+ elif add_aggregate:
431
+ update = update.update({'node': {'aggregate': None}})
432
+
433
+ if not self._skip_connect and not self._normalize:
434
+ return update
435
+
424
436
  feature = update.node['feature']
425
437
 
426
438
  if self._skip_connect:
@@ -649,7 +661,7 @@ class GIConv(GraphConv):
649
661
  return config
650
662
 
651
663
 
652
- @keras.saving.register_keras_serializable(package='molgraphx')
664
+ @keras.saving.register_keras_serializable(package='molcraft')
653
665
  class GAConv(GraphConv):
654
666
 
655
667
  """Graph attention network layer.
molcraft/losses.py CHANGED
@@ -2,7 +2,7 @@ import keras
2
2
  import numpy as np
3
3
 
4
4
 
5
- @keras.saving.register_keras_serializable(package='molgraph')
5
+ @keras.saving.register_keras_serializable(package='molcraft')
6
6
  class GaussianNegativeLogLikelihood(keras.losses.Loss):
7
7
 
8
8
  def __init__(
molcraft/models.py CHANGED
@@ -114,6 +114,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
114
114
  return typing.cast(GraphModel, super().__new__(cls))
115
115
 
116
116
  def __init__(self, *args, **kwargs):
117
+ self._model_layers = kwargs.pop('model_layers', None)
117
118
  super().__init__(*args, **kwargs)
118
119
  self.jit_compile = False
119
120
 
@@ -135,10 +136,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
135
136
  `molcraft.layers.Input(spec)`.
136
137
  """
137
138
  if not tensors.is_graph(graph_layers[0]):
138
- # TODO: Allow this. E.g.: return cls(layers=graph_layers)
139
- raise ValueError(
140
- 'Graph input not found. Make sure to add `Input`.'
141
- )
139
+ return cls(model_layers=graph_layers)
142
140
  inputs: dict = graph_layers.pop(0)
143
141
  x = inputs
144
142
  for layer in graph_layers:
@@ -148,6 +146,31 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
148
146
  outputs = x
149
147
  return cls(inputs=inputs, outputs=outputs, **kwargs)
150
148
 
149
+ def propagate(self, graph: tensors.GraphTensor) -> tensors.GraphTensor:
150
+ if self._model_layers is None:
151
+ return super().propagate(graph)
152
+ for layer in self._model_layers:
153
+ graph = layer(graph)
154
+ return graph
155
+
156
+ def get_config(self):
157
+ config = super().get_config()
158
+ if hasattr(self, '_model_layers') and self._model_layers is not None:
159
+ config['model_layers'] = [
160
+ keras.saving.serialize_keras_object(l)
161
+ for l in self._model_layers
162
+ ]
163
+ return config
164
+
165
+ @classmethod
166
+ def from_config(cls, config: dict):
167
+ if 'model_layers' in config:
168
+ config['model_layers'] = [
169
+ keras.saving.deserialize_keras_object(l)
170
+ for l in config['model_layers']
171
+ ]
172
+ return super().from_config(config)
173
+
151
174
  def compile(
152
175
  self,
153
176
  optimizer: keras.optimizers.Optimizer | str | None = 'rmsprop',
@@ -416,7 +439,6 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
416
439
  return self(tensor, training=False)
417
440
 
418
441
  def compute_loss(self, x, y, y_pred, sample_weight=None):
419
- y, y_pred, sample_weight = _maybe_reshape(y, y_pred, sample_weight)
420
442
  return super().compute_loss(x, y, y_pred, sample_weight)
421
443
 
422
444
  def compute_metrics(self, x, y, y_pred, sample_weight=None) -> dict[str, float]:
@@ -531,58 +553,6 @@ def saliency(
531
553
  }
532
554
  }
533
555
  )
534
-
535
- def predict(
536
- model: GraphModel,
537
- x: tensors.GraphTensor | tf.data.Dataset,
538
- repeats: int | None = 16,
539
- batch_size: int = 256,
540
- verbose: int = 0,
541
- **kwargs,
542
- ) -> tuple[tf.Tensor | np.ndarray, tf.Tensor | np.ndarray]:
543
- """Predict with model.
544
-
545
- By default performs monte-carlo predictions. Namely, it performs
546
- `repeats` number of predictions for each example with `training = True`,
547
- and subsequently computes mean and standard deviations of the predictions.
548
-
549
- Args:
550
- x:
551
- A `GraphTensor` instance.
552
- repeats:
553
- Number of predictions per example.
554
- batch_size:
555
- Number of samples per batch of computation.
556
- kwargs:
557
- See `Model.predict` in Keras documentation.
558
- May or may not apply here.
559
- """
560
- if not repeats:
561
- return model.predict(
562
- x, batch_size=batch_size, verbose=verbose, **kwargs
563
- )
564
- if isinstance(x, tensors.GraphTensor):
565
- ds = tf.data.Dataset.from_tensor_slices(x)
566
- ds = ds.repeat(repeats)
567
- ds = ds.batch(batch_size)
568
- elif isinstance(x, tf.data.Dataset):
569
- ds = x.repeat(repeats)
570
- else:
571
- raise ValueError(
572
- 'Input `x` needs to be a `tensors.GraphTensor` instance '
573
- 'or a `tf.data.Dataset` instance constructed from `tensors.GraphTensor`.'
574
- )
575
- ds = ds.prefetch(-1)
576
- y_pred = keras.ops.concatenate([
577
- model(x, training=True) for x in ds])
578
- global_batch_size = len(y_pred) // repeats
579
- y_pred = np.reshape(y_pred, (repeats, global_batch_size, -1))
580
- y_pred_loc = keras.ops.mean(y_pred, axis=0)
581
- y_pred_scale = keras.ops.std(y_pred, axis=0)
582
- if tf.executing_eagerly():
583
- y_pred_loc = y_pred_loc.numpy()
584
- y_pred_scale = y_pred_scale.numpy()
585
- return (y_pred_loc, y_pred_scale)
586
556
 
587
557
  def _functional_init_arguments(args, kwargs):
588
558
  return (
@@ -597,14 +567,3 @@ def _make_dataset(x: tensors.GraphTensor, batch_size: int):
597
567
  .batch(batch_size)
598
568
  .prefetch(-1)
599
569
  )
600
-
601
- def _maybe_reshape(y, y_pred, sample_weight):
602
- if (
603
- sample_weight is not None and
604
- len(keras.ops.shape(sample_weight)) == 2 and
605
- sample_weight.shape == y_pred.shape
606
- ):
607
- y = keras.ops.reshape(y, [-1])
608
- y_pred = keras.ops.reshape(y_pred, [-1])
609
- sample_weight = keras.ops.reshape(sample_weight, [-1])
610
- return y, y_pred, sample_weight
molcraft/ops.py CHANGED
@@ -105,7 +105,11 @@ def segment_mean(
105
105
  lambda: 0
106
106
  )
107
107
  if backend.backend() == 'tensorflow':
108
- return tf.math.unsorted_segment_mean(
108
+ segment_mean_fn = (
109
+ tf.math.unsorted_segment_mean if not sorted else
110
+ tf.math.segment_mean
111
+ )
112
+ return segment_mean_fn(
109
113
  data=data,
110
114
  segment_ids=segment_ids,
111
115
  num_segments=num_segments
molcraft/records.py CHANGED
@@ -51,19 +51,24 @@ def write(
51
51
  if num_files is None:
52
52
  num_files = min(len(inputs), max(1, math.ceil(len(inputs) / 1_000)))
53
53
 
54
- chunk_size = math.ceil(len(inputs) / num_files)
55
- num_files = math.ceil(len(inputs) / chunk_size)
54
+ num_examples = len(inputs)
55
+ chunk_sizes = [0] * num_files
56
+ for i in range(num_examples):
57
+ chunk_sizes[i % num_files] += 1
58
+
59
+ input_chunks = []
60
+ current_index = 0
61
+ for size in chunk_sizes:
62
+ input_chunks.append(inputs[current_index: current_index + size])
63
+ current_index += size
64
+
65
+ assert current_index == num_examples
56
66
 
57
67
  paths = [
58
68
  os.path.join(path, f'tfrecord-{i:04d}.tfrecord')
59
69
  for i in range(num_files)
60
70
  ]
61
71
 
62
- input_chunks = [
63
- inputs[i * chunk_size: (i + 1) * chunk_size]
64
- for i in range(num_files)
65
- ]
66
-
67
72
  if not multiprocessing:
68
73
  for path, input_chunk in zip(paths, input_chunks):
69
74
  _write_tfrecord(input_chunk, path, featurizer)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a9
3
+ Version: 0.1.0a11
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -25,7 +25,7 @@ License: MIT License
25
25
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
26
  SOFTWARE.
27
27
 
28
- Project-URL: Homepage, https://github.com/akensert/molcraft
28
+ Project-URL: Homepage, https://github.com/compomics/molcraft
29
29
  Keywords: python,machine-learning,deep-learning,graph-neural-networks,molecular-machine-learning,molecular-graphs,computational-chemistry,computational-biology
30
30
  Classifier: Programming Language :: Python :: 3
31
31
  Classifier: Intended Audience :: Science/Research
@@ -47,15 +47,20 @@ Dynamic: license-file
47
47
  **Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
48
48
 
49
49
  > [!NOTE]
50
- > In progress/Unfinished.
50
+ > In progress.
51
51
 
52
- ## Highlights
53
- - Compatible with **Keras 3**
54
- - Customizable and serializable **featurizers**
55
- - Customizable and serializable **layers** and **models**
56
- - Customizable **GraphTensor**
57
- - Fast and efficient featurization of molecular graphs
58
- - Fast and efficient input pipelines using TF **records**
52
+ ## Installation
53
+
54
+ For CPU users:
55
+
56
+ ```bash
57
+ pip install --pre molcraft
58
+ ```
59
+
60
+ For GPU users:
61
+ ```bash
62
+ pip install --pre molcraft[gpu]
63
+ ```
59
64
 
60
65
  ## Examples
61
66
 
@@ -70,7 +75,7 @@ import keras
70
75
  featurizer = featurizers.MolGraphFeaturizer(
71
76
  atom_features=[
72
77
  features.AtomType(),
73
- features.TotalNumHs(),
78
+ features.NumHydrogens(),
74
79
  features.Degree(),
75
80
  ],
76
81
  bond_features=[
@@ -0,0 +1,21 @@
1
+ molcraft/__init__.py,sha256=Huk8xSj59YLku1q0poDWWsKArf7_HULYSFbA9Jpn8u0,464
2
+ molcraft/callbacks.py,sha256=x5HnkZhqcFRrW6xdApt_jZ4X08A-0fxcnFKfdmRKa0c,3571
3
+ molcraft/chem.py,sha256=JARpv4IgFBtuNia0FLW_VF_DdmaA6e-_eZgH9dFAykA,21796
4
+ molcraft/conformers.py,sha256=K6ZtiSUNDN_fwqGP9JrPcwALLFFvlMlF_XejEJH3Sr4,4205
5
+ molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
+ molcraft/descriptors.py,sha256=W8GLuDpc38RtwmreNsPOcn-PRvMjTfVng9ksJwcrVyM,3032
7
+ molcraft/features.py,sha256=FpvT_9zk9EiOhvrk6OA5eEvUAYalquF7V6IvpiEJCns,13559
8
+ molcraft/featurizers.py,sha256=A_0wJfvz9JuPEZINi2iZoFNhhHgid608XJTTuVO_jwo,27063
9
+ molcraft/layers.py,sha256=cUpo9dqqNEnc7rNf-Dze8adFhOkTV5F9IhHOKs13OUI,60134
10
+ molcraft/losses.py,sha256=qnS2yC5g-O3n_zVea9MR6TNiFraW2yqRgePOisoUP4A,1065
11
+ molcraft/models.py,sha256=0x74B4WsaZgmUrHmpX9YNr9QXqd1rR3QF_ygyegHoXU,21770
12
+ molcraft/ops.py,sha256=PVxKfY_XbWCyntiSnmpyeBb-coFGT_VNNP9QzmeUwC0,4870
13
+ molcraft/records.py,sha256=MbvYkcCunbAmpy_MWXmQ9WBGi2WvwxFUlwQSPKPvSSk,5534
14
+ molcraft/tensors.py,sha256=EOUKx496KUZsjA1zA2ABc7tU_TW3Jv7AXDsug_QsLbA,22407
15
+ molcraft/apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ molcraft/apps/peptides.py,sha256=N5wJDGDIDRbmOmxin_dTY-odLqb0avAX9FU22U6x6c0,14576
17
+ molcraft-0.1.0a11.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
18
+ molcraft-0.1.0a11.dist-info/METADATA,sha256=jIcab-EvRqLHqM13ftx_eWNz5WjPZTkdmdNM8VttMYA,3893
19
+ molcraft-0.1.0a11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ molcraft-0.1.0a11.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
21
+ molcraft-0.1.0a11.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.7.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,19 +0,0 @@
1
- molcraft/__init__.py,sha256=8f1z8Lhuhh8TxB-QGHI5w4a3M_ZZNH8EWGD4Y6pB578,463
2
- molcraft/callbacks.py,sha256=x5HnkZhqcFRrW6xdApt_jZ4X08A-0fxcnFKfdmRKa0c,3571
3
- molcraft/chem.py,sha256=zHH7iX0ZJ7QmP-YqR_IXCpylTwCXHXptWf1DsblnZR4,21496
4
- molcraft/conformers.py,sha256=K6ZtiSUNDN_fwqGP9JrPcwALLFFvlMlF_XejEJH3Sr4,4205
5
- molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
- molcraft/descriptors.py,sha256=gKqlJ3BqJLTeR2ft8isftSEaJDC8cv64eTq5IYhy4XM,3032
7
- molcraft/features.py,sha256=aBYxDfQqQsVuyjKaPUlwEgvCjbNZ-FJhuKo2Cg5ajrA,13554
8
- molcraft/featurizers.py,sha256=ybJ1djH747cgsftztWHxAX2iTq6k03MYr17btQ2Gtcs,27063
9
- molcraft/layers.py,sha256=r6hEAyJxO_Yrw5hD1r2v8yb_UxLRK9S4FMjDCUQedH8,59655
10
- molcraft/losses.py,sha256=JEKZEX2f8vDgky_fUocsF8vZjy9VMzRjZUBa20Uf9Qw,1065
11
- molcraft/models.py,sha256=FLXpO3OUmRxLmxG3MjBK4ZwcVFlea1gqEgs1ibKly2w,23263
12
- molcraft/ops.py,sha256=dLIUq-KG8nOzEcphJqNbF_f82VZRDNrB1UKrcPt5JNM,4752
13
- molcraft/records.py,sha256=0sjOdcr266ZER4F-aTBQ3AVPNAwflKWNiNJVsSc1-PQ,5370
14
- molcraft/tensors.py,sha256=EOUKx496KUZsjA1zA2ABc7tU_TW3Jv7AXDsug_QsLbA,22407
15
- molcraft-0.1.0a9.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
16
- molcraft-0.1.0a9.dist-info/METADATA,sha256=HiwS2wmntCA7m_YpgSWKiJTP0BFpl4GWWz4a77w1XBw,4062
17
- molcraft-0.1.0a9.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
18
- molcraft-0.1.0a9.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
19
- molcraft-0.1.0a9.dist-info/RECORD,,