rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rf3/kinematics.py ADDED
@@ -0,0 +1,354 @@
1
+ # TODO: Many of these functions are unused; we will deprecate and delete
2
+ # (They are holdovers from previous frameworks)
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ PARAMS = {
8
+ "DMIN": 1,
9
+ "DMID": 4,
10
+ "DMAX": 20.0,
11
+ "DBINS1": 30,
12
+ "DBINS2": 30,
13
+ "ABINS": 36,
14
+ "USE_CB": False,
15
+ }
16
+
17
+
18
+ # ============================================================
19
+ def normQ(Q):
20
+ """normalize a quaternions"""
21
+ return Q / torch.linalg.norm(Q, keepdim=True, dim=-1)
22
+
23
+
24
+ # ============================================================
25
+ def avgQ(Qs):
26
+ """average a set of quaternions
27
+ input dims:
28
+ Qs - (B,N,R,4)
29
+ averages across 'N' dimension
30
+ """
31
+
32
+ def areClose(q1, q2):
33
+ return (q1 * q2).sum(dim=-1) >= 0.0
34
+
35
+ N = Qs.shape[1]
36
+ Qsum = Qs[:, 0] / N
37
+
38
+ for i in range(1, N):
39
+ mask = areClose(Qs[:, 0], Qs[:, i])
40
+ Qsum[mask] += Qs[:, i][mask] / N
41
+ Qsum[~mask] -= Qs[:, i][~mask] / N
42
+
43
+ return normQ(Qsum)
44
+
45
+
46
+ def Rs2Qs(Rs):
47
+ Qs = torch.zeros((*Rs.shape[:-2], 4), device=Rs.device)
48
+
49
+ Qs[..., 0] = 1.0 + Rs[..., 0, 0] + Rs[..., 1, 1] + Rs[..., 2, 2]
50
+ Qs[..., 1] = 1.0 + Rs[..., 0, 0] - Rs[..., 1, 1] - Rs[..., 2, 2]
51
+ Qs[..., 2] = 1.0 - Rs[..., 0, 0] + Rs[..., 1, 1] - Rs[..., 2, 2]
52
+ Qs[..., 3] = 1.0 - Rs[..., 0, 0] - Rs[..., 1, 1] + Rs[..., 2, 2]
53
+ Qs[Qs < 0.0] = 0.0
54
+ Qs = torch.sqrt(Qs) / 2.0
55
+ Qs[..., 1] *= torch.sign(Rs[..., 2, 1] - Rs[..., 1, 2])
56
+ Qs[..., 2] *= torch.sign(Rs[..., 0, 2] - Rs[..., 2, 0])
57
+ Qs[..., 3] *= torch.sign(Rs[..., 1, 0] - Rs[..., 0, 1])
58
+
59
+ return Qs
60
+
61
+
62
+ def Qs2Rs(Qs):
63
+ Rs = torch.zeros((*Qs.shape[:-1], 3, 3), device=Qs.device)
64
+
65
+ Rs[..., 0, 0] = (
66
+ Qs[..., 0] * Qs[..., 0]
67
+ + Qs[..., 1] * Qs[..., 1]
68
+ - Qs[..., 2] * Qs[..., 2]
69
+ - Qs[..., 3] * Qs[..., 3]
70
+ )
71
+ Rs[..., 0, 1] = 2 * Qs[..., 1] * Qs[..., 2] - 2 * Qs[..., 0] * Qs[..., 3]
72
+ Rs[..., 0, 2] = 2 * Qs[..., 1] * Qs[..., 3] + 2 * Qs[..., 0] * Qs[..., 2]
73
+ Rs[..., 1, 0] = 2 * Qs[..., 1] * Qs[..., 2] + 2 * Qs[..., 0] * Qs[..., 3]
74
+ Rs[..., 1, 1] = (
75
+ Qs[..., 0] * Qs[..., 0]
76
+ - Qs[..., 1] * Qs[..., 1]
77
+ + Qs[..., 2] * Qs[..., 2]
78
+ - Qs[..., 3] * Qs[..., 3]
79
+ )
80
+ Rs[..., 1, 2] = 2 * Qs[..., 2] * Qs[..., 3] - 2 * Qs[..., 0] * Qs[..., 1]
81
+ Rs[..., 2, 0] = 2 * Qs[..., 1] * Qs[..., 3] - 2 * Qs[..., 0] * Qs[..., 2]
82
+ Rs[..., 2, 1] = 2 * Qs[..., 2] * Qs[..., 3] + 2 * Qs[..., 0] * Qs[..., 1]
83
+ Rs[..., 2, 2] = (
84
+ Qs[..., 0] * Qs[..., 0]
85
+ - Qs[..., 1] * Qs[..., 1]
86
+ - Qs[..., 2] * Qs[..., 2]
87
+ + Qs[..., 3] * Qs[..., 3]
88
+ )
89
+
90
+ return Rs
91
+
92
+
93
+ # ============================================================
94
+ def get_pair_dist(a, b):
95
+ """calculate pair distances between two sets of points
96
+
97
+ Parameters
98
+ ----------
99
+ a,b : pytorch tensors of shape [batch,nres,3]
100
+ store Cartesian coordinates of two sets of atoms
101
+ Returns
102
+ -------
103
+ dist : pytorch tensor of shape [batch,nres,nres]
104
+ stores pairwise distances between atoms in a and b
105
+ """
106
+
107
+ dist = torch.cdist(a, b, p=2)
108
+ return dist
109
+
110
+
111
+ # ============================================================
112
+ def get_ang(a, b, c, eps=1e-4):
113
+ """calculate planar angles for all consecutive triples (a[i],b[i],c[i])
114
+ from Cartesian coordinates of three sets of atoms a,b,c
115
+
116
+ Parameters
117
+ ----------
118
+ a,b,c : pytorch tensors of shape [batch,nres,3]
119
+ store Cartesian coordinates of three sets of atoms
120
+ Returns
121
+ -------
122
+ ang : pytorch tensor of shape [batch,nres]
123
+ stores resulting planar angles
124
+ """
125
+ v = a - b
126
+ w = c - b
127
+ vn = v / (torch.norm(v, dim=-1, keepdim=True) + eps)
128
+ wn = w / (torch.norm(w, dim=-1, keepdim=True) + eps)
129
+ vw = torch.sum(vn * wn, dim=-1)
130
+
131
+ return torch.acos(torch.clamp(vw, -0.999, 0.999))
132
+
133
+
134
+ # ============================================================
135
+ def get_dih(a, b, c, d, eps=1e-4):
136
+ """calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i])
137
+ given Cartesian coordinates of four sets of atoms a,b,c,d
138
+
139
+ Parameters
140
+ ----------
141
+ a,b,c,d : pytorch tensors of shape [batch,nres,3]
142
+ store Cartesian coordinates of four sets of atoms
143
+ Returns
144
+ -------
145
+ dih : pytorch tensor of shape [batch,nres]
146
+ stores resulting dihedrals
147
+ """
148
+ b0 = a - b
149
+ b1 = c - b
150
+ b2 = d - c
151
+
152
+ b1n = b1 / (torch.norm(b1, dim=-1, keepdim=True) + eps)
153
+
154
+ v = b0 - torch.sum(b0 * b1n, dim=-1, keepdim=True) * b1n
155
+ w = b2 - torch.sum(b2 * b1n, dim=-1, keepdim=True) * b1n
156
+
157
+ x = torch.sum(v * w, dim=-1)
158
+ y = torch.sum(torch.cross(b1n, v, dim=-1) * w, dim=-1)
159
+
160
+ return torch.atan2(y + eps, x + eps)
161
+
162
+
163
+ # ============================================================
164
+
165
+
166
+ def generate_Cbeta(N, Ca, C):
167
+ # recreate Cb given N,Ca,C
168
+ b = Ca - N
169
+ c = C - Ca
170
+ a = torch.cross(b, c, dim=-1)
171
+ # Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
172
+ # fd: below matches sidechain generator (=Rosetta params)
173
+ Cb = -0.57910144 * a + 0.5689693 * b - 0.5441217 * c + Ca
174
+
175
+ return Cb
176
+
177
+
178
+ # ============================================================
179
+
180
+
181
+ def xyz_to_c6d(xyz, params=PARAMS):
182
+ """convert cartesian coordinates into 2d distance
183
+ and orientation maps
184
+
185
+ Parameters
186
+ ----------
187
+ xyz : pytorch tensor of shape [batch,nres,3,3]
188
+ stores Cartesian coordinates of backbone N,Ca,C atoms
189
+ Returns
190
+ -------
191
+ c6d : pytorch tensor of shape [batch,nres,nres,4]
192
+ stores stacked dist,omega,theta,phi 2D maps
193
+ """
194
+
195
+ batch = xyz.shape[0]
196
+ nres = xyz.shape[1]
197
+
198
+ # three anchor atoms
199
+ N = xyz[:, :, 0]
200
+ Ca = xyz[:, :, 1]
201
+ C = xyz[:, :, 2]
202
+
203
+ # recreate Cb given N,Ca,C
204
+ Cb = generate_Cbeta(N, Ca, C)
205
+
206
+ # 6d coordinates order: (dist,omega,theta,phi)
207
+ c6d = torch.zeros([batch, nres, nres, 4], dtype=xyz.dtype, device=xyz.device)
208
+
209
+ if params["USE_CB"]:
210
+ dist = get_pair_dist(Cb, Cb)
211
+ else:
212
+ dist = get_pair_dist(Ca, Ca)
213
+
214
+ dist[torch.isnan(dist)] = 999.9
215
+ c6d[..., 0] = dist + 999.9 * torch.eye(nres, device=xyz.device)[None, ...]
216
+ b, i, j = torch.where(c6d[..., 0] < params["DMAX"])
217
+
218
+ c6d[b, i, j, torch.full_like(b, 1)] = get_dih(
219
+ Ca[b, i], Cb[b, i], Cb[b, j], Ca[b, j]
220
+ )
221
+ c6d[b, i, j, torch.full_like(b, 2)] = get_dih(N[b, i], Ca[b, i], Cb[b, i], Cb[b, j])
222
+ c6d[b, i, j, torch.full_like(b, 3)] = get_ang(Ca[b, i], Cb[b, i], Cb[b, j])
223
+
224
+ # fix long-range distances
225
+ c6d[..., 0][c6d[..., 0] >= params["DMAX"]] = 999.9
226
+ c6d = torch.nan_to_num(c6d)
227
+
228
+ return c6d
229
+
230
+
231
+ def xyz_to_t2d(xyz_t, mask, has_rotation=None, params=PARAMS):
232
+ """convert template cartesian coordinates into 2d distance
233
+ and orientation maps
234
+
235
+ Parameters
236
+ ----------
237
+ xyz_t : pytorch tensor of shape [batch,templ,nres,3,3]
238
+ stores Cartesian coordinates of template backbone N,Ca,C atoms
239
+ mask : pytorch tensor [batch,templ,nres,nres]
240
+ indicates whether valid residue pairs or not
241
+ has_rotation : pytorch tensor [batch,templ, nres]
242
+ indicates whether a nodes has a rotation or not (eg atoms do not have rotations)
243
+ Returns
244
+ -------
245
+ t2d : pytorch tensor of shape [batch,nres,nres,37+6+3]
246
+ stores stacked dist,omega,theta,phi 2D maps
247
+ """
248
+ B, T, L = xyz_t.shape[:3]
249
+ c6d = xyz_to_c6d(xyz_t[:, :, :, :3].view(B * T, L, 3, 3), params=params)
250
+ c6d = c6d.view(B, T, L, L, 4)
251
+
252
+ # dist to one-hot encoded
253
+ mask = mask[..., None]
254
+ dist = dist_to_onehot(c6d[..., 0], params) * mask
255
+ orien = (
256
+ torch.cat((torch.sin(c6d[..., 1:]), torch.cos(c6d[..., 1:])), dim=-1) * mask
257
+ ) # (B, T, L, L, 6)
258
+ #
259
+ if has_rotation is not None:
260
+ has_rotation_2d = has_rotation[..., None, :] * has_rotation[..., None]
261
+ no_rotation_2d = ~has_rotation_2d
262
+ orien[:, :, no_rotation_2d] = 0.0
263
+ t2d = torch.cat((dist, orien, mask), dim=-1)
264
+ return t2d
265
+
266
+
267
+ def xyz_to_bbtor(xyz, params=PARAMS):
268
+ # three anchor atoms
269
+ N = xyz[:, :, 0]
270
+ Ca = xyz[:, :, 1]
271
+ C = xyz[:, :, 2]
272
+
273
+ # recreate Cb given N,Ca,C
274
+ next_N = torch.roll(N, -1, dims=1)
275
+ prev_C = torch.roll(C, 1, dims=1)
276
+ phi = get_dih(prev_C, N, Ca, C)
277
+ psi = get_dih(N, Ca, C, next_N)
278
+ #
279
+ phi[:, 0] = 0.0
280
+ psi[:, -1] = 0.0
281
+ #
282
+ astep = 2.0 * np.pi / params["ABINS"]
283
+ phi_bin = torch.round((phi + np.pi - astep / 2) / astep)
284
+ psi_bin = torch.round((psi + np.pi - astep / 2) / astep)
285
+ return torch.stack([phi_bin, psi_bin], axis=-1).long()
286
+
287
+
288
+ # ============================================================
289
+ def dist_to_onehot(dist, params=PARAMS):
290
+ db = dist_to_bins(dist, params)
291
+ dist = torch.nn.functional.one_hot(
292
+ db, num_classes=params["DBINS1"] + params["DBINS2"] + 1
293
+ ).float()
294
+ return dist
295
+
296
+
297
+ # ============================================================
298
+ def dist_to_bins(dist, params=PARAMS):
299
+ """bin 2d distance maps"""
300
+ dist[torch.isnan(dist)] = 999.9
301
+ dstep1 = (params["DMID"] - params["DMIN"]) / params["DBINS1"]
302
+ dstep2 = (params["DMAX"] - params["DMID"]) / params["DBINS2"]
303
+ dbins = torch.cat(
304
+ [
305
+ torch.linspace(
306
+ params["DMIN"] + dstep1,
307
+ params["DMID"],
308
+ params["DBINS1"],
309
+ dtype=dist.dtype,
310
+ device=dist.device,
311
+ ),
312
+ torch.linspace(
313
+ params["DMID"] + dstep2,
314
+ params["DMAX"],
315
+ params["DBINS2"],
316
+ dtype=dist.dtype,
317
+ device=dist.device,
318
+ ),
319
+ ]
320
+ )
321
+ db = torch.bucketize(dist.contiguous(), dbins).long()
322
+
323
+ return db
324
+
325
+
326
+ # ============================================================
327
+ def c6d_to_bins(c6d, same_chain, negative=False, params=PARAMS):
328
+ """bin 2d distance and orientation maps"""
329
+
330
+ db = dist_to_bins(c6d[..., 0], params) # all dist < DMIN are in bin 0
331
+
332
+ astep = 2.0 * np.pi / params["ABINS"]
333
+ ob = torch.round((c6d[..., 1] + np.pi - astep / 2) / astep)
334
+ tb = torch.round((c6d[..., 2] + np.pi - astep / 2) / astep)
335
+ pb = torch.round((c6d[..., 3] - astep / 2) / astep)
336
+
337
+ # synchronize no-contact bins
338
+ params["DBINS"] = params["DBINS1"] + params["DBINS2"]
339
+ ob[db == params["DBINS"]] = params["ABINS"]
340
+ tb[db == params["DBINS"]] = params["ABINS"]
341
+ pb[db == params["DBINS"]] = params["ABINS"] // 2
342
+
343
+ if negative:
344
+ db = torch.where(same_chain.bool(), db.long(), params["DBINS"])
345
+ ob = torch.where(same_chain.bool(), ob.long(), params["ABINS"])
346
+ tb = torch.where(same_chain.bool(), tb.long(), params["ABINS"])
347
+ pb = torch.where(same_chain.bool(), pb.long(), params["ABINS"] // 2)
348
+
349
+ return torch.stack([db, ob, tb, pb], axis=-1).long()
350
+
351
+
352
+ def standardize_dihedral_retain_first(a, b, c, d):
353
+ isomorphisms = [(a, b, c, d), (a, c, b, d)]
354
+ return sorted(isomorphisms)[0]