rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,313 @@
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from jaxtyping import Float
8
+ from opt_einsum import contract as einsum
9
+ from rf3.util_module import init_lecun_normal
10
+
11
+ from foundry import SHOULD_USE_CUEQUIVARIANCE
12
+ from foundry.training.checkpoint import activation_checkpointing
13
+
14
+ if SHOULD_USE_CUEQUIVARIANCE:
15
+ import cuequivariance_torch as cuet
16
+
17
+
18
+ class TriangleAttention(nn.Module):
19
+ """Implementation of Triangle Attention from AlphaFold3.
20
+
21
+ Routes to either cuEquivariance or vanilla implementation based on configuration.
22
+
23
+ Args:
24
+ pair: Pair representation tensor of shape (B, L, L, d_pair)
25
+
26
+ Returns:
27
+ Updated pair representation tensor of shape (B, L, L, d_pair)
28
+
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ d_pair,
34
+ n_head=4,
35
+ d_hidden=32,
36
+ p_drop=0.1, # noqa: E402
37
+ start_node=True,
38
+ use_cuequivariance=True,
39
+ ):
40
+ super(TriangleAttention, self).__init__()
41
+
42
+ self.norm = nn.LayerNorm(d_pair)
43
+
44
+ self.to_q = nn.Linear(d_pair, n_head * d_hidden, bias=False)
45
+ self.to_k = nn.Linear(d_pair, n_head * d_hidden, bias=False)
46
+ self.to_v = nn.Linear(d_pair, n_head * d_hidden, bias=False)
47
+ self.to_b = nn.Linear(d_pair, n_head, bias=False)
48
+ self.to_g = nn.Linear(d_pair, n_head * d_hidden)
49
+ self.to_out = nn.Linear(n_head * d_hidden, d_pair)
50
+
51
+ self.scaling = 1 / math.sqrt(d_hidden)
52
+
53
+ self.h = n_head
54
+ self.dim = d_hidden
55
+ self.start_node = start_node
56
+
57
+ self.use_cuequivariance = use_cuequivariance
58
+
59
+ self.reset_parameter()
60
+
61
+ def reset_parameter(self):
62
+ # query/key/value projection: Glorot uniform / Xavier uniform
63
+ nn.init.xavier_uniform_(self.to_q.weight)
64
+ nn.init.xavier_uniform_(self.to_k.weight)
65
+ nn.init.xavier_uniform_(self.to_v.weight)
66
+
67
+ # bias: normal distribution
68
+ self.to_b = init_lecun_normal(self.to_b)
69
+
70
+ # gating: zero weights, one biases (mostly open gate at the begining)
71
+ nn.init.zeros_(self.to_g.weight)
72
+ nn.init.ones_(self.to_g.bias)
73
+
74
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
75
+ nn.init.zeros_(self.to_out.weight)
76
+ nn.init.zeros_(self.to_out.bias)
77
+
78
+ @activation_checkpointing
79
+ def forward(self, pair):
80
+ """Forward pass of triangle attention."""
81
+ pair = self.norm(pair)
82
+ bias = self.to_b(pair) # (B, L, L, h)
83
+
84
+ if not self.start_node:
85
+ pair = rearrange(pair, "b i j d -> b j i d")
86
+
87
+ # Route to appropriate implementation
88
+ if self.use_cuequivariance and SHOULD_USE_CUEQUIVARIANCE:
89
+ out = self._forward_cuequivariance(pair, bias)
90
+ else:
91
+ out = self._forward_vanilla(pair, bias)
92
+
93
+ if not self.start_node:
94
+ out = rearrange(out, "b i j d -> b j i d")
95
+
96
+ # output projection
97
+ out = self.to_out(out)
98
+ return out
99
+
100
+ def _forward_cuequivariance(self, pair, bias):
101
+ """cuEquivariance triangle attention implementation."""
102
+ # Handle autocast conversion
103
+ if torch.is_autocast_enabled():
104
+ dtype = torch.get_autocast_dtype("cuda")
105
+ pair = pair.to(dtype=dtype)
106
+ bias = bias.to(dtype=dtype)
107
+
108
+ assert (
109
+ pair.dtype == torch.bfloat16 and bias.dtype == torch.bfloat16
110
+ ), f"cuEquivariance requires bfloat16 inputs (got pair={pair.dtype}, bias={bias.dtype})"
111
+
112
+ # Gate computation
113
+ gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
114
+
115
+ # Project and reshape to cuEquivariance format: (B, L, H, L, D)
116
+ query = rearrange(self.to_q(pair), "b i j (h d) -> b i h j d", h=self.h)
117
+ key = rearrange(self.to_k(pair), "b i k (h d) -> b i h k d", h=self.h)
118
+ value = rearrange(self.to_v(pair), "b i k (h d) -> b i h k d", h=self.h)
119
+
120
+ # Bias: (B, L, L, H) -> (B, 1, H, L, L)
121
+ bias_cueq = rearrange(bias, "b i j h -> b 1 h i j")
122
+
123
+ # Call cuEquivariance triangle attention
124
+ out_cueq = cuet.triangle_attention(
125
+ query, key, value, bias=bias_cueq, scale=self.scaling
126
+ )
127
+
128
+ # Reshape back: (B, L, H, L, D) -> (B, L, L, H*D)
129
+ out = rearrange(out_cueq, "b i h j d -> b i j (h d)")
130
+ out = gate * out # gated attention
131
+ return out
132
+
133
+ def _forward_vanilla(self, pair, bias):
134
+ """Vanilla PyTorch triangle attention implementation."""
135
+ B, L = pair.shape[:2]
136
+
137
+ # Gate computation
138
+ gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
139
+
140
+ # Project and reshape to vanilla format: (B, L, L, H, D)
141
+ query = self.to_q(pair).reshape(B, L, L, self.h, -1)
142
+ key = self.to_k(pair).reshape(B, L, L, self.h, -1)
143
+ value = self.to_v(pair).reshape(B, L, L, self.h, -1)
144
+
145
+ query = query * self.scaling
146
+
147
+ attn = einsum("bijhd,bikhd->bijkh", query, key)
148
+ attn = attn + bias.unsqueeze(1).expand(-1, L, -1, -1, -1) # (bijkh)
149
+ attn = F.softmax(attn, dim=-2)
150
+
151
+ out = einsum("bijkh,bikhd->bijhd", attn, value).reshape(B, L, L, -1)
152
+ out = gate * out # gated attention
153
+
154
+ return out
155
+
156
+
157
+ class TriangleMultiplication(nn.Module):
158
+ """Implementation of Triangle Multiplicative Update from AlphaFold3.
159
+
160
+ Routes to either cuEquivariance or naive implementation based on configuration.
161
+
162
+ Args:
163
+ d_pair: Pair representation dimension (must equal d_hidden for cuEquivariance)
164
+ d_hidden: Hidden dimension (must equal d_pair for cuEquivariance)
165
+ direction: "outgoing" or "incoming" triangle multiplication direction
166
+ bias: Whether to use bias in normalization layers
167
+ use_cuequivariance: Whether to use cuEquivariance fused kernel when available
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ d_pair,
173
+ d_hidden=None,
174
+ direction="outgoing",
175
+ bias=True,
176
+ use_cuequivariance=True,
177
+ ):
178
+ super(TriangleMultiplication, self).__init__()
179
+
180
+ # Set d_hidden to d_pair if not specified
181
+ if d_hidden is None:
182
+ d_hidden = d_pair
183
+
184
+ self.d_pair = d_pair
185
+ self.d_hidden = d_hidden
186
+
187
+ # Validate direction parameter
188
+ if direction not in ["outgoing", "incoming"]:
189
+ raise ValueError(
190
+ f"direction must be 'outgoing' or 'incoming', got '{direction}'"
191
+ )
192
+ self.direction = direction
193
+
194
+ self.use_cuequivariance = use_cuequivariance
195
+
196
+ if self.use_cuequivariance:
197
+ # cuEquivariance kernel requires d_pair == d_hidden...
198
+ assert (
199
+ d_pair == d_hidden
200
+ ), "cuEquivariance triangle multiplication requires d_pair == d_hidden"
201
+ # ... and d_pair must be a multiple of 32
202
+ assert (
203
+ d_pair % 32 == 0
204
+ ), "cuEquivariance triangle multiplication requires d_pair to be a multiple of 32"
205
+
206
+ # Input normalization (optional bias)
207
+ self.norm_in = nn.LayerNorm(d_pair, bias=bias)
208
+
209
+ # Input projections: combine left and right projections (2*d_hidden, d_pair) (no bias)
210
+ self.p_in = nn.Linear(d_pair, 2 * d_hidden, bias=False)
211
+
212
+ # Input gating: combine left and right gates (2*d_hidden, d_pair) (no bias)
213
+ self.g_in = nn.Linear(d_pair, 2 * d_hidden, bias=False)
214
+
215
+ # Output normalization (optional bias)
216
+ self.norm_out = nn.LayerNorm(d_hidden, bias=bias)
217
+
218
+ # Output projection (no bias)
219
+ self.p_out = nn.Linear(d_hidden, d_pair, bias=False)
220
+
221
+ # Output gating (no bias)
222
+ self.g_out = nn.Linear(d_pair, d_pair, bias=False)
223
+
224
+ self.reset_parameters()
225
+
226
+ def reset_parameters(self):
227
+ # Input projections: lecun normal distribution for regular linear weights
228
+ self.p_in = init_lecun_normal(self.p_in)
229
+
230
+ # We use default PyTorch initialization for the other parameters, as in AF-3 they do not specify their
231
+ # weight initialization schemes. Without bias, e.g., the gate initialization from AF-2 is not correct.
232
+
233
+ def forward(
234
+ self, pair: Float[torch.Tensor, "B N N D"]
235
+ ) -> Float[torch.Tensor, "B N N D"]:
236
+ """Forward pass of triangle multiplication."""
237
+ # Route to appropriate implementation
238
+ if self.use_cuequivariance and SHOULD_USE_CUEQUIVARIANCE:
239
+ return self._forward_cuequivariance(pair)
240
+ else:
241
+ return self._forward_vanilla(pair)
242
+
243
+ def _forward_vanilla(
244
+ self, pair: Float[torch.Tensor, "B N N D"]
245
+ ) -> Float[torch.Tensor, "B N N D"]:
246
+ """Vanilla PyTorch triangle multiplication implementation."""
247
+ B, L = pair.shape[:2]
248
+
249
+ # Input normalization
250
+ pair_norm = self.norm_in(pair)
251
+
252
+ # Input projections: get combined output and split
253
+ p_combined = self.p_in(pair_norm) # (B, L, L, 2*d_hidden)
254
+ left = p_combined[..., : self.d_hidden] # (B, L, L, d_hidden)
255
+ right = p_combined[..., self.d_hidden :] # (B, L, L, d_hidden)
256
+
257
+ # Input gating: get combined output and split
258
+ g_combined = self.g_in(pair_norm) # (B, L, L, 2*d_hidden)
259
+ left_gate = torch.sigmoid(g_combined[..., : self.d_hidden])
260
+ right_gate = torch.sigmoid(g_combined[..., self.d_hidden :])
261
+
262
+ # Apply gating
263
+ left = left_gate * left
264
+ right = right_gate * right
265
+
266
+ # Triangle multiplication based on direction
267
+ if self.direction == "outgoing":
268
+ out = torch.einsum("bikd,bjkd->bijd", left, right / float(L))
269
+ else: # incoming
270
+ out = torch.einsum("bkid,bkjd->bijd", left, right / float(L))
271
+
272
+ # Output normalization
273
+ out = self.norm_out(out)
274
+
275
+ # Output projection
276
+ out = self.p_out(out)
277
+
278
+ # Output gating
279
+ gate = torch.sigmoid(self.g_out(pair_norm))
280
+ out = gate * out
281
+
282
+ return out
283
+
284
+ def _forward_cuequivariance(
285
+ self, pair: Float[torch.Tensor, "B N N D"]
286
+ ) -> Float[torch.Tensor, "B N N D"]:
287
+ """cuEquivariance triangle multiplication implementation."""
288
+ # Handle autocast conversion
289
+ # (Use bfloat16 for optimal performance)
290
+ if torch.is_autocast_enabled():
291
+ dtype = torch.get_autocast_dtype("cuda")
292
+ pair = pair.to(dtype=dtype)
293
+
294
+ assert (
295
+ pair.dtype == torch.bfloat16
296
+ ), "cuEquivariance requires bfloat16 inputs for optimal performance"
297
+
298
+ output = cuet.triangle_multiplicative_update(
299
+ x=pair,
300
+ direction=self.direction,
301
+ mask=None,
302
+ norm_in_weight=self.norm_in.weight,
303
+ norm_in_bias=self.norm_in.bias,
304
+ p_in_weight=self.p_in.weight, # (2*d_hidden, d_pair)
305
+ g_in_weight=self.g_in.weight, # (2*d_hidden, d_pair)
306
+ norm_out_weight=self.norm_out.weight,
307
+ norm_out_bias=self.norm_out.bias,
308
+ p_out_weight=self.p_out.weight, # (d_pair, d_pair) since d_hidden == d_pair
309
+ g_out_weight=self.g_out.weight, # (d_pair, d_pair)
310
+ eps=1e-5,
311
+ )
312
+
313
+ return output
@@ -0,0 +1,127 @@
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn.functional import silu
7
+
8
+ from foundry.training.checkpoint import activation_checkpointing
9
+
10
+ linearNoBias = partial(torch.nn.Linear, bias=False)
11
+
12
+
13
+ def collapse(x, L):
14
+ return x.reshape((L, x.numel() // L))
15
+
16
+
17
+ class MultiDimLinear(nn.Linear):
18
+ def __init__(self, in_features, out_shape, **kwargs):
19
+ self.out_shape = out_shape
20
+ out_features = np.prod(out_shape)
21
+ super().__init__(in_features, out_features, **kwargs)
22
+ self.reset_parameters()
23
+
24
+ def reset_parameters(self, **kwargs) -> None:
25
+ super().reset_parameters()
26
+ nn.init.xavier_uniform_(self.weight)
27
+
28
+ def forward(self, x):
29
+ out = super().forward(x)
30
+ return out.reshape(x.shape[:-1] + self.out_shape)
31
+
32
+
33
+ class LinearBiasInit(nn.Linear):
34
+ def __init__(self, *args, biasinit, **kwargs):
35
+ assert biasinit == -2.0 # Sanity check
36
+ self.biasinit = biasinit
37
+ super().__init__(*args, **kwargs)
38
+
39
+ def reset_parameters(self) -> None:
40
+ super().reset_parameters()
41
+ self.bias.data.fill_(self.biasinit)
42
+
43
+
44
+ class Transition(nn.Module):
45
+ def __init__(self, n, c):
46
+ super().__init__()
47
+ self.layer_norm_1 = nn.LayerNorm(c)
48
+ self.linear_1 = linearNoBias(c, n * c)
49
+ self.linear_2 = linearNoBias(c, n * c)
50
+ self.linear_3 = linearNoBias(n * c, c)
51
+
52
+ @activation_checkpointing
53
+ def forward(
54
+ self,
55
+ X,
56
+ ):
57
+ X = self.layer_norm_1(X)
58
+ A = self.linear_1(X)
59
+ B = self.linear_2(X)
60
+ X = self.linear_3(silu(A) * B)
61
+ return X
62
+
63
+
64
+ class AdaLN(nn.Module):
65
+ def __init__(self, c_a, c_s, n=2):
66
+ super().__init__()
67
+ self.ln_a = nn.LayerNorm(normalized_shape=(c_a,), elementwise_affine=False)
68
+ self.ln_s = nn.LayerNorm(normalized_shape=(c_s,), bias=False)
69
+ self.to_gain = nn.Sequential(
70
+ nn.Linear(c_s, c_a),
71
+ nn.Sigmoid(),
72
+ )
73
+ self.to_bias = linearNoBias(c_s, c_a)
74
+
75
+ def forward(
76
+ self,
77
+ Ai, # [B, I, C_a]
78
+ Si, # [B, I, C_s]
79
+ ):
80
+ """
81
+ Output:
82
+ [B, I, C_a]
83
+ """
84
+ Ai = self.ln_a(Ai)
85
+ Si = self.ln_s(Si)
86
+ return self.to_gain(Si) * Ai + self.to_bias(Si)
87
+
88
+
89
+ def create_batch_dimension_if_not_present(batched_n_dim):
90
+ """
91
+ Decorator for adapting a function which expects batched arguments with ndim `batched_n_dim` also
92
+ accept unbatched arguments.
93
+ """
94
+
95
+ def wrap(f):
96
+ def _wrap(arg):
97
+ inserted_batch_dim = False
98
+ if arg.ndim == batched_n_dim - 1:
99
+ arg = arg[None]
100
+ inserted_batch_dim = True
101
+ elif arg.ndim == batched_n_dim:
102
+ pass
103
+ else:
104
+ raise Exception(
105
+ f"arg must have {batched_n_dim - 1} or {batched_n_dim} dimensions, got shape {arg.shape=}"
106
+ )
107
+ o = f(arg)
108
+
109
+ if inserted_batch_dim:
110
+ assert o.shape[0] == 1, f"{o.shape=}[0] != 1"
111
+ return o[0]
112
+ return o
113
+
114
+ return _wrap
115
+
116
+ return wrap
117
+
118
+
119
+ def unpack_args_for_checkpointing(arg_names):
120
+ def wrap(f):
121
+ def _wrap(*args):
122
+ f = args[0]
123
+ return f(**dict(zip(arg_names, args)))
124
+
125
+ return _wrap
126
+
127
+ return wrap
@@ -0,0 +1,118 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ from jaxtyping import Float
5
+
6
+
7
+ class ConformerEmbeddingWeightedAverage(nn.Module):
8
+ """Learned weighted average of reference conformer embeddings.
9
+
10
+ Args:
11
+ atom_level_embedding_dim: Dimension of the input atom-level embeddings (default: 384, for EGRET)
12
+ c_atompair: Dimension of the atom-pair embeddings (default: 16)
13
+ c_atom: Dimension of the output atom embeddings (default: 128)
14
+ n_conformers: Number of conformers to expect (default: 8)
15
+ dropout_rate: Dropout rate for regularization (default: 0.1)
16
+ use_layer_norm: Whether to apply layer normalization to the output (default: True)
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ atom_level_embedding_dim: int,
22
+ c_atompair: int,
23
+ c_atom: int,
24
+ n_conformers: int = 8,
25
+ dropout_rate: float = 0.1,
26
+ use_layer_norm: bool = True,
27
+ ):
28
+ super().__init__()
29
+
30
+ self.n_conformers = n_conformers
31
+ self.atom_level_embedding_dim = atom_level_embedding_dim
32
+
33
+ # Downcast MLP from atom_level_embedding_dim to c_atompair
34
+ self.process_atom_level_embedding = nn.Sequential(
35
+ nn.Linear(atom_level_embedding_dim, atom_level_embedding_dim // 2),
36
+ nn.ReLU(),
37
+ nn.Dropout(dropout_rate),
38
+ nn.Linear(atom_level_embedding_dim // 2, atom_level_embedding_dim // 4),
39
+ nn.ReLU(),
40
+ nn.Dropout(dropout_rate),
41
+ nn.Linear(atom_level_embedding_dim // 4, atom_level_embedding_dim // 8),
42
+ nn.ReLU(),
43
+ nn.Dropout(dropout_rate),
44
+ nn.Linear(atom_level_embedding_dim // 8, c_atompair),
45
+ )
46
+
47
+ # Final MLP to convert from (n_conformers * c_atompair) to c_atom
48
+ self.conformers_to_atom_single_embedding = nn.Sequential(
49
+ nn.Linear(n_conformers * c_atompair, c_atom, bias=False),
50
+ nn.LayerNorm(c_atom) if use_layer_norm else nn.Identity(),
51
+ )
52
+
53
+ # Zero-init the final linear layer to ensure the model starts with identity function (output ≈ 0)
54
+ nn.init.zeros_(self.conformers_to_atom_single_embedding[0].weight)
55
+
56
+ def forward(
57
+ self,
58
+ atom_level_embeddings: Float[
59
+ torch.Tensor, "n_conformers n_atom atom_level_embedding_dim"
60
+ ],
61
+ ) -> Float[torch.Tensor, "n_atom c_atom"]:
62
+ """Forward pass: process atom-level embeddings and return the processed result.
63
+
64
+ Args:
65
+ atom_level_embeddings: Input tensor of shape [n_conformers, n_atom, atom_level_embedding_dim]
66
+
67
+ Returns:
68
+ Processed tensor of shape [n_atom, c_atom] ready for residual addition
69
+ """
70
+ assert (
71
+ atom_level_embeddings.shape[0] == self.n_conformers
72
+ ), "Number of conformers must be consistent"
73
+
74
+ # Subset to [:atom_level_embedding_dim]
75
+ if atom_level_embeddings.shape[-1] > self.atom_level_embedding_dim:
76
+ atom_level_embeddings = atom_level_embeddings[
77
+ ..., : self.atom_level_embedding_dim
78
+ ]
79
+ elif atom_level_embeddings.shape[-1] < self.atom_level_embedding_dim:
80
+ raise ValueError(
81
+ f"Atom-level embedding dimension {atom_level_embeddings.shape[-1]} is less than the expected dimension {self.atom_level_embedding_dim}"
82
+ )
83
+
84
+ # Process atom-level embeddings to get shape [n_conformers, n_atom, c_atompair]
85
+ processed_embeddings: Float[torch.Tensor, "n_conformers n_atom c_atompair"] = (
86
+ self.process_atom_level_embedding(atom_level_embeddings)
87
+ )
88
+
89
+ # Pad with zeros if we don't have enough conformers
90
+ current_n_conformers = processed_embeddings.shape[0]
91
+ if current_n_conformers < self.n_conformers:
92
+ # Pad with zeros at the beginning
93
+ padding_size = self.n_conformers - current_n_conformers
94
+ padding: Float[torch.Tensor, "padding_size n_atom c_atompair"] = (
95
+ torch.zeros(
96
+ padding_size,
97
+ processed_embeddings.shape[1],
98
+ processed_embeddings.shape[2],
99
+ device=processed_embeddings.device,
100
+ dtype=processed_embeddings.dtype,
101
+ )
102
+ )
103
+ processed_embeddings = torch.cat([padding, processed_embeddings], dim=0)
104
+ elif current_n_conformers > self.n_conformers:
105
+ # Truncate to n_conformers
106
+ processed_embeddings = processed_embeddings[: self.n_conformers]
107
+
108
+ # Reshape from [n_conformers, n_atom, c_atompair] to [n_atom, n_conformers * c_atompair]
109
+ reshaped_embeddings: Float[torch.Tensor, "n_atom n_conformers*c_atompair"] = (
110
+ rearrange(processed_embeddings, "c n d -> n (c d)")
111
+ )
112
+
113
+ # Final MLP to get [n_atom, c_atom]
114
+ result: Float[torch.Tensor, "n_atom c_atom"] = (
115
+ self.conformers_to_atom_single_embedding(reshaped_embeddings)
116
+ )
117
+
118
+ return result
@@ -0,0 +1,59 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from rf3.util_module import init_lecun_normal
4
+
5
+ from foundry.training.checkpoint import activation_checkpointing
6
+
7
+
8
+ class OuterProductMean(nn.Module):
9
+ def __init__(self, d_msa=256, d_pair=128, d_hidden=16, p_drop=0.15):
10
+ super(OuterProductMean, self).__init__()
11
+ self.norm = nn.LayerNorm(d_msa)
12
+ self.proj_left = nn.Linear(d_msa, d_hidden)
13
+ self.proj_right = nn.Linear(d_msa, d_hidden)
14
+ self.proj_out = nn.Linear(d_hidden * d_hidden, d_pair)
15
+
16
+ self.reset_parameter()
17
+
18
+ def reset_parameter(self):
19
+ # normal initialization
20
+ self.proj_left = init_lecun_normal(self.proj_left)
21
+ self.proj_right = init_lecun_normal(self.proj_right)
22
+ nn.init.zeros_(self.proj_left.bias)
23
+ nn.init.zeros_(self.proj_right.bias)
24
+
25
+ # zero initialize output
26
+ nn.init.zeros_(self.proj_out.weight)
27
+ nn.init.zeros_(self.proj_out.bias)
28
+
29
+ def forward(self, msa):
30
+ B, N, L = msa.shape[:3]
31
+ msa = self.norm(msa)
32
+ left = self.proj_left(msa)
33
+ right = self.proj_right(msa)
34
+ right = right / float(N)
35
+ out = torch.einsum("bsli,bsmj->blmij", left, right).reshape(B, L, L, -1)
36
+ out = self.proj_out(out)
37
+
38
+ return out
39
+
40
+
41
+ class OuterProductMean_AF3(nn.Module):
42
+ def __init__(self, c_msa_embed, c_outer_product, c_out):
43
+ super(OuterProductMean_AF3, self).__init__()
44
+ self.norm = nn.LayerNorm(c_msa_embed)
45
+ self.proj_left = nn.Linear(c_msa_embed, c_outer_product)
46
+ self.proj_right = nn.Linear(c_msa_embed, c_outer_product)
47
+ self.proj_out = nn.Linear(c_outer_product * c_outer_product, c_out)
48
+
49
+ @activation_checkpointing
50
+ def forward(self, msa):
51
+ B, N, L = msa.shape[:3]
52
+ msa = self.norm(msa)
53
+ left = self.proj_left(msa)
54
+ right = self.proj_right(msa)
55
+ right = right / float(N)
56
+ out = torch.einsum("bsli,bsmj->blmij", left, right).reshape(B, L, L, -1)
57
+ out = self.proj_out(out)
58
+
59
+ return out