molcraft 0.1.0rc9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/datasets.py ADDED
@@ -0,0 +1,132 @@
1
+ import warnings
2
+ import numpy as np
3
+ import pandas as pd
4
+ import typing
5
+
6
+
7
+ def split(
8
+ data: pd.DataFrame | np.ndarray,
9
+ *,
10
+ train_size: float | None = None,
11
+ validation_size: float | None = None,
12
+ test_size: float | None = None,
13
+ groups: str | np.ndarray = None,
14
+ shuffle: bool = False,
15
+ random_seed: int | None = None,
16
+ ) -> tuple[np.ndarray | pd.DataFrame, ...]:
17
+ """Splits the dataset into subsets.
18
+
19
+ Args:
20
+ data:
21
+ A pd.DataFrame or np.ndarray object.
22
+ train_size:
23
+ The size of the train set.
24
+ validation_size:
25
+ The size of the validation set.
26
+ test_size:
27
+ The size of the test set.
28
+ groups:
29
+ The groups to perform the splitting on.
30
+ shuffle:
31
+ Whether the dataset should be shuffled prior to splitting.
32
+ random_seed:
33
+ The random state/seed. Only applicable if shuffling.
34
+ """
35
+ if not isinstance(data, (pd.DataFrame, np.ndarray)):
36
+ raise ValueError(f'Unsupported `data` type ({type(data)}).')
37
+
38
+ if isinstance(groups, str):
39
+ groups = data[groups].values
40
+ elif groups is None:
41
+ groups = np.arange(len(data))
42
+
43
+ indices = np.unique(groups)
44
+ size = len(indices)
45
+
46
+ if not train_size and not test_size:
47
+ raise ValueError(
48
+ f'Found both `train_size` and `test_size` to be `None`, '
49
+ f'specify at least one of them.'
50
+ )
51
+ if isinstance(test_size, float):
52
+ test_size = int(size * test_size)
53
+ if isinstance(train_size, float):
54
+ train_size = int(size * train_size)
55
+ if isinstance(validation_size, float):
56
+ validation_size = int(size * validation_size)
57
+ elif not validation_size:
58
+ validation_size = 0
59
+
60
+ if not train_size:
61
+ train_size = (size - test_size - validation_size)
62
+ if not test_size:
63
+ test_size = (size - train_size - validation_size)
64
+
65
+ remainder = size - (train_size + validation_size + test_size)
66
+ if remainder < 0:
67
+ raise ValueError(
68
+ f'subset sizes added up to more than the data size.'
69
+ )
70
+ train_size += remainder
71
+
72
+ if shuffle:
73
+ np.random.seed(random_seed)
74
+ np.random.shuffle(indices)
75
+
76
+ train_mask = np.isin(groups, indices[:train_size])
77
+ test_mask = np.isin(groups, indices[-test_size:])
78
+ if not validation_size:
79
+ return data[train_mask], data[test_mask]
80
+ validation_mask = np.isin(groups, indices[train_size:-test_size])
81
+ return data[train_mask], data[validation_mask], data[test_mask]
82
+
83
+ def cv_split(
84
+ data: pd.DataFrame | np.ndarray,
85
+ num_splits: int = 10,
86
+ groups: str | np.ndarray = None,
87
+ shuffle: bool = False,
88
+ random_seed: int | None = None,
89
+ ) -> typing.Iterator[
90
+ tuple[np.ndarray | pd.DataFrame, np.ndarray | pd.DataFrame]
91
+ ]:
92
+ """Splits the dataset into cross-validation folds.
93
+
94
+ Args:
95
+ data:
96
+ A pd.DataFrame or np.ndarray object.
97
+ num_splits:
98
+ The number of cross-validation folds.
99
+ groups:
100
+ The groups to perform the splitting on.
101
+ shuffle:
102
+ Whether the dataset should be shuffled prior to splitting.
103
+ random_seed:
104
+ The random state/seed. Only applicable if shuffling.
105
+ """
106
+ if not isinstance(data, (pd.DataFrame, np.ndarray)):
107
+ raise ValueError(f'Unsupported `data` type ({type(data)}).')
108
+
109
+ if isinstance(groups, str):
110
+ groups = data[groups].values
111
+ elif groups is None:
112
+ groups = np.arange(len(data))
113
+
114
+ indices = np.unique(groups)
115
+ size = len(indices)
116
+
117
+ if num_splits > size:
118
+ raise ValueError(
119
+ f'`num_splits` ({num_splits}) must not be greater than'
120
+ f'the data size or the number of groups ({size}).'
121
+ )
122
+ if shuffle:
123
+ np.random.seed(random_seed)
124
+ np.random.shuffle(indices)
125
+
126
+ indices_splits = np.array_split(indices, num_splits)
127
+
128
+ for k in range(num_splits):
129
+ test_indices = indices_splits[k]
130
+ test_mask = np.isin(groups, test_indices)
131
+ train_mask = ~test_mask
132
+ yield data[train_mask], data[test_mask]
@@ -0,0 +1,149 @@
1
+ import warnings
2
+ import keras
3
+ import numpy as np
4
+
5
+ from rdkit.Chem import rdMolDescriptors
6
+
7
+ from molcraft import chem
8
+ from molcraft import features
9
+
10
+
11
+ @keras.saving.register_keras_serializable(package='molcraft')
12
+ class Descriptor(features.Feature):
13
+
14
+ def __call__(self, mol: chem.Mol) -> np.ndarray:
15
+ if not isinstance(mol, chem.Mol):
16
+ raise ValueError(
17
+ f'Input to {self.name} must be a `chem.Mol` object.'
18
+ )
19
+ descriptor = self.call(mol)
20
+ func = (
21
+ self._featurize_categorical if self.vocab else
22
+ self._featurize_floating
23
+ )
24
+ if not isinstance(descriptor, (tuple, list, np.ndarray)):
25
+ descriptor = [descriptor]
26
+
27
+ descriptors = []
28
+ for value in descriptor:
29
+ descriptors.append(func(value))
30
+ return np.concatenate(descriptors)
31
+
32
+
33
+ @keras.saving.register_keras_serializable(package='molcraft')
34
+ class Descriptor3D(Descriptor):
35
+
36
+ def __call__(self, mol: chem.Mol) -> np.ndarray:
37
+ if not isinstance(mol, chem.Mol):
38
+ raise ValueError(
39
+ f'Input to {self.name} must be a `chem.Mol` object.'
40
+ )
41
+ if mol.num_conformers == 0:
42
+ raise ValueError(
43
+ f'The inputted `chem.Mol` to {self.name} must embed a conformer. '
44
+ f'It is recommended that {self.name} is used as a molecule feature '
45
+ 'for `MolGraphFeaturizer3D`, which by default embeds a conformer.'
46
+ )
47
+ return super().__call__(mol)
48
+
49
+
50
+ @keras.saving.register_keras_serializable(package='molcraft')
51
+ class MolWeight(Descriptor):
52
+ def call(self, mol: chem.Mol) -> np.ndarray:
53
+ return rdMolDescriptors.CalcExactMolWt(mol)
54
+
55
+
56
+ @keras.saving.register_keras_serializable(package='molcraft')
57
+ class TotalPolarSurfaceArea(Descriptor):
58
+ def call(self, mol: chem.Mol) -> np.ndarray:
59
+ return rdMolDescriptors.CalcTPSA(mol)
60
+
61
+
62
+ @keras.saving.register_keras_serializable(package='molcraft')
63
+ class LogP(Descriptor):
64
+ """Crippen logP."""
65
+ def call(self, mol: chem.Mol) -> np.ndarray:
66
+ return rdMolDescriptors.CalcCrippenDescriptors(mol)[0]
67
+
68
+
69
+ @keras.saving.register_keras_serializable(package='molcraft')
70
+ class MolarRefractivity(Descriptor):
71
+ """Crippen molar refractivity."""
72
+ def call(self, mol: chem.Mol) -> np.ndarray:
73
+ return rdMolDescriptors.CalcCrippenDescriptors(mol)[1]
74
+
75
+
76
+ @keras.saving.register_keras_serializable(package='molcraft')
77
+ class NumAtoms(Descriptor):
78
+ def call(self, mol: chem.Mol) -> np.ndarray:
79
+ return rdMolDescriptors.CalcNumAtoms(mol)
80
+
81
+
82
+ @keras.saving.register_keras_serializable(package='molcraft')
83
+ class NumHeavyAtoms(Descriptor):
84
+ def call(self, mol: chem.Mol) -> np.ndarray:
85
+ return rdMolDescriptors.CalcNumHeavyAtoms(mol)
86
+
87
+
88
+ @keras.saving.register_keras_serializable(package='molcraft')
89
+ class NumHeteroatoms(Descriptor):
90
+ def call(self, mol: chem.Mol) -> np.ndarray:
91
+ return rdMolDescriptors.CalcNumHeteroatoms(mol)
92
+
93
+
94
+ @keras.saving.register_keras_serializable(package='molcraft')
95
+ class NumHydrogenDonors(Descriptor):
96
+ def call(self, mol: chem.Mol) -> np.ndarray:
97
+ return rdMolDescriptors.CalcNumHBD(mol)
98
+
99
+
100
+ @keras.saving.register_keras_serializable(package='molcraft')
101
+ class NumHydrogenAcceptors(Descriptor):
102
+ def call(self, mol: chem.Mol) -> np.ndarray:
103
+ return rdMolDescriptors.CalcNumHBA(mol)
104
+
105
+
106
+ @keras.saving.register_keras_serializable(package='molcraft')
107
+ class NumRotatableBonds(Descriptor):
108
+ def call(self, mol: chem.Mol) -> np.ndarray:
109
+ return rdMolDescriptors.CalcNumRotatableBonds(mol)
110
+
111
+
112
+ @keras.saving.register_keras_serializable(package='molcraft')
113
+ class NumRings(Descriptor):
114
+ def call(self, mol: chem.Mol) -> np.ndarray:
115
+ return rdMolDescriptors.CalcNumRings(mol)
116
+
117
+
118
+ @keras.saving.register_keras_serializable(package='molcraft')
119
+ class NumAromaticRings(Descriptor):
120
+ def call(self, mol: chem.Mol) -> np.ndarray:
121
+ return rdMolDescriptors.CalcNumAromaticRings(mol)
122
+
123
+
124
+ @keras.saving.register_keras_serializable(package='molcraft')
125
+ class AtomCount(Descriptor):
126
+
127
+ def __init__(self, atom_type: str, **kwargs):
128
+ super().__init__(**kwargs)
129
+ self.atom_type = atom_type
130
+
131
+ def call(self, mol: chem.Mol) -> np.ndarray:
132
+ count = 0
133
+ for atom in mol.atoms:
134
+ if atom.GetSymbol() == self.atom_type:
135
+ count += 1
136
+ return count
137
+
138
+ def get_config(self) -> dict:
139
+ config = super().get_config()
140
+ config['atom_type'] = self.atom_type
141
+ return config
142
+
143
+
144
+ @keras.saving.register_keras_serializable(package='molcraft')
145
+ class ForceFieldEnergy(Descriptor3D):
146
+ """Universal Force Field (UFF) Energy."""
147
+ def call(self, mol: chem.Mol) -> np.ndarray:
148
+ return chem.conformer_energies(mol, method="UFF")
149
+
molcraft/features.py ADDED
@@ -0,0 +1,379 @@
1
+ import warnings
2
+ import abc
3
+ import math
4
+ import keras
5
+ import numpy as np
6
+
7
+ from molcraft import chem
8
+
9
+
10
+ @keras.saving.register_keras_serializable(package='molcraft')
11
+ class Feature(abc.ABC):
12
+
13
+ def __init__(
14
+ self,
15
+ vocab: set[int | str] = None,
16
+ allow_oov: bool = True,
17
+ encode_oov: bool = False,
18
+ dtype: str = 'float32'
19
+ ) -> None:
20
+ self.encode_oov = encode_oov
21
+ self.allow_oov = allow_oov
22
+ self.oov_token = '<oov>'
23
+ self.dtype = dtype
24
+ if not vocab:
25
+ vocab = default_vocabulary.get(self.name, None)
26
+ if vocab:
27
+ if isinstance(vocab, set):
28
+ vocab: list = list(vocab)
29
+ vocab.sort(key=lambda x: x if x is not None else "")
30
+ elif not isinstance(vocab, list):
31
+ vocab: list = list(vocab)
32
+ if self.encode_oov and self.oov_token not in vocab:
33
+ vocab.append(self.oov_token)
34
+ onehot_encodings = np.eye(len(vocab), dtype=self.dtype)
35
+ self.feature_to_onehot = dict(zip(vocab, onehot_encodings))
36
+ self.vocab = vocab
37
+
38
+ @abc.abstractmethod
39
+ def call(self, mol: chem.Mol) -> list[float | int | bool | str]:
40
+ pass
41
+
42
+ def __call__(self, mol: chem.Mol) -> np.ndarray:
43
+ if not isinstance(mol, chem.Mol):
44
+ raise TypeError(f'Input to {self.name} must be a `chem.Mol` object.')
45
+ features = self.call(mol)
46
+ if len(features) != mol.num_atoms and len(features) != mol.num_bonds:
47
+ raise ValueError(
48
+ f'The number of features computed by {self.name} does not '
49
+ 'match the number of atoms or bonds of the `chem.Mol` object. '
50
+ 'Make sure to iterate over `atoms` or `bonds` of the `chem.Mol` '
51
+ 'object when computing features.'
52
+ )
53
+ if len(features) == 0:
54
+ # Edge case: no atoms or bonds in the molecule.
55
+ return np.zeros((0, self.output_dim), dtype=self.dtype)
56
+
57
+ func = (
58
+ self._featurize_categorical if self.vocab else
59
+ self._featurize_floating
60
+ )
61
+ return np.stack([func(x) for x in features])
62
+
63
+
64
+ def get_config(self) -> dict:
65
+ config = {
66
+ 'vocab': self.vocab,
67
+ 'allow_oov': self.allow_oov,
68
+ 'encode_oov': self.encode_oov,
69
+ 'dtype': self.dtype
70
+ }
71
+ return config
72
+
73
+ @classmethod
74
+ def from_config(cls, config: dict) -> 'Feature':
75
+ return cls(**config)
76
+
77
+ @property
78
+ def name(self) -> str:
79
+ return self.__class__.__name__
80
+
81
+ @property
82
+ def output_dim(self) -> int:
83
+ return 1 if not self.vocab else len(self.vocab)
84
+
85
+ def _featurize_categorical(self, feature: str | int) -> np.ndarray:
86
+ encoding = self.feature_to_onehot.get(feature, None)
87
+ if encoding is not None:
88
+ return encoding
89
+ if not self.allow_oov:
90
+ raise ValueError(
91
+ f'{feature} could not be encoded, as it was not found in `vocab`. '
92
+ 'To allow OOV features, set `allow_oov` or `encode_oov` to True.'
93
+ )
94
+ oov_encoding = self.feature_to_onehot.get(self.oov_token, None)
95
+ if oov_encoding is None:
96
+ oov_encoding = np.zeros([self.output_dim], dtype=self.dtype)
97
+ return oov_encoding
98
+
99
+ def _featurize_floating(self, value: float | list[float]) -> np.ndarray:
100
+ if not isinstance(value, (int, float, bool)):
101
+ raise ValueError(
102
+ f'{self.name} produced a value of type {type(value)}. '
103
+ 'If it represents a categorical feature, please provide a `vocab` '
104
+ 'to the constructor. If if represents a floating point feature, '
105
+ 'please make sure its `call` method returns a list of values of '
106
+ 'type `float`, `int`, `bool` or `None`.'
107
+ )
108
+ if not math.isfinite(value):
109
+ warnings.warn(
110
+ f'Found value of {self.name} to be non-finite. '
111
+ f'Value received: {value}. Converting it to a value of 0.',
112
+ )
113
+ value = 0.0
114
+ return np.asarray([value], dtype=self.dtype)
115
+
116
+
117
+ @keras.saving.register_keras_serializable(package='molcraft')
118
+ class AtomType(Feature):
119
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
120
+ return [atom.GetSymbol() for atom in mol.atoms]
121
+
122
+
123
+ @keras.saving.register_keras_serializable(package='molcraft')
124
+ class Degree(Feature):
125
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
126
+ return [atom.GetDegree() for atom in mol.atoms]
127
+
128
+
129
+ @keras.saving.register_keras_serializable(package='molcraft')
130
+ class NumHydrogens(Feature):
131
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
132
+ return [atom.GetTotalNumHs() for atom in mol.atoms]
133
+
134
+
135
+ @keras.saving.register_keras_serializable(package='molcraft')
136
+ class Valence(Feature):
137
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
138
+ return [atom.GetTotalValence() for atom in mol.atoms]
139
+
140
+
141
+ @keras.saving.register_keras_serializable(package='molcraft')
142
+ class AtomicWeight(Feature):
143
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
144
+ pt = chem.get_periodic_table()
145
+ return [pt.GetAtomicWeight(atom.GetSymbol()) for atom in mol.atoms]
146
+
147
+
148
+ @keras.saving.register_keras_serializable(package='molcraft')
149
+ class Hybridization(Feature):
150
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
151
+ return [str(atom.GetHybridization()).lower() for atom in mol.atoms]
152
+
153
+
154
+ @keras.saving.register_keras_serializable(package='molcraft')
155
+ class CIPCode(Feature):
156
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
157
+ return [
158
+ atom.GetProp("_CIPCode") if atom.HasProp("_CIPCode") else "None"
159
+ for atom in mol.atoms]
160
+
161
+
162
+ @keras.saving.register_keras_serializable(package='molcraft')
163
+ class RingSize(Feature):
164
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
165
+ def ring_size(atom):
166
+ if not atom.IsInRing():
167
+ return -1
168
+ size = 3
169
+ while not atom.IsInRingSize(size):
170
+ size += 1
171
+ return size
172
+ return [ring_size(atom) for atom in mol.atoms]
173
+
174
+
175
+ @keras.saving.register_keras_serializable(package='molcraft')
176
+ class FormalCharge(Feature):
177
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
178
+ return [atom.GetFormalCharge() for atom in mol.atoms]
179
+
180
+
181
+ @keras.saving.register_keras_serializable(package='molcraft')
182
+ class IsChiralityPossible(Feature):
183
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
184
+ return [atom.HasProp("_ChiralityPossible") for atom in mol.atoms]
185
+
186
+
187
+ @keras.saving.register_keras_serializable(package='molcraft')
188
+ class NumRadicalElectrons(Feature):
189
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
190
+ return [atom.GetNumRadicalElectrons() for atom in mol.atoms]
191
+
192
+
193
+ @keras.saving.register_keras_serializable(package='molcraft')
194
+ class IsAromatic(Feature):
195
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
196
+ return [atom.GetIsAromatic() for atom in mol.atoms]
197
+
198
+
199
+ @keras.saving.register_keras_serializable(package='molcraft')
200
+ class IsHeteroatom(Feature):
201
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
202
+ return chem.hetero_atoms(mol)
203
+
204
+
205
+ @keras.saving.register_keras_serializable(package='molcraft')
206
+ class IsHydrogenDonor(Feature):
207
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
208
+ return chem.hydrogen_donors(mol)
209
+
210
+
211
+ @keras.saving.register_keras_serializable(package='molcraft')
212
+ class IsHydrogenAcceptor(Feature):
213
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
214
+ return chem.hydrogen_acceptors(mol)
215
+
216
+
217
+ @keras.saving.register_keras_serializable(package='molcraft')
218
+ class IsInRing(Feature):
219
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
220
+ return [atom.IsInRing() for atom in mol.atoms]
221
+
222
+
223
+ @keras.saving.register_keras_serializable(package='molcraft')
224
+ class PartialCharge(Feature):
225
+ """Gasteiger partial charge."""
226
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
227
+ return chem.partial_charges(mol)
228
+
229
+
230
+ @keras.saving.register_keras_serializable(package='molcraft')
231
+ class TotalPolarSurfaceAreaContribution(Feature):
232
+ """Total polar surface area (TPSA) contribution."""
233
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
234
+ return chem.total_polar_surface_area_contributions(mol)
235
+
236
+
237
+ @keras.saving.register_keras_serializable(package='molcraft')
238
+ class AccessibleSurfaceAreaContribution(Feature):
239
+ """Labute accessible surface area (ASA) contribution."""
240
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
241
+ return chem.accessible_surface_area_contributions(mol)
242
+
243
+
244
+ @keras.saving.register_keras_serializable(package='molcraft')
245
+ class LogPContribution(Feature):
246
+ """Crippen logP contribution."""
247
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
248
+ return chem.logp_contributions(mol)
249
+
250
+
251
+ @keras.saving.register_keras_serializable(package='molcraft')
252
+ class MolarRefractivityContribution(Feature):
253
+ """Crippen molar refractivity contribution."""
254
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
255
+ return chem.molar_refractivity_contributions(mol)
256
+
257
+
258
+ @keras.saving.register_keras_serializable(package='molcraft')
259
+ class BondType(Feature):
260
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
261
+ return [str(bond.GetBondType()).lower() for bond in mol.bonds]
262
+
263
+
264
+ @keras.saving.register_keras_serializable(package='molcraft')
265
+ class Stereo(Feature):
266
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
267
+ return [
268
+ str(bond.GetStereo()).replace('STEREO', '').capitalize()
269
+ for bond in mol.bonds
270
+ ]
271
+
272
+
273
+ @keras.saving.register_keras_serializable(package='molcraft')
274
+ class IsConjugated(Feature):
275
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
276
+ return [bond.GetIsConjugated() for bond in mol.bonds]
277
+
278
+
279
+ @keras.saving.register_keras_serializable(package='molcraft')
280
+ class IsRotatable(Feature):
281
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
282
+ return chem.rotatable_bonds(mol)
283
+
284
+
285
+ @keras.saving.register_keras_serializable(package='molcraft')
286
+ class PairFeature(Feature):
287
+
288
+ def __call__(self, mol: chem.Mol) -> np.ndarray:
289
+ if not isinstance(mol, chem.Mol):
290
+ raise TypeError(f'Input to {self.name} must be a `chem.Mol` instance.')
291
+ features = self.call(mol)
292
+ if len(features) != int(mol.num_atoms**2):
293
+ raise ValueError(
294
+ f'The number of features computed by {self.name} does not '
295
+ 'match the number of node/atom pairs in the `chem.Mol` object. '
296
+ f'Make sure the list of items returned by {self.name}(input) '
297
+ 'correspond to node/atom pairs: '
298
+ '[(0, 0), (0, 1), ..., (0, N), (1, 0), ... (N, N)], '
299
+ 'where N denotes the number of nodes/atoms.'
300
+ )
301
+ func = (
302
+ self._featurize_categorical if self.vocab else
303
+ self._featurize_floating
304
+ )
305
+ return np.asarray([func(x) for x in features], dtype=self.dtype)
306
+
307
+
308
+ @keras.saving.register_keras_serializable(package='molcraft')
309
+ class PairDistance(PairFeature):
310
+
311
+ def __init__(
312
+ self,
313
+ max_distance: int = None,
314
+ allow_oov: int = True,
315
+ encode_oov: bool = True,
316
+ **kwargs,
317
+ ) -> None:
318
+ vocab = kwargs.pop('vocab', None)
319
+ if not vocab:
320
+ if max_distance is None:
321
+ max_distance = 10
322
+ vocab = list(range(max_distance + 1))
323
+ super().__init__(
324
+ vocab=vocab,
325
+ allow_oov=allow_oov,
326
+ encode_oov=encode_oov,
327
+ **kwargs
328
+ )
329
+
330
+ def call(self, mol: chem.Mol) -> list[int]:
331
+ return [int(x) for x in chem.get_distances(mol).reshape(-1)]
332
+
333
+
334
+ default_vocabulary = {
335
+ 'AtomType': [
336
+ '*', 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na',
337
+ 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V',
338
+ 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se',
339
+ 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh',
340
+ 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba',
341
+ 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho',
342
+ 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt',
343
+ 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac',
344
+ 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm',
345
+ 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg',
346
+ 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'
347
+ ],
348
+ 'Degree': [
349
+ 0, 1, 2, 3, 4, 5, 6, 7, 8
350
+ ],
351
+ 'TotalNumHs': [
352
+ 0, 1, 2, 3, 4
353
+ ],
354
+ 'TotalValence': [
355
+ 0, 1, 2, 3, 4, 5, 6, 7, 8
356
+ ],
357
+ 'Hybridization': [
358
+ "s", "sp", "sp2", "sp3", "sp3d", "sp3d2", "unspecified"
359
+ ],
360
+ 'CIPCode': [
361
+ "R", "S", "None"
362
+ ],
363
+ 'FormalCharge': [
364
+ -3, -2, -1, 0, 1, 2, 3
365
+ ],
366
+ 'NumRadicalElectrons': [
367
+ 0, 1, 2, 3, 4
368
+ ],
369
+ 'RingSize': [
370
+ -1, 3, 4, 5, 6, 7, 8
371
+ ],
372
+ 'BondType': [
373
+ "zero", "single", "double", "triple", "aromatic"
374
+ ],
375
+ 'Stereo': [
376
+ "E", "Z", "Any", "None"
377
+ ],
378
+ }
379
+