boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,171 @@
1
+ # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2
+
3
+ from einops import einsum
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def weighted_rigid_align(
9
+ true_coords,
10
+ pred_coords,
11
+ weights,
12
+ mask,
13
+ ):
14
+ """Compute weighted alignment.
15
+
16
+ Parameters
17
+ ----------
18
+ true_coords: torch.Tensor
19
+ The ground truth atom coordinates
20
+ pred_coords: torch.Tensor
21
+ The predicted atom coordinates
22
+ weights: torch.Tensor
23
+ The weights for alignment
24
+ mask: torch.Tensor
25
+ The atoms mask
26
+
27
+ Returns
28
+ -------
29
+ torch.Tensor
30
+ Aligned coordinates
31
+
32
+ """
33
+
34
+ batch_size, num_points, dim = true_coords.shape
35
+ weights = (mask * weights).unsqueeze(-1)
36
+
37
+ # Compute weighted centroids
38
+ true_centroid = (true_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
39
+ dim=1, keepdim=True
40
+ )
41
+ pred_centroid = (pred_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
42
+ dim=1, keepdim=True
43
+ )
44
+
45
+ # Center the coordinates
46
+ true_coords_centered = true_coords - true_centroid
47
+ pred_coords_centered = pred_coords - pred_centroid
48
+
49
+ if num_points < (dim + 1):
50
+ print(
51
+ "Warning: The size of one of the point clouds is <= dim+1. "
52
+ + "`WeightedRigidAlign` cannot return a unique rotation."
53
+ )
54
+
55
+ # Compute the weighted covariance matrix
56
+ cov_matrix = einsum(
57
+ weights * pred_coords_centered, true_coords_centered, "b n i, b n j -> b i j"
58
+ )
59
+
60
+ # Compute the SVD of the covariance matrix, required float32 for svd and determinant
61
+ original_dtype = cov_matrix.dtype
62
+ cov_matrix_32 = cov_matrix.to(dtype=torch.float32)
63
+ U, S, V = torch.linalg.svd(
64
+ cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None
65
+ )
66
+ V = V.mH
67
+
68
+ # Catch ambiguous rotation by checking the magnitude of singular values
69
+ if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)):
70
+ print(
71
+ "Warning: Excessively low rank of "
72
+ + "cross-correlation between aligned point clouds. "
73
+ + "`WeightedRigidAlign` cannot return a unique rotation."
74
+ )
75
+
76
+ # Compute the rotation matrix
77
+ rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V).to(dtype=torch.float32)
78
+
79
+ # Ensure proper rotation matrix with determinant 1
80
+ F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[
81
+ None
82
+ ].repeat(batch_size, 1, 1)
83
+ F[:, -1, -1] = torch.det(rot_matrix)
84
+ rot_matrix = einsum(U, F, V, "b i j, b j k, b l k -> b i l")
85
+ rot_matrix = rot_matrix.to(dtype=original_dtype)
86
+
87
+ # Apply the rotation and translation
88
+ aligned_coords = (
89
+ einsum(true_coords_centered, rot_matrix, "b n i, b j i -> b n j")
90
+ + pred_centroid
91
+ )
92
+ aligned_coords.detach_()
93
+
94
+ return aligned_coords
95
+
96
+
97
+ def smooth_lddt_loss(
98
+ pred_coords,
99
+ true_coords,
100
+ is_nucleotide,
101
+ coords_mask,
102
+ nucleic_acid_cutoff: float = 30.0,
103
+ other_cutoff: float = 15.0,
104
+ multiplicity: int = 1,
105
+ ):
106
+ """Compute weighted alignment.
107
+
108
+ Parameters
109
+ ----------
110
+ pred_coords: torch.Tensor
111
+ The predicted atom coordinates
112
+ true_coords: torch.Tensor
113
+ The ground truth atom coordinates
114
+ is_nucleotide: torch.Tensor
115
+ The weights for alignment
116
+ coords_mask: torch.Tensor
117
+ The atoms mask
118
+ nucleic_acid_cutoff: float
119
+ The nucleic acid cutoff
120
+ other_cutoff: float
121
+ The non nucleic acid cutoff
122
+ multiplicity: int
123
+ The multiplicity
124
+ Returns
125
+ -------
126
+ torch.Tensor
127
+ Aligned coordinates
128
+
129
+ """
130
+ B, N, _ = true_coords.shape
131
+ true_dists = torch.cdist(true_coords, true_coords)
132
+ is_nucleotide = is_nucleotide.repeat_interleave(multiplicity, 0)
133
+
134
+ coords_mask = coords_mask.repeat_interleave(multiplicity, 0)
135
+ is_nucleotide_pair = is_nucleotide.unsqueeze(-1).expand(
136
+ -1, -1, is_nucleotide.shape[-1]
137
+ )
138
+
139
+ mask = (
140
+ is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float()
141
+ + (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float()
142
+ )
143
+ mask = mask * (1 - torch.eye(pred_coords.shape[1], device=pred_coords.device))
144
+ mask = mask * (coords_mask.unsqueeze(-1) * coords_mask.unsqueeze(-2))
145
+
146
+ # Compute distances between all pairs of atoms
147
+ pred_dists = torch.cdist(pred_coords, pred_coords)
148
+ dist_diff = torch.abs(true_dists - pred_dists)
149
+
150
+ # Compute epsilon values
151
+ eps = (
152
+ (
153
+ (
154
+ F.sigmoid(0.5 - dist_diff)
155
+ + F.sigmoid(1.0 - dist_diff)
156
+ + F.sigmoid(2.0 - dist_diff)
157
+ + F.sigmoid(4.0 - dist_diff)
158
+ )
159
+ / 4.0
160
+ )
161
+ .view(multiplicity, B // multiplicity, N, N)
162
+ .mean(dim=0)
163
+ )
164
+
165
+ # Calculate masked averaging
166
+ eps = eps.repeat_interleave(multiplicity, 0)
167
+ num = (eps * mask).sum(dim=(-1, -2))
168
+ den = mask.sum(dim=(-1, -2)).clamp(min=1)
169
+ lddt = num / den
170
+
171
+ return 1.0 - lddt.mean()
@@ -0,0 +1,134 @@
1
+ # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2
+
3
+ import einx
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import einsum, rearrange
7
+
8
+
9
+ def weighted_rigid_align(
10
+ true_coords, # Float['b n 3'], # true coordinates
11
+ pred_coords, # Float['b n 3'], # predicted coordinates
12
+ weights, # Float['b n'], # weights for each atom
13
+ mask, # Bool['b n'] | None = None # mask for variable lengths
14
+ ): # -> Float['b n 3']:
15
+ """Algorithm 28 : note there is a problem with the pseudocode in the paper where predicted and
16
+ GT are swapped in algorithm 28, but correct in equation (2)."""
17
+
18
+ batch_size, num_points, dim = true_coords.shape
19
+ weights = (mask * weights).unsqueeze(-1)
20
+
21
+ # Compute weighted centroids
22
+ true_centroid = (true_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
23
+ dim=1, keepdim=True
24
+ )
25
+ pred_centroid = (pred_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
26
+ dim=1, keepdim=True
27
+ )
28
+
29
+ # Center the coordinates
30
+ true_coords_centered = true_coords - true_centroid
31
+ pred_coords_centered = pred_coords - pred_centroid
32
+
33
+ if torch.any(mask.sum(dim=-1) < (dim + 1)):
34
+ print(
35
+ "Warning: The size of one of the point clouds is <= dim+1. "
36
+ + "`WeightedRigidAlign` cannot return a unique rotation."
37
+ )
38
+
39
+ # Compute the weighted covariance matrix
40
+ cov_matrix = einsum(
41
+ weights * pred_coords_centered, true_coords_centered, "b n i, b n j -> b i j"
42
+ )
43
+
44
+ # Compute the SVD of the covariance matrix, required float32 for svd and determinant
45
+ original_dtype = cov_matrix.dtype
46
+ cov_matrix_32 = cov_matrix.to(dtype=torch.float32)
47
+
48
+ U, S, V = torch.linalg.svd(
49
+ cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None
50
+ )
51
+ V = V.mH
52
+
53
+ # Catch ambiguous rotation by checking the magnitude of singular values
54
+ if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)):
55
+ print(
56
+ "Warning: Excessively low rank of "
57
+ + "cross-correlation between aligned point clouds. "
58
+ + "`WeightedRigidAlign` cannot return a unique rotation."
59
+ )
60
+
61
+ # Compute the rotation matrix
62
+ rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V).to(dtype=torch.float32)
63
+
64
+ # Ensure proper rotation matrix with determinant 1
65
+ F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[
66
+ None
67
+ ].repeat(batch_size, 1, 1)
68
+ F[:, -1, -1] = torch.det(rot_matrix)
69
+ rot_matrix = einsum(U, F, V, "b i j, b j k, b l k -> b i l")
70
+ rot_matrix = rot_matrix.to(dtype=original_dtype)
71
+
72
+ # Apply the rotation and translation
73
+ aligned_coords = (
74
+ einsum(true_coords_centered, rot_matrix, "b n i, b j i -> b n j")
75
+ + pred_centroid
76
+ )
77
+ aligned_coords.detach_()
78
+
79
+ return aligned_coords
80
+
81
+
82
+ def smooth_lddt_loss(
83
+ pred_coords, # Float['b n 3'],
84
+ true_coords, # Float['b n 3'],
85
+ is_nucleotide, # Bool['b n'],
86
+ coords_mask, # Bool['b n'] | None = None,
87
+ nucleic_acid_cutoff: float = 30.0,
88
+ other_cutoff: float = 15.0,
89
+ multiplicity: int = 1,
90
+ ): # -> Float['']:
91
+ """Algorithm 27
92
+ pred_coords: predicted coordinates
93
+ true_coords: true coordinates
94
+ Note: for efficiency pred_coords is the only one with the multiplicity expanded
95
+ TODO: add weighing which overweight the smooth lddt contribution close to t=0 (not present in the paper)
96
+ """
97
+ lddt = []
98
+ for i in range(true_coords.shape[0]):
99
+ true_dists = torch.cdist(true_coords[i], true_coords[i])
100
+
101
+ is_nucleotide_i = is_nucleotide[i // multiplicity]
102
+ coords_mask_i = coords_mask[i // multiplicity]
103
+
104
+ is_nucleotide_pair = is_nucleotide_i.unsqueeze(-1).expand(
105
+ -1, is_nucleotide_i.shape[-1]
106
+ )
107
+
108
+ mask = is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float()
109
+ mask += (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float()
110
+ mask *= 1 - torch.eye(pred_coords.shape[1], device=pred_coords.device)
111
+ mask *= coords_mask_i.unsqueeze(-1)
112
+ mask *= coords_mask_i.unsqueeze(-2)
113
+
114
+ valid_pairs = mask.nonzero()
115
+ true_dists_i = true_dists[valid_pairs[:, 0], valid_pairs[:, 1]]
116
+
117
+ pred_coords_i1 = pred_coords[i, valid_pairs[:, 0]]
118
+ pred_coords_i2 = pred_coords[i, valid_pairs[:, 1]]
119
+ pred_dists_i = F.pairwise_distance(pred_coords_i1, pred_coords_i2)
120
+
121
+ dist_diff_i = torch.abs(true_dists_i - pred_dists_i)
122
+
123
+ eps_i = (
124
+ F.sigmoid(0.5 - dist_diff_i)
125
+ + F.sigmoid(1.0 - dist_diff_i)
126
+ + F.sigmoid(2.0 - dist_diff_i)
127
+ + F.sigmoid(4.0 - dist_diff_i)
128
+ ) / 4.0
129
+
130
+ lddt_i = eps_i.sum() / (valid_pairs.shape[0] + 1e-5)
131
+ lddt.append(lddt_i)
132
+
133
+ # average over batch & multiplicity
134
+ return 1.0 - torch.stack(lddt, dim=0).mean(dim=0)
@@ -0,0 +1,48 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ def distogram_loss(
6
+ output: dict[str, Tensor],
7
+ feats: dict[str, Tensor],
8
+ ) -> tuple[Tensor, Tensor]:
9
+ """Compute the distogram loss.
10
+
11
+ Parameters
12
+ ----------
13
+ output : Dict[str, Tensor]
14
+ Output of the model
15
+ feats : Dict[str, Tensor]
16
+ Input features
17
+
18
+ Returns
19
+ -------
20
+ Tensor
21
+ The globally averaged loss.
22
+ Tensor
23
+ Per example loss.
24
+
25
+ """
26
+ # Get predicted distograms
27
+ pred = output["pdistogram"]
28
+
29
+ # Compute target distogram
30
+ target = feats["disto_target"]
31
+
32
+ # Combine target mask and padding mask
33
+ mask = feats["token_disto_mask"]
34
+ mask = mask[:, None, :] * mask[:, :, None]
35
+ mask = mask * (1 - torch.eye(mask.shape[1])[None]).to(pred)
36
+
37
+ # Compute the distogram loss
38
+ errors = -1 * torch.sum(
39
+ target * torch.nn.functional.log_softmax(pred, dim=-1),
40
+ dim=-1,
41
+ )
42
+ denom = 1e-5 + torch.sum(mask, dim=(-1, -2))
43
+ mean = errors * mask
44
+ mean = torch.sum(mean, dim=-1)
45
+ mean = mean / denom[..., None]
46
+ batch_loss = torch.sum(mean, dim=-1)
47
+ global_loss = torch.mean(batch_loss)
48
+ return global_loss, batch_loss
@@ -0,0 +1,105 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ def distogram_loss(
6
+ output: dict[str, Tensor],
7
+ feats: dict[str, Tensor],
8
+ aggregate_distogram: bool = True,
9
+ ) -> tuple[Tensor, Tensor]:
10
+ """Compute the distogram loss.
11
+
12
+ Parameters
13
+ ----------
14
+ output : Dict[str, Tensor]
15
+ Output of the model
16
+ feats : Dict[str, Tensor]
17
+ Input features
18
+
19
+ Returns
20
+ -------
21
+ Tensor
22
+ The globally averaged loss.
23
+ Tensor
24
+ Per example loss.
25
+
26
+ """
27
+ with torch.autocast("cuda", enabled=False):
28
+ # Get predicted distograms
29
+ pred = output["pdistogram"].float() # (B, L, L, num_distograms, disto_bins)
30
+ D = pred.shape[3] # num_distograms # noqa: N806
31
+ assert len(pred.shape) == 5 # noqa: PLR2004
32
+
33
+ # Compute target distogram
34
+ target = feats["disto_target"] # (B, L, L, K, disto_bins)
35
+ assert len(target.shape) == 5 # noqa: PLR2004
36
+
37
+ if aggregate_distogram:
38
+ msg = "Cannot aggregate GT distogram when num_distograms > 1"
39
+ assert pred.shape[3] == 1, msg
40
+
41
+ pred = pred.squeeze(3) # (B, L, L, disto_bins)
42
+
43
+ # Aggregate distogram over K conformers
44
+ target = target.sum(dim=3) # (B, L, L, disto_bins)
45
+
46
+ # Normalize distogram
47
+ P = target / target.sum(-1)[..., None].clamp(min=1) # noqa: N806
48
+
49
+ # Combine target mask and padding mask
50
+ mask = feats["token_disto_mask"]
51
+ mask = mask[:, None, :] * mask[:, :, None]
52
+ mask = mask * (1 - torch.eye(mask.shape[1])[None]).to(pred)
53
+
54
+ # Compute the distogram loss
55
+ log_Q = torch.nn.functional.log_softmax(pred, dim=-1) # noqa: N806
56
+ errors = -1 * torch.sum(
57
+ P * log_Q,
58
+ dim=-1,
59
+ )
60
+ denom = 1e-5 + torch.sum(mask, dim=(-1, -2))
61
+ mean = errors * mask
62
+ mean = torch.sum(mean, dim=-1)
63
+ mean = mean / denom[..., None]
64
+ batch_loss = torch.sum(mean, dim=-1)
65
+ global_loss = torch.mean(batch_loss)
66
+ else:
67
+ # We want to compute the loss for each pair of conformer K and predicted
68
+ # distogram
69
+
70
+ # Loop through conformers and compute the loss
71
+ batch_loss = []
72
+ for k in range(target.shape[3]):
73
+ # Get the target distogram for conformer k
74
+ # (B, L, L, K, disto_bins) -> (B, L, L, D, disto_bins)
75
+ P_k = target[:, :, :, k : k + 1, :].repeat_interleave(D, dim=3) # noqa: N806
76
+
77
+ # Compute the distogram loss to all predicted distograms
78
+ log_Q = torch.nn.functional.log_softmax(pred, dim=-1) # noqa: N806
79
+ errors = -1 * torch.sum(
80
+ P_k * log_Q,
81
+ dim=-1,
82
+ ) # (B, L, L, D)
83
+
84
+ # Compute mask
85
+ mask = feats["token_disto_mask"]
86
+ mask = mask[:, None, :] * mask[:, :, None]
87
+ mask = mask * (1 - torch.eye(mask.shape[1])[None]).to(pred)
88
+ mask = mask.unsqueeze(-1).repeat_interleave(D, -1) # (B, L, L, D)
89
+
90
+ denom = 1e-5 + torch.sum(mask, dim=(-2, -3)) # (B, D)
91
+ mean = errors * mask
92
+ mean = torch.sum(mean, dim=-2) # (B, L, D)
93
+ mean = mean / denom[..., None, :]
94
+ b_loss = torch.sum(mean, dim=-2) # (B, D)
95
+
96
+ batch_loss.append(b_loss)
97
+
98
+ batch_loss = torch.stack(batch_loss, dim=1) # (B, K, D)
99
+
100
+ # Compute the batch loss by taking the min over the predicted distograms
101
+ # and the average across conformers
102
+ batch_loss = torch.min(batch_loss, dim=-1).values.mean(dim=1)
103
+ global_loss = torch.mean(batch_loss)
104
+
105
+ return global_loss, batch_loss