molcraft 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/features.py ADDED
@@ -0,0 +1,387 @@
1
+ import abc
2
+ import math
3
+ import keras
4
+ import warnings
5
+ import numpy as np
6
+
7
+ from molcraft import chem
8
+
9
+
10
+ @keras.saving.register_keras_serializable(package='molcraft')
11
+ class Feature(abc.ABC):
12
+
13
+ def __init__(
14
+ self,
15
+ vocab: set[int | str] = None,
16
+ allow_oov: bool = True,
17
+ encode_oov: bool = False,
18
+ dtype: str = 'float32'
19
+ ) -> None:
20
+ self.encode_oov = encode_oov
21
+ self.allow_oov = allow_oov
22
+ self.oov_token = '<oov>'
23
+ self.dtype = dtype
24
+ if not vocab:
25
+ vocab = default_vocabulary.get(self.name, None)
26
+ if vocab:
27
+ if isinstance(vocab, set):
28
+ vocab: list = list(vocab)
29
+ vocab.sort(key=lambda x: x if x is not None else "")
30
+ elif not isinstance(vocab, list):
31
+ vocab: list = list(vocab)
32
+ if self.encode_oov and self.oov_token not in vocab:
33
+ vocab.append(self.oov_token)
34
+ onehot_encodings = np.eye(len(vocab), dtype=self.dtype)
35
+ self.feature_to_onehot = dict(zip(vocab, onehot_encodings))
36
+ self.vocab = vocab
37
+
38
+ @abc.abstractmethod
39
+ def call(self, mol: chem.Mol) -> list[float | int | bool | str]:
40
+ pass
41
+
42
+ def __call__(self, mol: chem.Mol) -> np.ndarray:
43
+ if not isinstance(mol, chem.Mol):
44
+ raise ValueError(
45
+ f'Input to {self.name} needs to be a `chem.Mol`, which '
46
+ 'implements two properties that should be iterated over '
47
+ 'to compute features: `atoms` and `bonds`.'
48
+ )
49
+ features = self.call(mol)
50
+ if len(features) != mol.num_atoms and len(features) != mol.num_bonds:
51
+ raise ValueError(
52
+ f'The number of features computed by {self.name} does not '
53
+ 'match the number of atoms or bonds of the `chem.Mol` object. '
54
+ 'Make sure to iterate over `atoms` or `bonds` of `chem.Mol` '
55
+ 'when computing features.'
56
+ )
57
+ if len(features) == 0:
58
+ # Edge case: no atoms or bonds in the molecule.
59
+ return np.zeros((0, self.output_dim), dtype=self.dtype)
60
+
61
+ func = (
62
+ self._featurize_categorical if self.vocab else
63
+ self._featurize_floating
64
+ )
65
+ return np.stack([func(x) for x in features])
66
+
67
+
68
+ def get_config(self) -> dict:
69
+ config = {
70
+ 'vocab': self.vocab,
71
+ 'allow_oov': self.allow_oov,
72
+ 'encode_oov': self.encode_oov,
73
+ 'dtype': self.dtype
74
+ }
75
+ return config
76
+
77
+ @classmethod
78
+ def from_config(cls, config: dict) -> 'Feature':
79
+ return cls(**config)
80
+
81
+ @property
82
+ def name(self) -> str:
83
+ return self.__class__.__name__
84
+
85
+ @property
86
+ def output_dim(self) -> int:
87
+ return 1 if not self.vocab else len(self.vocab)
88
+
89
+ def _featurize_categorical(self, feature: str | int) -> np.ndarray:
90
+ encoding = self.feature_to_onehot.get(feature, None)
91
+ if encoding is not None:
92
+ return encoding
93
+ if not self.allow_oov:
94
+ raise ValueError(
95
+ f'{feature} could not be encoded, as it was not found in `vocab`. '
96
+ 'To allow OOV features, set `allow_oov` or `encode_oov` to True.'
97
+ )
98
+ oov_encoding = self.feature_to_onehot.get(self.oov_token, None)
99
+ if oov_encoding is None:
100
+ oov_encoding = np.zeros([self.output_dim], dtype=self.dtype)
101
+ return oov_encoding
102
+
103
+ def _featurize_floating(self, value: float | list[float]) -> np.ndarray:
104
+ if not isinstance(value, (int, float, bool)):
105
+ raise ValueError(
106
+ f'{self.name} produced a value of type {type(value)}. '
107
+ 'If it represents a categorical feature, please provide a `vocab` '
108
+ 'to the constructor. If if represents a floating point feature, '
109
+ 'please make sure its `call` method returns a list of values of '
110
+ 'type `float`, `int`, `bool` or `None`.'
111
+ )
112
+ if not math.isfinite(value):
113
+ warn(
114
+ f'Found value of {self.name} to be non-finite. '
115
+ f'Value received: {value}. Converting it to a value of 0.'
116
+ )
117
+ value = 0.0
118
+ return np.asarray([value], dtype=self.dtype)
119
+
120
+
121
+ @keras.saving.register_keras_serializable(package='molcraft')
122
+ class EdgeFeature(Feature):
123
+
124
+ def __call__(self, mol: chem.Mol) -> np.ndarray:
125
+ if not isinstance(mol, chem.Mol):
126
+ raise ValueError(
127
+ f'Input to {self.name} needs to be a `chem.Mol`, which '
128
+ 'implements two properties that should be iterated over '
129
+ 'to compute features: `atoms` and `bonds`.'
130
+ )
131
+ features = self.call(mol)
132
+ if len(features) != int(mol.num_atoms**2):
133
+ raise ValueError(
134
+ f'The number of features computed by {self.name} does not '
135
+ 'match the number of node pairs in the `chem.Mol` object. '
136
+ f'Make sure the list of items returned by {self.name}(input) '
137
+ 'correspond to node/atom pairs: '
138
+ '[(0, 0), (0, 1), ..., (0, N), (1, 0), ... (N, N)], '
139
+ 'where N denotes the number of nodes/atoms.'
140
+ )
141
+ func = (
142
+ self._featurize_categorical if self.vocab else
143
+ self._featurize_floating
144
+ )
145
+ return np.asarray([func(x) for x in features], dtype=self.dtype)
146
+
147
+
148
+ @keras.saving.register_keras_serializable(package='molcraft')
149
+ class Distance(EdgeFeature):
150
+
151
+ def __init__(
152
+ self,
153
+ max_distance: int = None,
154
+ allow_oov: int = True,
155
+ encode_oov: bool = True,
156
+ **kwargs,
157
+ ) -> None:
158
+ if max_distance is None:
159
+ max_distance = 20
160
+ vocab = list(range(max_distance + 1))
161
+ super().__init__(
162
+ vocab=vocab,
163
+ allow_oov=allow_oov,
164
+ encode_oov=encode_oov,
165
+ **kwargs
166
+ )
167
+
168
+ def call(self, mol: chem.Mol) -> list[int]:
169
+ return [int(x) for x in chem.get_distances(mol).reshape(-1)]
170
+
171
+
172
+ @keras.saving.register_keras_serializable(package='molcraft')
173
+ class AtomType(Feature):
174
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
175
+ return [atom.GetSymbol() for atom in mol.atoms]
176
+
177
+
178
+ @keras.saving.register_keras_serializable(package='molcraft')
179
+ class Degree(Feature):
180
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
181
+ return [atom.GetDegree() for atom in mol.atoms]
182
+
183
+
184
+ @keras.saving.register_keras_serializable(package='molcraft')
185
+ class TotalNumHs(Feature):
186
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
187
+ return [atom.GetTotalNumHs() for atom in mol.atoms]
188
+
189
+
190
+ @keras.saving.register_keras_serializable(package='molcraft')
191
+ class TotalValence(Feature):
192
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
193
+ return [atom.GetTotalValence() for atom in mol.atoms]
194
+
195
+
196
+ @keras.saving.register_keras_serializable(package='molcraft')
197
+ class AtomicWeight(Feature):
198
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
199
+ pt = chem.get_periodic_table()
200
+ return [pt.GetAtomicWeight(atom.GetSymbol()) for atom in mol.atoms]
201
+
202
+
203
+ @keras.saving.register_keras_serializable(package='molcraft')
204
+ class Hybridization(Feature):
205
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
206
+ return [str(atom.GetHybridization()).lower() for atom in mol.atoms]
207
+
208
+
209
+ @keras.saving.register_keras_serializable(package='molcraft')
210
+ class CIPCode(Feature):
211
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
212
+ return [
213
+ atom.GetProp("_CIPCode") if atom.HasProp("_CIPCode") else "None"
214
+ for atom in mol.atoms]
215
+
216
+
217
+ @keras.saving.register_keras_serializable(package='molcraft')
218
+ class IsChiralityPossible(Feature):
219
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
220
+ return [atom.HasProp("_ChiralityPossible") for atom in mol.atoms]
221
+
222
+
223
+ @keras.saving.register_keras_serializable(package='molcraft')
224
+ class FormalCharge(Feature):
225
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
226
+ return [atom.GetFormalCharge() for atom in mol.atoms]
227
+
228
+
229
+ @keras.saving.register_keras_serializable(package='molcraft')
230
+ class NumRadicalElectrons(Feature):
231
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
232
+ return [atom.GetNumRadicalElectrons() for atom in mol.atoms]
233
+
234
+
235
+ @keras.saving.register_keras_serializable(package='molcraft')
236
+ class IsAromatic(Feature):
237
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
238
+ return [atom.GetIsAromatic() for atom in mol.atoms]
239
+
240
+
241
+ @keras.saving.register_keras_serializable(package='molcraft')
242
+ class IsHetero(Feature):
243
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
244
+ return chem.hetero_atoms(mol)
245
+
246
+
247
+ @keras.saving.register_keras_serializable(package='molcraft')
248
+ class IsHydrogenDonor(Feature):
249
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
250
+ return chem.hydrogen_donors(mol)
251
+
252
+
253
+ @keras.saving.register_keras_serializable(package='molcraft')
254
+ class IsHydrogenAcceptor(Feature):
255
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
256
+ return chem.hydrogen_acceptors(mol)
257
+
258
+
259
+ @keras.saving.register_keras_serializable(package='molcraft')
260
+ class RingSize(Feature):
261
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
262
+ def ring_size(atom):
263
+ if not atom.IsInRing():
264
+ return -1
265
+ size = 3
266
+ while not atom.IsInRingSize(size):
267
+ size += 1
268
+ return size
269
+ return [ring_size(atom) for atom in mol.atoms]
270
+
271
+
272
+ @keras.saving.register_keras_serializable(package='molcraft')
273
+ class IsInRing(Feature):
274
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
275
+ return [atom.IsInRing() for atom in mol.atoms]
276
+
277
+
278
+ @keras.saving.register_keras_serializable(package='molcraft')
279
+ class CrippenLogPContribution(Feature):
280
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
281
+ return chem.logp_contributions(mol)
282
+
283
+
284
+ @keras.saving.register_keras_serializable(package='molcraft')
285
+ class CrippenMolarRefractivityContribution(Feature):
286
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
287
+ return chem.molar_refractivity_contribution(mol)
288
+
289
+
290
+ @keras.saving.register_keras_serializable(package='molcraft')
291
+ class TPSAContribution(Feature):
292
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
293
+ return chem.tpsa_contribution(mol)
294
+
295
+
296
+ @keras.saving.register_keras_serializable(package='molcraft')
297
+ class LabuteASAContribution(Feature):
298
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
299
+ return chem.asa_contribution(mol)
300
+
301
+
302
+ @keras.saving.register_keras_serializable(package='molcraft')
303
+ class GasteigerCharge(Feature):
304
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
305
+ return chem.gasteiger_charges(mol)
306
+
307
+
308
+ @keras.saving.register_keras_serializable(package='molcraft')
309
+ class BondType(Feature):
310
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
311
+ return [str(bond.GetBondType()).lower() for bond in mol.bonds]
312
+
313
+
314
+ @keras.saving.register_keras_serializable(package='molcraft')
315
+ class Stereo(Feature):
316
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
317
+ return [
318
+ str(bond.GetStereo()).replace('STEREO', '').capitalize()
319
+ for bond in mol.bonds
320
+ ]
321
+
322
+
323
+ @keras.saving.register_keras_serializable(package='molcraft')
324
+ class IsConjugated(Feature):
325
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
326
+ return [bond.GetIsConjugated() for bond in mol.bonds]
327
+
328
+
329
+ @keras.saving.register_keras_serializable(package='molcraft')
330
+ class IsRotatable(Feature):
331
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
332
+ return chem.rotatable_bonds(mol)
333
+
334
+
335
+ default_vocabulary = {
336
+ 'AtomType': [
337
+ '*', 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na',
338
+ 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V',
339
+ 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se',
340
+ 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh',
341
+ 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba',
342
+ 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho',
343
+ 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt',
344
+ 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac',
345
+ 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm',
346
+ 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg',
347
+ 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'
348
+ ],
349
+ 'Degree': [
350
+ 0, 1, 2, 3, 4, 5, 6, 7, 8
351
+ ],
352
+ 'TotalNumHs': [
353
+ 0, 1, 2, 3, 4
354
+ ],
355
+ 'TotalValence': [
356
+ 0, 1, 2, 3, 4, 5, 6, 7, 8
357
+ ],
358
+ 'Hybridization': [
359
+ "s", "sp", "sp2", "sp3", "sp3d", "sp3d2", "unspecified"
360
+ ],
361
+ 'CIPCode': [
362
+ "R", "S", "None"
363
+ ],
364
+ 'FormalCharge': [
365
+ -3, -2, -1, 0, 1, 2, 3
366
+ ],
367
+ 'NumRadicalElectrons': [
368
+ 0, 1, 2, 3, 4
369
+ ],
370
+ 'RingSize': [
371
+ -1, 3, 4, 5, 6, 7, 8
372
+ ],
373
+ 'BondType': [
374
+ "zero", "single", "double", "triple", "aromatic"
375
+ ],
376
+ 'Stereo': [
377
+ "E", "Z", "Any", "None"
378
+ ],
379
+ }
380
+
381
+
382
+ def warn(message: str) -> None:
383
+ warnings.warn(
384
+ message=message,
385
+ category=UserWarning,
386
+ stacklevel=1
387
+ )