boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,396 @@
1
+ from dataclasses import astuple, dataclass
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+
6
+ from boltz.data import const
7
+ from boltz.data.tokenize.tokenizer import Tokenizer
8
+ from boltz.data.types import (
9
+ AffinityInfo,
10
+ Input,
11
+ StructureV2,
12
+ TokenBondV2,
13
+ Tokenized,
14
+ TokenV2,
15
+ )
16
+
17
+
18
+ @dataclass
19
+ class TokenData:
20
+ """TokenData datatype."""
21
+
22
+ token_idx: int
23
+ atom_idx: int
24
+ atom_num: int
25
+ res_idx: int
26
+ res_type: int
27
+ res_name: str
28
+ sym_id: int
29
+ asym_id: int
30
+ entity_id: int
31
+ mol_type: int
32
+ center_idx: int
33
+ disto_idx: int
34
+ center_coords: np.ndarray
35
+ disto_coords: np.ndarray
36
+ resolved_mask: bool
37
+ disto_mask: bool
38
+ modified: bool
39
+ frame_rot: np.ndarray
40
+ frame_t: np.ndarray
41
+ frame_mask: bool
42
+ cyclic_period: int
43
+ affinity_mask: bool = False
44
+
45
+
46
+ def compute_frame(
47
+ n: np.ndarray,
48
+ ca: np.ndarray,
49
+ c: np.ndarray,
50
+ ) -> tuple[np.ndarray, np.ndarray]:
51
+ """Compute the frame for a residue.
52
+
53
+ Parameters
54
+ ----------
55
+ n : np.ndarray
56
+ The N atom.
57
+ ca : np.ndarray
58
+ The C atom.
59
+ c : np.ndarray
60
+ The CA atom.
61
+
62
+ Returns
63
+ -------
64
+ np.ndarray
65
+ The frame.
66
+
67
+ """
68
+ v1 = c - ca
69
+ v2 = n - ca
70
+ e1 = v1 / (np.linalg.norm(v1) + 1e-10)
71
+ u2 = v2 - e1 * np.dot(e1.T, v2)
72
+ e2 = u2 / (np.linalg.norm(u2) + 1e-10)
73
+ e3 = np.cross(e1, e2)
74
+ rot = np.column_stack([e1, e2, e3])
75
+ t = ca
76
+ return rot, t
77
+
78
+
79
+ def get_unk_token(chain: np.ndarray) -> int:
80
+ """Get the unk token for a residue.
81
+
82
+ Parameters
83
+ ----------
84
+ chain : np.ndarray
85
+ The chain.
86
+
87
+ Returns
88
+ -------
89
+ int
90
+ The unk token.
91
+
92
+ """
93
+ if chain["mol_type"] == const.chain_type_ids["DNA"]:
94
+ unk_token = const.unk_token["DNA"]
95
+ elif chain["mol_type"] == const.chain_type_ids["RNA"]:
96
+ unk_token = const.unk_token["RNA"]
97
+ else:
98
+ unk_token = const.unk_token["PROTEIN"]
99
+
100
+ res_id = const.token_ids[unk_token]
101
+ return res_id
102
+
103
+
104
+ def tokenize_structure( # noqa: C901, PLR0915
105
+ struct: StructureV2,
106
+ affinity: Optional[AffinityInfo] = None,
107
+ ) -> tuple[np.ndarray, np.ndarray]:
108
+ """Tokenize a structure.
109
+
110
+ Parameters
111
+ ----------
112
+ struct : StructureV2
113
+ The structure to tokenize.
114
+ affinity : Optional[AffinityInfo]
115
+ The affinity information.
116
+
117
+ Returns
118
+ -------
119
+ np.ndarray
120
+ The tokenized data.
121
+ np.ndarray
122
+ The tokenized bonds.
123
+
124
+ """
125
+ # Create token data
126
+ token_data = []
127
+
128
+ # Keep track of atom_idx to token_idx
129
+ token_idx = 0
130
+ atom_to_token = {}
131
+
132
+ # Filter to valid chains only
133
+ chains = struct.chains[struct.mask]
134
+
135
+ # Ensemble atom id start in coords table.
136
+ # For cropper and other operations, harcoded to 0th conformer.
137
+ offset = struct.ensemble[0]["atom_coord_idx"]
138
+
139
+ for chain in chains:
140
+ # Get residue indices
141
+ res_start = chain["res_idx"]
142
+ res_end = chain["res_idx"] + chain["res_num"]
143
+ is_protein = chain["mol_type"] == const.chain_type_ids["PROTEIN"]
144
+ affinity_mask = (affinity is not None) and (
145
+ int(chain["asym_id"]) == int(affinity.chain_id)
146
+ )
147
+
148
+ for res in struct.residues[res_start:res_end]:
149
+ # Get atom indices
150
+ atom_start = res["atom_idx"]
151
+ atom_end = res["atom_idx"] + res["atom_num"]
152
+
153
+ # Standard residues are tokens
154
+ if res["is_standard"]:
155
+ # Get center and disto atoms
156
+ center = struct.atoms[res["atom_center"]]
157
+ disto = struct.atoms[res["atom_disto"]]
158
+
159
+ # Token is present if centers are
160
+ is_present = res["is_present"] & center["is_present"]
161
+ is_disto_present = res["is_present"] & disto["is_present"]
162
+
163
+ # Apply chain transformation
164
+ # Apply chain transformation
165
+ c_coords = struct.coords[offset + res["atom_center"]]["coords"]
166
+ d_coords = struct.coords[offset + res["atom_disto"]]["coords"]
167
+
168
+ # If protein, compute frame, only used for templates
169
+ frame_rot = np.eye(3).flatten()
170
+ frame_t = np.zeros(3)
171
+ frame_mask = False
172
+
173
+ if is_protein:
174
+ # Get frame atoms
175
+ atom_st = res["atom_idx"]
176
+ atom_en = res["atom_idx"] + res["atom_num"]
177
+ atoms = struct.atoms[atom_st:atom_en]
178
+
179
+ # Atoms are always in the order N, CA, C
180
+ atom_n = atoms[0]
181
+ atom_ca = atoms[1]
182
+ atom_c = atoms[2]
183
+
184
+ # Compute frame and mask
185
+ frame_mask = atom_ca["is_present"]
186
+ frame_mask &= atom_c["is_present"]
187
+ frame_mask &= atom_n["is_present"]
188
+ frame_mask = bool(frame_mask)
189
+ if frame_mask:
190
+ frame_rot, frame_t = compute_frame(
191
+ atom_n["coords"],
192
+ atom_ca["coords"],
193
+ atom_c["coords"],
194
+ )
195
+ frame_rot = frame_rot.flatten()
196
+
197
+ # Create token
198
+ token = TokenData(
199
+ token_idx=token_idx,
200
+ atom_idx=res["atom_idx"],
201
+ atom_num=res["atom_num"],
202
+ res_idx=res["res_idx"],
203
+ res_type=res["res_type"],
204
+ res_name=res["name"],
205
+ sym_id=chain["sym_id"],
206
+ asym_id=chain["asym_id"],
207
+ entity_id=chain["entity_id"],
208
+ mol_type=chain["mol_type"],
209
+ center_idx=res["atom_center"],
210
+ disto_idx=res["atom_disto"],
211
+ center_coords=c_coords,
212
+ disto_coords=d_coords,
213
+ resolved_mask=is_present,
214
+ disto_mask=is_disto_present,
215
+ modified=False,
216
+ frame_rot=frame_rot,
217
+ frame_t=frame_t,
218
+ frame_mask=frame_mask,
219
+ cyclic_period=chain["cyclic_period"],
220
+ affinity_mask=affinity_mask,
221
+ )
222
+ token_data.append(astuple(token))
223
+
224
+ # Update atom_idx to token_idx
225
+ for atom_idx in range(atom_start, atom_end):
226
+ atom_to_token[atom_idx] = token_idx
227
+
228
+ token_idx += 1
229
+
230
+ # Non-standard are tokenized per atom
231
+ elif chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
232
+ # We use the unk protein token as res_type
233
+ unk_token = const.unk_token["PROTEIN"]
234
+ unk_id = const.token_ids[unk_token]
235
+
236
+ # Get atom coordinates
237
+ atom_data = struct.atoms[atom_start:atom_end]
238
+ atom_coords = struct.coords[offset + atom_start : offset + atom_end][
239
+ "coords"
240
+ ]
241
+
242
+ # Tokenize each atom
243
+ for i, atom in enumerate(atom_data):
244
+ # Token is present if atom is
245
+ is_present = res["is_present"] & atom["is_present"]
246
+ index = atom_start + i
247
+
248
+ # Create token
249
+ token = TokenData(
250
+ token_idx=token_idx,
251
+ atom_idx=index,
252
+ atom_num=1,
253
+ res_idx=res["res_idx"],
254
+ res_type=unk_id,
255
+ res_name=res["name"],
256
+ sym_id=chain["sym_id"],
257
+ asym_id=chain["asym_id"],
258
+ entity_id=chain["entity_id"],
259
+ mol_type=chain["mol_type"],
260
+ center_idx=index,
261
+ disto_idx=index,
262
+ center_coords=atom_coords[i],
263
+ disto_coords=atom_coords[i],
264
+ resolved_mask=is_present,
265
+ disto_mask=is_present,
266
+ modified=chain["mol_type"]
267
+ != const.chain_type_ids["NONPOLYMER"],
268
+ frame_rot=np.eye(3).flatten(),
269
+ frame_t=np.zeros(3),
270
+ frame_mask=False,
271
+ cyclic_period=chain["cyclic_period"],
272
+ affinity_mask=affinity_mask,
273
+ )
274
+ token_data.append(astuple(token))
275
+
276
+ # Update atom_idx to token_idx
277
+ atom_to_token[index] = token_idx
278
+ token_idx += 1
279
+
280
+ # Modified residues in Boltz-2 are tokenized at residue level
281
+ else:
282
+ res_type = get_unk_token(chain)
283
+
284
+ # Get center and disto atoms
285
+ center = struct.atoms[res["atom_center"]]
286
+ disto = struct.atoms[res["atom_disto"]]
287
+
288
+ # Token is present if centers are
289
+ is_present = res["is_present"] & center["is_present"]
290
+ is_disto_present = res["is_present"] & disto["is_present"]
291
+
292
+ # Apply chain transformation
293
+ c_coords = struct.coords[offset + res["atom_center"]]["coords"]
294
+ d_coords = struct.coords[offset + res["atom_disto"]]["coords"]
295
+
296
+ # Create token
297
+ token = TokenData(
298
+ token_idx=token_idx,
299
+ atom_idx=res["atom_idx"],
300
+ atom_num=res["atom_num"],
301
+ res_idx=res["res_idx"],
302
+ res_type=res_type,
303
+ res_name=res["name"],
304
+ sym_id=chain["sym_id"],
305
+ asym_id=chain["asym_id"],
306
+ entity_id=chain["entity_id"],
307
+ mol_type=chain["mol_type"],
308
+ center_idx=res["atom_center"],
309
+ disto_idx=res["atom_disto"],
310
+ center_coords=c_coords,
311
+ disto_coords=d_coords,
312
+ resolved_mask=is_present,
313
+ disto_mask=is_disto_present,
314
+ modified=True,
315
+ frame_rot=np.eye(3).flatten(),
316
+ frame_t=np.zeros(3),
317
+ frame_mask=False,
318
+ cyclic_period=chain["cyclic_period"],
319
+ affinity_mask=affinity_mask,
320
+ )
321
+ token_data.append(astuple(token))
322
+
323
+ # Update atom_idx to token_idx
324
+ for atom_idx in range(atom_start, atom_end):
325
+ atom_to_token[atom_idx] = token_idx
326
+
327
+ token_idx += 1
328
+
329
+ # Create token bonds
330
+ token_bonds = []
331
+
332
+ # Add atom-atom bonds from ligands
333
+ for bond in struct.bonds:
334
+ if bond["atom_1"] not in atom_to_token or bond["atom_2"] not in atom_to_token:
335
+ continue
336
+ token_bond = (
337
+ atom_to_token[bond["atom_1"]],
338
+ atom_to_token[bond["atom_2"]],
339
+ bond["type"] + 1,
340
+ )
341
+ token_bonds.append(token_bond)
342
+
343
+ token_data = np.array(token_data, dtype=TokenV2)
344
+ token_bonds = np.array(token_bonds, dtype=TokenBondV2)
345
+
346
+ return token_data, token_bonds
347
+
348
+
349
+ class Boltz2Tokenizer(Tokenizer):
350
+ """Tokenize an input structure for training."""
351
+
352
+ def tokenize(self, data: Input) -> Tokenized:
353
+ """Tokenize the input data.
354
+
355
+ Parameters
356
+ ----------
357
+ data : Input
358
+ The input data.
359
+
360
+ Returns
361
+ -------
362
+ Tokenized
363
+ The tokenized data.
364
+
365
+ """
366
+ # Tokenize the structure
367
+ token_data, token_bonds = tokenize_structure(
368
+ data.structure, data.record.affinity
369
+ )
370
+
371
+ # Tokenize the templates
372
+ if data.templates is not None:
373
+ template_tokens = {}
374
+ template_bonds = {}
375
+ for template_id, template in data.templates.items():
376
+ tmpl_token_data, tmpl_token_bonds = tokenize_structure(template)
377
+ template_tokens[template_id] = tmpl_token_data
378
+ template_bonds[template_id] = tmpl_token_bonds
379
+ else:
380
+ template_tokens = None
381
+ template_bonds = None
382
+
383
+ # Create the tokenized data
384
+ tokenized = Tokenized(
385
+ tokens=token_data,
386
+ bonds=token_bonds,
387
+ structure=data.structure,
388
+ msa=data.msa,
389
+ record=data.record,
390
+ residue_constraints=data.residue_constraints,
391
+ templates=data.templates,
392
+ template_tokens=template_tokens,
393
+ template_bonds=template_bonds,
394
+ extra_mols=data.extra_mols,
395
+ )
396
+ return tokenized
@@ -0,0 +1,24 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from boltz.data.types import Input, Tokenized
4
+
5
+
6
+ class Tokenizer(ABC):
7
+ """Tokenize an input structure for training."""
8
+
9
+ @abstractmethod
10
+ def tokenize(self, data: Input) -> Tokenized:
11
+ """Tokenize the input data.
12
+
13
+ Parameters
14
+ ----------
15
+ data : Input
16
+ The input data.
17
+
18
+ Returns
19
+ -------
20
+ Tokenized
21
+ The tokenized data.
22
+
23
+ """
24
+ raise NotImplementedError