rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,265 @@
1
+ """
2
+ Collation utilities for PyTorch data loading.
3
+
4
+ This module provides collation functions for batching examples with
5
+ variable-length features, including padding and scalar feature
6
+ consistency checks.
7
+ """
8
+
9
+ import copy
10
+ from typing import Any, Dict, List
11
+
12
+ import torch
13
+ from atomworks.constants import UNKNOWN_AA
14
+ from mpnn.transforms.feature_aggregation.token_encodings import MPNN_TOKEN_ENCODING
15
+
16
+ MPNN_DEFAULT_PADDING = {
17
+ # Tensor features that require padding
18
+ "X": 0.0,
19
+ "X_m": False,
20
+ "S": MPNN_TOKEN_ENCODING.token_to_idx[UNKNOWN_AA],
21
+ "R_idx": -100,
22
+ "chain_labels": -1,
23
+ "residue_mask": False,
24
+ "Y": 0.0,
25
+ "Y_m": 0,
26
+ "Y_t": 0,
27
+ "designed_residue_mask": False,
28
+ "symmetry_equivalence_group": -1,
29
+ "symmetry_weight": 0,
30
+ "bias": 0,
31
+ "pair_bias": 0,
32
+ "temperature": 1,
33
+ }
34
+
35
+
36
+ class FeatureCollator:
37
+ """
38
+ Generic PyTorch collation class for handling variable-length features with
39
+ padding.
40
+
41
+ This collator processes batches of examples where each example
42
+ contains features that may have different shapes (requiring padding) or
43
+ scalar values (requiring consistency checks across the batch).
44
+ """
45
+
46
+ def __init__(self, default_padding: Dict[str, Any] = None):
47
+ """
48
+ Initialize the FeatureCollator.
49
+
50
+ Args:
51
+ default_padding (Dict[str, Any], optional): Maps each feature key
52
+ to a scalar or tensor to use for padding/missing values. Each
53
+ value should be appropriate for the expected feature type
54
+ (e.g., 0 for missing integers, 0.0 for missing floats, or a
55
+ tensor for missing tensor features). If None, uses
56
+ MPNN-specific defaults.
57
+ """
58
+ if default_padding is None:
59
+ default_padding = MPNN_DEFAULT_PADDING
60
+ self.default_padding = default_padding
61
+
62
+ def __call__(self, pipeline_outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
63
+ """
64
+ Collate a batch of pipeline outputs into network inputs.
65
+
66
+ Args:
67
+ pipeline_outputs (List[Dict[str, Any]]): List of examples, each
68
+ example is a dict with keys "input_features" (feature dict)
69
+ and "atom_array" (raw atom array).
70
+
71
+ Returns:
72
+ Dict[str, Any]: dict with keys 'input_features' (collated features)
73
+ and 'atom_array' (list of atom_arrays from each example).
74
+
75
+ Raises:
76
+ ValueError: If pipeline_outputs is empty, if feature values differ
77
+ across examples, or if feature types are unsupported or
78
+ mismatched.
79
+ """
80
+ # Raise error if batch is empty.
81
+ if not pipeline_outputs:
82
+ raise ValueError("Cannot collate empty batch")
83
+
84
+ # Only collate keys present in the first example's input_features.
85
+ keys_to_collate = pipeline_outputs[0]["input_features"].keys()
86
+ input_features = {}
87
+ for key in keys_to_collate:
88
+ # Assert that each key exists in every example's input_features.
89
+ for i, example in enumerate(pipeline_outputs):
90
+ if key not in example["input_features"]:
91
+ raise ValueError(
92
+ f"Feature '{key}' not found in " + f"example {i} input_features"
93
+ )
94
+
95
+ # Extract values for this feature from all examples.
96
+ values = [example["input_features"][key] for example in pipeline_outputs]
97
+ tensor_mask = [isinstance(v, torch.Tensor) for v in values]
98
+ if all(tensor_mask):
99
+ # Validate that all tensors have the same number of dimensions.
100
+ ndims = [v.ndim for v in values]
101
+ if len(set(ndims)) != 1:
102
+ raise ValueError(
103
+ f"Tensors for feature '{key}' have "
104
+ + f"mismatched dimensions: {ndims}"
105
+ )
106
+
107
+ ndim = ndims[0]
108
+
109
+ # Handle tensor collation and padding.
110
+ if ndim >= 1:
111
+ # Check if lengths along dimension 0 vary.
112
+ shapes_dim0 = [v.shape[0] for v in values]
113
+ if len(set(shapes_dim0)) > 1:
114
+ # Padding required - lengths vary along dim 0.
115
+ if key not in self.default_padding:
116
+ raise ValueError(
117
+ "No default padding value for feature "
118
+ + f"'{key}' required for tensor padding."
119
+ )
120
+ pad_val = self.default_padding[key]
121
+ max_L = max(shapes_dim0)
122
+ B = len(values)
123
+
124
+ # Verify that all other dimensions are identical.
125
+ other_shapes = [v.shape[1:] for v in values]
126
+ if len(set(other_shapes)) != 1:
127
+ raise ValueError(
128
+ f"Tensors for feature '{key}' "
129
+ + "have mismatched shapes beyond dimension 0:"
130
+ + f" {[v.shape for v in values]}"
131
+ )
132
+
133
+ # Create padded tensor.
134
+ shape = (B, max_L, *values[0].shape[1:])
135
+ padded = torch.full(
136
+ shape,
137
+ pad_val,
138
+ dtype=values[0].dtype,
139
+ device=values[0].device,
140
+ )
141
+
142
+ # Copy actual values into padded tensor.
143
+ for i, v in enumerate(values):
144
+ padded[i, : v.shape[0]] = v
145
+ input_features[key] = padded
146
+ else:
147
+ input_features[key] = torch.stack(values, dim=0)
148
+ else:
149
+ input_features[key] = torch.stack(values, dim=0)
150
+ else:
151
+ # Check that all values are the same type.
152
+ first_type = type(values[0])
153
+ if not all(isinstance(v, first_type) for v in values):
154
+ raise ValueError(
155
+ f"Feature '{key}' has mismatched "
156
+ + f"types: {[type(v).__name__ for v in values]}"
157
+ )
158
+
159
+ # Deep comparison for all values.
160
+ first_val = values[0]
161
+ if all(self._deep_equal(first_val, v) for v in values):
162
+ input_features[key] = copy.deepcopy(first_val)
163
+ else:
164
+ raise ValueError(
165
+ f"Feature '{key}' differs across examples: {values}"
166
+ )
167
+ atom_arrays = [example["atom_array"] for example in pipeline_outputs]
168
+
169
+ network_inputs = {"input_features": input_features, "atom_array": atom_arrays}
170
+
171
+ return network_inputs
172
+
173
+ def _deep_equal(self, a, b):
174
+ if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
175
+ return torch.equal(a, b)
176
+ if isinstance(a, dict) and isinstance(b, dict):
177
+ if a.keys() != b.keys():
178
+ return False
179
+ return all(self._deep_equal(a[k], b[k]) for k in a)
180
+ if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
181
+ return len(a) == len(b) and all(
182
+ self._deep_equal(x, y) for x, y in zip(a, b)
183
+ )
184
+ return a == b
185
+
186
+
187
+ class TokenBudgetAwareFeatureCollator(FeatureCollator):
188
+ """
189
+ Feature collator that enforces token budget constraints before collation.
190
+
191
+ This collator sorts pipeline outputs by sequence length and removes the
192
+ largest examples first if the batch would exceed the token budget when
193
+ padded to the maximum length.
194
+
195
+ Args:
196
+ max_tokens_with_padding: Maximum number of tokens allowed per batch,
197
+ including padding. The constraint is
198
+ max(batch_lengths) * len(batch) <= max_tokens.
199
+ default_padding: Default padding values for features.
200
+ """
201
+
202
+ def __init__(
203
+ self, max_tokens_with_padding: int, default_padding: Dict[str, Any] = None
204
+ ):
205
+ super().__init__(default_padding)
206
+ self.max_tokens_with_padding = max_tokens_with_padding
207
+
208
+ if max_tokens_with_padding <= 0:
209
+ raise ValueError("max_tokens_with_padding must be greater than 0")
210
+
211
+ def __call__(self, pipeline_outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
212
+ """
213
+ Collate pipeline outputs while respecting token budget constraints.
214
+
215
+ Sorts examples by length and removes largest examples first if needed
216
+ to stay within the token budget.
217
+
218
+ Args:
219
+ pipeline_outputs: List of pipeline output dictionaries.
220
+
221
+ Returns:
222
+ Dict containing collated features and atom arrays.
223
+
224
+ Raises:
225
+ ValueError: If pipeline_outputs is empty.
226
+ """
227
+ if not pipeline_outputs:
228
+ raise ValueError("Cannot collate empty batch")
229
+
230
+ # Extract lengths.
231
+ examples_with_L = []
232
+ for example in pipeline_outputs:
233
+ L = example["input_features"]["S"].shape[0]
234
+ examples_with_L.append((L, example))
235
+
236
+ # Sort by length (ascending).
237
+ examples_with_L.sort(key=lambda x: x[0])
238
+
239
+ # Apply token budget constraint by removing largest examples first.
240
+ filtered_examples = []
241
+ max_length = 0
242
+ for L, example in examples_with_L:
243
+ new_batch_size = len(filtered_examples) + 1
244
+ potential_max_length = max(L, max_length)
245
+
246
+ if potential_max_length * new_batch_size > self.max_tokens_with_padding:
247
+ # Adding this example would violate constraint, stop here.
248
+ break
249
+
250
+ filtered_examples.append((L, example))
251
+ max_length = potential_max_length
252
+
253
+ # Extract just the examples (without lengths) for collation.
254
+ filtered_pipeline_outputs = [example for _, example in filtered_examples]
255
+
256
+ # If no examples remain after filtering, raise an error
257
+ if not filtered_pipeline_outputs:
258
+ raise ValueError(
259
+ "No examples remain after applying token budget constraint. "
260
+ "All examples exceed max_tokens_with_padding="
261
+ f"{self.max_tokens_with_padding}"
262
+ )
263
+
264
+ # Call parent collation on filtered examples
265
+ return super().__call__(filtered_pipeline_outputs)
mpnn/inference.py ADDED
@@ -0,0 +1,53 @@
1
+ #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/mpnn_exec.sh" "$0" "$@"'
2
+
3
+ from mpnn.inference_engines.mpnn import MPNNInferenceEngine
4
+ from mpnn.utils.inference import (
5
+ MPNN_GLOBAL_INFERENCE_DEFAULTS,
6
+ build_arg_parser,
7
+ cli_to_json,
8
+ )
9
+
10
+
11
+ def main() -> None:
12
+ """Top-level CLI entry point for MPNN inference.
13
+
14
+ This script wires together:
15
+ - CLI / arg parsing
16
+ - JSON config building (or loading an existing JSON)
17
+ - Execution of the MPNNInferenceEngine
18
+ """
19
+ # CLI
20
+ parser = build_arg_parser()
21
+ args = parser.parse_args()
22
+
23
+ # JSON config building/loading
24
+ config = cli_to_json(args)
25
+
26
+ # Split global vs per-input config
27
+ engine = MPNNInferenceEngine(
28
+ model_type=config.get(
29
+ "model_type", MPNN_GLOBAL_INFERENCE_DEFAULTS["model_type"]
30
+ ),
31
+ checkpoint_path=config.get(
32
+ "checkpoint_path", MPNN_GLOBAL_INFERENCE_DEFAULTS["checkpoint_path"]
33
+ ),
34
+ is_legacy_weights=config.get(
35
+ "is_legacy_weights", MPNN_GLOBAL_INFERENCE_DEFAULTS["is_legacy_weights"]
36
+ ),
37
+ out_directory=config.get(
38
+ "out_directory", MPNN_GLOBAL_INFERENCE_DEFAULTS["out_directory"]
39
+ ),
40
+ write_fasta=config.get(
41
+ "write_fasta", MPNN_GLOBAL_INFERENCE_DEFAULTS["write_fasta"]
42
+ ),
43
+ write_structures=config.get(
44
+ "write_structures", MPNN_GLOBAL_INFERENCE_DEFAULTS["write_structures"]
45
+ ),
46
+ )
47
+
48
+ # In this case, structures are resolved from inputs; no explicit atom arrays
49
+ _ = engine.run(input_dicts=config["inputs"], atom_arrays=None)
50
+
51
+
52
+ if __name__ == "__main__":
53
+ main()