rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rf3/loss/af3_losses.py ADDED
@@ -0,0 +1,655 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from foundry.training.checkpoint import activation_checkpointing
6
+ from foundry.utils.alignment import weighted_rigid_align
7
+
8
+
9
+ # resolve residue-level symmetries in native vs pred
10
+ class ResidueSymmetryResolution(nn.Module):
11
+ def _get_best(self, x_pred, x_native, x_native_mask, a_i):
12
+ mask = torch.zeros_like(x_native_mask[0])
13
+ mask[a_i[0]] = True
14
+ d_pred = torch.cdist(x_pred[:, mask], x_pred[:, ~mask])
15
+ x_nat_j = x_native.clone()
16
+ for j in range(a_i.shape[0]):
17
+ x_nat_j[:, a_i[0]] = x_native[:, a_i[j]]
18
+ d_nat = torch.cdist(x_nat_j[:, mask], x_nat_j[:, ~mask])
19
+ drms_j = torch.square(d_pred - d_nat).nan_to_num()
20
+ drms_j[drms_j > 15] = 15
21
+ drms_j = torch.mean(drms_j, dim=(-1, -2))
22
+ if j == 0:
23
+ bestj = torch.zeros(
24
+ x_pred.shape[0], dtype=torch.long, device=x_pred.device
25
+ )
26
+ bestrms = drms_j
27
+ else:
28
+ bestj[drms_j < bestrms] = j
29
+ bestrms[drms_j < bestrms] = drms_j[drms_j < bestrms]
30
+ # x_nat_j[:,a_i[0]] = x_native[:,a_i[j]]
31
+ for j in range(x_pred.shape[0]):
32
+ x_native[j, a_i[0]] = x_native[j, a_i[bestj[j]]]
33
+ x_native_mask[j, a_i[0]] = x_native_mask[j, a_i[bestj[j]]]
34
+
35
+ return x_native, x_native_mask
36
+
37
+ def forward(self, network_output, loss_input, automorph_input):
38
+ x_pred = network_output["X_L"]
39
+ x_native = loss_input["X_gt_L"]
40
+ x_native_mask = loss_input["crd_mask_L"]
41
+ for a_i in automorph_input:
42
+ if a_i.shape[0] == 1:
43
+ continue
44
+ a_i = torch.tensor(a_i, device=x_pred.device)
45
+ x_native, x_native_mask = self._get_best(
46
+ x_pred, x_native, x_native_mask, a_i
47
+ )
48
+
49
+ loss_input["X_gt_L"] = x_native
50
+ loss_input["crd_mask_L"] = x_native_mask
51
+
52
+ return loss_input
53
+
54
+
55
+ # Resolve subunit-level symmetries in native vs pred
56
+ class SubunitSymmetryResolution(nn.Module):
57
+ def __init__(self, **losses):
58
+ super().__init__()
59
+
60
+ def _rms_align(self, X_fixed, X_moving):
61
+ # input:
62
+ # X_fixed = predicted = Nbatch x L x 3
63
+ # X_moving = native = Nambig x L x 3
64
+ # output:
65
+ # X_pre = Nambig x Nbatch x 3
66
+ # U = Nambig x Nbatch x 3 x 3
67
+ # X_post = Nambig x Nbatch x 3
68
+ assert X_fixed.shape[-2:] == X_moving.shape[-2:]
69
+ Nbatch = X_fixed.shape[0]
70
+ Nambig = X_moving.shape[0]
71
+ X_fixed = X_fixed[None, :]
72
+ X_moving = X_moving[:, None]
73
+
74
+ u_X_fixed = torch.mean(X_fixed, dim=-2)
75
+ u_X_moving = torch.mean(X_moving, dim=-2)
76
+
77
+ X_fixed = X_fixed - u_X_fixed.unsqueeze(-2)
78
+ X_moving = X_moving - u_X_moving.unsqueeze(-2)
79
+
80
+ C = torch.einsum("...ji,...jk->...ik", X_moving, X_fixed)
81
+ U, S, V = torch.linalg.svd(C)
82
+ R = U @ V
83
+ F = torch.eye(3, 3, device=X_fixed.device)[None, None].repeat(
84
+ Nambig, Nbatch, 1, 1
85
+ )
86
+ F[..., -1, -1] = torch.sign(torch.linalg.det(R))
87
+ R = U @ F @ V
88
+ return u_X_moving, R, u_X_fixed
89
+
90
+ def _greedy_resolve_mapping(
91
+ self,
92
+ dist,
93
+ iid_to_index,
94
+ entity_to_index,
95
+ iids_by_entity,
96
+ entity_by_iids,
97
+ nmodel_by_iid,
98
+ ):
99
+ # returns:
100
+ # best_xform tensor [i]->transform number
101
+ # best_assignment dict{pred_iid:[native_iids]} (batch)
102
+ nTransforms = dist.shape[0]
103
+ nIid = dist.shape[1]
104
+ nBatch = dist.shape[-1]
105
+ toAssign = [k for k, v in nmodel_by_iid.items() if v > 0]
106
+
107
+ # sort equiv groups by # resolved residues
108
+ # first make that list
109
+ nmodel_by_equiv = {
110
+ int(i): 0 for i in entity_to_index.keys()
111
+ } # torch.zeros(nEquiv,dtype=torch.long,device=dist.device)
112
+ for i, iid in enumerate(toAssign):
113
+ nmodel_by_equiv[entity_by_iids[iid]] += nmodel_by_iid[iid]
114
+ equiv_order = sorted(
115
+ nmodel_by_equiv, key=nmodel_by_equiv.get
116
+ ) # torch.argsort(nmodel_by_equiv,descending=True)
117
+
118
+ best_cost = torch.zeros(nBatch, device=dist.device)
119
+ best_xform = torch.zeros(nBatch, dtype=torch.long, device=dist.device)
120
+ best_assignment = {
121
+ int(i): torch.zeros(nBatch, dtype=torch.long, device=dist.device)
122
+ for i in toAssign
123
+ }
124
+ for t in range(nTransforms):
125
+ # then sort with most res first
126
+ cost = torch.zeros(nBatch, device=dist.device)
127
+ assignment = {
128
+ int(i): torch.full(
129
+ (nBatch,), int(i), dtype=torch.long, device=dist.device
130
+ )
131
+ for i in toAssign
132
+ }
133
+
134
+ for i_equiv in equiv_order:
135
+ mask_equiv = torch.zeros(
136
+ (nIid, nIid), dtype=torch.bool, device=dist.device
137
+ )
138
+ iids_in_i_equiv = iids_by_entity[i_equiv]
139
+ nIids_in_i_equiv = iids_in_i_equiv.shape[0]
140
+ iid_idxs_in_i_equiv = np.vectorize(iid_to_index.__getitem__)(
141
+ iids_in_i_equiv
142
+ )
143
+
144
+ nResolvedEntities_i = len(
145
+ [
146
+ nmodel_by_iid[int(i)]
147
+ for i in iids_in_i_equiv
148
+ if nmodel_by_iid[i] > 0
149
+ ]
150
+ )
151
+
152
+ mask_equiv[
153
+ iid_idxs_in_i_equiv[:, None], iid_idxs_in_i_equiv[None, :]
154
+ ] = True
155
+ wted_dist = dist[t, mask_equiv].nan_to_num(1e9)
156
+
157
+ # greedily assign min RMS within each equiv group
158
+ # print ('work on eq group',iid_idxs_in_i_equiv)
159
+ # print ('toAssign',toAssign)
160
+ for i in range(nResolvedEntities_i):
161
+ wted_dist = wted_dist.view(
162
+ nIids_in_i_equiv * nIids_in_i_equiv, nBatch
163
+ )
164
+ pn = torch.argmin(wted_dist, dim=0)
165
+
166
+ # special case: if there is NO seq overlap between predicted and native peptides,
167
+ # fall back to identity assignment
168
+ if (wted_dist[pn] == 1e9).all():
169
+ break
170
+
171
+ # weight the total cost by #residues
172
+ cost += (
173
+ wted_dist[pn, torch.arange(nBatch, device=wted_dist.device)]
174
+ * nmodel_by_iid[iids_in_i_equiv[i]]
175
+ )
176
+ i_nat, i_pred = pn // nIids_in_i_equiv, pn % nIids_in_i_equiv
177
+ for j, (ii_nat, ii_pred) in enumerate(zip(i_nat, i_pred)):
178
+ assignment[int(iids_by_entity[int(i_equiv)][ii_pred])][j] = (
179
+ iids_by_entity[int(i_equiv)][ii_nat]
180
+ )
181
+
182
+ wted_dist = wted_dist.view(
183
+ nIids_in_i_equiv, nIids_in_i_equiv, nBatch
184
+ )
185
+ for i in range(i_nat.shape[0]):
186
+ wted_dist[i_nat[i], :, i] = 1e6
187
+ wted_dist[:, i_pred[i], i] = 1e6
188
+ if t == 0:
189
+ best_cost = cost
190
+ best_assignment = assignment
191
+ else:
192
+ mask = cost < best_cost
193
+ best_cost[mask] = cost[mask]
194
+ for i, bi in best_assignment.items():
195
+ best_assignment[i][mask] = assignment[i][mask]
196
+ best_xform[mask] = t
197
+
198
+ return (best_xform, best_assignment)
199
+
200
+ def _resolve_subunits(
201
+ self, mol_entities, mol_iid, crop_mask, x_native, mask_native, x_pred
202
+ ):
203
+ # print('x_native',x_native.shape, x_native)
204
+ Nbatch = x_pred.shape[0]
205
+
206
+ # index -> entity
207
+ all_entities = torch.unique(mol_entities)
208
+ # entity -> index
209
+ entity_to_index = {int(ii): i for i, ii in enumerate(all_entities)}
210
+
211
+ # index -> iid
212
+ all_iids = torch.unique(mol_iid).cpu().numpy()
213
+ Niids = len(all_iids)
214
+ # iid -> index
215
+ iid_to_index = {int(ii): i for i, ii in enumerate(all_iids)}
216
+
217
+ # entity -> iid list
218
+ iids_by_entity = {
219
+ int(i): torch.unique(mol_iid[mol_entities == i]).long().cpu().numpy()
220
+ for i in all_entities
221
+ }
222
+ # iid -> entity list
223
+ entity_by_iids = {
224
+ int(i): torch.unique(mol_entities[mol_iid == i]).long().cpu().item()
225
+ for i in all_iids
226
+ }
227
+
228
+ # 1) get the iid with most resolved residues
229
+ mask = torch.zeros(
230
+ mol_entities.shape[0], dtype=torch.bool, device=mol_iid.device
231
+ )
232
+ mask[crop_mask] = 1
233
+ mask_by_iid = {int(i): mask[mol_iid == i] for i in all_iids}
234
+ mask_native_by_iid = {int(i): mask_native[mol_iid == i] for i in all_iids}
235
+ nmodeled_by_iid = {
236
+ int(i): torch.sum(mask_by_iid[i]) for i in mask_native_by_iid.keys()
237
+ }
238
+
239
+ iid_src_idx = max(
240
+ nmodeled_by_iid, key=nmodeled_by_iid.get
241
+ ) # int(nmodeled_by_iid.argmax())
242
+ entity_src_idx = entity_by_iids[iid_src_idx]
243
+ native_by_iid = {int(i): x_native[mol_iid == i] for i in all_iids}
244
+ pred_by_iid = {int(ii): x_pred[:, mol_iid[crop_mask] == ii] for ii in all_iids}
245
+
246
+ # align it to all equivalent targets
247
+ equiv_native_iids = iids_by_entity[entity_src_idx]
248
+
249
+ # output:
250
+ # xpres = Ntrans x Nbatch x 3
251
+ # U = Ntrans x Nbatch x 3 x 3
252
+ # xposts = Ntrans x Nbatch x 3
253
+ xpres, Us, xposts = [], [], []
254
+
255
+ for n in equiv_native_iids:
256
+ nat_n = native_by_iid[int(n)][mask_by_iid[int(iid_src_idx)]]
257
+ pred_n = pred_by_iid[int(iid_src_idx)]
258
+ mask_unres = ~nat_n[..., 0].isnan()
259
+ nat_n = nat_n[mask_unres]
260
+ pred_n = pred_n[:, mask_unres]
261
+
262
+ if mask_unres.sum() > 3:
263
+ xpre, U, xpost = self._rms_align(pred_n, nat_n[None])
264
+ xpres.append(xpre)
265
+ Us.append(U)
266
+ xposts.append(xpost)
267
+
268
+ xpres, Us, xposts = (
269
+ torch.cat(xpres, dim=0),
270
+ torch.cat(Us, dim=0),
271
+ torch.cat(xposts, dim=0),
272
+ )
273
+
274
+ # build up the matrix of COMs
275
+ # nat_com[i,j] = com of native iid i using crop mask from pred iid j (if compatible)
276
+ nat_com = torch.full((Niids, Niids, 3), np.nan, device=Us.device)
277
+ for i in all_iids:
278
+ equiv_native_iids = iids_by_entity[entity_by_iids[i]]
279
+ for j in equiv_native_iids:
280
+ mask_ij = mask_by_iid[int(j)] * ~native_by_iid[int(i)][:, 0].isnan()
281
+ if torch.any(mask_ij):
282
+ nat_com[iid_to_index[i], iid_to_index[j]] = torch.mean(
283
+ native_by_iid[int(i)][mask_ij], dim=0
284
+ )
285
+
286
+ # pred_com[i,j] = com using native mask from iid i on pred iid j
287
+ pred_com = torch.full((Niids, Niids, Nbatch, 3), np.nan, device=Us.device)
288
+ for i in all_iids:
289
+ equiv_native_iids = iids_by_entity[entity_by_iids[i]]
290
+ for j in equiv_native_iids:
291
+ mask_ij = ~native_by_iid[int(i)][:, 0].isnan()[mask_by_iid[int(j)]]
292
+ if torch.any(mask_ij):
293
+ pred_com[iid_to_index[i], iid_to_index[j]] = torch.mean(
294
+ pred_by_iid[int(j)][:, mask_ij], dim=1
295
+ )
296
+ # else:
297
+ # print ('no map',i,j)
298
+
299
+ # apply all transforms to native
300
+ nat_com = (
301
+ torch.einsum(
302
+ "ijkx,ijlxy->ijkly",
303
+ nat_com[None, :, :, :] - xpres[:, None, :, :],
304
+ Us[:, None],
305
+ )
306
+ + xposts[:, None, None]
307
+ )
308
+
309
+ # collect all distances
310
+ # dist[i,j,k,l] - distance assigning ...
311
+ # transform i of
312
+ # iid j of native to
313
+ # iid k of pred for
314
+ # all l models
315
+ dist = torch.linalg.norm(pred_com[None, :, :] - nat_com, dim=-1)
316
+
317
+ # solve mapping
318
+ transforms, assignment = self._greedy_resolve_mapping(
319
+ dist,
320
+ iid_to_index,
321
+ entity_to_index,
322
+ iids_by_entity,
323
+ entity_by_iids,
324
+ nmodeled_by_iid,
325
+ )
326
+
327
+ # generate output stack
328
+ x_native_aln = torch.zeros_like(x_pred)
329
+ x_native_mask = torch.zeros(
330
+ x_pred.shape[:2], dtype=torch.bool, device=x_pred.device
331
+ )
332
+ for i, si in assignment.items():
333
+ for t in range(x_native_aln.shape[0]):
334
+ mask_src = mol_iid == i
335
+ x_native_aln[t, mask_src[mask]] = native_by_iid[int(si[t])][
336
+ mask_by_iid[int(i)]
337
+ ]
338
+ x_native_mask[t, mask_src[mask]] = mask_native_by_iid[int(si[t])][
339
+ mask_by_iid[int(i)]
340
+ ]
341
+
342
+ return (x_native_aln, x_native_mask)
343
+
344
+ def forward(self, network_output, loss_input, symm_input):
345
+ x_pred = network_output["X_L"]
346
+ mol_entities = symm_input["molecule_entity"].to(x_pred.device)
347
+ mol_iid = symm_input["molecule_iid"].to(x_pred.device)
348
+ crop_mask = symm_input["crop_mask"].to(x_pred.device)
349
+ x_native = symm_input["coord_atom_lvl"].to(x_pred.device)
350
+ mask_native = symm_input["mask_atom_lvl"].to(x_pred.device)
351
+
352
+ x_native_aln, x_native_mask = self._resolve_subunits(
353
+ mol_entities, mol_iid, crop_mask, x_native, mask_native, x_pred
354
+ )
355
+
356
+ loss_input["X_gt_L"] = x_native_aln
357
+ loss_input["crd_mask_L"] = x_native_mask
358
+
359
+ return loss_input
360
+
361
+
362
+ class ProteinLigandBondLoss(nn.Module):
363
+ def __init__(self, weight):
364
+ super().__init__()
365
+ self.weight = weight
366
+
367
+ def forward(self, network_input, network_output, loss_input):
368
+ # find p/l bonds at token level
369
+ is_ligand = network_input["f"]["is_ligand"]
370
+ is_inter_polymer_ligand = torch.outer(is_ligand, ~is_ligand)
371
+ token_bonds = network_input["f"]["token_bonds"]
372
+ pl_bonds = token_bonds * is_inter_polymer_ligand
373
+ first_tok, second_tok = pl_bonds.nonzero(as_tuple=True)
374
+
375
+ # early exit
376
+ if first_tok.numel() == 0:
377
+ return torch.tensor(0.0), {"protein_ligand_bond_loss": torch.tensor(0.0)}
378
+
379
+ # map tokens to atom level
380
+ atom2token = network_input["f"]["atom_to_token_map"]
381
+ pl_atoms = torch.zeros(
382
+ (1, atom2token.shape[0], atom2token.shape[0]),
383
+ dtype=torch.bool,
384
+ device=atom2token.device,
385
+ )
386
+ for i, j in zip(first_tok, second_tok):
387
+ pl_atoms += (atom2token == i)[None, :, None] * (atom2token == j)[
388
+ None, None, :
389
+ ]
390
+
391
+ crd_mask_LL = (
392
+ loss_input["crd_mask_L"][:, None] * loss_input["crd_mask_L"][:, :, None]
393
+ )
394
+ resolved_bonds = pl_atoms * crd_mask_LL
395
+
396
+ # the mask may be different for each structure in the batch, so resolve bonds at the per-batch level
397
+ b, atom1, atom2 = resolved_bonds.nonzero(as_tuple=True)
398
+
399
+ # get loss
400
+ X_L = network_output["X_L"]
401
+ X_gt_L = loss_input["X_gt_L"]
402
+ predicted_distances = torch.linalg.norm(X_L[b, atom1] - X_L[b, atom2], dim=-1)
403
+ ground_truth_distances = torch.linalg.norm(
404
+ X_gt_L[b, atom1] - X_gt_L[b, atom2], dim=-1
405
+ )
406
+ mask_bonded = ground_truth_distances < 2.4
407
+ loss = torch.mean(
408
+ torch.square(
409
+ predicted_distances[mask_bonded] - ground_truth_distances[mask_bonded]
410
+ )
411
+ )
412
+
413
+ return self.weight * loss, {"protein_ligand_bond_loss": loss.detach()}
414
+
415
+
416
+ class DiffusionLoss(nn.Module):
417
+ def __init__(
418
+ self,
419
+ weight,
420
+ sigma_data,
421
+ alpha_dna,
422
+ alpha_rna,
423
+ alpha_ligand,
424
+ edm_lambda,
425
+ se3_invariant_loss,
426
+ clamp_diffusion_loss,
427
+ ):
428
+ super().__init__()
429
+ self.weight = weight
430
+ self.sigma_data = sigma_data
431
+ self.alpha_dna = alpha_dna
432
+ self.alpha_rna = alpha_rna
433
+ self.alpha_ligand = alpha_ligand
434
+ if edm_lambda:
435
+ # original EDM scaling factor
436
+ self.get_lambda = (
437
+ lambda sigma: (sigma**2 + self.sigma_data**2)
438
+ / (sigma * self.sigma_data) ** 2
439
+ )
440
+ else:
441
+ # AF3 uses a weird scaling factor for their loss
442
+ self.get_lambda = (
443
+ lambda sigma: (sigma**2 + self.sigma_data**2)
444
+ / (sigma + self.sigma_data) ** 2
445
+ )
446
+ self.se3_invariant_loss = se3_invariant_loss
447
+ self.clamp_diffusion_loss = clamp_diffusion_loss
448
+
449
+ def forward(self, network_input, network_output, loss_input):
450
+ X_L = network_output["X_L"] # D, L, 3
451
+ D = X_L.shape[0]
452
+ X_gt_L = loss_input["X_gt_L"]
453
+ crd_mask_L = loss_input["crd_mask_L"]
454
+ tok_idx = network_input["f"]["atom_to_token_map"]
455
+ t = network_input["t"] # (D,)
456
+
457
+ w_L = 1 + (
458
+ network_input["f"]["is_dna"] * self.alpha_dna
459
+ + network_input["f"]["is_rna"] * self.alpha_rna
460
+ + network_input["f"]["is_ligand"] * self.alpha_ligand
461
+ )[tok_idx].to(torch.float)
462
+ w_L = w_L[None].expand(D, -1) * crd_mask_L
463
+
464
+ if self.se3_invariant_loss:
465
+ # check if this is correct
466
+ X_gt_aligned_L = weighted_rigid_align(X_L, X_gt_L, crd_mask_L[0], w_L)
467
+ else:
468
+ X_gt_aligned_L = X_gt_L
469
+ X_gt_aligned_L = torch.nan_to_num(X_gt_aligned_L)
470
+ l_mse = (
471
+ 1
472
+ / 3
473
+ * torch.div(
474
+ torch.sum(w_L * torch.sum((X_L - X_gt_aligned_L) ** 2, dim=-1), dim=-1),
475
+ torch.sum(crd_mask_L[0]) + 1e-4,
476
+ )
477
+ ) # w_L is already updated by the mask
478
+
479
+ assert l_mse.shape == (D,)
480
+ l_diffusion = self.get_lambda(t) * l_mse
481
+ l_diffusion = (
482
+ torch.clamp(l_diffusion, max=2)
483
+ if self.clamp_diffusion_loss
484
+ else l_diffusion
485
+ )
486
+
487
+ l_diffusion_total = torch.mean(l_diffusion)
488
+ # smoothed lddt loss
489
+ smoothed_lddt_loss_ = smoothed_lddt_loss(
490
+ X_L,
491
+ X_gt_L,
492
+ crd_mask_L,
493
+ network_input["f"]["is_dna"],
494
+ network_input["f"]["is_rna"],
495
+ tok_idx,
496
+ # tag=network_input["id"]
497
+ )
498
+ l_diffusion_total += smoothed_lddt_loss_.mean()
499
+ loss_dict = {
500
+ "diffusion_loss": l_diffusion.detach(),
501
+ "smoothed_lddt_loss": smoothed_lddt_loss_.detach(),
502
+ "t": t.detach(),
503
+ }
504
+
505
+ return self.weight * l_diffusion_total, loss_dict
506
+
507
+
508
+ def _smoothed_lddt_loss_naive(X_L, X_gt_L_aligned, crd_mask_L, is_dna, is_rna, tok_idx):
509
+ """
510
+ computes lddt with a sigmoid within each bucket to smooth the loss
511
+ X_L: (D, L, 3)
512
+ X_gt_L_aligned: (D, L, 3)
513
+ crd_mask_L: (D, L)
514
+ is_dna: (L,)
515
+ is_rna: (L,)
516
+ tok_idx: (L,)
517
+
518
+ returns: (D,)
519
+ """
520
+ predicted_distances = torch.cdist(X_L, X_L)
521
+ ground_truth_distances = torch.cdist(X_gt_L_aligned, X_gt_L_aligned)
522
+ ground_truth_distances[ground_truth_distances.isnan()] = 9999.0
523
+ difference_distances = torch.abs(ground_truth_distances - predicted_distances)
524
+ lddt_matrix = torch.zeros_like(difference_distances)
525
+ lddt_matrix = (
526
+ 0.25 * torch.sigmoid(4.0 - difference_distances)
527
+ + 0.25 * torch.sigmoid(2.0 - difference_distances)
528
+ + 0.25 * torch.sigmoid(1.0 - difference_distances)
529
+ + 0.25 * torch.sigmoid(0.5 - difference_distances)
530
+ )
531
+ # remove unresolved atoms, atoms within same residue
532
+ in_same_residue_LL = tok_idx[:, None] == tok_idx[None, :]
533
+ is_na_L = is_dna[tok_idx] | is_rna[tok_idx]
534
+ is_close_distance = (ground_truth_distances < 30) * is_na_L + (
535
+ ground_truth_distances < 15
536
+ ) * ~is_na_L
537
+ mask = crd_mask_L[0] & ~in_same_residue_LL & is_close_distance[0]
538
+ lddt = (lddt_matrix * mask[None]).sum(dim=(-1, -2)) / (
539
+ mask.sum(dim=(-1, -2)) + 1e-6
540
+ )
541
+ return 1 - lddt
542
+
543
+
544
+ def smoothed_lddt_loss(X_L, X_gt_L, crd_mask_L, is_dna, is_rna, tok_idx, eps=1e-6):
545
+ @activation_checkpointing
546
+ def _dolddt(X_L, X_gt_L, crd_mask_L, is_dna, is_rna, tok_idx, eps, use_amp=True):
547
+ B, L = X_L.shape[:2]
548
+ first_index, second_index = torch.triu_indices(L, L, 1, device=X_L.device)
549
+
550
+ # compute the unique distances between all pairs of atoms
551
+ X_gt_L = X_gt_L.nan_to_num()
552
+
553
+ # only use native 1 (assumes dist map identical btwn all copies)
554
+ ground_truth_distances = torch.linalg.norm(
555
+ X_gt_L[0:1, first_index] - X_gt_L[0:1, second_index], dim=-1
556
+ )
557
+
558
+ # only score pairs that are close enough in the ground truth
559
+ is_na_L = is_dna[tok_idx][first_index] | is_rna[tok_idx][first_index]
560
+ pair_mask = torch.logical_and(
561
+ ground_truth_distances > 0,
562
+ ground_truth_distances < torch.where(is_na_L, 30.0, 15.0),
563
+ )
564
+ del is_na_L
565
+
566
+ # only score pairs that are resolved in the ground truth
567
+ pair_mask *= crd_mask_L[0:1, first_index] * crd_mask_L[0:1, second_index]
568
+ # don't score pairs that are in the same token
569
+ pair_mask *= tok_idx[None, first_index] != tok_idx[None, second_index]
570
+
571
+ _, valid_pairs = pair_mask.nonzero(as_tuple=True)
572
+ pair_mask = pair_mask[:, valid_pairs].to(X_L.dtype)
573
+ ground_truth_distances = ground_truth_distances[:, valid_pairs]
574
+ first_index, second_index = first_index[valid_pairs], second_index[valid_pairs]
575
+
576
+ predicted_distances = torch.linalg.norm(
577
+ X_L[:, first_index] - X_L[:, second_index], dim=-1
578
+ )
579
+
580
+ delta_distances = torch.abs(predicted_distances - ground_truth_distances + eps)
581
+ del predicted_distances, ground_truth_distances
582
+
583
+ lddt = (
584
+ 0.25
585
+ * (
586
+ torch.sum(torch.sigmoid(0.5 - delta_distances) * pair_mask, dim=(1))
587
+ + torch.sum(torch.sigmoid(1.0 - delta_distances) * pair_mask, dim=(1))
588
+ + torch.sum(torch.sigmoid(2.0 - delta_distances) * pair_mask, dim=(1))
589
+ + torch.sum(torch.sigmoid(4.0 - delta_distances) * pair_mask, dim=(1))
590
+ )
591
+ / (torch.sum(pair_mask, dim=(1)) + eps)
592
+ )
593
+
594
+ return 1 - lddt
595
+
596
+ return _dolddt(X_L, X_gt_L, crd_mask_L, is_dna, is_rna, tok_idx, eps)
597
+
598
+
599
+ def distogram_loss(
600
+ pred_distogram,
601
+ X_rep_atoms_I,
602
+ crd_mask_rep_atoms_I,
603
+ cce_loss,
604
+ min_distance=2,
605
+ max_distance=22,
606
+ bins=64,
607
+ ):
608
+ """
609
+ computes distogram loss
610
+ """
611
+ distance_map = torch.cdist(X_rep_atoms_I, X_rep_atoms_I)
612
+ distance_map[distance_map.isnan()] = 9999.0
613
+ bins = torch.linspace(min_distance, max_distance, bins).to(X_rep_atoms_I.device)
614
+ # Note that torch.bucketize adds a catch-all bin for values outside the range,
615
+ # so we end up with n_bins + 1 bins (65 in the case of AF-3)
616
+ binned_distances = torch.bucketize(distance_map, bins)
617
+ crd_mask_rep_atom_II = crd_mask_rep_atoms_I.unsqueeze(
618
+ -1
619
+ ) * crd_mask_rep_atoms_I.unsqueeze(-2)
620
+ distogram_cce = cce_loss(
621
+ pred_distogram.permute(-1, -2, -3)[None], binned_distances[None]
622
+ )
623
+ return distogram_cce[..., crd_mask_rep_atom_II].sum() / (
624
+ crd_mask_rep_atom_II.sum() + 1e-4
625
+ )
626
+
627
+
628
+ class DistogramLoss(nn.Module):
629
+ def __init__(self, weight):
630
+ super().__init__()
631
+ self.weight = weight
632
+ self.cce_loss = nn.CrossEntropyLoss(reduction="none")
633
+ self.eps = 1e-4
634
+
635
+ def forward(self, network_input, network_output, loss_input):
636
+ pred_distogram = network_output["distogram"]
637
+ X_rep_atoms_I = loss_input["coord_token_lvl"]
638
+ crd_mask_rep_atoms_I = loss_input["mask_token_lvl"]
639
+ loss = distogram_loss(
640
+ pred_distogram, X_rep_atoms_I, crd_mask_rep_atoms_I, self.cce_loss
641
+ )
642
+ return self.weight * loss, {"distogram_loss": loss.detach()}
643
+
644
+
645
+ class NullLoss(nn.Module):
646
+ def __init__(self):
647
+ super().__init__()
648
+
649
+ def forward(self, network_input, network_output, loss_input):
650
+ loss = 0
651
+ for key, val in network_output.items():
652
+ val[val.isnan()] = 0
653
+ loss += torch.sum(val) * 0
654
+
655
+ return loss, {}