boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
File without changes
File without changes
@@ -0,0 +1,132 @@
1
+ import torch
2
+ from einops.layers.torch import Rearrange
3
+ from torch import Tensor, nn
4
+
5
+ import boltz.model.layers.initialize as init
6
+
7
+
8
+ class AttentionPairBias(nn.Module):
9
+ """Attention pair bias layer."""
10
+
11
+ def __init__(
12
+ self,
13
+ c_s: int,
14
+ c_z: int,
15
+ num_heads: int,
16
+ inf: float = 1e6,
17
+ initial_norm: bool = True,
18
+ ) -> None:
19
+ """Initialize the attention pair bias layer.
20
+
21
+ Parameters
22
+ ----------
23
+ c_s : int
24
+ The input sequence dimension.
25
+ c_z : int
26
+ The input pairwise dimension.
27
+ num_heads : int
28
+ The number of heads.
29
+ inf : float, optional
30
+ The inf value, by default 1e6
31
+ initial_norm: bool, optional
32
+ Whether to apply layer norm to the input, by default True
33
+
34
+ """
35
+ super().__init__()
36
+
37
+ assert c_s % num_heads == 0
38
+
39
+ self.c_s = c_s
40
+ self.num_heads = num_heads
41
+ self.head_dim = c_s // num_heads
42
+ self.inf = inf
43
+
44
+ self.initial_norm = initial_norm
45
+ if self.initial_norm:
46
+ self.norm_s = nn.LayerNorm(c_s)
47
+
48
+ self.proj_q = nn.Linear(c_s, c_s)
49
+ self.proj_k = nn.Linear(c_s, c_s, bias=False)
50
+ self.proj_v = nn.Linear(c_s, c_s, bias=False)
51
+ self.proj_g = nn.Linear(c_s, c_s, bias=False)
52
+
53
+ self.proj_z = nn.Sequential(
54
+ nn.LayerNorm(c_z),
55
+ nn.Linear(c_z, num_heads, bias=False),
56
+ Rearrange("b ... h -> b h ..."),
57
+ )
58
+
59
+ self.proj_o = nn.Linear(c_s, c_s, bias=False)
60
+ init.final_init_(self.proj_o.weight)
61
+
62
+ def forward(
63
+ self,
64
+ s: Tensor,
65
+ z: Tensor,
66
+ mask: Tensor,
67
+ multiplicity: int = 1,
68
+ to_keys=None,
69
+ model_cache=None,
70
+ ) -> Tensor:
71
+ """Forward pass.
72
+
73
+ Parameters
74
+ ----------
75
+ s : torch.Tensor
76
+ The input sequence tensor (B, S, D)
77
+ z : torch.Tensor
78
+ The input pairwise tensor (B, N, N, D)
79
+ mask : torch.Tensor
80
+ The pairwise mask tensor (B, N)
81
+ multiplicity : int, optional
82
+ The diffusion batch size, by default 1
83
+
84
+ Returns
85
+ -------
86
+ torch.Tensor
87
+ The output sequence tensor.
88
+
89
+ """
90
+ B = s.shape[0]
91
+
92
+ # Layer norms
93
+ if self.initial_norm:
94
+ s = self.norm_s(s)
95
+
96
+ if to_keys is not None:
97
+ k_in = to_keys(s)
98
+ mask = to_keys(mask.unsqueeze(-1)).squeeze(-1)
99
+ else:
100
+ k_in = s
101
+
102
+ # Compute projections
103
+ q = self.proj_q(s).view(B, -1, self.num_heads, self.head_dim)
104
+ k = self.proj_k(k_in).view(B, -1, self.num_heads, self.head_dim)
105
+ v = self.proj_v(k_in).view(B, -1, self.num_heads, self.head_dim)
106
+
107
+ # Caching z projection during diffusion roll-out
108
+ if model_cache is None or "z" not in model_cache:
109
+ z = self.proj_z(z)
110
+
111
+ if model_cache is not None:
112
+ model_cache["z"] = z
113
+ else:
114
+ z = model_cache["z"]
115
+ z = z.repeat_interleave(multiplicity, 0)
116
+
117
+ g = self.proj_g(s).sigmoid()
118
+
119
+ with torch.autocast("cuda", enabled=False):
120
+ # Compute attention weights
121
+ attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float())
122
+ attn = attn / (self.head_dim**0.5) + z.float()
123
+ # The pairwise mask tensor (B, N) is broadcasted to (B, 1, 1, N) and (B, H, N, N)
124
+ attn = attn + (1 - mask[:, None, None].float()) * -self.inf
125
+ attn = attn.softmax(dim=-1)
126
+
127
+ # Compute output
128
+ o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype)
129
+ o = o.reshape(B, -1, self.c_s)
130
+ o = self.proj_o(g * o)
131
+
132
+ return o
@@ -0,0 +1,111 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from einops.layers.torch import Rearrange
5
+ from torch import Tensor, nn
6
+
7
+ import boltz.model.layers.initialize as init
8
+
9
+
10
+ class AttentionPairBias(nn.Module):
11
+ """Attention pair bias layer."""
12
+
13
+ def __init__(
14
+ self,
15
+ c_s: int,
16
+ c_z: Optional[int] = None,
17
+ num_heads: Optional[int] = None,
18
+ inf: float = 1e6,
19
+ compute_pair_bias: bool = True,
20
+ ) -> None:
21
+ """Initialize the attention pair bias layer.
22
+
23
+ Parameters
24
+ ----------
25
+ c_s : int
26
+ The input sequence dimension.
27
+ c_z : int
28
+ The input pairwise dimension.
29
+ num_heads : int
30
+ The number of heads.
31
+ inf : float, optional
32
+ The inf value, by default 1e6
33
+
34
+ """
35
+ super().__init__()
36
+
37
+ assert c_s % num_heads == 0
38
+
39
+ self.c_s = c_s
40
+ self.num_heads = num_heads
41
+ self.head_dim = c_s // num_heads
42
+ self.inf = inf
43
+
44
+ self.proj_q = nn.Linear(c_s, c_s)
45
+ self.proj_k = nn.Linear(c_s, c_s, bias=False)
46
+ self.proj_v = nn.Linear(c_s, c_s, bias=False)
47
+ self.proj_g = nn.Linear(c_s, c_s, bias=False)
48
+
49
+ self.compute_pair_bias = compute_pair_bias
50
+ if compute_pair_bias:
51
+ self.proj_z = nn.Sequential(
52
+ nn.LayerNorm(c_z),
53
+ nn.Linear(c_z, num_heads, bias=False),
54
+ Rearrange("b ... h -> b h ..."),
55
+ )
56
+ else:
57
+ self.proj_z = Rearrange("b ... h -> b h ...")
58
+
59
+ self.proj_o = nn.Linear(c_s, c_s, bias=False)
60
+ init.final_init_(self.proj_o.weight)
61
+
62
+ def forward(
63
+ self,
64
+ s: Tensor,
65
+ z: Tensor,
66
+ mask: Tensor,
67
+ k_in: Tensor,
68
+ multiplicity: int = 1,
69
+ ) -> Tensor:
70
+ """Forward pass.
71
+
72
+ Parameters
73
+ ----------
74
+ s : torch.Tensor
75
+ The input sequence tensor (B, S, D)
76
+ z : torch.Tensor
77
+ The input pairwise tensor or bias (B, N, N, D)
78
+ mask : torch.Tensor
79
+ The pairwise mask tensor (B, N, N)
80
+
81
+ Returns
82
+ -------
83
+ torch.Tensor
84
+ The output sequence tensor.
85
+
86
+ """
87
+ B = s.shape[0]
88
+
89
+ # Compute projections
90
+ q = self.proj_q(s).view(B, -1, self.num_heads, self.head_dim)
91
+ k = self.proj_k(k_in).view(B, -1, self.num_heads, self.head_dim)
92
+ v = self.proj_v(k_in).view(B, -1, self.num_heads, self.head_dim)
93
+
94
+ bias = self.proj_z(z)
95
+ bias = bias.repeat_interleave(multiplicity, 0)
96
+
97
+ g = self.proj_g(s).sigmoid()
98
+
99
+ with torch.autocast("cuda", enabled=False):
100
+ # Compute attention weights
101
+ attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float())
102
+ attn = attn / (self.head_dim**0.5) + bias.float()
103
+ attn = attn + (1 - mask[:, None, None].float()) * -self.inf
104
+ attn = attn.softmax(dim=-1)
105
+
106
+ # Compute output
107
+ o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype)
108
+ o = o.reshape(B, -1, self.c_s)
109
+ o = self.proj_o(g * o)
110
+
111
+ return o
@@ -0,0 +1,231 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from boltz.data import const
5
+
6
+
7
+ def compute_collinear_mask(v1, v2):
8
+ norm1 = torch.norm(v1, dim=1, keepdim=True)
9
+ norm2 = torch.norm(v2, dim=1, keepdim=True)
10
+ v1 = v1 / (norm1 + 1e-6)
11
+ v2 = v2 / (norm2 + 1e-6)
12
+ mask_angle = torch.abs(torch.sum(v1 * v2, dim=1)) < 0.9063
13
+ mask_overlap1 = norm1.reshape(-1) > 1e-2
14
+ mask_overlap2 = norm2.reshape(-1) > 1e-2
15
+ return mask_angle & mask_overlap1 & mask_overlap2
16
+
17
+
18
+ def compute_frame_pred(
19
+ pred_atom_coords,
20
+ frames_idx_true,
21
+ feats,
22
+ multiplicity,
23
+ resolved_mask=None,
24
+ inference=False,
25
+ ):
26
+ with torch.amp.autocast("cuda", enabled=False):
27
+ asym_id_token = feats["asym_id"]
28
+ asym_id_atom = torch.bmm(
29
+ feats["atom_to_token"].float(), asym_id_token.unsqueeze(-1).float()
30
+ ).squeeze(-1)
31
+
32
+ B, N, _ = pred_atom_coords.shape
33
+ pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
34
+ frames_idx_pred = (
35
+ frames_idx_true.clone()
36
+ .repeat_interleave(multiplicity, 0)
37
+ .reshape(B // multiplicity, multiplicity, -1, 3)
38
+ )
39
+
40
+ # Iterate through the batch and modify the frames for nonpolymers
41
+ for i, pred_atom_coord in enumerate(pred_atom_coords):
42
+ token_idx = 0
43
+ atom_idx = 0
44
+ for id in torch.unique(asym_id_token[i]):
45
+ mask_chain_token = (asym_id_token[i] == id) * feats["token_pad_mask"][i]
46
+ mask_chain_atom = (asym_id_atom[i] == id) * feats["atom_pad_mask"][i]
47
+ num_tokens = int(mask_chain_token.sum().item())
48
+ num_atoms = int(mask_chain_atom.sum().item())
49
+ if (
50
+ feats["mol_type"][i, token_idx] != const.chain_type_ids["NONPOLYMER"]
51
+ or num_atoms < 3
52
+ ):
53
+ token_idx += num_tokens
54
+ atom_idx += num_atoms
55
+ continue
56
+ dist_mat = (
57
+ (
58
+ pred_atom_coord[:, mask_chain_atom.bool()][:, None, :, :]
59
+ - pred_atom_coord[:, mask_chain_atom.bool()][:, :, None, :]
60
+ )
61
+ ** 2
62
+ ).sum(-1) ** 0.5
63
+ if inference:
64
+ resolved_pair = 1 - (
65
+ feats["atom_pad_mask"][i][mask_chain_atom.bool()][None, :]
66
+ * feats["atom_pad_mask"][i][mask_chain_atom.bool()][:, None]
67
+ ).to(torch.float32)
68
+ resolved_pair[resolved_pair == 1] = torch.inf
69
+ indices = torch.sort(dist_mat + resolved_pair, axis=2).indices
70
+ else:
71
+ if resolved_mask is None:
72
+ resolved_mask = feats["atom_resolved_mask"]
73
+ resolved_pair = 1 - (
74
+ resolved_mask[i][mask_chain_atom.bool()][None, :]
75
+ * resolved_mask[i][mask_chain_atom.bool()][:, None]
76
+ ).to(torch.float32)
77
+ resolved_pair[resolved_pair == 1] = torch.inf
78
+ indices = torch.sort(dist_mat + resolved_pair, axis=2).indices
79
+ frames = (
80
+ torch.cat(
81
+ [
82
+ indices[:, :, 1:2],
83
+ indices[:, :, 0:1],
84
+ indices[:, :, 2:3],
85
+ ],
86
+ dim=2,
87
+ )
88
+ + atom_idx
89
+ )
90
+ try:
91
+ frames_idx_pred[i, :, token_idx : token_idx + num_atoms, :] = frames
92
+ except Exception as e:
93
+ print(f"Failed to process {feats['pdb_id']} due to {e}")
94
+ token_idx += num_tokens
95
+ atom_idx += num_atoms
96
+
97
+ frames_expanded = pred_atom_coords[
98
+ torch.arange(0, B // multiplicity, 1)[:, None, None, None].to(
99
+ frames_idx_pred.device
100
+ ),
101
+ torch.arange(0, multiplicity, 1)[None, :, None, None].to(
102
+ frames_idx_pred.device
103
+ ),
104
+ frames_idx_pred,
105
+ ].reshape(-1, 3, 3)
106
+
107
+ # Compute masks for collinearity / overlap
108
+ mask_collinear_pred = compute_collinear_mask(
109
+ frames_expanded[:, 1] - frames_expanded[:, 0],
110
+ frames_expanded[:, 1] - frames_expanded[:, 2],
111
+ ).reshape(B // multiplicity, multiplicity, -1)
112
+ return frames_idx_pred, mask_collinear_pred * feats["token_pad_mask"][:, None, :]
113
+
114
+
115
+ def compute_aggregated_metric(logits, end=1.0):
116
+ # Compute aggregated metric from logits
117
+ num_bins = logits.shape[-1]
118
+ bin_width = end / num_bins
119
+ bounds = torch.arange(
120
+ start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
121
+ )
122
+ probs = nn.functional.softmax(logits, dim=-1)
123
+ plddt = torch.sum(
124
+ probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
125
+ dim=-1,
126
+ )
127
+ return plddt
128
+
129
+
130
+ def tm_function(d, Nres):
131
+ d0 = 1.24 * (torch.clip(Nres, min=19) - 15) ** (1 / 3) - 1.8
132
+ return 1 / (1 + (d / d0) ** 2)
133
+
134
+
135
+ def compute_ptms(logits, x_preds, feats, multiplicity):
136
+ # It needs to take as input the mask of the frames as they are not used to compute the PTM
137
+ _, mask_collinear_pred = compute_frame_pred(
138
+ x_preds, feats["frames_idx"], feats, multiplicity, inference=True
139
+ )
140
+ # mask overlapping, collinear tokens and ions (invalid frames)
141
+ mask_pad = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
142
+ maski = mask_collinear_pred.reshape(-1, mask_collinear_pred.shape[-1])
143
+ pair_mask_ptm = maski[:, :, None] * mask_pad[:, None, :] * mask_pad[:, :, None]
144
+ asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
145
+ pair_mask_iptm = (
146
+ maski[:, :, None]
147
+ * (asym_id[:, None, :] != asym_id[:, :, None])
148
+ * mask_pad[:, None, :]
149
+ * mask_pad[:, :, None]
150
+ )
151
+ num_bins = logits.shape[-1]
152
+ bin_width = 32.0 / num_bins
153
+ end = 32.0
154
+ pae_value = torch.arange(
155
+ start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
156
+ ).unsqueeze(0)
157
+ N_res = mask_pad.sum(dim=-1, keepdim=True)
158
+ tm_value = tm_function(pae_value, N_res).unsqueeze(1).unsqueeze(2)
159
+ probs = nn.functional.softmax(logits, dim=-1)
160
+ tm_expected_value = torch.sum(
161
+ probs * tm_value,
162
+ dim=-1,
163
+ ) # shape (B, N, N)
164
+ ptm = torch.max(
165
+ torch.sum(tm_expected_value * pair_mask_ptm, dim=-1)
166
+ / (torch.sum(pair_mask_ptm, dim=-1) + 1e-5),
167
+ dim=1,
168
+ ).values
169
+ iptm = torch.max(
170
+ torch.sum(tm_expected_value * pair_mask_iptm, dim=-1)
171
+ / (torch.sum(pair_mask_iptm, dim=-1) + 1e-5),
172
+ dim=1,
173
+ ).values
174
+
175
+ # compute ligand and protein iPTM
176
+ token_type = feats["mol_type"]
177
+ token_type = token_type.repeat_interleave(multiplicity, 0)
178
+ is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
179
+ is_protein_token = (token_type == const.chain_type_ids["PROTEIN"]).float()
180
+
181
+ ligand_iptm_mask = (
182
+ maski[:, :, None]
183
+ * (asym_id[:, None, :] != asym_id[:, :, None])
184
+ * mask_pad[:, None, :]
185
+ * mask_pad[:, :, None]
186
+ * (
187
+ (is_ligand_token[:, :, None] * is_protein_token[:, None, :])
188
+ + (is_protein_token[:, :, None] * is_ligand_token[:, None, :])
189
+ )
190
+ )
191
+ protein_ipmt_mask = (
192
+ maski[:, :, None]
193
+ * (asym_id[:, None, :] != asym_id[:, :, None])
194
+ * mask_pad[:, None, :]
195
+ * mask_pad[:, :, None]
196
+ * (is_protein_token[:, :, None] * is_protein_token[:, None, :])
197
+ )
198
+
199
+ ligand_iptm = torch.max(
200
+ torch.sum(tm_expected_value * ligand_iptm_mask, dim=-1)
201
+ / (torch.sum(ligand_iptm_mask, dim=-1) + 1e-5),
202
+ dim=1,
203
+ ).values
204
+ protein_iptm = torch.max(
205
+ torch.sum(tm_expected_value * protein_ipmt_mask, dim=-1)
206
+ / (torch.sum(protein_ipmt_mask, dim=-1) + 1e-5),
207
+ dim=1,
208
+ ).values
209
+
210
+ # Compute pair chain ipTM
211
+ chain_pair_iptm = {}
212
+ asym_ids_list = torch.unique(asym_id).tolist()
213
+ for idx1 in asym_ids_list:
214
+ chain_iptm = {}
215
+ for idx2 in asym_ids_list:
216
+ mask_pair_chain = (
217
+ maski[:, :, None]
218
+ * (asym_id[:, None, :] == idx1)
219
+ * (asym_id[:, :, None] == idx2)
220
+ * mask_pad[:, None, :]
221
+ * mask_pad[:, :, None]
222
+ )
223
+
224
+ chain_iptm[idx2] = torch.max(
225
+ torch.sum(tm_expected_value * mask_pair_chain, dim=-1)
226
+ / (torch.sum(mask_pair_chain, dim=-1) + 1e-5),
227
+ dim=1,
228
+ ).values
229
+ chain_pair_iptm[idx1] = chain_iptm
230
+
231
+ return ptm, iptm, ligand_iptm, protein_iptm, chain_pair_iptm
@@ -0,0 +1,34 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ def get_dropout_mask(
6
+ dropout: float,
7
+ z: Tensor,
8
+ training: bool,
9
+ columnwise: bool = False,
10
+ ) -> Tensor:
11
+ """Get the dropout mask.
12
+
13
+ Parameters
14
+ ----------
15
+ dropout : float
16
+ The dropout rate
17
+ z : torch.Tensor
18
+ The tensor to apply dropout to
19
+ training : bool
20
+ Whether the model is in training mode
21
+ columnwise : bool, optional
22
+ Whether to apply dropout columnwise
23
+
24
+ Returns
25
+ -------
26
+ torch.Tensor
27
+ The dropout mask
28
+
29
+ """
30
+ dropout = dropout * training
31
+ v = z[:, 0:1, :, 0:1] if columnwise else z[:, :, 0:1, 0:1]
32
+ d = torch.rand_like(v) > dropout
33
+ d = d * 1.0 / (1.0 - dropout)
34
+ return d
@@ -0,0 +1,100 @@
1
+ """Utility functions for initializing weights and biases."""
2
+
3
+ # Copyright 2021 AlQuraishi Laboratory
4
+ # Copyright 2021 DeepMind Technologies Limited
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import math
19
+
20
+ import numpy as np
21
+ import torch
22
+ from scipy.stats import truncnorm
23
+
24
+
25
+ def _prod(nums):
26
+ out = 1
27
+ for n in nums:
28
+ out = out * n
29
+ return out
30
+
31
+
32
+ def _calculate_fan(linear_weight_shape, fan="fan_in"):
33
+ fan_out, fan_in = linear_weight_shape
34
+
35
+ if fan == "fan_in":
36
+ f = fan_in
37
+ elif fan == "fan_out":
38
+ f = fan_out
39
+ elif fan == "fan_avg":
40
+ f = (fan_in + fan_out) / 2
41
+ else:
42
+ raise ValueError("Invalid fan option")
43
+
44
+ return f
45
+
46
+
47
+ def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
48
+ shape = weights.shape
49
+ f = _calculate_fan(shape, fan)
50
+ scale = scale / max(1, f)
51
+ a = -2
52
+ b = 2
53
+ std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
54
+ size = _prod(shape)
55
+ samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
56
+ samples = np.reshape(samples, shape)
57
+ with torch.no_grad():
58
+ weights.copy_(torch.tensor(samples, device=weights.device))
59
+
60
+
61
+ def lecun_normal_init_(weights):
62
+ trunc_normal_init_(weights, scale=1.0)
63
+
64
+
65
+ def he_normal_init_(weights):
66
+ trunc_normal_init_(weights, scale=2.0)
67
+
68
+
69
+ def glorot_uniform_init_(weights):
70
+ torch.nn.init.xavier_uniform_(weights, gain=1)
71
+
72
+
73
+ def final_init_(weights):
74
+ with torch.no_grad():
75
+ weights.fill_(0.0)
76
+
77
+
78
+ def gating_init_(weights):
79
+ with torch.no_grad():
80
+ weights.fill_(0.0)
81
+
82
+
83
+ def bias_init_zero_(bias):
84
+ with torch.no_grad():
85
+ bias.fill_(0.0)
86
+
87
+
88
+ def bias_init_one_(bias):
89
+ with torch.no_grad():
90
+ bias.fill_(1.0)
91
+
92
+
93
+ def normal_init_(weights):
94
+ torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
95
+
96
+
97
+ def ipa_point_weights_init_(weights):
98
+ with torch.no_grad():
99
+ softplus_inverse_1 = 0.541324854612918
100
+ weights.fill_(softplus_inverse_1)