boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,37 @@
1
+ from boltz.data.types import Record
2
+ from boltz.data.filter.dynamic.filter import DynamicFilter
3
+
4
+
5
+ class MaxResiduesFilter(DynamicFilter):
6
+ """A filter that filters structures based on their size."""
7
+
8
+ def __init__(self, min_residues: int = 1, max_residues: int = 500) -> None:
9
+ """Initialize the filter.
10
+
11
+ Parameters
12
+ ----------
13
+ min_chains : int
14
+ The minimum number of chains allowed.
15
+ max_chains : int
16
+ The maximum number of chains allowed.
17
+
18
+ """
19
+ self.min_residues = min_residues
20
+ self.max_residues = max_residues
21
+
22
+ def filter(self, record: Record) -> bool:
23
+ """Filter structures based on their resolution.
24
+
25
+ Parameters
26
+ ----------
27
+ record : Record
28
+ The record to filter.
29
+
30
+ Returns
31
+ -------
32
+ bool
33
+ Whether the record should be filtered.
34
+
35
+ """
36
+ num_residues = sum(chain.num_residues for chain in record.chains)
37
+ return num_residues <= self.max_residues and num_residues >= self.min_residues
@@ -0,0 +1,34 @@
1
+ from boltz.data.types import Record
2
+ from boltz.data.filter.dynamic.filter import DynamicFilter
3
+
4
+
5
+ class ResolutionFilter(DynamicFilter):
6
+ """A filter that filters complexes based on their resolution."""
7
+
8
+ def __init__(self, resolution: float = 9.0) -> None:
9
+ """Initialize the filter.
10
+
11
+ Parameters
12
+ ----------
13
+ resolution : float, optional
14
+ The maximum allowed resolution.
15
+
16
+ """
17
+ self.resolution = resolution
18
+
19
+ def filter(self, record: Record) -> bool:
20
+ """Filter complexes based on their resolution.
21
+
22
+ Parameters
23
+ ----------
24
+ record : Record
25
+ The record to filter.
26
+
27
+ Returns
28
+ -------
29
+ bool
30
+ Whether the record should be filtered.
31
+
32
+ """
33
+ structure = record.structure
34
+ return structure.resolution <= self.resolution
@@ -0,0 +1,38 @@
1
+ from boltz.data.types import Record
2
+ from boltz.data.filter.dynamic.filter import DynamicFilter
3
+
4
+
5
+ class SizeFilter(DynamicFilter):
6
+ """A filter that filters structures based on their size."""
7
+
8
+ def __init__(self, min_chains: int = 1, max_chains: int = 300) -> None:
9
+ """Initialize the filter.
10
+
11
+ Parameters
12
+ ----------
13
+ min_chains : int
14
+ The minimum number of chains allowed.
15
+ max_chains : int
16
+ The maximum number of chains allowed.
17
+
18
+ """
19
+ self.min_chains = min_chains
20
+ self.max_chains = max_chains
21
+
22
+ def filter(self, record: Record) -> bool:
23
+ """Filter structures based on their resolution.
24
+
25
+ Parameters
26
+ ----------
27
+ record : Record
28
+ The record to filter.
29
+
30
+ Returns
31
+ -------
32
+ bool
33
+ Whether the record should be filtered.
34
+
35
+ """
36
+ num_chains = record.structure.num_chains
37
+ num_valid = sum(1 for chain in record.chains if chain.valid)
38
+ return num_chains <= self.max_chains and num_valid >= self.min_chains
@@ -0,0 +1,42 @@
1
+ from pathlib import Path
2
+
3
+ from boltz.data.types import Record
4
+ from boltz.data.filter.dynamic.filter import DynamicFilter
5
+
6
+
7
+ class SubsetFilter(DynamicFilter):
8
+ """Filter a data record based on a subset of the data."""
9
+
10
+ def __init__(self, subset: str, reverse: bool = False) -> None:
11
+ """Initialize the filter.
12
+
13
+ Parameters
14
+ ----------
15
+ subset : str
16
+ The subset of data to consider, one per line.
17
+
18
+ """
19
+ with Path(subset).open("r") as f:
20
+ subset = f.read().splitlines()
21
+
22
+ self.subset = {s.lower() for s in subset}
23
+ self.reverse = reverse
24
+
25
+ def filter(self, record: Record) -> bool:
26
+ """Filter a data record.
27
+
28
+ Parameters
29
+ ----------
30
+ record : Record
31
+ The object to consider filtering in / out.
32
+
33
+ Returns
34
+ -------
35
+ bool
36
+ True if the data passes the filter, False otherwise.
37
+
38
+ """
39
+ if self.reverse:
40
+ return record.id.lower() not in self.subset
41
+ else: # noqa: RET505
42
+ return record.id.lower() in self.subset
File without changes
@@ -0,0 +1,26 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+
5
+ from boltz.data.types import Structure
6
+
7
+
8
+ class StaticFilter(ABC):
9
+ """Base class for structure filters."""
10
+
11
+ @abstractmethod
12
+ def filter(self, structure: Structure) -> np.ndarray:
13
+ """Filter chains in a structure.
14
+
15
+ Parameters
16
+ ----------
17
+ structure : Structure
18
+ The structure to filter chains from.
19
+
20
+ Returns
21
+ -------
22
+ np.ndarray
23
+ The chains to keep, as a boolean mask.
24
+
25
+ """
26
+ raise NotImplementedError
@@ -0,0 +1,37 @@
1
+ import numpy as np
2
+
3
+ from boltz.data import const
4
+ from boltz.data.filter.static.filter import StaticFilter
5
+ from boltz.data.types import Structure
6
+
7
+
8
+ class ExcludedLigands(StaticFilter):
9
+ """Filter excluded ligands."""
10
+
11
+ def filter(self, structure: Structure) -> np.ndarray:
12
+ """Filter excluded ligands.
13
+
14
+ Parameters
15
+ ----------
16
+ structure : Structure
17
+ The structure to filter chains from.
18
+
19
+ Returns
20
+ -------
21
+ np.ndarray
22
+ The chains to keep, as a boolean mask.
23
+
24
+ """
25
+ valid = np.ones(len(structure.chains), dtype=bool)
26
+
27
+ for i, chain in enumerate(structure.chains):
28
+ if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]:
29
+ continue
30
+
31
+ res_start = chain["res_idx"]
32
+ res_end = res_start + chain["res_num"]
33
+ residues = structure.residues[res_start:res_end]
34
+ if any(res["name"] in const.ligand_exclusion for res in residues):
35
+ valid[i] = 0
36
+
37
+ return valid
@@ -0,0 +1,299 @@
1
+ import itertools
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ from sklearn.neighbors import KDTree
6
+
7
+ from boltz.data import const
8
+ from boltz.data.filter.static.filter import StaticFilter
9
+ from boltz.data.types import Structure
10
+
11
+
12
+ class MinimumLengthFilter(StaticFilter):
13
+ """Filter polymers based on their length.
14
+
15
+ We use the number of resolved residues when considering
16
+ the minimum, and the sequence length for the maximum.
17
+
18
+ """
19
+
20
+ def __init__(self, min_len: int = 4, max_len: int = 5000) -> None:
21
+ """Initialize the filter.
22
+
23
+ Parameters
24
+ ----------
25
+ min_len : float, optional
26
+ The minimum allowed length.
27
+ max_len : float, optional
28
+ The maximum allowed length.
29
+
30
+ """
31
+ self._min = min_len
32
+ self._max = max_len
33
+
34
+ def filter(self, structure: Structure) -> np.ndarray:
35
+ """Filter a chains based on their length.
36
+
37
+ Parameters
38
+ ----------
39
+ structure : Structure
40
+ The structure to filter chains from.
41
+
42
+ Returns
43
+ -------
44
+ np.ndarray
45
+ The chains to keep, as a boolean mask.
46
+
47
+ """
48
+ valid = np.ones(len(structure.chains), dtype=bool)
49
+
50
+ for i, chain in enumerate(structure.chains):
51
+ if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
52
+ continue
53
+
54
+ res_start = chain["res_idx"]
55
+ res_end = res_start + chain["res_num"]
56
+ residues = structure.residues[res_start:res_end]
57
+ resolved = residues["is_present"].sum()
58
+
59
+ if (resolved < self._min) or (resolved > self._max):
60
+ valid[i] = 0
61
+
62
+ return valid
63
+
64
+
65
+ class UnknownFilter(StaticFilter):
66
+ """Filter proteins with all unknown residues."""
67
+
68
+ def filter(self, structure: Structure) -> np.ndarray:
69
+ """Filter proteins with all unknown residues.
70
+
71
+ Parameters
72
+ ----------
73
+ structure : Structure
74
+ The structure to filter chains from.
75
+
76
+ Returns
77
+ -------
78
+ np.ndarray
79
+ The chains to keep, as a boolean mask.
80
+
81
+ """
82
+ valid = np.ones(len(structure.chains), dtype=bool)
83
+ unk_toks = {
84
+ const.chain_type_ids["PROTEIN"]: const.unk_token_ids["PROTEIN"],
85
+ const.chain_type_ids["DNA"]: const.unk_token_ids["DNA"],
86
+ const.chain_type_ids["RNA"]: const.unk_token_ids["RNA"],
87
+ }
88
+
89
+ for i, chain in enumerate(structure.chains):
90
+ if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
91
+ continue
92
+
93
+ res_start = chain["res_idx"]
94
+ res_end = res_start + chain["res_num"]
95
+ residues = structure.residues[res_start:res_end]
96
+
97
+ unk_id = unk_toks[chain["mol_type"]]
98
+ if np.all(residues["res_type"] == unk_id):
99
+ valid[i] = 0
100
+
101
+ return valid
102
+
103
+
104
+ class ConsecutiveCA(StaticFilter):
105
+ """Filter proteins with consecutive CA atoms above a threshold."""
106
+
107
+ def __init__(self, max_dist: int = 10.0) -> None:
108
+ """Initialize the filter.
109
+
110
+ Parameters
111
+ ----------
112
+ max_dist : float, optional
113
+ The maximum allowed distance.
114
+
115
+ """
116
+ self._max_dist = max_dist
117
+
118
+ def filter(self, structure: Structure) -> np.ndarray:
119
+ """Filter protein if consecutive CA atoms above a threshold.
120
+
121
+ Parameters
122
+ ----------
123
+ structure : Structure
124
+ The structure to filter chains from.
125
+
126
+ Returns
127
+ -------
128
+ np.ndarray
129
+ The chains to keep, as a boolean mask.
130
+
131
+ """
132
+ valid = np.ones(len(structure.chains), dtype=bool)
133
+
134
+ # Remove chain if consecutive CA atoms are above threshold
135
+ for i, chain in enumerate(structure.chains):
136
+ # Skip non-protein chains
137
+ if chain["mol_type"] != const.chain_type_ids["PROTEIN"]:
138
+ continue
139
+
140
+ # Get residues
141
+ res_start = chain["res_idx"]
142
+ res_end = res_start + chain["res_num"]
143
+ residues = structure.residues[res_start:res_end]
144
+
145
+ # Get c-alphas
146
+ ca_ids = residues["atom_center"]
147
+ ca_atoms = structure.atoms[ca_ids]
148
+
149
+ res_valid = residues["is_present"]
150
+ ca_valid = ca_atoms["is_present"] & res_valid
151
+ ca_coords = ca_atoms["coords"]
152
+
153
+ # Compute distances between consecutive atoms
154
+ dist = np.linalg.norm(ca_coords[1:] - ca_coords[:-1], axis=1)
155
+ dist = dist > self._max_dist
156
+ dist = dist[ca_valid[1:] & ca_valid[:-1]]
157
+
158
+ # Remove the chain if any valid pair is above threshold
159
+ if np.any(dist):
160
+ valid[i] = 0
161
+
162
+ return valid
163
+
164
+
165
+ @dataclass(frozen=True)
166
+ class Clash:
167
+ """A clash between two chains."""
168
+
169
+ chain: int
170
+ other: int
171
+ num_atoms: int
172
+ num_clashes: int
173
+
174
+
175
+ class ClashingChainsFilter(StaticFilter):
176
+ """A filter that filters clashing chains.
177
+
178
+ Clashing chains are defined as those with >30% of atoms
179
+ within 1.7 Å of an atom in another chain. If two chains
180
+ are clashing with each other, the chain with the greater
181
+ percentage of clashing atoms will be removed. If the same
182
+ fraction of atoms are clashing, the chain with fewer total
183
+ atoms is removed. If the chains have the same number of
184
+ atoms, then the chain with the larger chain id is removed.
185
+
186
+ """
187
+
188
+ def __init__(self, dist: float = 1.7, freq: float = 0.3) -> None:
189
+ """Initialize the filter.
190
+
191
+ Parameters
192
+ ----------
193
+ dist : float, optional
194
+ The maximum distance for a clash.
195
+ freq : float, optional
196
+ The maximum allowed frequency of clashes.
197
+
198
+ """
199
+ self._dist = dist
200
+ self._freq = freq
201
+
202
+ def filter(self, structure: Structure) -> np.ndarray: # noqa: PLR0912, C901
203
+ """Filter out clashing chains.
204
+
205
+ Parameters
206
+ ----------
207
+ structure : Structure
208
+ The structure to filter chains from.
209
+
210
+ Returns
211
+ -------
212
+ np.ndarray
213
+ The chains to keep, as a boolean mask.
214
+
215
+ """
216
+ num_chains = len(structure.chains)
217
+ if num_chains < 2: # noqa: PLR2004
218
+ return np.ones(num_chains, dtype=bool)
219
+
220
+ # Get unique chain pairs
221
+ pairs = itertools.combinations(range(num_chains), 2)
222
+
223
+ # Compute clashes
224
+ clashes: list[Clash] = []
225
+ for i, j in pairs:
226
+ # Get the chains
227
+ c1 = structure.chains[i]
228
+ c2 = structure.chains[j]
229
+
230
+ # Get the atoms from each chain
231
+ c1_start = c1["atom_idx"]
232
+ c2_start = c2["atom_idx"]
233
+ c1_end = c1_start + c1["atom_num"]
234
+ c2_end = c2_start + c2["atom_num"]
235
+
236
+ atoms1 = structure.atoms[c1_start:c1_end]
237
+ atoms2 = structure.atoms[c2_start:c2_end]
238
+ atoms1 = atoms1[atoms1["is_present"]]
239
+ atoms2 = atoms2[atoms2["is_present"]]
240
+
241
+ # Skip if either chain has no atoms
242
+ if len(atoms1) == 0 or len(atoms2) == 0:
243
+ continue
244
+
245
+ # Compute the number of clashes
246
+ # Compute the distance matrix
247
+ tree = KDTree(atoms1["coords"], metric="euclidean")
248
+ query = tree.query_radius(atoms2["coords"], self._dist)
249
+
250
+ c2_clashes = sum(len(neighbors) > 0 for neighbors in query)
251
+ c1_clashes = len(set(itertools.chain.from_iterable(query)))
252
+
253
+ # Save results
254
+ if (c1_clashes / len(atoms1)) > self._freq:
255
+ clashes.append(Clash(i, j, len(atoms1), c1_clashes))
256
+ if (c2_clashes / len(atoms2)) > self._freq:
257
+ clashes.append(Clash(j, i, len(atoms2), c2_clashes))
258
+
259
+ # Compute indices to clash map
260
+ removed = set()
261
+ ids_to_clash = {(c.chain, c.other): c for c in clashes}
262
+
263
+ # Filter out chains according to ruleset
264
+ for clash in clashes:
265
+ # If either is already removed, skip
266
+ if clash.chain in removed or clash.other in removed:
267
+ continue
268
+
269
+ # Check if the two chains clash with each other
270
+ other_clash = ids_to_clash.get((clash.other, clash.chain))
271
+ if other_clash is not None:
272
+ # Remove the chain with the most clashes
273
+ clash1_freq = clash.num_clashes / clash.num_atoms
274
+ clash2_freq = other_clash.num_clashes / other_clash.num_atoms
275
+ if clash1_freq > clash2_freq:
276
+ removed.add(clash.chain)
277
+ elif clash1_freq < clash2_freq:
278
+ removed.add(clash.other)
279
+
280
+ # If same, remove the chain with fewer atoms
281
+ elif clash.num_atoms < other_clash.num_atoms:
282
+ removed.add(clash.chain)
283
+ elif clash.num_atoms > other_clash.num_atoms:
284
+ removed.add(clash.other)
285
+
286
+ # If same, remove the chain with the larger chain id
287
+ else:
288
+ removed.add(max(clash.chain, clash.other))
289
+
290
+ # Otherwise, just remove the chain directly
291
+ else:
292
+ removed.add(clash.chain)
293
+
294
+ # Remove the chains
295
+ valid = np.ones(len(structure.chains), dtype=bool)
296
+ for i in removed:
297
+ valid[i] = 0
298
+
299
+ return valid
File without changes