rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,153 @@
1
+ import biotite.structure as struct
2
+ import numpy as np
3
+ import torch
4
+ from atomworks.ml.preprocessing.constants import ChainType
5
+ from atomworks.ml.transforms._checks import check_contains_keys
6
+ from atomworks.ml.transforms.base import Transform
7
+ from rfd3.transforms.conditioning_base import (
8
+ convert_existing_annotations_to_bool,
9
+ )
10
+
11
+ MIRROR_IMAGE_MAPPING = {
12
+ "ALA": "DAL",
13
+ "SER": "DSN",
14
+ "CYS": "DCY",
15
+ "PRO": "DPR",
16
+ "VAL": "DVA",
17
+ "THR": "DTH",
18
+ "LEU": "DLE",
19
+ "ILE": "DIL",
20
+ "ASN": "DSG",
21
+ "ASP": "DAS",
22
+ "MET": "MED",
23
+ "GLN": "DGN",
24
+ "GLU": "DGL",
25
+ "LYS": "DLY",
26
+ "HIS": "DHI",
27
+ "PHE": "DPN",
28
+ "ARG": "DAR",
29
+ "TYR": "DTY",
30
+ "TRP": "DTR",
31
+ "GLY": "GLY",
32
+ }
33
+
34
+ D_TO_L_MAPPING = {v: k for k, v in MIRROR_IMAGE_MAPPING.items() if k != "GLY"}
35
+
36
+ TWO_WAY_MIRROR_IMAGE_MAPPING = {**MIRROR_IMAGE_MAPPING, **D_TO_L_MAPPING}
37
+
38
+ D_AA = [aa for aa in MIRROR_IMAGE_MAPPING.values() if aa != "GLY"]
39
+
40
+
41
+ class RandomlyMirrorInputs(Transform):
42
+ """
43
+ This component reflects inputs with a user-provided probability.
44
+
45
+ Only protein and ligand comonents are reflected, nucleic acids are not.
46
+ """
47
+
48
+ def forward(self, data: dict) -> dict:
49
+ assert not data.get("is_inference", False)
50
+ mirror_input = data["conditions"].get("mirror_input", False)
51
+ atom_array = data["atom_array"]
52
+
53
+ if (
54
+ (atom_array.chain_type == ChainType.DNA).any()
55
+ or (atom_array.chain_type == ChainType.RNA).any()
56
+ or (atom_array.chain_type == ChainType.DNA_RNA_HYBRID).any()
57
+ ):
58
+ return data
59
+
60
+ if not mirror_input:
61
+ return data
62
+
63
+ renamed_map = {}
64
+ res_starts = struct.get_residue_starts(atom_array)
65
+ for i, r_i in enumerate(res_starts):
66
+ if i == len(res_starts) - 1:
67
+ r_j = len(atom_array)
68
+ else:
69
+ r_j = res_starts[i + 1]
70
+
71
+ # case 1: standard AA
72
+ resname = atom_array.res_name[r_i]
73
+ if resname in TWO_WAY_MIRROR_IMAGE_MAPPING:
74
+ atom_array.res_name[r_i:r_j] = TWO_WAY_MIRROR_IMAGE_MAPPING[resname]
75
+ # case 2: non-standard AA or ligand with >=4 atoms
76
+ elif r_j - r_i >= 3:
77
+ if resname in renamed_map:
78
+ newname = renamed_map[resname]
79
+ else:
80
+ newname = "L:" + str(len(renamed_map))
81
+ renamed_map[resname] = newname
82
+ atom_array.res_name[r_i:r_j] = newname
83
+
84
+ # flip coords about Z
85
+ atom_array.coord = atom_array.coord * np.array([1, 1, -1.0])
86
+
87
+ xyz = data.get("coord_atom_lvl_to_be_noised", None)
88
+ if xyz is not None:
89
+ # flip coords about Z
90
+ data["coord_atom_lvl_to_be_noised"] = xyz * torch.tensor(
91
+ [1, 1, -1], dtype=xyz.dtype, device=xyz.device
92
+ )
93
+ ground_truth_coord = (
94
+ data["ground_truth"].get("coord_atom_lvl", None)
95
+ if "ground_truth" in data
96
+ else None
97
+ )
98
+ if ground_truth_coord is not None:
99
+ # flip coords about Z
100
+ data["ground_truth"]["coord_atom_lvl"] = ground_truth_coord * torch.tensor(
101
+ [1, 1, -1],
102
+ dtype=ground_truth_coord.dtype,
103
+ device=ground_truth_coord.device,
104
+ )
105
+
106
+ return data
107
+
108
+
109
+ class AddIsDAminoAcidFeat(Transform):
110
+ """
111
+ Adds an annotation to the atom array indicating whether each residue is a D-amino acid.
112
+ """
113
+
114
+ def check_input(self, data) -> None:
115
+ check_contains_keys(data, ["atom_array", "feats"])
116
+
117
+ def forward(self, data: dict) -> dict:
118
+ atom_array = data["atom_array"]
119
+ # Check if there is already an annotation for D-amino acids
120
+ if "is_d_amino_acid" not in atom_array.get_annotation_categories():
121
+ # Check if the res_name is in the D-amino acid set
122
+ is_d_aa = np.isin(atom_array.res_name, D_AA)
123
+ # Create a new annotation for D-amino acids
124
+
125
+ glycines = atom_array.res_name == "GLY"
126
+ # half the time, we will set glycine to be D-glycine
127
+ is_d_aa = np.logical_or(
128
+ is_d_aa, np.logical_and(glycines, np.random.rand(len(glycines)) < 0.5)
129
+ )
130
+
131
+ atom_array.set_annotation(
132
+ "is_d_amino_acid",
133
+ is_d_aa,
134
+ )
135
+
136
+ # Add feature for is_d_amino_acid
137
+ if "is_d_amino_acid" not in data["feats"]:
138
+ is_d_amino_acid = atom_array.get_annotation("is_d_amino_acid")
139
+ data["feats"]["is_d_amino_acid"] = is_d_amino_acid
140
+
141
+ data["atom_array"] = atom_array
142
+
143
+ return data
144
+
145
+
146
+ class StrtoBoolforIsDAminoAcidFeature(Transform):
147
+ def forward(self, data):
148
+ atom_array = data["atom_array"]
149
+ convert_existing_annotations_to_bool(
150
+ atom_array, annotations=["is_d_amino_acid"]
151
+ )
152
+ data["atom_array"] = atom_array
153
+ return data