boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,171 @@
|
|
1
|
+
# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
|
2
|
+
|
3
|
+
from einops import einsum
|
4
|
+
import torch
|
5
|
+
import torch.nn.functional as F
|
6
|
+
|
7
|
+
|
8
|
+
def weighted_rigid_align(
|
9
|
+
true_coords,
|
10
|
+
pred_coords,
|
11
|
+
weights,
|
12
|
+
mask,
|
13
|
+
):
|
14
|
+
"""Compute weighted alignment.
|
15
|
+
|
16
|
+
Parameters
|
17
|
+
----------
|
18
|
+
true_coords: torch.Tensor
|
19
|
+
The ground truth atom coordinates
|
20
|
+
pred_coords: torch.Tensor
|
21
|
+
The predicted atom coordinates
|
22
|
+
weights: torch.Tensor
|
23
|
+
The weights for alignment
|
24
|
+
mask: torch.Tensor
|
25
|
+
The atoms mask
|
26
|
+
|
27
|
+
Returns
|
28
|
+
-------
|
29
|
+
torch.Tensor
|
30
|
+
Aligned coordinates
|
31
|
+
|
32
|
+
"""
|
33
|
+
|
34
|
+
batch_size, num_points, dim = true_coords.shape
|
35
|
+
weights = (mask * weights).unsqueeze(-1)
|
36
|
+
|
37
|
+
# Compute weighted centroids
|
38
|
+
true_centroid = (true_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
|
39
|
+
dim=1, keepdim=True
|
40
|
+
)
|
41
|
+
pred_centroid = (pred_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
|
42
|
+
dim=1, keepdim=True
|
43
|
+
)
|
44
|
+
|
45
|
+
# Center the coordinates
|
46
|
+
true_coords_centered = true_coords - true_centroid
|
47
|
+
pred_coords_centered = pred_coords - pred_centroid
|
48
|
+
|
49
|
+
if num_points < (dim + 1):
|
50
|
+
print(
|
51
|
+
"Warning: The size of one of the point clouds is <= dim+1. "
|
52
|
+
+ "`WeightedRigidAlign` cannot return a unique rotation."
|
53
|
+
)
|
54
|
+
|
55
|
+
# Compute the weighted covariance matrix
|
56
|
+
cov_matrix = einsum(
|
57
|
+
weights * pred_coords_centered, true_coords_centered, "b n i, b n j -> b i j"
|
58
|
+
)
|
59
|
+
|
60
|
+
# Compute the SVD of the covariance matrix, required float32 for svd and determinant
|
61
|
+
original_dtype = cov_matrix.dtype
|
62
|
+
cov_matrix_32 = cov_matrix.to(dtype=torch.float32)
|
63
|
+
U, S, V = torch.linalg.svd(
|
64
|
+
cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None
|
65
|
+
)
|
66
|
+
V = V.mH
|
67
|
+
|
68
|
+
# Catch ambiguous rotation by checking the magnitude of singular values
|
69
|
+
if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)):
|
70
|
+
print(
|
71
|
+
"Warning: Excessively low rank of "
|
72
|
+
+ "cross-correlation between aligned point clouds. "
|
73
|
+
+ "`WeightedRigidAlign` cannot return a unique rotation."
|
74
|
+
)
|
75
|
+
|
76
|
+
# Compute the rotation matrix
|
77
|
+
rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V).to(dtype=torch.float32)
|
78
|
+
|
79
|
+
# Ensure proper rotation matrix with determinant 1
|
80
|
+
F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[
|
81
|
+
None
|
82
|
+
].repeat(batch_size, 1, 1)
|
83
|
+
F[:, -1, -1] = torch.det(rot_matrix)
|
84
|
+
rot_matrix = einsum(U, F, V, "b i j, b j k, b l k -> b i l")
|
85
|
+
rot_matrix = rot_matrix.to(dtype=original_dtype)
|
86
|
+
|
87
|
+
# Apply the rotation and translation
|
88
|
+
aligned_coords = (
|
89
|
+
einsum(true_coords_centered, rot_matrix, "b n i, b j i -> b n j")
|
90
|
+
+ pred_centroid
|
91
|
+
)
|
92
|
+
aligned_coords.detach_()
|
93
|
+
|
94
|
+
return aligned_coords
|
95
|
+
|
96
|
+
|
97
|
+
def smooth_lddt_loss(
|
98
|
+
pred_coords,
|
99
|
+
true_coords,
|
100
|
+
is_nucleotide,
|
101
|
+
coords_mask,
|
102
|
+
nucleic_acid_cutoff: float = 30.0,
|
103
|
+
other_cutoff: float = 15.0,
|
104
|
+
multiplicity: int = 1,
|
105
|
+
):
|
106
|
+
"""Compute weighted alignment.
|
107
|
+
|
108
|
+
Parameters
|
109
|
+
----------
|
110
|
+
pred_coords: torch.Tensor
|
111
|
+
The predicted atom coordinates
|
112
|
+
true_coords: torch.Tensor
|
113
|
+
The ground truth atom coordinates
|
114
|
+
is_nucleotide: torch.Tensor
|
115
|
+
The weights for alignment
|
116
|
+
coords_mask: torch.Tensor
|
117
|
+
The atoms mask
|
118
|
+
nucleic_acid_cutoff: float
|
119
|
+
The nucleic acid cutoff
|
120
|
+
other_cutoff: float
|
121
|
+
The non nucleic acid cutoff
|
122
|
+
multiplicity: int
|
123
|
+
The multiplicity
|
124
|
+
Returns
|
125
|
+
-------
|
126
|
+
torch.Tensor
|
127
|
+
Aligned coordinates
|
128
|
+
|
129
|
+
"""
|
130
|
+
B, N, _ = true_coords.shape
|
131
|
+
true_dists = torch.cdist(true_coords, true_coords)
|
132
|
+
is_nucleotide = is_nucleotide.repeat_interleave(multiplicity, 0)
|
133
|
+
|
134
|
+
coords_mask = coords_mask.repeat_interleave(multiplicity, 0)
|
135
|
+
is_nucleotide_pair = is_nucleotide.unsqueeze(-1).expand(
|
136
|
+
-1, -1, is_nucleotide.shape[-1]
|
137
|
+
)
|
138
|
+
|
139
|
+
mask = (
|
140
|
+
is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float()
|
141
|
+
+ (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float()
|
142
|
+
)
|
143
|
+
mask = mask * (1 - torch.eye(pred_coords.shape[1], device=pred_coords.device))
|
144
|
+
mask = mask * (coords_mask.unsqueeze(-1) * coords_mask.unsqueeze(-2))
|
145
|
+
|
146
|
+
# Compute distances between all pairs of atoms
|
147
|
+
pred_dists = torch.cdist(pred_coords, pred_coords)
|
148
|
+
dist_diff = torch.abs(true_dists - pred_dists)
|
149
|
+
|
150
|
+
# Compute epsilon values
|
151
|
+
eps = (
|
152
|
+
(
|
153
|
+
(
|
154
|
+
F.sigmoid(0.5 - dist_diff)
|
155
|
+
+ F.sigmoid(1.0 - dist_diff)
|
156
|
+
+ F.sigmoid(2.0 - dist_diff)
|
157
|
+
+ F.sigmoid(4.0 - dist_diff)
|
158
|
+
)
|
159
|
+
/ 4.0
|
160
|
+
)
|
161
|
+
.view(multiplicity, B // multiplicity, N, N)
|
162
|
+
.mean(dim=0)
|
163
|
+
)
|
164
|
+
|
165
|
+
# Calculate masked averaging
|
166
|
+
eps = eps.repeat_interleave(multiplicity, 0)
|
167
|
+
num = (eps * mask).sum(dim=(-1, -2))
|
168
|
+
den = mask.sum(dim=(-1, -2)).clamp(min=1)
|
169
|
+
lddt = num / den
|
170
|
+
|
171
|
+
return 1.0 - lddt.mean()
|
@@ -0,0 +1,134 @@
|
|
1
|
+
# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
|
2
|
+
|
3
|
+
import einx
|
4
|
+
import torch
|
5
|
+
import torch.nn.functional as F
|
6
|
+
from einops import einsum, rearrange
|
7
|
+
|
8
|
+
|
9
|
+
def weighted_rigid_align(
|
10
|
+
true_coords, # Float['b n 3'], # true coordinates
|
11
|
+
pred_coords, # Float['b n 3'], # predicted coordinates
|
12
|
+
weights, # Float['b n'], # weights for each atom
|
13
|
+
mask, # Bool['b n'] | None = None # mask for variable lengths
|
14
|
+
): # -> Float['b n 3']:
|
15
|
+
"""Algorithm 28 : note there is a problem with the pseudocode in the paper where predicted and
|
16
|
+
GT are swapped in algorithm 28, but correct in equation (2)."""
|
17
|
+
|
18
|
+
batch_size, num_points, dim = true_coords.shape
|
19
|
+
weights = (mask * weights).unsqueeze(-1)
|
20
|
+
|
21
|
+
# Compute weighted centroids
|
22
|
+
true_centroid = (true_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
|
23
|
+
dim=1, keepdim=True
|
24
|
+
)
|
25
|
+
pred_centroid = (pred_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
|
26
|
+
dim=1, keepdim=True
|
27
|
+
)
|
28
|
+
|
29
|
+
# Center the coordinates
|
30
|
+
true_coords_centered = true_coords - true_centroid
|
31
|
+
pred_coords_centered = pred_coords - pred_centroid
|
32
|
+
|
33
|
+
if torch.any(mask.sum(dim=-1) < (dim + 1)):
|
34
|
+
print(
|
35
|
+
"Warning: The size of one of the point clouds is <= dim+1. "
|
36
|
+
+ "`WeightedRigidAlign` cannot return a unique rotation."
|
37
|
+
)
|
38
|
+
|
39
|
+
# Compute the weighted covariance matrix
|
40
|
+
cov_matrix = einsum(
|
41
|
+
weights * pred_coords_centered, true_coords_centered, "b n i, b n j -> b i j"
|
42
|
+
)
|
43
|
+
|
44
|
+
# Compute the SVD of the covariance matrix, required float32 for svd and determinant
|
45
|
+
original_dtype = cov_matrix.dtype
|
46
|
+
cov_matrix_32 = cov_matrix.to(dtype=torch.float32)
|
47
|
+
|
48
|
+
U, S, V = torch.linalg.svd(
|
49
|
+
cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None
|
50
|
+
)
|
51
|
+
V = V.mH
|
52
|
+
|
53
|
+
# Catch ambiguous rotation by checking the magnitude of singular values
|
54
|
+
if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)):
|
55
|
+
print(
|
56
|
+
"Warning: Excessively low rank of "
|
57
|
+
+ "cross-correlation between aligned point clouds. "
|
58
|
+
+ "`WeightedRigidAlign` cannot return a unique rotation."
|
59
|
+
)
|
60
|
+
|
61
|
+
# Compute the rotation matrix
|
62
|
+
rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V).to(dtype=torch.float32)
|
63
|
+
|
64
|
+
# Ensure proper rotation matrix with determinant 1
|
65
|
+
F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[
|
66
|
+
None
|
67
|
+
].repeat(batch_size, 1, 1)
|
68
|
+
F[:, -1, -1] = torch.det(rot_matrix)
|
69
|
+
rot_matrix = einsum(U, F, V, "b i j, b j k, b l k -> b i l")
|
70
|
+
rot_matrix = rot_matrix.to(dtype=original_dtype)
|
71
|
+
|
72
|
+
# Apply the rotation and translation
|
73
|
+
aligned_coords = (
|
74
|
+
einsum(true_coords_centered, rot_matrix, "b n i, b j i -> b n j")
|
75
|
+
+ pred_centroid
|
76
|
+
)
|
77
|
+
aligned_coords.detach_()
|
78
|
+
|
79
|
+
return aligned_coords
|
80
|
+
|
81
|
+
|
82
|
+
def smooth_lddt_loss(
|
83
|
+
pred_coords, # Float['b n 3'],
|
84
|
+
true_coords, # Float['b n 3'],
|
85
|
+
is_nucleotide, # Bool['b n'],
|
86
|
+
coords_mask, # Bool['b n'] | None = None,
|
87
|
+
nucleic_acid_cutoff: float = 30.0,
|
88
|
+
other_cutoff: float = 15.0,
|
89
|
+
multiplicity: int = 1,
|
90
|
+
): # -> Float['']:
|
91
|
+
"""Algorithm 27
|
92
|
+
pred_coords: predicted coordinates
|
93
|
+
true_coords: true coordinates
|
94
|
+
Note: for efficiency pred_coords is the only one with the multiplicity expanded
|
95
|
+
TODO: add weighing which overweight the smooth lddt contribution close to t=0 (not present in the paper)
|
96
|
+
"""
|
97
|
+
lddt = []
|
98
|
+
for i in range(true_coords.shape[0]):
|
99
|
+
true_dists = torch.cdist(true_coords[i], true_coords[i])
|
100
|
+
|
101
|
+
is_nucleotide_i = is_nucleotide[i // multiplicity]
|
102
|
+
coords_mask_i = coords_mask[i // multiplicity]
|
103
|
+
|
104
|
+
is_nucleotide_pair = is_nucleotide_i.unsqueeze(-1).expand(
|
105
|
+
-1, is_nucleotide_i.shape[-1]
|
106
|
+
)
|
107
|
+
|
108
|
+
mask = is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float()
|
109
|
+
mask += (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float()
|
110
|
+
mask *= 1 - torch.eye(pred_coords.shape[1], device=pred_coords.device)
|
111
|
+
mask *= coords_mask_i.unsqueeze(-1)
|
112
|
+
mask *= coords_mask_i.unsqueeze(-2)
|
113
|
+
|
114
|
+
valid_pairs = mask.nonzero()
|
115
|
+
true_dists_i = true_dists[valid_pairs[:, 0], valid_pairs[:, 1]]
|
116
|
+
|
117
|
+
pred_coords_i1 = pred_coords[i, valid_pairs[:, 0]]
|
118
|
+
pred_coords_i2 = pred_coords[i, valid_pairs[:, 1]]
|
119
|
+
pred_dists_i = F.pairwise_distance(pred_coords_i1, pred_coords_i2)
|
120
|
+
|
121
|
+
dist_diff_i = torch.abs(true_dists_i - pred_dists_i)
|
122
|
+
|
123
|
+
eps_i = (
|
124
|
+
F.sigmoid(0.5 - dist_diff_i)
|
125
|
+
+ F.sigmoid(1.0 - dist_diff_i)
|
126
|
+
+ F.sigmoid(2.0 - dist_diff_i)
|
127
|
+
+ F.sigmoid(4.0 - dist_diff_i)
|
128
|
+
) / 4.0
|
129
|
+
|
130
|
+
lddt_i = eps_i.sum() / (valid_pairs.shape[0] + 1e-5)
|
131
|
+
lddt.append(lddt_i)
|
132
|
+
|
133
|
+
# average over batch & multiplicity
|
134
|
+
return 1.0 - torch.stack(lddt, dim=0).mean(dim=0)
|
@@ -0,0 +1,48 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import Tensor
|
3
|
+
|
4
|
+
|
5
|
+
def distogram_loss(
|
6
|
+
output: dict[str, Tensor],
|
7
|
+
feats: dict[str, Tensor],
|
8
|
+
) -> tuple[Tensor, Tensor]:
|
9
|
+
"""Compute the distogram loss.
|
10
|
+
|
11
|
+
Parameters
|
12
|
+
----------
|
13
|
+
output : Dict[str, Tensor]
|
14
|
+
Output of the model
|
15
|
+
feats : Dict[str, Tensor]
|
16
|
+
Input features
|
17
|
+
|
18
|
+
Returns
|
19
|
+
-------
|
20
|
+
Tensor
|
21
|
+
The globally averaged loss.
|
22
|
+
Tensor
|
23
|
+
Per example loss.
|
24
|
+
|
25
|
+
"""
|
26
|
+
# Get predicted distograms
|
27
|
+
pred = output["pdistogram"]
|
28
|
+
|
29
|
+
# Compute target distogram
|
30
|
+
target = feats["disto_target"]
|
31
|
+
|
32
|
+
# Combine target mask and padding mask
|
33
|
+
mask = feats["token_disto_mask"]
|
34
|
+
mask = mask[:, None, :] * mask[:, :, None]
|
35
|
+
mask = mask * (1 - torch.eye(mask.shape[1])[None]).to(pred)
|
36
|
+
|
37
|
+
# Compute the distogram loss
|
38
|
+
errors = -1 * torch.sum(
|
39
|
+
target * torch.nn.functional.log_softmax(pred, dim=-1),
|
40
|
+
dim=-1,
|
41
|
+
)
|
42
|
+
denom = 1e-5 + torch.sum(mask, dim=(-1, -2))
|
43
|
+
mean = errors * mask
|
44
|
+
mean = torch.sum(mean, dim=-1)
|
45
|
+
mean = mean / denom[..., None]
|
46
|
+
batch_loss = torch.sum(mean, dim=-1)
|
47
|
+
global_loss = torch.mean(batch_loss)
|
48
|
+
return global_loss, batch_loss
|
@@ -0,0 +1,105 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import Tensor
|
3
|
+
|
4
|
+
|
5
|
+
def distogram_loss(
|
6
|
+
output: dict[str, Tensor],
|
7
|
+
feats: dict[str, Tensor],
|
8
|
+
aggregate_distogram: bool = True,
|
9
|
+
) -> tuple[Tensor, Tensor]:
|
10
|
+
"""Compute the distogram loss.
|
11
|
+
|
12
|
+
Parameters
|
13
|
+
----------
|
14
|
+
output : Dict[str, Tensor]
|
15
|
+
Output of the model
|
16
|
+
feats : Dict[str, Tensor]
|
17
|
+
Input features
|
18
|
+
|
19
|
+
Returns
|
20
|
+
-------
|
21
|
+
Tensor
|
22
|
+
The globally averaged loss.
|
23
|
+
Tensor
|
24
|
+
Per example loss.
|
25
|
+
|
26
|
+
"""
|
27
|
+
with torch.autocast("cuda", enabled=False):
|
28
|
+
# Get predicted distograms
|
29
|
+
pred = output["pdistogram"].float() # (B, L, L, num_distograms, disto_bins)
|
30
|
+
D = pred.shape[3] # num_distograms # noqa: N806
|
31
|
+
assert len(pred.shape) == 5 # noqa: PLR2004
|
32
|
+
|
33
|
+
# Compute target distogram
|
34
|
+
target = feats["disto_target"] # (B, L, L, K, disto_bins)
|
35
|
+
assert len(target.shape) == 5 # noqa: PLR2004
|
36
|
+
|
37
|
+
if aggregate_distogram:
|
38
|
+
msg = "Cannot aggregate GT distogram when num_distograms > 1"
|
39
|
+
assert pred.shape[3] == 1, msg
|
40
|
+
|
41
|
+
pred = pred.squeeze(3) # (B, L, L, disto_bins)
|
42
|
+
|
43
|
+
# Aggregate distogram over K conformers
|
44
|
+
target = target.sum(dim=3) # (B, L, L, disto_bins)
|
45
|
+
|
46
|
+
# Normalize distogram
|
47
|
+
P = target / target.sum(-1)[..., None].clamp(min=1) # noqa: N806
|
48
|
+
|
49
|
+
# Combine target mask and padding mask
|
50
|
+
mask = feats["token_disto_mask"]
|
51
|
+
mask = mask[:, None, :] * mask[:, :, None]
|
52
|
+
mask = mask * (1 - torch.eye(mask.shape[1])[None]).to(pred)
|
53
|
+
|
54
|
+
# Compute the distogram loss
|
55
|
+
log_Q = torch.nn.functional.log_softmax(pred, dim=-1) # noqa: N806
|
56
|
+
errors = -1 * torch.sum(
|
57
|
+
P * log_Q,
|
58
|
+
dim=-1,
|
59
|
+
)
|
60
|
+
denom = 1e-5 + torch.sum(mask, dim=(-1, -2))
|
61
|
+
mean = errors * mask
|
62
|
+
mean = torch.sum(mean, dim=-1)
|
63
|
+
mean = mean / denom[..., None]
|
64
|
+
batch_loss = torch.sum(mean, dim=-1)
|
65
|
+
global_loss = torch.mean(batch_loss)
|
66
|
+
else:
|
67
|
+
# We want to compute the loss for each pair of conformer K and predicted
|
68
|
+
# distogram
|
69
|
+
|
70
|
+
# Loop through conformers and compute the loss
|
71
|
+
batch_loss = []
|
72
|
+
for k in range(target.shape[3]):
|
73
|
+
# Get the target distogram for conformer k
|
74
|
+
# (B, L, L, K, disto_bins) -> (B, L, L, D, disto_bins)
|
75
|
+
P_k = target[:, :, :, k : k + 1, :].repeat_interleave(D, dim=3) # noqa: N806
|
76
|
+
|
77
|
+
# Compute the distogram loss to all predicted distograms
|
78
|
+
log_Q = torch.nn.functional.log_softmax(pred, dim=-1) # noqa: N806
|
79
|
+
errors = -1 * torch.sum(
|
80
|
+
P_k * log_Q,
|
81
|
+
dim=-1,
|
82
|
+
) # (B, L, L, D)
|
83
|
+
|
84
|
+
# Compute mask
|
85
|
+
mask = feats["token_disto_mask"]
|
86
|
+
mask = mask[:, None, :] * mask[:, :, None]
|
87
|
+
mask = mask * (1 - torch.eye(mask.shape[1])[None]).to(pred)
|
88
|
+
mask = mask.unsqueeze(-1).repeat_interleave(D, -1) # (B, L, L, D)
|
89
|
+
|
90
|
+
denom = 1e-5 + torch.sum(mask, dim=(-2, -3)) # (B, D)
|
91
|
+
mean = errors * mask
|
92
|
+
mean = torch.sum(mean, dim=-2) # (B, L, D)
|
93
|
+
mean = mean / denom[..., None, :]
|
94
|
+
b_loss = torch.sum(mean, dim=-2) # (B, D)
|
95
|
+
|
96
|
+
batch_loss.append(b_loss)
|
97
|
+
|
98
|
+
batch_loss = torch.stack(batch_loss, dim=1) # (B, K, D)
|
99
|
+
|
100
|
+
# Compute the batch loss by taking the min over the predicted distograms
|
101
|
+
# and the average across conformers
|
102
|
+
batch_loss = torch.min(batch_loss, dim=-1).values.mean(dim=1)
|
103
|
+
global_loss = torch.mean(batch_loss)
|
104
|
+
|
105
|
+
return global_loss, batch_loss
|