boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2208 @@
1
+ import math
2
+ from typing import Optional
3
+
4
+ import numba
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ import rdkit.Chem.Descriptors
8
+ import torch
9
+ from numba import types
10
+ from rdkit.Chem import Mol
11
+ from scipy.spatial.distance import cdist
12
+ from torch import Tensor, from_numpy
13
+ from torch.nn.functional import one_hot
14
+
15
+ from boltz.data import const
16
+ from boltz.data.mol import (
17
+ get_amino_acids_symmetries,
18
+ get_chain_symmetries,
19
+ get_ligand_symmetries,
20
+ get_symmetries,
21
+ )
22
+ from boltz.data.pad import pad_dim
23
+ from boltz.data.types import (
24
+ MSA,
25
+ MSADeletion,
26
+ MSAResidue,
27
+ MSASequence,
28
+ TemplateInfo,
29
+ Tokenized,
30
+ )
31
+ from boltz.model.modules.utils import center_random_augmentation
32
+
33
+ ####################################################################################################
34
+ # HELPERS
35
+ ####################################################################################################
36
+
37
+
38
+ def convert_atom_name(name: str) -> tuple[int, int, int, int]:
39
+ """Convert an atom name to a standard format.
40
+
41
+ Parameters
42
+ ----------
43
+ name : str
44
+ The atom name.
45
+
46
+ Returns
47
+ -------
48
+ tuple[int, int, int, int]
49
+ The converted atom name.
50
+
51
+ """
52
+ name = str(name).strip()
53
+ name = [ord(c) - 32 for c in name]
54
+ name = name + [0] * (4 - len(name))
55
+ return tuple(name)
56
+
57
+
58
+ def sample_d(
59
+ min_d: float,
60
+ max_d: float,
61
+ n_samples: int,
62
+ random: np.random.Generator,
63
+ ) -> np.ndarray:
64
+ """Generate samples from a 1/d distribution between min_d and max_d.
65
+
66
+ Parameters
67
+ ----------
68
+ min_d : float
69
+ Minimum value of d
70
+ max_d : float
71
+ Maximum value of d
72
+ n_samples : int
73
+ Number of samples to generate
74
+ random : numpy.random.Generator
75
+ Random number generator
76
+
77
+ Returns
78
+ -------
79
+ numpy.ndarray
80
+ Array of samples drawn from the distribution
81
+
82
+ Notes
83
+ -----
84
+ The probability density function is:
85
+ f(d) = 1/(d * ln(max_d/min_d)) for d in [min_d, max_d]
86
+
87
+ The inverse CDF transform is:
88
+ d = min_d * (max_d/min_d)**u where u ~ Uniform(0,1)
89
+
90
+ """
91
+ # Generate n_samples uniform random numbers in [0, 1]
92
+ u = random.random(n_samples)
93
+ # Transform u using the inverse CDF
94
+ return min_d * (max_d / min_d) ** u
95
+
96
+
97
+ def compute_frames_nonpolymer(
98
+ data: Tokenized,
99
+ coords,
100
+ resolved_mask,
101
+ atom_to_token,
102
+ frame_data: list,
103
+ resolved_frame_data: list,
104
+ ) -> tuple[list, list]:
105
+ """Get the frames for non-polymer tokens.
106
+
107
+ Parameters
108
+ ----------
109
+ data : Tokenized
110
+ The input data to the model.
111
+ frame_data : list
112
+ The frame data.
113
+ resolved_frame_data : list
114
+ The resolved frame data.
115
+
116
+ Returns
117
+ -------
118
+ tuple[list, list]
119
+ The frame data and resolved frame data.
120
+
121
+ """
122
+ frame_data = np.array(frame_data)
123
+ resolved_frame_data = np.array(resolved_frame_data)
124
+ asym_id_token = data.tokens["asym_id"]
125
+ asym_id_atom = data.tokens["asym_id"][atom_to_token]
126
+ token_idx = 0
127
+ atom_idx = 0
128
+ for id in np.unique(data.tokens["asym_id"]):
129
+ mask_chain_token = asym_id_token == id
130
+ mask_chain_atom = asym_id_atom == id
131
+ num_tokens = mask_chain_token.sum()
132
+ num_atoms = mask_chain_atom.sum()
133
+ if (
134
+ data.tokens[token_idx]["mol_type"] != const.chain_type_ids["NONPOLYMER"]
135
+ or num_atoms < 3 # noqa: PLR2004
136
+ ):
137
+ token_idx += num_tokens
138
+ atom_idx += num_atoms
139
+ continue
140
+ dist_mat = (
141
+ (
142
+ coords.reshape(-1, 3)[mask_chain_atom][:, None, :]
143
+ - coords.reshape(-1, 3)[mask_chain_atom][None, :, :]
144
+ )
145
+ ** 2
146
+ ).sum(-1) ** 0.5
147
+ resolved_pair = 1 - (
148
+ resolved_mask[mask_chain_atom][None, :]
149
+ * resolved_mask[mask_chain_atom][:, None]
150
+ ).astype(np.float32)
151
+ resolved_pair[resolved_pair == 1] = math.inf
152
+ indices = np.argsort(dist_mat + resolved_pair, axis=1)
153
+ frames = (
154
+ np.concatenate(
155
+ [
156
+ indices[:, 1:2],
157
+ indices[:, 0:1],
158
+ indices[:, 2:3],
159
+ ],
160
+ axis=1,
161
+ )
162
+ + atom_idx
163
+ )
164
+ frame_data[token_idx : token_idx + num_atoms, :] = frames
165
+ resolved_frame_data[token_idx : token_idx + num_atoms] = resolved_mask[
166
+ frames
167
+ ].all(axis=1)
168
+ token_idx += num_tokens
169
+ atom_idx += num_atoms
170
+ frames_expanded = coords.reshape(-1, 3)[frame_data]
171
+
172
+ mask_collinear = compute_collinear_mask(
173
+ frames_expanded[:, 1] - frames_expanded[:, 0],
174
+ frames_expanded[:, 1] - frames_expanded[:, 2],
175
+ )
176
+ return frame_data, resolved_frame_data & mask_collinear
177
+
178
+
179
+ def compute_collinear_mask(v1, v2):
180
+ norm1 = np.linalg.norm(v1, axis=1, keepdims=True)
181
+ norm2 = np.linalg.norm(v2, axis=1, keepdims=True)
182
+ v1 = v1 / (norm1 + 1e-6)
183
+ v2 = v2 / (norm2 + 1e-6)
184
+ mask_angle = np.abs(np.sum(v1 * v2, axis=1)) < 0.9063
185
+ mask_overlap1 = norm1.reshape(-1) > 1e-2
186
+ mask_overlap2 = norm2.reshape(-1) > 1e-2
187
+ return mask_angle & mask_overlap1 & mask_overlap2
188
+
189
+
190
+ def dummy_msa(residues: np.ndarray) -> MSA:
191
+ """Create a dummy MSA for a chain.
192
+
193
+ Parameters
194
+ ----------
195
+ residues : np.ndarray
196
+ The residues for the chain.
197
+
198
+ Returns
199
+ -------
200
+ MSA
201
+ The dummy MSA.
202
+
203
+ """
204
+ residues = [res["res_type"] for res in residues]
205
+ deletions = []
206
+ sequences = [(0, -1, 0, len(residues), 0, 0)]
207
+ return MSA(
208
+ residues=np.array(residues, dtype=MSAResidue),
209
+ deletions=np.array(deletions, dtype=MSADeletion),
210
+ sequences=np.array(sequences, dtype=MSASequence),
211
+ )
212
+
213
+
214
+ def construct_paired_msa( # noqa: C901, PLR0915, PLR0912
215
+ data: Tokenized,
216
+ random: np.random.Generator,
217
+ max_seqs: int,
218
+ max_pairs: int = 8192,
219
+ max_total: int = 16384,
220
+ random_subset: bool = False,
221
+ ) -> tuple[Tensor, Tensor, Tensor]:
222
+ """Pair the MSA data.
223
+
224
+ Parameters
225
+ ----------
226
+ data : Tokenized
227
+ The input data to the model.
228
+
229
+ Returns
230
+ -------
231
+ Tensor
232
+ The MSA data.
233
+ Tensor
234
+ The deletion data.
235
+ Tensor
236
+ Mask indicating paired sequences.
237
+
238
+ """
239
+ # Get unique chains (ensuring monotonicity in the order)
240
+ assert np.all(np.diff(data.tokens["asym_id"], n=1) >= 0)
241
+ chain_ids = np.unique(data.tokens["asym_id"])
242
+
243
+ # Get relevant MSA, and create a dummy for chains without
244
+ msa: dict[int, MSA] = {}
245
+ for chain_id in chain_ids:
246
+ # Get input sequence
247
+ chain = data.structure.chains[chain_id]
248
+ res_start = chain["res_idx"]
249
+ res_end = res_start + chain["res_num"]
250
+ residues = data.structure.residues[res_start:res_end]
251
+
252
+ # Check if we have an MSA, and that the
253
+ # first sequence matches the input sequence
254
+ if chain_id in data.msa:
255
+ # Set the MSA
256
+ msa[chain_id] = data.msa[chain_id]
257
+
258
+ # Run length and residue type checks
259
+ first = data.msa[chain_id].sequences[0]
260
+ first_start = first["res_start"]
261
+ first_end = first["res_end"]
262
+ msa_residues = data.msa[chain_id].residues
263
+ first_residues = msa_residues[first_start:first_end]
264
+
265
+ warning = "Warning: MSA does not match input sequence, creating dummy."
266
+ if len(residues) == len(first_residues):
267
+ # If there is a mismatch, check if it is between MET & UNK
268
+ # If so, replace the first sequence with the input sequence.
269
+ # Otherwise, replace with a dummy MSA for this chain.
270
+ mismatches = residues["res_type"] != first_residues["res_type"]
271
+ if mismatches.sum().item():
272
+ idx = np.where(mismatches)[0]
273
+ is_met = residues["res_type"][idx] == const.token_ids["MET"]
274
+ is_unk = residues["res_type"][idx] == const.token_ids["UNK"]
275
+ is_msa_unk = (
276
+ first_residues["res_type"][idx] == const.token_ids["UNK"]
277
+ )
278
+ if (np.all(is_met) and np.all(is_msa_unk)) or np.all(is_unk):
279
+ msa_residues[first_start:first_end]["res_type"] = residues[
280
+ "res_type"
281
+ ]
282
+ else:
283
+ print(
284
+ warning,
285
+ "1",
286
+ residues["res_type"],
287
+ first_residues["res_type"],
288
+ data.record.id,
289
+ )
290
+ msa[chain_id] = dummy_msa(residues)
291
+ else:
292
+ print(
293
+ warning,
294
+ "2",
295
+ residues["res_type"],
296
+ first_residues["res_type"],
297
+ data.record.id,
298
+ )
299
+ msa[chain_id] = dummy_msa(residues)
300
+ else:
301
+ msa[chain_id] = dummy_msa(residues)
302
+
303
+ # Map taxonomies to (chain_id, seq_idx)
304
+ taxonomy_map: dict[str, list] = {}
305
+ for chain_id, chain_msa in msa.items():
306
+ sequences = chain_msa.sequences
307
+ sequences = sequences[sequences["taxonomy"] != -1]
308
+ for sequence in sequences:
309
+ seq_idx = sequence["seq_idx"]
310
+ taxon = sequence["taxonomy"]
311
+ taxonomy_map.setdefault(taxon, []).append((chain_id, seq_idx))
312
+
313
+ # Remove taxonomies with only one sequence and sort by the
314
+ # number of chain_id present in each of the taxonomies
315
+ taxonomy_map = {k: v for k, v in taxonomy_map.items() if len(v) > 1}
316
+ taxonomy_map = sorted(
317
+ taxonomy_map.items(),
318
+ key=lambda x: len({c for c, _ in x[1]}),
319
+ reverse=True,
320
+ )
321
+
322
+ # Keep track of the sequences available per chain, keeping the original
323
+ # order of the sequences in the MSA to favor the best matching sequences
324
+ visited = {(c, s) for c, items in taxonomy_map for s in items}
325
+ available = {}
326
+ for c in chain_ids:
327
+ available[c] = [
328
+ i for i in range(1, len(msa[c].sequences)) if (c, i) not in visited
329
+ ]
330
+
331
+ # Create sequence pairs
332
+ is_paired = []
333
+ pairing = []
334
+
335
+ # Start with the first sequence for each chain
336
+ is_paired.append({c: 1 for c in chain_ids})
337
+ pairing.append({c: 0 for c in chain_ids})
338
+
339
+ # Then add up to 8191 paired rows
340
+ for _, pairs in taxonomy_map:
341
+ # Group occurences by chain_id in case we have multiple
342
+ # sequences from the same chain and same taxonomy
343
+ chain_occurences = {}
344
+ for chain_id, seq_idx in pairs:
345
+ chain_occurences.setdefault(chain_id, []).append(seq_idx)
346
+
347
+ # We create as many pairings as the maximum number of occurences
348
+ max_occurences = max(len(v) for v in chain_occurences.values())
349
+ for i in range(max_occurences):
350
+ row_pairing = {}
351
+ row_is_paired = {}
352
+
353
+ # Add the chains present in the taxonomy
354
+ for chain_id, seq_idxs in chain_occurences.items():
355
+ # Roll over the sequence index to maximize diversity
356
+ idx = i % len(seq_idxs)
357
+ seq_idx = seq_idxs[idx]
358
+
359
+ # Add the sequence to the pairing
360
+ row_pairing[chain_id] = seq_idx
361
+ row_is_paired[chain_id] = 1
362
+
363
+ # Add any missing chains
364
+ for chain_id in chain_ids:
365
+ if chain_id not in row_pairing:
366
+ row_is_paired[chain_id] = 0
367
+ if available[chain_id]:
368
+ # Add the next available sequence
369
+ seq_idx = available[chain_id].pop(0)
370
+ row_pairing[chain_id] = seq_idx
371
+ else:
372
+ # No more sequences available, we place a gap
373
+ row_pairing[chain_id] = -1
374
+
375
+ pairing.append(row_pairing)
376
+ is_paired.append(row_is_paired)
377
+
378
+ # Break if we have enough pairs
379
+ if len(pairing) >= max_pairs:
380
+ break
381
+
382
+ # Break if we have enough pairs
383
+ if len(pairing) >= max_pairs:
384
+ break
385
+
386
+ # Now add up to 16384 unpaired rows total
387
+ max_left = max(len(v) for v in available.values())
388
+ for _ in range(min(max_total - len(pairing), max_left)):
389
+ row_pairing = {}
390
+ row_is_paired = {}
391
+ for chain_id in chain_ids:
392
+ row_is_paired[chain_id] = 0
393
+ if available[chain_id]:
394
+ # Add the next available sequence
395
+ seq_idx = available[chain_id].pop(0)
396
+ row_pairing[chain_id] = seq_idx
397
+ else:
398
+ # No more sequences available, we place a gap
399
+ row_pairing[chain_id] = -1
400
+
401
+ pairing.append(row_pairing)
402
+ is_paired.append(row_is_paired)
403
+
404
+ # Break if we have enough sequences
405
+ if len(pairing) >= max_total:
406
+ break
407
+
408
+ # Randomly sample a subset of the pairs
409
+ # ensuring the first row is always present
410
+ if random_subset:
411
+ num_seqs = len(pairing)
412
+ if num_seqs > max_seqs:
413
+ indices = random.choice(
414
+ np.arange(1, num_seqs), size=max_seqs - 1, replace=False
415
+ ) # noqa: NPY002
416
+ pairing = [pairing[0]] + [pairing[i] for i in indices]
417
+ is_paired = [is_paired[0]] + [is_paired[i] for i in indices]
418
+ else:
419
+ # Deterministic downsample to max_seqs
420
+ pairing = pairing[:max_seqs]
421
+ is_paired = is_paired[:max_seqs]
422
+
423
+ # Map (chain_id, seq_idx, res_idx) to deletion
424
+ deletions = {}
425
+ for chain_id, chain_msa in msa.items():
426
+ chain_deletions = chain_msa.deletions
427
+ for sequence in chain_msa.sequences:
428
+ del_start = sequence["del_start"]
429
+ del_end = sequence["del_end"]
430
+ chain_deletions = chain_msa.deletions[del_start:del_end]
431
+ for deletion_data in chain_deletions:
432
+ seq_idx = sequence["seq_idx"]
433
+ res_idx = deletion_data["res_idx"]
434
+ deletion = deletion_data["deletion"]
435
+ deletions[(chain_id, seq_idx, res_idx)] = deletion
436
+
437
+ # Add all the token MSA data
438
+ msa_data, del_data, paired_data = prepare_msa_arrays(
439
+ data.tokens, pairing, is_paired, deletions, msa
440
+ )
441
+
442
+ msa_data = torch.tensor(msa_data, dtype=torch.long)
443
+ del_data = torch.tensor(del_data, dtype=torch.float)
444
+ paired_data = torch.tensor(paired_data, dtype=torch.float)
445
+
446
+ return msa_data, del_data, paired_data
447
+
448
+
449
+ def prepare_msa_arrays(
450
+ tokens,
451
+ pairing: list[dict[int, int]],
452
+ is_paired: list[dict[int, int]],
453
+ deletions: dict[tuple[int, int, int], int],
454
+ msa: dict[int, MSA],
455
+ ) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
456
+ """Reshape data to play nicely with numba jit."""
457
+ token_asym_ids_arr = np.array([t["asym_id"] for t in tokens], dtype=np.int64)
458
+ token_res_idxs_arr = np.array([t["res_idx"] for t in tokens], dtype=np.int64)
459
+
460
+ chain_ids = sorted(msa.keys())
461
+
462
+ # chain_ids are not necessarily contiguous (e.g. they might be 0, 24, 25).
463
+ # This allows us to look up a chain_id by it's index in the chain_ids list.
464
+ chain_id_to_idx = {chain_id: i for i, chain_id in enumerate(chain_ids)}
465
+ token_asym_ids_idx_arr = np.array(
466
+ [chain_id_to_idx[asym_id] for asym_id in token_asym_ids_arr], dtype=np.int64
467
+ )
468
+
469
+ pairing_arr = np.zeros((len(pairing), len(chain_ids)), dtype=np.int64)
470
+ is_paired_arr = np.zeros((len(is_paired), len(chain_ids)), dtype=np.int64)
471
+
472
+ for i, row_pairing in enumerate(pairing):
473
+ for chain_id in chain_ids:
474
+ pairing_arr[i, chain_id_to_idx[chain_id]] = row_pairing[chain_id]
475
+
476
+ for i, row_is_paired in enumerate(is_paired):
477
+ for chain_id in chain_ids:
478
+ is_paired_arr[i, chain_id_to_idx[chain_id]] = row_is_paired[chain_id]
479
+
480
+ max_seq_len = max(len(msa[chain_id].sequences) for chain_id in chain_ids)
481
+
482
+ # we want res_start from sequences
483
+ msa_sequences = np.full((len(chain_ids), max_seq_len), -1, dtype=np.int64)
484
+ for chain_id in chain_ids:
485
+ for i, seq in enumerate(msa[chain_id].sequences):
486
+ msa_sequences[chain_id_to_idx[chain_id], i] = seq["res_start"]
487
+
488
+ max_residues_len = max(len(msa[chain_id].residues) for chain_id in chain_ids)
489
+ msa_residues = np.full((len(chain_ids), max_residues_len), -1, dtype=np.int64)
490
+ for chain_id in chain_ids:
491
+ residues = msa[chain_id].residues.astype(np.int64)
492
+ idxs = np.arange(len(residues))
493
+ chain_idx = chain_id_to_idx[chain_id]
494
+ msa_residues[chain_idx, idxs] = residues
495
+
496
+ deletions_dict = numba.typed.Dict.empty(
497
+ key_type=numba.types.Tuple(
498
+ [numba.types.int64, numba.types.int64, numba.types.int64]
499
+ ),
500
+ value_type=numba.types.int64,
501
+ )
502
+ deletions_dict.update(deletions)
503
+
504
+ return _prepare_msa_arrays_inner(
505
+ token_asym_ids_arr,
506
+ token_res_idxs_arr,
507
+ token_asym_ids_idx_arr,
508
+ pairing_arr,
509
+ is_paired_arr,
510
+ deletions_dict,
511
+ msa_sequences,
512
+ msa_residues,
513
+ const.token_ids["-"],
514
+ )
515
+
516
+
517
+ deletions_dict_type = types.DictType(types.UniTuple(types.int64, 3), types.int64)
518
+
519
+
520
+ @numba.njit(
521
+ [
522
+ types.Tuple(
523
+ (
524
+ types.int64[:, ::1], # msa_data
525
+ types.int64[:, ::1], # del_data
526
+ types.int64[:, ::1], # paired_data
527
+ )
528
+ )(
529
+ types.int64[::1], # token_asym_ids
530
+ types.int64[::1], # token_res_idxs
531
+ types.int64[::1], # token_asym_ids_idx
532
+ types.int64[:, ::1], # pairing
533
+ types.int64[:, ::1], # is_paired
534
+ deletions_dict_type, # deletions
535
+ types.int64[:, ::1], # msa_sequences
536
+ types.int64[:, ::1], # msa_residues
537
+ types.int64, # gap_token
538
+ )
539
+ ],
540
+ cache=True,
541
+ )
542
+ def _prepare_msa_arrays_inner(
543
+ token_asym_ids: npt.NDArray[np.int64],
544
+ token_res_idxs: npt.NDArray[np.int64],
545
+ token_asym_ids_idx: npt.NDArray[np.int64],
546
+ pairing: npt.NDArray[np.int64],
547
+ is_paired: npt.NDArray[np.int64],
548
+ deletions: dict[tuple[int, int, int], int],
549
+ msa_sequences: npt.NDArray[np.int64],
550
+ msa_residues: npt.NDArray[np.int64],
551
+ gap_token: int,
552
+ ) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
553
+ n_tokens = len(token_asym_ids)
554
+ n_pairs = len(pairing)
555
+ msa_data = np.full((n_tokens, n_pairs), gap_token, dtype=np.int64)
556
+ paired_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
557
+ del_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
558
+
559
+ # Add all the token MSA data
560
+ for token_idx in range(n_tokens):
561
+ chain_id_idx = token_asym_ids_idx[token_idx]
562
+ chain_id = token_asym_ids[token_idx]
563
+ res_idx = token_res_idxs[token_idx]
564
+
565
+ for pair_idx in range(n_pairs):
566
+ seq_idx = pairing[pair_idx, chain_id_idx]
567
+ paired_data[token_idx, pair_idx] = is_paired[pair_idx, chain_id_idx]
568
+
569
+ # Add residue type
570
+ if seq_idx != -1:
571
+ res_start = msa_sequences[chain_id_idx, seq_idx]
572
+ res_type = msa_residues[chain_id_idx, res_start + res_idx]
573
+ k = (chain_id, seq_idx, res_idx)
574
+ if k in deletions:
575
+ del_data[token_idx, pair_idx] = deletions[k]
576
+ msa_data[token_idx, pair_idx] = res_type
577
+
578
+ return msa_data, del_data, paired_data
579
+
580
+
581
+ ####################################################################################################
582
+ # FEATURES
583
+ ####################################################################################################
584
+
585
+
586
+ def select_subset_from_mask(mask, p, random: np.random.Generator) -> np.ndarray:
587
+ num_true = np.sum(mask)
588
+ v = random.geometric(p) + 1
589
+ k = min(v, num_true)
590
+
591
+ true_indices = np.where(mask)[0]
592
+
593
+ # Randomly select k indices from the true_indices
594
+ selected_indices = random.choice(true_indices, size=k, replace=False)
595
+
596
+ new_mask = np.zeros_like(mask)
597
+ new_mask[selected_indices] = 1
598
+
599
+ return new_mask
600
+
601
+
602
+ def get_range_bin(value: float, range_dict: dict[tuple[float, float], int], default=0):
603
+ """Get the bin of a value given a range dictionary."""
604
+ value = float(value)
605
+ for k, idx in range_dict.items():
606
+ if k == "other":
607
+ continue
608
+ low, high = k
609
+ if low <= value < high:
610
+ return idx
611
+ return default
612
+
613
+
614
+ def process_token_features( # noqa: C901, PLR0915, PLR0912
615
+ data: Tokenized,
616
+ random: np.random.Generator,
617
+ max_tokens: Optional[int] = None,
618
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
619
+ contact_conditioned_prop: Optional[float] = 0.0,
620
+ binder_pocket_cutoff_min: Optional[float] = 4.0,
621
+ binder_pocket_cutoff_max: Optional[float] = 20.0,
622
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
623
+ only_ligand_binder_pocket: Optional[bool] = False,
624
+ only_pp_contact: Optional[bool] = False,
625
+ inference_pocket_constraints: Optional[bool] = False,
626
+ override_method: Optional[str] = None,
627
+ ) -> dict[str, Tensor]:
628
+ """Get the token features.
629
+
630
+ Parameters
631
+ ----------
632
+ data : Tokenized
633
+ The input data to the model.
634
+ max_tokens : int
635
+ The maximum number of tokens.
636
+
637
+ Returns
638
+ -------
639
+ dict[str, Tensor]
640
+ The token features.
641
+
642
+ """
643
+ # Token data
644
+ token_data = data.tokens
645
+ token_bonds = data.bonds
646
+
647
+ # Token core features
648
+ token_index = torch.arange(len(token_data), dtype=torch.long)
649
+ residue_index = from_numpy(token_data["res_idx"]).long()
650
+ asym_id = from_numpy(token_data["asym_id"]).long()
651
+ entity_id = from_numpy(token_data["entity_id"]).long()
652
+ sym_id = from_numpy(token_data["sym_id"]).long()
653
+ mol_type = from_numpy(token_data["mol_type"]).long()
654
+ res_type = from_numpy(token_data["res_type"]).long()
655
+ res_type = one_hot(res_type, num_classes=const.num_tokens)
656
+ disto_center = from_numpy(token_data["disto_coords"])
657
+ modified = from_numpy(token_data["modified"]).long() # float()
658
+ cyclic_period = from_numpy(token_data["cyclic_period"].copy())
659
+ affinity_mask = from_numpy(token_data["affinity_mask"]).float()
660
+
661
+ ## Conditioning features ##
662
+ method = (
663
+ np.zeros(len(token_data))
664
+ + const.method_types_ids[
665
+ (
666
+ "x-ray diffraction"
667
+ if override_method is None
668
+ else override_method.lower()
669
+ )
670
+ ]
671
+ )
672
+ if data.record is not None:
673
+ if (
674
+ override_method is None
675
+ and data.record.structure.method is not None
676
+ and data.record.structure.method.lower() in const.method_types_ids
677
+ ):
678
+ method = (method * 0) + const.method_types_ids[
679
+ data.record.structure.method.lower()
680
+ ]
681
+
682
+ method_feature = from_numpy(method).long()
683
+
684
+ # Token mask features
685
+ pad_mask = torch.ones(len(token_data), dtype=torch.float)
686
+ resolved_mask = from_numpy(token_data["resolved_mask"]).float()
687
+ disto_mask = from_numpy(token_data["disto_mask"]).float()
688
+
689
+ # Token bond features
690
+ if max_tokens is not None:
691
+ pad_len = max_tokens - len(token_data)
692
+ num_tokens = max_tokens if pad_len > 0 else len(token_data)
693
+ else:
694
+ num_tokens = len(token_data)
695
+
696
+ tok_to_idx = {tok["token_idx"]: idx for idx, tok in enumerate(token_data)}
697
+ bonds = torch.zeros(num_tokens, num_tokens, dtype=torch.float)
698
+ bonds_type = torch.zeros(num_tokens, num_tokens, dtype=torch.long)
699
+ for token_bond in token_bonds:
700
+ token_1 = tok_to_idx[token_bond["token_1"]]
701
+ token_2 = tok_to_idx[token_bond["token_2"]]
702
+ bonds[token_1, token_2] = 1
703
+ bonds[token_2, token_1] = 1
704
+ bond_type = token_bond["type"]
705
+ bonds_type[token_1, token_2] = bond_type
706
+ bonds_type[token_2, token_1] = bond_type
707
+
708
+ bonds = bonds.unsqueeze(-1)
709
+
710
+ # Pocket conditioned feature
711
+ contact_conditioning = (
712
+ np.zeros((len(token_data), len(token_data)))
713
+ + const.contact_conditioning_info["UNSELECTED"]
714
+ )
715
+ contact_threshold = np.zeros((len(token_data), len(token_data)))
716
+
717
+ if inference_pocket_constraints is not None:
718
+ for binder, contacts, max_distance in inference_pocket_constraints:
719
+ binder_mask = token_data["asym_id"] == binder
720
+
721
+ for idx, token in enumerate(token_data):
722
+ if (
723
+ token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
724
+ and (token["asym_id"], token["res_idx"]) in contacts
725
+ ) or (
726
+ token["mol_type"] == const.chain_type_ids["NONPOLYMER"]
727
+ and (token["asym_id"], token["atom_idx"]) in contacts
728
+ ):
729
+ contact_conditioning[binder_mask][:, idx] = (
730
+ const.contact_conditioning_info["BINDER>POCKET"]
731
+ )
732
+ contact_conditioning[idx][binder_mask] = (
733
+ const.contact_conditioning_info["POCKET>BINDER"]
734
+ )
735
+ contact_threshold[binder_mask][:, idx] = max_distance
736
+ contact_threshold[idx][binder_mask] = max_distance
737
+
738
+ if binder_pocket_conditioned_prop > 0.0:
739
+ # choose as binder a random ligand in the crop, if there are no ligands select a protein chain
740
+ binder_asym_ids = np.unique(
741
+ token_data["asym_id"][
742
+ token_data["mol_type"] == const.chain_type_ids["NONPOLYMER"]
743
+ ]
744
+ )
745
+
746
+ if len(binder_asym_ids) == 0:
747
+ if not only_ligand_binder_pocket:
748
+ binder_asym_ids = np.unique(token_data["asym_id"])
749
+
750
+ while random.random() < binder_pocket_conditioned_prop:
751
+ if len(binder_asym_ids) == 0:
752
+ break
753
+
754
+ pocket_asym_id = random.choice(binder_asym_ids)
755
+ binder_asym_ids = binder_asym_ids[binder_asym_ids != pocket_asym_id]
756
+
757
+ binder_pocket_cutoff = sample_d(
758
+ min_d=binder_pocket_cutoff_min,
759
+ max_d=binder_pocket_cutoff_max,
760
+ n_samples=1,
761
+ random=random,
762
+ )
763
+
764
+ binder_mask = token_data["asym_id"] == pocket_asym_id
765
+
766
+ binder_coords = []
767
+ for token in token_data:
768
+ if token["asym_id"] == pocket_asym_id:
769
+ _coords = data.structure.atoms["coords"][
770
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
771
+ ]
772
+ _is_present = data.structure.atoms["is_present"][
773
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
774
+ ]
775
+ binder_coords.append(_coords[_is_present])
776
+ binder_coords = np.concatenate(binder_coords, axis=0)
777
+
778
+ # find the tokens in the pocket
779
+ token_dist = np.zeros(len(token_data)) + 1000
780
+ for i, token in enumerate(token_data):
781
+ if (
782
+ token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
783
+ and token["asym_id"] != pocket_asym_id
784
+ and token["resolved_mask"] == 1
785
+ ):
786
+ token_coords = data.structure.atoms["coords"][
787
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
788
+ ]
789
+ token_is_present = data.structure.atoms["is_present"][
790
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
791
+ ]
792
+ token_coords = token_coords[token_is_present]
793
+
794
+ # find chain and apply chain transformation
795
+ for chain in data.structure.chains:
796
+ if chain["asym_id"] == token["asym_id"]:
797
+ break
798
+
799
+ token_dist[i] = np.min(
800
+ np.linalg.norm(
801
+ token_coords[:, None, :] - binder_coords[None, :, :],
802
+ axis=-1,
803
+ )
804
+ )
805
+
806
+ pocket_mask = token_dist < binder_pocket_cutoff
807
+
808
+ if np.sum(pocket_mask) > 0:
809
+ if binder_pocket_sampling_geometric_p > 0.0:
810
+ # select a subset of the pocket, according
811
+ # to a geometric distribution with one as minimum
812
+ pocket_mask = select_subset_from_mask(
813
+ pocket_mask,
814
+ binder_pocket_sampling_geometric_p,
815
+ random,
816
+ )
817
+
818
+ contact_conditioning[np.ix_(binder_mask, pocket_mask)] = (
819
+ const.contact_conditioning_info["BINDER>POCKET"]
820
+ )
821
+ contact_conditioning[np.ix_(pocket_mask, binder_mask)] = (
822
+ const.contact_conditioning_info["POCKET>BINDER"]
823
+ )
824
+ contact_threshold[np.ix_(binder_mask, pocket_mask)] = (
825
+ binder_pocket_cutoff
826
+ )
827
+ contact_threshold[np.ix_(pocket_mask, binder_mask)] = (
828
+ binder_pocket_cutoff
829
+ )
830
+
831
+ # Contact conditioning feature
832
+ if contact_conditioned_prop > 0.0:
833
+ while random.random() < contact_conditioned_prop:
834
+ contact_cutoff = sample_d(
835
+ min_d=binder_pocket_cutoff_min,
836
+ max_d=binder_pocket_cutoff_max,
837
+ n_samples=1,
838
+ random=random,
839
+ )
840
+ if only_pp_contact:
841
+ chain_asym_ids = np.unique(
842
+ token_data["asym_id"][
843
+ token_data["mol_type"] == const.chain_type_ids["PROTEIN"]
844
+ ]
845
+ )
846
+ else:
847
+ chain_asym_ids = np.unique(token_data["asym_id"])
848
+
849
+ if len(chain_asym_ids) > 1:
850
+ chain_asym_id = random.choice(chain_asym_ids)
851
+
852
+ chain_coords = []
853
+ for token in token_data:
854
+ if token["asym_id"] == chain_asym_id:
855
+ _coords = data.structure.atoms["coords"][
856
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
857
+ ]
858
+ _is_present = data.structure.atoms["is_present"][
859
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
860
+ ]
861
+ chain_coords.append(_coords[_is_present])
862
+ chain_coords = np.concatenate(chain_coords, axis=0)
863
+
864
+ # find contacts in other chains
865
+ possible_other_chains = []
866
+ for other_chain_id in chain_asym_ids[chain_asym_ids != chain_asym_id]:
867
+ for token in token_data:
868
+ if token["asym_id"] == other_chain_id:
869
+ _coords = data.structure.atoms["coords"][
870
+ token["atom_idx"] : token["atom_idx"]
871
+ + token["atom_num"]
872
+ ]
873
+ _is_present = data.structure.atoms["is_present"][
874
+ token["atom_idx"] : token["atom_idx"]
875
+ + token["atom_num"]
876
+ ]
877
+ if _is_present.sum() == 0:
878
+ continue
879
+ token_coords = _coords[_is_present]
880
+
881
+ # check minimum distance
882
+ if (
883
+ np.min(cdist(chain_coords, token_coords))
884
+ < contact_cutoff
885
+ ):
886
+ possible_other_chains.append(other_chain_id)
887
+ break
888
+
889
+ if len(possible_other_chains) > 0:
890
+ other_chain_id = random.choice(possible_other_chains)
891
+
892
+ pairs = []
893
+ for token_1 in token_data:
894
+ if token_1["asym_id"] == chain_asym_id:
895
+ _coords = data.structure.atoms["coords"][
896
+ token_1["atom_idx"] : token_1["atom_idx"]
897
+ + token_1["atom_num"]
898
+ ]
899
+ _is_present = data.structure.atoms["is_present"][
900
+ token_1["atom_idx"] : token_1["atom_idx"]
901
+ + token_1["atom_num"]
902
+ ]
903
+ if _is_present.sum() == 0:
904
+ continue
905
+ token_1_coords = _coords[_is_present]
906
+
907
+ for token_2 in token_data:
908
+ if token_2["asym_id"] == other_chain_id:
909
+ _coords = data.structure.atoms["coords"][
910
+ token_2["atom_idx"] : token_2["atom_idx"]
911
+ + token_2["atom_num"]
912
+ ]
913
+ _is_present = data.structure.atoms["is_present"][
914
+ token_2["atom_idx"] : token_2["atom_idx"]
915
+ + token_2["atom_num"]
916
+ ]
917
+ if _is_present.sum() == 0:
918
+ continue
919
+ token_2_coords = _coords[_is_present]
920
+
921
+ if (
922
+ np.min(cdist(token_1_coords, token_2_coords))
923
+ < contact_cutoff
924
+ ):
925
+ pairs.append(
926
+ (token_1["token_idx"], token_2["token_idx"])
927
+ )
928
+
929
+ assert len(pairs) > 0
930
+
931
+ pair = random.choice(pairs)
932
+ token_1_mask = token_data["token_idx"] == pair[0]
933
+ token_2_mask = token_data["token_idx"] == pair[1]
934
+
935
+ contact_conditioning[np.ix_(token_1_mask, token_2_mask)] = (
936
+ const.contact_conditioning_info["CONTACT"]
937
+ )
938
+ contact_conditioning[np.ix_(token_2_mask, token_1_mask)] = (
939
+ const.contact_conditioning_info["CONTACT"]
940
+ )
941
+
942
+ elif not only_pp_contact:
943
+ # only one chain, find contacts within the chain with minimum residue distance
944
+ pairs = []
945
+ for token_1 in token_data:
946
+ _coords = data.structure.atoms["coords"][
947
+ token_1["atom_idx"] : token_1["atom_idx"] + token_1["atom_num"]
948
+ ]
949
+ _is_present = data.structure.atoms["is_present"][
950
+ token_1["atom_idx"] : token_1["atom_idx"] + token_1["atom_num"]
951
+ ]
952
+ if _is_present.sum() == 0:
953
+ continue
954
+ token_1_coords = _coords[_is_present]
955
+
956
+ for token_2 in token_data:
957
+ if np.abs(token_1["res_idx"] - token_2["res_idx"]) <= 8:
958
+ continue
959
+
960
+ _coords = data.structure.atoms["coords"][
961
+ token_2["atom_idx"] : token_2["atom_idx"]
962
+ + token_2["atom_num"]
963
+ ]
964
+ _is_present = data.structure.atoms["is_present"][
965
+ token_2["atom_idx"] : token_2["atom_idx"]
966
+ + token_2["atom_num"]
967
+ ]
968
+ if _is_present.sum() == 0:
969
+ continue
970
+ token_2_coords = _coords[_is_present]
971
+
972
+ if (
973
+ np.min(cdist(token_1_coords, token_2_coords))
974
+ < contact_cutoff
975
+ ):
976
+ pairs.append((token_1["token_idx"], token_2["token_idx"]))
977
+
978
+ if len(pairs) > 0:
979
+ pair = random.choice(pairs)
980
+ token_1_mask = token_data["token_idx"] == pair[0]
981
+ token_2_mask = token_data["token_idx"] == pair[1]
982
+
983
+ contact_conditioning[np.ix_(token_1_mask, token_2_mask)] = (
984
+ const.contact_conditioning_info["CONTACT"]
985
+ )
986
+ contact_conditioning[np.ix_(token_2_mask, token_1_mask)] = (
987
+ const.contact_conditioning_info["CONTACT"]
988
+ )
989
+
990
+ if np.all(contact_conditioning == const.contact_conditioning_info["UNSELECTED"]):
991
+ contact_conditioning = (
992
+ contact_conditioning
993
+ - const.contact_conditioning_info["UNSELECTED"]
994
+ + const.contact_conditioning_info["UNSPECIFIED"]
995
+ )
996
+ contact_conditioning = from_numpy(contact_conditioning).long()
997
+ contact_conditioning = one_hot(
998
+ contact_conditioning, num_classes=len(const.contact_conditioning_info)
999
+ )
1000
+ contact_threshold = from_numpy(contact_threshold).float()
1001
+
1002
+ # compute cyclic polymer mask
1003
+ cyclic_ids = {}
1004
+ for idx_chain, asym_id_iter in enumerate(data.structure.chains["asym_id"]):
1005
+ for connection in data.structure.bonds:
1006
+ if (
1007
+ idx_chain == connection["chain_1"] == connection["chain_2"]
1008
+ and data.structure.chains[connection["chain_1"]]["res_num"] > 2
1009
+ and connection["res_1"]
1010
+ != connection["res_2"] # Avoid same residue bonds!
1011
+ ):
1012
+ if (
1013
+ data.structure.chains[connection["chain_1"]]["res_num"]
1014
+ == (connection["res_2"] + 1)
1015
+ and connection["res_1"] == 0
1016
+ ) or (
1017
+ data.structure.chains[connection["chain_1"]]["res_num"]
1018
+ == (connection["res_1"] + 1)
1019
+ and connection["res_2"] == 0
1020
+ ):
1021
+ cyclic_ids[asym_id_iter] = data.structure.chains[
1022
+ connection["chain_1"]
1023
+ ]["res_num"]
1024
+ cyclic = from_numpy(
1025
+ np.array(
1026
+ [
1027
+ (cyclic_ids[asym_id_iter] if asym_id_iter in cyclic_ids else 0)
1028
+ for asym_id_iter in token_data["asym_id"]
1029
+ ]
1030
+ )
1031
+ ).float()
1032
+
1033
+ # cyclic period is either computed from the bonds or given as input flag
1034
+ cyclic_period = torch.maximum(cyclic, cyclic_period)
1035
+
1036
+ # Pad to max tokens if given
1037
+ if max_tokens is not None:
1038
+ pad_len = max_tokens - len(token_data)
1039
+ if pad_len > 0:
1040
+ token_index = pad_dim(token_index, 0, pad_len)
1041
+ residue_index = pad_dim(residue_index, 0, pad_len)
1042
+ asym_id = pad_dim(asym_id, 0, pad_len)
1043
+ entity_id = pad_dim(entity_id, 0, pad_len)
1044
+ sym_id = pad_dim(sym_id, 0, pad_len)
1045
+ mol_type = pad_dim(mol_type, 0, pad_len)
1046
+ res_type = pad_dim(res_type, 0, pad_len)
1047
+ disto_center = pad_dim(disto_center, 0, pad_len)
1048
+ pad_mask = pad_dim(pad_mask, 0, pad_len)
1049
+ resolved_mask = pad_dim(resolved_mask, 0, pad_len)
1050
+ disto_mask = pad_dim(disto_mask, 0, pad_len)
1051
+ contact_conditioning = pad_dim(contact_conditioning, 0, pad_len)
1052
+ contact_conditioning = pad_dim(contact_conditioning, 1, pad_len)
1053
+ contact_threshold = pad_dim(contact_threshold, 0, pad_len)
1054
+ contact_threshold = pad_dim(contact_threshold, 1, pad_len)
1055
+ method_feature = pad_dim(method_feature, 0, pad_len)
1056
+ modified = pad_dim(modified, 0, pad_len)
1057
+ cyclic_period = pad_dim(cyclic_period, 0, pad_len)
1058
+ affinity_mask = pad_dim(affinity_mask, 0, pad_len)
1059
+
1060
+ token_features = {
1061
+ "token_index": token_index,
1062
+ "residue_index": residue_index,
1063
+ "asym_id": asym_id,
1064
+ "entity_id": entity_id,
1065
+ "sym_id": sym_id,
1066
+ "mol_type": mol_type,
1067
+ "res_type": res_type,
1068
+ "disto_center": disto_center,
1069
+ "token_bonds": bonds,
1070
+ "type_bonds": bonds_type,
1071
+ "token_pad_mask": pad_mask,
1072
+ "token_resolved_mask": resolved_mask,
1073
+ "token_disto_mask": disto_mask,
1074
+ "contact_conditioning": contact_conditioning,
1075
+ "contact_threshold": contact_threshold,
1076
+ "method_feature": method_feature,
1077
+ "modified": modified,
1078
+ "cyclic_period": cyclic_period,
1079
+ "affinity_token_mask": affinity_mask,
1080
+ }
1081
+
1082
+ return token_features
1083
+
1084
+
1085
+ def process_atom_features(
1086
+ data: Tokenized,
1087
+ random: np.random.Generator,
1088
+ ensemble_features: dict,
1089
+ molecules: dict[str, Mol],
1090
+ atoms_per_window_queries: int = 32,
1091
+ min_dist: float = 2.0,
1092
+ max_dist: float = 22.0,
1093
+ num_bins: int = 64,
1094
+ max_atoms: Optional[int] = None,
1095
+ max_tokens: Optional[int] = None,
1096
+ disto_use_ensemble: Optional[bool] = False,
1097
+ override_bfactor: bool = False,
1098
+ compute_frames: bool = False,
1099
+ override_coords: Optional[Tensor] = None,
1100
+ bfactor_md_correction: bool = False,
1101
+ ) -> dict[str, Tensor]:
1102
+ """Get the atom features.
1103
+
1104
+ Parameters
1105
+ ----------
1106
+ data : Tokenized
1107
+ The input to the model.
1108
+ max_atoms : int, optional
1109
+ The maximum number of atoms.
1110
+
1111
+ Returns
1112
+ -------
1113
+ dict[str, Tensor]
1114
+ The atom features.
1115
+
1116
+ """
1117
+ # Filter to tokens' atoms
1118
+ atom_data = []
1119
+ atom_name = []
1120
+ atom_element = []
1121
+ atom_charge = []
1122
+ atom_conformer = []
1123
+ atom_chirality = []
1124
+ ref_space_uid = []
1125
+ coord_data = []
1126
+ if compute_frames:
1127
+ frame_data = []
1128
+ resolved_frame_data = []
1129
+ atom_to_token = []
1130
+ token_to_rep_atom = [] # index on cropped atom table
1131
+ r_set_to_rep_atom = []
1132
+ disto_coords_ensemble = []
1133
+ backbone_feat_index = []
1134
+ token_to_center_atom = []
1135
+
1136
+ e_offsets = data.structure.ensemble["atom_coord_idx"]
1137
+ atom_idx = 0
1138
+
1139
+ # Start atom idx in full atom table for structures chosen. Up to num_ensembles points.
1140
+ ensemble_atom_starts = [
1141
+ data.structure.ensemble[idx]["atom_coord_idx"]
1142
+ for idx in ensemble_features["ensemble_ref_idxs"]
1143
+ ]
1144
+
1145
+ # Set unk chirality id
1146
+ unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
1147
+
1148
+ chain_res_ids = {}
1149
+ res_index_to_conf_id = {}
1150
+ for token_id, token in enumerate(data.tokens):
1151
+ # Get the chain residue ids
1152
+ chain_idx, res_id = token["asym_id"], token["res_idx"]
1153
+ chain = data.structure.chains[chain_idx]
1154
+
1155
+ if (chain_idx, res_id) not in chain_res_ids:
1156
+ new_idx = len(chain_res_ids)
1157
+ chain_res_ids[(chain_idx, res_id)] = new_idx
1158
+ else:
1159
+ new_idx = chain_res_ids[(chain_idx, res_id)]
1160
+
1161
+ # Get the molecule and conformer
1162
+ mol = molecules[token["res_name"]]
1163
+ atom_name_to_ref = {a.GetProp("name"): a for a in mol.GetAtoms()}
1164
+
1165
+ # Sample a random conformer
1166
+ if (chain_idx, res_id) not in res_index_to_conf_id:
1167
+ conf_ids = [int(conf.GetId()) for conf in mol.GetConformers()]
1168
+ conf_id = int(random.choice(conf_ids))
1169
+ res_index_to_conf_id[(chain_idx, res_id)] = conf_id
1170
+
1171
+ conf_id = res_index_to_conf_id[(chain_idx, res_id)]
1172
+ conformer = mol.GetConformer(conf_id)
1173
+
1174
+ # Map atoms to token indices
1175
+ ref_space_uid.extend([new_idx] * token["atom_num"])
1176
+ atom_to_token.extend([token_id] * token["atom_num"])
1177
+
1178
+ # Add atom data
1179
+ start = token["atom_idx"]
1180
+ end = token["atom_idx"] + token["atom_num"]
1181
+ token_atoms = data.structure.atoms[start:end]
1182
+
1183
+ # Add atom ref data
1184
+ # element, charge, conformer, chirality
1185
+ token_atom_name = np.array([convert_atom_name(a["name"]) for a in token_atoms])
1186
+ token_atoms_ref = np.array([atom_name_to_ref[a["name"]] for a in token_atoms])
1187
+ token_atoms_element = np.array([a.GetAtomicNum() for a in token_atoms_ref])
1188
+ token_atoms_charge = np.array([a.GetFormalCharge() for a in token_atoms_ref])
1189
+ token_atoms_conformer = np.array(
1190
+ [
1191
+ (
1192
+ conformer.GetAtomPosition(a.GetIdx()).x,
1193
+ conformer.GetAtomPosition(a.GetIdx()).y,
1194
+ conformer.GetAtomPosition(a.GetIdx()).z,
1195
+ )
1196
+ for a in token_atoms_ref
1197
+ ]
1198
+ )
1199
+ token_atoms_chirality = np.array(
1200
+ [
1201
+ const.chirality_type_ids.get(a.GetChiralTag().name, unk_chirality)
1202
+ for a in token_atoms_ref
1203
+ ]
1204
+ )
1205
+
1206
+ # Map token to representative atom
1207
+ token_to_rep_atom.append(atom_idx + token["disto_idx"] - start)
1208
+ token_to_center_atom.append(atom_idx + token["center_idx"] - start)
1209
+ if (chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]) and token[
1210
+ "resolved_mask"
1211
+ ]:
1212
+ r_set_to_rep_atom.append(atom_idx + token["center_idx"] - start)
1213
+
1214
+ if chain["mol_type"] == const.chain_type_ids["PROTEIN"]:
1215
+ backbone_index = [
1216
+ (
1217
+ const.protein_backbone_atom_index[atom_name] + 1
1218
+ if atom_name in const.protein_backbone_atom_index
1219
+ else 0
1220
+ )
1221
+ for atom_name in token_atoms["name"]
1222
+ ]
1223
+ elif (
1224
+ chain["mol_type"] == const.chain_type_ids["DNA"]
1225
+ or chain["mol_type"] == const.chain_type_ids["RNA"]
1226
+ ):
1227
+ backbone_index = [
1228
+ (
1229
+ const.nucleic_backbone_atom_index[atom_name]
1230
+ + 1
1231
+ + len(const.protein_backbone_atom_index)
1232
+ if atom_name in const.nucleic_backbone_atom_index
1233
+ else 0
1234
+ )
1235
+ for atom_name in token_atoms["name"]
1236
+ ]
1237
+ else:
1238
+ backbone_index = [0] * token["atom_num"]
1239
+ backbone_feat_index.extend(backbone_index)
1240
+
1241
+ # Get token coordinates across sampled ensembles and apply transforms
1242
+ token_coords = np.array(
1243
+ [
1244
+ data.structure.coords[
1245
+ ensemble_atom_start + start : ensemble_atom_start + end
1246
+ ]["coords"]
1247
+ for ensemble_atom_start in ensemble_atom_starts
1248
+ ]
1249
+ )
1250
+ coord_data.append(token_coords)
1251
+
1252
+ if compute_frames:
1253
+ # Get frame data
1254
+ res_type = const.tokens[token["res_type"]]
1255
+ res_name = str(token["res_name"])
1256
+
1257
+ if token["atom_num"] < 3 or res_type in ["PAD", "UNK", "-"]:
1258
+ idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
1259
+ mask_frame = False
1260
+ elif (token["mol_type"] == const.chain_type_ids["PROTEIN"]) and (
1261
+ res_name in const.ref_atoms
1262
+ ):
1263
+ idx_frame_a, idx_frame_b, idx_frame_c = (
1264
+ const.ref_atoms[res_name].index("N"),
1265
+ const.ref_atoms[res_name].index("CA"),
1266
+ const.ref_atoms[res_name].index("C"),
1267
+ )
1268
+ mask_frame = (
1269
+ token_atoms["is_present"][idx_frame_a]
1270
+ and token_atoms["is_present"][idx_frame_b]
1271
+ and token_atoms["is_present"][idx_frame_c]
1272
+ )
1273
+ elif (
1274
+ token["mol_type"] == const.chain_type_ids["DNA"]
1275
+ or token["mol_type"] == const.chain_type_ids["RNA"]
1276
+ ) and (res_name in const.ref_atoms):
1277
+ idx_frame_a, idx_frame_b, idx_frame_c = (
1278
+ const.ref_atoms[res_name].index("C1'"),
1279
+ const.ref_atoms[res_name].index("C3'"),
1280
+ const.ref_atoms[res_name].index("C4'"),
1281
+ )
1282
+ mask_frame = (
1283
+ token_atoms["is_present"][idx_frame_a]
1284
+ and token_atoms["is_present"][idx_frame_b]
1285
+ and token_atoms["is_present"][idx_frame_c]
1286
+ )
1287
+ elif token["mol_type"] == const.chain_type_ids["PROTEIN"]:
1288
+ # Try to look for the atom nams in the modified residue
1289
+ is_ca = token_atoms["name"] == "CA"
1290
+ idx_frame_a = is_ca.argmax()
1291
+ ca_present = (
1292
+ token_atoms[idx_frame_a]["is_present"] if is_ca.any() else False
1293
+ )
1294
+
1295
+ is_n = token_atoms["name"] == "N"
1296
+ idx_frame_b = is_n.argmax()
1297
+ n_present = (
1298
+ token_atoms[idx_frame_b]["is_present"] if is_n.any() else False
1299
+ )
1300
+
1301
+ is_c = token_atoms["name"] == "C"
1302
+ idx_frame_c = is_c.argmax()
1303
+ c_present = (
1304
+ token_atoms[idx_frame_c]["is_present"] if is_c.any() else False
1305
+ )
1306
+ mask_frame = ca_present and n_present and c_present
1307
+
1308
+ elif (token["mol_type"] == const.chain_type_ids["DNA"]) or (
1309
+ token["mol_type"] == const.chain_type_ids["RNA"]
1310
+ ):
1311
+ # Try to look for the atom nams in the modified residue
1312
+ is_c1 = token_atoms["name"] == "C1'"
1313
+ idx_frame_a = is_c1.argmax()
1314
+ c1_present = (
1315
+ token_atoms[idx_frame_a]["is_present"] if is_c1.any() else False
1316
+ )
1317
+
1318
+ is_c3 = token_atoms["name"] == "C3'"
1319
+ idx_frame_b = is_c3.argmax()
1320
+ c3_present = (
1321
+ token_atoms[idx_frame_b]["is_present"] if is_c3.any() else False
1322
+ )
1323
+
1324
+ is_c4 = token_atoms["name"] == "C4'"
1325
+ idx_frame_c = is_c4.argmax()
1326
+ c4_present = (
1327
+ token_atoms[idx_frame_c]["is_present"] if is_c4.any() else False
1328
+ )
1329
+ mask_frame = c1_present and c3_present and c4_present
1330
+ else:
1331
+ idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
1332
+ mask_frame = False
1333
+ frame_data.append(
1334
+ [
1335
+ idx_frame_a + atom_idx,
1336
+ idx_frame_b + atom_idx,
1337
+ idx_frame_c + atom_idx,
1338
+ ]
1339
+ )
1340
+ resolved_frame_data.append(mask_frame)
1341
+
1342
+ # Get distogram coordinates
1343
+ disto_coords_ensemble_tok = data.structure.coords[
1344
+ e_offsets + token["disto_idx"]
1345
+ ]["coords"]
1346
+ disto_coords_ensemble.append(disto_coords_ensemble_tok)
1347
+
1348
+ # Update atom data. This is technically never used again (we rely on coord_data),
1349
+ # but we update for consistency and to make sure the Atom object has valid, transformed coordinates.
1350
+ token_atoms = token_atoms.copy()
1351
+ token_atoms["coords"] = token_coords[
1352
+ 0
1353
+ ] # atom has a copy of first coords in ensemble
1354
+ atom_data.append(token_atoms)
1355
+ atom_name.append(token_atom_name)
1356
+ atom_element.append(token_atoms_element)
1357
+ atom_charge.append(token_atoms_charge)
1358
+ atom_conformer.append(token_atoms_conformer)
1359
+ atom_chirality.append(token_atoms_chirality)
1360
+ atom_idx += len(token_atoms)
1361
+
1362
+ disto_coords_ensemble = np.array(disto_coords_ensemble) # (N_TOK, N_ENS, 3)
1363
+
1364
+ # Compute ensemble distogram
1365
+ L = len(data.tokens)
1366
+
1367
+ if disto_use_ensemble:
1368
+ # Use all available structures to create distogram
1369
+ idx_list = range(disto_coords_ensemble.shape[1])
1370
+ else:
1371
+ # Only use a sampled structures to create distogram
1372
+ idx_list = ensemble_features["ensemble_ref_idxs"]
1373
+
1374
+ # Create distogram
1375
+ disto_target = torch.zeros(L, L, len(idx_list), num_bins) # TODO1
1376
+
1377
+ # disto_target = torch.zeros(L, L, num_bins)
1378
+ for i, e_idx in enumerate(idx_list):
1379
+ t_center = torch.Tensor(disto_coords_ensemble[:, e_idx, :])
1380
+ t_dists = torch.cdist(t_center, t_center)
1381
+ boundaries = torch.linspace(min_dist, max_dist, num_bins - 1)
1382
+ distogram = (t_dists.unsqueeze(-1) > boundaries).sum(dim=-1).long()
1383
+ # disto_target += one_hot(distogram, num_classes=num_bins)
1384
+ disto_target[:, :, i, :] = one_hot(distogram, num_classes=num_bins) # TODO1
1385
+
1386
+ # Normalize distogram
1387
+ # disto_target = disto_target / disto_target.sum(-1)[..., None] # remove TODO1
1388
+ atom_data = np.concatenate(atom_data)
1389
+ atom_name = np.concatenate(atom_name)
1390
+ atom_element = np.concatenate(atom_element)
1391
+ atom_charge = np.concatenate(atom_charge)
1392
+ atom_conformer = np.concatenate(atom_conformer)
1393
+ atom_chirality = np.concatenate(atom_chirality)
1394
+ coord_data = np.concatenate(coord_data, axis=1)
1395
+ ref_space_uid = np.array(ref_space_uid)
1396
+
1397
+ # Compute features
1398
+ disto_coords_ensemble = from_numpy(disto_coords_ensemble)
1399
+ disto_coords_ensemble = disto_coords_ensemble[
1400
+ :, ensemble_features["ensemble_ref_idxs"]
1401
+ ].permute(1, 0, 2)
1402
+ backbone_feat_index = from_numpy(np.asarray(backbone_feat_index)).long()
1403
+ ref_atom_name_chars = from_numpy(atom_name).long()
1404
+ ref_element = from_numpy(atom_element).long()
1405
+ ref_charge = from_numpy(atom_charge).float()
1406
+ ref_pos = from_numpy(atom_conformer).float()
1407
+ ref_space_uid = from_numpy(ref_space_uid)
1408
+ ref_chirality = from_numpy(atom_chirality).long()
1409
+ coords = from_numpy(coord_data.copy())
1410
+ resolved_mask = from_numpy(atom_data["is_present"])
1411
+ pad_mask = torch.ones(len(atom_data), dtype=torch.float)
1412
+ atom_to_token = torch.tensor(atom_to_token, dtype=torch.long)
1413
+ token_to_rep_atom = torch.tensor(token_to_rep_atom, dtype=torch.long)
1414
+ r_set_to_rep_atom = torch.tensor(r_set_to_rep_atom, dtype=torch.long)
1415
+ token_to_center_atom = torch.tensor(token_to_center_atom, dtype=torch.long)
1416
+ bfactor = from_numpy(atom_data["bfactor"].copy())
1417
+ plddt = from_numpy(atom_data["plddt"].copy())
1418
+ if override_bfactor:
1419
+ bfactor = bfactor * 0.0
1420
+
1421
+ if bfactor_md_correction and data.record.structure.method.lower() == "md":
1422
+ # MD bfactor was computed as RMSF
1423
+ # Convert to b-factor
1424
+ bfactor = 8 * (np.pi**2) * (bfactor**2)
1425
+
1426
+ # We compute frames within ensemble
1427
+ if compute_frames:
1428
+ frames = []
1429
+ frame_resolved_mask = []
1430
+ for i in range(coord_data.shape[0]):
1431
+ frame_data_, resolved_frame_data_ = compute_frames_nonpolymer(
1432
+ data,
1433
+ coord_data[i],
1434
+ atom_data["is_present"],
1435
+ atom_to_token,
1436
+ frame_data,
1437
+ resolved_frame_data,
1438
+ ) # Compute frames for NONPOLYMER tokens
1439
+ frames.append(frame_data_.copy())
1440
+ frame_resolved_mask.append(resolved_frame_data_.copy())
1441
+ frames = from_numpy(np.stack(frames)) # (N_ENS, N_TOK, 3)
1442
+ frame_resolved_mask = from_numpy(np.stack(frame_resolved_mask))
1443
+
1444
+ # Convert to one-hot
1445
+ backbone_feat_index = one_hot(
1446
+ backbone_feat_index,
1447
+ num_classes=1
1448
+ + len(const.protein_backbone_atom_index)
1449
+ + len(const.nucleic_backbone_atom_index),
1450
+ )
1451
+ ref_atom_name_chars = one_hot(ref_atom_name_chars, num_classes=64)
1452
+ ref_element = one_hot(ref_element, num_classes=const.num_elements)
1453
+ atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1)
1454
+ token_to_rep_atom = one_hot(token_to_rep_atom, num_classes=len(atom_data))
1455
+ r_set_to_rep_atom = one_hot(r_set_to_rep_atom, num_classes=len(atom_data))
1456
+ token_to_center_atom = one_hot(token_to_center_atom, num_classes=len(atom_data))
1457
+
1458
+ # Center the ground truth coordinates
1459
+ center = (coords * resolved_mask[None, :, None]).sum(dim=1)
1460
+ center = center / resolved_mask.sum().clamp(min=1)
1461
+ coords = coords - center[:, None]
1462
+
1463
+ if isinstance(override_coords, Tensor):
1464
+ coords = override_coords.unsqueeze(0)
1465
+
1466
+ # Apply random roto-translation to the input conformers
1467
+ for i in range(torch.max(ref_space_uid)):
1468
+ included = ref_space_uid == i
1469
+ if torch.sum(included) > 0 and torch.any(resolved_mask[included]):
1470
+ ref_pos[included] = center_random_augmentation(
1471
+ ref_pos[included][None], resolved_mask[included][None], centering=True
1472
+ )[0]
1473
+
1474
+ # Compute padding and apply
1475
+ if max_atoms is not None:
1476
+ assert max_atoms % atoms_per_window_queries == 0
1477
+ pad_len = max_atoms - len(atom_data)
1478
+ else:
1479
+ pad_len = (
1480
+ (len(atom_data) - 1) // atoms_per_window_queries + 1
1481
+ ) * atoms_per_window_queries - len(atom_data)
1482
+
1483
+ if pad_len > 0:
1484
+ pad_mask = pad_dim(pad_mask, 0, pad_len)
1485
+ ref_pos = pad_dim(ref_pos, 0, pad_len)
1486
+ resolved_mask = pad_dim(resolved_mask, 0, pad_len)
1487
+ ref_atom_name_chars = pad_dim(ref_atom_name_chars, 0, pad_len)
1488
+ ref_element = pad_dim(ref_element, 0, pad_len)
1489
+ ref_charge = pad_dim(ref_charge, 0, pad_len)
1490
+ ref_chirality = pad_dim(ref_chirality, 0, pad_len)
1491
+ backbone_feat_index = pad_dim(backbone_feat_index, 0, pad_len)
1492
+ ref_space_uid = pad_dim(ref_space_uid, 0, pad_len)
1493
+ coords = pad_dim(coords, 1, pad_len)
1494
+ atom_to_token = pad_dim(atom_to_token, 0, pad_len)
1495
+ token_to_rep_atom = pad_dim(token_to_rep_atom, 1, pad_len)
1496
+ token_to_center_atom = pad_dim(token_to_center_atom, 1, pad_len)
1497
+ r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 1, pad_len)
1498
+ bfactor = pad_dim(bfactor, 0, pad_len)
1499
+ plddt = pad_dim(plddt, 0, pad_len)
1500
+
1501
+ if max_tokens is not None:
1502
+ pad_len = max_tokens - token_to_rep_atom.shape[0]
1503
+ if pad_len > 0:
1504
+ atom_to_token = pad_dim(atom_to_token, 1, pad_len)
1505
+ token_to_rep_atom = pad_dim(token_to_rep_atom, 0, pad_len)
1506
+ r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 0, pad_len)
1507
+ token_to_center_atom = pad_dim(token_to_center_atom, 0, pad_len)
1508
+ disto_target = pad_dim(pad_dim(disto_target, 0, pad_len), 1, pad_len)
1509
+ disto_coords_ensemble = pad_dim(disto_coords_ensemble, 1, pad_len)
1510
+
1511
+ if compute_frames:
1512
+ frames = pad_dim(frames, 1, pad_len)
1513
+ frame_resolved_mask = pad_dim(frame_resolved_mask, 1, pad_len)
1514
+
1515
+ atom_features = {
1516
+ "ref_pos": ref_pos,
1517
+ "atom_resolved_mask": resolved_mask,
1518
+ "ref_atom_name_chars": ref_atom_name_chars,
1519
+ "ref_element": ref_element,
1520
+ "ref_charge": ref_charge,
1521
+ "ref_chirality": ref_chirality,
1522
+ "atom_backbone_feat": backbone_feat_index,
1523
+ "ref_space_uid": ref_space_uid,
1524
+ "coords": coords,
1525
+ "atom_pad_mask": pad_mask,
1526
+ "atom_to_token": atom_to_token,
1527
+ "token_to_rep_atom": token_to_rep_atom,
1528
+ "r_set_to_rep_atom": r_set_to_rep_atom,
1529
+ "token_to_center_atom": token_to_center_atom,
1530
+ "disto_target": disto_target,
1531
+ "disto_coords_ensemble": disto_coords_ensemble,
1532
+ "bfactor": bfactor,
1533
+ "plddt": plddt,
1534
+ }
1535
+
1536
+ if compute_frames:
1537
+ atom_features["frames_idx"] = frames
1538
+ atom_features["frame_resolved_mask"] = frame_resolved_mask
1539
+
1540
+ return atom_features
1541
+
1542
+
1543
+ def process_msa_features(
1544
+ data: Tokenized,
1545
+ random: np.random.Generator,
1546
+ max_seqs_batch: int,
1547
+ max_seqs: int,
1548
+ max_tokens: Optional[int] = None,
1549
+ pad_to_max_seqs: bool = False,
1550
+ msa_sampling: bool = False,
1551
+ affinity: bool = False,
1552
+ ) -> dict[str, Tensor]:
1553
+ """Get the MSA features.
1554
+
1555
+ Parameters
1556
+ ----------
1557
+ data : Tokenized
1558
+ The input to the model.
1559
+ random : np.random.Generator
1560
+ The random number generator.
1561
+ max_seqs : int
1562
+ The maximum number of MSA sequences.
1563
+ max_tokens : int
1564
+ The maximum number of tokens.
1565
+ pad_to_max_seqs : bool
1566
+ Whether to pad to the maximum number of sequences.
1567
+ msa_sampling : bool
1568
+ Whether to sample the MSA.
1569
+
1570
+ Returns
1571
+ -------
1572
+ dict[str, Tensor]
1573
+ The MSA features.
1574
+
1575
+ """
1576
+ # Created paired MSA
1577
+ msa, deletion, paired = construct_paired_msa(
1578
+ data=data,
1579
+ random=random,
1580
+ max_seqs=max_seqs_batch,
1581
+ random_subset=msa_sampling,
1582
+ )
1583
+ msa, deletion, paired = (
1584
+ msa.transpose(1, 0),
1585
+ deletion.transpose(1, 0),
1586
+ paired.transpose(1, 0),
1587
+ ) # (N_MSA, N_RES, N_AA)
1588
+
1589
+ # Prepare features
1590
+ assert torch.all(msa >= 0) and torch.all(msa < const.num_tokens)
1591
+ msa_one_hot = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens)
1592
+ msa_mask = torch.ones_like(msa)
1593
+ profile = msa_one_hot.float().mean(dim=0)
1594
+ has_deletion = deletion > 0
1595
+ deletion = np.pi / 2 * np.arctan(deletion / 3)
1596
+ deletion_mean = deletion.mean(axis=0)
1597
+
1598
+ # Pad in the MSA dimension (dim=0)
1599
+ if pad_to_max_seqs:
1600
+ pad_len = max_seqs - msa.shape[0]
1601
+ if pad_len > 0:
1602
+ msa = pad_dim(msa, 0, pad_len, const.token_ids["-"])
1603
+ paired = pad_dim(paired, 0, pad_len)
1604
+ msa_mask = pad_dim(msa_mask, 0, pad_len)
1605
+ has_deletion = pad_dim(has_deletion, 0, pad_len)
1606
+ deletion = pad_dim(deletion, 0, pad_len)
1607
+
1608
+ # Pad in the token dimension (dim=1)
1609
+ if max_tokens is not None:
1610
+ pad_len = max_tokens - msa.shape[1]
1611
+ if pad_len > 0:
1612
+ msa = pad_dim(msa, 1, pad_len, const.token_ids["-"])
1613
+ paired = pad_dim(paired, 1, pad_len)
1614
+ msa_mask = pad_dim(msa_mask, 1, pad_len)
1615
+ has_deletion = pad_dim(has_deletion, 1, pad_len)
1616
+ deletion = pad_dim(deletion, 1, pad_len)
1617
+ profile = pad_dim(profile, 0, pad_len)
1618
+ deletion_mean = pad_dim(deletion_mean, 0, pad_len)
1619
+ if affinity:
1620
+ return {
1621
+ "deletion_mean_affinity": deletion_mean,
1622
+ "profile_affinity": profile,
1623
+ }
1624
+ else:
1625
+ return {
1626
+ "msa": msa,
1627
+ "msa_paired": paired,
1628
+ "deletion_value": deletion,
1629
+ "has_deletion": has_deletion,
1630
+ "deletion_mean": deletion_mean,
1631
+ "profile": profile,
1632
+ "msa_mask": msa_mask,
1633
+ }
1634
+
1635
+
1636
+ def load_dummy_templates_features(tdim: int, num_tokens: int) -> dict:
1637
+ """Load dummy templates for v2."""
1638
+ # Allocate features
1639
+ res_type = np.zeros((tdim, num_tokens), dtype=np.int64)
1640
+ frame_rot = np.zeros((tdim, num_tokens, 3, 3), dtype=np.float32)
1641
+ frame_t = np.zeros((tdim, num_tokens, 3), dtype=np.float32)
1642
+ cb_coords = np.zeros((tdim, num_tokens, 3), dtype=np.float32)
1643
+ ca_coords = np.zeros((tdim, num_tokens, 3), dtype=np.float32)
1644
+ frame_mask = np.zeros((tdim, num_tokens), dtype=np.float32)
1645
+ cb_mask = np.zeros((tdim, num_tokens), dtype=np.float32)
1646
+ template_mask = np.zeros((tdim, num_tokens), dtype=np.float32)
1647
+ query_to_template = np.zeros((tdim, num_tokens), dtype=np.int64)
1648
+ visibility_ids = np.zeros((tdim, num_tokens), dtype=np.float32)
1649
+
1650
+ # Convert to one-hot
1651
+ res_type = torch.from_numpy(res_type)
1652
+ res_type = one_hot(res_type, num_classes=const.num_tokens)
1653
+
1654
+ return {
1655
+ "template_restype": res_type,
1656
+ "template_frame_rot": torch.from_numpy(frame_rot),
1657
+ "template_frame_t": torch.from_numpy(frame_t),
1658
+ "template_cb": torch.from_numpy(cb_coords),
1659
+ "template_ca": torch.from_numpy(ca_coords),
1660
+ "template_mask_cb": torch.from_numpy(cb_mask),
1661
+ "template_mask_frame": torch.from_numpy(frame_mask),
1662
+ "template_mask": torch.from_numpy(template_mask),
1663
+ "query_to_template": torch.from_numpy(query_to_template),
1664
+ "visibility_ids": torch.from_numpy(visibility_ids),
1665
+ }
1666
+
1667
+
1668
+ def compute_template_features(
1669
+ query_tokens: Tokenized,
1670
+ tmpl_tokens: list[dict],
1671
+ num_tokens: int,
1672
+ ) -> dict:
1673
+ """Compute the template features."""
1674
+ # Allocate features
1675
+ res_type = np.zeros((num_tokens,), dtype=np.int64)
1676
+ frame_rot = np.zeros((num_tokens, 3, 3), dtype=np.float32)
1677
+ frame_t = np.zeros((num_tokens, 3), dtype=np.float32)
1678
+ cb_coords = np.zeros((num_tokens, 3), dtype=np.float32)
1679
+ ca_coords = np.zeros((num_tokens, 3), dtype=np.float32)
1680
+ frame_mask = np.zeros((num_tokens,), dtype=np.float32)
1681
+ cb_mask = np.zeros((num_tokens,), dtype=np.float32)
1682
+ template_mask = np.zeros((num_tokens,), dtype=np.float32)
1683
+ query_to_template = np.zeros((num_tokens,), dtype=np.int64)
1684
+ visibility_ids = np.zeros((num_tokens,), dtype=np.float32)
1685
+
1686
+ # Now create features per token
1687
+ asym_id_to_pdb_id = {}
1688
+
1689
+ for token_dict in tmpl_tokens:
1690
+ idx = token_dict["q_idx"]
1691
+ pdb_id = token_dict["pdb_id"]
1692
+ token = token_dict["token"]
1693
+ query_token = query_tokens.tokens[idx]
1694
+ asym_id_to_pdb_id[query_token["asym_id"]] = pdb_id
1695
+ res_type[idx] = token["res_type"]
1696
+ frame_rot[idx] = token["frame_rot"].reshape(3, 3)
1697
+ frame_t[idx] = token["frame_t"]
1698
+ cb_coords[idx] = token["disto_coords"]
1699
+ ca_coords[idx] = token["center_coords"]
1700
+ cb_mask[idx] = token["disto_mask"]
1701
+ frame_mask[idx] = token["frame_mask"]
1702
+ template_mask[idx] = 1.0
1703
+
1704
+ # Set visibility_id for templated chains
1705
+ for asym_id, pdb_id in asym_id_to_pdb_id.items():
1706
+ indices = (query_tokens.tokens["asym_id"] == asym_id).nonzero()
1707
+ visibility_ids[indices] = pdb_id
1708
+
1709
+ # Set visibility for non templated chain + olygomerics
1710
+ for asym_id in np.unique(query_tokens.structure.chains["asym_id"]):
1711
+ if asym_id not in asym_id_to_pdb_id:
1712
+ # We hack the chain id to be negative to not overlap with the above
1713
+ indices = (query_tokens.tokens["asym_id"] == asym_id).nonzero()
1714
+ visibility_ids[indices] = -1 - asym_id
1715
+
1716
+ # Convert to one-hot
1717
+ res_type = torch.from_numpy(res_type)
1718
+ res_type = one_hot(res_type, num_classes=const.num_tokens)
1719
+
1720
+ return {
1721
+ "template_restype": res_type,
1722
+ "template_frame_rot": torch.from_numpy(frame_rot),
1723
+ "template_frame_t": torch.from_numpy(frame_t),
1724
+ "template_cb": torch.from_numpy(cb_coords),
1725
+ "template_ca": torch.from_numpy(ca_coords),
1726
+ "template_mask_cb": torch.from_numpy(cb_mask),
1727
+ "template_mask_frame": torch.from_numpy(frame_mask),
1728
+ "template_mask": torch.from_numpy(template_mask),
1729
+ "query_to_template": torch.from_numpy(query_to_template),
1730
+ "visibility_ids": torch.from_numpy(visibility_ids),
1731
+ }
1732
+
1733
+
1734
+ def process_template_features(
1735
+ data: Tokenized,
1736
+ max_tokens: int,
1737
+ ) -> dict[str, torch.Tensor]:
1738
+ """Load the given input data.
1739
+
1740
+ Parameters
1741
+ ----------
1742
+ data : Tokenized
1743
+ The input to the model.
1744
+ max_tokens : int
1745
+ The maximum number of tokens.
1746
+
1747
+ Returns
1748
+ -------
1749
+ dict[str, torch.Tensor]
1750
+ The loaded template features.
1751
+
1752
+ """
1753
+ # Group templates by name
1754
+ name_to_templates: dict[str, list[TemplateInfo]] = {}
1755
+ for template_info in data.record.templates:
1756
+ name_to_templates.setdefault(template_info.name, []).append(template_info)
1757
+
1758
+ # Map chain name to asym_id
1759
+ chain_name_to_asym_id = {}
1760
+ for chain in data.structure.chains:
1761
+ chain_name_to_asym_id[chain["name"]] = chain["asym_id"]
1762
+
1763
+ # Compute the offset
1764
+ template_features = []
1765
+ for template_id, (template_name, templates) in enumerate(name_to_templates.items()):
1766
+ row_tokens = []
1767
+ template_structure = data.templates[template_name]
1768
+ template_tokens = data.template_tokens[template_name]
1769
+ tmpl_chain_name_to_asym_id = {}
1770
+ for chain in template_structure.chains:
1771
+ tmpl_chain_name_to_asym_id[chain["name"]] = chain["asym_id"]
1772
+
1773
+ for template in templates:
1774
+ offset = template.template_st - template.query_st
1775
+
1776
+ # Get query and template tokens to map residues
1777
+ query_tokens = data.tokens
1778
+ chain_id = chain_name_to_asym_id[template.query_chain]
1779
+ q_tokens = query_tokens[query_tokens["asym_id"] == chain_id]
1780
+ q_indices = dict(zip(q_tokens["res_idx"], q_tokens["token_idx"]))
1781
+
1782
+ # Get the template tokens at the query residues
1783
+ chain_id = tmpl_chain_name_to_asym_id[template.template_chain]
1784
+ toks = template_tokens[template_tokens["asym_id"] == chain_id]
1785
+ toks = [t for t in toks if t["res_idx"] - offset in q_indices]
1786
+ for t in toks:
1787
+ q_idx = q_indices[t["res_idx"] - offset]
1788
+ row_tokens.append(
1789
+ {
1790
+ "token": t,
1791
+ "pdb_id": template_id,
1792
+ "q_idx": q_idx,
1793
+ }
1794
+ )
1795
+
1796
+ # Compute template features for each row
1797
+ row_features = compute_template_features(data, row_tokens, max_tokens)
1798
+ template_features.append(row_features)
1799
+
1800
+ # Stack each feature
1801
+ out = {}
1802
+ for k in template_features[0]:
1803
+ out[k] = torch.stack([f[k] for f in template_features])
1804
+
1805
+ return out
1806
+
1807
+
1808
+ def process_symmetry_features(
1809
+ cropped: Tokenized, symmetries: dict
1810
+ ) -> dict[str, Tensor]:
1811
+ """Get the symmetry features.
1812
+
1813
+ Parameters
1814
+ ----------
1815
+ data : Tokenized
1816
+ The input to the model.
1817
+
1818
+ Returns
1819
+ -------
1820
+ dict[str, Tensor]
1821
+ The symmetry features.
1822
+
1823
+ """
1824
+ features = get_chain_symmetries(cropped)
1825
+ features.update(get_amino_acids_symmetries(cropped))
1826
+ features.update(get_ligand_symmetries(cropped, symmetries))
1827
+
1828
+ return features
1829
+
1830
+
1831
+ def process_ensemble_features(
1832
+ data: Tokenized,
1833
+ random: np.random.Generator,
1834
+ num_ensembles: int,
1835
+ ensemble_sample_replacement: bool,
1836
+ fix_single_ensemble: bool,
1837
+ ) -> dict[str, Tensor]:
1838
+ """Get the ensemble features.
1839
+
1840
+ Parameters
1841
+ ----------
1842
+ data : Tokenized
1843
+ The input to the model.
1844
+ random : np.random.Generator
1845
+ The random number generator.
1846
+ num_ensembles : int
1847
+ The maximum number of ensembles to sample.
1848
+ ensemble_sample_replacement : bool
1849
+ Whether to sample with replacement.
1850
+
1851
+ Returns
1852
+ -------
1853
+ dict[str, Tensor]
1854
+ The ensemble features.
1855
+
1856
+ """
1857
+ assert num_ensembles > 0, "Number of conformers sampled must be greater than 0."
1858
+
1859
+ # Number of available conformers in the structure
1860
+ # s_ensemble_num = min(len(cropped.structure.ensemble), 24) # Limit to 24 conformers DEBUG: TODO: remove !
1861
+ s_ensemble_num = len(data.structure.ensemble)
1862
+
1863
+ if fix_single_ensemble:
1864
+ # Always take the first conformer for train and validation
1865
+ assert (
1866
+ num_ensembles == 1
1867
+ ), "Number of conformers sampled must be 1 with fix_single_ensemble=True."
1868
+ ensemble_ref_idxs = np.array([0])
1869
+ else:
1870
+ if ensemble_sample_replacement:
1871
+ # Used in training
1872
+ ensemble_ref_idxs = random.integers(0, s_ensemble_num, (num_ensembles,))
1873
+ else:
1874
+ # Used in validation
1875
+ if s_ensemble_num < num_ensembles:
1876
+ # Take all available conformers
1877
+ ensemble_ref_idxs = np.arange(0, s_ensemble_num)
1878
+ else:
1879
+ # Sample without replacement
1880
+ ensemble_ref_idxs = random.choice(
1881
+ s_ensemble_num, num_ensembles, replace=False
1882
+ )
1883
+
1884
+ ensemble_features = {
1885
+ "ensemble_ref_idxs": torch.Tensor(ensemble_ref_idxs).long(),
1886
+ }
1887
+
1888
+ return ensemble_features
1889
+
1890
+
1891
+ def process_residue_constraint_features(data: Tokenized) -> dict[str, Tensor]:
1892
+ residue_constraints = data.residue_constraints
1893
+ if residue_constraints is not None:
1894
+ rdkit_bounds_constraints = residue_constraints.rdkit_bounds_constraints
1895
+ chiral_atom_constraints = residue_constraints.chiral_atom_constraints
1896
+ stereo_bond_constraints = residue_constraints.stereo_bond_constraints
1897
+ planar_bond_constraints = residue_constraints.planar_bond_constraints
1898
+ planar_ring_5_constraints = residue_constraints.planar_ring_5_constraints
1899
+ planar_ring_6_constraints = residue_constraints.planar_ring_6_constraints
1900
+
1901
+ rdkit_bounds_index = torch.tensor(
1902
+ rdkit_bounds_constraints["atom_idxs"].copy(), dtype=torch.long
1903
+ ).T
1904
+ rdkit_bounds_bond_mask = torch.tensor(
1905
+ rdkit_bounds_constraints["is_bond"].copy(), dtype=torch.bool
1906
+ )
1907
+ rdkit_bounds_angle_mask = torch.tensor(
1908
+ rdkit_bounds_constraints["is_angle"].copy(), dtype=torch.bool
1909
+ )
1910
+ rdkit_upper_bounds = torch.tensor(
1911
+ rdkit_bounds_constraints["upper_bound"].copy(), dtype=torch.float
1912
+ )
1913
+ rdkit_lower_bounds = torch.tensor(
1914
+ rdkit_bounds_constraints["lower_bound"].copy(), dtype=torch.float
1915
+ )
1916
+
1917
+ chiral_atom_index = torch.tensor(
1918
+ chiral_atom_constraints["atom_idxs"].copy(), dtype=torch.long
1919
+ ).T
1920
+ chiral_reference_mask = torch.tensor(
1921
+ chiral_atom_constraints["is_reference"].copy(), dtype=torch.bool
1922
+ )
1923
+ chiral_atom_orientations = torch.tensor(
1924
+ chiral_atom_constraints["is_r"].copy(), dtype=torch.bool
1925
+ )
1926
+
1927
+ stereo_bond_index = torch.tensor(
1928
+ stereo_bond_constraints["atom_idxs"].copy(), dtype=torch.long
1929
+ ).T
1930
+ stereo_reference_mask = torch.tensor(
1931
+ stereo_bond_constraints["is_reference"].copy(), dtype=torch.bool
1932
+ )
1933
+ stereo_bond_orientations = torch.tensor(
1934
+ stereo_bond_constraints["is_e"].copy(), dtype=torch.bool
1935
+ )
1936
+
1937
+ planar_bond_index = torch.tensor(
1938
+ planar_bond_constraints["atom_idxs"].copy(), dtype=torch.long
1939
+ ).T
1940
+ planar_ring_5_index = torch.tensor(
1941
+ planar_ring_5_constraints["atom_idxs"].copy(), dtype=torch.long
1942
+ ).T
1943
+ planar_ring_6_index = torch.tensor(
1944
+ planar_ring_6_constraints["atom_idxs"].copy(), dtype=torch.long
1945
+ ).T
1946
+ else:
1947
+ rdkit_bounds_index = torch.empty((2, 0), dtype=torch.long)
1948
+ rdkit_bounds_bond_mask = torch.empty((0,), dtype=torch.bool)
1949
+ rdkit_bounds_angle_mask = torch.empty((0,), dtype=torch.bool)
1950
+ rdkit_upper_bounds = torch.empty((0,), dtype=torch.float)
1951
+ rdkit_lower_bounds = torch.empty((0,), dtype=torch.float)
1952
+ chiral_atom_index = torch.empty(
1953
+ (
1954
+ 4,
1955
+ 0,
1956
+ ),
1957
+ dtype=torch.long,
1958
+ )
1959
+ chiral_reference_mask = torch.empty((0,), dtype=torch.bool)
1960
+ chiral_atom_orientations = torch.empty((0,), dtype=torch.bool)
1961
+ stereo_bond_index = torch.empty((4, 0), dtype=torch.long)
1962
+ stereo_reference_mask = torch.empty((0,), dtype=torch.bool)
1963
+ stereo_bond_orientations = torch.empty((0,), dtype=torch.bool)
1964
+ planar_bond_index = torch.empty((6, 0), dtype=torch.long)
1965
+ planar_ring_5_index = torch.empty((5, 0), dtype=torch.long)
1966
+ planar_ring_6_index = torch.empty((6, 0), dtype=torch.long)
1967
+
1968
+ return {
1969
+ "rdkit_bounds_index": rdkit_bounds_index,
1970
+ "rdkit_bounds_bond_mask": rdkit_bounds_bond_mask,
1971
+ "rdkit_bounds_angle_mask": rdkit_bounds_angle_mask,
1972
+ "rdkit_upper_bounds": rdkit_upper_bounds,
1973
+ "rdkit_lower_bounds": rdkit_lower_bounds,
1974
+ "chiral_atom_index": chiral_atom_index,
1975
+ "chiral_reference_mask": chiral_reference_mask,
1976
+ "chiral_atom_orientations": chiral_atom_orientations,
1977
+ "stereo_bond_index": stereo_bond_index,
1978
+ "stereo_reference_mask": stereo_reference_mask,
1979
+ "stereo_bond_orientations": stereo_bond_orientations,
1980
+ "planar_bond_index": planar_bond_index,
1981
+ "planar_ring_5_index": planar_ring_5_index,
1982
+ "planar_ring_6_index": planar_ring_6_index,
1983
+ }
1984
+
1985
+
1986
+ def process_chain_feature_constraints(data: Tokenized) -> dict[str, Tensor]:
1987
+ structure = data.structure
1988
+ if structure.bonds.shape[0] > 0:
1989
+ connected_chain_index, connected_atom_index = [], []
1990
+ for connection in structure.bonds:
1991
+ if connection["chain_1"] == connection["chain_2"]:
1992
+ continue
1993
+ connected_chain_index.append([connection["chain_1"], connection["chain_2"]])
1994
+ connected_atom_index.append([connection["atom_1"], connection["atom_2"]])
1995
+ if len(connected_chain_index) > 0:
1996
+ connected_chain_index = torch.tensor(
1997
+ connected_chain_index, dtype=torch.long
1998
+ ).T
1999
+ connected_atom_index = torch.tensor(
2000
+ connected_atom_index, dtype=torch.long
2001
+ ).T
2002
+ else:
2003
+ connected_chain_index = torch.empty((2, 0), dtype=torch.long)
2004
+ connected_atom_index = torch.empty((2, 0), dtype=torch.long)
2005
+ else:
2006
+ connected_chain_index = torch.empty((2, 0), dtype=torch.long)
2007
+ connected_atom_index = torch.empty((2, 0), dtype=torch.long)
2008
+
2009
+ symmetric_chain_index = []
2010
+ for i, chain_i in enumerate(structure.chains):
2011
+ for j, chain_j in enumerate(structure.chains):
2012
+ if j <= i:
2013
+ continue
2014
+ if chain_i["entity_id"] == chain_j["entity_id"]:
2015
+ symmetric_chain_index.append([i, j])
2016
+ if len(symmetric_chain_index) > 0:
2017
+ symmetric_chain_index = torch.tensor(symmetric_chain_index, dtype=torch.long).T
2018
+ else:
2019
+ symmetric_chain_index = torch.empty((2, 0), dtype=torch.long)
2020
+ return {
2021
+ "connected_chain_index": connected_chain_index,
2022
+ "connected_atom_index": connected_atom_index,
2023
+ "symmetric_chain_index": symmetric_chain_index,
2024
+ }
2025
+
2026
+
2027
+ class Boltz2Featurizer:
2028
+ """Boltz2 featurizer."""
2029
+
2030
+ def process(
2031
+ self,
2032
+ data: Tokenized,
2033
+ random: np.random.Generator,
2034
+ molecules: dict[str, Mol],
2035
+ training: bool,
2036
+ max_seqs: int,
2037
+ atoms_per_window_queries: int = 32,
2038
+ min_dist: float = 2.0,
2039
+ max_dist: float = 22.0,
2040
+ num_bins: int = 64,
2041
+ num_ensembles: int = 1,
2042
+ ensemble_sample_replacement: bool = False,
2043
+ disto_use_ensemble: Optional[bool] = False,
2044
+ fix_single_ensemble: Optional[bool] = True,
2045
+ max_tokens: Optional[int] = None,
2046
+ max_atoms: Optional[int] = None,
2047
+ pad_to_max_seqs: bool = False,
2048
+ compute_symmetries: bool = False,
2049
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
2050
+ contact_conditioned_prop: Optional[float] = 0.0,
2051
+ binder_pocket_cutoff_min: Optional[float] = 4.0,
2052
+ binder_pocket_cutoff_max: Optional[float] = 20.0,
2053
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
2054
+ only_ligand_binder_pocket: Optional[bool] = False,
2055
+ only_pp_contact: Optional[bool] = False,
2056
+ pocket_constraints: Optional[list] = None,
2057
+ single_sequence_prop: Optional[float] = 0.0,
2058
+ msa_sampling: bool = False,
2059
+ override_bfactor: float = False,
2060
+ override_method: Optional[str] = None,
2061
+ compute_frames: bool = False,
2062
+ override_coords: Optional[Tensor] = None,
2063
+ bfactor_md_correction: bool = False,
2064
+ compute_constraint_features: bool = False,
2065
+ inference_pocket_constraints: Optional[list] = None,
2066
+ compute_affinity: bool = False,
2067
+ ) -> dict[str, Tensor]:
2068
+ """Compute features.
2069
+
2070
+ Parameters
2071
+ ----------
2072
+ data : Tokenized
2073
+ The input to the model.
2074
+ training : bool
2075
+ Whether the model is in training mode.
2076
+ max_tokens : int, optional
2077
+ The maximum number of tokens.
2078
+ max_atoms : int, optional
2079
+ The maximum number of atoms
2080
+ max_seqs : int, optional
2081
+ The maximum number of sequences.
2082
+
2083
+ Returns
2084
+ -------
2085
+ dict[str, Tensor]
2086
+ The features for model training.
2087
+
2088
+ """
2089
+ # Compute random number of sequences
2090
+ if training and max_seqs is not None:
2091
+ if random.random() > single_sequence_prop:
2092
+ max_seqs_batch = random.integers(1, max_seqs + 1)
2093
+ else:
2094
+ max_seqs_batch = 1
2095
+ else:
2096
+ max_seqs_batch = max_seqs
2097
+
2098
+ # Compute ensemble features
2099
+ ensemble_features = process_ensemble_features(
2100
+ data=data,
2101
+ random=random,
2102
+ num_ensembles=num_ensembles,
2103
+ ensemble_sample_replacement=ensemble_sample_replacement,
2104
+ fix_single_ensemble=fix_single_ensemble,
2105
+ )
2106
+
2107
+ # Compute token features
2108
+ token_features = process_token_features(
2109
+ data=data,
2110
+ random=random,
2111
+ max_tokens=max_tokens,
2112
+ binder_pocket_conditioned_prop=binder_pocket_conditioned_prop,
2113
+ contact_conditioned_prop=contact_conditioned_prop,
2114
+ binder_pocket_cutoff_min=binder_pocket_cutoff_min,
2115
+ binder_pocket_cutoff_max=binder_pocket_cutoff_max,
2116
+ binder_pocket_sampling_geometric_p=binder_pocket_sampling_geometric_p,
2117
+ only_ligand_binder_pocket=only_ligand_binder_pocket,
2118
+ only_pp_contact=only_pp_contact,
2119
+ override_method=override_method,
2120
+ inference_pocket_constraints=inference_pocket_constraints,
2121
+ )
2122
+
2123
+ # Compute atom features
2124
+ atom_features = process_atom_features(
2125
+ data=data,
2126
+ random=random,
2127
+ molecules=molecules,
2128
+ ensemble_features=ensemble_features,
2129
+ atoms_per_window_queries=atoms_per_window_queries,
2130
+ min_dist=min_dist,
2131
+ max_dist=max_dist,
2132
+ num_bins=num_bins,
2133
+ max_atoms=max_atoms,
2134
+ max_tokens=max_tokens,
2135
+ disto_use_ensemble=disto_use_ensemble,
2136
+ override_bfactor=override_bfactor,
2137
+ compute_frames=compute_frames,
2138
+ override_coords=override_coords,
2139
+ bfactor_md_correction=bfactor_md_correction,
2140
+ )
2141
+
2142
+ # Compute MSA features
2143
+ msa_features = process_msa_features(
2144
+ data=data,
2145
+ random=random,
2146
+ max_seqs_batch=max_seqs_batch,
2147
+ max_seqs=max_seqs,
2148
+ max_tokens=max_tokens,
2149
+ pad_to_max_seqs=pad_to_max_seqs,
2150
+ msa_sampling=training and msa_sampling,
2151
+ )
2152
+
2153
+ # Compute MSA features
2154
+ msa_features_affinity = {}
2155
+ if compute_affinity:
2156
+ msa_features_affinity = process_msa_features(
2157
+ data=data,
2158
+ random=random,
2159
+ max_seqs_batch=1,
2160
+ max_seqs=1,
2161
+ max_tokens=max_tokens,
2162
+ pad_to_max_seqs=pad_to_max_seqs,
2163
+ msa_sampling=training and msa_sampling,
2164
+ affinity=True,
2165
+ )
2166
+
2167
+ # Compute affinity ligand Molecular Weight
2168
+ ligand_to_mw = {}
2169
+ if compute_affinity:
2170
+ ligand_to_mw["affinity_mw"] = data.record.affinity.mw
2171
+
2172
+ # Compute template features
2173
+ num_tokens = data.tokens.shape[0] if max_tokens is None else max_tokens
2174
+ if data.templates:
2175
+ template_features = process_template_features(
2176
+ data=data,
2177
+ max_tokens=num_tokens,
2178
+ )
2179
+ else:
2180
+ template_features = load_dummy_templates_features(
2181
+ tdim=1,
2182
+ num_tokens=num_tokens,
2183
+ )
2184
+
2185
+ # Compute symmetry features
2186
+ symmetry_features = {}
2187
+ if compute_symmetries:
2188
+ symmetries = get_symmetries(molecules)
2189
+ symmetry_features = process_symmetry_features(data, symmetries)
2190
+
2191
+ # Compute residue constraint features
2192
+ residue_constraint_features = {}
2193
+ if compute_constraint_features:
2194
+ residue_constraint_features = process_residue_constraint_features(data)
2195
+ chain_constraint_features = process_chain_feature_constraints(data)
2196
+
2197
+ return {
2198
+ **token_features,
2199
+ **atom_features,
2200
+ **msa_features,
2201
+ **msa_features_affinity,
2202
+ **template_features,
2203
+ **symmetry_features,
2204
+ **ensemble_features,
2205
+ **residue_constraint_features,
2206
+ **chain_constraint_features,
2207
+ **ligand_to_mw,
2208
+ }