rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,417 @@
1
+ import functools
2
+ import logging
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from rfd3.model.layers.block_utils import (
7
+ bucketize_scaled_distogram,
8
+ pairwise_mean_pool,
9
+ )
10
+ from rfd3.model.layers.blocks import (
11
+ Downcast,
12
+ LocalAtomTransformer,
13
+ OneDFeatureEmbedder,
14
+ PositionPairDistEmbedder,
15
+ RelativePositionEncodingWithIndexRemoval,
16
+ SinusoidalDistEmbed,
17
+ )
18
+ from rfd3.model.layers.chunked_pairwise import (
19
+ ChunkedPairwiseEmbedder,
20
+ ChunkedPositionPairDistEmbedder,
21
+ ChunkedSinusoidalDistEmbed,
22
+ )
23
+ from rfd3.model.layers.layer_utils import (
24
+ RMSNorm,
25
+ Transition,
26
+ linearNoBias,
27
+ )
28
+ from rfd3.model.layers.pairformer_layers import PairformerBlock
29
+
30
+ from foundry.common import exists
31
+ from foundry.training.checkpoint import activation_checkpointing
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class TokenInitializer(nn.Module):
37
+ """
38
+ Token embedding module for RFD3
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ c_s,
44
+ c_z,
45
+ c_atom,
46
+ c_atompair,
47
+ relative_position_encoding,
48
+ n_pairformer_blocks,
49
+ pairformer_block,
50
+ downcast,
51
+ token_1d_features,
52
+ atom_1d_features,
53
+ atom_transformer,
54
+ use_chunked_pll=False, # New parameter for memory optimization
55
+ ):
56
+ super().__init__()
57
+
58
+ # Store chunked mode flag
59
+ self.use_chunked_pll = use_chunked_pll
60
+
61
+ # Features
62
+ self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s)
63
+ self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom)
64
+ self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s)
65
+
66
+ self.downcast_atom = Downcast(c_atom=c_s, c_token=c_s, c_s=None, **downcast)
67
+ self.transition_post_token = Transition(c=c_s, n=2)
68
+ self.transition_post_atom = Transition(c=c_s, n=2)
69
+ self.process_s_init = nn.Sequential(
70
+ RMSNorm(c_s),
71
+ linearNoBias(c_s, c_s),
72
+ )
73
+
74
+ # Operations to mix into Z_II and S_I
75
+ self.to_z_init_i = linearNoBias(c_s, c_z)
76
+ self.to_z_init_j = linearNoBias(c_s, c_z)
77
+ self.relative_position_encoding = RelativePositionEncodingWithIndexRemoval(
78
+ c_z=c_z, **relative_position_encoding
79
+ )
80
+ self.relative_position_encoding2 = RelativePositionEncodingWithIndexRemoval(
81
+ c_z=c_z, **relative_position_encoding
82
+ )
83
+ self.process_token_bonds = linearNoBias(1, c_z)
84
+
85
+ # Processing of Z_init
86
+ self.process_z_init = nn.Sequential(
87
+ RMSNorm(c_z * 2),
88
+ linearNoBias(c_z * 2, c_z),
89
+ )
90
+ self.transition_1 = nn.ModuleList(
91
+ [
92
+ Transition(c=c_z, n=2),
93
+ Transition(c=c_z, n=2),
94
+ ]
95
+ )
96
+ self.ref_pos_embedder_tok = PositionPairDistEmbedder(c_z, embed_frame=False)
97
+
98
+ # Pairformer without triangle updates
99
+ self.transformer_stack = nn.ModuleList(
100
+ [
101
+ PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
102
+ for _ in range(n_pairformer_blocks)
103
+ ]
104
+ )
105
+
106
+ #############################################################################
107
+ # Token track processing
108
+ self.process_s_trunk = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_atom))
109
+ self.process_single_l = nn.Sequential(
110
+ nn.ReLU(), linearNoBias(c_atom, c_atompair)
111
+ )
112
+ self.process_single_m = nn.Sequential(
113
+ nn.ReLU(), linearNoBias(c_atom, c_atompair)
114
+ )
115
+ self.process_z = nn.Sequential(RMSNorm(c_z), linearNoBias(c_z, c_atompair))
116
+
117
+ # ALWAYS create these MLPs - they will be shared between chunked and standard modes
118
+ self.motif_pos_embedder = SinusoidalDistEmbed(c_atompair=c_atompair)
119
+ self.ref_pos_embedder = PositionPairDistEmbedder(c_atompair, embed_frame=False)
120
+ self.pair_mlp = nn.Sequential(
121
+ nn.ReLU(),
122
+ linearNoBias(c_atompair, c_atompair),
123
+ nn.ReLU(),
124
+ linearNoBias(c_atompair, c_atompair),
125
+ nn.ReLU(),
126
+ linearNoBias(c_atompair, c_atompair),
127
+ )
128
+
129
+ # Atom pair feature processing
130
+ if self.use_chunked_pll:
131
+ # Initialize chunked embedders and share the trained MLPs!
132
+ self.chunked_pairwise_embedder = ChunkedPairwiseEmbedder(
133
+ c_atompair=c_atompair,
134
+ motif_pos_embedder=ChunkedSinusoidalDistEmbed(c_atompair=c_atompair),
135
+ ref_pos_embedder=ChunkedPositionPairDistEmbedder(
136
+ c_atompair, embed_frame=False
137
+ ),
138
+ process_single_l=self.process_single_l, # Share trained parameters!
139
+ process_single_m=self.process_single_m, # Share trained parameters!
140
+ process_z=self.process_z, # Share trained parameters!
141
+ pair_mlp=self.pair_mlp, # Share trained parameters!
142
+ )
143
+ self.process_pll = linearNoBias(c_atompair, c_atompair)
144
+ self.project_pll = linearNoBias(c_atompair, c_z)
145
+
146
+ if atom_transformer["n_blocks"] > 0:
147
+ self.atom_transformer = LocalAtomTransformer(
148
+ c_atom=c_atom, c_s=None, c_atompair=c_atompair, **atom_transformer
149
+ )
150
+ else:
151
+ self.atom_transformer = None
152
+
153
+ # Post-processing
154
+ # self.process_s_post = nn.Sequential(
155
+ # RMSNorm(c_s),
156
+ # linearNoBias(c_s, c_s),
157
+ # )
158
+ # self.process_z_post = nn.Sequential(
159
+ # RMSNorm(c_z),
160
+ # linearNoBias(c_z, c_z),
161
+ # )
162
+
163
+ def forward(self, f):
164
+ """
165
+ Provides initial representation for atom and token representations
166
+ """
167
+ tok_idx = f["atom_to_token_map"]
168
+ L = len(tok_idx)
169
+ f["ref_atom_name_chars"] = f["ref_atom_name_chars"].reshape(L, -1)
170
+ I = len(f["restype"])
171
+
172
+ def init_tokens():
173
+ # Embed token features
174
+ S_I = self.token_1d_embedder(f, I)
175
+ S_I = S_I + self.transition_post_token(S_I)
176
+
177
+ # Embed atom features and downcast to token features
178
+ S_I = self.downcast_atom(
179
+ Q_L=self.atom_1d_embedder_1(f, L), A_I=S_I, tok_idx=tok_idx
180
+ )
181
+ S_I = S_I + self.transition_post_atom(S_I)
182
+ S_I = self.process_s_init(S_I)
183
+
184
+ # Embed Z_II
185
+ Z_init_II = self.to_z_init_i(S_I).unsqueeze(-3) + self.to_z_init_j(
186
+ S_I
187
+ ).unsqueeze(-2)
188
+ Z_init_II = Z_init_II + self.relative_position_encoding(f)
189
+ Z_init_II = Z_init_II + self.process_token_bonds(
190
+ f["token_bonds"].unsqueeze(-1).float()
191
+ )
192
+
193
+ # Embed reference coordinates of ligands
194
+ token_id = f["ref_space_uid"][f["is_ca"]]
195
+ valid_mask = (token_id.unsqueeze(-1) == token_id.unsqueeze(-2)).unsqueeze(
196
+ -1
197
+ )
198
+ Z_init_II = Z_init_II + self.ref_pos_embedder_tok(
199
+ f["ref_pos"][f["is_ca"]], valid_mask
200
+ )
201
+
202
+ # Run a small transformer to provide position encodings to single.
203
+ for block in self.transformer_stack:
204
+ S_I, Z_init_II = block(S_I, Z_init_II)
205
+
206
+ # Also cat the relative position encoding and mix
207
+ Z_init_II = torch.cat(
208
+ [
209
+ Z_init_II,
210
+ self.relative_position_encoding2(f),
211
+ ],
212
+ dim=-1,
213
+ )
214
+ Z_init_II = self.process_z_init(Z_init_II)
215
+ for b in range(2):
216
+ Z_init_II = Z_init_II + self.transition_1[b](Z_init_II)
217
+
218
+ return {"S_init_I": S_I, "Z_init_II": Z_init_II}
219
+
220
+ @activation_checkpointing
221
+ def init_atoms(S_init_I, Z_init_II):
222
+ Q_L_init = self.atom_1d_embedder_2(f, L)
223
+ C_L = Q_L_init + self.process_s_trunk(S_init_I)[..., tok_idx, :]
224
+
225
+ if self.use_chunked_pll:
226
+ # Chunked mode: return embedder for later sparse computation
227
+ return {
228
+ "Q_L_init": Q_L_init,
229
+ "C_L": C_L,
230
+ "chunked_pairwise_embedder": self.chunked_pairwise_embedder,
231
+ "S_I": S_init_I,
232
+ "Z_II": Z_init_II,
233
+ }
234
+ else:
235
+ # Original full P_LL computation
236
+ ##################################################################################
237
+ # Embed motif coordinates
238
+ valid_mask = (
239
+ f["is_motif_atom_with_fixed_coord"].unsqueeze(-1)
240
+ & f["is_motif_atom_with_fixed_coord"].unsqueeze(-2)
241
+ ).unsqueeze(-1)
242
+ P_LL = self.motif_pos_embedder(
243
+ f["motif_pos"], valid_mask
244
+ ) # (L, L, c_atompair)
245
+
246
+ # Embed ref pos
247
+ atoms_in_same_token = (
248
+ f["ref_space_uid"].unsqueeze(-1) == f["ref_space_uid"].unsqueeze(-2)
249
+ ).unsqueeze(-1)
250
+ # Only consider ref_pos for atoms given seq (otherwise ref_pos is 0, doesn't make sense to compute)
251
+ atoms_has_seq = (
252
+ f["is_motif_atom_with_fixed_seq"].unsqueeze(-1)
253
+ & f["is_motif_atom_with_fixed_seq"].unsqueeze(-2)
254
+ ).unsqueeze(-1)
255
+ valid_mask = atoms_in_same_token & atoms_has_seq
256
+ P_LL = P_LL + self.ref_pos_embedder(f["ref_pos"], valid_mask)
257
+
258
+ ##################################################################################
259
+
260
+ P_LL = P_LL + (
261
+ self.process_single_l(C_L).unsqueeze(-2)
262
+ + self.process_single_m(C_L).unsqueeze(-3)
263
+ )
264
+ P_LL = (
265
+ P_LL
266
+ + self.process_z(Z_init_II)[..., tok_idx, :, :][..., tok_idx, :]
267
+ )
268
+ P_LL = P_LL + self.pair_mlp(P_LL)
269
+ P_LL = P_LL.contiguous()
270
+
271
+ # Pool P_LL to token level to provide atom-level resolution for token track
272
+ pooled_atom_level_features = pairwise_mean_pool(
273
+ pairwise_atom_features=self.process_pll(P_LL).unsqueeze(0),
274
+ atom_to_token_map=tok_idx,
275
+ I=int(tok_idx.max().item()) + 1,
276
+ dtype=P_LL.dtype,
277
+ ).squeeze(0)
278
+ Z_init_II = Z_init_II + self.project_pll(pooled_atom_level_features)
279
+
280
+ # Mix atom conditioning features via sequence-local attention
281
+ if exists(self.atom_transformer):
282
+ C_L = self.atom_transformer(
283
+ C_L.unsqueeze(0), None, P_LL, indices=None, f=f, X_L=None
284
+ ).squeeze(0)
285
+
286
+ return {
287
+ "Q_L_init": Q_L_init,
288
+ "C_L": C_L,
289
+ "P_LL": P_LL,
290
+ "S_I": S_init_I,
291
+ "Z_II": Z_init_II,
292
+ }
293
+
294
+ tokens = init_tokens()
295
+ return init_atoms(**tokens)
296
+
297
+
298
+ class DiffusionTokenEncoder(nn.Module):
299
+ def __init__(
300
+ self,
301
+ c_s,
302
+ c_z,
303
+ c_token,
304
+ c_atompair,
305
+ sigma_data,
306
+ n_pairformer_blocks,
307
+ pairformer_block,
308
+ use_distogram,
309
+ use_self,
310
+ use_sinusoidal_distogram_embedder=True,
311
+ **_,
312
+ ):
313
+ super().__init__()
314
+
315
+ # Sequence processing
316
+ self.transition_1 = nn.ModuleList(
317
+ [
318
+ Transition(c=c_s, n=2),
319
+ Transition(c=c_s, n=2),
320
+ ]
321
+ )
322
+
323
+ # Post-processing of z
324
+ self.n_bins_distogram = 65 # n bins for both self distogram and distogram
325
+ n_bins_noise = self.n_bins_distogram
326
+ self.use_self = use_self
327
+ self.use_distogram = use_distogram
328
+ self.use_sinusoidal_distogram_embedder = use_sinusoidal_distogram_embedder
329
+ if self.use_distogram:
330
+ if self.use_sinusoidal_distogram_embedder:
331
+ self.dist_embedder = SinusoidalDistEmbed(c_atompair=c_z)
332
+ n_bins_noise = c_z
333
+ else:
334
+ self.bucketize_fn = functools.partial(
335
+ bucketize_scaled_distogram,
336
+ min_dist=1,
337
+ max_dist=30,
338
+ sigma_data=sigma_data,
339
+ n_bins=self.n_bins_distogram,
340
+ )
341
+ cat_c_z = (
342
+ c_z
343
+ + int(self.use_distogram) * n_bins_noise
344
+ + int(self.use_self) * self.n_bins_distogram
345
+ )
346
+ self.process_z = nn.Sequential(
347
+ RMSNorm(cat_c_z),
348
+ linearNoBias(cat_c_z, c_z),
349
+ )
350
+
351
+ self.transition_2 = nn.ModuleList(
352
+ [
353
+ Transition(c=c_z, n=2),
354
+ Transition(c=c_z, n=2),
355
+ ]
356
+ )
357
+
358
+ # Pairformer without triangle updates
359
+ self.pairformer_stack = nn.ModuleList(
360
+ [
361
+ PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
362
+ for _ in range(n_pairformer_blocks)
363
+ ]
364
+ )
365
+
366
+ def forward(self, f, R_L, S_init_I, Z_init_II, C_L, P_LL, **kwargs):
367
+ B = R_L.shape[0]
368
+ """
369
+ Pools atom-level features to token-level features and encodes them into Z_II, S_I and prepares A_I.
370
+ """
371
+
372
+ @activation_checkpointing
373
+ def token_embed(S_init_I, Z_init_II):
374
+ S_I = S_init_I
375
+ for b in range(2):
376
+ S_I = S_I + self.transition_1[b](S_I)
377
+
378
+ Z_II = Z_init_II.unsqueeze(0).expand(B, -1, -1, -1) # B, I, I, c_z
379
+
380
+ Z_II_list = [Z_II]
381
+ if self.use_distogram:
382
+ # Noise / self conditioning pair
383
+ if self.use_sinusoidal_distogram_embedder:
384
+ mask = f["is_motif_atom_with_fixed_coord"][f["is_ca"]]
385
+ mask = (mask[None, :] != mask[:, None]).unsqueeze(
386
+ -1
387
+ ) # remove off-diagonals where distances don't make sense across time
388
+ D_LL = self.dist_embedder(R_L[..., f["is_ca"], :], ~mask)
389
+ else:
390
+ D_LL = self.bucketize_fn(
391
+ R_L[..., f["is_ca"], :]
392
+ ) # [B, L, I, n_bins]
393
+ Z_II_list.append(D_LL)
394
+ if self.use_self:
395
+ D_II_self = kwargs.get("D_II_self")
396
+ if D_II_self is None:
397
+ D_II_self = torch.zeros(
398
+ Z_II.shape[:-1] + (self.n_bins_distogram,),
399
+ device=Z_II.device,
400
+ dtype=Z_II.dtype,
401
+ )
402
+ Z_II_list.append(D_II_self)
403
+ Z_II = torch.cat(Z_II_list, dim=-1)
404
+
405
+ # Flatten concatenated dims
406
+ Z_II = self.process_z(Z_II)
407
+
408
+ for b in range(2):
409
+ Z_II = Z_II + self.transition_2[b](Z_II)
410
+
411
+ # Pairformer to mix
412
+ for block in self.pairformer_stack:
413
+ S_I, Z_II = block(S_I, Z_II)
414
+
415
+ return S_I, Z_II
416
+
417
+ return token_embed(S_init_I, Z_init_II)
@@ -0,0 +1,197 @@
1
+ import math
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.functional import silu
8
+
9
+ from foundry.training.checkpoint import activation_checkpointing
10
+ from foundry.utils.ddp import RankedLogger
11
+
12
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
13
+ try:
14
+ from apex.normalization.fused_layer_norm import FusedRMSNorm
15
+
16
+ ranked_logger.info("Fused RMSNorm enabled!")
17
+ RMSNorm_ = FusedRMSNorm
18
+ except (ImportError, ModuleNotFoundError):
19
+ ranked_logger.warning(
20
+ "Using nn.RMSNorm instead of apex.normalization.fused_layer_norm.FusedRMSNorm."
21
+ "Ensure you're using the correct apptainer"
22
+ )
23
+ RMSNorm_ = nn.RMSNorm
24
+
25
+
26
+ # Allow bias=False to be passed for RMSNorm
27
+ def RMSNorm(*args, **kwargs):
28
+ if "bias" in kwargs:
29
+ kwargs.pop("bias")
30
+ return RMSNorm_(*args, **kwargs)
31
+
32
+
33
+ SWAP_LAYER_NORM_FOR_RMS_NORM = True
34
+ RMSNorm = RMSNorm if SWAP_LAYER_NORM_FOR_RMS_NORM else nn.LayerNorm
35
+ linearNoBias = partial(torch.nn.Linear, bias=False)
36
+
37
+
38
+ class EmbeddingLayer(nn.Linear):
39
+ """
40
+ Specialized linear layer for correct weight initialization for embedding layers.
41
+
42
+ Embedding layers are functionally a multiplication of an N channel input by an NxC weight matrix to produce an
43
+ embedding of length C. However, we compute the components separately with a ModuleDict, then sum at the end, for
44
+ embedding reusability and interoperability purposes.
45
+
46
+ This layer uses Xavier initialization as described in [1]_.
47
+
48
+ References
49
+ ----------
50
+ .. [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty
51
+ of training deep feedforward neural networks." (2010)
52
+ http://proceedings.mlr.press/v9/glorot10a.html
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ this_in_features,
58
+ total_embedding_features,
59
+ out_features,
60
+ device=None,
61
+ dtype=None,
62
+ ):
63
+ self.total_embedding_features = total_embedding_features
64
+ self.out_features = out_features
65
+ super().__init__(
66
+ this_in_features, out_features, bias=False, device=device, dtype=dtype
67
+ )
68
+ self.reset_parameters()
69
+
70
+ def reset_parameters(self, **kwargs):
71
+ super().reset_parameters()
72
+ a = math.sqrt(6.0 / float(self.total_embedding_features + self.out_features))
73
+ nn.init._no_grad_uniform_(self.weight, -a, a)
74
+
75
+
76
+ def collapse(x, L):
77
+ return x.reshape((L, x.numel() // L))
78
+
79
+
80
+ class MultiDimLinear(nn.Linear):
81
+ def __init__(self, in_features, out_shape, norm=False, **kwargs):
82
+ self.out_shape = out_shape
83
+ out_features = np.prod(out_shape)
84
+ super().__init__(in_features, out_features, **kwargs)
85
+ if norm:
86
+ self.ln = RMSNorm((out_features,))
87
+ self.use_ln = True
88
+ else:
89
+ self.use_ln = False
90
+ self.reset_parameters()
91
+
92
+ def reset_parameters(self, **kwargs) -> None:
93
+ super().reset_parameters()
94
+ nn.init.xavier_uniform_(self.weight)
95
+
96
+ def forward(self, x):
97
+ out = super().forward(x)
98
+ if self.use_ln:
99
+ out = self.ln(out)
100
+ return out.reshape(x.shape[:-1] + self.out_shape)
101
+
102
+
103
+ class LinearBiasInit(nn.Linear):
104
+ def __init__(self, *args, biasinit, **kwargs):
105
+ assert biasinit == -2.0 # Sanity check
106
+ self.biasinit = biasinit
107
+ super().__init__(*args, **kwargs)
108
+
109
+ def reset_parameters(self) -> None:
110
+ super().reset_parameters()
111
+ self.bias.data.fill_(self.biasinit)
112
+
113
+
114
+ class Transition(nn.Module):
115
+ def __init__(self, n, c):
116
+ super().__init__()
117
+ self.layer_norm_1 = RMSNorm(c)
118
+ self.linear_1 = linearNoBias(c, n * c)
119
+ self.linear_2 = linearNoBias(c, n * c)
120
+ self.linear_3 = linearNoBias(n * c, c)
121
+
122
+ @activation_checkpointing
123
+ def forward(
124
+ self,
125
+ X,
126
+ ):
127
+ X = self.layer_norm_1(X)
128
+ A = self.linear_1(X)
129
+ B = self.linear_2(X)
130
+ X = self.linear_3(silu(A) * B)
131
+ return X
132
+
133
+
134
+ class AdaLN(nn.Module):
135
+ def __init__(self, c_a, c_s, n=2):
136
+ super().__init__()
137
+ self.ln_a = RMSNorm(normalized_shape=(c_a,), elementwise_affine=False)
138
+ self.ln_s = RMSNorm(normalized_shape=(c_s,), bias=False)
139
+ self.to_gain = nn.Sequential(
140
+ nn.Linear(c_s, c_a),
141
+ nn.Sigmoid(),
142
+ )
143
+ self.to_bias = linearNoBias(c_s, c_a)
144
+
145
+ def forward(
146
+ self,
147
+ Ai, # [B, I, C_a]
148
+ Si, # [B, I, C_s]
149
+ ):
150
+ """
151
+ Output:
152
+ [B, I, C_a]
153
+ """
154
+ Ai = self.ln_a(Ai)
155
+ Si = self.ln_s(Si)
156
+ return self.to_gain(Si) * Ai + self.to_bias(Si)
157
+
158
+
159
+ def create_batch_dimension_if_not_present(batched_n_dim):
160
+ """
161
+ Decorator for adapting a function which expects batched arguments with ndim `batched_n_dim` also
162
+ accept unbatched arguments.
163
+ """
164
+
165
+ def wrap(f):
166
+ def _wrap(arg):
167
+ inserted_batch_dim = False
168
+ if arg.ndim == batched_n_dim - 1:
169
+ arg = arg[None]
170
+ inserted_batch_dim = True
171
+ elif arg.ndim == batched_n_dim:
172
+ pass
173
+ else:
174
+ raise Exception(
175
+ f"arg must have {batched_n_dim - 1} or {batched_n_dim} dimensions, got shape {arg.shape=}"
176
+ )
177
+ o = f(arg)
178
+
179
+ if inserted_batch_dim:
180
+ assert o.shape[0] == 1, f"{o.shape=}[0] != 1"
181
+ return o[0]
182
+ return o
183
+
184
+ return _wrap
185
+
186
+ return wrap
187
+
188
+
189
+ def unpack_args_for_checkpointing(arg_names):
190
+ def wrap(f):
191
+ def _wrap(*args):
192
+ f = args[0]
193
+ return f(**dict(zip(arg_names, args)))
194
+
195
+ return _wrap
196
+
197
+ return wrap