rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,193 @@
1
+ # '''
2
+ # Tailored dataset wrappers for design tasks
3
+ # '''
4
+
5
+ import json
6
+ import os
7
+ import textwrap
8
+ from os import PathLike
9
+ from typing import Any, Dict, List
10
+
11
+ import yaml
12
+ from atomworks.ml.datasets import MolecularDataset
13
+ from atomworks.ml.transforms.base import Compose, Transform
14
+ from omegaconf import DictConfig, OmegaConf
15
+ from rfd3.inference.input_parsing import (
16
+ DesignInputSpecification,
17
+ )
18
+ from rfd3.utils.inference import ensure_input_is_abspath
19
+ from torch.utils.data import (
20
+ DataLoader,
21
+ SequentialSampler,
22
+ )
23
+
24
+ from foundry.utils.datasets import assemble_distributed_loader
25
+ from foundry.utils.ddp import RankedLogger
26
+
27
+ logger = RankedLogger(__name__, rank_zero_only=True)
28
+ all_ranks_logger = RankedLogger(__name__, rank_zero_only=False)
29
+
30
+
31
+ class ContigJsonDataset(MolecularDataset):
32
+ """
33
+ Enables loading of JSON files containing contig data for benchmark design tasks,
34
+ or the passing of examples through analogously-structured hydra configs.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ *,
40
+ data: PathLike | Dict[str, dict | DesignInputSpecification],
41
+ cif_parser_args: dict | None,
42
+ transform: Transform | Compose | None,
43
+ name: str | None,
44
+ subset_to_keys: List[str] | None,
45
+ eval_every_n: int,
46
+ ):
47
+ """
48
+ Args:
49
+ - data: path to the JSON file containing the contig data
50
+ - cif_parser_args: arguments for the CIF parser
51
+ - transform: transform to apply to the data
52
+ - name: name of the dataset
53
+ - subset_to_keys: list of keys to subset the data to
54
+ - evaluate_every_n: how many times should this dataset be evaluated?
55
+ """
56
+
57
+ if isinstance(data, (PathLike, str)):
58
+ self.json_path = data
59
+ original_data = self._load_from_path(data)
60
+ elif isinstance(data, DictConfig):
61
+ self.json_path = None
62
+ original_data = OmegaConf.to_object(data)
63
+ else:
64
+ self.json_path = None
65
+ original_data = data
66
+
67
+ # These will have already been added at inference time, but this block is useful for validation.
68
+ if "global_args" in original_data:
69
+ global_args = original_data.pop("global_args")
70
+ for k, v in original_data.items():
71
+ original_data[k].update(global_args)
72
+
73
+ self._data = original_data
74
+
75
+ if subset_to_keys is not None:
76
+ assert (
77
+ len(subset_to_keys) > 0
78
+ ), "subset_to_keys must be a non-empty list of keys."
79
+ self._data = {k: v for k, v in self._data.items() if k in subset_to_keys}
80
+ self._check_json_keys()
81
+
82
+ # ...basic assignments
83
+ self.name = name if name is not None else "json-dataset"
84
+ self.transform = transform
85
+
86
+ self.cif_parser_args = cif_parser_args
87
+ self.eval_every_n = eval_every_n
88
+
89
+ if len(self) > 1_000:
90
+ logger.warning(
91
+ "ContigJsonDataset contains more than 1,000 entries. This may lead to performance issues."
92
+ )
93
+ elif len(self) == 0:
94
+ raise ValueError(
95
+ "ContigJsonDataset is empty, data: {}. Names: {}".format(
96
+ data, self.names
97
+ )
98
+ )
99
+
100
+ l = 46
101
+ fmt_names = textwrap.fill(
102
+ ", ".join(self.names), width=l
103
+ ) # .replace('\n', '+\n+ ')
104
+ logger.info(
105
+ f"\n+{l * '-'}+\n"
106
+ f"Dataset {self.name}:\n"
107
+ f" - Found {len(self):,} examples:\n"
108
+ f"{fmt_names}\n"
109
+ f"\n+{l * '-'}+\n"
110
+ )
111
+
112
+ @staticmethod
113
+ def _load_from_path(data):
114
+ """Load data from a JSON or YAML file."""
115
+ assert os.path.exists(data), f"Input file {data} does not exist."
116
+ with open(data, "r") as f:
117
+ if data.endswith(".json"):
118
+ data = json.load(f)
119
+ elif data.endswith(".yaml"):
120
+ data = yaml.safe_load(f)
121
+ else:
122
+ raise ValueError(f"Input file {data} must be a JSON or YAML file.")
123
+ return data
124
+
125
+ def _check_json_keys(self):
126
+ """Check if the JSON keys are valid."""
127
+ for k, data in self.data.items():
128
+ if not isinstance(data, (dict, DesignInputSpecification)):
129
+ raise ValueError("Each item in the JSON data must be a dictionary.")
130
+
131
+ @property
132
+ def data(self):
133
+ """Expose underlying dataframe as property to discourage changing it (can lead to unexpected behavior with torch ConcatDatasets)."""
134
+ return self._data
135
+
136
+ @property
137
+ def names(self) -> List[str]:
138
+ return list(self.data.keys())
139
+
140
+ def __len__(self) -> int:
141
+ """Pass through the length of the wrapped dataset."""
142
+ return len(self.names)
143
+
144
+ def __contains__(self, example_id: str) -> bool:
145
+ """Pass through the contains method of the wrapped dataset."""
146
+ return example_id in self.names
147
+
148
+ def id_to_idx(self, example_id: str) -> int:
149
+ """Pass through the id_to_idx method of the wrapped dataset."""
150
+ return self.names.index(example_id)
151
+
152
+ def idx_to_id(self, idx: int) -> str:
153
+ """Pass through the idx_to_id method of the wrapped dataset."""
154
+ return self.names[idx]
155
+
156
+ def __getitem__(self, idx: int) -> Any:
157
+ """Pass through the getitem method of the wrapped dataset."""
158
+ example_id = self.idx_to_id(idx)
159
+ spec = self.data[example_id]
160
+
161
+ # if 'input' in metadata and not abspath, prepend the source json directory to the file path
162
+ if not isinstance(spec, DesignInputSpecification):
163
+ spec = ensure_input_is_abspath(spec, self.json_path)
164
+ spec["cif_parser_args"] = self.cif_parser_args
165
+ spec = DesignInputSpecification.safe_init(**spec)
166
+
167
+ # Create pipeline input
168
+ data = spec.to_pipeline_input(example_id=example_id)
169
+
170
+ # Apply transforms and return
171
+ data = self.transform(data)
172
+ return data
173
+
174
+
175
+ def assemble_distributed_inference_loader_from_json(
176
+ *, rank: int, world_size: int, **dataset_kwargs
177
+ ) -> DataLoader:
178
+ """
179
+ Assemble a distributed inference DataLoader from JSONs.
180
+ example:
181
+ data={
182
+ "backbone_0": {**args},
183
+ "backbone_1": {**args}
184
+ }
185
+ """
186
+ dataset = ContigJsonDataset(**dataset_kwargs)
187
+ sampler = SequentialSampler(dataset)
188
+ return assemble_distributed_loader(
189
+ dataset=dataset,
190
+ sampler=sampler,
191
+ rank=rank,
192
+ world_size=world_size,
193
+ )