hjxdl 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. hdl/__init__.py +0 -0
  2. hdl/_version.py +16 -0
  3. hdl/args/__init__.py +0 -0
  4. hdl/args/loss_args.py +5 -0
  5. hdl/controllers/__init__.py +0 -0
  6. hdl/controllers/al/__init__.py +0 -0
  7. hdl/controllers/al/al.py +0 -0
  8. hdl/controllers/al/dispatcher.py +0 -0
  9. hdl/controllers/al/feedback.py +0 -0
  10. hdl/controllers/explain/__init__.py +0 -0
  11. hdl/controllers/explain/shapley.py +293 -0
  12. hdl/controllers/explain/subgraphx.py +865 -0
  13. hdl/controllers/train/__init__.py +0 -0
  14. hdl/controllers/train/rxn_train.py +219 -0
  15. hdl/controllers/train/train.py +50 -0
  16. hdl/controllers/train/train_ginet.py +316 -0
  17. hdl/controllers/train/trainer_base.py +155 -0
  18. hdl/controllers/train/trainer_iterative.py +389 -0
  19. hdl/data/__init__.py +0 -0
  20. hdl/data/dataset/__init__.py +0 -0
  21. hdl/data/dataset/base_dataset.py +98 -0
  22. hdl/data/dataset/fp/__init__.py +0 -0
  23. hdl/data/dataset/fp/fp_dataset.py +122 -0
  24. hdl/data/dataset/graph/__init__.py +0 -0
  25. hdl/data/dataset/graph/chiral.py +62 -0
  26. hdl/data/dataset/graph/gin.py +255 -0
  27. hdl/data/dataset/graph/molnet.py +362 -0
  28. hdl/data/dataset/loaders/__init__.py +0 -0
  29. hdl/data/dataset/loaders/chiral_graph.py +71 -0
  30. hdl/data/dataset/loaders/collate_funcs/__init__.py +0 -0
  31. hdl/data/dataset/loaders/collate_funcs/fp.py +56 -0
  32. hdl/data/dataset/loaders/collate_funcs/rxn.py +40 -0
  33. hdl/data/dataset/loaders/general.py +23 -0
  34. hdl/data/dataset/loaders/spliter.py +86 -0
  35. hdl/data/dataset/samplers/__init__.py +0 -0
  36. hdl/data/dataset/samplers/chiral.py +19 -0
  37. hdl/data/dataset/seq/__init__.py +0 -0
  38. hdl/data/dataset/seq/rxn_dataset.py +61 -0
  39. hdl/data/dataset/utils.py +31 -0
  40. hdl/data/to_mols.py +0 -0
  41. hdl/features/__init__.py +0 -0
  42. hdl/features/fp/__init__.py +0 -0
  43. hdl/features/fp/features_generators.py +235 -0
  44. hdl/features/graph/__init__.py +0 -0
  45. hdl/features/graph/featurization.py +297 -0
  46. hdl/features/utils/__init__.py +0 -0
  47. hdl/features/utils/utils.py +111 -0
  48. hdl/layers/__init__.py +0 -0
  49. hdl/layers/general/__init__.py +0 -0
  50. hdl/layers/general/gp.py +14 -0
  51. hdl/layers/general/linear.py +641 -0
  52. hdl/layers/graph/__init__.py +0 -0
  53. hdl/layers/graph/chiral_graph.py +230 -0
  54. hdl/layers/graph/gcn.py +16 -0
  55. hdl/layers/graph/gin.py +45 -0
  56. hdl/layers/graph/tetra.py +158 -0
  57. hdl/layers/graph/transformer.py +188 -0
  58. hdl/layers/sequential/__init__.py +0 -0
  59. hdl/metric_loss/__init__.py +0 -0
  60. hdl/metric_loss/loss.py +79 -0
  61. hdl/metric_loss/metric.py +178 -0
  62. hdl/metric_loss/multi_label.py +42 -0
  63. hdl/metric_loss/nt_xent.py +65 -0
  64. hdl/models/__init__.py +0 -0
  65. hdl/models/chiral_gnn.py +176 -0
  66. hdl/models/fast_transformer.py +234 -0
  67. hdl/models/ginet.py +189 -0
  68. hdl/models/linear.py +137 -0
  69. hdl/models/model_dict.py +18 -0
  70. hdl/models/norm_flows.py +33 -0
  71. hdl/models/optim_dict.py +16 -0
  72. hdl/models/rxn.py +63 -0
  73. hdl/models/utils.py +83 -0
  74. hdl/ops/__init__.py +0 -0
  75. hdl/ops/utils.py +42 -0
  76. hdl/optims/__init__.py +0 -0
  77. hdl/optims/nadam.py +86 -0
  78. hdl/utils/__init__.py +0 -0
  79. hdl/utils/chemical_tools/__init__.py +2 -0
  80. hdl/utils/chemical_tools/query_info.py +149 -0
  81. hdl/utils/chemical_tools/sdf.py +20 -0
  82. hdl/utils/database_tools/__init__.py +0 -0
  83. hdl/utils/database_tools/connect.py +28 -0
  84. hdl/utils/general/__init__.py +0 -0
  85. hdl/utils/general/glob.py +21 -0
  86. hdl/utils/schedulers/__init__.py +0 -0
  87. hdl/utils/schedulers/norm_lr.py +108 -0
  88. hjxdl-0.0.1.dist-info/METADATA +19 -0
  89. hjxdl-0.0.1.dist-info/RECORD +91 -0
  90. hjxdl-0.0.1.dist-info/WHEEL +5 -0
  91. hjxdl-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,61 @@
1
+ import typing as t
2
+
3
+ import numpy as np
4
+ import torch
5
+ import pkg_resources
6
+ from rxnfp.tokenization import (
7
+ SmilesTokenizer,
8
+ # convert_reaction_to_valid_features_batch,
9
+ convert_reaction_to_valid_features,
10
+ )
11
+
12
+ from ..base_dataset import CSVDataset
13
+
14
+
15
+ class RXNCSVDataset(CSVDataset):
16
+ def __init__(
17
+ self,
18
+ csv_file: str,
19
+ max_len: int = 512,
20
+ vocab_path: str = None,
21
+ splitter: str = ',',
22
+ smiles_col: str = 'SMILES',
23
+ target_cols: t.List = [],
24
+ **kwargs,
25
+ ) -> None:
26
+ super().__init__(
27
+ csv_file,
28
+ splitter=splitter,
29
+ smiles_col=smiles_col,
30
+ target_cols=target_cols,
31
+ **kwargs
32
+ )
33
+ if vocab_path is None:
34
+ vocab_path = pkg_resources.resource_filename(
35
+ "rxnfp",
36
+ "models/transformers/bert_ft/vocab.txt"
37
+ )
38
+ self.tokenizer = SmilesTokenizer(
39
+ vocab_path, max_len=max_len
40
+ )
41
+
42
+ def __getitem__(self, index):
43
+ # rxn_list = [self.df.loc[index][self.smiles_col]]
44
+ rxn = self.df.loc[index][self.smiles_col]
45
+ feats = convert_reaction_to_valid_features(
46
+ rxn,
47
+ self.tokenizer
48
+ )
49
+ X = [
50
+ torch.tensor(feats.input_ids.astype(np.int64)),
51
+ torch.tensor(feats.input_mask.astype(np.int64)),
52
+ torch.tensor(feats.segment_ids.astype(np.int64))
53
+ ]
54
+ if any(self.target_cols):
55
+ labels = self.df.loc[index][self.target_cols].tolist()
56
+ y = torch.LongTensor(labels)
57
+
58
+ return X, y
59
+ else:
60
+ return X
61
+
@@ -0,0 +1,31 @@
1
+ import typing as t
2
+ import csv
3
+
4
+
5
+
6
+ def read_smiles(
7
+ data_path: str,
8
+ file_type: str = 'smi',
9
+ smi_col_names: t.List = [],
10
+ y_col_name: str = 'None',
11
+ ):
12
+ smiles_data = []
13
+ if file_type == 'smi':
14
+ with open(data_path) as csv_file:
15
+ csv_reader = csv.reader(csv_file, delimiter=',')
16
+ for i, row in enumerate(csv_reader):
17
+ smiles = row[-1]
18
+ smiles_data.append(smiles)
19
+ elif file_type == 'csv' and any(smi_col_names):
20
+ # for _ in smi_col_names:
21
+ # smiles_data.append([])
22
+ with open(data_path, 'r') as theFile:
23
+ reader = csv.DictReader(theFile)
24
+ for line in reader:
25
+ # line is { 'workers': 'w0', 'constant': 7.334, 'age': -1.406, ... }
26
+ # e.g. print( line[ 'workers' ] ) yields 'w0'
27
+ smiles_data_i = [line[i] for i in smi_col_names]
28
+ if y_col_name is not None:
29
+ smiles_data_i.append(line[y_col_name])
30
+ smiles_data.append(smiles_data_i)
31
+ return smiles_data
hdl/data/to_mols.py ADDED
File without changes
File without changes
File without changes
@@ -0,0 +1,235 @@
1
+ from typing import Callable, List, Union
2
+
3
+ import numpy as np
4
+ from rdkit import Chem, DataStructs
5
+ from rdkit.Chem import AllChem
6
+ from rdkit.Chem import MACCSkeys
7
+
8
+
9
+ Molecule = Union[str, Chem.Mol]
10
+ FeaturesGenerator = Callable[[Molecule], np.ndarray]
11
+
12
+
13
+ FEATURES_GENERATOR_REGISTRY = {}
14
+
15
+
16
+ FP_BITS_DICT = {
17
+ 'maccs': 167,
18
+ 'morgan': 2048,
19
+ 'morgan_count': 2048,
20
+ 'rdkit_2d_normalized': 200,
21
+ 'rdkit_2d': 200,
22
+ }
23
+
24
+
25
+ def register_features_generator(features_generator_name: str) -> Callable[[FeaturesGenerator], FeaturesGenerator]:
26
+ """
27
+ Creates a decorator which registers a features generator in a global dictionary to enable access by name.
28
+
29
+ :param features_generator_name: The name to use to access the features generator.
30
+ :return: A decorator which will add a features generator to the registry using the specified name.
31
+ """
32
+ def decorator(features_generator: FeaturesGenerator) -> FeaturesGenerator:
33
+ FEATURES_GENERATOR_REGISTRY[features_generator_name] = features_generator
34
+ return features_generator
35
+
36
+ return decorator
37
+
38
+
39
+ def get_features_generator(features_generator_name: str) -> FeaturesGenerator:
40
+ """
41
+ Gets a registered features generator by name.
42
+
43
+ :param features_generator_name: The name of the features generator.
44
+ :return: The desired features generator.
45
+ """
46
+ if features_generator_name not in FEATURES_GENERATOR_REGISTRY:
47
+ raise ValueError(f'Features generator "{features_generator_name}" could not be found. '
48
+ f'If this generator relies on rdkit features, you may need to install descriptastorus.')
49
+
50
+ return FEATURES_GENERATOR_REGISTRY[features_generator_name]
51
+
52
+
53
+ def get_available_features_generators() -> List[str]:
54
+ """Returns a list of names of available features generators."""
55
+ return list(FEATURES_GENERATOR_REGISTRY.keys())
56
+
57
+
58
+ MORGAN_RADIUS = 2
59
+ MORGAN_NUM_BITS = 2048
60
+
61
+
62
+ @register_features_generator('morgan')
63
+ def morgan_binary_features_generator(mol: Molecule,
64
+ radius: int = MORGAN_RADIUS,
65
+ num_bits: int = MORGAN_NUM_BITS) -> np.ndarray:
66
+ """
67
+ Generates a binary Morgan fingerprint for a molecule.
68
+
69
+ :param mol: A molecule (i.e., either a SMILES or an RDKit molecule).
70
+ :param radius: Morgan fingerprint radius.
71
+ :param num_bits: Number of bits in Morgan fingerprint.
72
+ :return: A 1D numpy array containing the binary Morgan fingerprint.
73
+ """
74
+ mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
75
+ features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=num_bits)
76
+ features = np.zeros((1,))
77
+ DataStructs.ConvertToNumpyArray(features_vec, features)
78
+
79
+ return features
80
+
81
+
82
+ @register_features_generator('morgan_count')
83
+ def morgan_counts_features_generator(mol: Molecule,
84
+ radius: int = MORGAN_RADIUS,
85
+ num_bits: int = MORGAN_NUM_BITS) -> np.ndarray:
86
+ """
87
+ Generates a counts-based Morgan fingerprint for a molecule.
88
+
89
+ :param mol: A molecule (i.e., either a SMILES or an RDKit molecule).
90
+ :param radius: Morgan fingerprint radius.
91
+ :param num_bits: Number of bits in Morgan fingerprint.
92
+ :return: A 1D numpy array containing the counts-based Morgan fingerprint.
93
+ """
94
+ mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
95
+ features_vec = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=num_bits)
96
+ features = np.zeros((1,))
97
+ DataStructs.ConvertToNumpyArray(features_vec, features)
98
+
99
+ return features
100
+
101
+
102
+ @register_features_generator('maccs')
103
+ def macss_features_generator(
104
+ mol
105
+ ) -> np.ndarray:
106
+ mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
107
+ vec = MACCSkeys.GenMACCSKeys(mol)
108
+ bv = list(vec.GetOnBits())
109
+ arr = np.zeros(167)
110
+ arr[bv] = 1
111
+ return arr
112
+
113
+
114
+ try:
115
+ from descriptastorus.descriptors import rdDescriptors, rdNormalizedDescriptors
116
+
117
+ @register_features_generator('rdkit_2d')
118
+ def rdkit_2d_features_generator(mol: Molecule) -> np.ndarray:
119
+ """
120
+ Generates RDKit 2D features for a molecule.
121
+
122
+ :param mol: A molecule (i.e., either a SMILES or an RDKit molecule).
123
+ :return: A 1D numpy array containing the RDKit 2D features.
124
+ """
125
+ smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
126
+ generator = rdDescriptors.RDKit2D()
127
+ features = generator.process(smiles)[1:]
128
+
129
+ return features
130
+
131
+ @register_features_generator('rdkit_2d_normalized')
132
+ def rdkit_2d_normalized_features_generator(mol: Molecule) -> np.ndarray:
133
+ """
134
+ Generates RDKit 2D normalized features for a molecule.
135
+
136
+ :param mol: A molecule (i.e., either a SMILES or an RDKit molecule).
137
+ :return: A 1D numpy array containing the RDKit 2D normalized features.
138
+ """
139
+ smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
140
+ generator = rdNormalizedDescriptors.RDKit2DNormalized()
141
+ features = generator.process(smiles)[1:]
142
+
143
+ return features
144
+ except ImportError:
145
+ @register_features_generator('rdkit_2d')
146
+ def rdkit_2d_features_generator(mol: Molecule) -> np.ndarray:
147
+ """Mock implementation raising an ImportError if descriptastorus cannot be imported."""
148
+ raise ImportError('Failed to import descriptastorus. Please install descriptastorus '
149
+ '(https://github.com/bp-kelley/descriptastorus) to use RDKit 2D features.')
150
+
151
+ @register_features_generator('rdkit_2d_normalized')
152
+ def rdkit_2d_normalized_features_generator(mol: Molecule) -> np.ndarray:
153
+ """Mock implementation raising an ImportError if descriptastorus cannot be imported."""
154
+ raise ImportError('Failed to import descriptastorus. Please install descriptastorus '
155
+ '(https://github.com/bp-kelley/descriptastorus) to use RDKit 2D normalized features.')
156
+
157
+
158
+ @register_features_generator('e3fp')
159
+ def e3fp_features_generator(mol: Molecule) -> np.ndarray:
160
+ """
161
+ E3FP is a 3D molecular fingerprinting method inspired by Extended Connectivity FingerPrints (ECFP),
162
+
163
+ [LINK](https://pubs.acs.org/doi/10.1021/acs.jmedchem.7b00696)
164
+ Axen SD, Huang XP, Caceres EL, Gendelev L, Roth BL, Keiser MJ.
165
+ A Simple Representation Of Three-Dimensional Molecular Structure.
166
+ J. Med. Chem. 60 (17): 7393–7409 (2017).
167
+
168
+ The source code: https://github.com/keiserlab/e3fp
169
+
170
+ :param mol: A molecule(i.e., either a SMILES or an RDKit molecule).
171
+ :return: A 1D numpy array containing the E3FP fingerprints
172
+ """
173
+ return NotImplemented
174
+
175
+
176
+ @register_features_generator('whales')
177
+ def whales_features_generator(mol: Molecule) -> np.ndarray:
178
+ """
179
+ WHALES descriptor is a Weighted Holistic Atom Localization and Entity Shape (WHALES)
180
+ descriptors starting from an rdkit supplier file.
181
+
182
+ [LINK](https://www.nature.com/articles/s42004-018-0043-x)
183
+ Francesca Grisoni, Daniel Merk, Viviana Consonni, Jan A. Hiss,
184
+ Sara Giani Tagliabue, Roberto Todeschini & Gisbert Schneider
185
+ "Scaffold hopping from natural products to synthetic mimetics by
186
+ holistic molecular similarity", Nature Communications Chemistry 1, 44, 2018.
187
+
188
+ The source code: https://github.com/grisoniFr/whales_descriptors
189
+
190
+ :param mol: A molecule(i.e., either a SMILES or an RDKit molecule).
191
+ :return: A 2D numpy array containing the WHALES descriptors.
192
+ """
193
+ return NotImplemented
194
+
195
+
196
+ @register_features_generator('selfies')
197
+ def selfies_features_generator(mol) -> np.ndarray:
198
+ """
199
+ Self-Referencing Embedded Strings (SELFIES): A 100% robust molecular string representation
200
+
201
+ A main objective is to use SELFIES as direct input into machine learning models,
202
+ in particular in generative models, for the generation of molecular graphs
203
+ which are syntactically and semantically valid.
204
+
205
+ [LINK](https://iopscience.iop.org/article/10.1088/2632-2153/aba947)
206
+ Mario Krenn et al 2020 Mach. Learn.: Sci. Technol. 1 045024
207
+
208
+ The source code: https://github.com/aspuru-guzik-group/selfies
209
+
210
+ :param mol: A molecule(i.e., either a SMILES or an RDKit molecule).
211
+ :return: A 1D numpy array containing the symbols of input molecule in SELFIES style.
212
+ """
213
+ return NotImplemented
214
+
215
+
216
+ """
217
+ Custom features generator template.
218
+
219
+ Note: The name you use to register the features generator is the name
220
+ you will specify on the command line when using the --features_generator <name> flag.
221
+ Ex. python train.py ... --features_generator custom ...
222
+
223
+ @register_features_generator('custom')
224
+ def custom_features_generator(mol: Molecule) -> np.ndarray:
225
+ # If you want to use the SMILES string
226
+ smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
227
+
228
+ # If you want to use the RDKit molecule
229
+ mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
230
+
231
+ # Replace this with code which generates features from the molecule
232
+ features = np.array([0, 0, 1])
233
+
234
+ return features
235
+ """
File without changes
@@ -0,0 +1,297 @@
1
+ from argparse import Namespace
2
+ from typing import List, Tuple, Union
3
+
4
+ from rdkit import Chem
5
+ from rdkit.Chem.rdchem import ChiralType
6
+
7
+ import torch
8
+
9
+
10
+ # Atom feature sizes
11
+ ATOMIC_SYMBOLS = ['H', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'Br', 'I']
12
+ CIP_CHIRALITY = ['R', 'S']
13
+ ATOM_FEATURES = {
14
+ 'atomic_num': ATOMIC_SYMBOLS,
15
+ 'degree': [0, 1, 2, 3, 4, 5],
16
+ 'formal_charge': [-1, -2, 1, 2, 0],
17
+ 'chiral_tag': [0, 1, 2, 3],
18
+ 'global_chiral_tag': CIP_CHIRALITY,
19
+ 'num_Hs': [0, 1, 2, 3, 4],
20
+ 'hybridization': [
21
+ Chem.rdchem.HybridizationType.SP,
22
+ Chem.rdchem.HybridizationType.SP2,
23
+ Chem.rdchem.HybridizationType.SP3,
24
+ Chem.rdchem.HybridizationType.SP3D,
25
+ Chem.rdchem.HybridizationType.SP3D2
26
+ ]
27
+ }
28
+ BOND_FEATURES = {
29
+ 'bondtype':[
30
+ Chem.rdchem.BondType.SINGLE,
31
+ Chem.rdchem.BondType.DOUBLE,
32
+ Chem.rdchem.BondType.TRIPLE,
33
+ Chem.rdchem.BondType.AROMATIC,
34
+ ],
35
+ 'bondstereo':[
36
+ Chem.rdchem.BondStereo.STEREONONE,
37
+ Chem.rdchem.BondStereo.STEREOANY,
38
+ Chem.rdchem.BondStereo.STEREOZ,
39
+ Chem.rdchem.BondStereo.STEREOE,
40
+ ]
41
+ }
42
+ CHIRALTAG_PARITY = {
43
+ ChiralType.CHI_TETRAHEDRAL_CW: +1,
44
+ ChiralType.CHI_TETRAHEDRAL_CCW: -1,
45
+ ChiralType.CHI_UNSPECIFIED: 0,
46
+ ChiralType.CHI_OTHER: 0, # default
47
+ }
48
+
49
+ # len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic, mass and IsInRing
50
+ ATOM_FDIM = sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 3
51
+ BOND_FDIM = sum(len(choices) + 1 for choices in BOND_FEATURES.values()) + 3
52
+
53
+
54
+ def get_atom_fdim() -> int:
55
+ """
56
+ Gets the dimensionality of atom features.
57
+ :param: Arguments.
58
+ """
59
+ return ATOM_FDIM
60
+
61
+
62
+ def get_bond_fdim() -> int:
63
+ """
64
+ Gets the dimensionality of bond features.
65
+ :param: Arguments.
66
+ """
67
+ return BOND_FDIM
68
+
69
+
70
+ def onek_encoding_unk(value, choices: List) -> List[int]:
71
+ """
72
+ Creates a one-hot encoding.
73
+ :param value: The value for which the encoding should be one.
74
+ :param choices: A list of possible values.
75
+ :return: A one-hot encoding of the value in a list of length len(choices) + 1.
76
+ If value is not in the list of choices, then the final element in the encoding is 1.
77
+ """
78
+ encoding = [0] * (len(choices) + 1)
79
+ index = choices.index(value) if value in choices else -1
80
+ encoding[index] = 1
81
+
82
+ return encoding
83
+
84
+
85
+ def atom_features(
86
+ atom: Chem.rdchem.Atom,
87
+ chiral_features: bool = False,
88
+ global_chiral_features: bool = False
89
+ ) -> List[Union[bool, int, float]]:
90
+ """
91
+ Builds a feature vector for an atom.
92
+ :param atom: An RDKit atom.
93
+ :param functional_groups: A k-hot vector indicating the functional groups the atom belongs to.
94
+ :return: A list containing the atom features.
95
+ """
96
+ features = onek_encoding_unk(atom.GetSymbol(), ATOM_FEATURES['atomic_num']) + \
97
+ onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \
98
+ onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge'])
99
+ features += onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \
100
+ onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \
101
+ [1 if atom.GetIsAromatic() else 0] + [1 if atom.IsInRing() else 0] + \
102
+ [atom.GetMass() * 0.01] # scaled to about the same range as other features
103
+ if chiral_features:
104
+ features += onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag'])
105
+ if global_chiral_features:
106
+ if atom.HasProp('_CIPCode'):
107
+ features += onek_encoding_unk(atom.GetProp('_CIPCode'), ATOM_FEATURES['global_chiral_tag'])
108
+ else:
109
+ features += onek_encoding_unk(None, ATOM_FEATURES['global_chiral_tag'])
110
+ return features
111
+
112
+
113
+ def parity_features(atom: Chem.rdchem.Atom) -> int:
114
+ """
115
+ Returns the parity of an atom if it is a tetrahedral center.
116
+ +1 if CW, -1 if CCW, and 0 if undefined/unknown
117
+ :param atom: An RDKit atom.
118
+ """
119
+ return CHIRALTAG_PARITY[atom.GetChiralTag()]
120
+
121
+
122
+ def bond_features(bond: Chem.rdchem.Bond) -> List[Union[bool, int, float]]:
123
+ """
124
+ Builds a feature vector for a bond.
125
+ :param bond: A RDKit bond.
126
+ :return: A list containing the bond features.
127
+ """
128
+ bond_fdim = get_bond_fdim()
129
+
130
+ if bond is None:
131
+ fbond = [1] + [0] * (bond_fdim - 1)
132
+ else:
133
+ bt = bond.GetBondType()
134
+ # bond is not None
135
+ fbond = [0] + \
136
+ onek_encoding_unk(bond.GetBondType(), BOND_FEATURES['bondtype']) + \
137
+ onek_encoding_unk(bond.GetStereo(), BOND_FEATURES['bondstereo']) + \
138
+ [(bond.GetIsConjugated() if bt is not None else 0),
139
+ (bond.IsInRing() if bt is not None else 0)
140
+ ]
141
+ return fbond
142
+
143
+
144
+ class MolGraph:
145
+ """
146
+ A MolGraph represents the graph structure and featurization of a single molecule.
147
+ A MolGraph computes the following attributes:
148
+ - smiles: Smiles string.
149
+ - n_atoms: The number of atoms in the molecule.
150
+ - n_bonds: The number of bonds in the molecule.
151
+ - f_atoms: A mapping from an atom index to a list atom features.
152
+ - f_bonds: A mapping from a bond index to a list of bond features.
153
+ - a2b: A mapping from an atom index to a list of incoming bond indices.
154
+ - b2a: A mapping from a bond index to the index of the atom the bond originates from.
155
+ - b2revb: A mapping from a bond index to the index of the reverse bond.
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ smiles: str,
161
+ chiral_features: bool = False,
162
+ global_chiral_features: bool = False
163
+ # args: Namespace
164
+ ):
165
+ """
166
+ Computes the graph structure and featurization of a molecule.
167
+ :param smiles: A smiles string.
168
+ :param args: Arguments.
169
+ """
170
+ self.smiles = smiles
171
+ self.n_atoms = 0 # number of atoms
172
+ self.n_bonds = 0 # number of bonds
173
+ self.f_atoms = [] # mapping from atom index to atom features
174
+ self.f_bonds = [] # mapping from bond index to concat(in_atom, bond) features
175
+ self.a2b = [] # mapping from atom index to incoming bond indices
176
+ self.b2a = [] # mapping from bond index to the index of the atom the bond is coming from
177
+ self.b2revb = [] # mapping from bond index to the index of the reverse bond
178
+ self.parity_atoms = [] # mapping from atom index to CW (+1), CCW (-1) or undefined tetra (0)
179
+ self.edge_index = [] # list of tuples indicating presence of bonds
180
+ self.parity_bond_index = []
181
+
182
+ # Convert smiles to molecule
183
+ mol = Chem.MolFromSmiles(smiles)
184
+
185
+ # add chiral hydrogens
186
+ H_ids = [a.GetIdx() for a in mol.GetAtoms() if CHIRALTAG_PARITY[a.GetChiralTag()] != 0]
187
+ if H_ids:
188
+ mol = Chem.AddHs(mol, onlyOnAtoms=H_ids)
189
+
190
+ # remove stereochem label from atoms with less/more than 4 neighbors
191
+ for i in H_ids:
192
+ a = mol.GetAtomWithIdx(i)
193
+ if len(a.GetNeighbors()) != 4:
194
+ a.SetChiralTag(ChiralType.CHI_UNSPECIFIED)
195
+
196
+ # fake the number of "atoms" if we are collapsing substructures
197
+ self.n_atoms = mol.GetNumAtoms()
198
+
199
+ # Get atom features
200
+ for i, atom in enumerate(mol.GetAtoms()):
201
+ self.f_atoms.append(atom_features(
202
+ atom,
203
+ chiral_features=chiral_features,
204
+ global_chiral_features=global_chiral_features
205
+ ))
206
+ self.parity_atoms.append(parity_features(atom))
207
+ self.f_atoms = [self.f_atoms[i] for i in range(self.n_atoms)]
208
+
209
+ for _ in range(self.n_atoms):
210
+ self.a2b.append([])
211
+
212
+ # Get bond features
213
+ for a1 in range(self.n_atoms):
214
+ for a2 in range(a1 + 1, self.n_atoms):
215
+ bond = mol.GetBondBetweenAtoms(a1, a2)
216
+
217
+ if bond is None:
218
+ continue
219
+
220
+ self.edge_index.extend([(a1, a2), (a2, a1)])
221
+
222
+ f_bond = bond_features(bond)
223
+
224
+ self.f_bonds.append(f_bond)
225
+ self.f_bonds.append(f_bond)
226
+
227
+ # Update index mappings
228
+ b1 = self.n_bonds
229
+ b2 = b1 + 1
230
+ self.a2b[a2].append(b1) # b1 = a1 --> a2
231
+ self.b2a.append(a1)
232
+ self.a2b[a1].append(b2) # b2 = a2 --> a1
233
+ self.b2a.append(a2)
234
+ self.b2revb.append(b2)
235
+ self.b2revb.append(b1)
236
+ self.n_bonds += 2
237
+ for ai, ccw_mask in enumerate(self.parity_atoms):
238
+ if ccw_mask == 0: continue
239
+ nei_idx = []
240
+ for ei, e in enumerate(self.edge_index):
241
+ if e[0] == ai: nei_idx.append(ei)
242
+ if ccw_mask == -1:
243
+ nei_idx = [nei_idx[i] for i in [1, 0, 2, 3]]
244
+ self.parity_bond_index.extend(nei_idx)
245
+
246
+
247
+ def get_components(self) -> Tuple[torch.FloatTensor, torch.FloatTensor,
248
+ torch.LongTensor, torch.LongTensor, torch.LongTensor,
249
+ List[Tuple[int, int]], List[Tuple[int, int]]]:
250
+ """
251
+ Returns the components of the BatchMolGraph.
252
+ :return: A tuple containing PyTorch tensors with the atom features, bond features, and graph structure
253
+ and two lists indicating the scope of the atoms and bonds (i.e. which molecules they belong to).
254
+ """
255
+ return (
256
+ self.f_atoms,
257
+ self.f_bonds,
258
+ self.a2b,
259
+ self.b2a,
260
+ self.b2revb,
261
+ self.a_scope,
262
+ self.b_scope,
263
+ self.parity_atoms
264
+ )
265
+
266
+ def get_b2b(self) -> torch.LongTensor:
267
+ """
268
+ Computes (if necessary) and returns a mapping from each bond index to all the incoming bond indices.
269
+ :return: A PyTorch tensor containing the mapping from each bond index to all the incoming bond indices.
270
+ """
271
+
272
+ if self.b2b is None:
273
+ b2b = self.a2b[self.b2a] # num_bonds x max_num_bonds
274
+ # b2b includes reverse edge for each bond so need to mask out
275
+ revmask = (b2b != self.b2revb.unsqueeze(1).repeat(1, b2b.size(1))).long() # num_bonds x max_num_bonds
276
+ self.b2b = b2b * revmask
277
+
278
+ return self.b2b
279
+
280
+ def get_a2a(self) -> torch.LongTensor:
281
+ """
282
+ Computes (if necessary) and returns a mapping from each atom index to all neighboring atom indices.
283
+ :return: A PyTorch tensor containing the mapping from each bond index to all the incodming bond indices.
284
+ """
285
+ if self.a2a is None:
286
+ # b = a1 --> a2
287
+ # a2b maps a2 to all incoming bonds b
288
+ # b2a maps each bond b to the atom it comes from a1
289
+ # thus b2a[a2b] maps atom a2 to neighboring atoms a1
290
+ self.a2a = self.b2a[self.a2b] # num_atoms x max_num_bonds
291
+
292
+ return self.a2a
293
+
294
+
295
+
296
+
297
+
File without changes