boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1230 @@
1
+ import math
2
+ import random
3
+ from typing import Optional
4
+
5
+ import numba
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+ import torch
9
+ from numba import types
10
+ from torch import Tensor, from_numpy
11
+ from torch.nn.functional import one_hot
12
+
13
+ from boltz.data import const
14
+ from boltz.data.feature.symmetry import (
15
+ get_amino_acids_symmetries,
16
+ get_chain_symmetries,
17
+ get_ligand_symmetries,
18
+ )
19
+ from boltz.data.pad import pad_dim
20
+ from boltz.data.types import (
21
+ MSA,
22
+ MSADeletion,
23
+ MSAResidue,
24
+ MSASequence,
25
+ Tokenized,
26
+ )
27
+ from boltz.model.modules.utils import center_random_augmentation
28
+
29
+ ####################################################################################################
30
+ # HELPERS
31
+ ####################################################################################################
32
+
33
+
34
+ def compute_frames_nonpolymer(
35
+ data: Tokenized,
36
+ coords,
37
+ resolved_mask,
38
+ atom_to_token,
39
+ frame_data: list,
40
+ resolved_frame_data: list,
41
+ ) -> tuple[list, list]:
42
+ """Get the frames for non-polymer tokens.
43
+
44
+ Parameters
45
+ ----------
46
+ data : Tokenized
47
+ The tokenized data.
48
+ frame_data : list
49
+ The frame data.
50
+ resolved_frame_data : list
51
+ The resolved frame data.
52
+
53
+ Returns
54
+ -------
55
+ tuple[list, list]
56
+ The frame data and resolved frame data.
57
+
58
+ """
59
+ frame_data = np.array(frame_data)
60
+ resolved_frame_data = np.array(resolved_frame_data)
61
+ asym_id_token = data.tokens["asym_id"]
62
+ asym_id_atom = data.tokens["asym_id"][atom_to_token]
63
+ token_idx = 0
64
+ atom_idx = 0
65
+ for id in np.unique(data.tokens["asym_id"]):
66
+ mask_chain_token = asym_id_token == id
67
+ mask_chain_atom = asym_id_atom == id
68
+ num_tokens = mask_chain_token.sum()
69
+ num_atoms = mask_chain_atom.sum()
70
+ if (
71
+ data.tokens[token_idx]["mol_type"] != const.chain_type_ids["NONPOLYMER"]
72
+ or num_atoms < 3
73
+ ):
74
+ token_idx += num_tokens
75
+ atom_idx += num_atoms
76
+ continue
77
+ dist_mat = (
78
+ (
79
+ coords.reshape(-1, 3)[mask_chain_atom][:, None, :]
80
+ - coords.reshape(-1, 3)[mask_chain_atom][None, :, :]
81
+ )
82
+ ** 2
83
+ ).sum(-1) ** 0.5
84
+ resolved_pair = 1 - (
85
+ resolved_mask[mask_chain_atom][None, :]
86
+ * resolved_mask[mask_chain_atom][:, None]
87
+ ).astype(np.float32)
88
+ resolved_pair[resolved_pair == 1] = math.inf
89
+ indices = np.argsort(dist_mat + resolved_pair, axis=1)
90
+ frames = (
91
+ np.concatenate(
92
+ [
93
+ indices[:, 1:2],
94
+ indices[:, 0:1],
95
+ indices[:, 2:3],
96
+ ],
97
+ axis=1,
98
+ )
99
+ + atom_idx
100
+ )
101
+ frame_data[token_idx : token_idx + num_atoms, :] = frames
102
+ resolved_frame_data[token_idx : token_idx + num_atoms] = resolved_mask[
103
+ frames
104
+ ].all(axis=1)
105
+ token_idx += num_tokens
106
+ atom_idx += num_atoms
107
+ frames_expanded = coords.reshape(-1, 3)[frame_data]
108
+
109
+ mask_collinear = compute_collinear_mask(
110
+ frames_expanded[:, 1] - frames_expanded[:, 0],
111
+ frames_expanded[:, 1] - frames_expanded[:, 2],
112
+ )
113
+ return frame_data, resolved_frame_data & mask_collinear
114
+
115
+
116
+ def compute_collinear_mask(v1, v2):
117
+ norm1 = np.linalg.norm(v1, axis=1, keepdims=True)
118
+ norm2 = np.linalg.norm(v2, axis=1, keepdims=True)
119
+ v1 = v1 / (norm1 + 1e-6)
120
+ v2 = v2 / (norm2 + 1e-6)
121
+ mask_angle = np.abs(np.sum(v1 * v2, axis=1)) < 0.9063
122
+ mask_overlap1 = norm1.reshape(-1) > 1e-2
123
+ mask_overlap2 = norm2.reshape(-1) > 1e-2
124
+ return mask_angle & mask_overlap1 & mask_overlap2
125
+
126
+
127
+ def dummy_msa(residues: np.ndarray) -> MSA:
128
+ """Create a dummy MSA for a chain.
129
+
130
+ Parameters
131
+ ----------
132
+ residues : np.ndarray
133
+ The residues for the chain.
134
+
135
+ Returns
136
+ -------
137
+ MSA
138
+ The dummy MSA.
139
+
140
+ """
141
+ residues = [res["res_type"] for res in residues]
142
+ deletions = []
143
+ sequences = [(0, -1, 0, len(residues), 0, 0)]
144
+ return MSA(
145
+ residues=np.array(residues, dtype=MSAResidue),
146
+ deletions=np.array(deletions, dtype=MSADeletion),
147
+ sequences=np.array(sequences, dtype=MSASequence),
148
+ )
149
+
150
+
151
+ def construct_paired_msa( # noqa: C901, PLR0915, PLR0912
152
+ data: Tokenized,
153
+ max_seqs: int,
154
+ max_pairs: int = 8192,
155
+ max_total: int = 16384,
156
+ random_subset: bool = False,
157
+ ) -> tuple[Tensor, Tensor, Tensor]:
158
+ """Pair the MSA data.
159
+
160
+ Parameters
161
+ ----------
162
+ data : Input
163
+ The input data.
164
+
165
+ Returns
166
+ -------
167
+ Tensor
168
+ The MSA data.
169
+ Tensor
170
+ The deletion data.
171
+ Tensor
172
+ Mask indicating paired sequences.
173
+
174
+ """
175
+ # Get unique chains (ensuring monotonicity in the order)
176
+ assert np.all(np.diff(data.tokens["asym_id"], n=1) >= 0)
177
+ chain_ids = np.unique(data.tokens["asym_id"])
178
+
179
+ # Get relevant MSA, and create a dummy for chains without
180
+ msa = {k: data.msa[k] for k in chain_ids if k in data.msa}
181
+ for chain_id in chain_ids:
182
+ if chain_id not in msa:
183
+ chain = data.structure.chains[chain_id]
184
+ res_start = chain["res_idx"]
185
+ res_end = res_start + chain["res_num"]
186
+ residues = data.structure.residues[res_start:res_end]
187
+ msa[chain_id] = dummy_msa(residues)
188
+
189
+ # Map taxonomies to (chain_id, seq_idx)
190
+ taxonomy_map: dict[str, list] = {}
191
+ for chain_id, chain_msa in msa.items():
192
+ sequences = chain_msa.sequences
193
+ sequences = sequences[sequences["taxonomy"] != -1]
194
+ for sequence in sequences:
195
+ seq_idx = sequence["seq_idx"]
196
+ taxon = sequence["taxonomy"]
197
+ taxonomy_map.setdefault(taxon, []).append((chain_id, seq_idx))
198
+
199
+ # Remove taxonomies with only one sequence and sort by the
200
+ # number of chain_id present in each of the taxonomies
201
+ taxonomy_map = {k: v for k, v in taxonomy_map.items() if len(v) > 1}
202
+ taxonomy_map = sorted(
203
+ taxonomy_map.items(),
204
+ key=lambda x: len({c for c, _ in x[1]}),
205
+ reverse=True,
206
+ )
207
+
208
+ # Keep track of the sequences available per chain, keeping the original
209
+ # order of the sequences in the MSA to favor the best matching sequences
210
+ visited = {(c, s) for c, items in taxonomy_map for s in items}
211
+ available = {}
212
+ for c in chain_ids:
213
+ available[c] = [
214
+ i for i in range(1, len(msa[c].sequences)) if (c, i) not in visited
215
+ ]
216
+
217
+ # Create sequence pairs
218
+ is_paired = []
219
+ pairing = []
220
+
221
+ # Start with the first sequence for each chain
222
+ is_paired.append({c: 1 for c in chain_ids})
223
+ pairing.append({c: 0 for c in chain_ids})
224
+
225
+ # Then add up to 8191 paired rows
226
+ for _, pairs in taxonomy_map:
227
+ # Group occurences by chain_id in case we have multiple
228
+ # sequences from the same chain and same taxonomy
229
+ chain_occurences = {}
230
+ for chain_id, seq_idx in pairs:
231
+ chain_occurences.setdefault(chain_id, []).append(seq_idx)
232
+
233
+ # We create as many pairings as the maximum number of occurences
234
+ max_occurences = max(len(v) for v in chain_occurences.values())
235
+ for i in range(max_occurences):
236
+ row_pairing = {}
237
+ row_is_paired = {}
238
+
239
+ # Add the chains present in the taxonomy
240
+ for chain_id, seq_idxs in chain_occurences.items():
241
+ # Roll over the sequence index to maximize diversity
242
+ idx = i % len(seq_idxs)
243
+ seq_idx = seq_idxs[idx]
244
+
245
+ # Add the sequence to the pairing
246
+ row_pairing[chain_id] = seq_idx
247
+ row_is_paired[chain_id] = 1
248
+
249
+ # Add any missing chains
250
+ for chain_id in chain_ids:
251
+ if chain_id not in row_pairing:
252
+ row_is_paired[chain_id] = 0
253
+ if available[chain_id]:
254
+ # Add the next available sequence
255
+ seq_idx = available[chain_id].pop(0)
256
+ row_pairing[chain_id] = seq_idx
257
+ else:
258
+ # No more sequences available, we place a gap
259
+ row_pairing[chain_id] = -1
260
+
261
+ pairing.append(row_pairing)
262
+ is_paired.append(row_is_paired)
263
+
264
+ # Break if we have enough pairs
265
+ if len(pairing) >= max_pairs:
266
+ break
267
+
268
+ # Break if we have enough pairs
269
+ if len(pairing) >= max_pairs:
270
+ break
271
+
272
+ # Now add up to 16384 unpaired rows total
273
+ max_left = max(len(v) for v in available.values())
274
+ for _ in range(min(max_total - len(pairing), max_left)):
275
+ row_pairing = {}
276
+ row_is_paired = {}
277
+ for chain_id in chain_ids:
278
+ row_is_paired[chain_id] = 0
279
+ if available[chain_id]:
280
+ # Add the next available sequence
281
+ seq_idx = available[chain_id].pop(0)
282
+ row_pairing[chain_id] = seq_idx
283
+ else:
284
+ # No more sequences available, we place a gap
285
+ row_pairing[chain_id] = -1
286
+
287
+ pairing.append(row_pairing)
288
+ is_paired.append(row_is_paired)
289
+
290
+ # Break if we have enough sequences
291
+ if len(pairing) >= max_total:
292
+ break
293
+
294
+ # Randomly sample a subset of the pairs
295
+ # ensuring the first row is always present
296
+ if random_subset:
297
+ num_seqs = len(pairing)
298
+ if num_seqs > max_seqs:
299
+ indices = np.random.choice(
300
+ list(range(1, num_seqs)), size=max_seqs - 1, replace=False
301
+ ) # noqa: NPY002
302
+ pairing = [pairing[0]] + [pairing[i] for i in indices]
303
+ is_paired = [is_paired[0]] + [is_paired[i] for i in indices]
304
+ else:
305
+ # Deterministic downsample to max_seqs
306
+ pairing = pairing[:max_seqs]
307
+ is_paired = is_paired[:max_seqs]
308
+
309
+ # Map (chain_id, seq_idx, res_idx) to deletion
310
+ deletions = {}
311
+ for chain_id, chain_msa in msa.items():
312
+ chain_deletions = chain_msa.deletions
313
+ for sequence in chain_msa.sequences:
314
+ del_start = sequence["del_start"]
315
+ del_end = sequence["del_end"]
316
+ chain_deletions = chain_msa.deletions[del_start:del_end]
317
+ for deletion_data in chain_deletions:
318
+ seq_idx = sequence["seq_idx"]
319
+ res_idx = deletion_data["res_idx"]
320
+ deletion = deletion_data["deletion"]
321
+ deletions[(chain_id, seq_idx, res_idx)] = deletion
322
+
323
+ # Add all the token MSA data
324
+ msa_data, del_data, paired_data = prepare_msa_arrays(
325
+ data.tokens, pairing, is_paired, deletions, msa
326
+ )
327
+
328
+ msa_data = torch.tensor(msa_data, dtype=torch.long)
329
+ del_data = torch.tensor(del_data, dtype=torch.float)
330
+ paired_data = torch.tensor(paired_data, dtype=torch.float)
331
+
332
+ return msa_data, del_data, paired_data
333
+
334
+
335
+ def prepare_msa_arrays(
336
+ tokens,
337
+ pairing: list[dict[int, int]],
338
+ is_paired: list[dict[int, int]],
339
+ deletions: dict[tuple[int, int, int], int],
340
+ msa: dict[int, MSA],
341
+ ) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
342
+ """Reshape data to play nicely with numba jit."""
343
+ token_asym_ids_arr = np.array([t["asym_id"] for t in tokens], dtype=np.int64)
344
+ token_res_idxs_arr = np.array([t["res_idx"] for t in tokens], dtype=np.int64)
345
+
346
+ chain_ids = sorted(msa.keys())
347
+
348
+ # chain_ids are not necessarily contiguous (e.g. they might be 0, 24, 25).
349
+ # This allows us to look up a chain_id by it's index in the chain_ids list.
350
+ chain_id_to_idx = {chain_id: i for i, chain_id in enumerate(chain_ids)}
351
+ token_asym_ids_idx_arr = np.array(
352
+ [chain_id_to_idx[asym_id] for asym_id in token_asym_ids_arr], dtype=np.int64
353
+ )
354
+
355
+ pairing_arr = np.zeros((len(pairing), len(chain_ids)), dtype=np.int64)
356
+ is_paired_arr = np.zeros((len(is_paired), len(chain_ids)), dtype=np.int64)
357
+
358
+ for i, row_pairing in enumerate(pairing):
359
+ for chain_id in chain_ids:
360
+ pairing_arr[i, chain_id_to_idx[chain_id]] = row_pairing[chain_id]
361
+
362
+ for i, row_is_paired in enumerate(is_paired):
363
+ for chain_id in chain_ids:
364
+ is_paired_arr[i, chain_id_to_idx[chain_id]] = row_is_paired[chain_id]
365
+
366
+ max_seq_len = max(len(msa[chain_id].sequences) for chain_id in chain_ids)
367
+
368
+ # we want res_start from sequences
369
+ msa_sequences = np.full((len(chain_ids), max_seq_len), -1, dtype=np.int64)
370
+ for chain_id in chain_ids:
371
+ for i, seq in enumerate(msa[chain_id].sequences):
372
+ msa_sequences[chain_id_to_idx[chain_id], i] = seq["res_start"]
373
+
374
+ max_residues_len = max(len(msa[chain_id].residues) for chain_id in chain_ids)
375
+ msa_residues = np.full((len(chain_ids), max_residues_len), -1, dtype=np.int64)
376
+ for chain_id in chain_ids:
377
+ residues = msa[chain_id].residues.astype(np.int64)
378
+ idxs = np.arange(len(residues))
379
+ chain_idx = chain_id_to_idx[chain_id]
380
+ msa_residues[chain_idx, idxs] = residues
381
+
382
+ deletions_dict = numba.typed.Dict.empty(
383
+ key_type=numba.types.Tuple(
384
+ [numba.types.int64, numba.types.int64, numba.types.int64]
385
+ ),
386
+ value_type=numba.types.int64,
387
+ )
388
+ deletions_dict.update(deletions)
389
+
390
+ return _prepare_msa_arrays_inner(
391
+ token_asym_ids_arr,
392
+ token_res_idxs_arr,
393
+ token_asym_ids_idx_arr,
394
+ pairing_arr,
395
+ is_paired_arr,
396
+ deletions_dict,
397
+ msa_sequences,
398
+ msa_residues,
399
+ const.token_ids["-"],
400
+ )
401
+
402
+
403
+ deletions_dict_type = types.DictType(types.UniTuple(types.int64, 3), types.int64)
404
+
405
+
406
+ @numba.njit(
407
+ [
408
+ types.Tuple(
409
+ (
410
+ types.int64[:, ::1], # msa_data
411
+ types.int64[:, ::1], # del_data
412
+ types.int64[:, ::1], # paired_data
413
+ )
414
+ )(
415
+ types.int64[::1], # token_asym_ids
416
+ types.int64[::1], # token_res_idxs
417
+ types.int64[::1], # token_asym_ids_idx
418
+ types.int64[:, ::1], # pairing
419
+ types.int64[:, ::1], # is_paired
420
+ deletions_dict_type, # deletions
421
+ types.int64[:, ::1], # msa_sequences
422
+ types.int64[:, ::1], # msa_residues
423
+ types.int64, # gap_token
424
+ )
425
+ ],
426
+ cache=True,
427
+ )
428
+ def _prepare_msa_arrays_inner(
429
+ token_asym_ids: npt.NDArray[np.int64],
430
+ token_res_idxs: npt.NDArray[np.int64],
431
+ token_asym_ids_idx: npt.NDArray[np.int64],
432
+ pairing: npt.NDArray[np.int64],
433
+ is_paired: npt.NDArray[np.int64],
434
+ deletions: dict[tuple[int, int, int], int],
435
+ msa_sequences: npt.NDArray[np.int64],
436
+ msa_residues: npt.NDArray[np.int64],
437
+ gap_token: int,
438
+ ) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
439
+ n_tokens = len(token_asym_ids)
440
+ n_pairs = len(pairing)
441
+ msa_data = np.full((n_tokens, n_pairs), gap_token, dtype=np.int64)
442
+ paired_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
443
+ del_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
444
+
445
+ # Add all the token MSA data
446
+ for token_idx in range(n_tokens):
447
+ chain_id_idx = token_asym_ids_idx[token_idx]
448
+ chain_id = token_asym_ids[token_idx]
449
+ res_idx = token_res_idxs[token_idx]
450
+
451
+ for pair_idx in range(n_pairs):
452
+ seq_idx = pairing[pair_idx, chain_id_idx]
453
+ paired_data[token_idx, pair_idx] = is_paired[pair_idx, chain_id_idx]
454
+
455
+ # Add residue type
456
+ if seq_idx != -1:
457
+ res_start = msa_sequences[chain_id_idx, seq_idx]
458
+ res_type = msa_residues[chain_id_idx, res_start + res_idx]
459
+ k = (chain_id, seq_idx, res_idx)
460
+ if k in deletions:
461
+ del_data[token_idx, pair_idx] = deletions[k]
462
+ msa_data[token_idx, pair_idx] = res_type
463
+
464
+ return msa_data, del_data, paired_data
465
+
466
+
467
+ ####################################################################################################
468
+ # FEATURES
469
+ ####################################################################################################
470
+
471
+
472
+ def select_subset_from_mask(mask, p):
473
+ num_true = np.sum(mask)
474
+ v = np.random.geometric(p) + 1
475
+ k = min(v, num_true)
476
+
477
+ true_indices = np.where(mask)[0]
478
+
479
+ # Randomly select k indices from the true_indices
480
+ selected_indices = np.random.choice(true_indices, size=k, replace=False)
481
+
482
+ new_mask = np.zeros_like(mask)
483
+ new_mask[selected_indices] = 1
484
+
485
+ return new_mask
486
+
487
+
488
+ def process_token_features(
489
+ data: Tokenized,
490
+ max_tokens: Optional[int] = None,
491
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
492
+ binder_pocket_cutoff: Optional[float] = 6.0,
493
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
494
+ only_ligand_binder_pocket: Optional[bool] = False,
495
+ inference_binder: Optional[list[int]] = None,
496
+ inference_pocket: Optional[list[tuple[int, int]]] = None,
497
+ ) -> dict[str, Tensor]:
498
+ """Get the token features.
499
+
500
+ Parameters
501
+ ----------
502
+ data : Tokenized
503
+ The tokenized data.
504
+ max_tokens : int
505
+ The maximum number of tokens.
506
+
507
+ Returns
508
+ -------
509
+ dict[str, Tensor]
510
+ The token features.
511
+
512
+ """
513
+ # Token data
514
+ token_data = data.tokens
515
+ token_bonds = data.bonds
516
+
517
+ # Token core features
518
+ token_index = torch.arange(len(token_data), dtype=torch.long)
519
+ residue_index = from_numpy(token_data["res_idx"].copy()).long()
520
+ asym_id = from_numpy(token_data["asym_id"].copy()).long()
521
+ entity_id = from_numpy(token_data["entity_id"].copy()).long()
522
+ sym_id = from_numpy(token_data["sym_id"].copy()).long()
523
+ mol_type = from_numpy(token_data["mol_type"].copy()).long()
524
+ res_type = from_numpy(token_data["res_type"].copy()).long()
525
+ res_type = one_hot(res_type, num_classes=const.num_tokens)
526
+ disto_center = from_numpy(token_data["disto_coords"].copy())
527
+
528
+ # Token mask features
529
+ pad_mask = torch.ones(len(token_data), dtype=torch.float)
530
+ resolved_mask = from_numpy(token_data["resolved_mask"].copy()).float()
531
+ disto_mask = from_numpy(token_data["disto_mask"].copy()).float()
532
+ cyclic_period = from_numpy(token_data["cyclic_period"].copy())
533
+
534
+ # Token bond features
535
+ if max_tokens is not None:
536
+ pad_len = max_tokens - len(token_data)
537
+ num_tokens = max_tokens if pad_len > 0 else len(token_data)
538
+ else:
539
+ num_tokens = len(token_data)
540
+
541
+ tok_to_idx = {tok["token_idx"]: idx for idx, tok in enumerate(token_data)}
542
+ bonds = torch.zeros(num_tokens, num_tokens, dtype=torch.float)
543
+ for token_bond in token_bonds:
544
+ token_1 = tok_to_idx[token_bond["token_1"]]
545
+ token_2 = tok_to_idx[token_bond["token_2"]]
546
+ bonds[token_1, token_2] = 1
547
+ bonds[token_2, token_1] = 1
548
+
549
+ bonds = bonds.unsqueeze(-1)
550
+
551
+ # Pocket conditioned feature
552
+ pocket_feature = (
553
+ np.zeros(len(token_data)) + const.pocket_contact_info["UNSPECIFIED"]
554
+ )
555
+ if inference_binder is not None:
556
+ assert inference_pocket is not None
557
+ pocket_residues = set(inference_pocket)
558
+ for idx, token in enumerate(token_data):
559
+ if token["asym_id"] in inference_binder:
560
+ pocket_feature[idx] = const.pocket_contact_info["BINDER"]
561
+ elif (token["asym_id"], token["res_idx"]) in pocket_residues:
562
+ pocket_feature[idx] = const.pocket_contact_info["POCKET"]
563
+ else:
564
+ pocket_feature[idx] = const.pocket_contact_info["UNSELECTED"]
565
+ elif (
566
+ binder_pocket_conditioned_prop > 0.0
567
+ and random.random() < binder_pocket_conditioned_prop
568
+ ):
569
+ # choose as binder a random ligand in the crop, if there are no ligands select a protein chain
570
+ binder_asym_ids = np.unique(
571
+ token_data["asym_id"][
572
+ token_data["mol_type"] == const.chain_type_ids["NONPOLYMER"]
573
+ ]
574
+ )
575
+
576
+ if len(binder_asym_ids) == 0:
577
+ if not only_ligand_binder_pocket:
578
+ binder_asym_ids = np.unique(token_data["asym_id"])
579
+
580
+ if len(binder_asym_ids) > 0:
581
+ pocket_asym_id = random.choice(binder_asym_ids)
582
+ binder_mask = token_data["asym_id"] == pocket_asym_id
583
+
584
+ binder_coords = []
585
+ for token in token_data:
586
+ if token["asym_id"] == pocket_asym_id:
587
+ binder_coords.append(
588
+ data.structure.atoms["coords"][
589
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
590
+ ]
591
+ )
592
+ binder_coords = np.concatenate(binder_coords, axis=0)
593
+
594
+ # find the tokens in the pocket
595
+ token_dist = np.zeros(len(token_data)) + 1000
596
+ for i, token in enumerate(token_data):
597
+ if (
598
+ token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
599
+ and token["asym_id"] != pocket_asym_id
600
+ and token["resolved_mask"] == 1
601
+ ):
602
+ token_coords = data.structure.atoms["coords"][
603
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
604
+ ]
605
+
606
+ # find chain and apply chain transformation
607
+ for chain in data.structure.chains:
608
+ if chain["asym_id"] == token["asym_id"]:
609
+ break
610
+
611
+ token_dist[i] = np.min(
612
+ np.linalg.norm(
613
+ token_coords[:, None, :] - binder_coords[None, :, :],
614
+ axis=-1,
615
+ )
616
+ )
617
+
618
+ pocket_mask = token_dist < binder_pocket_cutoff
619
+
620
+ if np.sum(pocket_mask) > 0:
621
+ pocket_feature = (
622
+ np.zeros(len(token_data)) + const.pocket_contact_info["UNSELECTED"]
623
+ )
624
+ pocket_feature[binder_mask] = const.pocket_contact_info["BINDER"]
625
+
626
+ if binder_pocket_sampling_geometric_p > 0.0:
627
+ # select a subset of the pocket, according
628
+ # to a geometric distribution with one as minimum
629
+ pocket_mask = select_subset_from_mask(
630
+ pocket_mask, binder_pocket_sampling_geometric_p
631
+ )
632
+
633
+ pocket_feature[pocket_mask] = const.pocket_contact_info["POCKET"]
634
+ pocket_feature = from_numpy(pocket_feature).long()
635
+ pocket_feature = one_hot(pocket_feature, num_classes=len(const.pocket_contact_info))
636
+
637
+ # Pad to max tokens if given
638
+ if max_tokens is not None:
639
+ pad_len = max_tokens - len(token_data)
640
+ if pad_len > 0:
641
+ token_index = pad_dim(token_index, 0, pad_len)
642
+ residue_index = pad_dim(residue_index, 0, pad_len)
643
+ asym_id = pad_dim(asym_id, 0, pad_len)
644
+ entity_id = pad_dim(entity_id, 0, pad_len)
645
+ sym_id = pad_dim(sym_id, 0, pad_len)
646
+ mol_type = pad_dim(mol_type, 0, pad_len)
647
+ res_type = pad_dim(res_type, 0, pad_len)
648
+ disto_center = pad_dim(disto_center, 0, pad_len)
649
+ pad_mask = pad_dim(pad_mask, 0, pad_len)
650
+ resolved_mask = pad_dim(resolved_mask, 0, pad_len)
651
+ disto_mask = pad_dim(disto_mask, 0, pad_len)
652
+ pocket_feature = pad_dim(pocket_feature, 0, pad_len)
653
+
654
+ token_features = {
655
+ "token_index": token_index,
656
+ "residue_index": residue_index,
657
+ "asym_id": asym_id,
658
+ "entity_id": entity_id,
659
+ "sym_id": sym_id,
660
+ "mol_type": mol_type,
661
+ "res_type": res_type,
662
+ "disto_center": disto_center,
663
+ "token_bonds": bonds,
664
+ "token_pad_mask": pad_mask,
665
+ "token_resolved_mask": resolved_mask,
666
+ "token_disto_mask": disto_mask,
667
+ "pocket_feature": pocket_feature,
668
+ "cyclic_period": cyclic_period,
669
+ }
670
+ return token_features
671
+
672
+
673
+ def process_atom_features(
674
+ data: Tokenized,
675
+ atoms_per_window_queries: int = 32,
676
+ min_dist: float = 2.0,
677
+ max_dist: float = 22.0,
678
+ num_bins: int = 64,
679
+ max_atoms: Optional[int] = None,
680
+ max_tokens: Optional[int] = None,
681
+ ) -> dict[str, Tensor]:
682
+ """Get the atom features.
683
+
684
+ Parameters
685
+ ----------
686
+ data : Tokenized
687
+ The tokenized data.
688
+ max_atoms : int, optional
689
+ The maximum number of atoms.
690
+
691
+ Returns
692
+ -------
693
+ dict[str, Tensor]
694
+ The atom features.
695
+
696
+ """
697
+ # Filter to tokens' atoms
698
+ atom_data = []
699
+ ref_space_uid = []
700
+ coord_data = []
701
+ frame_data = []
702
+ resolved_frame_data = []
703
+ atom_to_token = []
704
+ token_to_rep_atom = [] # index on cropped atom table
705
+ r_set_to_rep_atom = []
706
+ disto_coords = []
707
+ atom_idx = 0
708
+
709
+ chain_res_ids = {}
710
+ for token_id, token in enumerate(data.tokens):
711
+ # Get the chain residue ids
712
+ chain_idx, res_id = token["asym_id"], token["res_idx"]
713
+ chain = data.structure.chains[chain_idx]
714
+
715
+ if (chain_idx, res_id) not in chain_res_ids:
716
+ new_idx = len(chain_res_ids)
717
+ chain_res_ids[(chain_idx, res_id)] = new_idx
718
+ else:
719
+ new_idx = chain_res_ids[(chain_idx, res_id)]
720
+
721
+ # Map atoms to token indices
722
+ ref_space_uid.extend([new_idx] * token["atom_num"])
723
+ atom_to_token.extend([token_id] * token["atom_num"])
724
+
725
+ # Add atom data
726
+ start = token["atom_idx"]
727
+ end = token["atom_idx"] + token["atom_num"]
728
+ token_atoms = data.structure.atoms[start:end]
729
+
730
+ # Map token to representative atom
731
+ token_to_rep_atom.append(atom_idx + token["disto_idx"] - start)
732
+ if (chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]) and token[
733
+ "resolved_mask"
734
+ ]:
735
+ r_set_to_rep_atom.append(atom_idx + token["center_idx"] - start)
736
+
737
+ # Get token coordinates
738
+ token_coords = np.array([token_atoms["coords"]])
739
+ coord_data.append(token_coords)
740
+
741
+ # Get frame data
742
+ res_type = const.tokens[token["res_type"]]
743
+
744
+ if token["atom_num"] < 3 or res_type in ["PAD", "UNK", "-"]:
745
+ idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
746
+ mask_frame = False
747
+ elif (token["mol_type"] == const.chain_type_ids["PROTEIN"]) and (
748
+ res_type in const.ref_atoms
749
+ ):
750
+ idx_frame_a, idx_frame_b, idx_frame_c = (
751
+ const.ref_atoms[res_type].index("N"),
752
+ const.ref_atoms[res_type].index("CA"),
753
+ const.ref_atoms[res_type].index("C"),
754
+ )
755
+ mask_frame = (
756
+ token_atoms["is_present"][idx_frame_a]
757
+ and token_atoms["is_present"][idx_frame_b]
758
+ and token_atoms["is_present"][idx_frame_c]
759
+ )
760
+ elif (
761
+ token["mol_type"] == const.chain_type_ids["DNA"]
762
+ or token["mol_type"] == const.chain_type_ids["RNA"]
763
+ ) and (res_type in const.ref_atoms):
764
+ idx_frame_a, idx_frame_b, idx_frame_c = (
765
+ const.ref_atoms[res_type].index("C1'"),
766
+ const.ref_atoms[res_type].index("C3'"),
767
+ const.ref_atoms[res_type].index("C4'"),
768
+ )
769
+ mask_frame = (
770
+ token_atoms["is_present"][idx_frame_a]
771
+ and token_atoms["is_present"][idx_frame_b]
772
+ and token_atoms["is_present"][idx_frame_c]
773
+ )
774
+ else:
775
+ idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
776
+ mask_frame = False
777
+ frame_data.append(
778
+ [idx_frame_a + atom_idx, idx_frame_b + atom_idx, idx_frame_c + atom_idx]
779
+ )
780
+ resolved_frame_data.append(mask_frame)
781
+
782
+ # Get distogram coordinates
783
+ disto_coords_tok = data.structure.atoms[token["disto_idx"]]["coords"]
784
+ disto_coords.append(disto_coords_tok)
785
+
786
+ # Update atom data. This is technically never used again (we rely on coord_data),
787
+ # but we update for consistency and to make sure the Atom object has valid, transformed coordinates.
788
+ token_atoms = token_atoms.copy()
789
+ token_atoms["coords"] = token_coords[0] # atom has a copy of first coords
790
+ atom_data.append(token_atoms)
791
+ atom_idx += len(token_atoms)
792
+
793
+ disto_coords = np.array(disto_coords)
794
+
795
+ # Compute distogram
796
+ t_center = torch.Tensor(disto_coords)
797
+ t_dists = torch.cdist(t_center, t_center)
798
+ boundaries = torch.linspace(min_dist, max_dist, num_bins - 1)
799
+ distogram = (t_dists.unsqueeze(-1) > boundaries).sum(dim=-1).long()
800
+ disto_target = one_hot(distogram, num_classes=num_bins)
801
+
802
+ atom_data = np.concatenate(atom_data)
803
+ coord_data = np.concatenate(coord_data, axis=1)
804
+ ref_space_uid = np.array(ref_space_uid)
805
+
806
+ # Compute features
807
+ ref_atom_name_chars = from_numpy(atom_data["name"]).long()
808
+ ref_element = from_numpy(atom_data["element"]).long()
809
+ ref_charge = from_numpy(atom_data["charge"])
810
+ ref_pos = from_numpy(
811
+ atom_data["conformer"].copy()
812
+ ) # not sure why I need to copy here..
813
+ ref_space_uid = from_numpy(ref_space_uid)
814
+ coords = from_numpy(coord_data.copy())
815
+ resolved_mask = from_numpy(atom_data["is_present"])
816
+ pad_mask = torch.ones(len(atom_data), dtype=torch.float)
817
+ atom_to_token = torch.tensor(atom_to_token, dtype=torch.long)
818
+ token_to_rep_atom = torch.tensor(token_to_rep_atom, dtype=torch.long)
819
+ r_set_to_rep_atom = torch.tensor(r_set_to_rep_atom, dtype=torch.long)
820
+ frame_data, resolved_frame_data = compute_frames_nonpolymer(
821
+ data,
822
+ coord_data,
823
+ atom_data["is_present"],
824
+ atom_to_token,
825
+ frame_data,
826
+ resolved_frame_data,
827
+ ) # Compute frames for NONPOLYMER tokens
828
+ frames = from_numpy(frame_data.copy())
829
+ frame_resolved_mask = from_numpy(resolved_frame_data.copy())
830
+ # Convert to one-hot
831
+ ref_atom_name_chars = one_hot(
832
+ ref_atom_name_chars % num_bins, num_classes=num_bins
833
+ ) # added for lower case letters
834
+ ref_element = one_hot(ref_element, num_classes=const.num_elements)
835
+ atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1)
836
+ token_to_rep_atom = one_hot(token_to_rep_atom, num_classes=len(atom_data))
837
+ r_set_to_rep_atom = one_hot(r_set_to_rep_atom, num_classes=len(atom_data))
838
+
839
+ # Center the ground truth coordinates
840
+ center = (coords * resolved_mask[None, :, None]).sum(dim=1)
841
+ center = center / resolved_mask.sum().clamp(min=1)
842
+ coords = coords - center[:, None]
843
+
844
+ # Apply random roto-translation to the input atoms
845
+ ref_pos = center_random_augmentation(
846
+ ref_pos[None], resolved_mask[None], centering=False
847
+ )[0]
848
+
849
+ # Compute padding and apply
850
+ if max_atoms is not None:
851
+ assert max_atoms % atoms_per_window_queries == 0
852
+ pad_len = max_atoms - len(atom_data)
853
+ else:
854
+ pad_len = (
855
+ (len(atom_data) - 1) // atoms_per_window_queries + 1
856
+ ) * atoms_per_window_queries - len(atom_data)
857
+
858
+ if pad_len > 0:
859
+ pad_mask = pad_dim(pad_mask, 0, pad_len)
860
+ ref_pos = pad_dim(ref_pos, 0, pad_len)
861
+ resolved_mask = pad_dim(resolved_mask, 0, pad_len)
862
+ ref_element = pad_dim(ref_element, 0, pad_len)
863
+ ref_charge = pad_dim(ref_charge, 0, pad_len)
864
+ ref_atom_name_chars = pad_dim(ref_atom_name_chars, 0, pad_len)
865
+ ref_space_uid = pad_dim(ref_space_uid, 0, pad_len)
866
+ coords = pad_dim(coords, 1, pad_len)
867
+ atom_to_token = pad_dim(atom_to_token, 0, pad_len)
868
+ token_to_rep_atom = pad_dim(token_to_rep_atom, 1, pad_len)
869
+ r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 1, pad_len)
870
+
871
+ if max_tokens is not None:
872
+ pad_len = max_tokens - token_to_rep_atom.shape[0]
873
+ if pad_len > 0:
874
+ atom_to_token = pad_dim(atom_to_token, 1, pad_len)
875
+ token_to_rep_atom = pad_dim(token_to_rep_atom, 0, pad_len)
876
+ r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 0, pad_len)
877
+ disto_target = pad_dim(pad_dim(disto_target, 0, pad_len), 1, pad_len)
878
+ frames = pad_dim(frames, 0, pad_len)
879
+ frame_resolved_mask = pad_dim(frame_resolved_mask, 0, pad_len)
880
+
881
+ return {
882
+ "ref_pos": ref_pos,
883
+ "atom_resolved_mask": resolved_mask,
884
+ "ref_element": ref_element,
885
+ "ref_charge": ref_charge,
886
+ "ref_atom_name_chars": ref_atom_name_chars,
887
+ "ref_space_uid": ref_space_uid,
888
+ "coords": coords,
889
+ "atom_pad_mask": pad_mask,
890
+ "atom_to_token": atom_to_token,
891
+ "token_to_rep_atom": token_to_rep_atom,
892
+ "r_set_to_rep_atom": r_set_to_rep_atom,
893
+ "disto_target": disto_target,
894
+ "frames_idx": frames,
895
+ "frame_resolved_mask": frame_resolved_mask,
896
+ }
897
+
898
+
899
+ def process_msa_features(
900
+ data: Tokenized,
901
+ max_seqs_batch: int,
902
+ max_seqs: int,
903
+ max_tokens: Optional[int] = None,
904
+ pad_to_max_seqs: bool = False,
905
+ ) -> dict[str, Tensor]:
906
+ """Get the MSA features.
907
+
908
+ Parameters
909
+ ----------
910
+ data : Tokenized
911
+ The tokenized data.
912
+ max_seqs : int
913
+ The maximum number of MSA sequences.
914
+ max_tokens : int
915
+ The maximum number of tokens.
916
+ pad_to_max_seqs : bool
917
+ Whether to pad to the maximum number of sequences.
918
+
919
+ Returns
920
+ -------
921
+ dict[str, Tensor]
922
+ The MSA features.
923
+
924
+ """
925
+ # Created paired MSA
926
+ msa, deletion, paired = construct_paired_msa(data, max_seqs_batch)
927
+ msa, deletion, paired = (
928
+ msa.transpose(1, 0),
929
+ deletion.transpose(1, 0),
930
+ paired.transpose(1, 0),
931
+ ) # (N_MSA, N_RES, N_AA)
932
+
933
+ # Prepare features
934
+ msa = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens)
935
+ msa_mask = torch.ones_like(msa[:, :, 0])
936
+ profile = msa.float().mean(dim=0)
937
+ has_deletion = deletion > 0
938
+ deletion = np.pi / 2 * np.arctan(deletion / 3)
939
+ deletion_mean = deletion.mean(axis=0)
940
+
941
+ # Pad in the MSA dimension (dim=0)
942
+ if pad_to_max_seqs:
943
+ pad_len = max_seqs - msa.shape[0]
944
+ if pad_len > 0:
945
+ msa = pad_dim(msa, 0, pad_len, const.token_ids["-"])
946
+ paired = pad_dim(paired, 0, pad_len)
947
+ msa_mask = pad_dim(msa_mask, 0, pad_len)
948
+ has_deletion = pad_dim(has_deletion, 0, pad_len)
949
+ deletion = pad_dim(deletion, 0, pad_len)
950
+
951
+ # Pad in the token dimension (dim=1)
952
+ if max_tokens is not None:
953
+ pad_len = max_tokens - msa.shape[1]
954
+ if pad_len > 0:
955
+ msa = pad_dim(msa, 1, pad_len, const.token_ids["-"])
956
+ paired = pad_dim(paired, 1, pad_len)
957
+ msa_mask = pad_dim(msa_mask, 1, pad_len)
958
+ has_deletion = pad_dim(has_deletion, 1, pad_len)
959
+ deletion = pad_dim(deletion, 1, pad_len)
960
+ profile = pad_dim(profile, 0, pad_len)
961
+ deletion_mean = pad_dim(deletion_mean, 0, pad_len)
962
+
963
+ return {
964
+ "msa": msa,
965
+ "msa_paired": paired,
966
+ "deletion_value": deletion,
967
+ "has_deletion": has_deletion,
968
+ "deletion_mean": deletion_mean,
969
+ "profile": profile,
970
+ "msa_mask": msa_mask,
971
+ }
972
+
973
+
974
+ def process_symmetry_features(
975
+ cropped: Tokenized, symmetries: dict
976
+ ) -> dict[str, Tensor]:
977
+ """Get the symmetry features.
978
+
979
+ Parameters
980
+ ----------
981
+ data : Tokenized
982
+ The tokenized data.
983
+
984
+ Returns
985
+ -------
986
+ dict[str, Tensor]
987
+ The symmetry features.
988
+
989
+ """
990
+ features = get_chain_symmetries(cropped)
991
+ features.update(get_amino_acids_symmetries(cropped))
992
+ features.update(get_ligand_symmetries(cropped, symmetries))
993
+
994
+ return features
995
+
996
+
997
+ def process_residue_constraint_features(
998
+ data: Tokenized,
999
+ ) -> dict[str, Tensor]:
1000
+ residue_constraints = data.residue_constraints
1001
+ if residue_constraints is not None:
1002
+ rdkit_bounds_constraints = residue_constraints.rdkit_bounds_constraints
1003
+ chiral_atom_constraints = residue_constraints.chiral_atom_constraints
1004
+ stereo_bond_constraints = residue_constraints.stereo_bond_constraints
1005
+ planar_bond_constraints = residue_constraints.planar_bond_constraints
1006
+ planar_ring_5_constraints = residue_constraints.planar_ring_5_constraints
1007
+ planar_ring_6_constraints = residue_constraints.planar_ring_6_constraints
1008
+
1009
+ rdkit_bounds_index = torch.tensor(
1010
+ rdkit_bounds_constraints["atom_idxs"].copy(), dtype=torch.long
1011
+ ).T
1012
+ rdkit_bounds_bond_mask = torch.tensor(
1013
+ rdkit_bounds_constraints["is_bond"].copy(), dtype=torch.bool
1014
+ )
1015
+ rdkit_bounds_angle_mask = torch.tensor(
1016
+ rdkit_bounds_constraints["is_angle"].copy(), dtype=torch.bool
1017
+ )
1018
+ rdkit_upper_bounds = torch.tensor(
1019
+ rdkit_bounds_constraints["upper_bound"].copy(), dtype=torch.float
1020
+ )
1021
+ rdkit_lower_bounds = torch.tensor(
1022
+ rdkit_bounds_constraints["lower_bound"].copy(), dtype=torch.float
1023
+ )
1024
+
1025
+ chiral_atom_index = torch.tensor(
1026
+ chiral_atom_constraints["atom_idxs"].copy(), dtype=torch.long
1027
+ ).T
1028
+ chiral_reference_mask = torch.tensor(
1029
+ chiral_atom_constraints["is_reference"].copy(), dtype=torch.bool
1030
+ )
1031
+ chiral_atom_orientations = torch.tensor(
1032
+ chiral_atom_constraints["is_r"].copy(), dtype=torch.bool
1033
+ )
1034
+
1035
+ stereo_bond_index = torch.tensor(
1036
+ stereo_bond_constraints["atom_idxs"].copy(), dtype=torch.long
1037
+ ).T
1038
+ stereo_reference_mask = torch.tensor(
1039
+ stereo_bond_constraints["is_reference"].copy(), dtype=torch.bool
1040
+ )
1041
+ stereo_bond_orientations = torch.tensor(
1042
+ stereo_bond_constraints["is_e"].copy(), dtype=torch.bool
1043
+ )
1044
+
1045
+ planar_bond_index = torch.tensor(
1046
+ planar_bond_constraints["atom_idxs"].copy(), dtype=torch.long
1047
+ ).T
1048
+ planar_ring_5_index = torch.tensor(
1049
+ planar_ring_5_constraints["atom_idxs"].copy(), dtype=torch.long
1050
+ ).T
1051
+ planar_ring_6_index = torch.tensor(
1052
+ planar_ring_6_constraints["atom_idxs"].copy(), dtype=torch.long
1053
+ ).T
1054
+ else:
1055
+ rdkit_bounds_index = torch.empty((2, 0), dtype=torch.long)
1056
+ rdkit_bounds_bond_mask = torch.empty((0,), dtype=torch.bool)
1057
+ rdkit_bounds_angle_mask = torch.empty((0,), dtype=torch.bool)
1058
+ rdkit_upper_bounds = torch.empty((0,), dtype=torch.float)
1059
+ rdkit_lower_bounds = torch.empty((0,), dtype=torch.float)
1060
+ chiral_atom_index = torch.empty(
1061
+ (
1062
+ 4,
1063
+ 0,
1064
+ ),
1065
+ dtype=torch.long,
1066
+ )
1067
+ chiral_reference_mask = torch.empty((0,), dtype=torch.bool)
1068
+ chiral_atom_orientations = torch.empty((0,), dtype=torch.bool)
1069
+ stereo_bond_index = torch.empty((4, 0), dtype=torch.long)
1070
+ stereo_reference_mask = torch.empty((0,), dtype=torch.bool)
1071
+ stereo_bond_orientations = torch.empty((0,), dtype=torch.bool)
1072
+ planar_bond_index = torch.empty((6, 0), dtype=torch.long)
1073
+ planar_ring_5_index = torch.empty((5, 0), dtype=torch.long)
1074
+ planar_ring_6_index = torch.empty((6, 0), dtype=torch.long)
1075
+
1076
+ return {
1077
+ "rdkit_bounds_index": rdkit_bounds_index,
1078
+ "rdkit_bounds_bond_mask": rdkit_bounds_bond_mask,
1079
+ "rdkit_bounds_angle_mask": rdkit_bounds_angle_mask,
1080
+ "rdkit_upper_bounds": rdkit_upper_bounds,
1081
+ "rdkit_lower_bounds": rdkit_lower_bounds,
1082
+ "chiral_atom_index": chiral_atom_index,
1083
+ "chiral_reference_mask": chiral_reference_mask,
1084
+ "chiral_atom_orientations": chiral_atom_orientations,
1085
+ "stereo_bond_index": stereo_bond_index,
1086
+ "stereo_reference_mask": stereo_reference_mask,
1087
+ "stereo_bond_orientations": stereo_bond_orientations,
1088
+ "planar_bond_index": planar_bond_index,
1089
+ "planar_ring_5_index": planar_ring_5_index,
1090
+ "planar_ring_6_index": planar_ring_6_index,
1091
+ }
1092
+
1093
+
1094
+ def process_chain_feature_constraints(
1095
+ data: Tokenized,
1096
+ ) -> dict[str, Tensor]:
1097
+ structure = data.structure
1098
+ if structure.connections.shape[0] > 0:
1099
+ connected_chain_index, connected_atom_index = [], []
1100
+ for connection in structure.connections:
1101
+ connected_chain_index.append([connection["chain_1"], connection["chain_2"]])
1102
+ connected_atom_index.append([connection["atom_1"], connection["atom_2"]])
1103
+ connected_chain_index = torch.tensor(connected_chain_index, dtype=torch.long).T
1104
+ connected_atom_index = torch.tensor(connected_atom_index, dtype=torch.long).T
1105
+ else:
1106
+ connected_chain_index = torch.empty((2, 0), dtype=torch.long)
1107
+ connected_atom_index = torch.empty((2, 0), dtype=torch.long)
1108
+
1109
+ symmetric_chain_index = []
1110
+ for i, chain_i in enumerate(structure.chains):
1111
+ for j, chain_j in enumerate(structure.chains):
1112
+ if j <= i:
1113
+ continue
1114
+ if chain_i["entity_id"] == chain_j["entity_id"]:
1115
+ symmetric_chain_index.append([i, j])
1116
+ if len(symmetric_chain_index) > 0:
1117
+ symmetric_chain_index = torch.tensor(symmetric_chain_index, dtype=torch.long).T
1118
+ else:
1119
+ symmetric_chain_index = torch.empty((2, 0), dtype=torch.long)
1120
+ return {
1121
+ "connected_chain_index": connected_chain_index,
1122
+ "connected_atom_index": connected_atom_index,
1123
+ "symmetric_chain_index": symmetric_chain_index,
1124
+ }
1125
+
1126
+
1127
+ class BoltzFeaturizer:
1128
+ """Boltz featurizer."""
1129
+
1130
+ def process(
1131
+ self,
1132
+ data: Tokenized,
1133
+ training: bool,
1134
+ max_seqs: int = 4096,
1135
+ atoms_per_window_queries: int = 32,
1136
+ min_dist: float = 2.0,
1137
+ max_dist: float = 22.0,
1138
+ num_bins: int = 64,
1139
+ max_tokens: Optional[int] = None,
1140
+ max_atoms: Optional[int] = None,
1141
+ pad_to_max_seqs: bool = False,
1142
+ compute_symmetries: bool = False,
1143
+ symmetries: Optional[dict] = None,
1144
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
1145
+ binder_pocket_cutoff: Optional[float] = 6.0,
1146
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
1147
+ only_ligand_binder_pocket: Optional[bool] = False,
1148
+ inference_binder: Optional[int] = None,
1149
+ inference_pocket: Optional[list[tuple[int, int]]] = None,
1150
+ compute_constraint_features: bool = False,
1151
+ ) -> dict[str, Tensor]:
1152
+ """Compute features.
1153
+
1154
+ Parameters
1155
+ ----------
1156
+ data : Tokenized
1157
+ The tokenized data.
1158
+ training : bool
1159
+ Whether the model is in training mode.
1160
+ max_tokens : int, optional
1161
+ The maximum number of tokens.
1162
+ max_atoms : int, optional
1163
+ The maximum number of atoms
1164
+ max_seqs : int, optional
1165
+ The maximum number of sequences.
1166
+
1167
+ Returns
1168
+ -------
1169
+ dict[str, Tensor]
1170
+ The features for model training.
1171
+
1172
+ """
1173
+ # Compute random number of sequences
1174
+ if training and max_seqs is not None:
1175
+ max_seqs_batch = np.random.randint(1, max_seqs + 1) # noqa: NPY002
1176
+ else:
1177
+ max_seqs_batch = max_seqs
1178
+
1179
+ # Compute token features
1180
+ token_features = process_token_features(
1181
+ data,
1182
+ max_tokens,
1183
+ binder_pocket_conditioned_prop,
1184
+ binder_pocket_cutoff,
1185
+ binder_pocket_sampling_geometric_p,
1186
+ only_ligand_binder_pocket,
1187
+ inference_binder=inference_binder,
1188
+ inference_pocket=inference_pocket,
1189
+ )
1190
+
1191
+ # Compute atom features
1192
+ atom_features = process_atom_features(
1193
+ data,
1194
+ atoms_per_window_queries,
1195
+ min_dist,
1196
+ max_dist,
1197
+ num_bins,
1198
+ max_atoms,
1199
+ max_tokens,
1200
+ )
1201
+
1202
+ # Compute MSA features
1203
+ msa_features = process_msa_features(
1204
+ data,
1205
+ max_seqs_batch,
1206
+ max_seqs,
1207
+ max_tokens,
1208
+ pad_to_max_seqs,
1209
+ )
1210
+
1211
+ # Compute symmetry features
1212
+ symmetry_features = {}
1213
+ if compute_symmetries:
1214
+ symmetry_features = process_symmetry_features(data, symmetries)
1215
+
1216
+ # Compute residue constraint features
1217
+ residue_constraint_features = {}
1218
+ chain_constraint_features = {}
1219
+ if compute_constraint_features:
1220
+ residue_constraint_features = process_residue_constraint_features(data)
1221
+ chain_constraint_features = process_chain_feature_constraints(data)
1222
+
1223
+ return {
1224
+ **token_features,
1225
+ **atom_features,
1226
+ **msa_features,
1227
+ **symmetry_features,
1228
+ **residue_constraint_features,
1229
+ **chain_constraint_features,
1230
+ }