rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/utils/frames.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# TODO: REFACTOR; COPIED FROM RF2AA. WE NEED TO ADD DOCSTRINGS, EXAMPLES, HOPEFULLY TESTS, AND CLEAN UP
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from rf3.chemical import NFRAMES, NNAPROTAAS, costgtNA
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def is_atom(seq):
|
|
8
|
+
return seq > NNAPROTAAS
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_frames(xyz_in, xyz_mask, seq, frame_indices, atom_frames=None):
|
|
12
|
+
# B,L,natoms = xyz_in.shape[:3]
|
|
13
|
+
frames = frame_indices[seq]
|
|
14
|
+
atoms = is_atom(seq)
|
|
15
|
+
if torch.any(atoms):
|
|
16
|
+
frames[:, atoms[0].nonzero().flatten(), 0] = atom_frames
|
|
17
|
+
|
|
18
|
+
frame_mask = ~torch.all(frames[..., 0, :] == frames[..., 1, :], axis=-1)
|
|
19
|
+
|
|
20
|
+
# frame_mask *= torch.all(
|
|
21
|
+
# torch.gather(xyz_mask,2,frames.reshape(B,L,-1)).reshape(B,L,-1,3),
|
|
22
|
+
# axis=-1)
|
|
23
|
+
|
|
24
|
+
return frames, frame_mask
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# build a frame from 3 points
|
|
28
|
+
# fd - more complicated version splits angle deviations between CA-N and CA-C (giving more accurate CB position)
|
|
29
|
+
# fd - makes no assumptions about input dims (other than last 1 is xyz)
|
|
30
|
+
def rigid_from_3_points(N, Ca, C, is_na=None, eps=1e-4):
|
|
31
|
+
dims = N.shape[:-1]
|
|
32
|
+
|
|
33
|
+
v1 = C - Ca
|
|
34
|
+
v2 = N - Ca
|
|
35
|
+
e1 = v1 / (torch.norm(v1, dim=-1, keepdim=True) + eps)
|
|
36
|
+
u2 = v2 - (torch.einsum("...li, ...li -> ...l", e1, v2)[..., None] * e1)
|
|
37
|
+
e2 = u2 / (torch.norm(u2, dim=-1, keepdim=True) + eps)
|
|
38
|
+
e3 = torch.cross(e1, e2, dim=-1)
|
|
39
|
+
R = torch.cat(
|
|
40
|
+
[e1[..., None], e2[..., None], e3[..., None]], axis=-1
|
|
41
|
+
) # [B,L,3,3] - rotation matrix
|
|
42
|
+
|
|
43
|
+
v2 = v2 / (torch.norm(v2, dim=-1, keepdim=True) + eps)
|
|
44
|
+
cosref = torch.sum(e1 * v2, dim=-1)
|
|
45
|
+
|
|
46
|
+
costgt = torch.full(dims, -0.3616, device=N.device)
|
|
47
|
+
if is_na is not None:
|
|
48
|
+
costgt[is_na] = costgtNA
|
|
49
|
+
|
|
50
|
+
cos2del = torch.clamp(
|
|
51
|
+
cosref * costgt
|
|
52
|
+
+ torch.sqrt((1 - cosref * cosref) * (1 - costgt * costgt) + eps),
|
|
53
|
+
min=-1.0,
|
|
54
|
+
max=1.0,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
cosdel = torch.sqrt(0.5 * (1 + cos2del) + eps)
|
|
58
|
+
|
|
59
|
+
sindel = torch.sign(costgt - cosref) * torch.sqrt(1 - 0.5 * (1 + cos2del) + eps)
|
|
60
|
+
|
|
61
|
+
Rp = torch.eye(3, device=N.device).repeat(*dims, 1, 1)
|
|
62
|
+
Rp[..., 0, 0] = cosdel
|
|
63
|
+
Rp[..., 0, 1] = -sindel
|
|
64
|
+
Rp[..., 1, 0] = sindel
|
|
65
|
+
Rp[..., 1, 1] = cosdel
|
|
66
|
+
R = torch.einsum("...ij,...jk->...ik", R, Rp)
|
|
67
|
+
|
|
68
|
+
return R, Ca
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def mask_unresolved_frames_batched(frames, frame_mask, atom_mask):
|
|
72
|
+
"""
|
|
73
|
+
reindex frames tensor from relative indices to absolute indices and masks out frames with atoms that are unresolved
|
|
74
|
+
in the structure
|
|
75
|
+
Input:
|
|
76
|
+
- frames: relative indices for frames (B, L, nframes, 3)
|
|
77
|
+
- frame_mask: mask for which frames are valid to compute FAPE/losses (B, L, nframes)
|
|
78
|
+
- atom_mask: mask for seen coordinates (B, L, natoms)
|
|
79
|
+
Output:
|
|
80
|
+
- frames_reindex: absolute indices for frames
|
|
81
|
+
- frame_mask_update: updated frame mask with frames with unresolved atoms removed
|
|
82
|
+
"""
|
|
83
|
+
B, L, natoms = atom_mask.shape
|
|
84
|
+
|
|
85
|
+
# reindex frames for flat X
|
|
86
|
+
frames_reindex = (
|
|
87
|
+
torch.arange(L, device=frames.device)[None, :, None, None] + frames[..., 0]
|
|
88
|
+
) * natoms + frames[..., 1]
|
|
89
|
+
|
|
90
|
+
masked_atom_frames = torch.any(
|
|
91
|
+
frames_reindex > L * natoms, dim=-1
|
|
92
|
+
) # find frames with atoms that aren't resolved
|
|
93
|
+
masked_atom_frames *= torch.any(frames_reindex < 0, dim=-1)
|
|
94
|
+
# There are currently indices for frames that aren't in the coordinates bc they arent resolved, reset these indices to 0 to avoid
|
|
95
|
+
# indexing errors
|
|
96
|
+
frames_reindex[masked_atom_frames, :] = 0
|
|
97
|
+
|
|
98
|
+
frame_mask_update = frame_mask.clone()
|
|
99
|
+
frame_mask_update *= ~masked_atom_frames
|
|
100
|
+
frame_mask_update *= torch.all(
|
|
101
|
+
torch.gather(
|
|
102
|
+
atom_mask.reshape(B, L * natoms),
|
|
103
|
+
1,
|
|
104
|
+
frames_reindex.reshape(B, L * NFRAMES * 3),
|
|
105
|
+
).reshape(B, L, -1, 3),
|
|
106
|
+
axis=-1,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return frames_reindex, frame_mask_update
|