boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
boltz/data/mol.py
ADDED
@@ -0,0 +1,900 @@
|
|
1
|
+
import itertools
|
2
|
+
import pickle
|
3
|
+
import random
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
from rdkit.Chem import Mol
|
9
|
+
from tqdm import tqdm
|
10
|
+
|
11
|
+
from boltz.data import const
|
12
|
+
from boltz.data.pad import pad_dim
|
13
|
+
from boltz.model.loss.confidence import lddt_dist
|
14
|
+
|
15
|
+
|
16
|
+
def load_molecules(moldir: str, molecules: list[str]) -> dict[str, Mol]:
|
17
|
+
"""Load the given input data.
|
18
|
+
|
19
|
+
Parameters
|
20
|
+
----------
|
21
|
+
moldir : str
|
22
|
+
The path to the molecules directory.
|
23
|
+
molecules : list[str]
|
24
|
+
The molecules to load.
|
25
|
+
|
26
|
+
Returns
|
27
|
+
-------
|
28
|
+
dict[str, Mol]
|
29
|
+
The loaded molecules.
|
30
|
+
"""
|
31
|
+
loaded_mols = {}
|
32
|
+
for molecule in molecules:
|
33
|
+
path = Path(moldir) / f"{molecule}.pkl"
|
34
|
+
if not path.exists():
|
35
|
+
msg = f"CCD component {molecule} not found!"
|
36
|
+
raise ValueError(msg)
|
37
|
+
with path.open("rb") as f:
|
38
|
+
loaded_mols[molecule] = pickle.load(f) # noqa: S301
|
39
|
+
return loaded_mols
|
40
|
+
|
41
|
+
|
42
|
+
def load_canonicals(moldir: str) -> dict[str, Mol]:
|
43
|
+
"""Load the given input data.
|
44
|
+
|
45
|
+
Parameters
|
46
|
+
----------
|
47
|
+
moldir : str
|
48
|
+
The molecules to load.
|
49
|
+
|
50
|
+
Returns
|
51
|
+
-------
|
52
|
+
dict[str, Mol]
|
53
|
+
The loaded molecules.
|
54
|
+
|
55
|
+
"""
|
56
|
+
return load_molecules(moldir, const.canonical_tokens)
|
57
|
+
|
58
|
+
|
59
|
+
def load_all_molecules(moldir: str) -> dict[str, Mol]:
|
60
|
+
"""Load the given input data.
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
moldir : str
|
65
|
+
The path to the molecules directory.
|
66
|
+
molecules : list[str]
|
67
|
+
The molecules to load.
|
68
|
+
|
69
|
+
Returns
|
70
|
+
-------
|
71
|
+
dict[str, Mol]
|
72
|
+
The loaded molecules.
|
73
|
+
|
74
|
+
"""
|
75
|
+
loaded_mols = {}
|
76
|
+
files = list(Path(moldir).glob("*.pkl"))
|
77
|
+
for path in tqdm(files, total=len(files), desc="Loading molecules", leave=False):
|
78
|
+
mol_name = path.stem
|
79
|
+
with path.open("rb") as f:
|
80
|
+
loaded_mols[mol_name] = pickle.load(f) # noqa: S301
|
81
|
+
return loaded_mols
|
82
|
+
|
83
|
+
|
84
|
+
def get_symmetries(mols: dict[str, Mol]) -> dict: # noqa: PLR0912
|
85
|
+
"""Create a dictionary for the ligand symmetries.
|
86
|
+
|
87
|
+
Parameters
|
88
|
+
----------
|
89
|
+
path : str
|
90
|
+
The path to the ligand symmetries.
|
91
|
+
|
92
|
+
Returns
|
93
|
+
-------
|
94
|
+
dict
|
95
|
+
The ligand symmetries.
|
96
|
+
|
97
|
+
"""
|
98
|
+
symmetries = {}
|
99
|
+
for key, mol in mols.items():
|
100
|
+
try:
|
101
|
+
sym = pickle.loads(bytes.fromhex(mol.GetProp("symmetries"))) # noqa: S301
|
102
|
+
|
103
|
+
if mol.HasProp("pb_edge_index"):
|
104
|
+
edge_index = pickle.loads(
|
105
|
+
bytes.fromhex(mol.GetProp("pb_edge_index"))
|
106
|
+
).astype(np.int64) # noqa: S301
|
107
|
+
lower_bounds = pickle.loads(
|
108
|
+
bytes.fromhex(mol.GetProp("pb_lower_bounds"))
|
109
|
+
) # noqa: S301
|
110
|
+
upper_bounds = pickle.loads(
|
111
|
+
bytes.fromhex(mol.GetProp("pb_upper_bounds"))
|
112
|
+
) # noqa: S301
|
113
|
+
bond_mask = pickle.loads(bytes.fromhex(mol.GetProp("pb_bond_mask"))) # noqa: S301
|
114
|
+
angle_mask = pickle.loads(bytes.fromhex(mol.GetProp("pb_angle_mask"))) # noqa: S301
|
115
|
+
else:
|
116
|
+
edge_index = np.empty((2, 0), dtype=np.int64)
|
117
|
+
lower_bounds = np.array([], dtype=np.float32)
|
118
|
+
upper_bounds = np.array([], dtype=np.float32)
|
119
|
+
bond_mask = np.array([], dtype=np.float32)
|
120
|
+
angle_mask = np.array([], dtype=np.float32)
|
121
|
+
|
122
|
+
if mol.HasProp("chiral_atom_index"):
|
123
|
+
chiral_atom_index = pickle.loads(
|
124
|
+
bytes.fromhex(mol.GetProp("chiral_atom_index"))
|
125
|
+
).astype(np.int64)
|
126
|
+
chiral_check_mask = pickle.loads(
|
127
|
+
bytes.fromhex(mol.GetProp("chiral_check_mask"))
|
128
|
+
).astype(np.int64)
|
129
|
+
chiral_atom_orientations = pickle.loads(
|
130
|
+
bytes.fromhex(mol.GetProp("chiral_atom_orientations"))
|
131
|
+
)
|
132
|
+
else:
|
133
|
+
chiral_atom_index = np.empty((4, 0), dtype=np.int64)
|
134
|
+
chiral_check_mask = np.array([], dtype=bool)
|
135
|
+
chiral_atom_orientations = np.array([], dtype=bool)
|
136
|
+
|
137
|
+
if mol.HasProp("stereo_bond_index"):
|
138
|
+
stereo_bond_index = pickle.loads(
|
139
|
+
bytes.fromhex(mol.GetProp("stereo_bond_index"))
|
140
|
+
).astype(np.int64)
|
141
|
+
stereo_check_mask = pickle.loads(
|
142
|
+
bytes.fromhex(mol.GetProp("stereo_check_mask"))
|
143
|
+
).astype(np.int64)
|
144
|
+
stereo_bond_orientations = pickle.loads(
|
145
|
+
bytes.fromhex(mol.GetProp("stereo_bond_orientations"))
|
146
|
+
)
|
147
|
+
else:
|
148
|
+
stereo_bond_index = np.empty((4, 0), dtype=np.int64)
|
149
|
+
stereo_check_mask = np.array([], dtype=bool)
|
150
|
+
stereo_bond_orientations = np.array([], dtype=bool)
|
151
|
+
|
152
|
+
if mol.HasProp("aromatic_5_ring_index"):
|
153
|
+
aromatic_5_ring_index = pickle.loads(
|
154
|
+
bytes.fromhex(mol.GetProp("aromatic_5_ring_index"))
|
155
|
+
).astype(np.int64)
|
156
|
+
else:
|
157
|
+
aromatic_5_ring_index = np.empty((5, 0), dtype=np.int64)
|
158
|
+
if mol.HasProp("aromatic_6_ring_index"):
|
159
|
+
aromatic_6_ring_index = pickle.loads(
|
160
|
+
bytes.fromhex(mol.GetProp("aromatic_6_ring_index"))
|
161
|
+
).astype(np.int64)
|
162
|
+
else:
|
163
|
+
aromatic_6_ring_index = np.empty((6, 0), dtype=np.int64)
|
164
|
+
if mol.HasProp("planar_double_bond_index"):
|
165
|
+
planar_double_bond_index = pickle.loads(
|
166
|
+
bytes.fromhex(mol.GetProp("planar_double_bond_index"))
|
167
|
+
).astype(np.int64)
|
168
|
+
else:
|
169
|
+
planar_double_bond_index = np.empty((6, 0), dtype=np.int64)
|
170
|
+
|
171
|
+
atom_names = [atom.GetProp("name") for atom in mol.GetAtoms()]
|
172
|
+
symmetries[key] = (
|
173
|
+
sym,
|
174
|
+
atom_names,
|
175
|
+
edge_index,
|
176
|
+
lower_bounds,
|
177
|
+
upper_bounds,
|
178
|
+
bond_mask,
|
179
|
+
angle_mask,
|
180
|
+
chiral_atom_index,
|
181
|
+
chiral_check_mask,
|
182
|
+
chiral_atom_orientations,
|
183
|
+
stereo_bond_index,
|
184
|
+
stereo_check_mask,
|
185
|
+
stereo_bond_orientations,
|
186
|
+
aromatic_5_ring_index,
|
187
|
+
aromatic_6_ring_index,
|
188
|
+
planar_double_bond_index,
|
189
|
+
)
|
190
|
+
except Exception as e: # noqa: BLE001, PERF203, S110
|
191
|
+
pass
|
192
|
+
|
193
|
+
return symmetries
|
194
|
+
|
195
|
+
|
196
|
+
def compute_symmetry_idx_dictionary(data):
|
197
|
+
# Compute the symmetry index dictionary
|
198
|
+
total_count = 0
|
199
|
+
all_coords = []
|
200
|
+
for i, chain in enumerate(data.chains):
|
201
|
+
chain.start_idx = total_count
|
202
|
+
for j, token in enumerate(chain.tokens):
|
203
|
+
token.start_idx = total_count - chain.start_idx
|
204
|
+
all_coords.extend(
|
205
|
+
[[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
|
206
|
+
)
|
207
|
+
total_count += len(token.atoms)
|
208
|
+
return all_coords
|
209
|
+
|
210
|
+
|
211
|
+
def get_current_idx_list(data):
|
212
|
+
idx = []
|
213
|
+
for chain in data.chains:
|
214
|
+
if chain.in_crop:
|
215
|
+
for token in chain.tokens:
|
216
|
+
if token.in_crop:
|
217
|
+
idx.extend(
|
218
|
+
[
|
219
|
+
chain.start_idx + token.start_idx + i
|
220
|
+
for i in range(len(token.atoms))
|
221
|
+
]
|
222
|
+
)
|
223
|
+
return idx
|
224
|
+
|
225
|
+
|
226
|
+
def all_different_after_swap(l):
|
227
|
+
final = [s[-1] for s in l]
|
228
|
+
return len(final) == len(set(final))
|
229
|
+
|
230
|
+
|
231
|
+
def minimum_lddt_symmetry_coords(
|
232
|
+
coords: torch.Tensor,
|
233
|
+
feats: dict,
|
234
|
+
index_batch: int,
|
235
|
+
):
|
236
|
+
all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
|
237
|
+
all_resolved_mask = (
|
238
|
+
feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
|
239
|
+
)
|
240
|
+
crop_to_all_atom_map = (
|
241
|
+
feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
|
242
|
+
)
|
243
|
+
chain_symmetries = feats["chain_swaps"][index_batch]
|
244
|
+
amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
|
245
|
+
ligand_symmetries = feats["ligand_symmetries"][index_batch]
|
246
|
+
|
247
|
+
dmat_predicted = torch.cdist(
|
248
|
+
coords[:, : len(crop_to_all_atom_map)], coords[:, : len(crop_to_all_atom_map)]
|
249
|
+
)
|
250
|
+
|
251
|
+
# Check best symmetry on chain swap
|
252
|
+
best_true_coords = all_coords[:, crop_to_all_atom_map].clone()
|
253
|
+
best_true_resolved_mask = all_resolved_mask[crop_to_all_atom_map].clone()
|
254
|
+
best_lddt = -1.0
|
255
|
+
for c in chain_symmetries:
|
256
|
+
true_all_coords = all_coords.clone()
|
257
|
+
true_all_resolved_mask = all_resolved_mask.clone()
|
258
|
+
for start1, end1, start2, end2, chainidx1, chainidx2 in c:
|
259
|
+
true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
|
260
|
+
true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
|
261
|
+
true_coords = true_all_coords[:, crop_to_all_atom_map]
|
262
|
+
true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
|
263
|
+
dmat_true = torch.cdist(true_coords, true_coords)
|
264
|
+
pair_mask = (
|
265
|
+
true_resolved_mask[:, None]
|
266
|
+
* true_resolved_mask[None, :]
|
267
|
+
* (1 - torch.eye(len(true_resolved_mask))).to(true_resolved_mask)
|
268
|
+
)
|
269
|
+
|
270
|
+
lddt = lddt_dist(
|
271
|
+
dmat_predicted, dmat_true, pair_mask, cutoff=15.0, per_atom=False
|
272
|
+
)[0]
|
273
|
+
lddt = lddt.item()
|
274
|
+
|
275
|
+
if lddt > best_lddt and torch.sum(true_resolved_mask) > 3:
|
276
|
+
best_lddt = lddt
|
277
|
+
best_true_coords = true_coords
|
278
|
+
best_true_resolved_mask = true_resolved_mask
|
279
|
+
|
280
|
+
# atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
|
281
|
+
true_coords = best_true_coords.clone()
|
282
|
+
true_resolved_mask = best_true_resolved_mask.clone()
|
283
|
+
for symmetric_amino_or_lig in amino_acids_symmetries + ligand_symmetries:
|
284
|
+
best_lddt_improvement = 0.0
|
285
|
+
|
286
|
+
indices = set()
|
287
|
+
for c in symmetric_amino_or_lig:
|
288
|
+
for i, j in c:
|
289
|
+
indices.add(i)
|
290
|
+
indices = sorted(list(indices))
|
291
|
+
indices = torch.from_numpy(np.asarray(indices)).to(true_coords.device).long()
|
292
|
+
pred_coords_subset = coords[:, : len(crop_to_all_atom_map)][:, indices]
|
293
|
+
sub_dmat_pred = torch.cdist(
|
294
|
+
coords[:, : len(crop_to_all_atom_map)], pred_coords_subset
|
295
|
+
)
|
296
|
+
|
297
|
+
for c in symmetric_amino_or_lig:
|
298
|
+
# starting from greedy best, try to swap the atoms
|
299
|
+
new_true_coords = true_coords.clone()
|
300
|
+
new_true_resolved_mask = true_resolved_mask.clone()
|
301
|
+
for i, j in c:
|
302
|
+
new_true_coords[:, i] = true_coords[:, j]
|
303
|
+
new_true_resolved_mask[i] = true_resolved_mask[j]
|
304
|
+
|
305
|
+
true_coords_subset = true_coords[:, indices]
|
306
|
+
new_true_coords_subset = new_true_coords[:, indices]
|
307
|
+
|
308
|
+
sub_dmat_true = torch.cdist(true_coords, true_coords_subset)
|
309
|
+
sub_dmat_new_true = torch.cdist(new_true_coords, new_true_coords_subset)
|
310
|
+
|
311
|
+
sub_true_pair_lddt = (
|
312
|
+
true_resolved_mask[:, None] * true_resolved_mask[None, indices]
|
313
|
+
)
|
314
|
+
sub_true_pair_lddt[indices] = (
|
315
|
+
sub_true_pair_lddt[indices]
|
316
|
+
* (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
|
317
|
+
)
|
318
|
+
|
319
|
+
sub_new_true_pair_lddt = (
|
320
|
+
new_true_resolved_mask[:, None] * new_true_resolved_mask[None, indices]
|
321
|
+
)
|
322
|
+
sub_new_true_pair_lddt[indices] = (
|
323
|
+
sub_new_true_pair_lddt[indices]
|
324
|
+
* (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
|
325
|
+
)
|
326
|
+
|
327
|
+
lddt, total = lddt_dist(
|
328
|
+
sub_dmat_pred,
|
329
|
+
sub_dmat_true,
|
330
|
+
sub_true_pair_lddt,
|
331
|
+
cutoff=15.0,
|
332
|
+
per_atom=False,
|
333
|
+
)
|
334
|
+
new_lddt, new_total = lddt_dist(
|
335
|
+
sub_dmat_pred,
|
336
|
+
sub_dmat_new_true,
|
337
|
+
sub_new_true_pair_lddt,
|
338
|
+
cutoff=15.0,
|
339
|
+
per_atom=False,
|
340
|
+
)
|
341
|
+
|
342
|
+
lddt_improvement = new_lddt - lddt
|
343
|
+
|
344
|
+
if lddt_improvement > best_lddt_improvement:
|
345
|
+
best_true_coords = new_true_coords
|
346
|
+
best_true_resolved_mask = new_true_resolved_mask
|
347
|
+
best_lddt_improvement = lddt_improvement
|
348
|
+
|
349
|
+
# greedily update best coordinates after each amino acid
|
350
|
+
true_coords = best_true_coords.clone()
|
351
|
+
true_resolved_mask = best_true_resolved_mask.clone()
|
352
|
+
|
353
|
+
# Recomputing alignment
|
354
|
+
true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
|
355
|
+
true_resolved_mask = pad_dim(
|
356
|
+
true_resolved_mask,
|
357
|
+
0,
|
358
|
+
coords.shape[1] - true_resolved_mask.shape[0],
|
359
|
+
)
|
360
|
+
|
361
|
+
return true_coords, true_resolved_mask.unsqueeze(0)
|
362
|
+
|
363
|
+
|
364
|
+
def compute_single_distogram_loss(pred, target, mask):
|
365
|
+
# Compute the distogram loss
|
366
|
+
errors = -1 * torch.sum(
|
367
|
+
target * torch.nn.functional.log_softmax(pred, dim=-1),
|
368
|
+
dim=-1,
|
369
|
+
)
|
370
|
+
denom = 1e-5 + torch.sum(mask, dim=(-1, -2))
|
371
|
+
mean = errors * mask
|
372
|
+
mean = torch.sum(mean, dim=-1)
|
373
|
+
mean = mean / denom[..., None]
|
374
|
+
batch_loss = torch.sum(mean, dim=-1)
|
375
|
+
global_loss = torch.mean(batch_loss)
|
376
|
+
return global_loss
|
377
|
+
|
378
|
+
|
379
|
+
def minimum_lddt_symmetry_dist(
|
380
|
+
pred_distogram: torch.Tensor,
|
381
|
+
feats: dict,
|
382
|
+
index_batch: int,
|
383
|
+
):
|
384
|
+
# Note: for now only ligand symmetries are resolved
|
385
|
+
|
386
|
+
disto_target = feats["disto_target"][index_batch]
|
387
|
+
mask = feats["token_disto_mask"][index_batch]
|
388
|
+
mask = mask[None, :] * mask[:, None]
|
389
|
+
mask = mask * (1 - torch.eye(mask.shape[1])).to(disto_target)
|
390
|
+
|
391
|
+
coords = feats["coords"][index_batch]
|
392
|
+
|
393
|
+
ligand_symmetries = feats["ligand_symmetries"][index_batch]
|
394
|
+
atom_to_token_map = feats["atom_to_token"][index_batch].argmax(dim=-1)
|
395
|
+
|
396
|
+
# atom symmetries, resolved greedily without recomputing alignment
|
397
|
+
for symmetric_amino_or_lig in ligand_symmetries:
|
398
|
+
best_c, best_disto, best_loss_improvement = None, None, 0.0
|
399
|
+
for c in symmetric_amino_or_lig:
|
400
|
+
# starting from greedy best, try to swap the atoms
|
401
|
+
new_disto_target = disto_target.clone()
|
402
|
+
indices = []
|
403
|
+
|
404
|
+
# fix the distogram by replacing first the columns then the rows
|
405
|
+
disto_temp = new_disto_target.clone()
|
406
|
+
for i, j in c:
|
407
|
+
new_disto_target[:, atom_to_token_map[i]] = disto_temp[
|
408
|
+
:, atom_to_token_map[j]
|
409
|
+
]
|
410
|
+
indices.append(atom_to_token_map[i].item())
|
411
|
+
disto_temp = new_disto_target.clone()
|
412
|
+
for i, j in c:
|
413
|
+
new_disto_target[atom_to_token_map[i], :] = disto_temp[
|
414
|
+
atom_to_token_map[j], :
|
415
|
+
]
|
416
|
+
|
417
|
+
indices = (
|
418
|
+
torch.from_numpy(np.asarray(indices)).to(disto_target.device).long()
|
419
|
+
)
|
420
|
+
|
421
|
+
pred_distogram_subset = pred_distogram[:, indices]
|
422
|
+
disto_target_subset = disto_target[:, indices]
|
423
|
+
new_disto_target_subset = new_disto_target[:, indices]
|
424
|
+
mask_subset = mask[:, indices]
|
425
|
+
|
426
|
+
loss = compute_single_distogram_loss(
|
427
|
+
pred_distogram_subset, disto_target_subset, mask_subset
|
428
|
+
)
|
429
|
+
new_loss = compute_single_distogram_loss(
|
430
|
+
pred_distogram_subset, new_disto_target_subset, mask_subset
|
431
|
+
)
|
432
|
+
loss_improvement = (loss - new_loss) * len(indices)
|
433
|
+
|
434
|
+
if loss_improvement > best_loss_improvement:
|
435
|
+
best_c = c
|
436
|
+
best_disto = new_disto_target
|
437
|
+
best_loss_improvement = loss_improvement
|
438
|
+
|
439
|
+
# greedily update best coordinates after each ligand
|
440
|
+
if best_loss_improvement > 0:
|
441
|
+
disto_target = best_disto.clone()
|
442
|
+
old_coords = coords.clone()
|
443
|
+
for i, j in best_c:
|
444
|
+
coords[:, i] = old_coords[:, j]
|
445
|
+
|
446
|
+
# update features to be used in diffusion and in distogram loss
|
447
|
+
feats["disto_target"][index_batch] = disto_target
|
448
|
+
feats["coords"][index_batch] = coords
|
449
|
+
return
|
450
|
+
|
451
|
+
|
452
|
+
def compute_all_coords_mask(structure):
|
453
|
+
# Compute all coords, crop mask and add start_idx to structure
|
454
|
+
total_count = 0
|
455
|
+
all_coords = []
|
456
|
+
all_coords_crop_mask = []
|
457
|
+
all_resolved_mask = []
|
458
|
+
for i, chain in enumerate(structure.chains):
|
459
|
+
chain.start_idx = total_count
|
460
|
+
for j, token in enumerate(chain.tokens):
|
461
|
+
token.start_idx = total_count - chain.start_idx
|
462
|
+
all_coords.extend(
|
463
|
+
[[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
|
464
|
+
)
|
465
|
+
all_coords_crop_mask.extend(
|
466
|
+
[token.in_crop for _ in range(len(token.atoms))]
|
467
|
+
)
|
468
|
+
all_resolved_mask.extend(
|
469
|
+
[token.is_present for _ in range(len(token.atoms))]
|
470
|
+
)
|
471
|
+
total_count += len(token.atoms)
|
472
|
+
if len(all_coords_crop_mask) != len(all_resolved_mask):
|
473
|
+
pass
|
474
|
+
return all_coords, all_coords_crop_mask, all_resolved_mask
|
475
|
+
|
476
|
+
|
477
|
+
def get_chain_symmetries(cropped, max_n_symmetries=100):
|
478
|
+
# get all coordinates and resolved mask
|
479
|
+
structure = cropped.structure
|
480
|
+
all_coords = []
|
481
|
+
all_resolved_mask = []
|
482
|
+
original_atom_idx = []
|
483
|
+
chain_atom_idx = []
|
484
|
+
chain_atom_num = []
|
485
|
+
chain_in_crop = []
|
486
|
+
chain_asym_id = []
|
487
|
+
new_atom_idx = 0
|
488
|
+
|
489
|
+
for chain in structure.chains:
|
490
|
+
atom_idx, atom_num = (
|
491
|
+
chain["atom_idx"], # Global index of first atom in the chain
|
492
|
+
chain["atom_num"], # Number of atoms in the chain
|
493
|
+
)
|
494
|
+
|
495
|
+
# compute coordinates and resolved mask
|
496
|
+
resolved_mask = structure.atoms["is_present"][
|
497
|
+
atom_idx : atom_idx + atom_num
|
498
|
+
] # Whether each atom in the chain is actually resolved
|
499
|
+
|
500
|
+
# ensemble_atom_starts = [structure.ensemble[idx]["atom_coord_idx"] for idx in cropped.ensemble_ref_idxs]
|
501
|
+
# coords = np.array(
|
502
|
+
# [structure.coords[ensemble_atom_start + atom_idx: ensemble_atom_start + atom_idx + atom_num]["coords"] for
|
503
|
+
# ensemble_atom_start in ensemble_atom_starts])
|
504
|
+
|
505
|
+
coords = structure.atoms["coords"][atom_idx : atom_idx + atom_num]
|
506
|
+
|
507
|
+
in_crop = False
|
508
|
+
for token in cropped.tokens:
|
509
|
+
if token["asym_id"] == chain["asym_id"]:
|
510
|
+
in_crop = True
|
511
|
+
break
|
512
|
+
|
513
|
+
all_coords.append(coords)
|
514
|
+
all_resolved_mask.append(resolved_mask)
|
515
|
+
original_atom_idx.append(atom_idx)
|
516
|
+
chain_atom_idx.append(new_atom_idx)
|
517
|
+
chain_atom_num.append(atom_num)
|
518
|
+
chain_in_crop.append(in_crop)
|
519
|
+
chain_asym_id.append(chain["asym_id"])
|
520
|
+
|
521
|
+
new_atom_idx += atom_num
|
522
|
+
|
523
|
+
all_coords = np.concatenate(all_coords, axis=0)
|
524
|
+
# Compute backmapping from token to all coords
|
525
|
+
crop_to_all_atom_map = []
|
526
|
+
for token in cropped.tokens:
|
527
|
+
chain_idx = chain_asym_id.index(token["asym_id"])
|
528
|
+
start = (
|
529
|
+
chain_atom_idx[chain_idx] - original_atom_idx[chain_idx] + token["atom_idx"]
|
530
|
+
)
|
531
|
+
crop_to_all_atom_map.append(np.arange(start, start + token["atom_num"]))
|
532
|
+
crop_to_all_atom_map = np.concatenate(crop_to_all_atom_map, axis=0)
|
533
|
+
|
534
|
+
# Compute the connections edge index for covalent bonds
|
535
|
+
all_atom_to_crop_map = np.zeros(all_coords.shape[0], dtype=np.int64)
|
536
|
+
all_atom_to_crop_map[crop_to_all_atom_map.astype(np.int64)] = np.arange(
|
537
|
+
crop_to_all_atom_map.shape[0]
|
538
|
+
)
|
539
|
+
connections_edge_index = []
|
540
|
+
for connection in structure.bonds:
|
541
|
+
if (connection["chain_1"] == connection["chain_2"]) and (
|
542
|
+
connection["res_1"] == connection["res_2"]
|
543
|
+
):
|
544
|
+
continue
|
545
|
+
connections_edge_index.append([connection["atom_1"], connection["atom_2"]])
|
546
|
+
if len(connections_edge_index) > 0:
|
547
|
+
connections_edge_index = np.array(connections_edge_index, dtype=np.int64).T
|
548
|
+
connections_edge_index = all_atom_to_crop_map[connections_edge_index]
|
549
|
+
else:
|
550
|
+
connections_edge_index = np.empty((2, 0))
|
551
|
+
|
552
|
+
# Compute the symmetries between chains
|
553
|
+
symmetries = []
|
554
|
+
swaps = []
|
555
|
+
for i, chain in enumerate(structure.chains):
|
556
|
+
start = chain_atom_idx[i]
|
557
|
+
end = start + chain_atom_num[i]
|
558
|
+
|
559
|
+
if chain_in_crop[i]:
|
560
|
+
possible_swaps = []
|
561
|
+
for j, chain2 in enumerate(structure.chains):
|
562
|
+
start2 = chain_atom_idx[j]
|
563
|
+
end2 = start2 + chain_atom_num[j]
|
564
|
+
if (
|
565
|
+
chain["entity_id"] == chain2["entity_id"]
|
566
|
+
and end - start == end2 - start2
|
567
|
+
):
|
568
|
+
possible_swaps.append((start, end, start2, end2, i, j))
|
569
|
+
swaps.append(possible_swaps)
|
570
|
+
|
571
|
+
found = False
|
572
|
+
for symmetry_idx, symmetry in enumerate(symmetries):
|
573
|
+
j = symmetry[0][0]
|
574
|
+
chain2 = structure.chains[j]
|
575
|
+
start2 = chain_atom_idx[j]
|
576
|
+
end2 = start2 + chain_atom_num[j]
|
577
|
+
if (
|
578
|
+
chain["entity_id"] == chain2["entity_id"]
|
579
|
+
and end - start == end2 - start2
|
580
|
+
):
|
581
|
+
symmetries[symmetry_idx].append(
|
582
|
+
(i, start, end, chain_in_crop[i], chain["mol_type"])
|
583
|
+
)
|
584
|
+
found = True
|
585
|
+
if not found:
|
586
|
+
symmetries.append([(i, start, end, chain_in_crop[i], chain["mol_type"])])
|
587
|
+
|
588
|
+
combinations = itertools.product(*swaps)
|
589
|
+
# to avoid combinatorial explosion, bound the number of combinations even considered
|
590
|
+
combinations = list(itertools.islice(combinations, max_n_symmetries * 10))
|
591
|
+
# filter for all chains getting a different assignment
|
592
|
+
combinations = [c for c in combinations if all_different_after_swap(c)]
|
593
|
+
|
594
|
+
if len(combinations) > max_n_symmetries:
|
595
|
+
combinations = random.sample(combinations, max_n_symmetries)
|
596
|
+
|
597
|
+
if len(combinations) == 0:
|
598
|
+
combinations.append([])
|
599
|
+
|
600
|
+
for i in range(len(symmetries) - 1, -1, -1):
|
601
|
+
if not any(chain[3] for chain in symmetries[i]):
|
602
|
+
symmetries.pop(i)
|
603
|
+
|
604
|
+
features = {}
|
605
|
+
features["all_coords"] = torch.Tensor(all_coords) # axis=1 with ensemble
|
606
|
+
|
607
|
+
features["all_resolved_mask"] = torch.Tensor(
|
608
|
+
np.concatenate(all_resolved_mask, axis=0)
|
609
|
+
)
|
610
|
+
features["crop_to_all_atom_map"] = torch.Tensor(crop_to_all_atom_map)
|
611
|
+
features["chain_symmetries"] = symmetries
|
612
|
+
features["connections_edge_index"] = torch.tensor(connections_edge_index)
|
613
|
+
features["chain_swaps"] = combinations
|
614
|
+
|
615
|
+
return features
|
616
|
+
|
617
|
+
|
618
|
+
def get_amino_acids_symmetries(cropped):
|
619
|
+
# Compute standard amino-acids symmetries
|
620
|
+
swaps = []
|
621
|
+
start_index_crop = 0
|
622
|
+
for token in cropped.tokens:
|
623
|
+
symmetries = const.ref_symmetries.get(const.tokens[token["res_type"]], [])
|
624
|
+
if len(symmetries) > 0:
|
625
|
+
residue_swaps = []
|
626
|
+
for sym in symmetries:
|
627
|
+
sym_new_idx = [
|
628
|
+
(i + start_index_crop, j + start_index_crop) for i, j in sym
|
629
|
+
]
|
630
|
+
residue_swaps.append(sym_new_idx)
|
631
|
+
swaps.append(residue_swaps)
|
632
|
+
start_index_crop += token["atom_num"]
|
633
|
+
|
634
|
+
features = {"amino_acids_symmetries": swaps}
|
635
|
+
return features
|
636
|
+
|
637
|
+
|
638
|
+
def slice_valid_index(index, ccd_to_valid_id_array, args=None):
|
639
|
+
index = ccd_to_valid_id_array[index]
|
640
|
+
valid_index_mask = (~np.isnan(index)).all(axis=0)
|
641
|
+
index = index[:, valid_index_mask]
|
642
|
+
if args is None:
|
643
|
+
return index
|
644
|
+
args = (arg[valid_index_mask] for arg in args)
|
645
|
+
return index, args
|
646
|
+
|
647
|
+
|
648
|
+
def get_ligand_symmetries(cropped, symmetries, return_physical_metrics=False):
|
649
|
+
# Compute ligand and non-standard amino-acids symmetries
|
650
|
+
structure = cropped.structure
|
651
|
+
|
652
|
+
added_molecules = {}
|
653
|
+
index_mols = []
|
654
|
+
atom_count = 0
|
655
|
+
|
656
|
+
for token in cropped.tokens:
|
657
|
+
# check if molecule is already added by identifying it through asym_id and res_idx
|
658
|
+
atom_count += token["atom_num"]
|
659
|
+
mol_id = (token["asym_id"], token["res_idx"])
|
660
|
+
if mol_id in added_molecules:
|
661
|
+
added_molecules[mol_id] += token["atom_num"]
|
662
|
+
continue
|
663
|
+
added_molecules[mol_id] = token["atom_num"]
|
664
|
+
|
665
|
+
# get the molecule type and indices
|
666
|
+
residue_idx = token["res_idx"] + structure.chains[token["asym_id"]]["res_idx"]
|
667
|
+
mol_name = structure.residues[residue_idx]["name"]
|
668
|
+
atom_idx = structure.residues[residue_idx]["atom_idx"]
|
669
|
+
mol_atom_names = structure.atoms[
|
670
|
+
atom_idx : atom_idx + structure.residues[residue_idx]["atom_num"]
|
671
|
+
]["name"]
|
672
|
+
if mol_name not in const.ref_symmetries:
|
673
|
+
index_mols.append(
|
674
|
+
(mol_name, atom_count - token["atom_num"], mol_id, mol_atom_names)
|
675
|
+
)
|
676
|
+
|
677
|
+
# for each molecule, get the symmetries
|
678
|
+
molecule_symmetries = []
|
679
|
+
all_edge_index = []
|
680
|
+
all_lower_bounds, all_upper_bounds = [], []
|
681
|
+
all_bond_mask, all_angle_mask = [], []
|
682
|
+
all_chiral_atom_index, all_chiral_check_mask, all_chiral_atom_orientations = (
|
683
|
+
[],
|
684
|
+
[],
|
685
|
+
[],
|
686
|
+
)
|
687
|
+
all_stereo_bond_index, all_stereo_check_mask, all_stereo_bond_orientations = (
|
688
|
+
[],
|
689
|
+
[],
|
690
|
+
[],
|
691
|
+
)
|
692
|
+
(
|
693
|
+
all_aromatic_5_ring_index,
|
694
|
+
all_aromatic_6_ring_index,
|
695
|
+
all_planar_double_bond_index,
|
696
|
+
) = (
|
697
|
+
[],
|
698
|
+
[],
|
699
|
+
[],
|
700
|
+
)
|
701
|
+
for mol_name, start_mol, mol_id, mol_atom_names in index_mols:
|
702
|
+
if not mol_name in symmetries:
|
703
|
+
continue
|
704
|
+
else:
|
705
|
+
swaps = []
|
706
|
+
(
|
707
|
+
syms_ccd,
|
708
|
+
mol_atom_names_ccd,
|
709
|
+
edge_index,
|
710
|
+
lower_bounds,
|
711
|
+
upper_bounds,
|
712
|
+
bond_mask,
|
713
|
+
angle_mask,
|
714
|
+
chiral_atom_index,
|
715
|
+
chiral_check_mask,
|
716
|
+
chiral_atom_orientations,
|
717
|
+
stereo_bond_index,
|
718
|
+
stereo_check_mask,
|
719
|
+
stereo_bond_orientations,
|
720
|
+
aromatic_5_ring_index,
|
721
|
+
aromatic_6_ring_index,
|
722
|
+
planar_double_bond_index,
|
723
|
+
) = symmetries[mol_name]
|
724
|
+
# Get indices of mol_atom_names_ccd that are in mol_atom_names
|
725
|
+
ccd_to_valid_ids = {
|
726
|
+
mol_atom_names_ccd.index(name): i
|
727
|
+
for i, name in enumerate(mol_atom_names)
|
728
|
+
}
|
729
|
+
ccd_to_valid_id_array = np.array(
|
730
|
+
[
|
731
|
+
float("nan") if i not in ccd_to_valid_ids else ccd_to_valid_ids[i]
|
732
|
+
for i in range(len(mol_atom_names_ccd))
|
733
|
+
]
|
734
|
+
)
|
735
|
+
ccd_valid_ids = set(ccd_to_valid_ids.keys())
|
736
|
+
syms = []
|
737
|
+
# Get syms
|
738
|
+
for sym_ccd in syms_ccd:
|
739
|
+
sym_dict = {}
|
740
|
+
bool_add = True
|
741
|
+
for i, j in enumerate(sym_ccd):
|
742
|
+
if i in ccd_valid_ids:
|
743
|
+
if j in ccd_valid_ids:
|
744
|
+
i_true = ccd_to_valid_ids[i]
|
745
|
+
j_true = ccd_to_valid_ids[j]
|
746
|
+
sym_dict[i_true] = j_true
|
747
|
+
else:
|
748
|
+
bool_add = False
|
749
|
+
break
|
750
|
+
if bool_add:
|
751
|
+
syms.append([sym_dict[i] for i in range(len(ccd_valid_ids))])
|
752
|
+
for sym in syms:
|
753
|
+
if len(sym) != added_molecules[mol_id]:
|
754
|
+
raise Exception(
|
755
|
+
f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
|
756
|
+
)
|
757
|
+
# assert (
|
758
|
+
# len(sym) == added_molecules[mol_id]
|
759
|
+
# ), f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
|
760
|
+
sym_new_idx = []
|
761
|
+
for i, j in enumerate(sym):
|
762
|
+
if i != int(j):
|
763
|
+
sym_new_idx.append((i + start_mol, int(j) + start_mol))
|
764
|
+
if len(sym_new_idx) > 0:
|
765
|
+
swaps.append(sym_new_idx)
|
766
|
+
|
767
|
+
if len(swaps) > 0:
|
768
|
+
molecule_symmetries.append(swaps)
|
769
|
+
|
770
|
+
if return_physical_metrics:
|
771
|
+
edge_index, (lower_bounds, upper_bounds, bond_mask, angle_mask) = (
|
772
|
+
slice_valid_index(
|
773
|
+
edge_index,
|
774
|
+
ccd_to_valid_id_array,
|
775
|
+
(lower_bounds, upper_bounds, bond_mask, angle_mask),
|
776
|
+
)
|
777
|
+
)
|
778
|
+
all_edge_index.append(edge_index + start_mol)
|
779
|
+
all_lower_bounds.append(lower_bounds)
|
780
|
+
all_upper_bounds.append(upper_bounds)
|
781
|
+
all_bond_mask.append(bond_mask)
|
782
|
+
all_angle_mask.append(angle_mask)
|
783
|
+
|
784
|
+
chiral_atom_index, (chiral_check_mask, chiral_atom_orientations) = (
|
785
|
+
slice_valid_index(
|
786
|
+
chiral_atom_index,
|
787
|
+
ccd_to_valid_id_array,
|
788
|
+
(chiral_check_mask, chiral_atom_orientations),
|
789
|
+
)
|
790
|
+
)
|
791
|
+
all_chiral_atom_index.append(chiral_atom_index + start_mol)
|
792
|
+
all_chiral_check_mask.append(chiral_check_mask)
|
793
|
+
all_chiral_atom_orientations.append(chiral_atom_orientations)
|
794
|
+
|
795
|
+
stereo_bond_index, (stereo_check_mask, stereo_bond_orientations) = (
|
796
|
+
slice_valid_index(
|
797
|
+
stereo_bond_index,
|
798
|
+
ccd_to_valid_id_array,
|
799
|
+
(stereo_check_mask, stereo_bond_orientations),
|
800
|
+
)
|
801
|
+
)
|
802
|
+
all_stereo_bond_index.append(stereo_bond_index + start_mol)
|
803
|
+
all_stereo_check_mask.append(stereo_check_mask)
|
804
|
+
all_stereo_bond_orientations.append(stereo_bond_orientations)
|
805
|
+
|
806
|
+
aromatic_5_ring_index = slice_valid_index(
|
807
|
+
aromatic_5_ring_index, ccd_to_valid_id_array
|
808
|
+
)
|
809
|
+
aromatic_6_ring_index = slice_valid_index(
|
810
|
+
aromatic_6_ring_index, ccd_to_valid_id_array
|
811
|
+
)
|
812
|
+
planar_double_bond_index = slice_valid_index(
|
813
|
+
planar_double_bond_index, ccd_to_valid_id_array
|
814
|
+
)
|
815
|
+
all_aromatic_5_ring_index.append(aromatic_5_ring_index + start_mol)
|
816
|
+
all_aromatic_6_ring_index.append(aromatic_6_ring_index + start_mol)
|
817
|
+
all_planar_double_bond_index.append(
|
818
|
+
planar_double_bond_index + start_mol
|
819
|
+
)
|
820
|
+
|
821
|
+
if return_physical_metrics:
|
822
|
+
if len(all_edge_index) > 0:
|
823
|
+
all_edge_index = np.concatenate(all_edge_index, axis=1)
|
824
|
+
all_lower_bounds = np.concatenate(all_lower_bounds, axis=0)
|
825
|
+
all_upper_bounds = np.concatenate(all_upper_bounds, axis=0)
|
826
|
+
all_bond_mask = np.concatenate(all_bond_mask, axis=0)
|
827
|
+
all_angle_mask = np.concatenate(all_angle_mask, axis=0)
|
828
|
+
|
829
|
+
all_chiral_atom_index = np.concatenate(all_chiral_atom_index, axis=1)
|
830
|
+
all_chiral_check_mask = np.concatenate(all_chiral_check_mask, axis=0)
|
831
|
+
all_chiral_atom_orientations = np.concatenate(
|
832
|
+
all_chiral_atom_orientations, axis=0
|
833
|
+
)
|
834
|
+
|
835
|
+
all_stereo_bond_index = np.concatenate(all_stereo_bond_index, axis=1)
|
836
|
+
all_stereo_check_mask = np.concatenate(all_stereo_check_mask, axis=0)
|
837
|
+
all_stereo_bond_orientations = np.concatenate(
|
838
|
+
all_stereo_bond_orientations, axis=0
|
839
|
+
)
|
840
|
+
|
841
|
+
all_aromatic_5_ring_index = np.concatenate(
|
842
|
+
all_aromatic_5_ring_index, axis=1
|
843
|
+
)
|
844
|
+
all_aromatic_6_ring_index = np.concatenate(
|
845
|
+
all_aromatic_6_ring_index, axis=1
|
846
|
+
)
|
847
|
+
all_planar_double_bond_index = np.empty(
|
848
|
+
(6, 0), dtype=np.int64
|
849
|
+
) # TODO remove np.concatenate(all_planar_double_bond_index, axis=1)
|
850
|
+
else:
|
851
|
+
all_edge_index = np.empty((2, 0), dtype=np.int64)
|
852
|
+
all_lower_bounds = np.array([], dtype=np.float32)
|
853
|
+
all_upper_bounds = np.array([], dtype=np.float32)
|
854
|
+
all_bond_mask = np.array([], dtype=bool)
|
855
|
+
all_angle_mask = np.array([], dtype=bool)
|
856
|
+
|
857
|
+
all_chiral_atom_index = np.empty((4, 0), dtype=np.int64)
|
858
|
+
all_chiral_check_mask = np.array([], dtype=bool)
|
859
|
+
all_chiral_atom_orientations = np.array([], dtype=bool)
|
860
|
+
|
861
|
+
all_stereo_bond_index = np.empty((4, 0), dtype=np.int64)
|
862
|
+
all_stereo_check_mask = np.array([], dtype=bool)
|
863
|
+
all_stereo_bond_orientations = np.array([], dtype=bool)
|
864
|
+
|
865
|
+
all_aromatic_5_ring_index = np.empty((5, 0), dtype=np.int64)
|
866
|
+
all_aromatic_6_ring_index = np.empty((6, 0), dtype=np.int64)
|
867
|
+
all_planar_double_bond_index = np.empty((6, 0), dtype=np.int64)
|
868
|
+
|
869
|
+
features = {
|
870
|
+
"ligand_symmetries": molecule_symmetries,
|
871
|
+
"ligand_edge_index": torch.tensor(all_edge_index).long(),
|
872
|
+
"ligand_edge_lower_bounds": torch.tensor(all_lower_bounds),
|
873
|
+
"ligand_edge_upper_bounds": torch.tensor(all_upper_bounds),
|
874
|
+
"ligand_edge_bond_mask": torch.tensor(all_bond_mask),
|
875
|
+
"ligand_edge_angle_mask": torch.tensor(all_angle_mask),
|
876
|
+
"ligand_chiral_atom_index": torch.tensor(all_chiral_atom_index).long(),
|
877
|
+
"ligand_chiral_check_mask": torch.tensor(all_chiral_check_mask),
|
878
|
+
"ligand_chiral_atom_orientations": torch.tensor(
|
879
|
+
all_chiral_atom_orientations
|
880
|
+
),
|
881
|
+
"ligand_stereo_bond_index": torch.tensor(all_stereo_bond_index).long(),
|
882
|
+
"ligand_stereo_check_mask": torch.tensor(all_stereo_check_mask),
|
883
|
+
"ligand_stereo_bond_orientations": torch.tensor(
|
884
|
+
all_stereo_bond_orientations
|
885
|
+
),
|
886
|
+
"ligand_aromatic_5_ring_index": torch.tensor(
|
887
|
+
all_aromatic_5_ring_index
|
888
|
+
).long(),
|
889
|
+
"ligand_aromatic_6_ring_index": torch.tensor(
|
890
|
+
all_aromatic_6_ring_index
|
891
|
+
).long(),
|
892
|
+
"ligand_planar_double_bond_index": torch.tensor(
|
893
|
+
all_planar_double_bond_index
|
894
|
+
).long(),
|
895
|
+
}
|
896
|
+
else:
|
897
|
+
features = {
|
898
|
+
"ligand_symmetries": molecule_symmetries,
|
899
|
+
}
|
900
|
+
return features
|