rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/kinematics.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
# TODO: Many of these functions are unused; we will deprecate and delete
|
|
2
|
+
# (They are holdovers from previous frameworks)
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
PARAMS = {
|
|
8
|
+
"DMIN": 1,
|
|
9
|
+
"DMID": 4,
|
|
10
|
+
"DMAX": 20.0,
|
|
11
|
+
"DBINS1": 30,
|
|
12
|
+
"DBINS2": 30,
|
|
13
|
+
"ABINS": 36,
|
|
14
|
+
"USE_CB": False,
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# ============================================================
|
|
19
|
+
def normQ(Q):
|
|
20
|
+
"""normalize a quaternions"""
|
|
21
|
+
return Q / torch.linalg.norm(Q, keepdim=True, dim=-1)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# ============================================================
|
|
25
|
+
def avgQ(Qs):
|
|
26
|
+
"""average a set of quaternions
|
|
27
|
+
input dims:
|
|
28
|
+
Qs - (B,N,R,4)
|
|
29
|
+
averages across 'N' dimension
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def areClose(q1, q2):
|
|
33
|
+
return (q1 * q2).sum(dim=-1) >= 0.0
|
|
34
|
+
|
|
35
|
+
N = Qs.shape[1]
|
|
36
|
+
Qsum = Qs[:, 0] / N
|
|
37
|
+
|
|
38
|
+
for i in range(1, N):
|
|
39
|
+
mask = areClose(Qs[:, 0], Qs[:, i])
|
|
40
|
+
Qsum[mask] += Qs[:, i][mask] / N
|
|
41
|
+
Qsum[~mask] -= Qs[:, i][~mask] / N
|
|
42
|
+
|
|
43
|
+
return normQ(Qsum)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def Rs2Qs(Rs):
|
|
47
|
+
Qs = torch.zeros((*Rs.shape[:-2], 4), device=Rs.device)
|
|
48
|
+
|
|
49
|
+
Qs[..., 0] = 1.0 + Rs[..., 0, 0] + Rs[..., 1, 1] + Rs[..., 2, 2]
|
|
50
|
+
Qs[..., 1] = 1.0 + Rs[..., 0, 0] - Rs[..., 1, 1] - Rs[..., 2, 2]
|
|
51
|
+
Qs[..., 2] = 1.0 - Rs[..., 0, 0] + Rs[..., 1, 1] - Rs[..., 2, 2]
|
|
52
|
+
Qs[..., 3] = 1.0 - Rs[..., 0, 0] - Rs[..., 1, 1] + Rs[..., 2, 2]
|
|
53
|
+
Qs[Qs < 0.0] = 0.0
|
|
54
|
+
Qs = torch.sqrt(Qs) / 2.0
|
|
55
|
+
Qs[..., 1] *= torch.sign(Rs[..., 2, 1] - Rs[..., 1, 2])
|
|
56
|
+
Qs[..., 2] *= torch.sign(Rs[..., 0, 2] - Rs[..., 2, 0])
|
|
57
|
+
Qs[..., 3] *= torch.sign(Rs[..., 1, 0] - Rs[..., 0, 1])
|
|
58
|
+
|
|
59
|
+
return Qs
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def Qs2Rs(Qs):
|
|
63
|
+
Rs = torch.zeros((*Qs.shape[:-1], 3, 3), device=Qs.device)
|
|
64
|
+
|
|
65
|
+
Rs[..., 0, 0] = (
|
|
66
|
+
Qs[..., 0] * Qs[..., 0]
|
|
67
|
+
+ Qs[..., 1] * Qs[..., 1]
|
|
68
|
+
- Qs[..., 2] * Qs[..., 2]
|
|
69
|
+
- Qs[..., 3] * Qs[..., 3]
|
|
70
|
+
)
|
|
71
|
+
Rs[..., 0, 1] = 2 * Qs[..., 1] * Qs[..., 2] - 2 * Qs[..., 0] * Qs[..., 3]
|
|
72
|
+
Rs[..., 0, 2] = 2 * Qs[..., 1] * Qs[..., 3] + 2 * Qs[..., 0] * Qs[..., 2]
|
|
73
|
+
Rs[..., 1, 0] = 2 * Qs[..., 1] * Qs[..., 2] + 2 * Qs[..., 0] * Qs[..., 3]
|
|
74
|
+
Rs[..., 1, 1] = (
|
|
75
|
+
Qs[..., 0] * Qs[..., 0]
|
|
76
|
+
- Qs[..., 1] * Qs[..., 1]
|
|
77
|
+
+ Qs[..., 2] * Qs[..., 2]
|
|
78
|
+
- Qs[..., 3] * Qs[..., 3]
|
|
79
|
+
)
|
|
80
|
+
Rs[..., 1, 2] = 2 * Qs[..., 2] * Qs[..., 3] - 2 * Qs[..., 0] * Qs[..., 1]
|
|
81
|
+
Rs[..., 2, 0] = 2 * Qs[..., 1] * Qs[..., 3] - 2 * Qs[..., 0] * Qs[..., 2]
|
|
82
|
+
Rs[..., 2, 1] = 2 * Qs[..., 2] * Qs[..., 3] + 2 * Qs[..., 0] * Qs[..., 1]
|
|
83
|
+
Rs[..., 2, 2] = (
|
|
84
|
+
Qs[..., 0] * Qs[..., 0]
|
|
85
|
+
- Qs[..., 1] * Qs[..., 1]
|
|
86
|
+
- Qs[..., 2] * Qs[..., 2]
|
|
87
|
+
+ Qs[..., 3] * Qs[..., 3]
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
return Rs
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# ============================================================
|
|
94
|
+
def get_pair_dist(a, b):
|
|
95
|
+
"""calculate pair distances between two sets of points
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
a,b : pytorch tensors of shape [batch,nres,3]
|
|
100
|
+
store Cartesian coordinates of two sets of atoms
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
dist : pytorch tensor of shape [batch,nres,nres]
|
|
104
|
+
stores pairwise distances between atoms in a and b
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
dist = torch.cdist(a, b, p=2)
|
|
108
|
+
return dist
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# ============================================================
|
|
112
|
+
def get_ang(a, b, c, eps=1e-4):
|
|
113
|
+
"""calculate planar angles for all consecutive triples (a[i],b[i],c[i])
|
|
114
|
+
from Cartesian coordinates of three sets of atoms a,b,c
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
a,b,c : pytorch tensors of shape [batch,nres,3]
|
|
119
|
+
store Cartesian coordinates of three sets of atoms
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
ang : pytorch tensor of shape [batch,nres]
|
|
123
|
+
stores resulting planar angles
|
|
124
|
+
"""
|
|
125
|
+
v = a - b
|
|
126
|
+
w = c - b
|
|
127
|
+
vn = v / (torch.norm(v, dim=-1, keepdim=True) + eps)
|
|
128
|
+
wn = w / (torch.norm(w, dim=-1, keepdim=True) + eps)
|
|
129
|
+
vw = torch.sum(vn * wn, dim=-1)
|
|
130
|
+
|
|
131
|
+
return torch.acos(torch.clamp(vw, -0.999, 0.999))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# ============================================================
|
|
135
|
+
def get_dih(a, b, c, d, eps=1e-4):
|
|
136
|
+
"""calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i])
|
|
137
|
+
given Cartesian coordinates of four sets of atoms a,b,c,d
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
a,b,c,d : pytorch tensors of shape [batch,nres,3]
|
|
142
|
+
store Cartesian coordinates of four sets of atoms
|
|
143
|
+
Returns
|
|
144
|
+
-------
|
|
145
|
+
dih : pytorch tensor of shape [batch,nres]
|
|
146
|
+
stores resulting dihedrals
|
|
147
|
+
"""
|
|
148
|
+
b0 = a - b
|
|
149
|
+
b1 = c - b
|
|
150
|
+
b2 = d - c
|
|
151
|
+
|
|
152
|
+
b1n = b1 / (torch.norm(b1, dim=-1, keepdim=True) + eps)
|
|
153
|
+
|
|
154
|
+
v = b0 - torch.sum(b0 * b1n, dim=-1, keepdim=True) * b1n
|
|
155
|
+
w = b2 - torch.sum(b2 * b1n, dim=-1, keepdim=True) * b1n
|
|
156
|
+
|
|
157
|
+
x = torch.sum(v * w, dim=-1)
|
|
158
|
+
y = torch.sum(torch.cross(b1n, v, dim=-1) * w, dim=-1)
|
|
159
|
+
|
|
160
|
+
return torch.atan2(y + eps, x + eps)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# ============================================================
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def generate_Cbeta(N, Ca, C):
|
|
167
|
+
# recreate Cb given N,Ca,C
|
|
168
|
+
b = Ca - N
|
|
169
|
+
c = C - Ca
|
|
170
|
+
a = torch.cross(b, c, dim=-1)
|
|
171
|
+
# Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
|
|
172
|
+
# fd: below matches sidechain generator (=Rosetta params)
|
|
173
|
+
Cb = -0.57910144 * a + 0.5689693 * b - 0.5441217 * c + Ca
|
|
174
|
+
|
|
175
|
+
return Cb
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# ============================================================
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def xyz_to_c6d(xyz, params=PARAMS):
|
|
182
|
+
"""convert cartesian coordinates into 2d distance
|
|
183
|
+
and orientation maps
|
|
184
|
+
|
|
185
|
+
Parameters
|
|
186
|
+
----------
|
|
187
|
+
xyz : pytorch tensor of shape [batch,nres,3,3]
|
|
188
|
+
stores Cartesian coordinates of backbone N,Ca,C atoms
|
|
189
|
+
Returns
|
|
190
|
+
-------
|
|
191
|
+
c6d : pytorch tensor of shape [batch,nres,nres,4]
|
|
192
|
+
stores stacked dist,omega,theta,phi 2D maps
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
batch = xyz.shape[0]
|
|
196
|
+
nres = xyz.shape[1]
|
|
197
|
+
|
|
198
|
+
# three anchor atoms
|
|
199
|
+
N = xyz[:, :, 0]
|
|
200
|
+
Ca = xyz[:, :, 1]
|
|
201
|
+
C = xyz[:, :, 2]
|
|
202
|
+
|
|
203
|
+
# recreate Cb given N,Ca,C
|
|
204
|
+
Cb = generate_Cbeta(N, Ca, C)
|
|
205
|
+
|
|
206
|
+
# 6d coordinates order: (dist,omega,theta,phi)
|
|
207
|
+
c6d = torch.zeros([batch, nres, nres, 4], dtype=xyz.dtype, device=xyz.device)
|
|
208
|
+
|
|
209
|
+
if params["USE_CB"]:
|
|
210
|
+
dist = get_pair_dist(Cb, Cb)
|
|
211
|
+
else:
|
|
212
|
+
dist = get_pair_dist(Ca, Ca)
|
|
213
|
+
|
|
214
|
+
dist[torch.isnan(dist)] = 999.9
|
|
215
|
+
c6d[..., 0] = dist + 999.9 * torch.eye(nres, device=xyz.device)[None, ...]
|
|
216
|
+
b, i, j = torch.where(c6d[..., 0] < params["DMAX"])
|
|
217
|
+
|
|
218
|
+
c6d[b, i, j, torch.full_like(b, 1)] = get_dih(
|
|
219
|
+
Ca[b, i], Cb[b, i], Cb[b, j], Ca[b, j]
|
|
220
|
+
)
|
|
221
|
+
c6d[b, i, j, torch.full_like(b, 2)] = get_dih(N[b, i], Ca[b, i], Cb[b, i], Cb[b, j])
|
|
222
|
+
c6d[b, i, j, torch.full_like(b, 3)] = get_ang(Ca[b, i], Cb[b, i], Cb[b, j])
|
|
223
|
+
|
|
224
|
+
# fix long-range distances
|
|
225
|
+
c6d[..., 0][c6d[..., 0] >= params["DMAX"]] = 999.9
|
|
226
|
+
c6d = torch.nan_to_num(c6d)
|
|
227
|
+
|
|
228
|
+
return c6d
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def xyz_to_t2d(xyz_t, mask, has_rotation=None, params=PARAMS):
|
|
232
|
+
"""convert template cartesian coordinates into 2d distance
|
|
233
|
+
and orientation maps
|
|
234
|
+
|
|
235
|
+
Parameters
|
|
236
|
+
----------
|
|
237
|
+
xyz_t : pytorch tensor of shape [batch,templ,nres,3,3]
|
|
238
|
+
stores Cartesian coordinates of template backbone N,Ca,C atoms
|
|
239
|
+
mask : pytorch tensor [batch,templ,nres,nres]
|
|
240
|
+
indicates whether valid residue pairs or not
|
|
241
|
+
has_rotation : pytorch tensor [batch,templ, nres]
|
|
242
|
+
indicates whether a nodes has a rotation or not (eg atoms do not have rotations)
|
|
243
|
+
Returns
|
|
244
|
+
-------
|
|
245
|
+
t2d : pytorch tensor of shape [batch,nres,nres,37+6+3]
|
|
246
|
+
stores stacked dist,omega,theta,phi 2D maps
|
|
247
|
+
"""
|
|
248
|
+
B, T, L = xyz_t.shape[:3]
|
|
249
|
+
c6d = xyz_to_c6d(xyz_t[:, :, :, :3].view(B * T, L, 3, 3), params=params)
|
|
250
|
+
c6d = c6d.view(B, T, L, L, 4)
|
|
251
|
+
|
|
252
|
+
# dist to one-hot encoded
|
|
253
|
+
mask = mask[..., None]
|
|
254
|
+
dist = dist_to_onehot(c6d[..., 0], params) * mask
|
|
255
|
+
orien = (
|
|
256
|
+
torch.cat((torch.sin(c6d[..., 1:]), torch.cos(c6d[..., 1:])), dim=-1) * mask
|
|
257
|
+
) # (B, T, L, L, 6)
|
|
258
|
+
#
|
|
259
|
+
if has_rotation is not None:
|
|
260
|
+
has_rotation_2d = has_rotation[..., None, :] * has_rotation[..., None]
|
|
261
|
+
no_rotation_2d = ~has_rotation_2d
|
|
262
|
+
orien[:, :, no_rotation_2d] = 0.0
|
|
263
|
+
t2d = torch.cat((dist, orien, mask), dim=-1)
|
|
264
|
+
return t2d
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def xyz_to_bbtor(xyz, params=PARAMS):
|
|
268
|
+
# three anchor atoms
|
|
269
|
+
N = xyz[:, :, 0]
|
|
270
|
+
Ca = xyz[:, :, 1]
|
|
271
|
+
C = xyz[:, :, 2]
|
|
272
|
+
|
|
273
|
+
# recreate Cb given N,Ca,C
|
|
274
|
+
next_N = torch.roll(N, -1, dims=1)
|
|
275
|
+
prev_C = torch.roll(C, 1, dims=1)
|
|
276
|
+
phi = get_dih(prev_C, N, Ca, C)
|
|
277
|
+
psi = get_dih(N, Ca, C, next_N)
|
|
278
|
+
#
|
|
279
|
+
phi[:, 0] = 0.0
|
|
280
|
+
psi[:, -1] = 0.0
|
|
281
|
+
#
|
|
282
|
+
astep = 2.0 * np.pi / params["ABINS"]
|
|
283
|
+
phi_bin = torch.round((phi + np.pi - astep / 2) / astep)
|
|
284
|
+
psi_bin = torch.round((psi + np.pi - astep / 2) / astep)
|
|
285
|
+
return torch.stack([phi_bin, psi_bin], axis=-1).long()
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# ============================================================
|
|
289
|
+
def dist_to_onehot(dist, params=PARAMS):
|
|
290
|
+
db = dist_to_bins(dist, params)
|
|
291
|
+
dist = torch.nn.functional.one_hot(
|
|
292
|
+
db, num_classes=params["DBINS1"] + params["DBINS2"] + 1
|
|
293
|
+
).float()
|
|
294
|
+
return dist
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
# ============================================================
|
|
298
|
+
def dist_to_bins(dist, params=PARAMS):
|
|
299
|
+
"""bin 2d distance maps"""
|
|
300
|
+
dist[torch.isnan(dist)] = 999.9
|
|
301
|
+
dstep1 = (params["DMID"] - params["DMIN"]) / params["DBINS1"]
|
|
302
|
+
dstep2 = (params["DMAX"] - params["DMID"]) / params["DBINS2"]
|
|
303
|
+
dbins = torch.cat(
|
|
304
|
+
[
|
|
305
|
+
torch.linspace(
|
|
306
|
+
params["DMIN"] + dstep1,
|
|
307
|
+
params["DMID"],
|
|
308
|
+
params["DBINS1"],
|
|
309
|
+
dtype=dist.dtype,
|
|
310
|
+
device=dist.device,
|
|
311
|
+
),
|
|
312
|
+
torch.linspace(
|
|
313
|
+
params["DMID"] + dstep2,
|
|
314
|
+
params["DMAX"],
|
|
315
|
+
params["DBINS2"],
|
|
316
|
+
dtype=dist.dtype,
|
|
317
|
+
device=dist.device,
|
|
318
|
+
),
|
|
319
|
+
]
|
|
320
|
+
)
|
|
321
|
+
db = torch.bucketize(dist.contiguous(), dbins).long()
|
|
322
|
+
|
|
323
|
+
return db
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
# ============================================================
|
|
327
|
+
def c6d_to_bins(c6d, same_chain, negative=False, params=PARAMS):
|
|
328
|
+
"""bin 2d distance and orientation maps"""
|
|
329
|
+
|
|
330
|
+
db = dist_to_bins(c6d[..., 0], params) # all dist < DMIN are in bin 0
|
|
331
|
+
|
|
332
|
+
astep = 2.0 * np.pi / params["ABINS"]
|
|
333
|
+
ob = torch.round((c6d[..., 1] + np.pi - astep / 2) / astep)
|
|
334
|
+
tb = torch.round((c6d[..., 2] + np.pi - astep / 2) / astep)
|
|
335
|
+
pb = torch.round((c6d[..., 3] - astep / 2) / astep)
|
|
336
|
+
|
|
337
|
+
# synchronize no-contact bins
|
|
338
|
+
params["DBINS"] = params["DBINS1"] + params["DBINS2"]
|
|
339
|
+
ob[db == params["DBINS"]] = params["ABINS"]
|
|
340
|
+
tb[db == params["DBINS"]] = params["ABINS"]
|
|
341
|
+
pb[db == params["DBINS"]] = params["ABINS"] // 2
|
|
342
|
+
|
|
343
|
+
if negative:
|
|
344
|
+
db = torch.where(same_chain.bool(), db.long(), params["DBINS"])
|
|
345
|
+
ob = torch.where(same_chain.bool(), ob.long(), params["ABINS"])
|
|
346
|
+
tb = torch.where(same_chain.bool(), tb.long(), params["ABINS"])
|
|
347
|
+
pb = torch.where(same_chain.bool(), pb.long(), params["ABINS"] // 2)
|
|
348
|
+
|
|
349
|
+
return torch.stack([db, ob, tb, pb], axis=-1).long()
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def standardize_dihedral_retain_first(a, b, c, d):
|
|
353
|
+
isomorphisms = [(a, b, c, d), (a, c, b, d)]
|
|
354
|
+
return sorted(isomorphisms)[0]
|