rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# '''
|
|
2
|
+
# Tailored dataset wrappers for design tasks
|
|
3
|
+
# '''
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import textwrap
|
|
8
|
+
from os import PathLike
|
|
9
|
+
from typing import Any, Dict, List
|
|
10
|
+
|
|
11
|
+
import yaml
|
|
12
|
+
from atomworks.ml.datasets import MolecularDataset
|
|
13
|
+
from atomworks.ml.transforms.base import Compose, Transform
|
|
14
|
+
from omegaconf import DictConfig, OmegaConf
|
|
15
|
+
from rfd3.inference.input_parsing import (
|
|
16
|
+
DesignInputSpecification,
|
|
17
|
+
)
|
|
18
|
+
from rfd3.utils.inference import ensure_input_is_abspath
|
|
19
|
+
from torch.utils.data import (
|
|
20
|
+
DataLoader,
|
|
21
|
+
SequentialSampler,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from foundry.utils.datasets import assemble_distributed_loader
|
|
25
|
+
from foundry.utils.ddp import RankedLogger
|
|
26
|
+
|
|
27
|
+
logger = RankedLogger(__name__, rank_zero_only=True)
|
|
28
|
+
all_ranks_logger = RankedLogger(__name__, rank_zero_only=False)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ContigJsonDataset(MolecularDataset):
|
|
32
|
+
"""
|
|
33
|
+
Enables loading of JSON files containing contig data for benchmark design tasks,
|
|
34
|
+
or the passing of examples through analogously-structured hydra configs.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
*,
|
|
40
|
+
data: PathLike | Dict[str, dict | DesignInputSpecification],
|
|
41
|
+
cif_parser_args: dict | None,
|
|
42
|
+
transform: Transform | Compose | None,
|
|
43
|
+
name: str | None,
|
|
44
|
+
subset_to_keys: List[str] | None,
|
|
45
|
+
eval_every_n: int,
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Args:
|
|
49
|
+
- data: path to the JSON file containing the contig data
|
|
50
|
+
- cif_parser_args: arguments for the CIF parser
|
|
51
|
+
- transform: transform to apply to the data
|
|
52
|
+
- name: name of the dataset
|
|
53
|
+
- subset_to_keys: list of keys to subset the data to
|
|
54
|
+
- evaluate_every_n: how many times should this dataset be evaluated?
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
if isinstance(data, (PathLike, str)):
|
|
58
|
+
self.json_path = data
|
|
59
|
+
original_data = self._load_from_path(data)
|
|
60
|
+
elif isinstance(data, DictConfig):
|
|
61
|
+
self.json_path = None
|
|
62
|
+
original_data = OmegaConf.to_object(data)
|
|
63
|
+
else:
|
|
64
|
+
self.json_path = None
|
|
65
|
+
original_data = data
|
|
66
|
+
|
|
67
|
+
# These will have already been added at inference time, but this block is useful for validation.
|
|
68
|
+
if "global_args" in original_data:
|
|
69
|
+
global_args = original_data.pop("global_args")
|
|
70
|
+
for k, v in original_data.items():
|
|
71
|
+
original_data[k].update(global_args)
|
|
72
|
+
|
|
73
|
+
self._data = original_data
|
|
74
|
+
|
|
75
|
+
if subset_to_keys is not None:
|
|
76
|
+
assert (
|
|
77
|
+
len(subset_to_keys) > 0
|
|
78
|
+
), "subset_to_keys must be a non-empty list of keys."
|
|
79
|
+
self._data = {k: v for k, v in self._data.items() if k in subset_to_keys}
|
|
80
|
+
self._check_json_keys()
|
|
81
|
+
|
|
82
|
+
# ...basic assignments
|
|
83
|
+
self.name = name if name is not None else "json-dataset"
|
|
84
|
+
self.transform = transform
|
|
85
|
+
|
|
86
|
+
self.cif_parser_args = cif_parser_args
|
|
87
|
+
self.eval_every_n = eval_every_n
|
|
88
|
+
|
|
89
|
+
if len(self) > 1_000:
|
|
90
|
+
logger.warning(
|
|
91
|
+
"ContigJsonDataset contains more than 1,000 entries. This may lead to performance issues."
|
|
92
|
+
)
|
|
93
|
+
elif len(self) == 0:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"ContigJsonDataset is empty, data: {}. Names: {}".format(
|
|
96
|
+
data, self.names
|
|
97
|
+
)
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
l = 46
|
|
101
|
+
fmt_names = textwrap.fill(
|
|
102
|
+
", ".join(self.names), width=l
|
|
103
|
+
) # .replace('\n', '+\n+ ')
|
|
104
|
+
logger.info(
|
|
105
|
+
f"\n+{l * '-'}+\n"
|
|
106
|
+
f"Dataset {self.name}:\n"
|
|
107
|
+
f" - Found {len(self):,} examples:\n"
|
|
108
|
+
f"{fmt_names}\n"
|
|
109
|
+
f"\n+{l * '-'}+\n"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def _load_from_path(data):
|
|
114
|
+
"""Load data from a JSON or YAML file."""
|
|
115
|
+
assert os.path.exists(data), f"Input file {data} does not exist."
|
|
116
|
+
with open(data, "r") as f:
|
|
117
|
+
if data.endswith(".json"):
|
|
118
|
+
data = json.load(f)
|
|
119
|
+
elif data.endswith(".yaml"):
|
|
120
|
+
data = yaml.safe_load(f)
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError(f"Input file {data} must be a JSON or YAML file.")
|
|
123
|
+
return data
|
|
124
|
+
|
|
125
|
+
def _check_json_keys(self):
|
|
126
|
+
"""Check if the JSON keys are valid."""
|
|
127
|
+
for k, data in self.data.items():
|
|
128
|
+
if not isinstance(data, (dict, DesignInputSpecification)):
|
|
129
|
+
raise ValueError("Each item in the JSON data must be a dictionary.")
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def data(self):
|
|
133
|
+
"""Expose underlying dataframe as property to discourage changing it (can lead to unexpected behavior with torch ConcatDatasets)."""
|
|
134
|
+
return self._data
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def names(self) -> List[str]:
|
|
138
|
+
return list(self.data.keys())
|
|
139
|
+
|
|
140
|
+
def __len__(self) -> int:
|
|
141
|
+
"""Pass through the length of the wrapped dataset."""
|
|
142
|
+
return len(self.names)
|
|
143
|
+
|
|
144
|
+
def __contains__(self, example_id: str) -> bool:
|
|
145
|
+
"""Pass through the contains method of the wrapped dataset."""
|
|
146
|
+
return example_id in self.names
|
|
147
|
+
|
|
148
|
+
def id_to_idx(self, example_id: str) -> int:
|
|
149
|
+
"""Pass through the id_to_idx method of the wrapped dataset."""
|
|
150
|
+
return self.names.index(example_id)
|
|
151
|
+
|
|
152
|
+
def idx_to_id(self, idx: int) -> str:
|
|
153
|
+
"""Pass through the idx_to_id method of the wrapped dataset."""
|
|
154
|
+
return self.names[idx]
|
|
155
|
+
|
|
156
|
+
def __getitem__(self, idx: int) -> Any:
|
|
157
|
+
"""Pass through the getitem method of the wrapped dataset."""
|
|
158
|
+
example_id = self.idx_to_id(idx)
|
|
159
|
+
spec = self.data[example_id]
|
|
160
|
+
|
|
161
|
+
# if 'input' in metadata and not abspath, prepend the source json directory to the file path
|
|
162
|
+
if not isinstance(spec, DesignInputSpecification):
|
|
163
|
+
spec = ensure_input_is_abspath(spec, self.json_path)
|
|
164
|
+
spec["cif_parser_args"] = self.cif_parser_args
|
|
165
|
+
spec = DesignInputSpecification.safe_init(**spec)
|
|
166
|
+
|
|
167
|
+
# Create pipeline input
|
|
168
|
+
data = spec.to_pipeline_input(example_id=example_id)
|
|
169
|
+
|
|
170
|
+
# Apply transforms and return
|
|
171
|
+
data = self.transform(data)
|
|
172
|
+
return data
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def assemble_distributed_inference_loader_from_json(
|
|
176
|
+
*, rank: int, world_size: int, **dataset_kwargs
|
|
177
|
+
) -> DataLoader:
|
|
178
|
+
"""
|
|
179
|
+
Assemble a distributed inference DataLoader from JSONs.
|
|
180
|
+
example:
|
|
181
|
+
data={
|
|
182
|
+
"backbone_0": {**args},
|
|
183
|
+
"backbone_1": {**args}
|
|
184
|
+
}
|
|
185
|
+
"""
|
|
186
|
+
dataset = ContigJsonDataset(**dataset_kwargs)
|
|
187
|
+
sampler = SequentialSampler(dataset)
|
|
188
|
+
return assemble_distributed_loader(
|
|
189
|
+
dataset=dataset,
|
|
190
|
+
sampler=sampler,
|
|
191
|
+
rank=rank,
|
|
192
|
+
world_size=world_size,
|
|
193
|
+
)
|