rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,192 @@
1
+ from itertools import combinations
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from jaxtyping import Bool, Float
7
+ from numpy.typing import NDArray
8
+
9
+
10
+ def find_bin_midpoints(
11
+ max_distance: float, num_bins: int, device: Union[str, torch.device] = "cpu"
12
+ ) -> Float[torch.Tensor, "num_bins"]:
13
+ """
14
+ Find the bin midpoints for a given binning scheme. Used to find expectation of values when converting binned
15
+ predictions to unbinned predictions. Assumes the minimum of the schema is 0.
16
+ Args:
17
+ max_distance: float, maximum distance
18
+ num_bins: int, number of bins
19
+ device: device to run on
20
+ Returns:
21
+ pae_midpoints: [num_bins], bin midpoints
22
+ """
23
+ bin_size = max_distance / num_bins
24
+ bins = torch.linspace(
25
+ bin_size, max_distance - bin_size, num_bins - 1, device=device
26
+ )
27
+ midpoints = (bins[1:] + bins[:-1]) / 2
28
+ midpoints = torch.cat(
29
+ [(bins[0] - bin_size / 2)[None], midpoints, bins[-1:] + bin_size / 2]
30
+ )
31
+
32
+ return midpoints
33
+
34
+
35
+ def unbin_logits(
36
+ logits: Float[torch.Tensor, "B num_bins L X"], max_distance: float, num_bins: int
37
+ ) -> Float[torch.Tensor, "B L L"]:
38
+ """
39
+ Unbin the logits to get the matrix
40
+ Args:
41
+ logits: [B, num_bins, L, X], binned logits where X is 23 for plddt and L for pae and pde
42
+ max_distance: float, maximum distance
43
+ num_bins: int, number of bins
44
+ Returns:
45
+ unbinned: [B, L, L], unbinned matrix
46
+ """
47
+ midpoints = find_bin_midpoints(max_distance, num_bins, device=logits.device)
48
+ probabilities = torch.nn.Softmax(dim=1)(logits).detach().float()
49
+ unbinned = (probabilities * midpoints[None, :, None, None]).sum(dim=1)
50
+ return unbinned
51
+
52
+
53
+ def create_chainwise_masks_1d(
54
+ ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu"
55
+ ) -> dict[str, Bool[torch.Tensor, "L"]]:
56
+ """
57
+ Create 1D chainwise masks for a set of chain labels
58
+ Args:
59
+ ch_label: np.ndarray [L], chain labels
60
+ device: torch.device, device to run on
61
+ Returns:
62
+ ch_masks: dict, chain maps chain letter to which elements to score for that chain
63
+ """
64
+ unique_chains = np.unique(ch_label)
65
+ ch_masks = {}
66
+ for chain in unique_chains:
67
+ indices = torch.from_numpy((ch_label == chain)).to(
68
+ dtype=torch.bool, device=device
69
+ )
70
+ ch_masks[chain] = indices
71
+ return ch_masks
72
+
73
+
74
+ def create_chainwise_masks_2d(
75
+ ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu"
76
+ ) -> dict[str, Bool[torch.Tensor, "L L"]]:
77
+ """
78
+ Create 2D chainwise masks for a set of chain labels
79
+ Args:
80
+ ch_label: np.ndarray [L], chain labels
81
+ device: torch.device, device to run on
82
+ Returns:
83
+ ch_masks: dict, chain maps chain letter to which elements to score for that chain
84
+ """
85
+ unique_chains = np.unique(ch_label)
86
+ ch_masks = {}
87
+ for chain in unique_chains:
88
+ indices = torch.from_numpy((ch_label == chain))
89
+ mask = torch.outer(indices, indices).to(dtype=torch.bool, device=device)
90
+ ch_masks[chain] = mask
91
+ return ch_masks
92
+
93
+
94
+ def create_interface_masks_2d(
95
+ ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu"
96
+ ) -> dict[tuple[str, str], Bool[torch.Tensor, "L L"]]:
97
+ """
98
+ Create interface masks for a set of chain labels
99
+ Args:
100
+ ch_label: np.ndarray [L], chain labels
101
+ device: torch.device, device to run on
102
+ Returns:
103
+ pairs_to_score: dict mapping chain pairs to boolean masks
104
+ """
105
+ unique_chains = np.unique(ch_label)
106
+ pairs_to_score = {}
107
+ for chain_i, chain_j in combinations(unique_chains, 2):
108
+ chain_i_indices = torch.from_numpy((ch_label == chain_i))
109
+ chain_j_indices = torch.from_numpy((ch_label == chain_j))
110
+ to_be_scored = torch.outer(chain_i_indices, chain_j_indices).to(
111
+ dtype=torch.bool, device=device
112
+ ) + torch.outer(chain_j_indices, chain_i_indices).to(
113
+ dtype=torch.bool, device=device
114
+ )
115
+ pairs_to_score[(chain_i, chain_j)] = to_be_scored
116
+ return pairs_to_score
117
+
118
+
119
+ def compute_mean_over_subsampled_pairs(
120
+ matrix_to_mean: Float[torch.Tensor, "B L M"],
121
+ pairs_to_score: Bool[torch.Tensor, "L M"],
122
+ eps: float = 1e-6,
123
+ ) -> Float[torch.Tensor, "B"]:
124
+ """
125
+ Compute the mean over a subsample of pairs in a 2d matrix. Returns a tensor with an element for each batch
126
+ Args:
127
+ matrix_to_mean: tensor of shape (batch, L, L)
128
+ pairs_to_score: 2d tensor of shape (L, L) with 1s where pairs should be scored and 0s elsewhere
129
+ eps: small epsilon value to avoid division by zero
130
+ Returns:
131
+ 1d tensor of shape (batch,) with the mean over the subsampled pairs for each batch
132
+ """
133
+ B, L, M = matrix_to_mean.shape
134
+ assert matrix_to_mean.shape == (
135
+ B,
136
+ L,
137
+ M,
138
+ ), "Matrix to mean should be of shape (batch, L, M)"
139
+ assert pairs_to_score.shape == (L, M), "Pairs to score should be of shape (L, M)"
140
+ batch = (matrix_to_mean * pairs_to_score).sum(dim=(-1, -2)) / (
141
+ pairs_to_score.sum() + eps
142
+ )
143
+ assert batch.shape == (B,), "Batch should be of shape (batch,)"
144
+ return batch
145
+
146
+
147
+ def compute_min_over_subsampled_pairs(
148
+ matrix_to_min: Float[torch.Tensor, "B L M"],
149
+ pairs_to_score: Bool[torch.Tensor, "L M"],
150
+ ) -> Float[torch.Tensor, "B"]:
151
+ """
152
+ Compute the min over a subsample of pairs in a 2d matrix. Returns a tensor with an element for each batch
153
+ Args:
154
+ matrix_to_min: tensor of shape (batch, L, L)
155
+ pairs_to_score: 2d tensor of shape (L, L) with 1s where pairs should be scored and 0s elsewhere
156
+ Returns:
157
+ 1d tensor of shape (batch,) with the min over the subsampled pairs for each batch
158
+ """
159
+ B, L, M = matrix_to_min.shape
160
+ assert matrix_to_min.shape == (
161
+ B,
162
+ L,
163
+ M,
164
+ ), "Matrix to min should be of shape (batch, L, M)"
165
+ assert pairs_to_score.shape == (L, M), "Pairs to score should be of shape (L, M)"
166
+ # Use torch.where to efficiently mask without cloning the entire matrix
167
+ # This broadcasts pairs_to_score across the batch dimension
168
+ masked_matrix = torch.where(
169
+ pairs_to_score.bool(), # condition (L, M) -> broadcasts to (B, L, M)
170
+ matrix_to_min, # if True: use original values (B, L, M)
171
+ torch.tensor(
172
+ float("inf"), device=matrix_to_min.device, dtype=matrix_to_min.dtype
173
+ ), # if False: use inf
174
+ )
175
+
176
+ # Flatten the last two dimensions and compute min across them
177
+ batch = masked_matrix.view(B, -1).min(dim=-1)[0]
178
+
179
+ assert batch.shape == (B,), "Batch should be of shape (batch,)"
180
+ return batch
181
+
182
+
183
+ def spread_batch_into_dictionary(batch: Float[torch.Tensor, "B"]) -> dict[int, float]:
184
+ """
185
+ Given a batch of data, create a dictionary with keys as the batch index and value as the corresponding data
186
+ Args:
187
+ batch: 1D tensor of shape (B,)
188
+ Returns:
189
+ Dictionary mapping batch indices to float values
190
+ """
191
+ assert len(batch.shape) == 1, f"Batch should be a 1d tensor, {batch}"
192
+ return {i: data.item() for i, data in enumerate(batch)}
@@ -0,0 +1,134 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+ from rf3.metrics.metric_utils import find_bin_midpoints
5
+
6
+ from foundry.metrics.metric import Metric
7
+
8
+
9
+ def compute_ptm(
10
+ pae: torch.Tensor,
11
+ to_calculate: torch.Tensor | None,
12
+ max_distance: float = 32,
13
+ bin_count: int = 64,
14
+ ):
15
+ """Compute the predicted TM-score (PTM) from the predicted aligned error (PAE).
16
+
17
+ Args:
18
+ pae: Predicted aligned error tensor.
19
+ to_calculate: Tensor indicating which residues to calculate PTM for.
20
+
21
+ Returns:
22
+ ptm: Computed predicted TM-score.
23
+ """
24
+ D, I = pae.shape[0], pae.shape[1]
25
+ if to_calculate is None:
26
+ to_calculate = torch.ones((I, I), dtype=torch.bool, device=pae.device)
27
+
28
+ bin_centers = find_bin_midpoints(
29
+ max_distance, bin_count, device=pae.device
30
+ ) # TODO: get this from config
31
+ pae = torch.nn.Softmax(dim=-1)(pae).detach().float()
32
+ normalization_factor = 1.24 * (max(I, 19) - 15.0) ** (1 / 3) - 1.8
33
+ denominator = 1 / (1 + (bin_centers / (normalization_factor)) ** 2)
34
+ pae = pae * denominator[None, None, None, :] # Broadcast to match pae shape
35
+
36
+ pae = pae.sum(dim=-1) # Sum over the last dimension
37
+ pae = (pae * to_calculate[None]).sum(dim=-1) / (to_calculate.sum(dim=-1) + 1e-6)
38
+ ptm = pae.max(dim=-1).values
39
+ assert ptm.shape == (D,)
40
+ return ptm
41
+
42
+
43
+ class ComputePTM(Metric):
44
+ @property
45
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
46
+ return {
47
+ "pae": ("network_output", "pae"),
48
+ "asym_id": ("network_input", "f", "asym_id"),
49
+ }
50
+
51
+ def compute(
52
+ self,
53
+ pae: torch.Tensor,
54
+ asym_id: torch.Tensor,
55
+ ) -> dict[str, float]:
56
+ """Compute the predicted TM-score (PTM) from the predicted aligned error (PAE).
57
+ Args:
58
+ pae: Predicted aligned error tensor.
59
+ asym_id: AtomArrayStack containing the predicted structure.
60
+ Returns:
61
+ ptm: Computed predicted TM-score.
62
+ """
63
+ ptm = compute_ptm(pae, None)
64
+ # split the batch dimension into separate keys in the output dictionary
65
+ ptm = ptm.cpu().numpy()
66
+ ptm = {f"ptm_{i}": ptm[i] for i in range(len(ptm))}
67
+ return ptm
68
+
69
+
70
+ class ComputeIPTM(Metric):
71
+ @property
72
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
73
+ return {
74
+ "pae": ("network_output", "pae"),
75
+ "asym_id": ("network_input", "f", "asym_id"),
76
+ "is_ligand": ("network_input", "f", "is_ligand"),
77
+ }
78
+
79
+ def compute(
80
+ self,
81
+ pae: torch.Tensor,
82
+ asym_id: torch.Tensor,
83
+ is_ligand: torch.Tensor,
84
+ ) -> dict[str, float]:
85
+ """Compute the predicted interface TM-score (iPTM) from the predicted aligned error (PAE).
86
+ Args:
87
+ pae: Predicted aligned error tensor.
88
+ predicted_atom_array_stack: AtomArrayStack containing the predicted structure.
89
+ Returns:
90
+ iptm: Computed interface TM-score.
91
+ """
92
+ unique, counts = torch.unique(asym_id, return_counts=True)
93
+ to_calculate = asym_id[None, :] != asym_id[:, None]
94
+ iptm = compute_ptm(pae, to_calculate)
95
+
96
+ # make a protein - ligand mask
97
+ protein_mask = is_ligand == 0
98
+ ligand_mask = is_ligand == 1
99
+ # calculate iptm for protein-protein, protein-ligand, and ligand-ligand interfaces
100
+ protein_protein_mask = (
101
+ protein_mask[None, :] & protein_mask[:, None] * to_calculate
102
+ )
103
+ protein_ligand_mask = (
104
+ (protein_mask[None, :] & ligand_mask[:, None])
105
+ | (ligand_mask[None, :] & protein_mask[:, None])
106
+ ) * to_calculate
107
+ ligand_ligand_mask = ligand_mask[None, :] & ligand_mask[:, None] * to_calculate
108
+ # calculate iptm for each interface type
109
+ iptm_protein_protein = compute_ptm(pae, protein_protein_mask)
110
+ iptm_protein_ligand = compute_ptm(pae, protein_ligand_mask)
111
+ iptm_ligand_ligand = compute_ptm(pae, ligand_ligand_mask)
112
+
113
+ # split the batch dimension into separate keys in the output dictionary
114
+ iptm = iptm.cpu().numpy()
115
+ iptm = {f"iptm_{i}": iptm[i] for i in range(len(iptm))}
116
+ iptm_protein_protein = iptm_protein_protein.cpu().numpy()
117
+ iptm_protein_protein = {
118
+ f"iptm_protein_protein_{i}": iptm_protein_protein[i]
119
+ for i in range(len(iptm_protein_protein))
120
+ }
121
+ iptm_protein_ligand = iptm_protein_ligand.cpu().numpy()
122
+ iptm_protein_ligand = {
123
+ f"iptm_protein_ligand_{i}": iptm_protein_ligand[i]
124
+ for i in range(len(iptm_protein_ligand))
125
+ }
126
+ iptm_ligand_ligand = iptm_ligand_ligand.cpu().numpy()
127
+ iptm_ligand_ligand = {
128
+ f"iptm_ligand_ligand_{i}": iptm_ligand_ligand[i]
129
+ for i in range(len(iptm_ligand_ligand))
130
+ }
131
+ iptm.update(iptm_protein_protein)
132
+ iptm.update(iptm_protein_ligand)
133
+ iptm.update(iptm_ligand_ligand)
134
+ return iptm
rf3/metrics/rasa.py ADDED
@@ -0,0 +1,108 @@
1
+ import numpy as np
2
+ from atomworks.ml.transforms.sasa import calculate_atomwise_rasa
3
+ from beartype.typing import Any
4
+ from biotite.structure import AtomArrayStack
5
+
6
+ from foundry.metrics.metric import Metric
7
+
8
+
9
+ class UnresolvedRegionRASA(Metric):
10
+ """
11
+ This metric computes the RASA score for unresolved regions in a protein structure.
12
+ The RASA score is defined as the ratio of the solvent-accessible surface area (SASA)
13
+ of a residue in a protein structure to the SASA of the same residue in an extended conformation.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ probe_radius: float = 1.4,
19
+ atom_radii: str | np.ndarray = "ProtOr",
20
+ point_number: int = 100,
21
+ include_resolved: bool = False,
22
+ **kwargs,
23
+ ):
24
+ super().__init__(**kwargs)
25
+ self.probe_radius = probe_radius
26
+ self.atom_radii = atom_radii
27
+ self.point_number = point_number
28
+ self.include_resolved = include_resolved
29
+
30
+ @property
31
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
32
+ return {
33
+ "predicted_atom_array_stack": ("predicted_atom_array_stack",),
34
+ "ground_truth_atom_array_stack": ("ground_truth_atom_array_stack",),
35
+ }
36
+
37
+ def compute(
38
+ self,
39
+ predicted_atom_array_stack: AtomArrayStack,
40
+ ground_truth_atom_array_stack: AtomArrayStack,
41
+ ) -> dict[str, Any]:
42
+ """Compute the RASA score for unresolved regions in a protein structure.
43
+
44
+ Args:
45
+ predicted_atom_array (AtomArray): The input atom array representing the predicted protein structure.
46
+ ground_truth_atom_array (AtomArray): The input atom array representing the ground truth protein structure.
47
+ probe_radius (float, optional): Van-der-Waals radius of the probe in Angstrom. Defaults to 1.4 (for water).
48
+ atom_radii (str | np.ndarray, optional): Atom radii set to use for calculation. Defaults to "ProtOr".
49
+ point_number (int, optional): Number of points in the Shrake-Rupley algorithm to sample for calculating SASA. Defaults to 100.
50
+ include_resolved (bool, optional): Whether to include resolved regions in the RASA score. Defaults to False.
51
+
52
+ Returns:
53
+ dict: A dictionary containing the RASA score and other relevant information.
54
+ """
55
+
56
+ # find unresolved regions
57
+ # (polymer atoms with occupancy 0.0)
58
+ atoms_to_score_unresolved = ground_truth_atom_array_stack.is_polymer & (
59
+ ground_truth_atom_array_stack.occupancy == 0.0
60
+ )
61
+
62
+ # find resolved regions (polymer atoms with occupancy > 0.0)
63
+ atoms_to_score_resolved = ground_truth_atom_array_stack.is_polymer & (
64
+ ground_truth_atom_array_stack.occupancy > 0.0
65
+ )
66
+
67
+ unresolved_rasas = []
68
+ resolved_rasas = []
69
+
70
+ # Calculate RASA
71
+ for atom_array in predicted_atom_array_stack:
72
+ try:
73
+ rasa = calculate_atomwise_rasa(
74
+ atom_array=atom_array,
75
+ probe_radius=self.probe_radius,
76
+ atom_radii=self.atom_radii,
77
+ point_number=self.point_number,
78
+ )
79
+ unresolved_rasas.append(rasa[atoms_to_score_unresolved].mean())
80
+ if self.include_resolved:
81
+ resolved_rasas.append(rasa[atoms_to_score_resolved].mean())
82
+ except KeyError:
83
+ unresolved_rasas.append(np.nan)
84
+ if self.include_resolved:
85
+ resolved_rasas.append(np.nan)
86
+
87
+ # Calculate the mean RASA scores
88
+ # Pattern-match other metrics by appending "_i" to the metric name to represent multiple batches
89
+ # (e.g., "unresolved_polymer_rasa_0", "unresolved_polymer_rasa_1", etc.)
90
+ unresolved_rasa = np.nanmean(unresolved_rasas)
91
+ output_dictionary = {
92
+ f"unresolved_polymer_rasa_{i}": rasa
93
+ for i, rasa in enumerate(unresolved_rasas)
94
+ }
95
+ output_dictionary["mean_unresolved_polymer_rasa"] = unresolved_rasa
96
+
97
+ # ... and add resolved region RASA scores if flag is enabled
98
+ if self.include_resolved:
99
+ resolved_rasa = np.nanmean(resolved_rasas)
100
+ output_dictionary.update(
101
+ {
102
+ f"resolved_polymer_rasa_{i}": rasa
103
+ for i, rasa in enumerate(resolved_rasas)
104
+ }
105
+ )
106
+ output_dictionary["mean_resolved_polymer_rasa"] = resolved_rasa
107
+
108
+ return output_dictionary
@@ -0,0 +1,91 @@
1
+ import numpy as np
2
+ from atomworks.ml.utils import nested_dict
3
+ from atomworks.ml.utils.selection import (
4
+ get_mask_from_atom_selection,
5
+ parse_selection_string,
6
+ )
7
+ from beartype.typing import Any
8
+ from biotite.structure import AtomArrayStack
9
+
10
+ from foundry.metrics.metric import Metric
11
+
12
+
13
+ class SelectedAtomByAtomDistances(Metric):
14
+ """Computes all-by-all 2D distances given a list of selection strings"""
15
+
16
+ def compute_from_kwargs(self, **kwargs: Any) -> dict[str, Any]:
17
+ """Override parent class to handle optional selection_strings parameter"""
18
+ compute_inputs = {
19
+ "atom_array_stack": nested_dict.getitem(
20
+ kwargs, key="predicted_atom_array_stack"
21
+ )
22
+ }
23
+
24
+ # Add selection_strings only if it exists
25
+ try:
26
+ compute_inputs["selection_strings"] = nested_dict.getitem(
27
+ kwargs, key=("extra_info", "selection_strings")
28
+ )
29
+ except (KeyError, IndexError, TypeError):
30
+ pass
31
+
32
+ return self.compute(**compute_inputs)
33
+
34
+ def compute(
35
+ self,
36
+ atom_array_stack: AtomArrayStack,
37
+ selection_strings: list[str] | None = None,
38
+ ) -> dict[str, Any]:
39
+ # Short-circuit if no selection strings are provided
40
+ if not selection_strings:
41
+ return {}
42
+
43
+ # ... select the specified atoms
44
+ mask = np.zeros(atom_array_stack.array_length(), dtype=bool)
45
+ atom_selections = [parse_selection_string(s) for s in selection_strings]
46
+ for atom_selection in atom_selections:
47
+ mask |= get_mask_from_atom_selection(atom_array_stack, atom_selection)
48
+ selected_atom_array_stack = atom_array_stack[:, mask]
49
+
50
+ # Create views with added dimensions for broadcasting
51
+ # coord is (D, L, 3), we want pairwise distances for each D
52
+ coord_i = selected_atom_array_stack.coord[:, :, np.newaxis, :] # (D, L, 1, 3)
53
+ coord_j = selected_atom_array_stack.coord[:, np.newaxis, :, :] # (D, 1, L, 3)
54
+
55
+ # Calculate pairwise differences and distances
56
+ differences = coord_i - coord_j # broadcasts to (D, L, L, 3)
57
+ distances = np.linalg.norm(differences, axis=-1) # (D, L, L)
58
+
59
+ # Compute the mean and standard deviation across the D dimension
60
+ mean_distances = np.mean(distances, axis=0) # Shape: (L, L)
61
+ std_distances = np.std(distances, axis=0) # Shape: (L, L)
62
+
63
+ # Name the features with the chain_id, res_name, res_id, atom_name
64
+ def _format_atom_id(chain_id, res_name, res_id, atom_name):
65
+ return f"{chain_id}/{res_name}/{res_id}/{atom_name}"
66
+
67
+ vectorized_format = np.vectorize(_format_atom_id)
68
+ id = vectorized_format(
69
+ selected_atom_array_stack.chain_id,
70
+ selected_atom_array_stack.res_name,
71
+ selected_atom_array_stack.res_id,
72
+ selected_atom_array_stack.atom_name,
73
+ )
74
+
75
+ # Create a 2x2 numpy arrays of names, where we concatenate the id ...
76
+ id_i = np.char.add(id, "-")
77
+ id_II = np.char.add(id_i[:, np.newaxis], id[np.newaxis, :])
78
+
79
+ # ... and store the results in a dictionary, naming the columns with the concatenated id
80
+ results = {}
81
+ for i in range(len(id)):
82
+ for j in range(
83
+ i + 1, len(id)
84
+ ): # Only consider j > i to avoid symmetric duplicates
85
+ col_id = id_II[i, j]
86
+ mean = mean_distances[i, j]
87
+ std = std_distances[i, j]
88
+ results[f"{col_id}_mean"] = mean
89
+ results[f"{col_id}_std"] = std
90
+
91
+ return results