boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,99 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
|
4
|
+
class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
|
5
|
+
"""Implements the learning rate schedule defined AF3.
|
6
|
+
|
7
|
+
A linear warmup is followed by a plateau at the maximum
|
8
|
+
learning rate and then exponential decay. Note that the
|
9
|
+
initial learning rate of the optimizer in question is
|
10
|
+
ignored; use this class' base_lr parameter to specify
|
11
|
+
the starting point of the warmup.
|
12
|
+
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
optimizer: torch.optim.Optimizer,
|
18
|
+
last_epoch: int = -1,
|
19
|
+
base_lr: float = 0.0,
|
20
|
+
max_lr: float = 1.8e-3,
|
21
|
+
warmup_no_steps: int = 1000,
|
22
|
+
start_decay_after_n_steps: int = 50000,
|
23
|
+
decay_every_n_steps: int = 50000,
|
24
|
+
decay_factor: float = 0.95,
|
25
|
+
) -> None:
|
26
|
+
"""Initialize the learning rate scheduler.
|
27
|
+
|
28
|
+
Parameters
|
29
|
+
----------
|
30
|
+
optimizer : torch.optim.Optimizer
|
31
|
+
The optimizer.
|
32
|
+
last_epoch : int, optional
|
33
|
+
The last epoch, by default -1
|
34
|
+
base_lr : float, optional
|
35
|
+
The base learning rate, by default 0.0
|
36
|
+
max_lr : float, optional
|
37
|
+
The maximum learning rate, by default 1.8e-3
|
38
|
+
warmup_no_steps : int, optional
|
39
|
+
The number of warmup steps, by default 1000
|
40
|
+
start_decay_after_n_steps : int, optional
|
41
|
+
The number of steps after which to start decay, by default 50000
|
42
|
+
decay_every_n_steps : int, optional
|
43
|
+
The number of steps after which to decay, by default 50000
|
44
|
+
decay_factor : float, optional
|
45
|
+
The decay factor, by default 0.95
|
46
|
+
|
47
|
+
"""
|
48
|
+
step_counts = {
|
49
|
+
"warmup_no_steps": warmup_no_steps,
|
50
|
+
"start_decay_after_n_steps": start_decay_after_n_steps,
|
51
|
+
}
|
52
|
+
|
53
|
+
for k, v in step_counts.items():
|
54
|
+
if v < 0:
|
55
|
+
msg = f"{k} must be nonnegative"
|
56
|
+
raise ValueError(msg)
|
57
|
+
|
58
|
+
if warmup_no_steps > start_decay_after_n_steps:
|
59
|
+
msg = "warmup_no_steps must not exceed start_decay_after_n_steps"
|
60
|
+
raise ValueError(msg)
|
61
|
+
|
62
|
+
self.optimizer = optimizer
|
63
|
+
self.last_epoch = last_epoch
|
64
|
+
self.base_lr = base_lr
|
65
|
+
self.max_lr = max_lr
|
66
|
+
self.warmup_no_steps = warmup_no_steps
|
67
|
+
self.start_decay_after_n_steps = start_decay_after_n_steps
|
68
|
+
self.decay_every_n_steps = decay_every_n_steps
|
69
|
+
self.decay_factor = decay_factor
|
70
|
+
|
71
|
+
super().__init__(optimizer, last_epoch=last_epoch)
|
72
|
+
|
73
|
+
def state_dict(self) -> dict:
|
74
|
+
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
|
75
|
+
return state_dict
|
76
|
+
|
77
|
+
def load_state_dict(self, state_dict):
|
78
|
+
self.__dict__.update(state_dict)
|
79
|
+
|
80
|
+
def get_lr(self):
|
81
|
+
if not self._get_lr_called_within_step:
|
82
|
+
msg = (
|
83
|
+
"To get the last learning rate computed by the scheduler, use "
|
84
|
+
"get_last_lr()"
|
85
|
+
)
|
86
|
+
raise RuntimeError(msg)
|
87
|
+
|
88
|
+
step_no = self.last_epoch
|
89
|
+
|
90
|
+
if step_no <= self.warmup_no_steps:
|
91
|
+
lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
|
92
|
+
elif step_no > self.start_decay_after_n_steps:
|
93
|
+
steps_since_decay = step_no - self.start_decay_after_n_steps
|
94
|
+
exp = (steps_since_decay // self.decay_every_n_steps) + 1
|
95
|
+
lr = self.max_lr * (self.decay_factor**exp)
|
96
|
+
else: # plateau
|
97
|
+
lr = self.max_lr
|
98
|
+
|
99
|
+
return [lr for group in self.optimizer.param_groups]
|
File without changes
|
@@ -0,0 +1,497 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import Optional, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from boltz.data import const
|
8
|
+
from boltz.model.potentials.schedules import (
|
9
|
+
ExponentialInterpolation,
|
10
|
+
ParameterSchedule,
|
11
|
+
PiecewiseStepFunction,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class Potential(ABC):
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
parameters: Optional[
|
19
|
+
dict[str, Union[ParameterSchedule, float, int, bool]]
|
20
|
+
] = None,
|
21
|
+
):
|
22
|
+
self.parameters = parameters
|
23
|
+
|
24
|
+
def compute(self, coords, feats, parameters):
|
25
|
+
index, args, com_args = self.compute_args(feats, parameters)
|
26
|
+
|
27
|
+
if index.shape[1] == 0:
|
28
|
+
return torch.zeros(coords.shape[:-2], device=coords.device)
|
29
|
+
|
30
|
+
if com_args is not None:
|
31
|
+
com_index, atom_pad_mask = com_args
|
32
|
+
unpad_com_index = com_index[atom_pad_mask]
|
33
|
+
unpad_coords = coords[..., atom_pad_mask, :]
|
34
|
+
coords = torch.zeros(
|
35
|
+
(*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3),
|
36
|
+
device=coords.device,
|
37
|
+
).scatter_reduce(
|
38
|
+
-2,
|
39
|
+
unpad_com_index.unsqueeze(-1).expand_as(unpad_coords),
|
40
|
+
unpad_coords,
|
41
|
+
"mean",
|
42
|
+
)
|
43
|
+
value = self.compute_variable(coords, index, compute_gradient=False)
|
44
|
+
energy = self.compute_function(value, *args)
|
45
|
+
return energy.sum(dim=-1)
|
46
|
+
|
47
|
+
def compute_gradient(self, coords, feats, parameters):
|
48
|
+
index, args, com_args = self.compute_args(feats, parameters)
|
49
|
+
if com_args is not None:
|
50
|
+
com_index, atom_pad_mask = com_args
|
51
|
+
else:
|
52
|
+
com_index, atom_pad_mask = None, None
|
53
|
+
|
54
|
+
if index.shape[1] == 0:
|
55
|
+
return torch.zeros_like(coords)
|
56
|
+
|
57
|
+
if com_index is not None:
|
58
|
+
unpad_coords = coords[..., atom_pad_mask, :]
|
59
|
+
unpad_com_index = com_index[atom_pad_mask]
|
60
|
+
coords = torch.zeros(
|
61
|
+
(*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3),
|
62
|
+
device=coords.device,
|
63
|
+
).scatter_reduce(
|
64
|
+
-2,
|
65
|
+
unpad_com_index.unsqueeze(-1).expand_as(unpad_coords),
|
66
|
+
unpad_coords,
|
67
|
+
"mean",
|
68
|
+
)
|
69
|
+
com_counts = torch.bincount(com_index[atom_pad_mask])
|
70
|
+
|
71
|
+
value, grad_value = self.compute_variable(coords, index, compute_gradient=True)
|
72
|
+
energy, dEnergy = self.compute_function(value, *args, compute_derivative=True)
|
73
|
+
|
74
|
+
grad_atom = torch.zeros_like(coords).scatter_reduce(
|
75
|
+
-2,
|
76
|
+
index.flatten(start_dim=0, end_dim=1)
|
77
|
+
.unsqueeze(-1)
|
78
|
+
.expand((*coords.shape[:-2], -1, 3)),
|
79
|
+
dEnergy.tile(grad_value.shape[-3]).unsqueeze(-1)
|
80
|
+
* grad_value.flatten(start_dim=-3, end_dim=-2),
|
81
|
+
"sum",
|
82
|
+
)
|
83
|
+
|
84
|
+
if com_index is not None:
|
85
|
+
grad_atom = grad_atom[..., com_index, :]
|
86
|
+
|
87
|
+
return grad_atom
|
88
|
+
|
89
|
+
def compute_parameters(self, t):
|
90
|
+
if self.parameters is None:
|
91
|
+
return None
|
92
|
+
parameters = {
|
93
|
+
name: parameter
|
94
|
+
if not isinstance(parameter, ParameterSchedule)
|
95
|
+
else parameter.compute(t)
|
96
|
+
for name, parameter in self.parameters.items()
|
97
|
+
}
|
98
|
+
return parameters
|
99
|
+
|
100
|
+
@abstractmethod
|
101
|
+
def compute_function(self, value, *args, compute_derivative=False):
|
102
|
+
raise NotImplementedError
|
103
|
+
|
104
|
+
@abstractmethod
|
105
|
+
def compute_variable(self, coords, index, compute_gradient=False):
|
106
|
+
raise NotImplementedError
|
107
|
+
|
108
|
+
@abstractmethod
|
109
|
+
def compute_args(self, t, feats, **parameters):
|
110
|
+
raise NotImplementedError
|
111
|
+
|
112
|
+
|
113
|
+
class FlatBottomPotential(Potential):
|
114
|
+
def compute_function(
|
115
|
+
self, value, k, lower_bounds, upper_bounds, compute_derivative=False
|
116
|
+
):
|
117
|
+
if lower_bounds is None:
|
118
|
+
lower_bounds = torch.full_like(value, float("-inf"))
|
119
|
+
if upper_bounds is None:
|
120
|
+
upper_bounds = torch.full_like(value, float("inf"))
|
121
|
+
|
122
|
+
neg_overflow_mask = value < lower_bounds
|
123
|
+
pos_overflow_mask = value > upper_bounds
|
124
|
+
|
125
|
+
energy = torch.zeros_like(value)
|
126
|
+
energy[neg_overflow_mask] = (k * (lower_bounds - value))[neg_overflow_mask]
|
127
|
+
energy[pos_overflow_mask] = (k * (value - upper_bounds))[pos_overflow_mask]
|
128
|
+
if not compute_derivative:
|
129
|
+
return energy
|
130
|
+
|
131
|
+
dEnergy = torch.zeros_like(value)
|
132
|
+
dEnergy[neg_overflow_mask] = (
|
133
|
+
-1 * k.expand_as(neg_overflow_mask)[neg_overflow_mask]
|
134
|
+
)
|
135
|
+
dEnergy[pos_overflow_mask] = (
|
136
|
+
1 * k.expand_as(pos_overflow_mask)[pos_overflow_mask]
|
137
|
+
)
|
138
|
+
|
139
|
+
return energy, dEnergy
|
140
|
+
|
141
|
+
|
142
|
+
class DistancePotential(Potential):
|
143
|
+
def compute_variable(self, coords, index, compute_gradient=False):
|
144
|
+
r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1])
|
145
|
+
r_ij_norm = torch.linalg.norm(r_ij, dim=-1)
|
146
|
+
r_hat_ij = r_ij / r_ij_norm.unsqueeze(-1)
|
147
|
+
|
148
|
+
if not compute_gradient:
|
149
|
+
return r_ij_norm
|
150
|
+
|
151
|
+
grad_i = r_hat_ij
|
152
|
+
grad_j = -1 * r_hat_ij
|
153
|
+
grad = torch.stack((grad_i, grad_j), dim=1)
|
154
|
+
|
155
|
+
return r_ij_norm, grad
|
156
|
+
|
157
|
+
|
158
|
+
class DihedralPotential(Potential):
|
159
|
+
def compute_variable(self, coords, index, compute_gradient=False):
|
160
|
+
r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1])
|
161
|
+
r_kj = coords.index_select(-2, index[2]) - coords.index_select(-2, index[1])
|
162
|
+
r_kl = coords.index_select(-2, index[2]) - coords.index_select(-2, index[3])
|
163
|
+
|
164
|
+
n_ijk = torch.cross(r_ij, r_kj, dim=-1)
|
165
|
+
n_jkl = torch.cross(r_kj, r_kl, dim=-1)
|
166
|
+
|
167
|
+
r_kj_norm = torch.linalg.norm(r_kj, dim=-1)
|
168
|
+
n_ijk_norm = torch.linalg.norm(n_ijk, dim=-1)
|
169
|
+
n_jkl_norm = torch.linalg.norm(n_jkl, dim=-1)
|
170
|
+
|
171
|
+
sign_phi = torch.sign(
|
172
|
+
r_kj.unsqueeze(-2) @ torch.cross(n_ijk, n_jkl, dim=-1).unsqueeze(-1)
|
173
|
+
).squeeze(-1, -2)
|
174
|
+
phi = sign_phi * torch.arccos(
|
175
|
+
torch.clamp(
|
176
|
+
(n_ijk.unsqueeze(-2) @ n_jkl.unsqueeze(-1)).squeeze(-1, -2)
|
177
|
+
/ (n_ijk_norm * n_jkl_norm),
|
178
|
+
-1 + 1e-8,
|
179
|
+
1 - 1e-8,
|
180
|
+
)
|
181
|
+
)
|
182
|
+
|
183
|
+
if not compute_gradient:
|
184
|
+
return phi
|
185
|
+
|
186
|
+
a = (
|
187
|
+
(r_ij.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2)
|
188
|
+
).unsqueeze(-1)
|
189
|
+
b = (
|
190
|
+
(r_kl.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2)
|
191
|
+
).unsqueeze(-1)
|
192
|
+
|
193
|
+
grad_i = n_ijk * (r_kj_norm / n_ijk_norm**2).unsqueeze(-1)
|
194
|
+
grad_l = -1 * n_jkl * (r_kj_norm / n_jkl_norm**2).unsqueeze(-1)
|
195
|
+
grad_j = (a - 1) * grad_i - b * grad_l
|
196
|
+
grad_k = (b - 1) * grad_l - a * grad_i
|
197
|
+
grad = torch.stack((grad_i, grad_j, grad_k, grad_l), dim=1)
|
198
|
+
return phi, grad
|
199
|
+
|
200
|
+
|
201
|
+
class AbsDihedralPotential(DihedralPotential):
|
202
|
+
def compute_variable(self, coords, index, compute_gradient=False):
|
203
|
+
if not compute_gradient:
|
204
|
+
phi = super().compute_variable(
|
205
|
+
coords, index, compute_gradient=compute_gradient
|
206
|
+
)
|
207
|
+
phi = torch.abs(phi)
|
208
|
+
return phi
|
209
|
+
|
210
|
+
phi, grad = super().compute_variable(
|
211
|
+
coords, index, compute_gradient=compute_gradient
|
212
|
+
)
|
213
|
+
grad[(phi < 0)[..., None, :, None].expand_as(grad)] *= -1
|
214
|
+
phi = torch.abs(phi)
|
215
|
+
|
216
|
+
return phi, grad
|
217
|
+
|
218
|
+
|
219
|
+
class PoseBustersPotential(FlatBottomPotential, DistancePotential):
|
220
|
+
def compute_args(self, feats, parameters):
|
221
|
+
pair_index = feats["rdkit_bounds_index"][0]
|
222
|
+
lower_bounds = feats["rdkit_lower_bounds"][0].clone()
|
223
|
+
upper_bounds = feats["rdkit_upper_bounds"][0].clone()
|
224
|
+
bond_mask = feats["rdkit_bounds_bond_mask"][0]
|
225
|
+
angle_mask = feats["rdkit_bounds_angle_mask"][0]
|
226
|
+
|
227
|
+
lower_bounds[bond_mask * ~angle_mask] *= 1.0 - parameters["bond_buffer"]
|
228
|
+
upper_bounds[bond_mask * ~angle_mask] *= 1.0 + parameters["bond_buffer"]
|
229
|
+
lower_bounds[~bond_mask * angle_mask] *= 1.0 - parameters["angle_buffer"]
|
230
|
+
upper_bounds[~bond_mask * angle_mask] *= 1.0 + parameters["angle_buffer"]
|
231
|
+
lower_bounds[bond_mask * angle_mask] *= 1.0 - min(
|
232
|
+
parameters["angle_buffer"], parameters["angle_buffer"]
|
233
|
+
)
|
234
|
+
upper_bounds[bond_mask * angle_mask] *= 1.0 + min(
|
235
|
+
parameters["angle_buffer"], parameters["angle_buffer"]
|
236
|
+
)
|
237
|
+
lower_bounds[~bond_mask * ~angle_mask] *= 1.0 - parameters["clash_buffer"]
|
238
|
+
upper_bounds[~bond_mask * ~angle_mask] = float("inf")
|
239
|
+
|
240
|
+
k = torch.ones_like(lower_bounds)
|
241
|
+
|
242
|
+
return pair_index, (k, lower_bounds, upper_bounds), None
|
243
|
+
|
244
|
+
|
245
|
+
class ConnectionsPotential(FlatBottomPotential, DistancePotential):
|
246
|
+
def compute_args(self, feats, parameters):
|
247
|
+
pair_index = feats["connected_atom_index"][0]
|
248
|
+
lower_bounds = None
|
249
|
+
upper_bounds = torch.full(
|
250
|
+
(pair_index.shape[1],), parameters["buffer"], device=pair_index.device
|
251
|
+
)
|
252
|
+
k = torch.ones_like(upper_bounds)
|
253
|
+
|
254
|
+
return pair_index, (k, lower_bounds, upper_bounds), None
|
255
|
+
|
256
|
+
|
257
|
+
class VDWOverlapPotential(FlatBottomPotential, DistancePotential):
|
258
|
+
def compute_args(self, feats, parameters):
|
259
|
+
atom_chain_id = (
|
260
|
+
torch.bmm(
|
261
|
+
feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float()
|
262
|
+
)
|
263
|
+
.squeeze(-1)
|
264
|
+
.long()
|
265
|
+
)[0]
|
266
|
+
atom_pad_mask = feats["atom_pad_mask"][0].bool()
|
267
|
+
chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask])
|
268
|
+
single_ion_mask = (chain_sizes > 1)[atom_chain_id]
|
269
|
+
|
270
|
+
vdw_radii = torch.zeros(
|
271
|
+
const.num_elements, dtype=torch.float32, device=atom_chain_id.device
|
272
|
+
)
|
273
|
+
vdw_radii[1:119] = torch.tensor(
|
274
|
+
const.vdw_radii, dtype=torch.float32, device=atom_chain_id.device
|
275
|
+
)
|
276
|
+
atom_vdw_radii = (
|
277
|
+
feats["ref_element"].float() @ vdw_radii.unsqueeze(-1)
|
278
|
+
).squeeze(-1)[0]
|
279
|
+
|
280
|
+
pair_index = torch.triu_indices(
|
281
|
+
atom_chain_id.shape[0],
|
282
|
+
atom_chain_id.shape[0],
|
283
|
+
1,
|
284
|
+
device=atom_chain_id.device,
|
285
|
+
)
|
286
|
+
|
287
|
+
pair_pad_mask = atom_pad_mask[pair_index].all(dim=0)
|
288
|
+
pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]]
|
289
|
+
|
290
|
+
num_chains = atom_chain_id.max() + 1
|
291
|
+
connected_chain_index = feats["connected_chain_index"][0]
|
292
|
+
connected_chain_matrix = torch.eye(
|
293
|
+
num_chains, device=atom_chain_id.device, dtype=torch.bool
|
294
|
+
)
|
295
|
+
connected_chain_matrix[connected_chain_index[0], connected_chain_index[1]] = (
|
296
|
+
True
|
297
|
+
)
|
298
|
+
connected_chain_matrix[connected_chain_index[1], connected_chain_index[0]] = (
|
299
|
+
True
|
300
|
+
)
|
301
|
+
connected_chain_mask = connected_chain_matrix[
|
302
|
+
atom_chain_id[pair_index[0]], atom_chain_id[pair_index[1]]
|
303
|
+
]
|
304
|
+
|
305
|
+
pair_index = pair_index[
|
306
|
+
:, pair_pad_mask * pair_ion_mask * ~connected_chain_mask
|
307
|
+
]
|
308
|
+
|
309
|
+
lower_bounds = atom_vdw_radii[pair_index].sum(dim=0) * (
|
310
|
+
1.0 - parameters["buffer"]
|
311
|
+
)
|
312
|
+
upper_bounds = None
|
313
|
+
k = torch.ones_like(lower_bounds)
|
314
|
+
|
315
|
+
return pair_index, (k, lower_bounds, upper_bounds), None
|
316
|
+
|
317
|
+
|
318
|
+
class SymmetricChainCOMPotential(FlatBottomPotential, DistancePotential):
|
319
|
+
def compute_args(self, feats, parameters):
|
320
|
+
atom_chain_id = (
|
321
|
+
torch.bmm(
|
322
|
+
feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float()
|
323
|
+
)
|
324
|
+
.squeeze(-1)
|
325
|
+
.long()
|
326
|
+
)[0]
|
327
|
+
atom_pad_mask = feats["atom_pad_mask"][0].bool()
|
328
|
+
chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask])
|
329
|
+
single_ion_mask = chain_sizes > 1
|
330
|
+
|
331
|
+
pair_index = feats["symmetric_chain_index"][0]
|
332
|
+
pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]]
|
333
|
+
pair_index = pair_index[:, pair_ion_mask]
|
334
|
+
lower_bounds = torch.full(
|
335
|
+
(pair_index.shape[1],),
|
336
|
+
parameters["buffer"],
|
337
|
+
dtype=torch.float32,
|
338
|
+
device=pair_index.device,
|
339
|
+
)
|
340
|
+
upper_bounds = None
|
341
|
+
k = torch.ones_like(lower_bounds)
|
342
|
+
|
343
|
+
return (
|
344
|
+
pair_index,
|
345
|
+
(k, lower_bounds, upper_bounds),
|
346
|
+
(atom_chain_id, atom_pad_mask),
|
347
|
+
)
|
348
|
+
|
349
|
+
|
350
|
+
class StereoBondPotential(FlatBottomPotential, AbsDihedralPotential):
|
351
|
+
def compute_args(self, feats, parameters):
|
352
|
+
stereo_bond_index = feats["stereo_bond_index"][0]
|
353
|
+
stereo_bond_orientations = feats["stereo_bond_orientations"][0].bool()
|
354
|
+
|
355
|
+
lower_bounds = torch.zeros(
|
356
|
+
stereo_bond_orientations.shape, device=stereo_bond_orientations.device
|
357
|
+
)
|
358
|
+
upper_bounds = torch.zeros(
|
359
|
+
stereo_bond_orientations.shape, device=stereo_bond_orientations.device
|
360
|
+
)
|
361
|
+
lower_bounds[stereo_bond_orientations] = torch.pi - parameters["buffer"]
|
362
|
+
upper_bounds[stereo_bond_orientations] = float("inf")
|
363
|
+
lower_bounds[~stereo_bond_orientations] = float("-inf")
|
364
|
+
upper_bounds[~stereo_bond_orientations] = parameters["buffer"]
|
365
|
+
|
366
|
+
k = torch.ones_like(lower_bounds)
|
367
|
+
|
368
|
+
return stereo_bond_index, (k, lower_bounds, upper_bounds), None
|
369
|
+
|
370
|
+
|
371
|
+
class ChiralAtomPotential(FlatBottomPotential, DihedralPotential):
|
372
|
+
def compute_args(self, feats, parameters):
|
373
|
+
chiral_atom_index = feats["chiral_atom_index"][0]
|
374
|
+
chiral_atom_orientations = feats["chiral_atom_orientations"][0].bool()
|
375
|
+
|
376
|
+
lower_bounds = torch.zeros(
|
377
|
+
chiral_atom_orientations.shape, device=chiral_atom_orientations.device
|
378
|
+
)
|
379
|
+
upper_bounds = torch.zeros(
|
380
|
+
chiral_atom_orientations.shape, device=chiral_atom_orientations.device
|
381
|
+
)
|
382
|
+
lower_bounds[chiral_atom_orientations] = parameters["buffer"]
|
383
|
+
upper_bounds[chiral_atom_orientations] = float("inf")
|
384
|
+
upper_bounds[~chiral_atom_orientations] = -1 * parameters["buffer"]
|
385
|
+
lower_bounds[~chiral_atom_orientations] = float("-inf")
|
386
|
+
|
387
|
+
k = torch.ones_like(lower_bounds)
|
388
|
+
return chiral_atom_index, (k, lower_bounds, upper_bounds), None
|
389
|
+
|
390
|
+
|
391
|
+
class PlanarBondPotential(FlatBottomPotential, AbsDihedralPotential):
|
392
|
+
def compute_args(self, feats, parameters):
|
393
|
+
double_bond_index = feats["planar_bond_index"][0].T
|
394
|
+
double_bond_improper_index = torch.tensor(
|
395
|
+
[
|
396
|
+
[1, 2, 3, 0],
|
397
|
+
[4, 5, 0, 3],
|
398
|
+
],
|
399
|
+
device=double_bond_index.device,
|
400
|
+
).T
|
401
|
+
improper_index = (
|
402
|
+
double_bond_index[:, double_bond_improper_index]
|
403
|
+
.swapaxes(0, 1)
|
404
|
+
.flatten(start_dim=1)
|
405
|
+
)
|
406
|
+
lower_bounds = None
|
407
|
+
upper_bounds = torch.full(
|
408
|
+
(improper_index.shape[1],),
|
409
|
+
parameters["buffer"],
|
410
|
+
device=improper_index.device,
|
411
|
+
)
|
412
|
+
k = torch.ones_like(upper_bounds)
|
413
|
+
|
414
|
+
return improper_index, (k, lower_bounds, upper_bounds), None
|
415
|
+
|
416
|
+
|
417
|
+
def get_potentials():
|
418
|
+
potentials = [
|
419
|
+
SymmetricChainCOMPotential(
|
420
|
+
parameters={
|
421
|
+
"guidance_interval": 4,
|
422
|
+
"guidance_weight": 0.5,
|
423
|
+
"resampling_weight": 0.5,
|
424
|
+
"buffer": ExponentialInterpolation(start=1.0, end=5.0, alpha=-2.0),
|
425
|
+
}
|
426
|
+
),
|
427
|
+
VDWOverlapPotential(
|
428
|
+
parameters={
|
429
|
+
"guidance_interval": 5,
|
430
|
+
"guidance_weight": PiecewiseStepFunction(
|
431
|
+
thresholds=[0.4], values=[0.125, 0.0]
|
432
|
+
),
|
433
|
+
"resampling_weight": PiecewiseStepFunction(
|
434
|
+
thresholds=[0.6], values=[0.01, 0.0]
|
435
|
+
),
|
436
|
+
"buffer": 0.225,
|
437
|
+
}
|
438
|
+
),
|
439
|
+
ConnectionsPotential(
|
440
|
+
parameters={
|
441
|
+
"guidance_interval": 1,
|
442
|
+
"guidance_weight": 0.15,
|
443
|
+
"resampling_weight": 1.0,
|
444
|
+
"buffer": 2.0,
|
445
|
+
}
|
446
|
+
),
|
447
|
+
PoseBustersPotential(
|
448
|
+
parameters={
|
449
|
+
"guidance_interval": 1,
|
450
|
+
"guidance_weight": 0.05,
|
451
|
+
"resampling_weight": 0.1,
|
452
|
+
"bond_buffer": 0.20,
|
453
|
+
"angle_buffer": 0.20,
|
454
|
+
"clash_buffer": 0.15,
|
455
|
+
}
|
456
|
+
),
|
457
|
+
ChiralAtomPotential(
|
458
|
+
parameters={
|
459
|
+
"guidance_interval": 1,
|
460
|
+
"guidance_weight": 0.10,
|
461
|
+
"resampling_weight": 1.0,
|
462
|
+
"buffer": 0.52360,
|
463
|
+
}
|
464
|
+
),
|
465
|
+
StereoBondPotential(
|
466
|
+
parameters={
|
467
|
+
"guidance_interval": 1,
|
468
|
+
"guidance_weight": 0.05,
|
469
|
+
"resampling_weight": 1.0,
|
470
|
+
"buffer": 0.52360,
|
471
|
+
}
|
472
|
+
),
|
473
|
+
PlanarBondPotential(
|
474
|
+
parameters={
|
475
|
+
"guidance_interval": 1,
|
476
|
+
"guidance_weight": 0.05,
|
477
|
+
"resampling_weight": 1.0,
|
478
|
+
"buffer": 0.26180,
|
479
|
+
}
|
480
|
+
),
|
481
|
+
]
|
482
|
+
return potentials
|
483
|
+
|
484
|
+
|
485
|
+
@dataclass
|
486
|
+
class GuidanceConfig:
|
487
|
+
"""Guidance configuration."""
|
488
|
+
|
489
|
+
potentials: Optional[list[Potential]] = None
|
490
|
+
guidance_update: Optional[bool] = None
|
491
|
+
num_guidance_gd_steps: Optional[int] = None
|
492
|
+
guidance_gd_step_size: Optional[int] = None
|
493
|
+
fk_steering: Optional[bool] = None
|
494
|
+
fk_resampling_interval: Optional[int] = 1
|
495
|
+
fk_lambda: Optional[float] = 1.0
|
496
|
+
fk_method: Optional[str] = None
|
497
|
+
fk_batch_size: Optional[int] = 2
|
@@ -0,0 +1,32 @@
|
|
1
|
+
import math
|
2
|
+
from abc import ABC
|
3
|
+
|
4
|
+
class ParameterSchedule(ABC):
|
5
|
+
def compute(self, t):
|
6
|
+
raise NotImplementedError
|
7
|
+
|
8
|
+
class ExponentialInterpolation(ParameterSchedule):
|
9
|
+
def __init__(self, start, end, alpha):
|
10
|
+
self.start = start
|
11
|
+
self.end = end
|
12
|
+
self.alpha = alpha
|
13
|
+
|
14
|
+
def compute(self, t):
|
15
|
+
if self.alpha != 0:
|
16
|
+
return self.start + (self.end - self.start) * (math.exp(self.alpha * t) - 1) / (math.exp(self.alpha) - 1)
|
17
|
+
else:
|
18
|
+
return self.start + (self.end - self.start) * t
|
19
|
+
|
20
|
+
class PiecewiseStepFunction(ParameterSchedule):
|
21
|
+
def __init__(self, thresholds, values):
|
22
|
+
self.thresholds = thresholds
|
23
|
+
self.values = values
|
24
|
+
|
25
|
+
def compute(self, t):
|
26
|
+
assert len(self.thresholds) > 0
|
27
|
+
assert len(self.values) == len(self.thresholds) + 1
|
28
|
+
|
29
|
+
idx = 0
|
30
|
+
while idx < len(self.thresholds) and t > self.thresholds[idx]:
|
31
|
+
idx += 1
|
32
|
+
return self.values[idx]
|