rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,37 @@
1
+ import numpy as np
2
+
3
+
4
+ def sample_bernoulli_rv(p):
5
+ """
6
+ Given a probability p, representing the success probability of a Bernoulli
7
+ distribution, sample X ~ Bernoulli(p).
8
+
9
+ Arguments:
10
+ p (float): a float between 0 and 1, representing the success probability
11
+ of a Bernoulli distribution.
12
+
13
+ Returns:
14
+ x (int): the result of sampling the random variable X ~ Bernoulli(p).
15
+ P(X = 1) = p
16
+ P(X = 0) = 1 - p.
17
+ """
18
+ # Check that 0 <= p <= 1.
19
+ if p < 0 or p > 1:
20
+ raise ValueError("The success probability p must be between 0 and 1 inclusive.")
21
+
22
+ # Handle the edge cases, otherwise utilize the numpy uniform distribution.
23
+ if p == 0:
24
+ x = 0
25
+ elif p == 1:
26
+ x = 1
27
+ else:
28
+ # Sample the Y ~ Uniform(0, 1) distribution.
29
+ uniform_sample = np.random.uniform(0.0, 1.0)
30
+
31
+ # P(Y < p) = p.
32
+ if uniform_sample < p:
33
+ x = 1
34
+ else:
35
+ x = 0
36
+
37
+ return x
mpnn/utils/weights.py ADDED
@@ -0,0 +1,309 @@
1
+ import torch
2
+ from mpnn.transforms.feature_aggregation.token_encodings import (
3
+ legacy_token_order,
4
+ token_order,
5
+ )
6
+
7
+
8
+ def load_legacy_weights(model: torch.nn.Module, weights_path: str) -> None:
9
+ """
10
+ Load a legacy MPNN checkpoint from 'weights_path' into 'model' (the
11
+ refactored MPNN implementation).
12
+
13
+ This performs several transformations:
14
+ - Copies certain non-learned registries (e.g., periodic table info) from
15
+ the new model into the checkpoint state dict (to match the new code).
16
+ - Renames legacy parameter/buffer names into the new module naming scheme.
17
+ - Fixes a 120->119 atom-type embedding weight size mismatch by dropping
18
+ the unused legacy atom type.
19
+ - Reorders pairwise backbone distance embedding weights to match the new
20
+ atom-pair ordering.
21
+ - Reorders token (AA) embeddings/projections weights from the legacy order
22
+ (alphabetical 1-letter) to the new order (alphabetical 3-letter).
23
+ """
24
+ # Load legacy checkpoint state dict.
25
+ checkpoint = torch.load(weights_path, map_location="cpu", weights_only=False)
26
+ checkpoint_state_dict = checkpoint["model_state_dict"]
27
+
28
+ # Values to copy directly from the *current* model into the state dict.
29
+ # These are effectively "configuration" tensors/registries, not learned
30
+ # weights we want to preserve from the legacy model.
31
+ values_to_copy = [
32
+ "graph_featurization_module.side_chain_atom_types",
33
+ "graph_featurization_module.periodic_table_groups",
34
+ "graph_featurization_module.periodic_table_periods",
35
+ ]
36
+ # Copy over some hard-coded registers/values.
37
+ for value_name in values_to_copy:
38
+ # Walk the attribute chain.
39
+ attr_list = value_name.split(".")
40
+ sub_module = model
41
+ while len(attr_list) > 1 and sub_module is not None:
42
+ attr = attr_list.pop(0)
43
+ if hasattr(sub_module, attr):
44
+ sub_module = getattr(sub_module, attr)
45
+ else:
46
+ sub_module = None
47
+
48
+ # If the current sub-module exists, and it has the final attribute,
49
+ # copy it into the checkpoint state dict.
50
+ if sub_module is not None:
51
+ if hasattr(sub_module, attr_list[0]):
52
+ checkpoint_state_dict[value_name] = getattr(sub_module, attr_list[0])
53
+
54
+ # Mapping of legacy weight names to new weight names.
55
+ # Left side = name in the old checkpoint.
56
+ # Right side = name expected by the new model implementation.
57
+ legacy_weight_to_new_weight = {
58
+ "features.embeddings.linear.weight": "graph_featurization_module.positional_embedding.embed_positional_features.weight",
59
+ "features.embeddings.linear.bias": "graph_featurization_module.positional_embedding.embed_positional_features.bias",
60
+ "features.edge_embedding.weight": "graph_featurization_module.edge_embedding.weight",
61
+ "features.norm_edges.weight": "graph_featurization_module.edge_norm.weight",
62
+ "features.norm_edges.bias": "graph_featurization_module.edge_norm.bias",
63
+ "context_encoder_layers.0.norm1.weight": "protein_ligand_context_encoder_layers.0.norm1.weight",
64
+ "context_encoder_layers.0.norm1.bias": "protein_ligand_context_encoder_layers.0.norm1.bias",
65
+ "context_encoder_layers.0.norm2.weight": "protein_ligand_context_encoder_layers.0.norm2.weight",
66
+ "context_encoder_layers.0.norm2.bias": "protein_ligand_context_encoder_layers.0.norm2.bias",
67
+ "context_encoder_layers.0.W1.weight": "protein_ligand_context_encoder_layers.0.W1.weight",
68
+ "context_encoder_layers.0.W1.bias": "protein_ligand_context_encoder_layers.0.W1.bias",
69
+ "context_encoder_layers.0.W2.weight": "protein_ligand_context_encoder_layers.0.W2.weight",
70
+ "context_encoder_layers.0.W2.bias": "protein_ligand_context_encoder_layers.0.W2.bias",
71
+ "context_encoder_layers.0.W3.weight": "protein_ligand_context_encoder_layers.0.W3.weight",
72
+ "context_encoder_layers.0.W3.bias": "protein_ligand_context_encoder_layers.0.W3.bias",
73
+ "context_encoder_layers.0.dense.W_in.weight": "protein_ligand_context_encoder_layers.0.dense.W_in.weight",
74
+ "context_encoder_layers.0.dense.W_in.bias": "protein_ligand_context_encoder_layers.0.dense.W_in.bias",
75
+ "context_encoder_layers.0.dense.W_out.weight": "protein_ligand_context_encoder_layers.0.dense.W_out.weight",
76
+ "context_encoder_layers.0.dense.W_out.bias": "protein_ligand_context_encoder_layers.0.dense.W_out.bias",
77
+ "context_encoder_layers.1.norm1.weight": "protein_ligand_context_encoder_layers.1.norm1.weight",
78
+ "context_encoder_layers.1.norm1.bias": "protein_ligand_context_encoder_layers.1.norm1.bias",
79
+ "context_encoder_layers.1.norm2.weight": "protein_ligand_context_encoder_layers.1.norm2.weight",
80
+ "context_encoder_layers.1.norm2.bias": "protein_ligand_context_encoder_layers.1.norm2.bias",
81
+ "context_encoder_layers.1.W1.weight": "protein_ligand_context_encoder_layers.1.W1.weight",
82
+ "context_encoder_layers.1.W1.bias": "protein_ligand_context_encoder_layers.1.W1.bias",
83
+ "context_encoder_layers.1.W2.weight": "protein_ligand_context_encoder_layers.1.W2.weight",
84
+ "context_encoder_layers.1.W2.bias": "protein_ligand_context_encoder_layers.1.W2.bias",
85
+ "context_encoder_layers.1.W3.weight": "protein_ligand_context_encoder_layers.1.W3.weight",
86
+ "context_encoder_layers.1.W3.bias": "protein_ligand_context_encoder_layers.1.W3.bias",
87
+ "context_encoder_layers.1.dense.W_in.weight": "protein_ligand_context_encoder_layers.1.dense.W_in.weight",
88
+ "context_encoder_layers.1.dense.W_in.bias": "protein_ligand_context_encoder_layers.1.dense.W_in.bias",
89
+ "context_encoder_layers.1.dense.W_out.weight": "protein_ligand_context_encoder_layers.1.dense.W_out.weight",
90
+ "context_encoder_layers.1.dense.W_out.bias": "protein_ligand_context_encoder_layers.1.dense.W_out.bias",
91
+ "y_context_encoder_layers.0.norm1.weight": "ligand_context_encoder_layers.0.norm1.weight",
92
+ "y_context_encoder_layers.0.norm1.bias": "ligand_context_encoder_layers.0.norm1.bias",
93
+ "y_context_encoder_layers.0.norm2.weight": "ligand_context_encoder_layers.0.norm2.weight",
94
+ "y_context_encoder_layers.0.norm2.bias": "ligand_context_encoder_layers.0.norm2.bias",
95
+ "y_context_encoder_layers.0.W1.weight": "ligand_context_encoder_layers.0.W1.weight",
96
+ "y_context_encoder_layers.0.W1.bias": "ligand_context_encoder_layers.0.W1.bias",
97
+ "y_context_encoder_layers.0.W2.weight": "ligand_context_encoder_layers.0.W2.weight",
98
+ "y_context_encoder_layers.0.W2.bias": "ligand_context_encoder_layers.0.W2.bias",
99
+ "y_context_encoder_layers.0.W3.weight": "ligand_context_encoder_layers.0.W3.weight",
100
+ "y_context_encoder_layers.0.W3.bias": "ligand_context_encoder_layers.0.W3.bias",
101
+ "y_context_encoder_layers.0.dense.W_in.weight": "ligand_context_encoder_layers.0.dense.W_in.weight",
102
+ "y_context_encoder_layers.0.dense.W_in.bias": "ligand_context_encoder_layers.0.dense.W_in.bias",
103
+ "y_context_encoder_layers.0.dense.W_out.weight": "ligand_context_encoder_layers.0.dense.W_out.weight",
104
+ "y_context_encoder_layers.0.dense.W_out.bias": "ligand_context_encoder_layers.0.dense.W_out.bias",
105
+ "y_context_encoder_layers.1.norm1.weight": "ligand_context_encoder_layers.1.norm1.weight",
106
+ "y_context_encoder_layers.1.norm1.bias": "ligand_context_encoder_layers.1.norm1.bias",
107
+ "y_context_encoder_layers.1.norm2.weight": "ligand_context_encoder_layers.1.norm2.weight",
108
+ "y_context_encoder_layers.1.norm2.bias": "ligand_context_encoder_layers.1.norm2.bias",
109
+ "y_context_encoder_layers.1.W1.weight": "ligand_context_encoder_layers.1.W1.weight",
110
+ "y_context_encoder_layers.1.W1.bias": "ligand_context_encoder_layers.1.W1.bias",
111
+ "y_context_encoder_layers.1.W2.weight": "ligand_context_encoder_layers.1.W2.weight",
112
+ "y_context_encoder_layers.1.W2.bias": "ligand_context_encoder_layers.1.W2.bias",
113
+ "y_context_encoder_layers.1.W3.weight": "ligand_context_encoder_layers.1.W3.weight",
114
+ "y_context_encoder_layers.1.W3.bias": "ligand_context_encoder_layers.1.W3.bias",
115
+ "y_context_encoder_layers.1.dense.W_in.weight": "ligand_context_encoder_layers.1.dense.W_in.weight",
116
+ "y_context_encoder_layers.1.dense.W_in.bias": "ligand_context_encoder_layers.1.dense.W_in.bias",
117
+ "y_context_encoder_layers.1.dense.W_out.weight": "ligand_context_encoder_layers.1.dense.W_out.weight",
118
+ "y_context_encoder_layers.1.dense.W_out.bias": "ligand_context_encoder_layers.1.dense.W_out.bias",
119
+ "features.node_project_down.weight": "graph_featurization_module.node_embedding.weight",
120
+ "features.node_project_down.bias": "graph_featurization_module.node_embedding.bias",
121
+ "features.norm_nodes.weight": "graph_featurization_module.node_norm.weight",
122
+ "features.norm_nodes.bias": "graph_featurization_module.node_norm.bias",
123
+ "features.type_linear.weight": "graph_featurization_module.embed_atom_type_features.weight",
124
+ "features.type_linear.bias": "graph_featurization_module.embed_atom_type_features.bias",
125
+ "features.y_nodes.weight": "graph_featurization_module.ligand_subgraph_node_embedding.weight",
126
+ "features.y_edges.weight": "graph_featurization_module.ligand_subgraph_edge_embedding.weight",
127
+ "features.norm_y_edges.weight": "graph_featurization_module.ligand_subgraph_edge_norm.weight",
128
+ "features.norm_y_edges.bias": "graph_featurization_module.ligand_subgraph_edge_norm.bias",
129
+ "features.norm_y_nodes.weight": "graph_featurization_module.ligand_subgraph_node_norm.weight",
130
+ "features.norm_y_nodes.bias": "graph_featurization_module.ligand_subgraph_node_norm.bias",
131
+ "W_v.weight": "W_protein_to_ligand_edges_embed.weight",
132
+ "W_v.bias": "W_protein_to_ligand_edges_embed.bias",
133
+ "W_c.weight": "W_protein_encoding_embed.weight",
134
+ "W_c.bias": "W_protein_encoding_embed.bias",
135
+ "W_nodes_y.weight": "W_ligand_nodes_embed.weight",
136
+ "W_nodes_y.bias": "W_ligand_nodes_embed.bias",
137
+ "W_edges_y.weight": "W_ligand_edges_embed.weight",
138
+ "W_edges_y.bias": "W_ligand_edges_embed.bias",
139
+ "V_C.weight": "W_final_context_embed.weight",
140
+ "V_C_norm.weight": "final_context_norm.weight",
141
+ "V_C_norm.bias": "final_context_norm.bias",
142
+ }
143
+ # Rename the weights in the checkpoint state dict.
144
+ for legacy_weight_name, new_weight_name in legacy_weight_to_new_weight.items():
145
+ if legacy_weight_name in checkpoint_state_dict:
146
+ checkpoint_state_dict[new_weight_name] = checkpoint_state_dict.pop(
147
+ legacy_weight_name
148
+ )
149
+
150
+ # Remove unused atom type embedding weight.
151
+ # - Previous LigandMPNN used 120 atom types, but the last one was unused.
152
+ # - The new model uses 119 atom types.
153
+ atom_type_embedding_keys = [
154
+ "graph_featurization_module.embed_atom_type_features.weight",
155
+ "graph_featurization_module.ligand_subgraph_node_embedding.weight",
156
+ ]
157
+ # For each of these keys, drop the unused atom type embedding.
158
+ for key in atom_type_embedding_keys:
159
+ if key in checkpoint_state_dict:
160
+ legacy_weight = checkpoint_state_dict[key]
161
+ num_atomic_numbers = model.graph_featurization_module.num_atomic_numbers
162
+ checkpoint_state_dict[key] = torch.cat(
163
+ [
164
+ legacy_weight[:, :num_atomic_numbers],
165
+ legacy_weight[:, num_atomic_numbers + 1 :],
166
+ ],
167
+ dim=1,
168
+ )
169
+
170
+ # Permute weights for embedding of pairwise backbone atom distances.
171
+ # - The legacy model used the order specified in 'legacy_order' dict.
172
+ # - The new model uses the order specified in 'new_order' list (the
173
+ # outer product of the atom types in the order N, Ca, C, O, Cb).
174
+ legacy_order = {
175
+ "Ca-Ca": 0,
176
+ "N-N": 1,
177
+ "C-C": 2,
178
+ "O-O": 3,
179
+ "Cb-Cb": 4,
180
+ "Ca-N": 5,
181
+ "Ca-C": 6,
182
+ "Ca-O": 7,
183
+ "Ca-Cb": 8,
184
+ "N-C": 9,
185
+ "N-O": 10,
186
+ "N-Cb": 11,
187
+ "Cb-C": 12,
188
+ "Cb-O": 13,
189
+ "O-C": 14,
190
+ "N-Ca": 15,
191
+ "C-Ca": 16,
192
+ "O-Ca": 17,
193
+ "Cb-Ca": 18,
194
+ "C-N": 19,
195
+ "O-N": 20,
196
+ "Cb-N": 21,
197
+ "C-Cb": 22,
198
+ "O-Cb": 23,
199
+ "C-O": 24,
200
+ }
201
+ new_order = [
202
+ "N-N",
203
+ "N-Ca",
204
+ "N-C",
205
+ "N-O",
206
+ "N-Cb",
207
+ "Ca-N",
208
+ "Ca-Ca",
209
+ "Ca-C",
210
+ "Ca-O",
211
+ "Ca-Cb",
212
+ "C-N",
213
+ "C-Ca",
214
+ "C-C",
215
+ "C-O",
216
+ "C-Cb",
217
+ "O-N",
218
+ "O-Ca",
219
+ "O-C",
220
+ "O-O",
221
+ "O-Cb",
222
+ "Cb-N",
223
+ "Cb-Ca",
224
+ "Cb-C",
225
+ "Cb-O",
226
+ "Cb-Cb",
227
+ ]
228
+ pairwise_backbone_atom_embeddings_keys = [
229
+ "graph_featurization_module.edge_embedding.weight",
230
+ ]
231
+ for key in pairwise_backbone_atom_embeddings_keys:
232
+ if key in checkpoint_state_dict:
233
+ # Grab the legacy weight and shape.
234
+ legacy_weight = checkpoint_state_dict[key]
235
+ out_dim, _ = legacy_weight.shape
236
+
237
+ # Grab the necessary dimensions from the model.
238
+ num_positional_embeddings = (
239
+ model.graph_featurization_module.num_positional_embeddings
240
+ )
241
+ num_atoms = (
242
+ model.graph_featurization_module.num_backbone_atoms
243
+ + model.graph_featurization_module.num_virtual_atoms
244
+ )
245
+ num_rbf = model.graph_featurization_module.num_rbf
246
+
247
+ # Split positional and RBF embedding weights.
248
+ legacy_weight_positional_embeddings = legacy_weight[
249
+ :, :num_positional_embeddings
250
+ ]
251
+ legacy_weight_rbf_embeddings_flat = legacy_weight[
252
+ :, num_positional_embeddings:
253
+ ]
254
+
255
+ # Reshape the weights to separate atom pairs and the rbf dimension.
256
+ legacy_weight_rbf_embeddings_atom_pairs = (
257
+ legacy_weight_rbf_embeddings_flat.view(
258
+ out_dim, num_atoms * num_atoms, num_rbf
259
+ )
260
+ )
261
+
262
+ # Reorder the atom pairs to match the new order.
263
+ new_weight_rbf_embeddings_atom_pairs = (
264
+ legacy_weight_rbf_embeddings_atom_pairs[
265
+ :, [legacy_order[atom_pair_name] for atom_pair_name in new_order], :
266
+ ]
267
+ )
268
+
269
+ # Flatten the reordered weights back to 2D.
270
+ new_weight_rbf_embeddings_flat = (
271
+ new_weight_rbf_embeddings_atom_pairs.reshape(
272
+ out_dim, num_atoms * num_atoms * num_rbf
273
+ )
274
+ )
275
+
276
+ # Concatenate positional + reordered RBF
277
+ checkpoint_state_dict[key] = torch.cat(
278
+ [legacy_weight_positional_embeddings, new_weight_rbf_embeddings_flat],
279
+ dim=1,
280
+ )
281
+
282
+ # Permute the token order of amino acids coming out of the model to match
283
+ # the new vocabulary order.
284
+ # - The legacy model used an order specified by alphabetic order of one-
285
+ # letter amino acid codes.
286
+ # - The new model uses an order specified by alphabetic order of three-
287
+ # letter amino acid codes.
288
+ token_embedding_keys = ["W_s.weight"]
289
+ token_projection_keys = ["W_out.weight", "W_out.bias"]
290
+ # For each of these keys, reorder the embeddings/projections.
291
+ for key in token_embedding_keys + token_projection_keys:
292
+ if key in checkpoint_state_dict:
293
+ # Grab the old weight.
294
+ legacy_weight = checkpoint_state_dict[key]
295
+
296
+ # Reorder the weight/bias according to the new token order.
297
+ if "weight" in key:
298
+ checkpoint_state_dict[key] = legacy_weight[
299
+ [legacy_token_order.index(aa) for aa in token_order], :
300
+ ]
301
+ elif "bias" in key:
302
+ checkpoint_state_dict[key] = legacy_weight[
303
+ [legacy_token_order.index(aa) for aa in token_order]
304
+ ]
305
+ else:
306
+ raise ValueError(f"Unrecognized key for token projection: {key}")
307
+
308
+ # Load the modified state dict into the model.
309
+ model.load_state_dict(checkpoint_state_dict, strict=True)
@@ -0,0 +1,239 @@
1
+ Metadata-Version: 2.4
2
+ Name: rc-foundry
3
+ Version: 0.1.1
4
+ Summary: Shared utilities and training infrastructure for biomolecular structure prediction models.
5
+ Author-email: Institute for Protein Design <contact@ipd.uw.edu>
6
+ License: BSD 3-Clause License
7
+
8
+ Copyright (c) 2025, Institute for Protein Design, University of Washington
9
+
10
+ Redistribution and use in source and binary forms, with or without
11
+ modification, are permitted provided that the following conditions are met:
12
+
13
+ * Redistributions of source code must retain the above copyright notice, this
14
+ list of conditions and the following disclaimer.
15
+
16
+ * Redistributions in binary form must reproduce the above copyright notice,
17
+ this list of conditions and the following disclaimer in the documentation
18
+ and/or other materials provided with the distribution.
19
+
20
+ * Neither the name of the copyright holder nor the names of its
21
+ contributors may be used to endorse or promote products derived from
22
+ this software without specific prior written permission.
23
+
24
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34
+ License-File: LICENSE.md
35
+ Classifier: Intended Audience :: Developers
36
+ Classifier: Intended Audience :: Science/Research
37
+ Classifier: License :: OSI Approved :: BSD License
38
+ Classifier: Natural Language :: English
39
+ Classifier: Operating System :: MacOS
40
+ Classifier: Operating System :: Microsoft :: Windows
41
+ Classifier: Operating System :: POSIX :: Linux
42
+ Classifier: Programming Language :: Python :: 3
43
+ Classifier: Programming Language :: Python :: Implementation :: CPython
44
+ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
45
+ Requires-Python: >=3.12
46
+ Requires-Dist: assertpy
47
+ Requires-Dist: atomworks[ml]>=2.1.1
48
+ Requires-Dist: beartype<1,>=0.18.0
49
+ Requires-Dist: dm-tree<1,>=0.1.6
50
+ Requires-Dist: einops<1,>=0.8.0
51
+ Requires-Dist: einx<1,>=0.1.0
52
+ Requires-Dist: environs<12,>=11.0.0
53
+ Requires-Dist: hydra-core<1.4,>=1.3.0
54
+ Requires-Dist: ipykernel>=6.31.0
55
+ Requires-Dist: jaxtyping<1,>=0.2.17
56
+ Requires-Dist: lightning>=2.5.0
57
+ Requires-Dist: loralib>=0.1.1
58
+ Requires-Dist: opt-einsum<4,>=3.4.0
59
+ Requires-Dist: pandas
60
+ Requires-Dist: rich>=13.9.4
61
+ Requires-Dist: rootutils<1.1,>=1.0.7
62
+ Requires-Dist: toolz
63
+ Requires-Dist: torch<3,>=2.2.0
64
+ Requires-Dist: typer<1,>=0.20.0
65
+ Requires-Dist: wandb<1,>=0.15.10
66
+ Requires-Dist: zstandard
67
+ Provides-Extra: all
68
+ Requires-Dist: cuequivariance-ops-cu12>=0.6.1; (sys_platform == 'linux') and extra == 'all'
69
+ Requires-Dist: cuequivariance-ops-torch-cu12>=0.6.1; (sys_platform == 'linux') and extra == 'all'
70
+ Requires-Dist: cuequivariance-torch>=0.6.1; (sys_platform == 'linux') and extra == 'all'
71
+ Requires-Dist: pydantic>=2.8; extra == 'all'
72
+ Provides-Extra: dev
73
+ Requires-Dist: assertpy; extra == 'dev'
74
+ Requires-Dist: atomworks[dev,ml,openbabel]>=2.1.1; extra == 'dev'
75
+ Requires-Dist: debugpy<2,>=1.8.5; extra == 'dev'
76
+ Requires-Dist: ipdb; extra == 'dev'
77
+ Requires-Dist: ipykernel<7,>=6.29.4; extra == 'dev'
78
+ Requires-Dist: pre-commit; extra == 'dev'
79
+ Requires-Dist: pytest-benchmark<6,>=5.0.0; extra == 'dev'
80
+ Requires-Dist: pytest-cov<5,>=4.1.0; extra == 'dev'
81
+ Requires-Dist: pytest-dotenv<1,>=0.5.2; extra == 'dev'
82
+ Requires-Dist: pytest-testmon<3,>=2.1.1; extra == 'dev'
83
+ Requires-Dist: pytest-xdist<4,>=3.6.1; extra == 'dev'
84
+ Requires-Dist: pytest<9,>=8.2.0; extra == 'dev'
85
+ Requires-Dist: ruff==0.8.3; extra == 'dev'
86
+ Provides-Extra: rf3
87
+ Requires-Dist: cuequivariance-ops-cu12>=0.6.1; (sys_platform == 'linux') and extra == 'rf3'
88
+ Requires-Dist: cuequivariance-ops-torch-cu12>=0.6.1; (sys_platform == 'linux') and extra == 'rf3'
89
+ Requires-Dist: cuequivariance-torch>=0.6.1; (sys_platform == 'linux') and extra == 'rf3'
90
+ Provides-Extra: rfd3
91
+ Requires-Dist: pydantic>=2.8; extra == 'rfd3'
92
+ Description-Content-Type: text/markdown
93
+
94
+ # Protein design with Foundry
95
+
96
+ Foundry provides tooling and infrastructure for using and training all classes of models for protein design, including design (RFD3), inverse folding (ProteinMPNN) and protein folding (RF3).
97
+
98
+ All models within Foundry rely on [AtomWorks](https://github.com/RosettaCommons/atomworks) - a unified framework for manipulating and processing biomolecular structures - for both training and inference.
99
+
100
+ ## Getting Started
101
+ ### Quickstart guide
102
+ **Installation**
103
+ ```bash
104
+ pip install rc-foundry[all]
105
+ ```
106
+
107
+ **Downloading weights** All models can be downloaded to a target folder with:
108
+
109
+ ```
110
+ foundry install all --checkpoint_dir <path/to/ckpt/dir>
111
+ ```
112
+ This will download all the models supported (including multiple checkpoints of RF3) but as a beginner you can start with:
113
+ ```
114
+ foundry install rfd3 ligandmpnn rf3 --checkpoint_dir <path/to/ckpt/dir>
115
+ ```
116
+
117
+ >*See `examples/all.ipynb` for how to run each model in a notebook.*
118
+
119
+ ### RFdiffusion3 (RFD3)
120
+
121
+ [RFdiffusion3](https://www.biorxiv.org/content/10.1101/2025.09.18.676967v2) is an all-atom generative model capable of designing protein structures under complex constraints.
122
+
123
+ > *See [models/rfd3/README.md](models/rfd3/README.md) for complete documentation.*
124
+
125
+ <div align="center">
126
+ <img src="docs/_static/rfd3_trajectory.png" alt="RFdiffusion3 generation trajectory." width="400">
127
+ </div>
128
+
129
+ ### ProteinMPNN
130
+ [ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) and [LigandMPNN](https://www.nature.com/articles/s41592-025-02626-1) are lightweight inverse-folding models which can be use to design diverse sequences for backbones under constrained conditions.
131
+
132
+ > *See [models/mpnn/README.md](models/mpnn/README.md) for complete documentation.*
133
+
134
+
135
+ ### RosettaFold3 (RF3)
136
+
137
+ [RF3](https://doi.org/10.1101/2025.08.14.670328) is a structure prediction neural network that narrows the gap between closed-source AF-3 and open-source alternatives.
138
+
139
+ <div align="center">
140
+ <img src="docs/_static/prot_dna.png" alt="Protein-DNA complex prediction" width="400">
141
+ </div>
142
+
143
+ > *See [models/rf3/README.md](models/rf3/README.md) for complete documentation.*
144
+
145
+ ---
146
+
147
+ ## Development
148
+
149
+ ### Code Organization
150
+
151
+ **Strict dependency flow:** `foundry` → `atomworks`
152
+
153
+ - **atomworks**: Structure I/O, preprocessing, featurization
154
+ - **foundry**: Model architectures, training, inference endpoints
155
+ - **models/\<model\>:** Released models.
156
+
157
+ #### For Core Developers (Multiple Packages)
158
+
159
+ Install both `foundry` and models in editable mode for development:
160
+
161
+ ```bash
162
+ # Install foundry and RF3 in editable mode
163
+ uv pip install -e . -e ./models/rf3 -e ./models/rfd3 -e ./models/mpnn
164
+
165
+ # Or install only foundry (no models)
166
+ uv pip install -e .
167
+ ```
168
+
169
+ This approach allows you to:
170
+ - Modify `foundry` shared utilities and see changes immediately
171
+ - Work on specific models without installing all models
172
+ - Add new models as independent packages in `models/`
173
+
174
+ ### Adding New Models
175
+
176
+ To add a new model:
177
+
178
+ 1. Create `models/<model_name>/` directory with its own `pyproject.toml`
179
+ 2. Add `foundry` as a dependency
180
+ 3. Implement model-specific code in `models/<model_name>/src/`
181
+ 4. Users can install with: `uv pip install -e ./models/<model_name>`
182
+
183
+ ### Pre-commit Formatting
184
+
185
+ We ship a `.pre-commit-config.yaml` that runs `make format` (via `ruff format`) before each commit. Enable it once per clone:
186
+
187
+ ```bash
188
+ pip install pre-commit # if not already installed
189
+ pre-commit install
190
+ ```
191
+
192
+ After installation the hook automatically formats the repo whenever you `git commit`. Use `pre-commit run --all-files` to apply it manually.
193
+
194
+ ## Citation
195
+
196
+ If you use this repository code or data in your work, please cite the relavant work as below:
197
+
198
+ ```bibtex
199
+ @article{corley2025accelerating,
200
+ title={Accelerating biomolecular modeling with atomworks and rf3},
201
+ author={Corley, Nathaniel and Mathis, Simon and Krishna, Rohith and Bauer, Magnus S and Thompson, Tuscan R and Ahern, Woody and Kazman, Maxwell W and Brent, Rafael I and Didi, Kieran and Kubaney, Andrew and others},
202
+ journal={bioRxiv},
203
+ year={2025}
204
+ }
205
+
206
+ @article {butcher2025_rfdiffusion3,
207
+ author = {Butcher, Jasper and Krishna, Rohith and Mitra, Raktim and Brent, Rafael Isaac and Li, Yanjing and Corley, Nathaniel and Kim, Paul T and Funk, Jonathan and Mathis, Simon Valentin and Salike, Saman and Muraishi, Aiko and Eisenach, Helen and Thompson, Tuscan Rock and Chen, Jie and Politanska, Yuliya and Sehgal, Enisha and Coventry, Brian and Zhang, Odin and Qiang, Bo and Didi, Kieran and Kazman, Maxwell and DiMaio, Frank and Baker, David},
208
+ title = {De novo Design of All-atom Biomolecular Interactions with RFdiffusion3},
209
+ elocation-id = {2025.09.18.676967},
210
+ year = {2025},
211
+ doi = {10.1101/2025.09.18.676967},
212
+ publisher = {Cold Spring Harbor Laboratory},
213
+ URL = {https://www.biorxiv.org/content/early/2025/11/19/2025.09.18.676967},
214
+ eprint = {https://www.biorxiv.org/content/early/2025/11/19/2025.09.18.676967.full.pdf},
215
+ journal = {bioRxiv}
216
+ }
217
+
218
+ @article{dauparas2022robust,
219
+ title={Robust deep learning--based protein sequence design using ProteinMPNN},
220
+ author={Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others},
221
+ journal={Science},
222
+ volume={378},
223
+ number={6615},
224
+ pages={49--56},
225
+ year={2022},
226
+ publisher={American Association for the Advancement of Science}
227
+ }
228
+
229
+ @article{dauparas2025atomic,
230
+ title={Atomic context-conditioned protein sequence design using LigandMPNN},
231
+ author={Dauparas, Justas and Lee, Gyu Rie and Pecoraro, Robert and An, Linna and Anishchenko, Ivan and Glasscock, Cameron and Baker, David},
232
+ journal={Nature Methods},
233
+ pages={1--7},
234
+ year={2025},
235
+ publisher={Nature Publishing Group US New York}
236
+ }
237
+ ```
238
+ ## Acknowledgments
239
+ We thank Rachel Clune and Hope Woods from the RosettaCommons for their collaboration on the codebase, documentation, tutorials and examples.