boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,181 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from boltz.data import const
5
+ from boltz.model.loss.confidence import compute_frame_pred
6
+
7
+
8
+ def compute_aggregated_metric(logits, end=1.0):
9
+ """Compute the metric from the logits.
10
+
11
+ Parameters
12
+ ----------
13
+ logits : torch.Tensor
14
+ The logits of the metric
15
+ end : float
16
+ Max value of the metric, by default 1.0
17
+
18
+ Returns
19
+ -------
20
+ Tensor
21
+ The metric value
22
+
23
+ """
24
+ num_bins = logits.shape[-1]
25
+ bin_width = end / num_bins
26
+ bounds = torch.arange(
27
+ start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
28
+ )
29
+ probs = nn.functional.softmax(logits, dim=-1)
30
+ plddt = torch.sum(
31
+ probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
32
+ dim=-1,
33
+ )
34
+ return plddt
35
+
36
+
37
+ def tm_function(d, Nres):
38
+ """Compute the rescaling function for pTM.
39
+
40
+ Parameters
41
+ ----------
42
+ d : torch.Tensor
43
+ The input
44
+ Nres : torch.Tensor
45
+ The number of residues
46
+
47
+ Returns
48
+ -------
49
+ Tensor
50
+ Output of the function
51
+
52
+ """
53
+ d0 = 1.24 * (torch.clip(Nres, min=19) - 15) ** (1 / 3) - 1.8
54
+ return 1 / (1 + (d / d0) ** 2)
55
+
56
+
57
+ def compute_ptms(logits, x_preds, feats, multiplicity):
58
+ """Compute pTM and ipTM scores.
59
+
60
+ Parameters
61
+ ----------
62
+ logits : torch.Tensor
63
+ pae logits
64
+ x_preds : torch.Tensor
65
+ The predicted coordinates
66
+ feats : Dict[str, torch.Tensor]
67
+ The input features
68
+ multiplicity : int
69
+ The batch size of the diffusion roll-out
70
+
71
+ Returns
72
+ -------
73
+ Tensor
74
+ pTM score
75
+ Tensor
76
+ ipTM score
77
+ Tensor
78
+ ligand ipTM score
79
+ Tensor
80
+ protein ipTM score
81
+
82
+ """
83
+ # Compute mask for collinear and overlapping tokens
84
+ _, mask_collinear_pred = compute_frame_pred(
85
+ x_preds, feats["frames_idx"], feats, multiplicity, inference=True
86
+ )
87
+ mask_pad = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
88
+ maski = mask_collinear_pred.reshape(-1, mask_collinear_pred.shape[-1])
89
+ pair_mask_ptm = maski[:, :, None] * mask_pad[:, None, :] * mask_pad[:, :, None]
90
+ asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
91
+ pair_mask_iptm = (
92
+ maski[:, :, None]
93
+ * (asym_id[:, None, :] != asym_id[:, :, None])
94
+ * mask_pad[:, None, :]
95
+ * mask_pad[:, :, None]
96
+ )
97
+
98
+ # Extract pae values
99
+ num_bins = logits.shape[-1]
100
+ bin_width = 32.0 / num_bins
101
+ end = 32.0
102
+ pae_value = torch.arange(
103
+ start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
104
+ ).unsqueeze(0)
105
+ N_res = mask_pad.sum(dim=-1, keepdim=True)
106
+
107
+ # compute pTM and ipTM
108
+ tm_value = tm_function(pae_value, N_res).unsqueeze(1).unsqueeze(2)
109
+ probs = nn.functional.softmax(logits, dim=-1)
110
+ tm_expected_value = torch.sum(
111
+ probs * tm_value,
112
+ dim=-1,
113
+ ) # shape (B, N, N)
114
+ ptm = torch.max(
115
+ torch.sum(tm_expected_value * pair_mask_ptm, dim=-1)
116
+ / (torch.sum(pair_mask_ptm, dim=-1) + 1e-5),
117
+ dim=1,
118
+ ).values
119
+ iptm = torch.max(
120
+ torch.sum(tm_expected_value * pair_mask_iptm, dim=-1)
121
+ / (torch.sum(pair_mask_iptm, dim=-1) + 1e-5),
122
+ dim=1,
123
+ ).values
124
+
125
+ # compute ligand and protein ipTM
126
+ token_type = feats["mol_type"]
127
+ token_type = token_type.repeat_interleave(multiplicity, 0)
128
+ is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
129
+ is_protein_token = (token_type == const.chain_type_ids["PROTEIN"]).float()
130
+
131
+ ligand_iptm_mask = (
132
+ maski[:, :, None]
133
+ * (asym_id[:, None, :] != asym_id[:, :, None])
134
+ * mask_pad[:, None, :]
135
+ * mask_pad[:, :, None]
136
+ * (
137
+ (is_ligand_token[:, :, None] * is_protein_token[:, None, :])
138
+ + (is_protein_token[:, :, None] * is_ligand_token[:, None, :])
139
+ )
140
+ )
141
+ protein_ipmt_mask = (
142
+ maski[:, :, None]
143
+ * (asym_id[:, None, :] != asym_id[:, :, None])
144
+ * mask_pad[:, None, :]
145
+ * mask_pad[:, :, None]
146
+ * (is_protein_token[:, :, None] * is_protein_token[:, None, :])
147
+ )
148
+
149
+ ligand_iptm = torch.max(
150
+ torch.sum(tm_expected_value * ligand_iptm_mask, dim=-1)
151
+ / (torch.sum(ligand_iptm_mask, dim=-1) + 1e-5),
152
+ dim=1,
153
+ ).values
154
+ protein_iptm = torch.max(
155
+ torch.sum(tm_expected_value * protein_ipmt_mask, dim=-1)
156
+ / (torch.sum(protein_ipmt_mask, dim=-1) + 1e-5),
157
+ dim=1,
158
+ ).values
159
+
160
+ # Compute pair chain ipTM
161
+ chain_pair_iptm = {}
162
+ asym_ids_list = torch.unique(asym_id).tolist()
163
+ for idx1 in asym_ids_list:
164
+ chain_iptm = {}
165
+ for idx2 in asym_ids_list:
166
+ mask_pair_chain = (
167
+ maski[:, :, None]
168
+ * (asym_id[:, None, :] == idx1)
169
+ * (asym_id[:, :, None] == idx2)
170
+ * mask_pad[:, None, :]
171
+ * mask_pad[:, :, None]
172
+ )
173
+
174
+ chain_iptm[idx2] = torch.max(
175
+ torch.sum(tm_expected_value * mask_pair_chain, dim=-1)
176
+ / (torch.sum(mask_pair_chain, dim=-1) + 1e-5),
177
+ dim=1,
178
+ ).values
179
+ chain_pair_iptm[idx1] = chain_iptm
180
+
181
+ return ptm, iptm, ligand_iptm, protein_iptm, chain_pair_iptm
@@ -0,0 +1,495 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn.functional import pad
4
+
5
+ import boltz.model.layers.initialize as init
6
+ from boltz.data import const
7
+ from boltz.model.layers.confidence_utils import (
8
+ compute_aggregated_metric,
9
+ compute_ptms,
10
+ )
11
+ from boltz.model.layers.pairformer import PairformerModule
12
+ from boltz.model.modules.encodersv2 import RelativePositionEncoder
13
+ from boltz.model.modules.trunkv2 import (
14
+ ContactConditioning,
15
+ )
16
+ from boltz.model.modules.utils import LinearNoBias
17
+
18
+
19
+ class ConfidenceModule(nn.Module):
20
+ """Algorithm 31"""
21
+
22
+ def __init__(
23
+ self,
24
+ token_s,
25
+ token_z,
26
+ pairformer_args: dict,
27
+ num_dist_bins=64,
28
+ token_level_confidence=True,
29
+ max_dist=22,
30
+ add_s_to_z_prod=False,
31
+ add_s_input_to_s=False,
32
+ add_z_input_to_z=False,
33
+ maximum_bond_distance=0,
34
+ bond_type_feature=False,
35
+ confidence_args: dict = None,
36
+ compile_pairformer=False,
37
+ fix_sym_check=False,
38
+ cyclic_pos_enc=False,
39
+ return_latent_feats=False,
40
+ conditioning_cutoff_min=None,
41
+ conditioning_cutoff_max=None,
42
+ **kwargs,
43
+ ):
44
+ super().__init__()
45
+ self.max_num_atoms_per_token = 23
46
+ self.no_update_s = pairformer_args.get("no_update_s", False)
47
+ boundaries = torch.linspace(2, max_dist, num_dist_bins - 1)
48
+ self.register_buffer("boundaries", boundaries)
49
+ self.dist_bin_pairwise_embed = nn.Embedding(num_dist_bins, token_z)
50
+ init.gating_init_(self.dist_bin_pairwise_embed.weight)
51
+ self.token_level_confidence = token_level_confidence
52
+
53
+ self.s_to_z = LinearNoBias(token_s, token_z)
54
+ self.s_to_z_transpose = LinearNoBias(token_s, token_z)
55
+ init.gating_init_(self.s_to_z.weight)
56
+ init.gating_init_(self.s_to_z_transpose.weight)
57
+
58
+ self.add_s_to_z_prod = add_s_to_z_prod
59
+ if add_s_to_z_prod:
60
+ self.s_to_z_prod_in1 = LinearNoBias(token_s, token_z)
61
+ self.s_to_z_prod_in2 = LinearNoBias(token_s, token_z)
62
+ self.s_to_z_prod_out = LinearNoBias(token_z, token_z)
63
+ init.gating_init_(self.s_to_z_prod_out.weight)
64
+
65
+ self.s_inputs_norm = nn.LayerNorm(token_s)
66
+ if not self.no_update_s:
67
+ self.s_norm = nn.LayerNorm(token_s)
68
+ self.z_norm = nn.LayerNorm(token_z)
69
+
70
+ self.add_s_input_to_s = add_s_input_to_s
71
+ if add_s_input_to_s:
72
+ self.s_input_to_s = LinearNoBias(token_s, token_s)
73
+ init.gating_init_(self.s_input_to_s.weight)
74
+
75
+ self.add_z_input_to_z = add_z_input_to_z
76
+ if add_z_input_to_z:
77
+ self.rel_pos = RelativePositionEncoder(
78
+ token_z, fix_sym_check=fix_sym_check, cyclic_pos_enc=cyclic_pos_enc
79
+ )
80
+ self.token_bonds = nn.Linear(
81
+ 1 if maximum_bond_distance == 0 else maximum_bond_distance + 2,
82
+ token_z,
83
+ bias=False,
84
+ )
85
+ self.bond_type_feature = bond_type_feature
86
+ if bond_type_feature:
87
+ self.token_bonds_type = nn.Embedding(len(const.bond_types) + 1, token_z)
88
+
89
+ self.contact_conditioning = ContactConditioning(
90
+ token_z=token_z,
91
+ cutoff_min=conditioning_cutoff_min,
92
+ cutoff_max=conditioning_cutoff_max,
93
+ )
94
+ pairformer_args["v2"] = True
95
+ self.pairformer_stack = PairformerModule(
96
+ token_s,
97
+ token_z,
98
+ **pairformer_args,
99
+ )
100
+ self.return_latent_feats = return_latent_feats
101
+
102
+ self.confidence_heads = ConfidenceHeads(
103
+ token_s,
104
+ token_z,
105
+ token_level_confidence=token_level_confidence,
106
+ **confidence_args,
107
+ )
108
+
109
+ def forward(
110
+ self,
111
+ s_inputs, # Float['b n ts']
112
+ s, # Float['b n ts']
113
+ z, # Float['b n n tz']
114
+ x_pred, # Float['bm m 3']
115
+ feats,
116
+ pred_distogram_logits,
117
+ multiplicity=1,
118
+ run_sequentially=False,
119
+ use_kernels: bool = False,
120
+ ):
121
+ if run_sequentially and multiplicity > 1:
122
+ assert z.shape[0] == 1, "Not supported with batch size > 1"
123
+ out_dicts = []
124
+ for sample_idx in range(multiplicity):
125
+ out_dicts.append( # noqa: PERF401
126
+ self.forward(
127
+ s_inputs,
128
+ s,
129
+ z,
130
+ x_pred[sample_idx : sample_idx + 1],
131
+ feats,
132
+ pred_distogram_logits,
133
+ multiplicity=1,
134
+ run_sequentially=False,
135
+ use_kernels=use_kernels,
136
+ )
137
+ )
138
+
139
+ out_dict = {}
140
+ for key in out_dicts[0]:
141
+ if key != "pair_chains_iptm":
142
+ out_dict[key] = torch.cat([out[key] for out in out_dicts], dim=0)
143
+ else:
144
+ pair_chains_iptm = {}
145
+ for chain_idx1 in out_dicts[0][key]:
146
+ chains_iptm = {}
147
+ for chain_idx2 in out_dicts[0][key][chain_idx1]:
148
+ chains_iptm[chain_idx2] = torch.cat(
149
+ [out[key][chain_idx1][chain_idx2] for out in out_dicts],
150
+ dim=0,
151
+ )
152
+ pair_chains_iptm[chain_idx1] = chains_iptm
153
+ out_dict[key] = pair_chains_iptm
154
+ return out_dict
155
+
156
+ s_inputs = self.s_inputs_norm(s_inputs)
157
+ if not self.no_update_s:
158
+ s = self.s_norm(s)
159
+
160
+ if self.add_s_input_to_s:
161
+ s = s + self.s_input_to_s(s_inputs)
162
+
163
+ z = self.z_norm(z)
164
+
165
+ if self.add_z_input_to_z:
166
+ relative_position_encoding = self.rel_pos(feats)
167
+ z = z + relative_position_encoding
168
+ z = z + self.token_bonds(feats["token_bonds"].float())
169
+ if self.bond_type_feature:
170
+ z = z + self.token_bonds_type(feats["type_bonds"].long())
171
+ z = z + self.contact_conditioning(feats)
172
+
173
+ s = s.repeat_interleave(multiplicity, 0)
174
+
175
+ z = (
176
+ z
177
+ + self.s_to_z(s_inputs)[:, :, None, :]
178
+ + self.s_to_z_transpose(s_inputs)[:, None, :, :]
179
+ )
180
+ if self.add_s_to_z_prod:
181
+ z = z + self.s_to_z_prod_out(
182
+ self.s_to_z_prod_in1(s_inputs)[:, :, None, :]
183
+ * self.s_to_z_prod_in2(s_inputs)[:, None, :, :]
184
+ )
185
+
186
+ z = z.repeat_interleave(multiplicity, 0)
187
+ s_inputs = s_inputs.repeat_interleave(multiplicity, 0)
188
+
189
+ token_to_rep_atom = feats["token_to_rep_atom"]
190
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
191
+ if len(x_pred.shape) == 4:
192
+ B, mult, N, _ = x_pred.shape
193
+ x_pred = x_pred.reshape(B * mult, N, -1)
194
+ else:
195
+ BM, N, _ = x_pred.shape
196
+ x_pred_repr = torch.bmm(token_to_rep_atom.float(), x_pred)
197
+ d = torch.cdist(x_pred_repr, x_pred_repr)
198
+ distogram = (d.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
199
+ distogram = self.dist_bin_pairwise_embed(distogram)
200
+ z = z + distogram
201
+
202
+ mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
203
+ pair_mask = mask[:, :, None] * mask[:, None, :]
204
+
205
+ s_t, z_t = self.pairformer_stack(
206
+ s, z, mask=mask, pair_mask=pair_mask, use_kernels=use_kernels
207
+ )
208
+
209
+ # AF3 has residual connections, we remove them
210
+ s = s_t
211
+ z = z_t
212
+
213
+ out_dict = {}
214
+
215
+ if self.return_latent_feats:
216
+ out_dict["s_conf"] = s
217
+ out_dict["z_conf"] = z
218
+
219
+ # confidence heads
220
+ out_dict.update(
221
+ self.confidence_heads(
222
+ s=s,
223
+ z=z,
224
+ x_pred=x_pred,
225
+ d=d,
226
+ feats=feats,
227
+ multiplicity=multiplicity,
228
+ pred_distogram_logits=pred_distogram_logits,
229
+ )
230
+ )
231
+ return out_dict
232
+
233
+
234
+ class ConfidenceHeads(nn.Module):
235
+ def __init__(
236
+ self,
237
+ token_s,
238
+ token_z,
239
+ num_plddt_bins=50,
240
+ num_pde_bins=64,
241
+ num_pae_bins=64,
242
+ token_level_confidence=True,
243
+ use_separate_heads: bool = False,
244
+ **kwargs,
245
+ ):
246
+ super().__init__()
247
+ self.max_num_atoms_per_token = 23
248
+ self.token_level_confidence = token_level_confidence
249
+ self.use_separate_heads = use_separate_heads
250
+
251
+ if self.use_separate_heads:
252
+ self.to_pae_intra_logits = LinearNoBias(token_z, num_pae_bins)
253
+ self.to_pae_inter_logits = LinearNoBias(token_z, num_pae_bins)
254
+ else:
255
+ self.to_pae_logits = LinearNoBias(token_z, num_pae_bins)
256
+
257
+ if self.use_separate_heads:
258
+ self.to_pde_intra_logits = LinearNoBias(token_z, num_pde_bins)
259
+ self.to_pde_inter_logits = LinearNoBias(token_z, num_pde_bins)
260
+ else:
261
+ self.to_pde_logits = LinearNoBias(token_z, num_pde_bins)
262
+
263
+ if self.token_level_confidence:
264
+ self.to_plddt_logits = LinearNoBias(token_s, num_plddt_bins)
265
+ self.to_resolved_logits = LinearNoBias(token_s, 2)
266
+ else:
267
+ self.to_plddt_logits = LinearNoBias(
268
+ token_s, num_plddt_bins * self.max_num_atoms_per_token
269
+ )
270
+ self.to_resolved_logits = LinearNoBias(
271
+ token_s, 2 * self.max_num_atoms_per_token
272
+ )
273
+
274
+ def forward(
275
+ self,
276
+ s, # Float['b n ts']
277
+ z, # Float['b n n tz']
278
+ x_pred, # Float['bm m 3']
279
+ d,
280
+ feats,
281
+ pred_distogram_logits,
282
+ multiplicity=1,
283
+ ):
284
+ if self.use_separate_heads:
285
+ asym_id_token = feats["asym_id"]
286
+ is_same_chain = asym_id_token.unsqueeze(-1) == asym_id_token.unsqueeze(-2)
287
+ is_different_chain = ~is_same_chain
288
+
289
+ if self.use_separate_heads:
290
+ pae_intra_logits = self.to_pae_intra_logits(z)
291
+ pae_intra_logits = pae_intra_logits * is_same_chain.float().unsqueeze(-1)
292
+
293
+ pae_inter_logits = self.to_pae_inter_logits(z)
294
+ pae_inter_logits = pae_inter_logits * is_different_chain.float().unsqueeze(
295
+ -1
296
+ )
297
+
298
+ pae_logits = pae_inter_logits + pae_intra_logits
299
+ else:
300
+ pae_logits = self.to_pae_logits(z)
301
+
302
+ if self.use_separate_heads:
303
+ pde_intra_logits = self.to_pde_intra_logits(z + z.transpose(1, 2))
304
+ pde_intra_logits = pde_intra_logits * is_same_chain.float().unsqueeze(-1)
305
+
306
+ pde_inter_logits = self.to_pde_inter_logits(z + z.transpose(1, 2))
307
+ pde_inter_logits = pde_inter_logits * is_different_chain.float().unsqueeze(
308
+ -1
309
+ )
310
+
311
+ pde_logits = pde_inter_logits + pde_intra_logits
312
+ else:
313
+ pde_logits = self.to_pde_logits(z + z.transpose(1, 2))
314
+ resolved_logits = self.to_resolved_logits(s)
315
+ plddt_logits = self.to_plddt_logits(s)
316
+
317
+ ligand_weight = 20
318
+ non_interface_weight = 1
319
+ interface_weight = 10
320
+
321
+ token_type = feats["mol_type"]
322
+ token_type = token_type.repeat_interleave(multiplicity, 0)
323
+ is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
324
+
325
+ if self.token_level_confidence:
326
+ plddt = compute_aggregated_metric(plddt_logits)
327
+ token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
328
+ complex_plddt = (plddt * token_pad_mask).sum(dim=-1) / token_pad_mask.sum(
329
+ dim=-1
330
+ )
331
+
332
+ is_contact = (d < 8).float()
333
+ is_different_chain = (
334
+ feats["asym_id"].unsqueeze(-1) != feats["asym_id"].unsqueeze(-2)
335
+ ).float()
336
+ is_different_chain = is_different_chain.repeat_interleave(multiplicity, 0)
337
+ token_interface_mask = torch.max(
338
+ is_contact * is_different_chain * (1 - is_ligand_token).unsqueeze(-1),
339
+ dim=-1,
340
+ ).values
341
+ token_non_interface_mask = (1 - token_interface_mask) * (
342
+ 1 - is_ligand_token
343
+ )
344
+ iplddt_weight = (
345
+ is_ligand_token * ligand_weight
346
+ + token_interface_mask * interface_weight
347
+ + token_non_interface_mask * non_interface_weight
348
+ )
349
+ complex_iplddt = (plddt * token_pad_mask * iplddt_weight).sum(
350
+ dim=-1
351
+ ) / torch.sum(token_pad_mask * iplddt_weight, dim=-1)
352
+
353
+ else:
354
+ # token to atom conversion for resolved logits
355
+ B, N, _ = resolved_logits.shape
356
+ resolved_logits = resolved_logits.reshape(
357
+ B, N, self.max_num_atoms_per_token, 2
358
+ )
359
+
360
+ arange_max_num_atoms = (
361
+ torch.arange(self.max_num_atoms_per_token)
362
+ .reshape(1, 1, -1)
363
+ .to(resolved_logits.device)
364
+ )
365
+ max_num_atoms_mask = (
366
+ feats["atom_to_token"].sum(1).unsqueeze(-1) > arange_max_num_atoms
367
+ )
368
+ resolved_logits = resolved_logits[:, max_num_atoms_mask.squeeze(0)]
369
+ resolved_logits = pad(
370
+ resolved_logits,
371
+ (
372
+ 0,
373
+ 0,
374
+ 0,
375
+ int(
376
+ feats["atom_pad_mask"].shape[1]
377
+ - feats["atom_pad_mask"].sum().item()
378
+ ),
379
+ ),
380
+ value=0,
381
+ )
382
+ plddt_logits = plddt_logits.reshape(B, N, self.max_num_atoms_per_token, -1)
383
+ plddt_logits = plddt_logits[:, max_num_atoms_mask.squeeze(0)]
384
+ plddt_logits = pad(
385
+ plddt_logits,
386
+ (
387
+ 0,
388
+ 0,
389
+ 0,
390
+ int(
391
+ feats["atom_pad_mask"].shape[1]
392
+ - feats["atom_pad_mask"].sum().item()
393
+ ),
394
+ ),
395
+ value=0,
396
+ )
397
+ atom_pad_mask = feats["atom_pad_mask"].repeat_interleave(multiplicity, 0)
398
+ plddt = compute_aggregated_metric(plddt_logits)
399
+
400
+ complex_plddt = (plddt * atom_pad_mask).sum(dim=-1) / atom_pad_mask.sum(
401
+ dim=-1
402
+ )
403
+ token_type = feats["mol_type"].float()
404
+ atom_to_token = feats["atom_to_token"].float()
405
+ chain_id_token = feats["asym_id"].float()
406
+ atom_type = torch.bmm(atom_to_token, token_type.unsqueeze(-1)).squeeze(-1)
407
+ is_ligand_atom = (atom_type == const.chain_type_ids["NONPOLYMER"]).float()
408
+ d_atom = torch.cdist(x_pred, x_pred)
409
+ is_contact = (d_atom < 8).float()
410
+ chain_id_atom = torch.bmm(
411
+ atom_to_token, chain_id_token.unsqueeze(-1)
412
+ ).squeeze(-1)
413
+ is_different_chain = (
414
+ chain_id_atom.unsqueeze(-1) != chain_id_atom.unsqueeze(-2)
415
+ ).float()
416
+
417
+ atom_interface_mask = torch.max(
418
+ is_contact * is_different_chain * (1 - is_ligand_atom).unsqueeze(-1),
419
+ dim=-1,
420
+ ).values
421
+ atom_non_interface_mask = (1 - atom_interface_mask) * (1 - is_ligand_atom)
422
+ iplddt_weight = (
423
+ is_ligand_atom * ligand_weight
424
+ + atom_interface_mask * interface_weight
425
+ + atom_non_interface_mask * non_interface_weight
426
+ )
427
+
428
+ complex_iplddt = (plddt * feats["atom_pad_mask"] * iplddt_weight).sum(
429
+ dim=-1
430
+ ) / torch.sum(feats["atom_pad_mask"] * iplddt_weight, dim=-1)
431
+
432
+ # Compute the gPDE and giPDE
433
+ pde = compute_aggregated_metric(pde_logits, end=32)
434
+ pred_distogram_prob = nn.functional.softmax(
435
+ pred_distogram_logits, dim=-1
436
+ ).repeat_interleave(multiplicity, 0)
437
+ contacts = torch.zeros((1, 1, 1, 64), dtype=pred_distogram_prob.dtype).to(
438
+ pred_distogram_prob.device
439
+ )
440
+ contacts[:, :, :, :20] = 1.0
441
+ prob_contact = (pred_distogram_prob * contacts).sum(-1)
442
+ token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
443
+ token_pad_pair_mask = (
444
+ token_pad_mask.unsqueeze(-1)
445
+ * token_pad_mask.unsqueeze(-2)
446
+ * (
447
+ 1
448
+ - torch.eye(
449
+ token_pad_mask.shape[1], device=token_pad_mask.device
450
+ ).unsqueeze(0)
451
+ )
452
+ )
453
+ token_pair_mask = token_pad_pair_mask * prob_contact
454
+ complex_pde = (pde * token_pair_mask).sum(dim=(1, 2)) / token_pair_mask.sum(
455
+ dim=(1, 2)
456
+ )
457
+ asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
458
+ token_interface_pair_mask = token_pair_mask * (
459
+ asym_id.unsqueeze(-1) != asym_id.unsqueeze(-2)
460
+ )
461
+ complex_ipde = (pde * token_interface_pair_mask).sum(dim=(1, 2)) / (
462
+ token_interface_pair_mask.sum(dim=(1, 2)) + 1e-5
463
+ )
464
+ out_dict = dict(
465
+ pde_logits=pde_logits,
466
+ plddt_logits=plddt_logits,
467
+ resolved_logits=resolved_logits,
468
+ pde=pde,
469
+ plddt=plddt,
470
+ complex_plddt=complex_plddt,
471
+ complex_iplddt=complex_iplddt,
472
+ complex_pde=complex_pde,
473
+ complex_ipde=complex_ipde,
474
+ )
475
+ out_dict["pae_logits"] = pae_logits
476
+ out_dict["pae"] = compute_aggregated_metric(pae_logits, end=32)
477
+
478
+ try:
479
+ ptm, iptm, ligand_iptm, protein_iptm, pair_chains_iptm = compute_ptms(
480
+ pae_logits, x_pred, feats, multiplicity
481
+ )
482
+ out_dict["ptm"] = ptm
483
+ out_dict["iptm"] = iptm
484
+ out_dict["ligand_iptm"] = ligand_iptm
485
+ out_dict["protein_iptm"] = protein_iptm
486
+ out_dict["pair_chains_iptm"] = pair_chains_iptm
487
+ except Exception as e:
488
+ print(f"Error in compute_ptms: {e}")
489
+ out_dict["ptm"] = torch.zeros_like(complex_plddt)
490
+ out_dict["iptm"] = torch.zeros_like(complex_plddt)
491
+ out_dict["ligand_iptm"] = torch.zeros_like(complex_plddt)
492
+ out_dict["protein_iptm"] = torch.zeros_like(complex_plddt)
493
+ out_dict["pair_chains_iptm"] = torch.zeros_like(complex_plddt)
494
+
495
+ return out_dict