rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/data/paired_msa.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import socket
|
|
3
|
+
import time
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from atomworks.common import exists
|
|
9
|
+
from atomworks.enums import ChainType
|
|
10
|
+
from atomworks.ml.datasets import StructuralDatasetWrapper, logger
|
|
11
|
+
from atomworks.ml.datasets.parsers import (
|
|
12
|
+
MetadataRowParser,
|
|
13
|
+
load_example_from_metadata_row,
|
|
14
|
+
)
|
|
15
|
+
from atomworks.ml.transforms._checks import (
|
|
16
|
+
check_contains_keys,
|
|
17
|
+
check_is_instance,
|
|
18
|
+
check_nonzero_length,
|
|
19
|
+
)
|
|
20
|
+
from atomworks.ml.transforms.base import Transform, TransformedDict
|
|
21
|
+
from atomworks.ml.transforms.msa._msa_loading_utils import load_msa_data_from_path
|
|
22
|
+
from atomworks.ml.utils.rng import capture_rng_states
|
|
23
|
+
from biotite.structure import AtomArray, concatenate
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# input data wrapper that allows multiple input files separated by ':'
|
|
27
|
+
# data is loaded as concatentation of all inputs
|
|
28
|
+
class MultiInputDatasetWrapper(StructuralDatasetWrapper):
|
|
29
|
+
def __init__(self, *args, **kwargs):
|
|
30
|
+
super().__init__(*args, **kwargs)
|
|
31
|
+
|
|
32
|
+
def __getitem__(self, idx: int) -> Any:
|
|
33
|
+
# Capture example ID & current rng state (for reproducibility & debugging)
|
|
34
|
+
if hasattr(self, "idx_to_id"):
|
|
35
|
+
# ...if the dataset has a custom idx_to_id method, use it (e.g., for a PandasDataset)
|
|
36
|
+
example_id = self.idx_to_id(idx)
|
|
37
|
+
else:
|
|
38
|
+
# ...otherwise, fallback to a the `id_column` or a string representation of the index
|
|
39
|
+
example_id = (
|
|
40
|
+
self.dataset[idx][self.id_column] if self.id_column else f"row_{idx}"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Get process id and hostname (for debugging)
|
|
44
|
+
logger.debug(
|
|
45
|
+
f"({socket.gethostname()}:{os.getpid()}) Processing example ID: {example_id}"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Load the row, using the __getitem__ method of the dataset
|
|
49
|
+
row = self.dataset[idx]
|
|
50
|
+
pdb_path = row["pdb_path"].split(":")
|
|
51
|
+
|
|
52
|
+
# Process the row into a transform-ready dictionary with the given CIF and dataset parsers
|
|
53
|
+
# We require the "data" dictionary output from `load_example_from_metadata_row` to contain, at a minimum:
|
|
54
|
+
# (a) An "id" key, which uniquely identifies the example within the dataframe; and,
|
|
55
|
+
# (b) The "path" key, which is the path to the CIF file
|
|
56
|
+
_start_parse_time = time.time()
|
|
57
|
+
data = None
|
|
58
|
+
assert len(pdb_path) <= 2
|
|
59
|
+
|
|
60
|
+
for pdb_i in pdb_path:
|
|
61
|
+
row_i = {"example_id": row["example_id"], "path": pdb_i}
|
|
62
|
+
data_i = load_example_from_metadata_row(
|
|
63
|
+
row_i, self.dataset_parser, cif_parser_args=self.cif_parser_args
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if data is None:
|
|
67
|
+
data = data_i
|
|
68
|
+
else:
|
|
69
|
+
data_i["atom_array"].pn_unit_id = np.full(
|
|
70
|
+
len(data_i["atom_array"]), "B_1"
|
|
71
|
+
) # unique pn unit id
|
|
72
|
+
data_i["atom_array"].pn_unit_iid = np.full(
|
|
73
|
+
len(data_i["atom_array"]), "B_1"
|
|
74
|
+
) # unique pn unit iid
|
|
75
|
+
data_i["atom_array"].chain_id = np.full(
|
|
76
|
+
len(data_i["atom_array"]), "B"
|
|
77
|
+
) # unique chain id
|
|
78
|
+
data_i["atom_array"].chain_iid = np.full(
|
|
79
|
+
len(data_i["atom_array"]), "B"
|
|
80
|
+
) # unique chain iid
|
|
81
|
+
data["atom_array"] = concatenate(
|
|
82
|
+
[data["atom_array"], data_i["atom_array"]]
|
|
83
|
+
)
|
|
84
|
+
data["atom_array_stack"] = concatenate(
|
|
85
|
+
[data["atom_array_stack"], data_i["atom_array_stack"]]
|
|
86
|
+
)
|
|
87
|
+
data["chain_info"]["B"] = data_i["chain_info"]["A"]
|
|
88
|
+
|
|
89
|
+
# 'example_id', 'path', 'assembly_id', 'query_pn_unit_iids',
|
|
90
|
+
data["path"] = row["pdb_path"]
|
|
91
|
+
data["msa_path"] = Path(row["msa_path"]) # save msa
|
|
92
|
+
_stop_parse_time = time.time()
|
|
93
|
+
|
|
94
|
+
# Manually add timing for cif-parsing
|
|
95
|
+
data = TransformedDict(data)
|
|
96
|
+
data.__transform_history__.append(
|
|
97
|
+
dict(
|
|
98
|
+
name="load_example_from_metadata_row",
|
|
99
|
+
instance=hex(id(load_example_from_metadata_row)),
|
|
100
|
+
start_time=_start_parse_time,
|
|
101
|
+
end_time=_stop_parse_time,
|
|
102
|
+
processing_time=_stop_parse_time - _start_parse_time,
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Apply the transformation pipeline to the data
|
|
107
|
+
if exists(self.transform):
|
|
108
|
+
try:
|
|
109
|
+
rng_state_dict = capture_rng_states(include_cuda=False)
|
|
110
|
+
data = self.transform(data)
|
|
111
|
+
except KeyboardInterrupt as e:
|
|
112
|
+
raise e
|
|
113
|
+
except Exception as e:
|
|
114
|
+
# Log the error and save the failed example to disk (optional)
|
|
115
|
+
logger.info(f"Error processing row {idx} ({example_id}): {e}")
|
|
116
|
+
|
|
117
|
+
if exists(self.save_failed_examples_to_dir):
|
|
118
|
+
save_failed_example_to_disk(
|
|
119
|
+
example_id=example_id,
|
|
120
|
+
error_msg=e,
|
|
121
|
+
rng_state_dict=rng_state_dict,
|
|
122
|
+
data={}, # We do not save the data, since it may be large.
|
|
123
|
+
fail_dir=self.save_failed_examples_to_dir,
|
|
124
|
+
)
|
|
125
|
+
raise e
|
|
126
|
+
|
|
127
|
+
return data
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class MultidomainDFParser(MetadataRowParser):
|
|
131
|
+
"""Parser for Qian's multidomain data"""
|
|
132
|
+
|
|
133
|
+
def __init__(
|
|
134
|
+
self,
|
|
135
|
+
example_id_colname: str = "example_id",
|
|
136
|
+
path_colname: str = "path",
|
|
137
|
+
):
|
|
138
|
+
self.example_id_colname = example_id_colname
|
|
139
|
+
self.path_colname = path_colname
|
|
140
|
+
|
|
141
|
+
def _parse(self, row: dict) -> dict[str, Any]:
|
|
142
|
+
query_pn_unit_iids = None
|
|
143
|
+
assembly_id = "1"
|
|
144
|
+
|
|
145
|
+
return {
|
|
146
|
+
"example_id": row[self.example_id_colname],
|
|
147
|
+
"path": Path(row[self.path_colname]),
|
|
148
|
+
"assembly_id": assembly_id,
|
|
149
|
+
"query_pn_unit_iids": query_pn_unit_iids,
|
|
150
|
+
"extra_info": row,
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class LoadPairedMSAs(Transform):
|
|
155
|
+
"""
|
|
156
|
+
LoadPairedMSAs adds paired MSAs from disk, overwriting previously paired MSA data.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def check_input(self, data: dict[str, Any]):
|
|
160
|
+
check_contains_keys(data, ["atom_array", "msa_path"])
|
|
161
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
162
|
+
check_nonzero_length(data, "atom_array")
|
|
163
|
+
|
|
164
|
+
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
165
|
+
atom_array = data["atom_array"]
|
|
166
|
+
msa_file_path = data["msa_path"]
|
|
167
|
+
chain_type = data["chain_info"]["A"]["chain_type"]
|
|
168
|
+
max_msa_sequences = 10000
|
|
169
|
+
|
|
170
|
+
msa_data = load_msa_data_from_path(
|
|
171
|
+
msa_file_path=msa_file_path,
|
|
172
|
+
chain_type=chain_type,
|
|
173
|
+
max_msa_sequences=max_msa_sequences,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# split into chains
|
|
177
|
+
start_idx = 0
|
|
178
|
+
allpolymerchains = np.unique(
|
|
179
|
+
atom_array.chain_id[
|
|
180
|
+
np.isin(atom_array.chain_type, ChainType.get_polymers())
|
|
181
|
+
]
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
data["polymer_msas_by_chain_id"] = {} # nuke old version
|
|
185
|
+
for chain_id in allpolymerchains:
|
|
186
|
+
sequence = data["chain_info"][chain_id][
|
|
187
|
+
"processed_entity_non_canonical_sequence"
|
|
188
|
+
]
|
|
189
|
+
stop_idx = start_idx + len(sequence)
|
|
190
|
+
|
|
191
|
+
data["polymer_msas_by_chain_id"][chain_id] = {}
|
|
192
|
+
|
|
193
|
+
# trim all msa info to this chain only
|
|
194
|
+
for mkey in msa_data.keys():
|
|
195
|
+
data["polymer_msas_by_chain_id"][chain_id][mkey] = msa_data[mkey][
|
|
196
|
+
..., start_idx:stop_idx
|
|
197
|
+
]
|
|
198
|
+
|
|
199
|
+
# mock msa_is_padded_mask (all 0s)
|
|
200
|
+
data["polymer_msas_by_chain_id"][chain_id]["msa_is_padded_mask"] = np.zeros(
|
|
201
|
+
data["polymer_msas_by_chain_id"][chain_id]["msa"].shape, dtype=bool
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
start_idx = stop_idx
|
|
205
|
+
|
|
206
|
+
return data
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from atomworks.enums import ChainType
|
|
5
|
+
from atomworks.ml.transforms._checks import check_atom_array_annotation
|
|
6
|
+
from atomworks.ml.transforms.crop import compute_local_hash
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
from rf3.data.ground_truth_template import (
|
|
9
|
+
FeaturizeNoisedGroundTruthAsTemplateDistogram,
|
|
10
|
+
TokenGroupNoiseScaleSampler,
|
|
11
|
+
af3_noise_scale_distribution_wrapped,
|
|
12
|
+
af3_noise_scale_to_noise_level,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def annotate_pre_crop_hash(data: dict) -> dict:
|
|
17
|
+
hash_pre = compute_local_hash(data["atom_array"])
|
|
18
|
+
data["atom_array"].set_annotation("hash_pre", hash_pre)
|
|
19
|
+
return data
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def annotate_post_crop_hash(data: dict) -> dict:
|
|
23
|
+
hash_post = compute_local_hash(data["atom_array"])
|
|
24
|
+
data["atom_array"].set_annotation("hash_post", hash_post)
|
|
25
|
+
return data
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def set_to_occupancy_0_where_crop_hashes_differ(data: dict) -> dict:
|
|
29
|
+
check_atom_array_annotation(
|
|
30
|
+
data["atom_array"], ["hash_pre", "hash_post", "occupancy"]
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Create a mask of where hash_pre != hash_post
|
|
34
|
+
atom_array = data["atom_array"]
|
|
35
|
+
mask = atom_array.get_annotation("hash_pre") != atom_array.get_annotation(
|
|
36
|
+
"hash_post"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# Where the hashes differ, set occupancy to 0
|
|
40
|
+
atom_array.occupancy[mask] = 0
|
|
41
|
+
|
|
42
|
+
return data
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def build_ground_truth_distogram_transform(
|
|
46
|
+
*,
|
|
47
|
+
template_noise_scales: dict[str, float | None] | DictConfig,
|
|
48
|
+
allowed_chain_types_for_conditioning: list[ChainType] | None = None,
|
|
49
|
+
p_condition_per_token: float = 0.0,
|
|
50
|
+
p_provide_inter_molecule_distances: float = 0.0,
|
|
51
|
+
is_inference: bool = False,
|
|
52
|
+
) -> FeaturizeNoisedGroundTruthAsTemplateDistogram:
|
|
53
|
+
"""
|
|
54
|
+
Build a FeaturizeNoisedGroundTruthAsTemplateDistogram transform for either training or inference.
|
|
55
|
+
|
|
56
|
+
For inference, we must be deterministic, so we:
|
|
57
|
+
- Use constant noise scales (1.0)
|
|
58
|
+
- Always apply token-level conditioning
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
template_noise_scales (dict[str, float | None] | DictConfig):
|
|
62
|
+
Noise scales for 'atomized' and 'not_atomized' tokens. If is_inference=True, these are used as constants.
|
|
63
|
+
If is_inference=False, these are used as upper bounds for the noise scale distribution.
|
|
64
|
+
allowed_chain_types_for_conditioning (list[ChainType] | None):
|
|
65
|
+
List of allowed chain types for conditioning. None disables conditioning.
|
|
66
|
+
p_condition_per_token (float):
|
|
67
|
+
Probability of conditioning each eligible token.
|
|
68
|
+
p_provide_inter_molecule_distances (float):
|
|
69
|
+
Probability of providing inter-molecule (inter-chain) distances.
|
|
70
|
+
is_inference (bool):
|
|
71
|
+
If True, use constant noise scales for conditioning. If False, sample from provided distributions.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
FeaturizeNoisedGroundTruthAsTemplateDistogram: The configured transform.
|
|
75
|
+
"""
|
|
76
|
+
mask_and_sampling_fns = []
|
|
77
|
+
if is_inference:
|
|
78
|
+
# Use constant noise scales for inference, rather than sampling (no stochasticity)
|
|
79
|
+
if template_noise_scales["atomized"] is not None:
|
|
80
|
+
mask_and_sampling_fns.append(
|
|
81
|
+
(
|
|
82
|
+
lambda arr: arr.atomize,
|
|
83
|
+
lambda size: torch.ones(size) * template_noise_scales["atomized"],
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
if template_noise_scales["not_atomized"] is not None:
|
|
87
|
+
mask_and_sampling_fns.append(
|
|
88
|
+
(
|
|
89
|
+
lambda arr: ~arr.atomize,
|
|
90
|
+
lambda size: torch.ones(size)
|
|
91
|
+
* template_noise_scales["not_atomized"],
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
# Use noise scale distributions for training
|
|
96
|
+
if template_noise_scales["atomized"] is not None:
|
|
97
|
+
mask_and_sampling_fns.append(
|
|
98
|
+
(
|
|
99
|
+
lambda arr: arr.atomize,
|
|
100
|
+
partial(
|
|
101
|
+
af3_noise_scale_distribution_wrapped,
|
|
102
|
+
upper_noise_level=af3_noise_scale_to_noise_level(
|
|
103
|
+
template_noise_scales["atomized"]
|
|
104
|
+
).item(),
|
|
105
|
+
),
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
if template_noise_scales["not_atomized"] is not None:
|
|
109
|
+
mask_and_sampling_fns.append(
|
|
110
|
+
(
|
|
111
|
+
lambda arr: ~arr.atomize,
|
|
112
|
+
partial(
|
|
113
|
+
af3_noise_scale_distribution_wrapped,
|
|
114
|
+
upper_noise_level=af3_noise_scale_to_noise_level(
|
|
115
|
+
template_noise_scales["not_atomized"]
|
|
116
|
+
).item(),
|
|
117
|
+
),
|
|
118
|
+
)
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return FeaturizeNoisedGroundTruthAsTemplateDistogram(
|
|
122
|
+
noise_scale_distribution=TokenGroupNoiseScaleSampler(
|
|
123
|
+
mask_and_sampling_fns=tuple(mask_and_sampling_fns),
|
|
124
|
+
),
|
|
125
|
+
allowed_chain_types=allowed_chain_types_for_conditioning,
|
|
126
|
+
p_condition_per_token=p_condition_per_token,
|
|
127
|
+
p_provide_inter_molecule_distances=p_provide_inter_molecule_distances,
|
|
128
|
+
)
|