boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
boltz/model/__init__.py
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,132 @@
|
|
1
|
+
import torch
|
2
|
+
from einops.layers.torch import Rearrange
|
3
|
+
from torch import Tensor, nn
|
4
|
+
|
5
|
+
import boltz.model.layers.initialize as init
|
6
|
+
|
7
|
+
|
8
|
+
class AttentionPairBias(nn.Module):
|
9
|
+
"""Attention pair bias layer."""
|
10
|
+
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
c_s: int,
|
14
|
+
c_z: int,
|
15
|
+
num_heads: int,
|
16
|
+
inf: float = 1e6,
|
17
|
+
initial_norm: bool = True,
|
18
|
+
) -> None:
|
19
|
+
"""Initialize the attention pair bias layer.
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
23
|
+
c_s : int
|
24
|
+
The input sequence dimension.
|
25
|
+
c_z : int
|
26
|
+
The input pairwise dimension.
|
27
|
+
num_heads : int
|
28
|
+
The number of heads.
|
29
|
+
inf : float, optional
|
30
|
+
The inf value, by default 1e6
|
31
|
+
initial_norm: bool, optional
|
32
|
+
Whether to apply layer norm to the input, by default True
|
33
|
+
|
34
|
+
"""
|
35
|
+
super().__init__()
|
36
|
+
|
37
|
+
assert c_s % num_heads == 0
|
38
|
+
|
39
|
+
self.c_s = c_s
|
40
|
+
self.num_heads = num_heads
|
41
|
+
self.head_dim = c_s // num_heads
|
42
|
+
self.inf = inf
|
43
|
+
|
44
|
+
self.initial_norm = initial_norm
|
45
|
+
if self.initial_norm:
|
46
|
+
self.norm_s = nn.LayerNorm(c_s)
|
47
|
+
|
48
|
+
self.proj_q = nn.Linear(c_s, c_s)
|
49
|
+
self.proj_k = nn.Linear(c_s, c_s, bias=False)
|
50
|
+
self.proj_v = nn.Linear(c_s, c_s, bias=False)
|
51
|
+
self.proj_g = nn.Linear(c_s, c_s, bias=False)
|
52
|
+
|
53
|
+
self.proj_z = nn.Sequential(
|
54
|
+
nn.LayerNorm(c_z),
|
55
|
+
nn.Linear(c_z, num_heads, bias=False),
|
56
|
+
Rearrange("b ... h -> b h ..."),
|
57
|
+
)
|
58
|
+
|
59
|
+
self.proj_o = nn.Linear(c_s, c_s, bias=False)
|
60
|
+
init.final_init_(self.proj_o.weight)
|
61
|
+
|
62
|
+
def forward(
|
63
|
+
self,
|
64
|
+
s: Tensor,
|
65
|
+
z: Tensor,
|
66
|
+
mask: Tensor,
|
67
|
+
multiplicity: int = 1,
|
68
|
+
to_keys=None,
|
69
|
+
model_cache=None,
|
70
|
+
) -> Tensor:
|
71
|
+
"""Forward pass.
|
72
|
+
|
73
|
+
Parameters
|
74
|
+
----------
|
75
|
+
s : torch.Tensor
|
76
|
+
The input sequence tensor (B, S, D)
|
77
|
+
z : torch.Tensor
|
78
|
+
The input pairwise tensor (B, N, N, D)
|
79
|
+
mask : torch.Tensor
|
80
|
+
The pairwise mask tensor (B, N)
|
81
|
+
multiplicity : int, optional
|
82
|
+
The diffusion batch size, by default 1
|
83
|
+
|
84
|
+
Returns
|
85
|
+
-------
|
86
|
+
torch.Tensor
|
87
|
+
The output sequence tensor.
|
88
|
+
|
89
|
+
"""
|
90
|
+
B = s.shape[0]
|
91
|
+
|
92
|
+
# Layer norms
|
93
|
+
if self.initial_norm:
|
94
|
+
s = self.norm_s(s)
|
95
|
+
|
96
|
+
if to_keys is not None:
|
97
|
+
k_in = to_keys(s)
|
98
|
+
mask = to_keys(mask.unsqueeze(-1)).squeeze(-1)
|
99
|
+
else:
|
100
|
+
k_in = s
|
101
|
+
|
102
|
+
# Compute projections
|
103
|
+
q = self.proj_q(s).view(B, -1, self.num_heads, self.head_dim)
|
104
|
+
k = self.proj_k(k_in).view(B, -1, self.num_heads, self.head_dim)
|
105
|
+
v = self.proj_v(k_in).view(B, -1, self.num_heads, self.head_dim)
|
106
|
+
|
107
|
+
# Caching z projection during diffusion roll-out
|
108
|
+
if model_cache is None or "z" not in model_cache:
|
109
|
+
z = self.proj_z(z)
|
110
|
+
|
111
|
+
if model_cache is not None:
|
112
|
+
model_cache["z"] = z
|
113
|
+
else:
|
114
|
+
z = model_cache["z"]
|
115
|
+
z = z.repeat_interleave(multiplicity, 0)
|
116
|
+
|
117
|
+
g = self.proj_g(s).sigmoid()
|
118
|
+
|
119
|
+
with torch.autocast("cuda", enabled=False):
|
120
|
+
# Compute attention weights
|
121
|
+
attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float())
|
122
|
+
attn = attn / (self.head_dim**0.5) + z.float()
|
123
|
+
# The pairwise mask tensor (B, N) is broadcasted to (B, 1, 1, N) and (B, H, N, N)
|
124
|
+
attn = attn + (1 - mask[:, None, None].float()) * -self.inf
|
125
|
+
attn = attn.softmax(dim=-1)
|
126
|
+
|
127
|
+
# Compute output
|
128
|
+
o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype)
|
129
|
+
o = o.reshape(B, -1, self.c_s)
|
130
|
+
o = self.proj_o(g * o)
|
131
|
+
|
132
|
+
return o
|
@@ -0,0 +1,111 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from einops.layers.torch import Rearrange
|
5
|
+
from torch import Tensor, nn
|
6
|
+
|
7
|
+
import boltz.model.layers.initialize as init
|
8
|
+
|
9
|
+
|
10
|
+
class AttentionPairBias(nn.Module):
|
11
|
+
"""Attention pair bias layer."""
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
c_s: int,
|
16
|
+
c_z: Optional[int] = None,
|
17
|
+
num_heads: Optional[int] = None,
|
18
|
+
inf: float = 1e6,
|
19
|
+
compute_pair_bias: bool = True,
|
20
|
+
) -> None:
|
21
|
+
"""Initialize the attention pair bias layer.
|
22
|
+
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
c_s : int
|
26
|
+
The input sequence dimension.
|
27
|
+
c_z : int
|
28
|
+
The input pairwise dimension.
|
29
|
+
num_heads : int
|
30
|
+
The number of heads.
|
31
|
+
inf : float, optional
|
32
|
+
The inf value, by default 1e6
|
33
|
+
|
34
|
+
"""
|
35
|
+
super().__init__()
|
36
|
+
|
37
|
+
assert c_s % num_heads == 0
|
38
|
+
|
39
|
+
self.c_s = c_s
|
40
|
+
self.num_heads = num_heads
|
41
|
+
self.head_dim = c_s // num_heads
|
42
|
+
self.inf = inf
|
43
|
+
|
44
|
+
self.proj_q = nn.Linear(c_s, c_s)
|
45
|
+
self.proj_k = nn.Linear(c_s, c_s, bias=False)
|
46
|
+
self.proj_v = nn.Linear(c_s, c_s, bias=False)
|
47
|
+
self.proj_g = nn.Linear(c_s, c_s, bias=False)
|
48
|
+
|
49
|
+
self.compute_pair_bias = compute_pair_bias
|
50
|
+
if compute_pair_bias:
|
51
|
+
self.proj_z = nn.Sequential(
|
52
|
+
nn.LayerNorm(c_z),
|
53
|
+
nn.Linear(c_z, num_heads, bias=False),
|
54
|
+
Rearrange("b ... h -> b h ..."),
|
55
|
+
)
|
56
|
+
else:
|
57
|
+
self.proj_z = Rearrange("b ... h -> b h ...")
|
58
|
+
|
59
|
+
self.proj_o = nn.Linear(c_s, c_s, bias=False)
|
60
|
+
init.final_init_(self.proj_o.weight)
|
61
|
+
|
62
|
+
def forward(
|
63
|
+
self,
|
64
|
+
s: Tensor,
|
65
|
+
z: Tensor,
|
66
|
+
mask: Tensor,
|
67
|
+
k_in: Tensor,
|
68
|
+
multiplicity: int = 1,
|
69
|
+
) -> Tensor:
|
70
|
+
"""Forward pass.
|
71
|
+
|
72
|
+
Parameters
|
73
|
+
----------
|
74
|
+
s : torch.Tensor
|
75
|
+
The input sequence tensor (B, S, D)
|
76
|
+
z : torch.Tensor
|
77
|
+
The input pairwise tensor or bias (B, N, N, D)
|
78
|
+
mask : torch.Tensor
|
79
|
+
The pairwise mask tensor (B, N, N)
|
80
|
+
|
81
|
+
Returns
|
82
|
+
-------
|
83
|
+
torch.Tensor
|
84
|
+
The output sequence tensor.
|
85
|
+
|
86
|
+
"""
|
87
|
+
B = s.shape[0]
|
88
|
+
|
89
|
+
# Compute projections
|
90
|
+
q = self.proj_q(s).view(B, -1, self.num_heads, self.head_dim)
|
91
|
+
k = self.proj_k(k_in).view(B, -1, self.num_heads, self.head_dim)
|
92
|
+
v = self.proj_v(k_in).view(B, -1, self.num_heads, self.head_dim)
|
93
|
+
|
94
|
+
bias = self.proj_z(z)
|
95
|
+
bias = bias.repeat_interleave(multiplicity, 0)
|
96
|
+
|
97
|
+
g = self.proj_g(s).sigmoid()
|
98
|
+
|
99
|
+
with torch.autocast("cuda", enabled=False):
|
100
|
+
# Compute attention weights
|
101
|
+
attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float())
|
102
|
+
attn = attn / (self.head_dim**0.5) + bias.float()
|
103
|
+
attn = attn + (1 - mask[:, None, None].float()) * -self.inf
|
104
|
+
attn = attn.softmax(dim=-1)
|
105
|
+
|
106
|
+
# Compute output
|
107
|
+
o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype)
|
108
|
+
o = o.reshape(B, -1, self.c_s)
|
109
|
+
o = self.proj_o(g * o)
|
110
|
+
|
111
|
+
return o
|
@@ -0,0 +1,231 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
|
4
|
+
from boltz.data import const
|
5
|
+
|
6
|
+
|
7
|
+
def compute_collinear_mask(v1, v2):
|
8
|
+
norm1 = torch.norm(v1, dim=1, keepdim=True)
|
9
|
+
norm2 = torch.norm(v2, dim=1, keepdim=True)
|
10
|
+
v1 = v1 / (norm1 + 1e-6)
|
11
|
+
v2 = v2 / (norm2 + 1e-6)
|
12
|
+
mask_angle = torch.abs(torch.sum(v1 * v2, dim=1)) < 0.9063
|
13
|
+
mask_overlap1 = norm1.reshape(-1) > 1e-2
|
14
|
+
mask_overlap2 = norm2.reshape(-1) > 1e-2
|
15
|
+
return mask_angle & mask_overlap1 & mask_overlap2
|
16
|
+
|
17
|
+
|
18
|
+
def compute_frame_pred(
|
19
|
+
pred_atom_coords,
|
20
|
+
frames_idx_true,
|
21
|
+
feats,
|
22
|
+
multiplicity,
|
23
|
+
resolved_mask=None,
|
24
|
+
inference=False,
|
25
|
+
):
|
26
|
+
with torch.amp.autocast("cuda", enabled=False):
|
27
|
+
asym_id_token = feats["asym_id"]
|
28
|
+
asym_id_atom = torch.bmm(
|
29
|
+
feats["atom_to_token"].float(), asym_id_token.unsqueeze(-1).float()
|
30
|
+
).squeeze(-1)
|
31
|
+
|
32
|
+
B, N, _ = pred_atom_coords.shape
|
33
|
+
pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
|
34
|
+
frames_idx_pred = (
|
35
|
+
frames_idx_true.clone()
|
36
|
+
.repeat_interleave(multiplicity, 0)
|
37
|
+
.reshape(B // multiplicity, multiplicity, -1, 3)
|
38
|
+
)
|
39
|
+
|
40
|
+
# Iterate through the batch and modify the frames for nonpolymers
|
41
|
+
for i, pred_atom_coord in enumerate(pred_atom_coords):
|
42
|
+
token_idx = 0
|
43
|
+
atom_idx = 0
|
44
|
+
for id in torch.unique(asym_id_token[i]):
|
45
|
+
mask_chain_token = (asym_id_token[i] == id) * feats["token_pad_mask"][i]
|
46
|
+
mask_chain_atom = (asym_id_atom[i] == id) * feats["atom_pad_mask"][i]
|
47
|
+
num_tokens = int(mask_chain_token.sum().item())
|
48
|
+
num_atoms = int(mask_chain_atom.sum().item())
|
49
|
+
if (
|
50
|
+
feats["mol_type"][i, token_idx] != const.chain_type_ids["NONPOLYMER"]
|
51
|
+
or num_atoms < 3
|
52
|
+
):
|
53
|
+
token_idx += num_tokens
|
54
|
+
atom_idx += num_atoms
|
55
|
+
continue
|
56
|
+
dist_mat = (
|
57
|
+
(
|
58
|
+
pred_atom_coord[:, mask_chain_atom.bool()][:, None, :, :]
|
59
|
+
- pred_atom_coord[:, mask_chain_atom.bool()][:, :, None, :]
|
60
|
+
)
|
61
|
+
** 2
|
62
|
+
).sum(-1) ** 0.5
|
63
|
+
if inference:
|
64
|
+
resolved_pair = 1 - (
|
65
|
+
feats["atom_pad_mask"][i][mask_chain_atom.bool()][None, :]
|
66
|
+
* feats["atom_pad_mask"][i][mask_chain_atom.bool()][:, None]
|
67
|
+
).to(torch.float32)
|
68
|
+
resolved_pair[resolved_pair == 1] = torch.inf
|
69
|
+
indices = torch.sort(dist_mat + resolved_pair, axis=2).indices
|
70
|
+
else:
|
71
|
+
if resolved_mask is None:
|
72
|
+
resolved_mask = feats["atom_resolved_mask"]
|
73
|
+
resolved_pair = 1 - (
|
74
|
+
resolved_mask[i][mask_chain_atom.bool()][None, :]
|
75
|
+
* resolved_mask[i][mask_chain_atom.bool()][:, None]
|
76
|
+
).to(torch.float32)
|
77
|
+
resolved_pair[resolved_pair == 1] = torch.inf
|
78
|
+
indices = torch.sort(dist_mat + resolved_pair, axis=2).indices
|
79
|
+
frames = (
|
80
|
+
torch.cat(
|
81
|
+
[
|
82
|
+
indices[:, :, 1:2],
|
83
|
+
indices[:, :, 0:1],
|
84
|
+
indices[:, :, 2:3],
|
85
|
+
],
|
86
|
+
dim=2,
|
87
|
+
)
|
88
|
+
+ atom_idx
|
89
|
+
)
|
90
|
+
try:
|
91
|
+
frames_idx_pred[i, :, token_idx : token_idx + num_atoms, :] = frames
|
92
|
+
except Exception as e:
|
93
|
+
print(f"Failed to process {feats['pdb_id']} due to {e}")
|
94
|
+
token_idx += num_tokens
|
95
|
+
atom_idx += num_atoms
|
96
|
+
|
97
|
+
frames_expanded = pred_atom_coords[
|
98
|
+
torch.arange(0, B // multiplicity, 1)[:, None, None, None].to(
|
99
|
+
frames_idx_pred.device
|
100
|
+
),
|
101
|
+
torch.arange(0, multiplicity, 1)[None, :, None, None].to(
|
102
|
+
frames_idx_pred.device
|
103
|
+
),
|
104
|
+
frames_idx_pred,
|
105
|
+
].reshape(-1, 3, 3)
|
106
|
+
|
107
|
+
# Compute masks for collinearity / overlap
|
108
|
+
mask_collinear_pred = compute_collinear_mask(
|
109
|
+
frames_expanded[:, 1] - frames_expanded[:, 0],
|
110
|
+
frames_expanded[:, 1] - frames_expanded[:, 2],
|
111
|
+
).reshape(B // multiplicity, multiplicity, -1)
|
112
|
+
return frames_idx_pred, mask_collinear_pred * feats["token_pad_mask"][:, None, :]
|
113
|
+
|
114
|
+
|
115
|
+
def compute_aggregated_metric(logits, end=1.0):
|
116
|
+
# Compute aggregated metric from logits
|
117
|
+
num_bins = logits.shape[-1]
|
118
|
+
bin_width = end / num_bins
|
119
|
+
bounds = torch.arange(
|
120
|
+
start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
|
121
|
+
)
|
122
|
+
probs = nn.functional.softmax(logits, dim=-1)
|
123
|
+
plddt = torch.sum(
|
124
|
+
probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
|
125
|
+
dim=-1,
|
126
|
+
)
|
127
|
+
return plddt
|
128
|
+
|
129
|
+
|
130
|
+
def tm_function(d, Nres):
|
131
|
+
d0 = 1.24 * (torch.clip(Nres, min=19) - 15) ** (1 / 3) - 1.8
|
132
|
+
return 1 / (1 + (d / d0) ** 2)
|
133
|
+
|
134
|
+
|
135
|
+
def compute_ptms(logits, x_preds, feats, multiplicity):
|
136
|
+
# It needs to take as input the mask of the frames as they are not used to compute the PTM
|
137
|
+
_, mask_collinear_pred = compute_frame_pred(
|
138
|
+
x_preds, feats["frames_idx"], feats, multiplicity, inference=True
|
139
|
+
)
|
140
|
+
# mask overlapping, collinear tokens and ions (invalid frames)
|
141
|
+
mask_pad = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
|
142
|
+
maski = mask_collinear_pred.reshape(-1, mask_collinear_pred.shape[-1])
|
143
|
+
pair_mask_ptm = maski[:, :, None] * mask_pad[:, None, :] * mask_pad[:, :, None]
|
144
|
+
asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
|
145
|
+
pair_mask_iptm = (
|
146
|
+
maski[:, :, None]
|
147
|
+
* (asym_id[:, None, :] != asym_id[:, :, None])
|
148
|
+
* mask_pad[:, None, :]
|
149
|
+
* mask_pad[:, :, None]
|
150
|
+
)
|
151
|
+
num_bins = logits.shape[-1]
|
152
|
+
bin_width = 32.0 / num_bins
|
153
|
+
end = 32.0
|
154
|
+
pae_value = torch.arange(
|
155
|
+
start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
|
156
|
+
).unsqueeze(0)
|
157
|
+
N_res = mask_pad.sum(dim=-1, keepdim=True)
|
158
|
+
tm_value = tm_function(pae_value, N_res).unsqueeze(1).unsqueeze(2)
|
159
|
+
probs = nn.functional.softmax(logits, dim=-1)
|
160
|
+
tm_expected_value = torch.sum(
|
161
|
+
probs * tm_value,
|
162
|
+
dim=-1,
|
163
|
+
) # shape (B, N, N)
|
164
|
+
ptm = torch.max(
|
165
|
+
torch.sum(tm_expected_value * pair_mask_ptm, dim=-1)
|
166
|
+
/ (torch.sum(pair_mask_ptm, dim=-1) + 1e-5),
|
167
|
+
dim=1,
|
168
|
+
).values
|
169
|
+
iptm = torch.max(
|
170
|
+
torch.sum(tm_expected_value * pair_mask_iptm, dim=-1)
|
171
|
+
/ (torch.sum(pair_mask_iptm, dim=-1) + 1e-5),
|
172
|
+
dim=1,
|
173
|
+
).values
|
174
|
+
|
175
|
+
# compute ligand and protein iPTM
|
176
|
+
token_type = feats["mol_type"]
|
177
|
+
token_type = token_type.repeat_interleave(multiplicity, 0)
|
178
|
+
is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
|
179
|
+
is_protein_token = (token_type == const.chain_type_ids["PROTEIN"]).float()
|
180
|
+
|
181
|
+
ligand_iptm_mask = (
|
182
|
+
maski[:, :, None]
|
183
|
+
* (asym_id[:, None, :] != asym_id[:, :, None])
|
184
|
+
* mask_pad[:, None, :]
|
185
|
+
* mask_pad[:, :, None]
|
186
|
+
* (
|
187
|
+
(is_ligand_token[:, :, None] * is_protein_token[:, None, :])
|
188
|
+
+ (is_protein_token[:, :, None] * is_ligand_token[:, None, :])
|
189
|
+
)
|
190
|
+
)
|
191
|
+
protein_ipmt_mask = (
|
192
|
+
maski[:, :, None]
|
193
|
+
* (asym_id[:, None, :] != asym_id[:, :, None])
|
194
|
+
* mask_pad[:, None, :]
|
195
|
+
* mask_pad[:, :, None]
|
196
|
+
* (is_protein_token[:, :, None] * is_protein_token[:, None, :])
|
197
|
+
)
|
198
|
+
|
199
|
+
ligand_iptm = torch.max(
|
200
|
+
torch.sum(tm_expected_value * ligand_iptm_mask, dim=-1)
|
201
|
+
/ (torch.sum(ligand_iptm_mask, dim=-1) + 1e-5),
|
202
|
+
dim=1,
|
203
|
+
).values
|
204
|
+
protein_iptm = torch.max(
|
205
|
+
torch.sum(tm_expected_value * protein_ipmt_mask, dim=-1)
|
206
|
+
/ (torch.sum(protein_ipmt_mask, dim=-1) + 1e-5),
|
207
|
+
dim=1,
|
208
|
+
).values
|
209
|
+
|
210
|
+
# Compute pair chain ipTM
|
211
|
+
chain_pair_iptm = {}
|
212
|
+
asym_ids_list = torch.unique(asym_id).tolist()
|
213
|
+
for idx1 in asym_ids_list:
|
214
|
+
chain_iptm = {}
|
215
|
+
for idx2 in asym_ids_list:
|
216
|
+
mask_pair_chain = (
|
217
|
+
maski[:, :, None]
|
218
|
+
* (asym_id[:, None, :] == idx1)
|
219
|
+
* (asym_id[:, :, None] == idx2)
|
220
|
+
* mask_pad[:, None, :]
|
221
|
+
* mask_pad[:, :, None]
|
222
|
+
)
|
223
|
+
|
224
|
+
chain_iptm[idx2] = torch.max(
|
225
|
+
torch.sum(tm_expected_value * mask_pair_chain, dim=-1)
|
226
|
+
/ (torch.sum(mask_pair_chain, dim=-1) + 1e-5),
|
227
|
+
dim=1,
|
228
|
+
).values
|
229
|
+
chain_pair_iptm[idx1] = chain_iptm
|
230
|
+
|
231
|
+
return ptm, iptm, ligand_iptm, protein_iptm, chain_pair_iptm
|
@@ -0,0 +1,34 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import Tensor
|
3
|
+
|
4
|
+
|
5
|
+
def get_dropout_mask(
|
6
|
+
dropout: float,
|
7
|
+
z: Tensor,
|
8
|
+
training: bool,
|
9
|
+
columnwise: bool = False,
|
10
|
+
) -> Tensor:
|
11
|
+
"""Get the dropout mask.
|
12
|
+
|
13
|
+
Parameters
|
14
|
+
----------
|
15
|
+
dropout : float
|
16
|
+
The dropout rate
|
17
|
+
z : torch.Tensor
|
18
|
+
The tensor to apply dropout to
|
19
|
+
training : bool
|
20
|
+
Whether the model is in training mode
|
21
|
+
columnwise : bool, optional
|
22
|
+
Whether to apply dropout columnwise
|
23
|
+
|
24
|
+
Returns
|
25
|
+
-------
|
26
|
+
torch.Tensor
|
27
|
+
The dropout mask
|
28
|
+
|
29
|
+
"""
|
30
|
+
dropout = dropout * training
|
31
|
+
v = z[:, 0:1, :, 0:1] if columnwise else z[:, :, 0:1, 0:1]
|
32
|
+
d = torch.rand_like(v) > dropout
|
33
|
+
d = d * 1.0 / (1.0 - dropout)
|
34
|
+
return d
|
@@ -0,0 +1,100 @@
|
|
1
|
+
"""Utility functions for initializing weights and biases."""
|
2
|
+
|
3
|
+
# Copyright 2021 AlQuraishi Laboratory
|
4
|
+
# Copyright 2021 DeepMind Technologies Limited
|
5
|
+
#
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License.
|
8
|
+
# You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
+
# See the License for the specific language governing permissions and
|
16
|
+
# limitations under the License.
|
17
|
+
|
18
|
+
import math
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
import torch
|
22
|
+
from scipy.stats import truncnorm
|
23
|
+
|
24
|
+
|
25
|
+
def _prod(nums):
|
26
|
+
out = 1
|
27
|
+
for n in nums:
|
28
|
+
out = out * n
|
29
|
+
return out
|
30
|
+
|
31
|
+
|
32
|
+
def _calculate_fan(linear_weight_shape, fan="fan_in"):
|
33
|
+
fan_out, fan_in = linear_weight_shape
|
34
|
+
|
35
|
+
if fan == "fan_in":
|
36
|
+
f = fan_in
|
37
|
+
elif fan == "fan_out":
|
38
|
+
f = fan_out
|
39
|
+
elif fan == "fan_avg":
|
40
|
+
f = (fan_in + fan_out) / 2
|
41
|
+
else:
|
42
|
+
raise ValueError("Invalid fan option")
|
43
|
+
|
44
|
+
return f
|
45
|
+
|
46
|
+
|
47
|
+
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
|
48
|
+
shape = weights.shape
|
49
|
+
f = _calculate_fan(shape, fan)
|
50
|
+
scale = scale / max(1, f)
|
51
|
+
a = -2
|
52
|
+
b = 2
|
53
|
+
std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
|
54
|
+
size = _prod(shape)
|
55
|
+
samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
|
56
|
+
samples = np.reshape(samples, shape)
|
57
|
+
with torch.no_grad():
|
58
|
+
weights.copy_(torch.tensor(samples, device=weights.device))
|
59
|
+
|
60
|
+
|
61
|
+
def lecun_normal_init_(weights):
|
62
|
+
trunc_normal_init_(weights, scale=1.0)
|
63
|
+
|
64
|
+
|
65
|
+
def he_normal_init_(weights):
|
66
|
+
trunc_normal_init_(weights, scale=2.0)
|
67
|
+
|
68
|
+
|
69
|
+
def glorot_uniform_init_(weights):
|
70
|
+
torch.nn.init.xavier_uniform_(weights, gain=1)
|
71
|
+
|
72
|
+
|
73
|
+
def final_init_(weights):
|
74
|
+
with torch.no_grad():
|
75
|
+
weights.fill_(0.0)
|
76
|
+
|
77
|
+
|
78
|
+
def gating_init_(weights):
|
79
|
+
with torch.no_grad():
|
80
|
+
weights.fill_(0.0)
|
81
|
+
|
82
|
+
|
83
|
+
def bias_init_zero_(bias):
|
84
|
+
with torch.no_grad():
|
85
|
+
bias.fill_(0.0)
|
86
|
+
|
87
|
+
|
88
|
+
def bias_init_one_(bias):
|
89
|
+
with torch.no_grad():
|
90
|
+
bias.fill_(1.0)
|
91
|
+
|
92
|
+
|
93
|
+
def normal_init_(weights):
|
94
|
+
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
|
95
|
+
|
96
|
+
|
97
|
+
def ipa_point_weights_init_(weights):
|
98
|
+
with torch.no_grad():
|
99
|
+
softplus_inverse_1 = 0.541324854612918
|
100
|
+
weights.fill_(softplus_inverse_1)
|