rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,67 @@
1
+ from collections import OrderedDict
2
+ from copy import deepcopy
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class EMA(nn.Module):
9
+ # TODO: Rename shadow to `ema_model` to better match convention
10
+ def __init__(self, model: nn.Module, decay: float):
11
+ """Initialize the Exponential Moving Average (EMA) module.
12
+
13
+ EMA maintains a shadow model that slowly tracks the weight of the original model.
14
+
15
+ Args:
16
+ model: The original model.
17
+ decay: The decay rate of the EMA. The shadow model will be updated with the formula:
18
+ shadow_variable -= (1 - decay) * (shadow_variable - variable)
19
+ """
20
+ super().__init__()
21
+ self.decay = decay
22
+
23
+ self.model = model
24
+ self.shadow = deepcopy(self.model)
25
+
26
+ # Detach the shadow model from the computation graph
27
+ for param in self.shadow.parameters():
28
+ param.detach_()
29
+
30
+ @torch.no_grad()
31
+ def update(self):
32
+ """Update the shadow model using the weight of the original model and the decay rate."""
33
+ if not self.training:
34
+ raise RuntimeError("EMA update should only be called during training")
35
+
36
+ # ... get the model and shadow parameters
37
+ model_params = OrderedDict(self.model.named_parameters())
38
+ shadow_params = OrderedDict(self.shadow.named_parameters())
39
+
40
+ # ... ensure that both models have the same set of keys
41
+ assert model_params.keys() == shadow_params.keys()
42
+
43
+ for name, param in model_params.items():
44
+ # Update the shadow model with the formula:
45
+ # shadow_variable -= (1 - decay) * (shadow_variable - variable)
46
+ # Reference: https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
47
+ if param.requires_grad:
48
+ shadow_params[name].sub_(
49
+ (1.0 - self.decay) * (shadow_params[name] - param)
50
+ )
51
+
52
+ # ... and do the same with the buffers (e.g,. objects that are part of the module state but not trainable parameters)
53
+ model_buffers = OrderedDict(self.model.named_buffers())
54
+ shadow_buffers = OrderedDict(self.shadow.named_buffers())
55
+
56
+ assert model_buffers.keys() == shadow_buffers.keys()
57
+
58
+ for name, buffer in model_buffers.items():
59
+ # ... copy the buffers from the model to the shadow
60
+ shadow_buffers[name].copy_(buffer)
61
+
62
+ def forward(self, *args, **kwargs):
63
+ """Dynamic dispatch to the correct model (model or shadow)."""
64
+ if self.training:
65
+ return self.model(*args, **kwargs)
66
+ else:
67
+ return self.shadow(*args, **kwargs)
@@ -0,0 +1,61 @@
1
+ """Utilities for gradient checkpointing.
2
+
3
+ References:
4
+ * `PyTorch Checkpoint Documentation`_
5
+
6
+ .. _PyTorch Checkpoint Documentation: https://pytorch.org/docs/stable/checkpoint.html
7
+ """
8
+
9
+ import torch
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+
13
+ def create_custom_forward(module, **kwargs):
14
+ """Create a custom forward function for gradient checkpointing with fixed kwargs.
15
+
16
+ Enables passing keyword arguments to a module when using PyTorch's checkpoint function,
17
+ which only accepts positional arguments for the function to be checkpointed.
18
+
19
+ Args:
20
+ module: The callable (typically a nn.Module) to wrap.
21
+ **kwargs: Keyword arguments to pass to the module during forward.
22
+
23
+ Returns:
24
+ A callable that accepts only positional arguments and forwards them along
25
+ with the fixed kwargs to the original module.
26
+ """
27
+
28
+ def custom_forward(*inputs):
29
+ return module(*inputs, **kwargs)
30
+
31
+ return custom_forward
32
+
33
+
34
+ def activation_checkpointing(function):
35
+ """Decorator to enable gradient checkpointing for a function during training.
36
+
37
+ Args:
38
+ function: The function to apply gradient checkpointing to.
39
+
40
+ Returns:
41
+ Wrapped function that conditionally applies checkpointing based on gradient state.
42
+
43
+ Examples:
44
+ Apply to a forward pass method::
45
+
46
+ @activation_checkpointing
47
+ def forward(self, x, mask=None):
48
+ return self.layer(x, mask)
49
+
50
+ Notes:
51
+ Uses ``use_reentrant=False`` for compatibility with recent PyTorch versions.
52
+ """
53
+
54
+ def wrapper(*args, **kwargs):
55
+ if torch.is_grad_enabled():
56
+ return checkpoint(
57
+ create_custom_forward(function, **kwargs), *args, use_reentrant=False
58
+ )
59
+ return function(*args, **kwargs)
60
+
61
+ return wrapper
@@ -0,0 +1,91 @@
1
+ from dataclasses import dataclass
2
+
3
+ from torch.optim.lr_scheduler import LRScheduler, _LRScheduler
4
+ from torch.optim.optimizer import Optimizer
5
+
6
+
7
+ class AF3Scheduler(_LRScheduler):
8
+ """Implements a two-phase learning rate schedule a-la AF-3:
9
+ 1. The base learning rate is 1.8 · 10^−3, which is linearly increased from 0 over the first 1,000 steps.
10
+ 2. The learning rate is then decreased by a factor of 0.95 every 50,000 steps.
11
+
12
+ From the AF-3 Supplement, Section 5.4:
13
+ > "For training we use the Adam optimizer with parameters β1 = 0.9, β2 = 0.95, ϵ = 10^−8. The base learning rate
14
+ is 1.8 · 10^−3, which is linearly increased from 0 over the first 1,000 steps. The learning rate is then decreased
15
+ by a factor of 0.95 every 5 · 10^4 steps."
16
+
17
+ References:
18
+ - AF-3 Supplement
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ optimizer: Optimizer,
24
+ base_lr: float = 1.8e-3,
25
+ warmup_steps: int = 1000,
26
+ decay_factor: float = 0.95,
27
+ decay_steps: int = 50000,
28
+ last_epoch: int = -1,
29
+ ) -> None:
30
+ """Initializes a new instance of AF3LRScheduler.
31
+
32
+ Note that the "last_epoch" value is incremented every time we call `scheduler.step()`
33
+ method; we name it "epoch" to follow the PyTorch convention.
34
+
35
+ Args:
36
+ optimizer (Optimizer): Wrapped optimizer.
37
+ base_lr (float): The base learning rate after warmup (which will then be decayed).
38
+ warmup_steps (int): Number of steps for linear warmup.
39
+ decay_factor (float): Factor by which the learning rate is multiplied every decay_steps.
40
+ decay_steps (int): Number of steps between each decay.
41
+ last_epoch (int): The index of the last epoch. Default: -1.
42
+ """
43
+ self.base_lr = base_lr
44
+ self.warmup_steps = warmup_steps
45
+ self.decay_factor = decay_factor
46
+ self.decay_steps = decay_steps
47
+ super(AF3Scheduler, self).__init__(optimizer, last_epoch)
48
+
49
+ def get_lr(self) -> list[float]:
50
+ if self.last_epoch < self.warmup_steps:
51
+ # Linear warmup
52
+ return [
53
+ self.base_lr * (self.last_epoch / self.warmup_steps)
54
+ for _ in self.optimizer.param_groups
55
+ ]
56
+ else:
57
+ # Decay after warmup
58
+ num_decays = (self.last_epoch - self.warmup_steps) // self.decay_steps
59
+ return [
60
+ self.base_lr * (self.decay_factor**num_decays)
61
+ for _ in self.optimizer.param_groups
62
+ ]
63
+
64
+
65
+ @dataclass
66
+ class SchedulerConfig:
67
+ """Flexible configuration for a learning rate scheduler.
68
+
69
+ Modeled on the PyTorch Lightning scheduler configuration.
70
+
71
+ Attributes:
72
+ scheduler (LRScheduler): The learning rate scheduler instance. Must inherit from `torch.optim.lr_scheduler.LRScheduler`.
73
+ interval (str): The interval at which to apply the scheduler, typically "epoch" or "step". Defaults to "step".
74
+ frequency (int): The frequency of applying the scheduler. For example, a frequency of 1 means the scheduler is applied every epoch. Defaults to 1.
75
+ """
76
+
77
+ scheduler: LRScheduler = None
78
+ interval: str = "step"
79
+ frequency: int = 1
80
+
81
+ def state_dict(self) -> dict:
82
+ return {
83
+ "scheduler": self.scheduler.state_dict(),
84
+ "interval": self.interval,
85
+ "frequency": self.frequency,
86
+ }
87
+
88
+ def load_state_dict(self, state_dict: dict) -> None:
89
+ self.scheduler.load_state_dict(state_dict["scheduler"])
90
+ self.interval = state_dict["interval"]
91
+ self.frequency = state_dict["frequency"]
@@ -0,0 +1,86 @@
1
+ import logging
2
+
3
+ import torch
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def weighted_rigid_align(
9
+ X_L, # [B, L, 3]
10
+ X_gt_L, # [B, L, 3]
11
+ X_exists_L=None, # [L]
12
+ w_L=None, # [B, L]
13
+ ):
14
+ """
15
+ Weighted rigid body alignment of X_gt_L onto X_L with weights w_L
16
+ Allows for "moving target" ground truth that is se3 invariant
17
+ Following algorithm 28 in AF3 paper
18
+ Returns:
19
+ X_align_L: [B, L, 3]
20
+ """
21
+ assert X_L.shape == X_gt_L.shape
22
+ assert X_L.shape[:-1] == w_L.shape
23
+
24
+ if X_exists_L is None:
25
+ X_exists_L = torch.ones((X_L.shape[-2]), dtype=torch.bool)
26
+ if w_L is None:
27
+ w_L = torch.ones_like(X_L[..., 0])
28
+ else:
29
+ w_L = w_L.to(torch.float32)
30
+
31
+ # Assert `X_exists_L` is a boolean mask
32
+ assert (
33
+ X_exists_L.dtype == torch.bool
34
+ ), "X_exists_L should be a boolean mask! Otherwise, the alignment will be incorrect (silent failure)!"
35
+
36
+ X_resolved = X_L[:, X_exists_L]
37
+ X_gt_resolved = X_gt_L[:, X_exists_L]
38
+ w_resolved = w_L[:, X_exists_L]
39
+ u_X = torch.sum(X_resolved * w_resolved.unsqueeze(-1), dim=-2) / torch.sum(
40
+ w_resolved, dim=-1, keepdim=True
41
+ )
42
+ u_X_gt = torch.sum(X_gt_resolved * w_resolved.unsqueeze(-1), dim=-2) / torch.sum(
43
+ w_resolved, dim=-1, keepdim=True
44
+ )
45
+
46
+ X_resolved = X_resolved - u_X.unsqueeze(-2)
47
+ X_gt_resolved = X_gt_resolved - u_X_gt.unsqueeze(-2)
48
+
49
+ # Computation of the covariance matrix
50
+ C = torch.einsum("bji,bjk->bik", w_resolved[..., None] * X_gt_resolved, X_resolved)
51
+
52
+ U, S, V = torch.linalg.svd(C)
53
+
54
+ R = U @ V
55
+ B, _, _ = X_L.shape
56
+ F = torch.eye(3, 3, device=X_L.device)[None].tile(
57
+ (
58
+ B,
59
+ 1,
60
+ 1,
61
+ )
62
+ )
63
+
64
+ F[..., -1, -1] = torch.sign(torch.linalg.det(R))
65
+ R = U @ F @ V
66
+
67
+ X_gt_L = X_gt_L - u_X_gt.unsqueeze(-2)
68
+ X_align_L = X_gt_L @ R + u_X.unsqueeze(-2)
69
+
70
+ return X_align_L.detach()
71
+
72
+
73
+ def get_rmsd(xyz1, xyz2, eps=1e-4):
74
+ L = xyz1.shape[-2]
75
+ rmsd = torch.sqrt(torch.sum((xyz2 - xyz1) * (xyz2 - xyz1), axis=(-1, -2)) / L + eps)
76
+ return rmsd
77
+
78
+
79
+ def superimpose(xyz1, xyz2, mask, eps=1e-4):
80
+ """
81
+ Superimpose xyz1 onto xyz2 using mask
82
+ """
83
+ L = xyz1.shape[-2]
84
+ assert mask.shape == (L,)
85
+ assert xyz1.shape == xyz2.shape
86
+ assert mask.dtype == torch.bool