rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rf3/data/paired_msa.py ADDED
@@ -0,0 +1,206 @@
1
+ import os
2
+ import socket
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ from atomworks.common import exists
9
+ from atomworks.enums import ChainType
10
+ from atomworks.ml.datasets import StructuralDatasetWrapper, logger
11
+ from atomworks.ml.datasets.parsers import (
12
+ MetadataRowParser,
13
+ load_example_from_metadata_row,
14
+ )
15
+ from atomworks.ml.transforms._checks import (
16
+ check_contains_keys,
17
+ check_is_instance,
18
+ check_nonzero_length,
19
+ )
20
+ from atomworks.ml.transforms.base import Transform, TransformedDict
21
+ from atomworks.ml.transforms.msa._msa_loading_utils import load_msa_data_from_path
22
+ from atomworks.ml.utils.rng import capture_rng_states
23
+ from biotite.structure import AtomArray, concatenate
24
+
25
+
26
+ # input data wrapper that allows multiple input files separated by ':'
27
+ # data is loaded as concatentation of all inputs
28
+ class MultiInputDatasetWrapper(StructuralDatasetWrapper):
29
+ def __init__(self, *args, **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+
32
+ def __getitem__(self, idx: int) -> Any:
33
+ # Capture example ID & current rng state (for reproducibility & debugging)
34
+ if hasattr(self, "idx_to_id"):
35
+ # ...if the dataset has a custom idx_to_id method, use it (e.g., for a PandasDataset)
36
+ example_id = self.idx_to_id(idx)
37
+ else:
38
+ # ...otherwise, fallback to a the `id_column` or a string representation of the index
39
+ example_id = (
40
+ self.dataset[idx][self.id_column] if self.id_column else f"row_{idx}"
41
+ )
42
+
43
+ # Get process id and hostname (for debugging)
44
+ logger.debug(
45
+ f"({socket.gethostname()}:{os.getpid()}) Processing example ID: {example_id}"
46
+ )
47
+
48
+ # Load the row, using the __getitem__ method of the dataset
49
+ row = self.dataset[idx]
50
+ pdb_path = row["pdb_path"].split(":")
51
+
52
+ # Process the row into a transform-ready dictionary with the given CIF and dataset parsers
53
+ # We require the "data" dictionary output from `load_example_from_metadata_row` to contain, at a minimum:
54
+ # (a) An "id" key, which uniquely identifies the example within the dataframe; and,
55
+ # (b) The "path" key, which is the path to the CIF file
56
+ _start_parse_time = time.time()
57
+ data = None
58
+ assert len(pdb_path) <= 2
59
+
60
+ for pdb_i in pdb_path:
61
+ row_i = {"example_id": row["example_id"], "path": pdb_i}
62
+ data_i = load_example_from_metadata_row(
63
+ row_i, self.dataset_parser, cif_parser_args=self.cif_parser_args
64
+ )
65
+
66
+ if data is None:
67
+ data = data_i
68
+ else:
69
+ data_i["atom_array"].pn_unit_id = np.full(
70
+ len(data_i["atom_array"]), "B_1"
71
+ ) # unique pn unit id
72
+ data_i["atom_array"].pn_unit_iid = np.full(
73
+ len(data_i["atom_array"]), "B_1"
74
+ ) # unique pn unit iid
75
+ data_i["atom_array"].chain_id = np.full(
76
+ len(data_i["atom_array"]), "B"
77
+ ) # unique chain id
78
+ data_i["atom_array"].chain_iid = np.full(
79
+ len(data_i["atom_array"]), "B"
80
+ ) # unique chain iid
81
+ data["atom_array"] = concatenate(
82
+ [data["atom_array"], data_i["atom_array"]]
83
+ )
84
+ data["atom_array_stack"] = concatenate(
85
+ [data["atom_array_stack"], data_i["atom_array_stack"]]
86
+ )
87
+ data["chain_info"]["B"] = data_i["chain_info"]["A"]
88
+
89
+ # 'example_id', 'path', 'assembly_id', 'query_pn_unit_iids',
90
+ data["path"] = row["pdb_path"]
91
+ data["msa_path"] = Path(row["msa_path"]) # save msa
92
+ _stop_parse_time = time.time()
93
+
94
+ # Manually add timing for cif-parsing
95
+ data = TransformedDict(data)
96
+ data.__transform_history__.append(
97
+ dict(
98
+ name="load_example_from_metadata_row",
99
+ instance=hex(id(load_example_from_metadata_row)),
100
+ start_time=_start_parse_time,
101
+ end_time=_stop_parse_time,
102
+ processing_time=_stop_parse_time - _start_parse_time,
103
+ )
104
+ )
105
+
106
+ # Apply the transformation pipeline to the data
107
+ if exists(self.transform):
108
+ try:
109
+ rng_state_dict = capture_rng_states(include_cuda=False)
110
+ data = self.transform(data)
111
+ except KeyboardInterrupt as e:
112
+ raise e
113
+ except Exception as e:
114
+ # Log the error and save the failed example to disk (optional)
115
+ logger.info(f"Error processing row {idx} ({example_id}): {e}")
116
+
117
+ if exists(self.save_failed_examples_to_dir):
118
+ save_failed_example_to_disk(
119
+ example_id=example_id,
120
+ error_msg=e,
121
+ rng_state_dict=rng_state_dict,
122
+ data={}, # We do not save the data, since it may be large.
123
+ fail_dir=self.save_failed_examples_to_dir,
124
+ )
125
+ raise e
126
+
127
+ return data
128
+
129
+
130
+ class MultidomainDFParser(MetadataRowParser):
131
+ """Parser for Qian's multidomain data"""
132
+
133
+ def __init__(
134
+ self,
135
+ example_id_colname: str = "example_id",
136
+ path_colname: str = "path",
137
+ ):
138
+ self.example_id_colname = example_id_colname
139
+ self.path_colname = path_colname
140
+
141
+ def _parse(self, row: dict) -> dict[str, Any]:
142
+ query_pn_unit_iids = None
143
+ assembly_id = "1"
144
+
145
+ return {
146
+ "example_id": row[self.example_id_colname],
147
+ "path": Path(row[self.path_colname]),
148
+ "assembly_id": assembly_id,
149
+ "query_pn_unit_iids": query_pn_unit_iids,
150
+ "extra_info": row,
151
+ }
152
+
153
+
154
+ class LoadPairedMSAs(Transform):
155
+ """
156
+ LoadPairedMSAs adds paired MSAs from disk, overwriting previously paired MSA data.
157
+ """
158
+
159
+ def check_input(self, data: dict[str, Any]):
160
+ check_contains_keys(data, ["atom_array", "msa_path"])
161
+ check_is_instance(data, "atom_array", AtomArray)
162
+ check_nonzero_length(data, "atom_array")
163
+
164
+ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
165
+ atom_array = data["atom_array"]
166
+ msa_file_path = data["msa_path"]
167
+ chain_type = data["chain_info"]["A"]["chain_type"]
168
+ max_msa_sequences = 10000
169
+
170
+ msa_data = load_msa_data_from_path(
171
+ msa_file_path=msa_file_path,
172
+ chain_type=chain_type,
173
+ max_msa_sequences=max_msa_sequences,
174
+ )
175
+
176
+ # split into chains
177
+ start_idx = 0
178
+ allpolymerchains = np.unique(
179
+ atom_array.chain_id[
180
+ np.isin(atom_array.chain_type, ChainType.get_polymers())
181
+ ]
182
+ )
183
+
184
+ data["polymer_msas_by_chain_id"] = {} # nuke old version
185
+ for chain_id in allpolymerchains:
186
+ sequence = data["chain_info"][chain_id][
187
+ "processed_entity_non_canonical_sequence"
188
+ ]
189
+ stop_idx = start_idx + len(sequence)
190
+
191
+ data["polymer_msas_by_chain_id"][chain_id] = {}
192
+
193
+ # trim all msa info to this chain only
194
+ for mkey in msa_data.keys():
195
+ data["polymer_msas_by_chain_id"][chain_id][mkey] = msa_data[mkey][
196
+ ..., start_idx:stop_idx
197
+ ]
198
+
199
+ # mock msa_is_padded_mask (all 0s)
200
+ data["polymer_msas_by_chain_id"][chain_id]["msa_is_padded_mask"] = np.zeros(
201
+ data["polymer_msas_by_chain_id"][chain_id]["msa"].shape, dtype=bool
202
+ )
203
+
204
+ start_idx = stop_idx
205
+
206
+ return data
@@ -0,0 +1,128 @@
1
+ from functools import partial
2
+
3
+ import torch
4
+ from atomworks.enums import ChainType
5
+ from atomworks.ml.transforms._checks import check_atom_array_annotation
6
+ from atomworks.ml.transforms.crop import compute_local_hash
7
+ from omegaconf import DictConfig
8
+ from rf3.data.ground_truth_template import (
9
+ FeaturizeNoisedGroundTruthAsTemplateDistogram,
10
+ TokenGroupNoiseScaleSampler,
11
+ af3_noise_scale_distribution_wrapped,
12
+ af3_noise_scale_to_noise_level,
13
+ )
14
+
15
+
16
+ def annotate_pre_crop_hash(data: dict) -> dict:
17
+ hash_pre = compute_local_hash(data["atom_array"])
18
+ data["atom_array"].set_annotation("hash_pre", hash_pre)
19
+ return data
20
+
21
+
22
+ def annotate_post_crop_hash(data: dict) -> dict:
23
+ hash_post = compute_local_hash(data["atom_array"])
24
+ data["atom_array"].set_annotation("hash_post", hash_post)
25
+ return data
26
+
27
+
28
+ def set_to_occupancy_0_where_crop_hashes_differ(data: dict) -> dict:
29
+ check_atom_array_annotation(
30
+ data["atom_array"], ["hash_pre", "hash_post", "occupancy"]
31
+ )
32
+
33
+ # Create a mask of where hash_pre != hash_post
34
+ atom_array = data["atom_array"]
35
+ mask = atom_array.get_annotation("hash_pre") != atom_array.get_annotation(
36
+ "hash_post"
37
+ )
38
+
39
+ # Where the hashes differ, set occupancy to 0
40
+ atom_array.occupancy[mask] = 0
41
+
42
+ return data
43
+
44
+
45
+ def build_ground_truth_distogram_transform(
46
+ *,
47
+ template_noise_scales: dict[str, float | None] | DictConfig,
48
+ allowed_chain_types_for_conditioning: list[ChainType] | None = None,
49
+ p_condition_per_token: float = 0.0,
50
+ p_provide_inter_molecule_distances: float = 0.0,
51
+ is_inference: bool = False,
52
+ ) -> FeaturizeNoisedGroundTruthAsTemplateDistogram:
53
+ """
54
+ Build a FeaturizeNoisedGroundTruthAsTemplateDistogram transform for either training or inference.
55
+
56
+ For inference, we must be deterministic, so we:
57
+ - Use constant noise scales (1.0)
58
+ - Always apply token-level conditioning
59
+
60
+ Args:
61
+ template_noise_scales (dict[str, float | None] | DictConfig):
62
+ Noise scales for 'atomized' and 'not_atomized' tokens. If is_inference=True, these are used as constants.
63
+ If is_inference=False, these are used as upper bounds for the noise scale distribution.
64
+ allowed_chain_types_for_conditioning (list[ChainType] | None):
65
+ List of allowed chain types for conditioning. None disables conditioning.
66
+ p_condition_per_token (float):
67
+ Probability of conditioning each eligible token.
68
+ p_provide_inter_molecule_distances (float):
69
+ Probability of providing inter-molecule (inter-chain) distances.
70
+ is_inference (bool):
71
+ If True, use constant noise scales for conditioning. If False, sample from provided distributions.
72
+
73
+ Returns:
74
+ FeaturizeNoisedGroundTruthAsTemplateDistogram: The configured transform.
75
+ """
76
+ mask_and_sampling_fns = []
77
+ if is_inference:
78
+ # Use constant noise scales for inference, rather than sampling (no stochasticity)
79
+ if template_noise_scales["atomized"] is not None:
80
+ mask_and_sampling_fns.append(
81
+ (
82
+ lambda arr: arr.atomize,
83
+ lambda size: torch.ones(size) * template_noise_scales["atomized"],
84
+ )
85
+ )
86
+ if template_noise_scales["not_atomized"] is not None:
87
+ mask_and_sampling_fns.append(
88
+ (
89
+ lambda arr: ~arr.atomize,
90
+ lambda size: torch.ones(size)
91
+ * template_noise_scales["not_atomized"],
92
+ )
93
+ )
94
+ else:
95
+ # Use noise scale distributions for training
96
+ if template_noise_scales["atomized"] is not None:
97
+ mask_and_sampling_fns.append(
98
+ (
99
+ lambda arr: arr.atomize,
100
+ partial(
101
+ af3_noise_scale_distribution_wrapped,
102
+ upper_noise_level=af3_noise_scale_to_noise_level(
103
+ template_noise_scales["atomized"]
104
+ ).item(),
105
+ ),
106
+ )
107
+ )
108
+ if template_noise_scales["not_atomized"] is not None:
109
+ mask_and_sampling_fns.append(
110
+ (
111
+ lambda arr: ~arr.atomize,
112
+ partial(
113
+ af3_noise_scale_distribution_wrapped,
114
+ upper_noise_level=af3_noise_scale_to_noise_level(
115
+ template_noise_scales["not_atomized"]
116
+ ).item(),
117
+ ),
118
+ )
119
+ )
120
+
121
+ return FeaturizeNoisedGroundTruthAsTemplateDistogram(
122
+ noise_scale_distribution=TokenGroupNoiseScaleSampler(
123
+ mask_and_sampling_fns=tuple(mask_and_sampling_fns),
124
+ ),
125
+ allowed_chain_types=allowed_chain_types_for_conditioning,
126
+ p_condition_per_token=p_condition_per_token,
127
+ p_provide_inter_molecule_distances=p_provide_inter_molecule_distances,
128
+ )