rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,332 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from mpnn.model.layers.position_wise_feed_forward import PositionWiseFeedForward
4
+
5
+
6
+ # Gather functions borrowed from ProteinMPNN;
7
+ # originally from:
8
+ # https://github.com/jingraham/neurips19-graph-protein-design/tree/master
9
+ def gather_edges(edge_features, neighbor_idx):
10
+ """
11
+ Gather edge features for the neighbors of each node.
12
+ Args:
13
+ edge_features: [B,L,L,H] - edge features
14
+ neighbor_idx: [B,L,K] - neighbor indices
15
+ Returns:
16
+ edge_features_at_neighbors: [B,L,K,H] - neighbor edge features, gathered
17
+ at the neighbor indices.
18
+ """
19
+ _, _, _, H = edge_features.shape
20
+
21
+ # neighbor_idx_expand [B,L,K,H] - expand the neighbor indices along the
22
+ # feature dimension.
23
+ neighbor_idx_expand = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, H)
24
+
25
+ # edge_features_at_neighbors [B,L,K,H] - gather the edge features at the
26
+ # neighbor indices.
27
+ edge_features_at_neighbors = torch.gather(edge_features, 2, neighbor_idx_expand)
28
+
29
+ return edge_features_at_neighbors
30
+
31
+
32
+ def gather_nodes(node_features, neighbor_idx):
33
+ """
34
+ Gather node features for the neighbors of each node.
35
+
36
+ NOTE: in most cases, L1 == L2. This is the straightforward case where the
37
+ node features are gathered at the neighbor indices for every node. However,
38
+ L2 can differ from L1, which allows for gathering node features of less
39
+ nodes (useful for gathering features during decoding, where we decode
40
+ one node at a time).
41
+
42
+ Args:
43
+ node_features: [B,L1,H] - node features
44
+ neighbor_idx: [B,L2,K] - neighbor indices
45
+ Returns:
46
+ node_features_at_neighbors: [B,L2,K,H] - neighbor node features,
47
+ gathered at the neighbor indices.
48
+ """
49
+ B, L2, K = neighbor_idx.shape
50
+ _, _, H = node_features.shape
51
+
52
+ # neighbor_idx_flat [B,L2 * K] - flatten the residue index and neighbor
53
+ # index dimensions; this is done to allow for gathering.
54
+ neighbor_idx_flat = neighbor_idx.reshape((B, -1))
55
+
56
+ # neighbor_idx_flat_expand [B,L2 * K,H] - expand the neighbor indices along
57
+ # the feature dimension.
58
+ neighbor_idx_flat_expand = neighbor_idx_flat.unsqueeze(-1).expand(-1, -1, H)
59
+
60
+ # node_features_at_neighbors_flat [B,L2 * K,H] - gather the node features
61
+ # at the flattened neighbor indices.
62
+ node_features_at_neighbors_flat = torch.gather(
63
+ node_features, 1, neighbor_idx_flat_expand
64
+ )
65
+
66
+ # node_features_at_neighbors [B,L2,K,H] - reshape the gathered node
67
+ # features to the original shape.
68
+ node_features_at_neighbors = node_features_at_neighbors_flat.view(B, L2, K, H)
69
+
70
+ return node_features_at_neighbors
71
+
72
+
73
+ def cat_neighbors_nodes(node_features, edge_features_at_neighbors, neighbor_idx):
74
+ """
75
+ Gather node features for the neighbors of each node and concatenate them
76
+ with the edge features.
77
+
78
+ NOTE: in most cases, L1 == L2. This is the straightforward case where the
79
+ node features are gathered at the neighbor indices for every node, then
80
+ concatenated with the edge features. However, L2 can differ from L1, which
81
+ allows for gathering node features and concatenating to edge features for
82
+ less nodes (useful for gathering features during decoding, where we decode
83
+ one node at a time).
84
+
85
+ Args:
86
+ node_features [B,L1,H1]: node features
87
+ edge_features_at_neighbors [B,L2,K,H2]: edge hidden states
88
+ neighbor_idx [B,L2,K]: neighbor indices
89
+ Returns:
90
+ edge_and_node_features_at_neighbors [B,L2,K,H2+H1]: concatenated node
91
+ and edge features, with the edge features first.
92
+ """
93
+ # node_features_at_neighbors [B,L2,K,H1] - gather the node features at the
94
+ # neighbor indices.
95
+ node_features_at_neighbors = gather_nodes(node_features, neighbor_idx)
96
+
97
+ # edge_and_node_features_at_neighbors [B,L2,K,H2+H1] - concatenate the
98
+ # gathered node features with the edge features.
99
+ edge_and_node_features_at_neighbors = torch.cat(
100
+ [edge_features_at_neighbors, node_features_at_neighbors], -1
101
+ )
102
+
103
+ return edge_and_node_features_at_neighbors
104
+
105
+
106
+ class EncLayer(nn.Module):
107
+ def __init__(self, num_hidden, num_in, dropout=0.1, scale=30):
108
+ super(EncLayer, self).__init__()
109
+
110
+ self.num_hidden = num_hidden
111
+ self.num_in = num_in
112
+ self.scale = scale
113
+
114
+ self.dropout1 = nn.Dropout(dropout)
115
+ self.dropout2 = nn.Dropout(dropout)
116
+ self.dropout3 = nn.Dropout(dropout)
117
+
118
+ self.norm1 = nn.LayerNorm(num_hidden)
119
+ self.norm2 = nn.LayerNorm(num_hidden)
120
+ self.norm3 = nn.LayerNorm(num_hidden)
121
+
122
+ self.W1 = nn.Linear(num_in, num_hidden, bias=True)
123
+ self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
124
+ self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
125
+
126
+ self.W11 = nn.Linear(num_in, num_hidden, bias=True)
127
+ self.W12 = nn.Linear(num_hidden, num_hidden, bias=True)
128
+ self.W13 = nn.Linear(num_hidden, num_hidden, bias=True)
129
+
130
+ self.act = torch.nn.GELU()
131
+
132
+ self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
133
+
134
+ def forward(self, h_V, h_E, E_idx, mask_V=None, mask_E=None):
135
+ """
136
+ Encoder message passing step; updates both the node and edge hidden
137
+ states.
138
+
139
+ NOTE: num_in = 3H
140
+
141
+ Args:
142
+ h_V [B, L, H] - node hidden states
143
+ h_E [B, L, K, H] - edge hidden states
144
+ E_idx [B, L, K] - edge indices
145
+ mask_V [B, L] - node mask
146
+ mask_E [B, L, K] - edge mask
147
+ Returns:
148
+ h_V [B, L, H] - updated node hidden states
149
+ h_E [B, L, K, H] - updated edge hidden states
150
+ """
151
+ # Concatenate h_V_j to h_E_ij
152
+ # (result h_E_ij cat h_V_j)
153
+ # Shape: [B, L, K, H] + [B, L, H] => [B, L, K, 2H]
154
+ h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
155
+
156
+ # Concatenate h_V_i to h_E_ij cat h_V_j
157
+ # (result h_E_ij cat h_V_j cat h_V_i)
158
+ # Shape (h_EV): [B, L, K, 2H] + [B, L, H] => [B, L, K, 3H]
159
+ h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1)
160
+ h_EV = torch.cat([h_V_expand, h_EV], -1)
161
+
162
+ # Compute the message.
163
+ # Shape: [B, L, K, 3H] => [B, L, K, H]
164
+ h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
165
+
166
+ # Apply the edge mask to the message.
167
+ if mask_E is not None:
168
+ h_message = mask_E.unsqueeze(-1) * h_message
169
+
170
+ # Scaled sum aggregation.
171
+ # Shape: [B, L, K, H] => [B, L, H]
172
+ dh = torch.sum(h_message, -2) / self.scale
173
+
174
+ # Dropout + residual + norm.
175
+ h_V = self.norm1(h_V + self.dropout1(dh))
176
+
177
+ # Position-wise feedforward.
178
+ dh = self.dense(h_V)
179
+
180
+ # Dropout + residual + norm.
181
+ h_V = self.norm2(h_V + self.dropout2(dh))
182
+
183
+ # Apply the node mask to the node hidden states.
184
+ if mask_V is not None:
185
+ mask_V = mask_V.unsqueeze(-1)
186
+ h_V = mask_V * h_V
187
+
188
+ # Concatenate h_V_j to h_E_ij (using the updated node state).
189
+ # result h_E_ij cat h_V_j
190
+ # Shape: [B, L, K, H] + [B, L, H] => [B, L, K, 2H]
191
+ h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
192
+
193
+ # Concatenate h_V_i to h_E_ij cat h_V_j (using the updated node state).
194
+ # result h_E_ij cat h_V_j cat h_V_i
195
+ # Shape: [B, L, K, 2H] + [B, L, H] => [B, L, K, 3H]
196
+ h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1)
197
+ h_EV = torch.cat([h_V_expand, h_EV], -1)
198
+
199
+ # Compute an edge update.
200
+ # Shape: [B, L, K, 3H] => [B, L, K, H]
201
+ h_message = self.W13(self.act(self.W12(self.act(self.W11(h_EV)))))
202
+
203
+ # Add the edge update to the edge hidden states.
204
+ # Dropout + residual + norm.
205
+ h_E = self.norm3(h_E + self.dropout3(h_message))
206
+
207
+ return h_V, h_E
208
+
209
+
210
+ class DecLayer(nn.Module):
211
+ def __init__(self, num_hidden, num_in, dropout=0.1, scale=30):
212
+ super(DecLayer, self).__init__()
213
+
214
+ self.num_hidden = num_hidden
215
+ self.num_in = num_in
216
+ self.scale = scale
217
+
218
+ self.dropout1 = nn.Dropout(dropout)
219
+ self.dropout2 = nn.Dropout(dropout)
220
+
221
+ self.norm1 = nn.LayerNorm(num_hidden)
222
+ self.norm2 = nn.LayerNorm(num_hidden)
223
+
224
+ self.W1 = nn.Linear(num_in, num_hidden, bias=True)
225
+ self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
226
+ self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
227
+
228
+ self.act = torch.nn.GELU()
229
+
230
+ self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
231
+
232
+ def forward(self, h_V, h_E, mask_V=None, mask_E=None):
233
+ """
234
+ Decoder message passing step; updates only the node hidden states.
235
+ NOTE: this function is used for both the protein decoder and the ligand
236
+ context encoder. As such, this function operates on the "deepest"
237
+ graph in the tensor.
238
+
239
+ Below, the shapes for the protein decoder application will be as
240
+ follows:
241
+ ... = empty
242
+ node_num = L
243
+ neighbor_num = K
244
+ num_in = 4H
245
+ For the ligand subgraph encoder, the shapes will be as follows:
246
+ ... = L
247
+ node_num = M
248
+ neighbor_num = M
249
+ num_in = 2H
250
+ BUG: this should be 3H, but the original LigandMPNN does not
251
+ pre-concatenate the destination node features to the edge
252
+ features, which breaks the message passing in the ligand
253
+ subgraphs.
254
+ For the protein-ligand graph encoder, the shapes will be as follows:
255
+ ... = empty
256
+ node_num = L
257
+ neighbor_num = M
258
+ num_in = 3H
259
+
260
+ Args:
261
+ h_V [B, ..., node_num, H] - node hidden states
262
+ h_E [B, ..., node_num, neighbor_num, num_in - H] - edge hidden
263
+ states;
264
+ NOTE: for message passing to behave in the decoder, the
265
+ destination node features (and sequence if applicable) MUST
266
+ be pre-concatenated to the edge features.
267
+ So, h_E is actually:
268
+ - protein decoder: h_E_ij cat h_S_j cat h_V_j
269
+ - ligand subgraph encoder: h_ligand_subgraph_edges_ij;
270
+ BUG: this should be h_ligand_subgraph_edges_ij cat
271
+ h_ligand_subgraph_nodes_j; in its current form
272
+ (replicating the original LigandMPNN), the
273
+ destination node features are not concatenated to
274
+ the edge features, which breaks the message passing
275
+ in the ligand subgraph.
276
+ - protein-ligand graph encoder: h_E_protein_to_ligand_ij
277
+ cat h_ligand_subgraph_nodes_j
278
+ mask_V [B, ..., node_num] - node mask
279
+ mask_E [B, ..., node_num, neighbor_num] - edge mask
280
+ Returns:
281
+ h_V [B, ..., node_num, H] - updated node hidden states
282
+ """
283
+
284
+ # Concatenate source node features to edge features, which include the
285
+ # destination node features.
286
+ # - protein decoder: concatenate h_V_i to h_E_ij cat h_S_j cat h_V_j
287
+ # result: h_E_ij cat h_S_j cat h_V_j cat h_V_i
288
+ # - ligand subgraph encoder: concatenate h_ligand_subgraph_nodes_i
289
+ # to h_ligand_subgraph_edges_ij
290
+ # result: h_ligand_subgraph_edges_ij cat
291
+ # h_ligand_subgraph_nodes_i
292
+ # - protein-ligand graph encoder: concatenate h_V_i to
293
+ # h_E_protein_to_ligand_ij cat h_ligand_subgraph_nodes_j
294
+ # result: h_E_protein_to_ligand_ij cat h_ligand_subgraph_nodes_j
295
+ # cat h_V_i
296
+ # Shape (h_EV): [B, ..., node_num, neighbor_num, num_in - H] +
297
+ # [B, ..., node_num, H] => [B, ..., node_num, neighbor_num, num_in]
298
+ h_V_expand = h_V.unsqueeze(-2).expand(
299
+ *h_V.shape[:-1], # B, ..., node_num
300
+ h_E.size(-2), # neighbor_num
301
+ h_V.shape[-1], # H
302
+ )
303
+ h_EV = torch.cat([h_V_expand, h_E], -1)
304
+
305
+ # Compute the message.
306
+ # Shape: [B, ..., node_num, neighbor_num, num_in] =>
307
+ # [B, ..., node_num, neighbor_num, H]
308
+ h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
309
+
310
+ # Apply the edge mask to the message.
311
+ if mask_E is not None:
312
+ h_message = mask_E.unsqueeze(-1) * h_message
313
+
314
+ # Scaled sum aggregation.
315
+ # Shape: [B, ..., node_num, neighbor_num, H] => [B, ..., node_num, H]
316
+ dh = torch.sum(h_message, -2) / self.scale
317
+
318
+ # Dropout + residual + norm.
319
+ h_V = self.norm1(h_V + self.dropout1(dh))
320
+
321
+ # Position-wise feedforward
322
+ dh = self.dense(h_V)
323
+
324
+ # Dropout + residual + norm.
325
+ h_V = self.norm2(h_V + self.dropout2(dh))
326
+
327
+ # Apply the node mask to the node hidden states.
328
+ if mask_V is not None:
329
+ mask_V = mask_V.unsqueeze(-1)
330
+ h_V = mask_V * h_V
331
+
332
+ return h_V
@@ -0,0 +1,44 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class PositionWiseFeedForward(nn.Module):
6
+ def __init__(self, num_hidden, num_ff):
7
+ """
8
+ Position-wise feed-forward layer.
9
+
10
+ Args:
11
+ num_hidden (int): The hidden dimension size of the input and output.
12
+ num_ff (int): The hidden dimension size of the feed-forward layer.
13
+ """
14
+ super(PositionWiseFeedForward, self).__init__()
15
+
16
+ # Initialize the linear layers for the position-wise feed-forward
17
+ # layer with bias.
18
+ self.W_in = nn.Linear(num_hidden, num_ff, bias=True)
19
+ self.W_out = nn.Linear(num_ff, num_hidden, bias=True)
20
+
21
+ self.act = torch.nn.GELU()
22
+
23
+ def forward(self, h_V):
24
+ """
25
+ Forward pass of the position-wise feed-forward layer.
26
+
27
+ Args:
28
+ h_V (torch.Tensor): [B, L, num_hidden] - the hidden embedding
29
+ of the node features.
30
+ Returns:
31
+ feed_forward_output (torch.Tensor): [B, L, num_hidden] - the output
32
+ of the position-wise feed-forward layer.
33
+ """
34
+ # feed_forward_latent [B, L, num_ff] - the input, projected with a
35
+ # linear layer to the feed-forward dimension, and then passed through
36
+ # the activation function.
37
+ feed_forward_latent = self.act(self.W_in(h_V))
38
+
39
+ # feed_forward_output [B, L, num_hidden] - the output of the
40
+ # position-wise feed-forward layer, projected back to the hidden
41
+ # dimension with a linear layer.
42
+ feed_forward_output = self.W_out(feed_forward_latent)
43
+
44
+ return feed_forward_output
@@ -0,0 +1,98 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class PositionalEncodings(nn.Module):
6
+ def __init__(self, num_positional_embeddings, max_relative_feature=32):
7
+ """
8
+ Positional encodings for the MPNN model.
9
+
10
+ Args:
11
+ num_positional_embeddings (int): The dimension of the embeddings for
12
+ the positional encodings.
13
+ max_relative_feature (int): The maximum relative feature offset.
14
+ Default is 32, which means the positional encodings will handle
15
+ offsets in the range [-32, 32]. This is used to determine the
16
+ size of the one-hot encoding for the positional offsets.
17
+ """
18
+ super(PositionalEncodings, self).__init__()
19
+
20
+ # Store the number of embeddings and the maximum relative feature.
21
+ self.num_positional_embeddings = num_positional_embeddings
22
+ self.max_relative_feature = max_relative_feature
23
+
24
+ # We reserve enough space for the -max_relative_feature,...,0,...,
25
+ # max_relative_feature, plus an additional input for residue pairs not
26
+ self.num_positional_features = 2 * max_relative_feature + 1 + 1
27
+
28
+ # Initialize the linear layer that will map the one-hot encoding of the
29
+ # positional offsets to the embeddings.
30
+ self.embed_positional_features = nn.Linear(
31
+ self.num_positional_features, num_positional_embeddings
32
+ )
33
+
34
+ def forward(self, positional_offset, same_chain_mask):
35
+ """
36
+ Forward pass of the positional encodings.
37
+
38
+ Args:
39
+ positional_offset (torch.Tensor): [B, L, K] - pairwise differences
40
+ between the indices of residues, gathered for the K nearest
41
+ neighbors.
42
+ same_chain_mask (torch.Tensor): [B, L, K] - a mask indicating
43
+ whether the residues are on the same chain (1) or not (0).
44
+ Returns:
45
+ positional_offset_embeddings (torch.Tensor): [B, L, K,
46
+ self.num_positional_embeddings] - the embeddings for the
47
+ positional offsets, where each offset shifted and clipped to
48
+ the range [0, 2 * max_relative_feature], with a special value
49
+ of (2 * max_relative_feature + 1) for residues not on the same
50
+ chain. The embeddings are obtained by passing the one-hot
51
+ encoding of the chain-aware clipped positional offsets through
52
+ a linear layer.
53
+ """
54
+ # Check that the same chain mask has a boolean dtype.
55
+ if same_chain_mask.dtype != torch.bool:
56
+ raise ValueError("The same_chain_mask must be of boolean dtype.")
57
+
58
+ # shifted_positional_offset [B, L, K] - the positional offset shifted
59
+ # by the maximum relative feature.
60
+ shifted_positional_offset = positional_offset + self.max_relative_feature
61
+
62
+ # clipped_positional_offset [B, L, K] - the shifted positional offset
63
+ # clipped to the range [0, 2 * max_relative_feature]. Combining the
64
+ # shifting and clipping, this captures original positional offsets in
65
+ # the range [-max_relative_feature, max_relative_feature], shifting
66
+ # them to the range [0, 2 * max_relative_feature], clipping any values
67
+ # outside this range to the nearest valid value. The shifting to non-
68
+ # negative values is necessary for the one-hot encoding.
69
+ clipped_positional_offset = torch.clip(
70
+ shifted_positional_offset, 0, 2 * self.max_relative_feature
71
+ )
72
+
73
+ # clipped_positional_offset_chain_aware [B, L, K] - the clipped
74
+ # positional offset, where the values for residues on the same chain
75
+ # are preserved, and the values for residues not on the same chain are
76
+ # set to a special value (2 * max_relative_feature + 1). This is
77
+ # done to ensure that the positional embeddings are chain-aware.
78
+ clipped_positional_offset_chain_aware = (
79
+ clipped_positional_offset * same_chain_mask
80
+ + (~same_chain_mask) * (2 * self.max_relative_feature + 1)
81
+ )
82
+
83
+ # clipped_positional_offset_chain_aware_onehot [B, L, K,
84
+ # self.num_positional_features] - the chained-aware clipped positional
85
+ # offset converted to a one-hot encoding.
86
+ clipped_positional_offset_chain_aware_onehot = torch.nn.functional.one_hot(
87
+ clipped_positional_offset_chain_aware,
88
+ num_classes=self.num_positional_features,
89
+ ).float()
90
+
91
+ # positional_offset_embeddings [B, L, K, self.num_positional_embeddings]
92
+ # - the embeddings for the positional offsets, obtained by passing the
93
+ # one-hot encoding through a linear layer.
94
+ positional_offset_embeddings = self.embed_positional_features(
95
+ clipped_positional_offset_chain_aware_onehot
96
+ )
97
+
98
+ return positional_offset_embeddings