boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,223 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ import boltz.model.layers.initialize as init
5
+ from boltz.model.layers.pairformer import PairformerNoSeqModule
6
+ from boltz.model.modules.encodersv2 import PairwiseConditioning
7
+ from boltz.model.modules.transformersv2 import DiffusionTransformer
8
+ from boltz.model.modules.utils import LinearNoBias
9
+
10
+
11
+ class GaussianSmearing(torch.nn.Module):
12
+ """Gaussian smearing."""
13
+
14
+ def __init__(
15
+ self,
16
+ start: float = 0.0,
17
+ stop: float = 5.0,
18
+ num_gaussians: int = 50,
19
+ ) -> None:
20
+ super().__init__()
21
+ offset = torch.linspace(start, stop, num_gaussians)
22
+ self.num_gaussians = num_gaussians
23
+ self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
24
+ self.register_buffer("offset", offset)
25
+
26
+ def forward(self, dist):
27
+ shape = dist.shape
28
+ dist = dist.view(-1, 1) - self.offset.view(1, -1)
29
+ return torch.exp(self.coeff * torch.pow(dist, 2)).reshape(
30
+ *shape, self.num_gaussians
31
+ )
32
+
33
+
34
+ class AffinityModule(nn.Module):
35
+ """Algorithm 31"""
36
+
37
+ def __init__(
38
+ self,
39
+ token_s,
40
+ token_z,
41
+ pairformer_args: dict,
42
+ transformer_args: dict,
43
+ num_dist_bins=64,
44
+ max_dist=22,
45
+ use_cross_transformer: bool = False,
46
+ groups: dict = {},
47
+ ):
48
+ super().__init__()
49
+ boundaries = torch.linspace(2, max_dist, num_dist_bins - 1)
50
+ self.register_buffer("boundaries", boundaries)
51
+ self.dist_bin_pairwise_embed = nn.Embedding(num_dist_bins, token_z)
52
+ init.gating_init_(self.dist_bin_pairwise_embed.weight)
53
+
54
+ self.s_to_z_prod_in1 = LinearNoBias(token_s, token_z)
55
+ self.s_to_z_prod_in2 = LinearNoBias(token_s, token_z)
56
+
57
+ self.z_norm = nn.LayerNorm(token_z)
58
+ self.z_linear = LinearNoBias(token_z, token_z)
59
+
60
+ self.pairwise_conditioner = PairwiseConditioning(
61
+ token_z=token_z,
62
+ dim_token_rel_pos_feats=token_z,
63
+ num_transitions=2,
64
+ )
65
+
66
+ self.pairformer_stack = PairformerNoSeqModule(token_z, **pairformer_args)
67
+ self.affinity_heads = AffinityHeadsTransformer(
68
+ token_z,
69
+ transformer_args["token_s"],
70
+ transformer_args["num_blocks"],
71
+ transformer_args["num_heads"],
72
+ transformer_args["activation_checkpointing"],
73
+ False,
74
+ groups=groups,
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ s_inputs,
80
+ z,
81
+ x_pred,
82
+ feats,
83
+ multiplicity=1,
84
+ use_kernels=False,
85
+ ):
86
+ z = self.z_linear(self.z_norm(z))
87
+ z = z.repeat_interleave(multiplicity, 0)
88
+
89
+ z = (
90
+ z
91
+ + self.s_to_z_prod_in1(s_inputs)[:, :, None, :]
92
+ + self.s_to_z_prod_in2(s_inputs)[:, None, :, :]
93
+ )
94
+
95
+ token_to_rep_atom = feats["token_to_rep_atom"]
96
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
97
+ if len(x_pred.shape) == 4:
98
+ B, mult, N, _ = x_pred.shape
99
+ x_pred = x_pred.reshape(B * mult, N, -1)
100
+ else:
101
+ BM, N, _ = x_pred.shape
102
+ B = BM // multiplicity
103
+ mult = multiplicity
104
+ x_pred_repr = torch.bmm(token_to_rep_atom.float(), x_pred)
105
+ d = torch.cdist(x_pred_repr, x_pred_repr)
106
+
107
+ distogram = (d.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
108
+ distogram = self.dist_bin_pairwise_embed(distogram)
109
+
110
+ z = z + self.pairwise_conditioner(z_trunk=z, token_rel_pos_feats=distogram)
111
+
112
+ pad_token_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
113
+ rec_mask = (feats["mol_type"] == 0).repeat_interleave(multiplicity, 0)
114
+ rec_mask = rec_mask * pad_token_mask
115
+ lig_mask = (
116
+ feats["affinity_token_mask"]
117
+ .repeat_interleave(multiplicity, 0)
118
+ .to(torch.bool)
119
+ )
120
+ lig_mask = lig_mask * pad_token_mask
121
+ cross_pair_mask = (
122
+ lig_mask[:, :, None] * rec_mask[:, None, :]
123
+ + rec_mask[:, :, None] * lig_mask[:, None, :]
124
+ + lig_mask[:, :, None] * lig_mask[:, None, :]
125
+ )
126
+ z = self.pairformer_stack(
127
+ z,
128
+ pair_mask=cross_pair_mask,
129
+ use_kernels=use_kernels,
130
+ )
131
+
132
+ out_dict = {}
133
+
134
+ # affinity heads
135
+ out_dict.update(
136
+ self.affinity_heads(z=z, feats=feats, multiplicity=multiplicity)
137
+ )
138
+
139
+ return out_dict
140
+
141
+
142
+ class AffinityHeadsTransformer(nn.Module):
143
+ def __init__(
144
+ self,
145
+ token_z,
146
+ input_token_s,
147
+ num_blocks,
148
+ num_heads,
149
+ activation_checkpointing,
150
+ use_cross_transformer,
151
+ groups={},
152
+ ):
153
+ super().__init__()
154
+ self.affinity_out_mlp = nn.Sequential(
155
+ nn.Linear(token_z, token_z),
156
+ nn.ReLU(),
157
+ nn.Linear(token_z, input_token_s),
158
+ nn.ReLU(),
159
+ )
160
+
161
+ self.to_affinity_pred_value = nn.Sequential(
162
+ nn.Linear(input_token_s, input_token_s),
163
+ nn.ReLU(),
164
+ nn.Linear(input_token_s, input_token_s),
165
+ nn.ReLU(),
166
+ nn.Linear(input_token_s, 1),
167
+ )
168
+
169
+ self.to_affinity_pred_score = nn.Sequential(
170
+ nn.Linear(input_token_s, input_token_s),
171
+ nn.ReLU(),
172
+ nn.Linear(input_token_s, input_token_s),
173
+ nn.ReLU(),
174
+ nn.Linear(input_token_s, 1),
175
+ )
176
+ self.to_affinity_logits_binary = nn.Linear(1, 1)
177
+
178
+ def forward(
179
+ self,
180
+ z,
181
+ feats,
182
+ multiplicity=1,
183
+ ):
184
+ pad_token_mask = (
185
+ feats["token_pad_mask"].repeat_interleave(multiplicity, 0).unsqueeze(-1)
186
+ )
187
+ rec_mask = (
188
+ (feats["mol_type"] == 0).repeat_interleave(multiplicity, 0).unsqueeze(-1)
189
+ )
190
+ rec_mask = rec_mask * pad_token_mask
191
+ lig_mask = (
192
+ feats["affinity_token_mask"]
193
+ .repeat_interleave(multiplicity, 0)
194
+ .to(torch.bool)
195
+ .unsqueeze(-1)
196
+ ) * pad_token_mask
197
+ cross_pair_mask = (
198
+ lig_mask[:, :, None] * rec_mask[:, None, :]
199
+ + rec_mask[:, :, None] * lig_mask[:, None, :]
200
+ + (lig_mask[:, :, None] * lig_mask[:, None, :])
201
+ ) * (
202
+ 1
203
+ - torch.eye(lig_mask.shape[1], device=lig_mask.device)
204
+ .unsqueeze(-1)
205
+ .unsqueeze(0)
206
+ )
207
+
208
+ g = torch.sum(z * cross_pair_mask, dim=(1, 2)) / (
209
+ torch.sum(cross_pair_mask, dim=(1, 2)) + 1e-7
210
+ )
211
+
212
+ g = self.affinity_out_mlp(g)
213
+
214
+ affinity_pred_value = self.to_affinity_pred_value(g).reshape(-1, 1)
215
+ affinity_pred_score = self.to_affinity_pred_score(g).reshape(-1, 1)
216
+ affinity_logits_binary = self.to_affinity_logits_binary(
217
+ affinity_pred_score
218
+ ).reshape(-1, 1)
219
+ out_dict = {
220
+ "affinity_pred_value": affinity_pred_value,
221
+ "affinity_logits_binary": affinity_logits_binary,
222
+ }
223
+ return out_dict
@@ -0,0 +1,481 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ import boltz.model.layers.initialize as init
6
+ from boltz.data import const
7
+ from boltz.model.modules.confidence_utils import (
8
+ compute_aggregated_metric,
9
+ compute_ptms,
10
+ )
11
+ from boltz.model.modules.encoders import RelativePositionEncoder
12
+ from boltz.model.modules.trunk import (
13
+ InputEmbedder,
14
+ MSAModule,
15
+ PairformerModule,
16
+ )
17
+ from boltz.model.modules.utils import LinearNoBias
18
+
19
+
20
+ class ConfidenceModule(nn.Module):
21
+ """Confidence module."""
22
+
23
+ def __init__(
24
+ self,
25
+ token_s,
26
+ token_z,
27
+ pairformer_args: dict,
28
+ num_dist_bins=64,
29
+ max_dist=22,
30
+ add_s_to_z_prod=False,
31
+ add_s_input_to_s=False,
32
+ use_s_diffusion=False,
33
+ add_z_input_to_z=False,
34
+ confidence_args: dict = None,
35
+ compute_pae: bool = False,
36
+ imitate_trunk=False,
37
+ full_embedder_args: dict = None,
38
+ msa_args: dict = None,
39
+ compile_pairformer=False,
40
+ ):
41
+ """Initialize the confidence module.
42
+
43
+ Parameters
44
+ ----------
45
+ token_s : int
46
+ The single representation dimension.
47
+ token_z : int
48
+ The pair representation dimension.
49
+ pairformer_args : int
50
+ The pairformer arguments.
51
+ num_dist_bins : int, optional
52
+ The number of distance bins, by default 64.
53
+ max_dist : int, optional
54
+ The maximum distance, by default 22.
55
+ add_s_to_z_prod : bool, optional
56
+ Whether to add s to z product, by default False.
57
+ add_s_input_to_s : bool, optional
58
+ Whether to add s input to s, by default False.
59
+ use_s_diffusion : bool, optional
60
+ Whether to use s diffusion, by default False.
61
+ add_z_input_to_z : bool, optional
62
+ Whether to add z input to z, by default False.
63
+ confidence_args : dict, optional
64
+ The confidence arguments, by default None.
65
+ compute_pae : bool, optional
66
+ Whether to compute pae, by default False.
67
+ imitate_trunk : bool, optional
68
+ Whether to imitate trunk, by default False.
69
+ full_embedder_args : dict, optional
70
+ The full embedder arguments, by default None.
71
+ msa_args : dict, optional
72
+ The msa arguments, by default None.
73
+ compile_pairformer : bool, optional
74
+ Whether to compile pairformer, by default False.
75
+
76
+ """
77
+ super().__init__()
78
+ self.max_num_atoms_per_token = 23
79
+ self.no_update_s = pairformer_args.get("no_update_s", False)
80
+ boundaries = torch.linspace(2, max_dist, num_dist_bins - 1)
81
+ self.register_buffer("boundaries", boundaries)
82
+ self.dist_bin_pairwise_embed = nn.Embedding(num_dist_bins, token_z)
83
+ init.gating_init_(self.dist_bin_pairwise_embed.weight)
84
+ s_input_dim = (
85
+ token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info)
86
+ )
87
+
88
+ self.use_s_diffusion = use_s_diffusion
89
+ if use_s_diffusion:
90
+ self.s_diffusion_norm = nn.LayerNorm(2 * token_s)
91
+ self.s_diffusion_to_s = LinearNoBias(2 * token_s, token_s)
92
+ init.gating_init_(self.s_diffusion_to_s.weight)
93
+
94
+ self.s_to_z = LinearNoBias(s_input_dim, token_z)
95
+ self.s_to_z_transpose = LinearNoBias(s_input_dim, token_z)
96
+ init.gating_init_(self.s_to_z.weight)
97
+ init.gating_init_(self.s_to_z_transpose.weight)
98
+
99
+ self.add_s_to_z_prod = add_s_to_z_prod
100
+ if add_s_to_z_prod:
101
+ self.s_to_z_prod_in1 = LinearNoBias(s_input_dim, token_z)
102
+ self.s_to_z_prod_in2 = LinearNoBias(s_input_dim, token_z)
103
+ self.s_to_z_prod_out = LinearNoBias(token_z, token_z)
104
+ init.gating_init_(self.s_to_z_prod_out.weight)
105
+
106
+ self.imitate_trunk = imitate_trunk
107
+ if self.imitate_trunk:
108
+ s_input_dim = (
109
+ token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info)
110
+ )
111
+ self.s_init = nn.Linear(s_input_dim, token_s, bias=False)
112
+ self.z_init_1 = nn.Linear(s_input_dim, token_z, bias=False)
113
+ self.z_init_2 = nn.Linear(s_input_dim, token_z, bias=False)
114
+
115
+ # Input embeddings
116
+ self.input_embedder = InputEmbedder(**full_embedder_args)
117
+ self.rel_pos = RelativePositionEncoder(token_z)
118
+ self.token_bonds = nn.Linear(1, token_z, bias=False)
119
+
120
+ # Normalization layers
121
+ self.s_norm = nn.LayerNorm(token_s)
122
+ self.z_norm = nn.LayerNorm(token_z)
123
+
124
+ # Recycling projections
125
+ self.s_recycle = nn.Linear(token_s, token_s, bias=False)
126
+ self.z_recycle = nn.Linear(token_z, token_z, bias=False)
127
+ init.gating_init_(self.s_recycle.weight)
128
+ init.gating_init_(self.z_recycle.weight)
129
+
130
+ # Pairwise stack
131
+ self.msa_module = MSAModule(
132
+ token_z=token_z,
133
+ s_input_dim=s_input_dim,
134
+ **msa_args,
135
+ )
136
+ self.pairformer_module = PairformerModule(
137
+ token_s,
138
+ token_z,
139
+ **pairformer_args,
140
+ )
141
+ if compile_pairformer:
142
+ # Big models hit the default cache limit (8)
143
+ self.is_pairformer_compiled = True
144
+ torch._dynamo.config.cache_size_limit = 512
145
+ torch._dynamo.config.accumulated_cache_size_limit = 512
146
+ self.pairformer_module = torch.compile(
147
+ self.pairformer_module,
148
+ dynamic=False,
149
+ fullgraph=False,
150
+ )
151
+
152
+ self.final_s_norm = nn.LayerNorm(token_s)
153
+ self.final_z_norm = nn.LayerNorm(token_z)
154
+ else:
155
+ self.s_inputs_norm = nn.LayerNorm(s_input_dim)
156
+ if not self.no_update_s:
157
+ self.s_norm = nn.LayerNorm(token_s)
158
+ self.z_norm = nn.LayerNorm(token_z)
159
+
160
+ self.add_s_input_to_s = add_s_input_to_s
161
+ if add_s_input_to_s:
162
+ self.s_input_to_s = LinearNoBias(s_input_dim, token_s)
163
+ init.gating_init_(self.s_input_to_s.weight)
164
+
165
+ self.add_z_input_to_z = add_z_input_to_z
166
+ if add_z_input_to_z:
167
+ self.rel_pos = RelativePositionEncoder(token_z)
168
+ self.token_bonds = nn.Linear(1, token_z, bias=False)
169
+
170
+ self.pairformer_stack = PairformerModule(
171
+ token_s,
172
+ token_z,
173
+ **pairformer_args,
174
+ )
175
+
176
+ self.confidence_heads = ConfidenceHeads(
177
+ token_s,
178
+ token_z,
179
+ compute_pae=compute_pae,
180
+ **confidence_args,
181
+ )
182
+
183
+ def forward(
184
+ self,
185
+ s_inputs,
186
+ s,
187
+ z,
188
+ x_pred,
189
+ feats,
190
+ pred_distogram_logits,
191
+ multiplicity=1,
192
+ s_diffusion=None,
193
+ run_sequentially=False,
194
+ use_kernels: bool = False,
195
+ ):
196
+ if run_sequentially and multiplicity > 1:
197
+ assert z.shape[0] == 1, "Not supported with batch size > 1"
198
+ out_dicts = []
199
+ for sample_idx in range(multiplicity):
200
+ out_dicts.append( # noqa: PERF401
201
+ self.forward(
202
+ s_inputs,
203
+ s,
204
+ z,
205
+ x_pred[sample_idx : sample_idx + 1],
206
+ feats,
207
+ pred_distogram_logits,
208
+ multiplicity=1,
209
+ s_diffusion=s_diffusion[sample_idx : sample_idx + 1]
210
+ if s_diffusion is not None
211
+ else None,
212
+ run_sequentially=False,
213
+ use_kernels=use_kernels,
214
+ )
215
+ )
216
+
217
+ out_dict = {}
218
+ for key in out_dicts[0]:
219
+ if key != "pair_chains_iptm":
220
+ out_dict[key] = torch.cat([out[key] for out in out_dicts], dim=0)
221
+ else:
222
+ pair_chains_iptm = {}
223
+ for chain_idx1 in out_dicts[0][key].keys():
224
+ chains_iptm = {}
225
+ for chain_idx2 in out_dicts[0][key][chain_idx1].keys():
226
+ chains_iptm[chain_idx2] = torch.cat(
227
+ [out[key][chain_idx1][chain_idx2] for out in out_dicts],
228
+ dim=0,
229
+ )
230
+ pair_chains_iptm[chain_idx1] = chains_iptm
231
+ out_dict[key] = pair_chains_iptm
232
+ return out_dict
233
+ if self.imitate_trunk:
234
+ s_inputs = self.input_embedder(feats)
235
+
236
+ # Initialize the sequence and pairwise embeddings
237
+ s_init = self.s_init(s_inputs)
238
+ z_init = (
239
+ self.z_init_1(s_inputs)[:, :, None]
240
+ + self.z_init_2(s_inputs)[:, None, :]
241
+ )
242
+ relative_position_encoding = self.rel_pos(feats)
243
+ z_init = z_init + relative_position_encoding
244
+ z_init = z_init + self.token_bonds(feats["token_bonds"].float())
245
+
246
+ # Apply recycling
247
+ s = s_init + self.s_recycle(self.s_norm(s))
248
+ z = z_init + self.z_recycle(self.z_norm(z))
249
+
250
+ else:
251
+ s_inputs = self.s_inputs_norm(s_inputs).repeat_interleave(multiplicity, 0)
252
+ if not self.no_update_s:
253
+ s = self.s_norm(s)
254
+
255
+ if self.add_s_input_to_s:
256
+ s = s + self.s_input_to_s(s_inputs)
257
+
258
+ z = self.z_norm(z)
259
+
260
+ if self.add_z_input_to_z:
261
+ relative_position_encoding = self.rel_pos(feats)
262
+ z = z + relative_position_encoding
263
+ z = z + self.token_bonds(feats["token_bonds"].float())
264
+
265
+ s = s.repeat_interleave(multiplicity, 0)
266
+
267
+ if self.use_s_diffusion:
268
+ assert s_diffusion is not None
269
+ s_diffusion = self.s_diffusion_norm(s_diffusion)
270
+ s = s + self.s_diffusion_to_s(s_diffusion)
271
+
272
+ z = z.repeat_interleave(multiplicity, 0)
273
+ z = (
274
+ z
275
+ + self.s_to_z(s_inputs)[:, :, None, :]
276
+ + self.s_to_z_transpose(s_inputs)[:, None, :, :]
277
+ )
278
+
279
+ if self.add_s_to_z_prod:
280
+ z = z + self.s_to_z_prod_out(
281
+ self.s_to_z_prod_in1(s_inputs)[:, :, None, :]
282
+ * self.s_to_z_prod_in2(s_inputs)[:, None, :, :]
283
+ )
284
+
285
+ token_to_rep_atom = feats["token_to_rep_atom"]
286
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
287
+ if len(x_pred.shape) == 4:
288
+ B, mult, N, _ = x_pred.shape
289
+ x_pred = x_pred.reshape(B * mult, N, -1)
290
+ x_pred_repr = torch.bmm(token_to_rep_atom.float(), x_pred)
291
+ d = torch.cdist(x_pred_repr, x_pred_repr)
292
+
293
+ distogram = (d.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
294
+ distogram = self.dist_bin_pairwise_embed(distogram)
295
+
296
+ z = z + distogram
297
+
298
+ mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
299
+ pair_mask = mask[:, :, None] * mask[:, None, :]
300
+
301
+ if self.imitate_trunk:
302
+ z = z + self.msa_module(z, s_inputs, feats, use_kernels=use_kernels)
303
+
304
+ s, z = self.pairformer_module(
305
+ s, z, mask=mask, pair_mask=pair_mask, use_kernels=use_kernels
306
+ )
307
+
308
+ s, z = self.final_s_norm(s), self.final_z_norm(z)
309
+
310
+ else:
311
+ s_t, z_t = self.pairformer_stack(
312
+ s, z, mask=mask, pair_mask=pair_mask, use_kernels=use_kernels
313
+ )
314
+
315
+ # AF3 has residual connections, we remove them
316
+ s = s_t
317
+ z = z_t
318
+
319
+ out_dict = {}
320
+
321
+ # confidence heads
322
+ out_dict.update(
323
+ self.confidence_heads(
324
+ s=s,
325
+ z=z,
326
+ x_pred=x_pred,
327
+ d=d,
328
+ feats=feats,
329
+ multiplicity=multiplicity,
330
+ pred_distogram_logits=pred_distogram_logits,
331
+ )
332
+ )
333
+
334
+ return out_dict
335
+
336
+
337
+ class ConfidenceHeads(nn.Module):
338
+ """Confidence heads."""
339
+
340
+ def __init__(
341
+ self,
342
+ token_s,
343
+ token_z,
344
+ num_plddt_bins=50,
345
+ num_pde_bins=64,
346
+ num_pae_bins=64,
347
+ compute_pae: bool = True,
348
+ ):
349
+ """Initialize the confidence head.
350
+
351
+ Parameters
352
+ ----------
353
+ token_s : int
354
+ The single representation dimension.
355
+ token_z : int
356
+ The pair representation dimension.
357
+ num_plddt_bins : int
358
+ The number of plddt bins, by default 50.
359
+ num_pde_bins : int
360
+ The number of pde bins, by default 64.
361
+ num_pae_bins : int
362
+ The number of pae bins, by default 64.
363
+ compute_pae : bool
364
+ Whether to compute pae, by default False
365
+
366
+ """
367
+ super().__init__()
368
+ self.max_num_atoms_per_token = 23
369
+ self.to_pde_logits = LinearNoBias(token_z, num_pde_bins)
370
+ self.to_plddt_logits = LinearNoBias(token_s, num_plddt_bins)
371
+ self.to_resolved_logits = LinearNoBias(token_s, 2)
372
+ self.compute_pae = compute_pae
373
+ if self.compute_pae:
374
+ self.to_pae_logits = LinearNoBias(token_z, num_pae_bins)
375
+
376
+ def forward(
377
+ self,
378
+ s,
379
+ z,
380
+ x_pred,
381
+ d,
382
+ feats,
383
+ pred_distogram_logits,
384
+ multiplicity=1,
385
+ ):
386
+ # Compute the pLDDT, PDE, PAE, and resolved logits
387
+ plddt_logits = self.to_plddt_logits(s)
388
+ pde_logits = self.to_pde_logits(z + z.transpose(1, 2))
389
+ resolved_logits = self.to_resolved_logits(s)
390
+ if self.compute_pae:
391
+ pae_logits = self.to_pae_logits(z)
392
+
393
+ # Weights used to compute the interface pLDDT
394
+ ligand_weight = 2
395
+ interface_weight = 1
396
+
397
+ # Retrieve relevant features
398
+ token_type = feats["mol_type"]
399
+ token_type = token_type.repeat_interleave(multiplicity, 0)
400
+ is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
401
+
402
+ # Compute the aggregated pLDDT and iPLDDT
403
+ plddt = compute_aggregated_metric(plddt_logits)
404
+ token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
405
+ complex_plddt = (plddt * token_pad_mask).sum(dim=-1) / token_pad_mask.sum(
406
+ dim=-1
407
+ )
408
+
409
+ is_contact = (d < 8).float()
410
+ is_different_chain = (
411
+ feats["asym_id"].unsqueeze(-1) != feats["asym_id"].unsqueeze(-2)
412
+ ).float()
413
+ is_different_chain = is_different_chain.repeat_interleave(multiplicity, 0)
414
+ token_interface_mask = torch.max(
415
+ is_contact * is_different_chain * (1 - is_ligand_token).unsqueeze(-1),
416
+ dim=-1,
417
+ ).values
418
+ iplddt_weight = (
419
+ is_ligand_token * ligand_weight + token_interface_mask * interface_weight
420
+ )
421
+ complex_iplddt = (plddt * token_pad_mask * iplddt_weight).sum(dim=-1) / (
422
+ torch.sum(token_pad_mask * iplddt_weight, dim=-1) + 1e-5
423
+ )
424
+
425
+ # Compute the aggregated PDE and iPDE
426
+ pde = compute_aggregated_metric(pde_logits, end=32)
427
+ pred_distogram_prob = nn.functional.softmax(
428
+ pred_distogram_logits, dim=-1
429
+ ).repeat_interleave(multiplicity, 0)
430
+ contacts = torch.zeros((1, 1, 1, 64), dtype=pred_distogram_prob.dtype).to(
431
+ pred_distogram_prob.device
432
+ )
433
+ contacts[:, :, :, :20] = 1.0
434
+ prob_contact = (pred_distogram_prob * contacts).sum(-1)
435
+ token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
436
+ token_pad_pair_mask = (
437
+ token_pad_mask.unsqueeze(-1)
438
+ * token_pad_mask.unsqueeze(-2)
439
+ * (
440
+ 1
441
+ - torch.eye(
442
+ token_pad_mask.shape[1], device=token_pad_mask.device
443
+ ).unsqueeze(0)
444
+ )
445
+ )
446
+ token_pair_mask = token_pad_pair_mask * prob_contact
447
+ complex_pde = (pde * token_pair_mask).sum(dim=(1, 2)) / token_pair_mask.sum(
448
+ dim=(1, 2)
449
+ )
450
+ asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
451
+ token_interface_pair_mask = token_pair_mask * (
452
+ asym_id.unsqueeze(-1) != asym_id.unsqueeze(-2)
453
+ )
454
+ complex_ipde = (pde * token_interface_pair_mask).sum(dim=(1, 2)) / (
455
+ token_interface_pair_mask.sum(dim=(1, 2)) + 1e-5
456
+ )
457
+
458
+ out_dict = dict(
459
+ pde_logits=pde_logits,
460
+ plddt_logits=plddt_logits,
461
+ resolved_logits=resolved_logits,
462
+ pde=pde,
463
+ plddt=plddt,
464
+ complex_plddt=complex_plddt,
465
+ complex_iplddt=complex_iplddt,
466
+ complex_pde=complex_pde,
467
+ complex_ipde=complex_ipde,
468
+ )
469
+ if self.compute_pae:
470
+ out_dict["pae_logits"] = pae_logits
471
+ out_dict["pae"] = compute_aggregated_metric(pae_logits, end=32)
472
+ ptm, iptm, ligand_iptm, protein_iptm, pair_chains_iptm = compute_ptms(
473
+ pae_logits, x_pred, feats, multiplicity
474
+ )
475
+ out_dict["ptm"] = ptm
476
+ out_dict["iptm"] = iptm
477
+ out_dict["ligand_iptm"] = ligand_iptm
478
+ out_dict["protein_iptm"] = protein_iptm
479
+ out_dict["pair_chains_iptm"] = pair_chains_iptm
480
+
481
+ return out_dict