boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,99 @@
1
+ import torch
2
+
3
+
4
+ class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
5
+ """Implements the learning rate schedule defined AF3.
6
+
7
+ A linear warmup is followed by a plateau at the maximum
8
+ learning rate and then exponential decay. Note that the
9
+ initial learning rate of the optimizer in question is
10
+ ignored; use this class' base_lr parameter to specify
11
+ the starting point of the warmup.
12
+
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ optimizer: torch.optim.Optimizer,
18
+ last_epoch: int = -1,
19
+ base_lr: float = 0.0,
20
+ max_lr: float = 1.8e-3,
21
+ warmup_no_steps: int = 1000,
22
+ start_decay_after_n_steps: int = 50000,
23
+ decay_every_n_steps: int = 50000,
24
+ decay_factor: float = 0.95,
25
+ ) -> None:
26
+ """Initialize the learning rate scheduler.
27
+
28
+ Parameters
29
+ ----------
30
+ optimizer : torch.optim.Optimizer
31
+ The optimizer.
32
+ last_epoch : int, optional
33
+ The last epoch, by default -1
34
+ base_lr : float, optional
35
+ The base learning rate, by default 0.0
36
+ max_lr : float, optional
37
+ The maximum learning rate, by default 1.8e-3
38
+ warmup_no_steps : int, optional
39
+ The number of warmup steps, by default 1000
40
+ start_decay_after_n_steps : int, optional
41
+ The number of steps after which to start decay, by default 50000
42
+ decay_every_n_steps : int, optional
43
+ The number of steps after which to decay, by default 50000
44
+ decay_factor : float, optional
45
+ The decay factor, by default 0.95
46
+
47
+ """
48
+ step_counts = {
49
+ "warmup_no_steps": warmup_no_steps,
50
+ "start_decay_after_n_steps": start_decay_after_n_steps,
51
+ }
52
+
53
+ for k, v in step_counts.items():
54
+ if v < 0:
55
+ msg = f"{k} must be nonnegative"
56
+ raise ValueError(msg)
57
+
58
+ if warmup_no_steps > start_decay_after_n_steps:
59
+ msg = "warmup_no_steps must not exceed start_decay_after_n_steps"
60
+ raise ValueError(msg)
61
+
62
+ self.optimizer = optimizer
63
+ self.last_epoch = last_epoch
64
+ self.base_lr = base_lr
65
+ self.max_lr = max_lr
66
+ self.warmup_no_steps = warmup_no_steps
67
+ self.start_decay_after_n_steps = start_decay_after_n_steps
68
+ self.decay_every_n_steps = decay_every_n_steps
69
+ self.decay_factor = decay_factor
70
+
71
+ super().__init__(optimizer, last_epoch=last_epoch)
72
+
73
+ def state_dict(self) -> dict:
74
+ state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
75
+ return state_dict
76
+
77
+ def load_state_dict(self, state_dict):
78
+ self.__dict__.update(state_dict)
79
+
80
+ def get_lr(self):
81
+ if not self._get_lr_called_within_step:
82
+ msg = (
83
+ "To get the last learning rate computed by the scheduler, use "
84
+ "get_last_lr()"
85
+ )
86
+ raise RuntimeError(msg)
87
+
88
+ step_no = self.last_epoch
89
+
90
+ if step_no <= self.warmup_no_steps:
91
+ lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
92
+ elif step_no > self.start_decay_after_n_steps:
93
+ steps_since_decay = step_no - self.start_decay_after_n_steps
94
+ exp = (steps_since_decay // self.decay_every_n_steps) + 1
95
+ lr = self.max_lr * (self.decay_factor**exp)
96
+ else: # plateau
97
+ lr = self.max_lr
98
+
99
+ return [lr for group in self.optimizer.param_groups]
File without changes
@@ -0,0 +1,497 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+
7
+ from boltz.data import const
8
+ from boltz.model.potentials.schedules import (
9
+ ExponentialInterpolation,
10
+ ParameterSchedule,
11
+ PiecewiseStepFunction,
12
+ )
13
+
14
+
15
+ class Potential(ABC):
16
+ def __init__(
17
+ self,
18
+ parameters: Optional[
19
+ dict[str, Union[ParameterSchedule, float, int, bool]]
20
+ ] = None,
21
+ ):
22
+ self.parameters = parameters
23
+
24
+ def compute(self, coords, feats, parameters):
25
+ index, args, com_args = self.compute_args(feats, parameters)
26
+
27
+ if index.shape[1] == 0:
28
+ return torch.zeros(coords.shape[:-2], device=coords.device)
29
+
30
+ if com_args is not None:
31
+ com_index, atom_pad_mask = com_args
32
+ unpad_com_index = com_index[atom_pad_mask]
33
+ unpad_coords = coords[..., atom_pad_mask, :]
34
+ coords = torch.zeros(
35
+ (*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3),
36
+ device=coords.device,
37
+ ).scatter_reduce(
38
+ -2,
39
+ unpad_com_index.unsqueeze(-1).expand_as(unpad_coords),
40
+ unpad_coords,
41
+ "mean",
42
+ )
43
+ value = self.compute_variable(coords, index, compute_gradient=False)
44
+ energy = self.compute_function(value, *args)
45
+ return energy.sum(dim=-1)
46
+
47
+ def compute_gradient(self, coords, feats, parameters):
48
+ index, args, com_args = self.compute_args(feats, parameters)
49
+ if com_args is not None:
50
+ com_index, atom_pad_mask = com_args
51
+ else:
52
+ com_index, atom_pad_mask = None, None
53
+
54
+ if index.shape[1] == 0:
55
+ return torch.zeros_like(coords)
56
+
57
+ if com_index is not None:
58
+ unpad_coords = coords[..., atom_pad_mask, :]
59
+ unpad_com_index = com_index[atom_pad_mask]
60
+ coords = torch.zeros(
61
+ (*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3),
62
+ device=coords.device,
63
+ ).scatter_reduce(
64
+ -2,
65
+ unpad_com_index.unsqueeze(-1).expand_as(unpad_coords),
66
+ unpad_coords,
67
+ "mean",
68
+ )
69
+ com_counts = torch.bincount(com_index[atom_pad_mask])
70
+
71
+ value, grad_value = self.compute_variable(coords, index, compute_gradient=True)
72
+ energy, dEnergy = self.compute_function(value, *args, compute_derivative=True)
73
+
74
+ grad_atom = torch.zeros_like(coords).scatter_reduce(
75
+ -2,
76
+ index.flatten(start_dim=0, end_dim=1)
77
+ .unsqueeze(-1)
78
+ .expand((*coords.shape[:-2], -1, 3)),
79
+ dEnergy.tile(grad_value.shape[-3]).unsqueeze(-1)
80
+ * grad_value.flatten(start_dim=-3, end_dim=-2),
81
+ "sum",
82
+ )
83
+
84
+ if com_index is not None:
85
+ grad_atom = grad_atom[..., com_index, :]
86
+
87
+ return grad_atom
88
+
89
+ def compute_parameters(self, t):
90
+ if self.parameters is None:
91
+ return None
92
+ parameters = {
93
+ name: parameter
94
+ if not isinstance(parameter, ParameterSchedule)
95
+ else parameter.compute(t)
96
+ for name, parameter in self.parameters.items()
97
+ }
98
+ return parameters
99
+
100
+ @abstractmethod
101
+ def compute_function(self, value, *args, compute_derivative=False):
102
+ raise NotImplementedError
103
+
104
+ @abstractmethod
105
+ def compute_variable(self, coords, index, compute_gradient=False):
106
+ raise NotImplementedError
107
+
108
+ @abstractmethod
109
+ def compute_args(self, t, feats, **parameters):
110
+ raise NotImplementedError
111
+
112
+
113
+ class FlatBottomPotential(Potential):
114
+ def compute_function(
115
+ self, value, k, lower_bounds, upper_bounds, compute_derivative=False
116
+ ):
117
+ if lower_bounds is None:
118
+ lower_bounds = torch.full_like(value, float("-inf"))
119
+ if upper_bounds is None:
120
+ upper_bounds = torch.full_like(value, float("inf"))
121
+
122
+ neg_overflow_mask = value < lower_bounds
123
+ pos_overflow_mask = value > upper_bounds
124
+
125
+ energy = torch.zeros_like(value)
126
+ energy[neg_overflow_mask] = (k * (lower_bounds - value))[neg_overflow_mask]
127
+ energy[pos_overflow_mask] = (k * (value - upper_bounds))[pos_overflow_mask]
128
+ if not compute_derivative:
129
+ return energy
130
+
131
+ dEnergy = torch.zeros_like(value)
132
+ dEnergy[neg_overflow_mask] = (
133
+ -1 * k.expand_as(neg_overflow_mask)[neg_overflow_mask]
134
+ )
135
+ dEnergy[pos_overflow_mask] = (
136
+ 1 * k.expand_as(pos_overflow_mask)[pos_overflow_mask]
137
+ )
138
+
139
+ return energy, dEnergy
140
+
141
+
142
+ class DistancePotential(Potential):
143
+ def compute_variable(self, coords, index, compute_gradient=False):
144
+ r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1])
145
+ r_ij_norm = torch.linalg.norm(r_ij, dim=-1)
146
+ r_hat_ij = r_ij / r_ij_norm.unsqueeze(-1)
147
+
148
+ if not compute_gradient:
149
+ return r_ij_norm
150
+
151
+ grad_i = r_hat_ij
152
+ grad_j = -1 * r_hat_ij
153
+ grad = torch.stack((grad_i, grad_j), dim=1)
154
+
155
+ return r_ij_norm, grad
156
+
157
+
158
+ class DihedralPotential(Potential):
159
+ def compute_variable(self, coords, index, compute_gradient=False):
160
+ r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1])
161
+ r_kj = coords.index_select(-2, index[2]) - coords.index_select(-2, index[1])
162
+ r_kl = coords.index_select(-2, index[2]) - coords.index_select(-2, index[3])
163
+
164
+ n_ijk = torch.cross(r_ij, r_kj, dim=-1)
165
+ n_jkl = torch.cross(r_kj, r_kl, dim=-1)
166
+
167
+ r_kj_norm = torch.linalg.norm(r_kj, dim=-1)
168
+ n_ijk_norm = torch.linalg.norm(n_ijk, dim=-1)
169
+ n_jkl_norm = torch.linalg.norm(n_jkl, dim=-1)
170
+
171
+ sign_phi = torch.sign(
172
+ r_kj.unsqueeze(-2) @ torch.cross(n_ijk, n_jkl, dim=-1).unsqueeze(-1)
173
+ ).squeeze(-1, -2)
174
+ phi = sign_phi * torch.arccos(
175
+ torch.clamp(
176
+ (n_ijk.unsqueeze(-2) @ n_jkl.unsqueeze(-1)).squeeze(-1, -2)
177
+ / (n_ijk_norm * n_jkl_norm),
178
+ -1 + 1e-8,
179
+ 1 - 1e-8,
180
+ )
181
+ )
182
+
183
+ if not compute_gradient:
184
+ return phi
185
+
186
+ a = (
187
+ (r_ij.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2)
188
+ ).unsqueeze(-1)
189
+ b = (
190
+ (r_kl.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2)
191
+ ).unsqueeze(-1)
192
+
193
+ grad_i = n_ijk * (r_kj_norm / n_ijk_norm**2).unsqueeze(-1)
194
+ grad_l = -1 * n_jkl * (r_kj_norm / n_jkl_norm**2).unsqueeze(-1)
195
+ grad_j = (a - 1) * grad_i - b * grad_l
196
+ grad_k = (b - 1) * grad_l - a * grad_i
197
+ grad = torch.stack((grad_i, grad_j, grad_k, grad_l), dim=1)
198
+ return phi, grad
199
+
200
+
201
+ class AbsDihedralPotential(DihedralPotential):
202
+ def compute_variable(self, coords, index, compute_gradient=False):
203
+ if not compute_gradient:
204
+ phi = super().compute_variable(
205
+ coords, index, compute_gradient=compute_gradient
206
+ )
207
+ phi = torch.abs(phi)
208
+ return phi
209
+
210
+ phi, grad = super().compute_variable(
211
+ coords, index, compute_gradient=compute_gradient
212
+ )
213
+ grad[(phi < 0)[..., None, :, None].expand_as(grad)] *= -1
214
+ phi = torch.abs(phi)
215
+
216
+ return phi, grad
217
+
218
+
219
+ class PoseBustersPotential(FlatBottomPotential, DistancePotential):
220
+ def compute_args(self, feats, parameters):
221
+ pair_index = feats["rdkit_bounds_index"][0]
222
+ lower_bounds = feats["rdkit_lower_bounds"][0].clone()
223
+ upper_bounds = feats["rdkit_upper_bounds"][0].clone()
224
+ bond_mask = feats["rdkit_bounds_bond_mask"][0]
225
+ angle_mask = feats["rdkit_bounds_angle_mask"][0]
226
+
227
+ lower_bounds[bond_mask * ~angle_mask] *= 1.0 - parameters["bond_buffer"]
228
+ upper_bounds[bond_mask * ~angle_mask] *= 1.0 + parameters["bond_buffer"]
229
+ lower_bounds[~bond_mask * angle_mask] *= 1.0 - parameters["angle_buffer"]
230
+ upper_bounds[~bond_mask * angle_mask] *= 1.0 + parameters["angle_buffer"]
231
+ lower_bounds[bond_mask * angle_mask] *= 1.0 - min(
232
+ parameters["angle_buffer"], parameters["angle_buffer"]
233
+ )
234
+ upper_bounds[bond_mask * angle_mask] *= 1.0 + min(
235
+ parameters["angle_buffer"], parameters["angle_buffer"]
236
+ )
237
+ lower_bounds[~bond_mask * ~angle_mask] *= 1.0 - parameters["clash_buffer"]
238
+ upper_bounds[~bond_mask * ~angle_mask] = float("inf")
239
+
240
+ k = torch.ones_like(lower_bounds)
241
+
242
+ return pair_index, (k, lower_bounds, upper_bounds), None
243
+
244
+
245
+ class ConnectionsPotential(FlatBottomPotential, DistancePotential):
246
+ def compute_args(self, feats, parameters):
247
+ pair_index = feats["connected_atom_index"][0]
248
+ lower_bounds = None
249
+ upper_bounds = torch.full(
250
+ (pair_index.shape[1],), parameters["buffer"], device=pair_index.device
251
+ )
252
+ k = torch.ones_like(upper_bounds)
253
+
254
+ return pair_index, (k, lower_bounds, upper_bounds), None
255
+
256
+
257
+ class VDWOverlapPotential(FlatBottomPotential, DistancePotential):
258
+ def compute_args(self, feats, parameters):
259
+ atom_chain_id = (
260
+ torch.bmm(
261
+ feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float()
262
+ )
263
+ .squeeze(-1)
264
+ .long()
265
+ )[0]
266
+ atom_pad_mask = feats["atom_pad_mask"][0].bool()
267
+ chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask])
268
+ single_ion_mask = (chain_sizes > 1)[atom_chain_id]
269
+
270
+ vdw_radii = torch.zeros(
271
+ const.num_elements, dtype=torch.float32, device=atom_chain_id.device
272
+ )
273
+ vdw_radii[1:119] = torch.tensor(
274
+ const.vdw_radii, dtype=torch.float32, device=atom_chain_id.device
275
+ )
276
+ atom_vdw_radii = (
277
+ feats["ref_element"].float() @ vdw_radii.unsqueeze(-1)
278
+ ).squeeze(-1)[0]
279
+
280
+ pair_index = torch.triu_indices(
281
+ atom_chain_id.shape[0],
282
+ atom_chain_id.shape[0],
283
+ 1,
284
+ device=atom_chain_id.device,
285
+ )
286
+
287
+ pair_pad_mask = atom_pad_mask[pair_index].all(dim=0)
288
+ pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]]
289
+
290
+ num_chains = atom_chain_id.max() + 1
291
+ connected_chain_index = feats["connected_chain_index"][0]
292
+ connected_chain_matrix = torch.eye(
293
+ num_chains, device=atom_chain_id.device, dtype=torch.bool
294
+ )
295
+ connected_chain_matrix[connected_chain_index[0], connected_chain_index[1]] = (
296
+ True
297
+ )
298
+ connected_chain_matrix[connected_chain_index[1], connected_chain_index[0]] = (
299
+ True
300
+ )
301
+ connected_chain_mask = connected_chain_matrix[
302
+ atom_chain_id[pair_index[0]], atom_chain_id[pair_index[1]]
303
+ ]
304
+
305
+ pair_index = pair_index[
306
+ :, pair_pad_mask * pair_ion_mask * ~connected_chain_mask
307
+ ]
308
+
309
+ lower_bounds = atom_vdw_radii[pair_index].sum(dim=0) * (
310
+ 1.0 - parameters["buffer"]
311
+ )
312
+ upper_bounds = None
313
+ k = torch.ones_like(lower_bounds)
314
+
315
+ return pair_index, (k, lower_bounds, upper_bounds), None
316
+
317
+
318
+ class SymmetricChainCOMPotential(FlatBottomPotential, DistancePotential):
319
+ def compute_args(self, feats, parameters):
320
+ atom_chain_id = (
321
+ torch.bmm(
322
+ feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float()
323
+ )
324
+ .squeeze(-1)
325
+ .long()
326
+ )[0]
327
+ atom_pad_mask = feats["atom_pad_mask"][0].bool()
328
+ chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask])
329
+ single_ion_mask = chain_sizes > 1
330
+
331
+ pair_index = feats["symmetric_chain_index"][0]
332
+ pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]]
333
+ pair_index = pair_index[:, pair_ion_mask]
334
+ lower_bounds = torch.full(
335
+ (pair_index.shape[1],),
336
+ parameters["buffer"],
337
+ dtype=torch.float32,
338
+ device=pair_index.device,
339
+ )
340
+ upper_bounds = None
341
+ k = torch.ones_like(lower_bounds)
342
+
343
+ return (
344
+ pair_index,
345
+ (k, lower_bounds, upper_bounds),
346
+ (atom_chain_id, atom_pad_mask),
347
+ )
348
+
349
+
350
+ class StereoBondPotential(FlatBottomPotential, AbsDihedralPotential):
351
+ def compute_args(self, feats, parameters):
352
+ stereo_bond_index = feats["stereo_bond_index"][0]
353
+ stereo_bond_orientations = feats["stereo_bond_orientations"][0].bool()
354
+
355
+ lower_bounds = torch.zeros(
356
+ stereo_bond_orientations.shape, device=stereo_bond_orientations.device
357
+ )
358
+ upper_bounds = torch.zeros(
359
+ stereo_bond_orientations.shape, device=stereo_bond_orientations.device
360
+ )
361
+ lower_bounds[stereo_bond_orientations] = torch.pi - parameters["buffer"]
362
+ upper_bounds[stereo_bond_orientations] = float("inf")
363
+ lower_bounds[~stereo_bond_orientations] = float("-inf")
364
+ upper_bounds[~stereo_bond_orientations] = parameters["buffer"]
365
+
366
+ k = torch.ones_like(lower_bounds)
367
+
368
+ return stereo_bond_index, (k, lower_bounds, upper_bounds), None
369
+
370
+
371
+ class ChiralAtomPotential(FlatBottomPotential, DihedralPotential):
372
+ def compute_args(self, feats, parameters):
373
+ chiral_atom_index = feats["chiral_atom_index"][0]
374
+ chiral_atom_orientations = feats["chiral_atom_orientations"][0].bool()
375
+
376
+ lower_bounds = torch.zeros(
377
+ chiral_atom_orientations.shape, device=chiral_atom_orientations.device
378
+ )
379
+ upper_bounds = torch.zeros(
380
+ chiral_atom_orientations.shape, device=chiral_atom_orientations.device
381
+ )
382
+ lower_bounds[chiral_atom_orientations] = parameters["buffer"]
383
+ upper_bounds[chiral_atom_orientations] = float("inf")
384
+ upper_bounds[~chiral_atom_orientations] = -1 * parameters["buffer"]
385
+ lower_bounds[~chiral_atom_orientations] = float("-inf")
386
+
387
+ k = torch.ones_like(lower_bounds)
388
+ return chiral_atom_index, (k, lower_bounds, upper_bounds), None
389
+
390
+
391
+ class PlanarBondPotential(FlatBottomPotential, AbsDihedralPotential):
392
+ def compute_args(self, feats, parameters):
393
+ double_bond_index = feats["planar_bond_index"][0].T
394
+ double_bond_improper_index = torch.tensor(
395
+ [
396
+ [1, 2, 3, 0],
397
+ [4, 5, 0, 3],
398
+ ],
399
+ device=double_bond_index.device,
400
+ ).T
401
+ improper_index = (
402
+ double_bond_index[:, double_bond_improper_index]
403
+ .swapaxes(0, 1)
404
+ .flatten(start_dim=1)
405
+ )
406
+ lower_bounds = None
407
+ upper_bounds = torch.full(
408
+ (improper_index.shape[1],),
409
+ parameters["buffer"],
410
+ device=improper_index.device,
411
+ )
412
+ k = torch.ones_like(upper_bounds)
413
+
414
+ return improper_index, (k, lower_bounds, upper_bounds), None
415
+
416
+
417
+ def get_potentials():
418
+ potentials = [
419
+ SymmetricChainCOMPotential(
420
+ parameters={
421
+ "guidance_interval": 4,
422
+ "guidance_weight": 0.5,
423
+ "resampling_weight": 0.5,
424
+ "buffer": ExponentialInterpolation(start=1.0, end=5.0, alpha=-2.0),
425
+ }
426
+ ),
427
+ VDWOverlapPotential(
428
+ parameters={
429
+ "guidance_interval": 5,
430
+ "guidance_weight": PiecewiseStepFunction(
431
+ thresholds=[0.4], values=[0.125, 0.0]
432
+ ),
433
+ "resampling_weight": PiecewiseStepFunction(
434
+ thresholds=[0.6], values=[0.01, 0.0]
435
+ ),
436
+ "buffer": 0.225,
437
+ }
438
+ ),
439
+ ConnectionsPotential(
440
+ parameters={
441
+ "guidance_interval": 1,
442
+ "guidance_weight": 0.15,
443
+ "resampling_weight": 1.0,
444
+ "buffer": 2.0,
445
+ }
446
+ ),
447
+ PoseBustersPotential(
448
+ parameters={
449
+ "guidance_interval": 1,
450
+ "guidance_weight": 0.05,
451
+ "resampling_weight": 0.1,
452
+ "bond_buffer": 0.20,
453
+ "angle_buffer": 0.20,
454
+ "clash_buffer": 0.15,
455
+ }
456
+ ),
457
+ ChiralAtomPotential(
458
+ parameters={
459
+ "guidance_interval": 1,
460
+ "guidance_weight": 0.10,
461
+ "resampling_weight": 1.0,
462
+ "buffer": 0.52360,
463
+ }
464
+ ),
465
+ StereoBondPotential(
466
+ parameters={
467
+ "guidance_interval": 1,
468
+ "guidance_weight": 0.05,
469
+ "resampling_weight": 1.0,
470
+ "buffer": 0.52360,
471
+ }
472
+ ),
473
+ PlanarBondPotential(
474
+ parameters={
475
+ "guidance_interval": 1,
476
+ "guidance_weight": 0.05,
477
+ "resampling_weight": 1.0,
478
+ "buffer": 0.26180,
479
+ }
480
+ ),
481
+ ]
482
+ return potentials
483
+
484
+
485
+ @dataclass
486
+ class GuidanceConfig:
487
+ """Guidance configuration."""
488
+
489
+ potentials: Optional[list[Potential]] = None
490
+ guidance_update: Optional[bool] = None
491
+ num_guidance_gd_steps: Optional[int] = None
492
+ guidance_gd_step_size: Optional[int] = None
493
+ fk_steering: Optional[bool] = None
494
+ fk_resampling_interval: Optional[int] = 1
495
+ fk_lambda: Optional[float] = 1.0
496
+ fk_method: Optional[str] = None
497
+ fk_batch_size: Optional[int] = 2
@@ -0,0 +1,32 @@
1
+ import math
2
+ from abc import ABC
3
+
4
+ class ParameterSchedule(ABC):
5
+ def compute(self, t):
6
+ raise NotImplementedError
7
+
8
+ class ExponentialInterpolation(ParameterSchedule):
9
+ def __init__(self, start, end, alpha):
10
+ self.start = start
11
+ self.end = end
12
+ self.alpha = alpha
13
+
14
+ def compute(self, t):
15
+ if self.alpha != 0:
16
+ return self.start + (self.end - self.start) * (math.exp(self.alpha * t) - 1) / (math.exp(self.alpha) - 1)
17
+ else:
18
+ return self.start + (self.end - self.start) * t
19
+
20
+ class PiecewiseStepFunction(ParameterSchedule):
21
+ def __init__(self, thresholds, values):
22
+ self.thresholds = thresholds
23
+ self.values = values
24
+
25
+ def compute(self, t):
26
+ assert len(self.thresholds) > 0
27
+ assert len(self.values) == len(self.thresholds) + 1
28
+
29
+ idx = 0
30
+ while idx < len(self.thresholds) and t > self.thresholds[idx]:
31
+ idx += 1
32
+ return self.values[idx]