rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rf3/utils/frames.py ADDED
@@ -0,0 +1,109 @@
1
+ # TODO: REFACTOR; COPIED FROM RF2AA. WE NEED TO ADD DOCSTRINGS, EXAMPLES, HOPEFULLY TESTS, AND CLEAN UP
2
+
3
+ import torch
4
+ from rf3.chemical import NFRAMES, NNAPROTAAS, costgtNA
5
+
6
+
7
+ def is_atom(seq):
8
+ return seq > NNAPROTAAS
9
+
10
+
11
+ def get_frames(xyz_in, xyz_mask, seq, frame_indices, atom_frames=None):
12
+ # B,L,natoms = xyz_in.shape[:3]
13
+ frames = frame_indices[seq]
14
+ atoms = is_atom(seq)
15
+ if torch.any(atoms):
16
+ frames[:, atoms[0].nonzero().flatten(), 0] = atom_frames
17
+
18
+ frame_mask = ~torch.all(frames[..., 0, :] == frames[..., 1, :], axis=-1)
19
+
20
+ # frame_mask *= torch.all(
21
+ # torch.gather(xyz_mask,2,frames.reshape(B,L,-1)).reshape(B,L,-1,3),
22
+ # axis=-1)
23
+
24
+ return frames, frame_mask
25
+
26
+
27
+ # build a frame from 3 points
28
+ # fd - more complicated version splits angle deviations between CA-N and CA-C (giving more accurate CB position)
29
+ # fd - makes no assumptions about input dims (other than last 1 is xyz)
30
+ def rigid_from_3_points(N, Ca, C, is_na=None, eps=1e-4):
31
+ dims = N.shape[:-1]
32
+
33
+ v1 = C - Ca
34
+ v2 = N - Ca
35
+ e1 = v1 / (torch.norm(v1, dim=-1, keepdim=True) + eps)
36
+ u2 = v2 - (torch.einsum("...li, ...li -> ...l", e1, v2)[..., None] * e1)
37
+ e2 = u2 / (torch.norm(u2, dim=-1, keepdim=True) + eps)
38
+ e3 = torch.cross(e1, e2, dim=-1)
39
+ R = torch.cat(
40
+ [e1[..., None], e2[..., None], e3[..., None]], axis=-1
41
+ ) # [B,L,3,3] - rotation matrix
42
+
43
+ v2 = v2 / (torch.norm(v2, dim=-1, keepdim=True) + eps)
44
+ cosref = torch.sum(e1 * v2, dim=-1)
45
+
46
+ costgt = torch.full(dims, -0.3616, device=N.device)
47
+ if is_na is not None:
48
+ costgt[is_na] = costgtNA
49
+
50
+ cos2del = torch.clamp(
51
+ cosref * costgt
52
+ + torch.sqrt((1 - cosref * cosref) * (1 - costgt * costgt) + eps),
53
+ min=-1.0,
54
+ max=1.0,
55
+ )
56
+
57
+ cosdel = torch.sqrt(0.5 * (1 + cos2del) + eps)
58
+
59
+ sindel = torch.sign(costgt - cosref) * torch.sqrt(1 - 0.5 * (1 + cos2del) + eps)
60
+
61
+ Rp = torch.eye(3, device=N.device).repeat(*dims, 1, 1)
62
+ Rp[..., 0, 0] = cosdel
63
+ Rp[..., 0, 1] = -sindel
64
+ Rp[..., 1, 0] = sindel
65
+ Rp[..., 1, 1] = cosdel
66
+ R = torch.einsum("...ij,...jk->...ik", R, Rp)
67
+
68
+ return R, Ca
69
+
70
+
71
+ def mask_unresolved_frames_batched(frames, frame_mask, atom_mask):
72
+ """
73
+ reindex frames tensor from relative indices to absolute indices and masks out frames with atoms that are unresolved
74
+ in the structure
75
+ Input:
76
+ - frames: relative indices for frames (B, L, nframes, 3)
77
+ - frame_mask: mask for which frames are valid to compute FAPE/losses (B, L, nframes)
78
+ - atom_mask: mask for seen coordinates (B, L, natoms)
79
+ Output:
80
+ - frames_reindex: absolute indices for frames
81
+ - frame_mask_update: updated frame mask with frames with unresolved atoms removed
82
+ """
83
+ B, L, natoms = atom_mask.shape
84
+
85
+ # reindex frames for flat X
86
+ frames_reindex = (
87
+ torch.arange(L, device=frames.device)[None, :, None, None] + frames[..., 0]
88
+ ) * natoms + frames[..., 1]
89
+
90
+ masked_atom_frames = torch.any(
91
+ frames_reindex > L * natoms, dim=-1
92
+ ) # find frames with atoms that aren't resolved
93
+ masked_atom_frames *= torch.any(frames_reindex < 0, dim=-1)
94
+ # There are currently indices for frames that aren't in the coordinates bc they arent resolved, reset these indices to 0 to avoid
95
+ # indexing errors
96
+ frames_reindex[masked_atom_frames, :] = 0
97
+
98
+ frame_mask_update = frame_mask.clone()
99
+ frame_mask_update *= ~masked_atom_frames
100
+ frame_mask_update *= torch.all(
101
+ torch.gather(
102
+ atom_mask.reshape(B, L * natoms),
103
+ 1,
104
+ frames_reindex.reshape(B, L * NFRAMES * 3),
105
+ ).reshape(B, L, -1, 3),
106
+ axis=-1,
107
+ )
108
+
109
+ return frames_reindex, frame_mask_update