rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,222 @@
1
+ import torch
2
+ from beartype.typing import Any, Literal
3
+ from jaxtyping import Float
4
+
5
+ from foundry.utils.ddp import RankedLogger
6
+ from foundry.utils.rotation_augmentation import centre_random_augmentation
7
+
8
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
9
+
10
+
11
+ class SampleDiffusion:
12
+ """Algorithm 18"""
13
+
14
+ def __init__(
15
+ self,
16
+ *,
17
+ # Hyperparameters
18
+ num_timesteps: int, # AF-3: 200
19
+ min_t: int, # AF-3: 0
20
+ max_t: int, # AF-3: 1
21
+ sigma_data: int, # AF-3: 16
22
+ s_min: float, # AF-3: 4e-4
23
+ s_max: int, # AF-3: 160
24
+ p: int, # AF-3: 7
25
+ gamma_0: float, # AF-3: 0.8
26
+ gamma_min: float, # AF-3: 1.0,
27
+ noise_scale: float, # AF-3: 1.003,
28
+ step_scale: float, # AF-3: 1.5,
29
+ solver: Literal["af3"],
30
+ ):
31
+ """Initialize the diffusion sampler, to perform a complete diffusion roll-out with the given recycling outputs.
32
+
33
+ We do not use default values for the parameters to make the Hydra configuration the single source of truth and avoid silent failures.
34
+
35
+ Args:
36
+ num_timesteps (int): The number of timesteps for which the noise schedule is constructed. Default is 200, per AF3.
37
+ min_t (float): The minimum value of t in the schedule. Default is 0, per AF3.
38
+ max_t (float): The maximum value of t in the schedule. Default is 1, per AF3.
39
+ sigma_data (int): A constant determined by the variance of the data. Default is 16, as defined in the AlphaFold 3 Supplement (Algorithm 20, Diffusion Module).
40
+ s_min (float): The minimum value of the noise schedule. Default is 4e-4, per AF3.
41
+ s_max (float): The maximum value of the noise schedule. Default is 160, per AF3.
42
+ p (int): A constant that determines the shape of the noise schedule. Default is 7, per AF3.
43
+ gamma_0 (float): The value of gamma when t > gamma_min. Default is 0.8, per AF3.
44
+ solver (str): The solver to use for the diffusion process. Default is "af3".
45
+
46
+ TODO: Continue documentation of the remaining parameters.
47
+ """
48
+ self.num_timesteps = num_timesteps
49
+ self.min_t = min_t
50
+ self.max_t = max_t
51
+ self.sigma_data = sigma_data
52
+ self.s_min = s_min
53
+ self.s_max = s_max
54
+ self.p = p
55
+ self.gamma_0 = gamma_0
56
+ self.gamma_min = gamma_min
57
+ self.noise_scale = noise_scale
58
+ self.step_scale = step_scale
59
+ self.solver = solver
60
+
61
+ def _construct_inference_noise_schedule(self, device: torch.device) -> torch.Tensor:
62
+ """Constructs a noise schedule for use during inference.
63
+
64
+ The inference noise schedule is defined in the AF-3 supplement as:
65
+
66
+ t_hat = sigma_data * (s_max**(1/p) + t * (s_min**(1/p) - s_max**(1/p)))**p
67
+
68
+ Returns:
69
+ torch.Tensor: A tensor representing the noise schedule `t_hat`.
70
+
71
+ Reference:
72
+ AlphaFold 3 Supplement, Section 3.7.1.
73
+ """
74
+ # Create a linearly spaced tensor of timesteps between min_t and max_t
75
+ t = torch.linspace(self.min_t, self.max_t, self.num_timesteps, device=device)
76
+
77
+ # Construct the noise schedule, using the formula provided in the reference
78
+ t_hat = (
79
+ self.sigma_data
80
+ * (
81
+ (self.s_max) ** (1 / self.p)
82
+ + t * (self.s_min ** (1 / self.p) - self.s_max ** (1 / self.p))
83
+ )
84
+ ** self.p
85
+ )
86
+
87
+ return t_hat
88
+
89
+ def _get_initial_structure(
90
+ self,
91
+ c0: torch.Tensor,
92
+ D: int,
93
+ L: int,
94
+ coord_atom_lvl_to_be_noised: torch.Tensor,
95
+ ) -> torch.Tensor:
96
+ """Sample initial point cloud from a normal distribution.
97
+
98
+ Args:
99
+ c0 (torch.Tensor): A scalar tensor that will be used to scale the initial point cloud. Effectively, the same as
100
+ directly changing the standard deviation of the normal distribution. Derived from noise_schedule[0].
101
+ D (int): The number of structures to sample.
102
+ L (int): The number of atoms in the structure.
103
+ coord_atom_lvl_to_be_noised (torch.Tensor): The atom-level coordinates to be noised (either completely or partially)
104
+ """
105
+ noise = c0 * torch.normal(mean=0.0, std=1.0, size=(D, L, 3), device=c0.device)
106
+ X_L = noise + coord_atom_lvl_to_be_noised
107
+
108
+ return X_L
109
+
110
+ def sample_diffusion_like_af3(
111
+ self,
112
+ *,
113
+ S_inputs_I: Float[torch.Tensor, "I c_s_inputs"],
114
+ S_trunk_I: Float[torch.Tensor, "I c_s"],
115
+ Z_trunk_II: Float[torch.Tensor, "I I c_z"],
116
+ f: dict[str, Any],
117
+ diffusion_module: torch.nn.Module,
118
+ diffusion_batch_size: int,
119
+ coord_atom_lvl_to_be_noised: Float[torch.Tensor, "D L 3"],
120
+ ) -> dict[str, Any]:
121
+ """Perform a complete diffusion roll-out with the given recycling outputs.
122
+
123
+ Args:
124
+ diffusion_module (torch.nn.Module): The diffusion module to use for denoising. If using EMA and performing validation or inference,
125
+ this model should be the EMA model.
126
+ """
127
+ # Construct the noise schedule t_hat for inference on the appropriate device
128
+ noise_schedule = self._construct_inference_noise_schedule(
129
+ device=S_inputs_I.device
130
+ )
131
+
132
+ # Infer number of atoms from any atom-level feature
133
+ L = f["ref_element"].shape[0]
134
+ D = diffusion_batch_size
135
+
136
+ # Initial X_L is drawn from a normal distribution with a mean vector of 0 and a
137
+ # covariance matrix equal to the 3x3 identity matrix, scaled by the noise schedule
138
+ X_L = self._get_initial_structure(
139
+ c0=noise_schedule[0],
140
+ D=D,
141
+ L=L,
142
+ coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
143
+ ) # (D, L, 3)
144
+
145
+ X_noisy_L_traj = []
146
+ X_denoised_L_traj = []
147
+ t_hats = []
148
+
149
+ for c_t_minus_1, c_t in zip(noise_schedule, noise_schedule[1:]):
150
+ # (All predicted atoms exist)
151
+ X_exists_L = torch.ones((D, L)).bool() # (D, L)
152
+
153
+ # Apply a random rotation and translation to the structure
154
+ # TODO: Make s_trans a hyperparameter
155
+ s_trans = 1.0
156
+ X_L = centre_random_augmentation(X_L, X_exists_L, s_trans)
157
+
158
+ # Update gamma
159
+ gamma = self.gamma_0 if c_t > self.gamma_min else 0
160
+
161
+ # Compute the value of t_hat
162
+ t_hat = c_t_minus_1 * (gamma + 1)
163
+
164
+ # Noise the coordinates with scaled Gaussian noise
165
+ epsilon_L = (
166
+ self.noise_scale
167
+ * torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1))
168
+ * torch.normal(mean=0.0, std=1.0, size=X_L.shape, device=X_L.device)
169
+ )
170
+ X_noisy_L = X_L + epsilon_L
171
+
172
+ # Denoise the coordinates
173
+ X_denoised_L = diffusion_module(
174
+ X_noisy_L=X_noisy_L,
175
+ t=t_hat.tile(D),
176
+ f=f,
177
+ S_inputs_I=S_inputs_I,
178
+ S_trunk_I=S_trunk_I,
179
+ Z_trunk_II=Z_trunk_II,
180
+ )
181
+
182
+ # Compute the delta between the noisy and denoised coordinates, scaled by t_hat
183
+ delta_L = (X_noisy_L - X_denoised_L) / t_hat
184
+ d_t = c_t - t_hat
185
+
186
+ # Update the coordinates, scaled by the step size
187
+ X_L = X_noisy_L + self.step_scale * d_t * delta_L
188
+
189
+ X_noisy_L_scaled = (
190
+ X_noisy_L
191
+ / (torch.sqrt(t_hat[..., None, None] ** 2 + self.sigma_data**2))
192
+ ) * self.sigma_data
193
+ # Append the results to the trajectory (for visualization of the diffusion process)
194
+ X_noisy_L_traj.append(X_noisy_L_scaled)
195
+ X_denoised_L_traj.append(X_denoised_L)
196
+ t_hats.append(t_hat)
197
+
198
+ return dict(
199
+ X_L=X_L, # (D, L, 3)
200
+ X_noisy_L_traj=X_noisy_L_traj, # list[Tensor[D, L, 3]]
201
+ X_denoised_L_traj=X_denoised_L_traj, # list[Tensor[D, L, 3]]
202
+ t_hats=t_hats, # list[Tensor[D]], where D is shared across all diffusion batches
203
+ )
204
+
205
+
206
+ class SamplePartialDiffusion(SampleDiffusion):
207
+ def __init__(self, partial_t: int, **kwargs):
208
+ super().__init__(**kwargs)
209
+ self.partial_t = partial_t
210
+
211
+ def _construct_inference_noise_schedule(self, device: torch.device) -> torch.Tensor:
212
+ """Constructs a noise schedule for use during inference with partial t."""
213
+ t_hat_full = super()._construct_inference_noise_schedule(device)
214
+
215
+ assert (
216
+ self.partial_t < self.num_timesteps
217
+ ), f"Partial t ({self.partial_t}) must be less than num_timesteps ({self.num_timesteps})"
218
+ ranked_logger.info(
219
+ f"Using partial t index: {self.partial_t} [e.g., {t_hat_full[self.partial_t]:.4}], or {self.partial_t / (self.num_timesteps):.2%}, by index (100% is data, 0% is noise)"
220
+ )
221
+
222
+ return t_hat_full[self.partial_t :]
rf3/inference.py ADDED
@@ -0,0 +1,65 @@
1
+ #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
2
+
3
+ import os
4
+
5
+ import hydra
6
+ import rootutils
7
+ from dotenv import load_dotenv
8
+ from hydra.utils import instantiate
9
+ from omegaconf import DictConfig, OmegaConf
10
+
11
+ from foundry.utils.logging import suppress_warnings
12
+
13
+ # Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
14
+ # NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
15
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
16
+
17
+ load_dotenv(override=True)
18
+
19
+ _config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
20
+
21
+
22
+ @hydra.main(
23
+ config_path=_config_path,
24
+ config_name="inference",
25
+ version_base="1.3",
26
+ )
27
+ def run_inference(cfg: DictConfig) -> None:
28
+ """Execute RF3 inference pipeline."""
29
+
30
+ # Extract run() parameters from config
31
+ # Preserve string inputs, convert other sequence-like inputs to a Python list (None -> [])
32
+ inputs_param = cfg.inputs if isinstance(cfg.inputs, str) else list(cfg.inputs or [])
33
+
34
+ run_params = {
35
+ "inputs": inputs_param,
36
+ "out_dir": str(cfg.out_dir) if cfg.get("out_dir") else None,
37
+ "dump_predictions": cfg.get("dump_predictions", True),
38
+ "dump_trajectories": cfg.get("dump_trajectories", False),
39
+ "one_model_per_file": cfg.get("one_model_per_file", False),
40
+ "annotate_b_factor_with_plddt": cfg.get("annotate_b_factor_with_plddt", False),
41
+ "sharding_pattern": cfg.get("sharding_pattern", None),
42
+ "skip_existing": cfg.get("skip_existing", False),
43
+ "template_selection": cfg.get("template_selection", None),
44
+ "ground_truth_conformer_selection": cfg.get(
45
+ "ground_truth_conformer_selection", None
46
+ ),
47
+ "cyclic_chains": cfg.get("cyclic_chains", []),
48
+ }
49
+
50
+ # Create init config with only __init__ params
51
+ cfg_dict = OmegaConf.to_container(cfg, resolve=True)
52
+ run_param_keys = set(run_params.keys())
53
+ init_cfg_dict = {k: v for k, v in cfg_dict.items() if k not in run_param_keys}
54
+ init_cfg = OmegaConf.create(init_cfg_dict)
55
+
56
+ # Instantiate engine (only __init__ params)
57
+ inference_engine = instantiate(init_cfg, _convert_="partial", _recursive_=False)
58
+
59
+ # Run inference
60
+ with suppress_warnings(is_inference=True):
61
+ inference_engine.run(**run_params)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ run_inference()
@@ -0,0 +1,5 @@
1
+ """RF3 inference engines."""
2
+
3
+ from rf3.inference_engines.rf3 import RF3InferenceEngine
4
+
5
+ __all__ = ["RF3InferenceEngine"]