rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,303 @@
1
+ import logging
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from rf3.model.layers.af3_diffusion_transformer import (
6
+ AtomAttentionEncoderDiffusion,
7
+ AtomTransformer,
8
+ DiffusionTransformer,
9
+ )
10
+ from rf3.model.layers.layer_utils import Transition, linearNoBias
11
+ from rf3.model.layers.pairformer_layers import (
12
+ MSAModule,
13
+ PairformerBlock,
14
+ RelativePositionEncoding,
15
+ RF3TemplateEmbedder,
16
+ )
17
+
18
+ from foundry.model.layers.blocks import FourierEmbedding
19
+ from foundry.training.checkpoint import activation_checkpointing
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ """
24
+ Glossary:
25
+ I: # tokens (coarse representation)
26
+ L: # atoms (fine representation)
27
+ M: # msa
28
+ T: # templates
29
+ D: # diffusion structure batch dim
30
+ """
31
+
32
+
33
+ class AtomAttentionDecoder(nn.Module):
34
+ def __init__(self, c_token, c_atom, c_atompair, atom_transformer):
35
+ super().__init__()
36
+ self.atom_transformer = AtomTransformer(
37
+ c_atom=c_atom, c_atompair=c_atompair, **atom_transformer
38
+ )
39
+ self.linear_1 = linearNoBias(c_token, c_atom)
40
+ self.to_r_update = nn.Sequential(
41
+ nn.LayerNorm((c_atom,)), linearNoBias(c_atom, 3)
42
+ )
43
+
44
+ def forward(
45
+ self,
46
+ f,
47
+ Ai, # [L, C_token]
48
+ Ql_skip, # [L, C_atom]
49
+ Cl_skip, # [L, C_atom]
50
+ Plm_skip, # [L, L, C_atompair]
51
+ ):
52
+ tok_idx = f["atom_to_token_map"]
53
+
54
+ @activation_checkpointing
55
+ def atom_decoder(Ai, Ql_skip, Cl_skip, Plm_skip, tok_idx):
56
+ # Broadcast per-token activiations to per-atom activations and add the skip connection
57
+ Ql = self.linear_1(Ai[..., tok_idx, :]) + Ql_skip
58
+
59
+ # Cross attention transformer.
60
+ Ql = self.atom_transformer(Ql, Cl_skip, Plm_skip)
61
+
62
+ # Map to positions update
63
+ Rl_update = self.to_r_update(Ql)
64
+
65
+ return Rl_update
66
+
67
+ return atom_decoder(Ai, Ql_skip, Cl_skip, Plm_skip, tok_idx)
68
+
69
+
70
+ class DiffusionModule(nn.Module):
71
+ def __init__(
72
+ self,
73
+ sigma_data,
74
+ c_atom,
75
+ c_atompair,
76
+ c_token,
77
+ c_s,
78
+ c_z,
79
+ f_pred,
80
+ diffusion_conditioning,
81
+ atom_attention_encoder,
82
+ diffusion_transformer,
83
+ atom_attention_decoder,
84
+ ):
85
+ super().__init__()
86
+ self.sigma_data = sigma_data
87
+ self.c_atom = c_atom
88
+ self.c_atompair = c_atompair
89
+ self.c_token = c_token
90
+ self.c_s = c_s
91
+ self.f_pred = f_pred
92
+
93
+ self.diffusion_conditioning = DiffusionConditioning(
94
+ sigma_data=sigma_data, c_s=c_s, c_z=c_z, **diffusion_conditioning
95
+ )
96
+ self.atom_attention_encoder = AtomAttentionEncoderDiffusion(
97
+ c_token=c_token,
98
+ c_s=c_s,
99
+ c_atom=c_atom,
100
+ c_atompair=c_atompair,
101
+ **atom_attention_encoder,
102
+ )
103
+ self.process_s = nn.Sequential(
104
+ nn.LayerNorm((c_s,)),
105
+ linearNoBias(c_s, c_token),
106
+ )
107
+ self.diffusion_transformer = DiffusionTransformer(
108
+ c_token=c_token, c_s=c_s, c_tokenpair=c_z, **diffusion_transformer
109
+ )
110
+ self.layer_norm_1 = nn.LayerNorm(c_token)
111
+ self.atom_attention_decoder = AtomAttentionDecoder(
112
+ c_token=c_token,
113
+ c_atom=c_atom,
114
+ c_atompair=c_atompair,
115
+ **atom_attention_decoder,
116
+ )
117
+
118
+ def forward(
119
+ self,
120
+ X_noisy_L, # [B, L, 3]
121
+ t, # [B] (0 is ground truth)
122
+ f, # Dict (Input feature dictionary)
123
+ S_inputs_I, # [B, I, C_S_input]
124
+ S_trunk_I, # [B, I, C_S_trunk]
125
+ Z_trunk_II, # [B, I, I, C_Z]
126
+ ):
127
+ # Conditioning
128
+ S_I, Z_II = self.diffusion_conditioning(
129
+ t, f, S_inputs_I.float(), S_trunk_I.float(), Z_trunk_II.float()
130
+ )
131
+
132
+ # Scale positions to dimensionless vectors with approximately unit variance
133
+ if self.f_pred == "edm":
134
+ R_noisy_L = X_noisy_L / torch.sqrt(
135
+ t[..., None, None] ** 2 + self.sigma_data**2
136
+ )
137
+ elif self.f_pred == "unconditioned":
138
+ R_noisy_L = torch.zeros_like(X_noisy_L)
139
+ elif self.f_pred == "noise_pred":
140
+ R_noisy_L = X_noisy_L
141
+ else:
142
+ raise Exception(f"{self.f_pred=} unrecognized")
143
+ # Sequence-local Atom Attention and aggregation to coarse-grained tokens
144
+ A_I, Q_skip_L, C_skip_L, P_skip_LL = self.atom_attention_encoder(
145
+ f, R_noisy_L, S_trunk_I.float(), Z_II
146
+ )
147
+ # Full self-attention on token level
148
+
149
+ A_I = A_I + self.process_s(S_I)
150
+ A_I = self.diffusion_transformer(A_I, S_I, Z_II, Beta_II=None)
151
+ A_I = self.layer_norm_1(A_I)
152
+
153
+ # Broadcast token activations to atoms and run Sequence-local Atom Attention
154
+ R_update_L = self.atom_attention_decoder(
155
+ f, A_I.float(), Q_skip_L, C_skip_L, P_skip_LL
156
+ )
157
+ # Rescale updates to positions and combine with input positions
158
+ if self.f_pred == "edm":
159
+ X_out_L = (self.sigma_data**2 / (self.sigma_data**2 + t**2))[
160
+ ..., None, None
161
+ ] * X_noisy_L + (self.sigma_data * t / (self.sigma_data**2 + t**2) ** 0.5)[
162
+ ..., None, None
163
+ ] * R_update_L
164
+ elif self.f_pred == "unconditioned":
165
+ X_out_L = R_update_L
166
+ elif self.f_pred == "noise_pred":
167
+ X_out_L = X_noisy_L + R_update_L
168
+ else:
169
+ raise Exception(f"{self.f_pred=} unrecognized")
170
+
171
+ return X_out_L
172
+
173
+
174
+ class DiffusionConditioning(nn.Module):
175
+ def __init__(
176
+ self, sigma_data, c_z, c_s, c_s_inputs, c_t_embed, relative_position_encoding
177
+ ):
178
+ super().__init__()
179
+ self.sigma_data = sigma_data
180
+ self.relative_position_encoding = RelativePositionEncoding(
181
+ c_z=c_z, **relative_position_encoding
182
+ )
183
+ self.to_zii = nn.Sequential(
184
+ nn.LayerNorm(
185
+ c_z * 2
186
+ ), # Operates on concatenated (z_ij_trunk: [..., c_z]), RelativePositionalEncoding: [..., c_z])
187
+ linearNoBias(c_z * 2, c_z),
188
+ )
189
+ self.transition_1 = nn.ModuleList(
190
+ [
191
+ Transition(c=c_z, n=2),
192
+ Transition(c=c_z, n=2),
193
+ ]
194
+ )
195
+ self.to_si = nn.Sequential(
196
+ nn.LayerNorm(c_s + c_s_inputs), linearNoBias(c_s + c_s_inputs, c_s)
197
+ )
198
+ c_t_embed = 256
199
+ self.fourier_embedding = FourierEmbedding(c_t_embed)
200
+ self.process_n = nn.Sequential(
201
+ nn.LayerNorm(c_t_embed), linearNoBias(c_t_embed, c_s)
202
+ )
203
+ self.transition_2 = nn.ModuleList(
204
+ [
205
+ Transition(c=c_s, n=2),
206
+ Transition(c=c_s, n=2),
207
+ ]
208
+ )
209
+
210
+ def forward(self, t, f, S_inputs_I, S_trunk_I, Z_trunk_II):
211
+ # Pair conditioning
212
+ Z_II = torch.cat([Z_trunk_II, self.relative_position_encoding(f)], dim=-1)
213
+
214
+ @activation_checkpointing
215
+ def _run_conditioning(Z_II, S_trunk_I, S_inputs_I):
216
+ Z_II = self.to_zii(Z_II)
217
+ for b in range(2):
218
+ Z_II = Z_II + self.transition_1[b](Z_II)
219
+
220
+ # Single conditioning
221
+ S_I = torch.cat([S_trunk_I, S_inputs_I], dim=-1)
222
+ S_I = self.to_si(S_I)
223
+ N_D = self.fourier_embedding(1 / 4 * torch.log(t / self.sigma_data))
224
+ S_I = self.process_n(N_D).unsqueeze(-2) + S_I
225
+ for b in range(2):
226
+ S_I = S_I + self.transition_2[b](S_I)
227
+
228
+ return S_I, Z_II
229
+
230
+ return _run_conditioning(Z_II, S_trunk_I, S_inputs_I)
231
+
232
+
233
+ class DistogramHead(nn.Module):
234
+ def __init__(
235
+ self,
236
+ c_z,
237
+ bins,
238
+ ):
239
+ super().__init__()
240
+ self.predictor = nn.Linear(c_z, bins)
241
+ self.reset_parameters()
242
+
243
+ def reset_parameters(self):
244
+ # initialize linear layer for final logit prediction
245
+ nn.init.zeros_(self.predictor.weight)
246
+ nn.init.zeros_(self.predictor.bias)
247
+
248
+ def forward(
249
+ self,
250
+ Z_II,
251
+ ):
252
+ return self.predictor(
253
+ Z_II + Z_II.transpose(-2, -3) # symmetrize pair features
254
+ )
255
+
256
+
257
+ class Recycler(nn.Module):
258
+ def __init__(
259
+ self,
260
+ c_s,
261
+ c_z,
262
+ template_embedder,
263
+ msa_module,
264
+ n_pairformer_blocks,
265
+ pairformer_block,
266
+ ):
267
+ super().__init__()
268
+ self.c_z = c_z
269
+ self.process_zh = nn.Sequential(
270
+ nn.LayerNorm(c_z),
271
+ linearNoBias(c_z, c_z),
272
+ )
273
+ self.template_embedder = RF3TemplateEmbedder(c_z=c_z, **template_embedder)
274
+ self.msa_module = MSAModule(**msa_module)
275
+ self.process_sh = nn.Sequential(
276
+ nn.LayerNorm(c_s),
277
+ linearNoBias(c_s, c_s),
278
+ )
279
+ self.pairformer_stack = nn.ModuleList(
280
+ [
281
+ PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
282
+ for _ in range(n_pairformer_blocks)
283
+ ]
284
+ )
285
+
286
+ def forward(
287
+ self,
288
+ f,
289
+ S_inputs_I,
290
+ S_init_I,
291
+ Z_init_II,
292
+ S_I,
293
+ Z_II,
294
+ ):
295
+ Z_II = Z_init_II + self.process_zh(Z_II)
296
+ Z_II = Z_II + self.template_embedder(f, Z_II)
297
+ # NOTE: Implementing bugfix from the Protenix Technical report, where residual-connecting the MSA module is redundant
298
+ # Reference: https://github.com/bytedance/Protenix/blob/main/Protenix_Technical_Report.pdf
299
+ Z_II = self.msa_module(f, Z_II, S_inputs_I)
300
+ S_I = S_init_I + self.process_sh(S_I)
301
+ for block in self.pairformer_stack:
302
+ S_I, Z_II = block(S_I, Z_II)
303
+ return S_I, Z_II
@@ -0,0 +1,255 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from rf3.model.RF3_structure import PairformerBlock, linearNoBias
5
+
6
+ # TODO: Get from RF2AA encoding instead
7
+ CHEM_DATA_LEGACY = {"NHEAVY": 23, "aa2num": {"UNK": 20, "GLY": 7, "MAS": 21}}
8
+
9
+
10
+ def discretize_distance_matrix(
11
+ distance_matrix, num_bins=38, min_distance=3.25, max_distance=50.75
12
+ ):
13
+ # Calculate the bin width
14
+ bin_width = (max_distance - min_distance) / num_bins
15
+ bins = (
16
+ torch.arange(num_bins, device=distance_matrix.device) * bin_width + min_distance
17
+ )
18
+
19
+ # Discretize distances into bins (bucketize automatically places out-of-range values in the last bin)
20
+ binned_distances = torch.bucketize(distance_matrix, bins)
21
+
22
+ return binned_distances
23
+
24
+
25
+ class ConfidenceHead(nn.Module):
26
+ """Algorithm 31"""
27
+
28
+ def __init__(
29
+ self,
30
+ c_s,
31
+ c_z,
32
+ n_pairformer_layers,
33
+ pairformer,
34
+ n_bins_pae,
35
+ n_bins_pde,
36
+ n_bins_plddt,
37
+ n_bins_exp_resolved,
38
+ use_Cb_distances=False,
39
+ use_af3_style_binning_and_final_layer_norms=False,
40
+ symmetrize_Cb_logits=True,
41
+ layer_norm_along_feature_dimension=False,
42
+ ):
43
+ super(ConfidenceHead, self).__init__()
44
+ self.process_s_inputs_right = linearNoBias(449, c_z)
45
+ self.process_s_inputs_left = linearNoBias(449, c_z)
46
+ self.use_af3_style_binning_and_final_layer_norms = (
47
+ use_af3_style_binning_and_final_layer_norms
48
+ )
49
+ self.layer_norm_along_feature_dimension = layer_norm_along_feature_dimension
50
+ if self.use_af3_style_binning_and_final_layer_norms:
51
+ self.layernorm_pde = nn.LayerNorm(c_z)
52
+ self.layernorm_pae = nn.LayerNorm(c_z)
53
+ self.layernorm_plddt = nn.LayerNorm(c_s)
54
+ self.layernorm_exp_resolved = nn.LayerNorm(c_s)
55
+ self.process_pred_distances = linearNoBias(40, c_z)
56
+ else:
57
+ self.process_pred_distances = linearNoBias(11, c_z)
58
+
59
+ self.pairformer = nn.ModuleList(
60
+ [
61
+ PairformerBlock(c_s=c_s, c_z=c_z, **pairformer)
62
+ for _ in range(n_pairformer_layers)
63
+ ]
64
+ )
65
+
66
+ self.predict_pae = linearNoBias(c_z, n_bins_pae)
67
+ self.predict_pde = linearNoBias(c_z, n_bins_pde)
68
+ self.predict_plddt = linearNoBias(
69
+ c_s, CHEM_DATA_LEGACY["NHEAVY"] * n_bins_plddt
70
+ )
71
+ self.predict_exp_resolved = linearNoBias(
72
+ c_s, CHEM_DATA_LEGACY["NHEAVY"] * n_bins_exp_resolved
73
+ )
74
+ self.use_Cb_distances = use_Cb_distances
75
+ if self.use_Cb_distances:
76
+ self.process_Cb_distances = linearNoBias(25, c_z)
77
+ self.symmetrize_Cb_logits = symmetrize_Cb_logits
78
+
79
+ def reset_parameters(self):
80
+ for m in self.modules():
81
+ if isinstance(m, nn.Linear):
82
+ nn.init.xavier_uniform_(m.weight)
83
+ if m.bias is not None:
84
+ nn.init.constant_(m.bias, 0)
85
+
86
+ def forward(
87
+ self,
88
+ S_inputs_I,
89
+ S_trunk_I,
90
+ Z_trunk_II,
91
+ X_pred_L,
92
+ seq,
93
+ rep_atoms,
94
+ frame_atom_idxs=None,
95
+ ):
96
+ # stopgrad on S_trunk_I, Z_trunk_II, X_pred_L but not S_inputs_I (4.3.5)
97
+ S_trunk_I = S_trunk_I.detach().float() # B, L, 384
98
+ Z_trunk_II = Z_trunk_II.detach().float() # B, L, L, 128
99
+ if X_pred_L is not None:
100
+ X_pred_L = X_pred_L.detach().float() # B, n_atoms, 3
101
+ S_inputs_I = S_inputs_I.detach().float() # B, L, 384
102
+ seq = seq.detach()
103
+
104
+ if self.layer_norm_along_feature_dimension:
105
+ # do a layer norm on S_trunk_I
106
+ S_trunk_I = F.layer_norm(S_trunk_I, normalized_shape=(S_trunk_I.shape[-1]))
107
+ # do a layer norm on Z_trunk_II
108
+ Z_trunk_II = F.layer_norm(
109
+ Z_trunk_II, normalized_shape=(Z_trunk_II.shape[-1])
110
+ )
111
+ # do a layer norm on S_inputs_I
112
+ S_inputs_I = F.layer_norm(
113
+ S_inputs_I, normalized_shape=(S_inputs_I.shape[-1])
114
+ )
115
+ else:
116
+ S_trunk_I = F.layer_norm(S_trunk_I, normalized_shape=(S_trunk_I.shape))
117
+ Z_trunk_II = F.layer_norm(Z_trunk_II, normalized_shape=(Z_trunk_II.shape))
118
+ S_inputs_I = F.layer_norm(S_inputs_I, normalized_shape=(S_inputs_I.shape))
119
+
120
+ # embed S_inputs_I twice
121
+ S_inputs_I_right = self.process_s_inputs_right(S_inputs_I)
122
+ S_inputs_I_left = self.process_s_inputs_left(S_inputs_I)
123
+ # add outer product of two linear embeddings of S_inputs_I to Z_II
124
+ # TODO: check the unsqueezed dimension is the correct one
125
+ Z_trunk_II = Z_trunk_II + (
126
+ S_inputs_I_right.unsqueeze(-2) + S_inputs_I_left.unsqueeze(-3)
127
+ )
128
+
129
+ # embed distances of representative atom from every token
130
+ # in the pair representation
131
+ # if no coords are input, skip this connection
132
+ if X_pred_L is not None:
133
+ X_pred_rep_I = X_pred_L.index_select(1, rep_atoms)
134
+ dist = torch.cdist(X_pred_rep_I, X_pred_rep_I)
135
+ if not self.use_af3_style_binning_and_final_layer_norms:
136
+ # bins are 3.375 to 20.375 in 1.75 increments according to pseudocode
137
+ dist_one_hot = F.one_hot(
138
+ discretize_distance_matrix(
139
+ dist, min_distance=3.375, max_distance=20.875, num_bins=10
140
+ ),
141
+ num_classes=11,
142
+ )
143
+ else:
144
+ # published code is 3.25 to 50.75, with 39 bins
145
+ dist_one_hot = F.one_hot(
146
+ discretize_distance_matrix(
147
+ dist, min_distance=3.25, max_distance=50.75, num_bins=39
148
+ ),
149
+ num_classes=40,
150
+ )
151
+
152
+ Z_trunk_II = Z_trunk_II + self.process_pred_distances(dist_one_hot.float())
153
+
154
+ if self.use_Cb_distances:
155
+ # embed difference between observed cb and ideal cb positions
156
+ Cb_distances = calc_Cb_distances(
157
+ X_pred_L, seq, rep_atoms, frame_atom_idxs
158
+ )
159
+ Cb_distances_one_hot = F.one_hot(
160
+ discretize_distance_matrix(
161
+ Cb_distances,
162
+ min_distance=0.0001,
163
+ max_distance=0.25,
164
+ num_bins=24,
165
+ ),
166
+ num_classes=25,
167
+ )
168
+ Cb_logits = self.process_Cb_distances(Cb_distances_one_hot.float())
169
+ # symmetrize the logits
170
+ if self.symmetrize_Cb_logits:
171
+ Cb_logits = Cb_logits[:, None, :, :] + Cb_logits[:, :, None, :]
172
+ else:
173
+ Cb_logits = Cb_logits[:, None, :, :]
174
+
175
+ Z_trunk_II = Z_trunk_II + Cb_logits
176
+
177
+ if not self.use_af3_style_binning_and_final_layer_norms:
178
+ S_trunk_residual_I = S_trunk_I.clone()
179
+ Z_trunk_residual_II = Z_trunk_II.clone()
180
+
181
+ # process with pairformer stack
182
+ for n in range(len(self.pairformer)):
183
+ S_trunk_I, Z_trunk_II = self.pairformer[n](S_trunk_I, Z_trunk_II)
184
+
185
+ # despite doing so in their pseudocode, af3's published code does not add the residual back
186
+ if not self.use_af3_style_binning_and_final_layer_norms:
187
+ S_trunk_I = S_trunk_residual_I + S_trunk_I
188
+ Z_trunk_II = Z_trunk_residual_II + Z_trunk_II
189
+
190
+ # linearly project for each prediction task
191
+ pde_logits = self.predict_pde(
192
+ Z_trunk_II + Z_trunk_II.transpose(-2, -3)
193
+ ) # BUG: needs to be symmetrized correctly
194
+
195
+ pae_logits = self.predict_pae(Z_trunk_II)
196
+
197
+ plddt_logits = self.predict_plddt(S_trunk_I)
198
+ exp_resolved_logits = self.predict_exp_resolved(S_trunk_I)
199
+
200
+ # af3's published code does not add the residual back and has some additional layernorms before the linear projections
201
+ # they also do the pde slightly differently, adding the transpose after the linear projection
202
+ else:
203
+ left_distance_logits = self.predict_pde(self.layernorm_pde(Z_trunk_II))
204
+ right_distance_logits = left_distance_logits.transpose(-2, -3)
205
+ pde_logits = left_distance_logits + right_distance_logits
206
+
207
+ pae_logits = self.predict_pae(self.layernorm_pae(Z_trunk_II))
208
+ plddt_logits = self.predict_plddt(self.layernorm_plddt(S_trunk_I))
209
+ exp_resolved_logits = self.predict_exp_resolved(
210
+ self.layernorm_exp_resolved(S_trunk_I)
211
+ )
212
+
213
+ return dict(
214
+ pde_logits=pde_logits,
215
+ pae_logits=pae_logits,
216
+ plddt_logits=plddt_logits,
217
+ exp_resolved_logits=exp_resolved_logits,
218
+ )
219
+
220
+
221
+ def calc_Cb_distances(X_pred_L, seq, rep_atoms, frame_atom_idxs):
222
+ frame_atom_idxs = frame_atom_idxs.unsqueeze(0).expand(X_pred_L.shape[0], -1, -1)
223
+
224
+ N = torch.gather(
225
+ X_pred_L, 1, frame_atom_idxs[..., 0].unsqueeze(-1).expand(-1, -1, 3)
226
+ )
227
+ Ca = torch.gather(
228
+ X_pred_L, 1, frame_atom_idxs[..., 1].unsqueeze(-1).expand(-1, -1, 3)
229
+ )
230
+ C = torch.gather(
231
+ X_pred_L, 1, frame_atom_idxs[..., 2].unsqueeze(-1).expand(-1, -1, 3)
232
+ )
233
+ Cb = X_pred_L.index_select(1, rep_atoms)
234
+
235
+ is_valid_Cb = (
236
+ (seq != CHEM_DATA_LEGACY.aa2num["UNK"])
237
+ & (seq != CHEM_DATA_LEGACY.aa2num["GLY"])
238
+ & (seq != CHEM_DATA_LEGACY.aa2num["MAS"])
239
+ )
240
+
241
+ def _legacy_is_protein(seq):
242
+ return (seq >= 0).all() & (seq < 20).all()
243
+
244
+ is_valid_Cb = is_valid_Cb & _legacy_is_protein(seq)
245
+
246
+ b = Ca - N
247
+ c = C - Ca
248
+ a = torch.cross(b, c, dim=-1)
249
+
250
+ ideal_Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca
251
+
252
+ Cb_distances = torch.norm(Cb - ideal_Cb, dim=-1)
253
+ Cb_distances[:, ~is_valid_Cb] = 0.0
254
+
255
+ return Cb_distances