rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
mpnn/trainers/mpnn.py ADDED
@@ -0,0 +1,193 @@
1
+ import torch
2
+ from beartype.typing import Any
3
+ from lightning_utilities import apply_to_collection
4
+ from mpnn.loss.nll_loss import LabelSmoothedNLLLoss
5
+ from mpnn.metrics.nll import NLL, InterfaceNLL
6
+ from mpnn.metrics.sequence_recovery import (
7
+ InterfaceSequenceRecovery,
8
+ SequenceRecovery,
9
+ )
10
+ from mpnn.model.mpnn import LigandMPNN, ProteinMPNN
11
+ from omegaconf import DictConfig
12
+
13
+ from foundry.metrics.metric import MetricManager
14
+ from foundry.trainers.fabric import FabricTrainer
15
+ from foundry.utils.ddp import RankedLogger
16
+ from foundry.utils.torch import assert_no_nans
17
+
18
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
19
+
20
+
21
+ class MPNNTrainer(FabricTrainer):
22
+ """Standard Trainer for MPNN-style models"""
23
+
24
+ def __init__(
25
+ self,
26
+ *,
27
+ model_type: str,
28
+ loss: DictConfig | dict | None = None,
29
+ metrics: DictConfig | dict | None = None,
30
+ **kwargs,
31
+ ):
32
+ """
33
+ See `FabricTrainer` for the additional initialization arguments.
34
+
35
+ Args:
36
+ model_type (str): Type of model to use ("protein_mpnn" or
37
+ "ligand_mpnn").
38
+ loss (DictConfig | dict | None): Configuration for the loss
39
+ function. If None, default parameters will be used.
40
+ metrics (DictConfig | dict | None): Configuration for the metrics.
41
+ Ignored - metrics are hard-coded.
42
+ """
43
+ super().__init__(**kwargs)
44
+
45
+ self.model_type = model_type
46
+
47
+ # Metrics
48
+ metrics = {
49
+ "nll": NLL(),
50
+ "sequence_recovery": SequenceRecovery(),
51
+ }
52
+ if self.model_type == "ligand_mpnn":
53
+ metrics["interface_nll"] = InterfaceNLL()
54
+ metrics["interface_sequence_recovery"] = InterfaceSequenceRecovery()
55
+ self.metrics = MetricManager(metrics)
56
+
57
+ # Loss
58
+ loss_params = loss if loss else {}
59
+ self.loss = LabelSmoothedNLLLoss(**loss_params)
60
+
61
+ def construct_model(self):
62
+ """Construct the model with hard-coded parameters."""
63
+ with self.fabric.init_module():
64
+ ranked_logger.info(f"Instantiating {self.model_type} model...")
65
+
66
+ # Hard-coded model selection
67
+ if self.model_type == "protein_mpnn":
68
+ model = ProteinMPNN()
69
+ elif self.model_type == "ligand_mpnn":
70
+ model = LigandMPNN()
71
+ else:
72
+ raise ValueError(f"Invalid model type: {self.model_type}")
73
+
74
+ # Initialize model weights
75
+ model.apply(model.init_weights)
76
+
77
+ self.initialize_or_update_trainer_state({"model": model})
78
+
79
+ def training_step(
80
+ self,
81
+ batch: Any,
82
+ batch_idx: int,
83
+ is_accumulating: bool,
84
+ ) -> None:
85
+ """
86
+ Training step, running forward and backward passes.
87
+
88
+ Args:
89
+ batch (Any): The current batch; can be of any form.
90
+ batch_idx (int): The index of the current batch.
91
+ is_accumulating (bool): Whether we are accumulating gradients
92
+ (i.e., not yet calling optimizer.step()). If this is the case,
93
+ we should skip the synchronization during the backward pass.
94
+
95
+ Returns:
96
+ None; we call `loss.backward()` directly, and store the outputs in
97
+ `self._current_train_return`.
98
+ """
99
+ model = self.state["model"]
100
+ assert model.training, "Model must be training!"
101
+
102
+ network_input = batch
103
+
104
+ with self.fabric.no_backward_sync(model, enabled=is_accumulating):
105
+ # Forward pass
106
+ network_output = model.forward(network_input)
107
+ assert_no_nans(
108
+ network_output["decoder_features"],
109
+ msg="network_output['decoder_features'] "
110
+ + f"for batch_idx: {batch_idx}",
111
+ )
112
+
113
+ total_loss, loss_dict = self.loss(
114
+ network_input=batch,
115
+ network_output=network_output,
116
+ loss_input={},
117
+ )
118
+
119
+ # Backward pass
120
+ self.fabric.backward(total_loss)
121
+
122
+ # Optionally compute training metrics
123
+ train_return = {"total_loss": total_loss, "loss_dict": loss_dict}
124
+
125
+ # Store the outputs without gradients for use in logging,
126
+ # callbacks, learning rate schedulers, etc.
127
+ self._current_train_return = apply_to_collection(
128
+ train_return,
129
+ dtype=torch.Tensor,
130
+ function=lambda x: x.detach(),
131
+ )
132
+
133
+ def validation_step(
134
+ self,
135
+ batch: Any,
136
+ batch_idx: int,
137
+ compute_metrics: bool = True,
138
+ ) -> dict:
139
+ """
140
+ Validation step, running forward pass and computing validation
141
+ metrics.
142
+
143
+ Args:
144
+ batch (Any): The current batch; can be of any form.
145
+ batch_idx (int): The index of the current batch.
146
+ compute_metrics (bool): Whether to compute metrics. If False, we
147
+ will not compute metrics, and the output will be None. Set to
148
+ False during the inference pipeline, where we need the network
149
+ output but cannot compute metrics (since we do not have the
150
+ ground truth).
151
+
152
+ Returns:
153
+ dict: Output dictionary containing the validation metrics and
154
+ network output.
155
+ """
156
+ model = self.state["model"]
157
+ assert not model.training, "Model must be in evaluation mode during validation!"
158
+
159
+ network_input = batch
160
+
161
+ # Forward pass
162
+ network_output = model.forward(network_input)
163
+
164
+ assert_no_nans(
165
+ network_output["decoder_features"],
166
+ msg="network_output['decoder_features'] " + f"for batch_idx: {batch_idx}",
167
+ )
168
+
169
+ metrics_output = {}
170
+ if compute_metrics:
171
+ # Compute all metrics using MetricManager
172
+ metrics_output = self.metrics(
173
+ network_input=batch,
174
+ network_output=network_output,
175
+ extra_info={},
176
+ )
177
+
178
+ # Avoid gradients in stored values to prevent memory leaks
179
+ if metrics_output:
180
+ metrics_output = apply_to_collection(
181
+ metrics_output, torch.Tensor, lambda x: x.detach()
182
+ )
183
+
184
+ network_output = apply_to_collection(
185
+ network_output, torch.Tensor, lambda x: x.detach()
186
+ )
187
+
188
+ validation_return = {
189
+ "metrics_output": metrics_output,
190
+ "network_output": network_output,
191
+ }
192
+
193
+ return validation_return
@@ -0,0 +1,184 @@
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ from atomworks.common import KeyToIntMapper
5
+ from atomworks.ml.transforms._checks import (
6
+ check_atom_array_annotation,
7
+ )
8
+ from atomworks.ml.transforms.base import Transform
9
+ from atomworks.ml.transforms.encoding import atom_array_to_encoding
10
+ from atomworks.ml.utils.token import get_token_starts
11
+ from mpnn.transforms.feature_aggregation.token_encodings import MPNN_TOKEN_ENCODING
12
+
13
+
14
+ class EncodeMPNNNonAtomizedTokens(Transform):
15
+ """Encode non-atomized tokens for MPNN with X, X_m, and S features.
16
+
17
+ Creates:
18
+ - X: (L, 37, 3) coordinates for non-atomized tokens
19
+ - X_m: (L, 37) mask for atom existence and occupancy > occupancy_threshold
20
+ - S: (L) sequence encoding
21
+
22
+ Args:
23
+ occupancy_threshold (float): Minimum occupancy to consider atom as present. Defaults to 0.5.
24
+ """
25
+
26
+ def __init__(self, occupancy_threshold: float = 0.5):
27
+ self.occupancy_threshold = occupancy_threshold
28
+ self.encoding = MPNN_TOKEN_ENCODING
29
+
30
+ def check_input(self, data: dict[str, Any]) -> None:
31
+ check_atom_array_annotation(data, ["atomize", "res_name", "occupancy"])
32
+
33
+ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
34
+ atom_array = data["atom_array"]
35
+
36
+ # Check that atom_array is not empty.
37
+ assert len(atom_array) > 0, "atom_array cannot be empty"
38
+
39
+ # Get non-atomized tokens only
40
+ non_atomized_mask = ~atom_array.atomize
41
+ non_atomized_array = atom_array[non_atomized_mask]
42
+
43
+ assert len(non_atomized_array) > 0, "No non-atomized atoms found"
44
+
45
+ if len(non_atomized_array) == 0:
46
+ # No non-atomized tokens, create empty arrays
47
+ data["input_features"].update(
48
+ {
49
+ "X": np.zeros((0, 37, 3), dtype=np.float32),
50
+ "X_m": np.zeros((0, 37), dtype=np.bool_),
51
+ "S": np.zeros((0,), dtype=np.int64),
52
+ }
53
+ )
54
+ return data
55
+
56
+ # Encode using the MPNN token encoding
57
+ encoded = atom_array_to_encoding(
58
+ non_atomized_array,
59
+ encoding=self.encoding,
60
+ default_coord=0.0, # Use 0.0 instead of NaN for MPNN
61
+ occupancy_threshold=self.occupancy_threshold,
62
+ )
63
+
64
+ # Create X: coordinates (L, 37, 3)
65
+ X = encoded["xyz"].astype(np.float32)
66
+
67
+ # Create X_m: mask for existence and occupancy > threshold (L, 37)
68
+ # encoded["mask"] already considers occupancy, we just need to check if atoms exist in encoding
69
+ X_m = encoded["mask"].astype(np.bool_)
70
+
71
+ # Create S: sequence encoding (L,)
72
+ S = encoded["seq"].astype(np.int64)
73
+
74
+ data["input_features"].update(
75
+ {
76
+ "X": X,
77
+ "X_m": X_m,
78
+ "S": S,
79
+ }
80
+ )
81
+
82
+ # Check that we have at least one non-atomized token.
83
+ L = X.shape[0]
84
+ assert L > 0, "At least one non-atomized token should be present"
85
+
86
+ return data
87
+
88
+
89
+ class FeaturizeNonAtomizedTokens(Transform):
90
+ """Add additional features for non-atomized tokens: R_idx, chain_labels, residue_mask."""
91
+
92
+ def check_input(self, data: dict[str, Any]) -> None:
93
+ check_atom_array_annotation(
94
+ data, ["atomize", "within_chain_res_idx", "chain_iid"]
95
+ )
96
+
97
+ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
98
+ atom_array = data["atom_array"]
99
+
100
+ # Subset to non-atomized.
101
+ non_atomized_array = atom_array[~atom_array.atomize]
102
+
103
+ # Get token starts for non-atomized tokens
104
+ non_atomized_token_starts = get_token_starts(non_atomized_array)
105
+ non_atomized_token_level = non_atomized_array[non_atomized_token_starts]
106
+
107
+ if len(non_atomized_token_level) == 0:
108
+ # No non-atomized tokens
109
+ data["input_features"].update(
110
+ {
111
+ "R_idx": np.zeros((0,), dtype=np.int32),
112
+ "chain_labels": np.zeros((0,), dtype=np.int64),
113
+ "residue_mask": np.zeros((0,), dtype=np.bool_),
114
+ }
115
+ )
116
+ return data
117
+
118
+ # R_idx: residue indices within chains (0-indexed)
119
+ R_idx = non_atomized_token_level.within_chain_res_idx.astype(np.int32)
120
+
121
+ # chain_labels: convert chain_iid to unique integers
122
+ chain_mapper = KeyToIntMapper()
123
+ chain_labels = np.array(
124
+ [
125
+ chain_mapper(chain_iid)
126
+ for chain_iid in non_atomized_token_level.chain_iid
127
+ ],
128
+ dtype=np.int64,
129
+ )
130
+
131
+ # residue_mask: all 1's for non-atomized tokens
132
+ residue_mask = np.ones(len(non_atomized_token_level), dtype=np.bool_)
133
+
134
+ data["input_features"].update(
135
+ {
136
+ "R_idx": R_idx,
137
+ "chain_labels": chain_labels,
138
+ "residue_mask": residue_mask,
139
+ }
140
+ )
141
+ return data
142
+
143
+
144
+ class FeaturizeAtomizedTokens(Transform):
145
+ """Add features for atomized tokens: Y, Y_t, Y_m."""
146
+
147
+ def check_input(self, data: dict[str, Any]) -> None:
148
+ check_atom_array_annotation(data, ["atomize", "atomic_number"])
149
+
150
+ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
151
+ atom_array = data["atom_array"]
152
+
153
+ # Get atomized tokens only
154
+ atomized_mask = atom_array.atomize
155
+ atomized_array = atom_array[atomized_mask]
156
+
157
+ if len(atomized_array) == 0:
158
+ # No atomized tokens
159
+ data["input_features"].update(
160
+ {
161
+ "Y": np.zeros((0, 3), dtype=np.float32),
162
+ "Y_t": np.zeros((0,), dtype=np.int32),
163
+ "Y_m": np.zeros((0,), dtype=np.bool_),
164
+ }
165
+ )
166
+ return data
167
+
168
+ # Y: coordinates of atomized tokens (n_atomized, 3)
169
+ Y = atomized_array.coord.astype(np.float32)
170
+
171
+ # Y_t: atomic numbers of atomized tokens (n_atomized,)
172
+ Y_t = atomized_array.atomic_number.astype(np.int32)
173
+
174
+ # Y_m: mask for atomized tokens (all 1's since they exist) (n_atomized,)
175
+ Y_m = np.ones(len(atomized_array), dtype=np.bool_)
176
+
177
+ data["input_features"].update(
178
+ {
179
+ "Y": Y,
180
+ "Y_t": Y_t,
181
+ "Y_m": Y_m,
182
+ }
183
+ )
184
+ return data
@@ -0,0 +1,76 @@
1
+ """
2
+ Feature aggregation for polymer-ligand interface masks.
3
+
4
+ This module provides transforms to compute interface masks for polymer residues
5
+ that are at the interface with ligand molecules.
6
+ """
7
+
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ from atomworks.ml.transforms._checks import check_atom_array_annotation
12
+ from atomworks.ml.transforms.base import Transform
13
+ from atomworks.ml.utils.token import get_token_starts
14
+
15
+
16
+ class FeaturizePolymerLigandInterfaceMask(Transform):
17
+ """
18
+ Compute a polymer mask indicating which residues are at the polymer-ligand
19
+ interface.
20
+
21
+ This transform processes an atom array to identify polymer residues that
22
+ have any atoms within the specified distance threshold of ligand atoms.
23
+ It expects that the atom array already has the
24
+ 'at_polymer_ligand_interface' annotation computed by the
25
+ ComputePolymerLigandInterface transform.
26
+ """
27
+
28
+ def check_input(self, data: dict[str, Any]) -> None:
29
+ """Check that required annotations are present."""
30
+ check_atom_array_annotation(
31
+ {"atom_array": data["atom_array"]},
32
+ required=["element", "atomize", "at_polymer_ligand_interface"],
33
+ )
34
+
35
+ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
36
+ """Compute polymer-ligand interface mask and add to input_features."""
37
+ atom_array = data["atom_array"]
38
+
39
+ # Get interface annotation that should already be computed
40
+ interface_atoms = atom_array.at_polymer_ligand_interface
41
+
42
+ # Get token starts to map atoms to residues
43
+ token_starts = get_token_starts(atom_array)
44
+
45
+ # Create residue-level interface mask for all tokens
46
+ all_residue_interface_mask = np.zeros(len(token_starts), dtype=bool)
47
+
48
+ # For each token (residue), check if any of its atoms are at the
49
+ # interface.
50
+ for i, start_idx in enumerate(token_starts):
51
+ if i < len(token_starts) - 1:
52
+ end_idx = token_starts[i + 1]
53
+ else:
54
+ end_idx = len(atom_array)
55
+
56
+ # Check if any atom in this residue is at the interface
57
+ residue_atoms = interface_atoms[start_idx:end_idx]
58
+ all_residue_interface_mask[i] = np.any(residue_atoms)
59
+
60
+ # Get token-level atomize annotation
61
+ token_level_array = atom_array[token_starts]
62
+ non_atomized_mask = ~token_level_array.atomize
63
+
64
+ # Get interface mask for non-atomized residues only
65
+ polymer_interface_mask = all_residue_interface_mask[non_atomized_mask]
66
+
67
+ # Initialize input_features if it doesn't exist.
68
+ if "input_features" not in data:
69
+ data["input_features"] = {}
70
+
71
+ # Add the interface mask to input_features
72
+ data["input_features"]["polymer_ligand_interface_mask"] = (
73
+ polymer_interface_mask.astype(np.bool_)
74
+ )
75
+
76
+ return data
@@ -0,0 +1,132 @@
1
+ from atomworks.constants import AA_LIKE_CHEM_TYPES, STANDARD_AA, UNKNOWN_AA
2
+ from atomworks.ml.encoding_definitions import TokenEncoding
3
+
4
+ # Token ordering for MPNN.
5
+ token_order = STANDARD_AA + (UNKNOWN_AA,)
6
+
7
+ # Token ordering for old versions of MPNN.
8
+ legacy_token_order = (
9
+ "ALA",
10
+ "CYS",
11
+ "ASP",
12
+ "GLU",
13
+ "PHE",
14
+ "GLY",
15
+ "HIS",
16
+ "ILE",
17
+ "LYS",
18
+ "LEU",
19
+ "MET",
20
+ "ASN",
21
+ "PRO",
22
+ "GLN",
23
+ "ARG",
24
+ "SER",
25
+ "THR",
26
+ "VAL",
27
+ "TRP",
28
+ "TYR",
29
+ "UNK",
30
+ )
31
+
32
+ # Atom ordering for new versions of MPNN.
33
+ atom_order = (
34
+ "N",
35
+ "CA",
36
+ "C",
37
+ "O",
38
+ "CB",
39
+ "CG",
40
+ "CG1",
41
+ "CG2",
42
+ "OG",
43
+ "OG1",
44
+ "SG",
45
+ "CD",
46
+ "CD1",
47
+ "CD2",
48
+ "ND1",
49
+ "ND2",
50
+ "OD1",
51
+ "OD2",
52
+ "SD",
53
+ "CE",
54
+ "CE1",
55
+ "CE2",
56
+ "CE3",
57
+ "NE",
58
+ "NE1",
59
+ "NE2",
60
+ "OE1",
61
+ "OE2",
62
+ "CH2",
63
+ "NH1",
64
+ "NH2",
65
+ "OH",
66
+ "CZ",
67
+ "CZ2",
68
+ "CZ3",
69
+ "NZ",
70
+ "OXT",
71
+ )
72
+
73
+ # Atom ordering for old versions of MPNN.
74
+ legacy_atom_order = (
75
+ "N",
76
+ "CA",
77
+ "C",
78
+ "CB",
79
+ "O",
80
+ "CG",
81
+ "CG1",
82
+ "CG2",
83
+ "OG",
84
+ "OG1",
85
+ "SG",
86
+ "CD",
87
+ "CD1",
88
+ "CD2",
89
+ "ND1",
90
+ "ND2",
91
+ "OD1",
92
+ "OD2",
93
+ "SD",
94
+ "CE",
95
+ "CE1",
96
+ "CE2",
97
+ "CE3",
98
+ "NE",
99
+ "NE1",
100
+ "NE2",
101
+ "OE1",
102
+ "OE2",
103
+ "CH2",
104
+ "NH1",
105
+ "NH2",
106
+ "OH",
107
+ "CZ",
108
+ "CZ2",
109
+ "CZ3",
110
+ "NZ",
111
+ "OXT",
112
+ )
113
+
114
+ # Token encoding for MPNN.
115
+ MPNN_TOKEN_ENCODING = TokenEncoding(
116
+ token_atoms={token: atom_order for token in token_order},
117
+ chemcomp_type_to_unknown={chem_type: "UNK" for chem_type in AA_LIKE_CHEM_TYPES},
118
+ )
119
+
120
+ # Token encoding for versions of MPNN using the legacy token order and
121
+ # new atom order.
122
+ MPNN_LEGACY_TOKEN_ENCODING = TokenEncoding(
123
+ token_atoms={token: atom_order for token in legacy_token_order},
124
+ chemcomp_type_to_unknown={chem_type: "UNK" for chem_type in AA_LIKE_CHEM_TYPES},
125
+ )
126
+
127
+ # Token encoding for versions of MPNN using the legacy token order and
128
+ # legacy atom order.
129
+ MPNN_LEGACY_TOKEN_LEGACY_ATOM_ENCODING = TokenEncoding(
130
+ token_atoms={token: legacy_atom_order for token in legacy_token_order},
131
+ chemcomp_type_to_unknown={chem_type: "UNK" for chem_type in AA_LIKE_CHEM_TYPES},
132
+ )