hjxdl 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/__init__.py +0 -0
- hdl/_version.py +16 -0
- hdl/args/__init__.py +0 -0
- hdl/args/loss_args.py +5 -0
- hdl/controllers/__init__.py +0 -0
- hdl/controllers/al/__init__.py +0 -0
- hdl/controllers/al/al.py +0 -0
- hdl/controllers/al/dispatcher.py +0 -0
- hdl/controllers/al/feedback.py +0 -0
- hdl/controllers/explain/__init__.py +0 -0
- hdl/controllers/explain/shapley.py +293 -0
- hdl/controllers/explain/subgraphx.py +865 -0
- hdl/controllers/train/__init__.py +0 -0
- hdl/controllers/train/rxn_train.py +219 -0
- hdl/controllers/train/train.py +50 -0
- hdl/controllers/train/train_ginet.py +316 -0
- hdl/controllers/train/trainer_base.py +155 -0
- hdl/controllers/train/trainer_iterative.py +389 -0
- hdl/data/__init__.py +0 -0
- hdl/data/dataset/__init__.py +0 -0
- hdl/data/dataset/base_dataset.py +98 -0
- hdl/data/dataset/fp/__init__.py +0 -0
- hdl/data/dataset/fp/fp_dataset.py +122 -0
- hdl/data/dataset/graph/__init__.py +0 -0
- hdl/data/dataset/graph/chiral.py +62 -0
- hdl/data/dataset/graph/gin.py +255 -0
- hdl/data/dataset/graph/molnet.py +362 -0
- hdl/data/dataset/loaders/__init__.py +0 -0
- hdl/data/dataset/loaders/chiral_graph.py +71 -0
- hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
- hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
- hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
- hdl/data/dataset/loaders/general.py +23 -0
- hdl/data/dataset/loaders/spliter.py +86 -0
- hdl/data/dataset/samplers/__init__.py +0 -0
- hdl/data/dataset/samplers/chiral.py +19 -0
- hdl/data/dataset/seq/__init__.py +0 -0
- hdl/data/dataset/seq/rxn_dataset.py +61 -0
- hdl/data/dataset/utils.py +31 -0
- hdl/data/to_mols.py +0 -0
- hdl/features/__init__.py +0 -0
- hdl/features/fp/__init__.py +0 -0
- hdl/features/fp/features_generators.py +235 -0
- hdl/features/graph/__init__.py +0 -0
- hdl/features/graph/featurization.py +297 -0
- hdl/features/utils/__init__.py +0 -0
- hdl/features/utils/utils.py +111 -0
- hdl/layers/__init__.py +0 -0
- hdl/layers/general/__init__.py +0 -0
- hdl/layers/general/gp.py +14 -0
- hdl/layers/general/linear.py +641 -0
- hdl/layers/graph/__init__.py +0 -0
- hdl/layers/graph/chiral_graph.py +230 -0
- hdl/layers/graph/gcn.py +16 -0
- hdl/layers/graph/gin.py +45 -0
- hdl/layers/graph/tetra.py +158 -0
- hdl/layers/graph/transformer.py +188 -0
- hdl/layers/sequential/__init__.py +0 -0
- hdl/metric_loss/__init__.py +0 -0
- hdl/metric_loss/loss.py +79 -0
- hdl/metric_loss/metric.py +178 -0
- hdl/metric_loss/multi_label.py +42 -0
- hdl/metric_loss/nt_xent.py +65 -0
- hdl/models/__init__.py +0 -0
- hdl/models/chiral_gnn.py +176 -0
- hdl/models/fast_transformer.py +234 -0
- hdl/models/ginet.py +189 -0
- hdl/models/linear.py +137 -0
- hdl/models/model_dict.py +18 -0
- hdl/models/norm_flows.py +33 -0
- hdl/models/optim_dict.py +16 -0
- hdl/models/rxn.py +63 -0
- hdl/models/utils.py +83 -0
- hdl/ops/__init__.py +0 -0
- hdl/ops/utils.py +42 -0
- hdl/optims/__init__.py +0 -0
- hdl/optims/nadam.py +86 -0
- hdl/utils/__init__.py +0 -0
- hdl/utils/chemical_tools/__init__.py +2 -0
- hdl/utils/chemical_tools/query_info.py +149 -0
- hdl/utils/chemical_tools/sdf.py +20 -0
- hdl/utils/database_tools/__init__.py +0 -0
- hdl/utils/database_tools/connect.py +28 -0
- hdl/utils/general/__init__.py +0 -0
- hdl/utils/general/glob.py +21 -0
- hdl/utils/schedulers/__init__.py +0 -0
- hdl/utils/schedulers/norm_lr.py +108 -0
- hjxdl-0.0.1.dist-info/METADATA +19 -0
- hjxdl-0.0.1.dist-info/RECORD +91 -0
- hjxdl-0.0.1.dist-info/WHEEL +5 -0
- hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,61 @@
|
|
1
|
+
import typing as t
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
import pkg_resources
|
6
|
+
from rxnfp.tokenization import (
|
7
|
+
SmilesTokenizer,
|
8
|
+
# convert_reaction_to_valid_features_batch,
|
9
|
+
convert_reaction_to_valid_features,
|
10
|
+
)
|
11
|
+
|
12
|
+
from ..base_dataset import CSVDataset
|
13
|
+
|
14
|
+
|
15
|
+
class RXNCSVDataset(CSVDataset):
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
csv_file: str,
|
19
|
+
max_len: int = 512,
|
20
|
+
vocab_path: str = None,
|
21
|
+
splitter: str = ',',
|
22
|
+
smiles_col: str = 'SMILES',
|
23
|
+
target_cols: t.List = [],
|
24
|
+
**kwargs,
|
25
|
+
) -> None:
|
26
|
+
super().__init__(
|
27
|
+
csv_file,
|
28
|
+
splitter=splitter,
|
29
|
+
smiles_col=smiles_col,
|
30
|
+
target_cols=target_cols,
|
31
|
+
**kwargs
|
32
|
+
)
|
33
|
+
if vocab_path is None:
|
34
|
+
vocab_path = pkg_resources.resource_filename(
|
35
|
+
"rxnfp",
|
36
|
+
"models/transformers/bert_ft/vocab.txt"
|
37
|
+
)
|
38
|
+
self.tokenizer = SmilesTokenizer(
|
39
|
+
vocab_path, max_len=max_len
|
40
|
+
)
|
41
|
+
|
42
|
+
def __getitem__(self, index):
|
43
|
+
# rxn_list = [self.df.loc[index][self.smiles_col]]
|
44
|
+
rxn = self.df.loc[index][self.smiles_col]
|
45
|
+
feats = convert_reaction_to_valid_features(
|
46
|
+
rxn,
|
47
|
+
self.tokenizer
|
48
|
+
)
|
49
|
+
X = [
|
50
|
+
torch.tensor(feats.input_ids.astype(np.int64)),
|
51
|
+
torch.tensor(feats.input_mask.astype(np.int64)),
|
52
|
+
torch.tensor(feats.segment_ids.astype(np.int64))
|
53
|
+
]
|
54
|
+
if any(self.target_cols):
|
55
|
+
labels = self.df.loc[index][self.target_cols].tolist()
|
56
|
+
y = torch.LongTensor(labels)
|
57
|
+
|
58
|
+
return X, y
|
59
|
+
else:
|
60
|
+
return X
|
61
|
+
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import typing as t
|
2
|
+
import csv
|
3
|
+
|
4
|
+
|
5
|
+
|
6
|
+
def read_smiles(
|
7
|
+
data_path: str,
|
8
|
+
file_type: str = 'smi',
|
9
|
+
smi_col_names: t.List = [],
|
10
|
+
y_col_name: str = 'None',
|
11
|
+
):
|
12
|
+
smiles_data = []
|
13
|
+
if file_type == 'smi':
|
14
|
+
with open(data_path) as csv_file:
|
15
|
+
csv_reader = csv.reader(csv_file, delimiter=',')
|
16
|
+
for i, row in enumerate(csv_reader):
|
17
|
+
smiles = row[-1]
|
18
|
+
smiles_data.append(smiles)
|
19
|
+
elif file_type == 'csv' and any(smi_col_names):
|
20
|
+
# for _ in smi_col_names:
|
21
|
+
# smiles_data.append([])
|
22
|
+
with open(data_path, 'r') as theFile:
|
23
|
+
reader = csv.DictReader(theFile)
|
24
|
+
for line in reader:
|
25
|
+
# line is { 'workers': 'w0', 'constant': 7.334, 'age': -1.406, ... }
|
26
|
+
# e.g. print( line[ 'workers' ] ) yields 'w0'
|
27
|
+
smiles_data_i = [line[i] for i in smi_col_names]
|
28
|
+
if y_col_name is not None:
|
29
|
+
smiles_data_i.append(line[y_col_name])
|
30
|
+
smiles_data.append(smiles_data_i)
|
31
|
+
return smiles_data
|
hdl/data/to_mols.py
ADDED
File without changes
|
hdl/features/__init__.py
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,235 @@
|
|
1
|
+
from typing import Callable, List, Union
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from rdkit import Chem, DataStructs
|
5
|
+
from rdkit.Chem import AllChem
|
6
|
+
from rdkit.Chem import MACCSkeys
|
7
|
+
|
8
|
+
|
9
|
+
Molecule = Union[str, Chem.Mol]
|
10
|
+
FeaturesGenerator = Callable[[Molecule], np.ndarray]
|
11
|
+
|
12
|
+
|
13
|
+
FEATURES_GENERATOR_REGISTRY = {}
|
14
|
+
|
15
|
+
|
16
|
+
FP_BITS_DICT = {
|
17
|
+
'maccs': 167,
|
18
|
+
'morgan': 2048,
|
19
|
+
'morgan_count': 2048,
|
20
|
+
'rdkit_2d_normalized': 200,
|
21
|
+
'rdkit_2d': 200,
|
22
|
+
}
|
23
|
+
|
24
|
+
|
25
|
+
def register_features_generator(features_generator_name: str) -> Callable[[FeaturesGenerator], FeaturesGenerator]:
|
26
|
+
"""
|
27
|
+
Creates a decorator which registers a features generator in a global dictionary to enable access by name.
|
28
|
+
|
29
|
+
:param features_generator_name: The name to use to access the features generator.
|
30
|
+
:return: A decorator which will add a features generator to the registry using the specified name.
|
31
|
+
"""
|
32
|
+
def decorator(features_generator: FeaturesGenerator) -> FeaturesGenerator:
|
33
|
+
FEATURES_GENERATOR_REGISTRY[features_generator_name] = features_generator
|
34
|
+
return features_generator
|
35
|
+
|
36
|
+
return decorator
|
37
|
+
|
38
|
+
|
39
|
+
def get_features_generator(features_generator_name: str) -> FeaturesGenerator:
|
40
|
+
"""
|
41
|
+
Gets a registered features generator by name.
|
42
|
+
|
43
|
+
:param features_generator_name: The name of the features generator.
|
44
|
+
:return: The desired features generator.
|
45
|
+
"""
|
46
|
+
if features_generator_name not in FEATURES_GENERATOR_REGISTRY:
|
47
|
+
raise ValueError(f'Features generator "{features_generator_name}" could not be found. '
|
48
|
+
f'If this generator relies on rdkit features, you may need to install descriptastorus.')
|
49
|
+
|
50
|
+
return FEATURES_GENERATOR_REGISTRY[features_generator_name]
|
51
|
+
|
52
|
+
|
53
|
+
def get_available_features_generators() -> List[str]:
|
54
|
+
"""Returns a list of names of available features generators."""
|
55
|
+
return list(FEATURES_GENERATOR_REGISTRY.keys())
|
56
|
+
|
57
|
+
|
58
|
+
MORGAN_RADIUS = 2
|
59
|
+
MORGAN_NUM_BITS = 2048
|
60
|
+
|
61
|
+
|
62
|
+
@register_features_generator('morgan')
|
63
|
+
def morgan_binary_features_generator(mol: Molecule,
|
64
|
+
radius: int = MORGAN_RADIUS,
|
65
|
+
num_bits: int = MORGAN_NUM_BITS) -> np.ndarray:
|
66
|
+
"""
|
67
|
+
Generates a binary Morgan fingerprint for a molecule.
|
68
|
+
|
69
|
+
:param mol: A molecule (i.e., either a SMILES or an RDKit molecule).
|
70
|
+
:param radius: Morgan fingerprint radius.
|
71
|
+
:param num_bits: Number of bits in Morgan fingerprint.
|
72
|
+
:return: A 1D numpy array containing the binary Morgan fingerprint.
|
73
|
+
"""
|
74
|
+
mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
|
75
|
+
features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=num_bits)
|
76
|
+
features = np.zeros((1,))
|
77
|
+
DataStructs.ConvertToNumpyArray(features_vec, features)
|
78
|
+
|
79
|
+
return features
|
80
|
+
|
81
|
+
|
82
|
+
@register_features_generator('morgan_count')
|
83
|
+
def morgan_counts_features_generator(mol: Molecule,
|
84
|
+
radius: int = MORGAN_RADIUS,
|
85
|
+
num_bits: int = MORGAN_NUM_BITS) -> np.ndarray:
|
86
|
+
"""
|
87
|
+
Generates a counts-based Morgan fingerprint for a molecule.
|
88
|
+
|
89
|
+
:param mol: A molecule (i.e., either a SMILES or an RDKit molecule).
|
90
|
+
:param radius: Morgan fingerprint radius.
|
91
|
+
:param num_bits: Number of bits in Morgan fingerprint.
|
92
|
+
:return: A 1D numpy array containing the counts-based Morgan fingerprint.
|
93
|
+
"""
|
94
|
+
mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
|
95
|
+
features_vec = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=num_bits)
|
96
|
+
features = np.zeros((1,))
|
97
|
+
DataStructs.ConvertToNumpyArray(features_vec, features)
|
98
|
+
|
99
|
+
return features
|
100
|
+
|
101
|
+
|
102
|
+
@register_features_generator('maccs')
|
103
|
+
def macss_features_generator(
|
104
|
+
mol
|
105
|
+
) -> np.ndarray:
|
106
|
+
mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
|
107
|
+
vec = MACCSkeys.GenMACCSKeys(mol)
|
108
|
+
bv = list(vec.GetOnBits())
|
109
|
+
arr = np.zeros(167)
|
110
|
+
arr[bv] = 1
|
111
|
+
return arr
|
112
|
+
|
113
|
+
|
114
|
+
try:
|
115
|
+
from descriptastorus.descriptors import rdDescriptors, rdNormalizedDescriptors
|
116
|
+
|
117
|
+
@register_features_generator('rdkit_2d')
|
118
|
+
def rdkit_2d_features_generator(mol: Molecule) -> np.ndarray:
|
119
|
+
"""
|
120
|
+
Generates RDKit 2D features for a molecule.
|
121
|
+
|
122
|
+
:param mol: A molecule (i.e., either a SMILES or an RDKit molecule).
|
123
|
+
:return: A 1D numpy array containing the RDKit 2D features.
|
124
|
+
"""
|
125
|
+
smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
|
126
|
+
generator = rdDescriptors.RDKit2D()
|
127
|
+
features = generator.process(smiles)[1:]
|
128
|
+
|
129
|
+
return features
|
130
|
+
|
131
|
+
@register_features_generator('rdkit_2d_normalized')
|
132
|
+
def rdkit_2d_normalized_features_generator(mol: Molecule) -> np.ndarray:
|
133
|
+
"""
|
134
|
+
Generates RDKit 2D normalized features for a molecule.
|
135
|
+
|
136
|
+
:param mol: A molecule (i.e., either a SMILES or an RDKit molecule).
|
137
|
+
:return: A 1D numpy array containing the RDKit 2D normalized features.
|
138
|
+
"""
|
139
|
+
smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
|
140
|
+
generator = rdNormalizedDescriptors.RDKit2DNormalized()
|
141
|
+
features = generator.process(smiles)[1:]
|
142
|
+
|
143
|
+
return features
|
144
|
+
except ImportError:
|
145
|
+
@register_features_generator('rdkit_2d')
|
146
|
+
def rdkit_2d_features_generator(mol: Molecule) -> np.ndarray:
|
147
|
+
"""Mock implementation raising an ImportError if descriptastorus cannot be imported."""
|
148
|
+
raise ImportError('Failed to import descriptastorus. Please install descriptastorus '
|
149
|
+
'(https://github.com/bp-kelley/descriptastorus) to use RDKit 2D features.')
|
150
|
+
|
151
|
+
@register_features_generator('rdkit_2d_normalized')
|
152
|
+
def rdkit_2d_normalized_features_generator(mol: Molecule) -> np.ndarray:
|
153
|
+
"""Mock implementation raising an ImportError if descriptastorus cannot be imported."""
|
154
|
+
raise ImportError('Failed to import descriptastorus. Please install descriptastorus '
|
155
|
+
'(https://github.com/bp-kelley/descriptastorus) to use RDKit 2D normalized features.')
|
156
|
+
|
157
|
+
|
158
|
+
@register_features_generator('e3fp')
|
159
|
+
def e3fp_features_generator(mol: Molecule) -> np.ndarray:
|
160
|
+
"""
|
161
|
+
E3FP is a 3D molecular fingerprinting method inspired by Extended Connectivity FingerPrints (ECFP),
|
162
|
+
|
163
|
+
[LINK](https://pubs.acs.org/doi/10.1021/acs.jmedchem.7b00696)
|
164
|
+
Axen SD, Huang XP, Caceres EL, Gendelev L, Roth BL, Keiser MJ.
|
165
|
+
A Simple Representation Of Three-Dimensional Molecular Structure.
|
166
|
+
J. Med. Chem. 60 (17): 7393–7409 (2017).
|
167
|
+
|
168
|
+
The source code: https://github.com/keiserlab/e3fp
|
169
|
+
|
170
|
+
:param mol: A molecule(i.e., either a SMILES or an RDKit molecule).
|
171
|
+
:return: A 1D numpy array containing the E3FP fingerprints
|
172
|
+
"""
|
173
|
+
return NotImplemented
|
174
|
+
|
175
|
+
|
176
|
+
@register_features_generator('whales')
|
177
|
+
def whales_features_generator(mol: Molecule) -> np.ndarray:
|
178
|
+
"""
|
179
|
+
WHALES descriptor is a Weighted Holistic Atom Localization and Entity Shape (WHALES)
|
180
|
+
descriptors starting from an rdkit supplier file.
|
181
|
+
|
182
|
+
[LINK](https://www.nature.com/articles/s42004-018-0043-x)
|
183
|
+
Francesca Grisoni, Daniel Merk, Viviana Consonni, Jan A. Hiss,
|
184
|
+
Sara Giani Tagliabue, Roberto Todeschini & Gisbert Schneider
|
185
|
+
"Scaffold hopping from natural products to synthetic mimetics by
|
186
|
+
holistic molecular similarity", Nature Communications Chemistry 1, 44, 2018.
|
187
|
+
|
188
|
+
The source code: https://github.com/grisoniFr/whales_descriptors
|
189
|
+
|
190
|
+
:param mol: A molecule(i.e., either a SMILES or an RDKit molecule).
|
191
|
+
:return: A 2D numpy array containing the WHALES descriptors.
|
192
|
+
"""
|
193
|
+
return NotImplemented
|
194
|
+
|
195
|
+
|
196
|
+
@register_features_generator('selfies')
|
197
|
+
def selfies_features_generator(mol) -> np.ndarray:
|
198
|
+
"""
|
199
|
+
Self-Referencing Embedded Strings (SELFIES): A 100% robust molecular string representation
|
200
|
+
|
201
|
+
A main objective is to use SELFIES as direct input into machine learning models,
|
202
|
+
in particular in generative models, for the generation of molecular graphs
|
203
|
+
which are syntactically and semantically valid.
|
204
|
+
|
205
|
+
[LINK](https://iopscience.iop.org/article/10.1088/2632-2153/aba947)
|
206
|
+
Mario Krenn et al 2020 Mach. Learn.: Sci. Technol. 1 045024
|
207
|
+
|
208
|
+
The source code: https://github.com/aspuru-guzik-group/selfies
|
209
|
+
|
210
|
+
:param mol: A molecule(i.e., either a SMILES or an RDKit molecule).
|
211
|
+
:return: A 1D numpy array containing the symbols of input molecule in SELFIES style.
|
212
|
+
"""
|
213
|
+
return NotImplemented
|
214
|
+
|
215
|
+
|
216
|
+
"""
|
217
|
+
Custom features generator template.
|
218
|
+
|
219
|
+
Note: The name you use to register the features generator is the name
|
220
|
+
you will specify on the command line when using the --features_generator <name> flag.
|
221
|
+
Ex. python train.py ... --features_generator custom ...
|
222
|
+
|
223
|
+
@register_features_generator('custom')
|
224
|
+
def custom_features_generator(mol: Molecule) -> np.ndarray:
|
225
|
+
# If you want to use the SMILES string
|
226
|
+
smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
|
227
|
+
|
228
|
+
# If you want to use the RDKit molecule
|
229
|
+
mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
|
230
|
+
|
231
|
+
# Replace this with code which generates features from the molecule
|
232
|
+
features = np.array([0, 0, 1])
|
233
|
+
|
234
|
+
return features
|
235
|
+
"""
|
File without changes
|
@@ -0,0 +1,297 @@
|
|
1
|
+
from argparse import Namespace
|
2
|
+
from typing import List, Tuple, Union
|
3
|
+
|
4
|
+
from rdkit import Chem
|
5
|
+
from rdkit.Chem.rdchem import ChiralType
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
# Atom feature sizes
|
11
|
+
ATOMIC_SYMBOLS = ['H', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'Br', 'I']
|
12
|
+
CIP_CHIRALITY = ['R', 'S']
|
13
|
+
ATOM_FEATURES = {
|
14
|
+
'atomic_num': ATOMIC_SYMBOLS,
|
15
|
+
'degree': [0, 1, 2, 3, 4, 5],
|
16
|
+
'formal_charge': [-1, -2, 1, 2, 0],
|
17
|
+
'chiral_tag': [0, 1, 2, 3],
|
18
|
+
'global_chiral_tag': CIP_CHIRALITY,
|
19
|
+
'num_Hs': [0, 1, 2, 3, 4],
|
20
|
+
'hybridization': [
|
21
|
+
Chem.rdchem.HybridizationType.SP,
|
22
|
+
Chem.rdchem.HybridizationType.SP2,
|
23
|
+
Chem.rdchem.HybridizationType.SP3,
|
24
|
+
Chem.rdchem.HybridizationType.SP3D,
|
25
|
+
Chem.rdchem.HybridizationType.SP3D2
|
26
|
+
]
|
27
|
+
}
|
28
|
+
BOND_FEATURES = {
|
29
|
+
'bondtype':[
|
30
|
+
Chem.rdchem.BondType.SINGLE,
|
31
|
+
Chem.rdchem.BondType.DOUBLE,
|
32
|
+
Chem.rdchem.BondType.TRIPLE,
|
33
|
+
Chem.rdchem.BondType.AROMATIC,
|
34
|
+
],
|
35
|
+
'bondstereo':[
|
36
|
+
Chem.rdchem.BondStereo.STEREONONE,
|
37
|
+
Chem.rdchem.BondStereo.STEREOANY,
|
38
|
+
Chem.rdchem.BondStereo.STEREOZ,
|
39
|
+
Chem.rdchem.BondStereo.STEREOE,
|
40
|
+
]
|
41
|
+
}
|
42
|
+
CHIRALTAG_PARITY = {
|
43
|
+
ChiralType.CHI_TETRAHEDRAL_CW: +1,
|
44
|
+
ChiralType.CHI_TETRAHEDRAL_CCW: -1,
|
45
|
+
ChiralType.CHI_UNSPECIFIED: 0,
|
46
|
+
ChiralType.CHI_OTHER: 0, # default
|
47
|
+
}
|
48
|
+
|
49
|
+
# len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic, mass and IsInRing
|
50
|
+
ATOM_FDIM = sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 3
|
51
|
+
BOND_FDIM = sum(len(choices) + 1 for choices in BOND_FEATURES.values()) + 3
|
52
|
+
|
53
|
+
|
54
|
+
def get_atom_fdim() -> int:
|
55
|
+
"""
|
56
|
+
Gets the dimensionality of atom features.
|
57
|
+
:param: Arguments.
|
58
|
+
"""
|
59
|
+
return ATOM_FDIM
|
60
|
+
|
61
|
+
|
62
|
+
def get_bond_fdim() -> int:
|
63
|
+
"""
|
64
|
+
Gets the dimensionality of bond features.
|
65
|
+
:param: Arguments.
|
66
|
+
"""
|
67
|
+
return BOND_FDIM
|
68
|
+
|
69
|
+
|
70
|
+
def onek_encoding_unk(value, choices: List) -> List[int]:
|
71
|
+
"""
|
72
|
+
Creates a one-hot encoding.
|
73
|
+
:param value: The value for which the encoding should be one.
|
74
|
+
:param choices: A list of possible values.
|
75
|
+
:return: A one-hot encoding of the value in a list of length len(choices) + 1.
|
76
|
+
If value is not in the list of choices, then the final element in the encoding is 1.
|
77
|
+
"""
|
78
|
+
encoding = [0] * (len(choices) + 1)
|
79
|
+
index = choices.index(value) if value in choices else -1
|
80
|
+
encoding[index] = 1
|
81
|
+
|
82
|
+
return encoding
|
83
|
+
|
84
|
+
|
85
|
+
def atom_features(
|
86
|
+
atom: Chem.rdchem.Atom,
|
87
|
+
chiral_features: bool = False,
|
88
|
+
global_chiral_features: bool = False
|
89
|
+
) -> List[Union[bool, int, float]]:
|
90
|
+
"""
|
91
|
+
Builds a feature vector for an atom.
|
92
|
+
:param atom: An RDKit atom.
|
93
|
+
:param functional_groups: A k-hot vector indicating the functional groups the atom belongs to.
|
94
|
+
:return: A list containing the atom features.
|
95
|
+
"""
|
96
|
+
features = onek_encoding_unk(atom.GetSymbol(), ATOM_FEATURES['atomic_num']) + \
|
97
|
+
onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \
|
98
|
+
onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge'])
|
99
|
+
features += onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \
|
100
|
+
onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \
|
101
|
+
[1 if atom.GetIsAromatic() else 0] + [1 if atom.IsInRing() else 0] + \
|
102
|
+
[atom.GetMass() * 0.01] # scaled to about the same range as other features
|
103
|
+
if chiral_features:
|
104
|
+
features += onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag'])
|
105
|
+
if global_chiral_features:
|
106
|
+
if atom.HasProp('_CIPCode'):
|
107
|
+
features += onek_encoding_unk(atom.GetProp('_CIPCode'), ATOM_FEATURES['global_chiral_tag'])
|
108
|
+
else:
|
109
|
+
features += onek_encoding_unk(None, ATOM_FEATURES['global_chiral_tag'])
|
110
|
+
return features
|
111
|
+
|
112
|
+
|
113
|
+
def parity_features(atom: Chem.rdchem.Atom) -> int:
|
114
|
+
"""
|
115
|
+
Returns the parity of an atom if it is a tetrahedral center.
|
116
|
+
+1 if CW, -1 if CCW, and 0 if undefined/unknown
|
117
|
+
:param atom: An RDKit atom.
|
118
|
+
"""
|
119
|
+
return CHIRALTAG_PARITY[atom.GetChiralTag()]
|
120
|
+
|
121
|
+
|
122
|
+
def bond_features(bond: Chem.rdchem.Bond) -> List[Union[bool, int, float]]:
|
123
|
+
"""
|
124
|
+
Builds a feature vector for a bond.
|
125
|
+
:param bond: A RDKit bond.
|
126
|
+
:return: A list containing the bond features.
|
127
|
+
"""
|
128
|
+
bond_fdim = get_bond_fdim()
|
129
|
+
|
130
|
+
if bond is None:
|
131
|
+
fbond = [1] + [0] * (bond_fdim - 1)
|
132
|
+
else:
|
133
|
+
bt = bond.GetBondType()
|
134
|
+
# bond is not None
|
135
|
+
fbond = [0] + \
|
136
|
+
onek_encoding_unk(bond.GetBondType(), BOND_FEATURES['bondtype']) + \
|
137
|
+
onek_encoding_unk(bond.GetStereo(), BOND_FEATURES['bondstereo']) + \
|
138
|
+
[(bond.GetIsConjugated() if bt is not None else 0),
|
139
|
+
(bond.IsInRing() if bt is not None else 0)
|
140
|
+
]
|
141
|
+
return fbond
|
142
|
+
|
143
|
+
|
144
|
+
class MolGraph:
|
145
|
+
"""
|
146
|
+
A MolGraph represents the graph structure and featurization of a single molecule.
|
147
|
+
A MolGraph computes the following attributes:
|
148
|
+
- smiles: Smiles string.
|
149
|
+
- n_atoms: The number of atoms in the molecule.
|
150
|
+
- n_bonds: The number of bonds in the molecule.
|
151
|
+
- f_atoms: A mapping from an atom index to a list atom features.
|
152
|
+
- f_bonds: A mapping from a bond index to a list of bond features.
|
153
|
+
- a2b: A mapping from an atom index to a list of incoming bond indices.
|
154
|
+
- b2a: A mapping from a bond index to the index of the atom the bond originates from.
|
155
|
+
- b2revb: A mapping from a bond index to the index of the reverse bond.
|
156
|
+
"""
|
157
|
+
|
158
|
+
def __init__(
|
159
|
+
self,
|
160
|
+
smiles: str,
|
161
|
+
chiral_features: bool = False,
|
162
|
+
global_chiral_features: bool = False
|
163
|
+
# args: Namespace
|
164
|
+
):
|
165
|
+
"""
|
166
|
+
Computes the graph structure and featurization of a molecule.
|
167
|
+
:param smiles: A smiles string.
|
168
|
+
:param args: Arguments.
|
169
|
+
"""
|
170
|
+
self.smiles = smiles
|
171
|
+
self.n_atoms = 0 # number of atoms
|
172
|
+
self.n_bonds = 0 # number of bonds
|
173
|
+
self.f_atoms = [] # mapping from atom index to atom features
|
174
|
+
self.f_bonds = [] # mapping from bond index to concat(in_atom, bond) features
|
175
|
+
self.a2b = [] # mapping from atom index to incoming bond indices
|
176
|
+
self.b2a = [] # mapping from bond index to the index of the atom the bond is coming from
|
177
|
+
self.b2revb = [] # mapping from bond index to the index of the reverse bond
|
178
|
+
self.parity_atoms = [] # mapping from atom index to CW (+1), CCW (-1) or undefined tetra (0)
|
179
|
+
self.edge_index = [] # list of tuples indicating presence of bonds
|
180
|
+
self.parity_bond_index = []
|
181
|
+
|
182
|
+
# Convert smiles to molecule
|
183
|
+
mol = Chem.MolFromSmiles(smiles)
|
184
|
+
|
185
|
+
# add chiral hydrogens
|
186
|
+
H_ids = [a.GetIdx() for a in mol.GetAtoms() if CHIRALTAG_PARITY[a.GetChiralTag()] != 0]
|
187
|
+
if H_ids:
|
188
|
+
mol = Chem.AddHs(mol, onlyOnAtoms=H_ids)
|
189
|
+
|
190
|
+
# remove stereochem label from atoms with less/more than 4 neighbors
|
191
|
+
for i in H_ids:
|
192
|
+
a = mol.GetAtomWithIdx(i)
|
193
|
+
if len(a.GetNeighbors()) != 4:
|
194
|
+
a.SetChiralTag(ChiralType.CHI_UNSPECIFIED)
|
195
|
+
|
196
|
+
# fake the number of "atoms" if we are collapsing substructures
|
197
|
+
self.n_atoms = mol.GetNumAtoms()
|
198
|
+
|
199
|
+
# Get atom features
|
200
|
+
for i, atom in enumerate(mol.GetAtoms()):
|
201
|
+
self.f_atoms.append(atom_features(
|
202
|
+
atom,
|
203
|
+
chiral_features=chiral_features,
|
204
|
+
global_chiral_features=global_chiral_features
|
205
|
+
))
|
206
|
+
self.parity_atoms.append(parity_features(atom))
|
207
|
+
self.f_atoms = [self.f_atoms[i] for i in range(self.n_atoms)]
|
208
|
+
|
209
|
+
for _ in range(self.n_atoms):
|
210
|
+
self.a2b.append([])
|
211
|
+
|
212
|
+
# Get bond features
|
213
|
+
for a1 in range(self.n_atoms):
|
214
|
+
for a2 in range(a1 + 1, self.n_atoms):
|
215
|
+
bond = mol.GetBondBetweenAtoms(a1, a2)
|
216
|
+
|
217
|
+
if bond is None:
|
218
|
+
continue
|
219
|
+
|
220
|
+
self.edge_index.extend([(a1, a2), (a2, a1)])
|
221
|
+
|
222
|
+
f_bond = bond_features(bond)
|
223
|
+
|
224
|
+
self.f_bonds.append(f_bond)
|
225
|
+
self.f_bonds.append(f_bond)
|
226
|
+
|
227
|
+
# Update index mappings
|
228
|
+
b1 = self.n_bonds
|
229
|
+
b2 = b1 + 1
|
230
|
+
self.a2b[a2].append(b1) # b1 = a1 --> a2
|
231
|
+
self.b2a.append(a1)
|
232
|
+
self.a2b[a1].append(b2) # b2 = a2 --> a1
|
233
|
+
self.b2a.append(a2)
|
234
|
+
self.b2revb.append(b2)
|
235
|
+
self.b2revb.append(b1)
|
236
|
+
self.n_bonds += 2
|
237
|
+
for ai, ccw_mask in enumerate(self.parity_atoms):
|
238
|
+
if ccw_mask == 0: continue
|
239
|
+
nei_idx = []
|
240
|
+
for ei, e in enumerate(self.edge_index):
|
241
|
+
if e[0] == ai: nei_idx.append(ei)
|
242
|
+
if ccw_mask == -1:
|
243
|
+
nei_idx = [nei_idx[i] for i in [1, 0, 2, 3]]
|
244
|
+
self.parity_bond_index.extend(nei_idx)
|
245
|
+
|
246
|
+
|
247
|
+
def get_components(self) -> Tuple[torch.FloatTensor, torch.FloatTensor,
|
248
|
+
torch.LongTensor, torch.LongTensor, torch.LongTensor,
|
249
|
+
List[Tuple[int, int]], List[Tuple[int, int]]]:
|
250
|
+
"""
|
251
|
+
Returns the components of the BatchMolGraph.
|
252
|
+
:return: A tuple containing PyTorch tensors with the atom features, bond features, and graph structure
|
253
|
+
and two lists indicating the scope of the atoms and bonds (i.e. which molecules they belong to).
|
254
|
+
"""
|
255
|
+
return (
|
256
|
+
self.f_atoms,
|
257
|
+
self.f_bonds,
|
258
|
+
self.a2b,
|
259
|
+
self.b2a,
|
260
|
+
self.b2revb,
|
261
|
+
self.a_scope,
|
262
|
+
self.b_scope,
|
263
|
+
self.parity_atoms
|
264
|
+
)
|
265
|
+
|
266
|
+
def get_b2b(self) -> torch.LongTensor:
|
267
|
+
"""
|
268
|
+
Computes (if necessary) and returns a mapping from each bond index to all the incoming bond indices.
|
269
|
+
:return: A PyTorch tensor containing the mapping from each bond index to all the incoming bond indices.
|
270
|
+
"""
|
271
|
+
|
272
|
+
if self.b2b is None:
|
273
|
+
b2b = self.a2b[self.b2a] # num_bonds x max_num_bonds
|
274
|
+
# b2b includes reverse edge for each bond so need to mask out
|
275
|
+
revmask = (b2b != self.b2revb.unsqueeze(1).repeat(1, b2b.size(1))).long() # num_bonds x max_num_bonds
|
276
|
+
self.b2b = b2b * revmask
|
277
|
+
|
278
|
+
return self.b2b
|
279
|
+
|
280
|
+
def get_a2a(self) -> torch.LongTensor:
|
281
|
+
"""
|
282
|
+
Computes (if necessary) and returns a mapping from each atom index to all neighboring atom indices.
|
283
|
+
:return: A PyTorch tensor containing the mapping from each bond index to all the incodming bond indices.
|
284
|
+
"""
|
285
|
+
if self.a2a is None:
|
286
|
+
# b = a1 --> a2
|
287
|
+
# a2b maps a2 to all incoming bonds b
|
288
|
+
# b2a maps each bond b to the atom it comes from a1
|
289
|
+
# thus b2a[a2b] maps atom a2 to neighboring atoms a1
|
290
|
+
self.a2a = self.b2a[self.a2b] # num_atoms x max_num_bonds
|
291
|
+
|
292
|
+
return self.a2a
|
293
|
+
|
294
|
+
|
295
|
+
|
296
|
+
|
297
|
+
|
File without changes
|