rc-foundry 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/utils/ddp.py +1 -1
- foundry/utils/logging.py +1 -1
- foundry/version.py +2 -2
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/METADATA +6 -2
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/RECORD +22 -22
- rf3/cli.py +13 -4
- rf3/inference.py +3 -1
- rfd3/engine.py +11 -3
- rfd3/inference/datasets.py +1 -1
- rfd3/inference/input_parsing.py +31 -0
- rfd3/inference/symmetry/atom_array.py +78 -9
- rfd3/inference/symmetry/checks.py +12 -4
- rfd3/inference/symmetry/frames.py +248 -0
- rfd3/inference/symmetry/symmetry_utils.py +5 -5
- rfd3/model/inference_sampler.py +11 -1
- rfd3/model/layers/block_utils.py +33 -33
- rfd3/model/layers/chunked_pairwise.py +84 -82
- rfd3/transforms/symmetry.py +16 -7
- rfd3/utils/inference.py +4 -28
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -24,6 +24,12 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
|
|
|
24
24
|
elif symmetry_id.lower().startswith("d"):
|
|
25
25
|
order = int(symmetry_id[1:])
|
|
26
26
|
frames = get_dihedral_frames(order)
|
|
27
|
+
elif symmetry_id.lower() == "t":
|
|
28
|
+
frames = get_tetrahedral_frames()
|
|
29
|
+
elif symmetry_id.lower() == "o":
|
|
30
|
+
frames = get_octahedral_frames()
|
|
31
|
+
elif symmetry_id.lower() == "i":
|
|
32
|
+
frames = get_icosahedral_frames()
|
|
27
33
|
elif symmetry_id.lower() == "input_defined":
|
|
28
34
|
assert (
|
|
29
35
|
sym_conf.symmetry_file is not None
|
|
@@ -280,6 +286,248 @@ def get_dihedral_frames(order):
|
|
|
280
286
|
return frames
|
|
281
287
|
|
|
282
288
|
|
|
289
|
+
def get_tetrahedral_frames():
|
|
290
|
+
"""
|
|
291
|
+
Get tetrahedral frames (T symmetry group, 12 elements).
|
|
292
|
+
Returns:
|
|
293
|
+
frames: list of rotation matrices
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
frames = []
|
|
297
|
+
|
|
298
|
+
# Identity
|
|
299
|
+
frames.append((np.eye(3), np.array([0, 0, 0])))
|
|
300
|
+
|
|
301
|
+
# 8 rotations by ±120° around body diagonals (±1, ±1, ±1)
|
|
302
|
+
diagonals = [
|
|
303
|
+
np.array([1, 1, 1]),
|
|
304
|
+
np.array([1, -1, -1]),
|
|
305
|
+
np.array([-1, 1, -1]),
|
|
306
|
+
np.array([-1, -1, 1]),
|
|
307
|
+
]
|
|
308
|
+
for d in diagonals:
|
|
309
|
+
axis = d / np.linalg.norm(d)
|
|
310
|
+
for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
|
|
311
|
+
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
312
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
313
|
+
|
|
314
|
+
# 3 rotations by 180° around coordinate axes
|
|
315
|
+
for axis in [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])]:
|
|
316
|
+
R = _rotation_matrix_from_axis_angle(axis, np.pi)
|
|
317
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
318
|
+
|
|
319
|
+
return frames
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def get_octahedral_frames():
|
|
323
|
+
"""
|
|
324
|
+
Get octahedral frames (O symmetry group, 24 elements).
|
|
325
|
+
The axes are computed from the geometry of a cube with vertices at (±1, ±1, ±1).
|
|
326
|
+
Returns:
|
|
327
|
+
frames: list of rotation matrices
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
frames = []
|
|
331
|
+
|
|
332
|
+
# 8 vertices of the cube
|
|
333
|
+
vertices = []
|
|
334
|
+
for s1 in [1, -1]:
|
|
335
|
+
for s2 in [1, -1]:
|
|
336
|
+
for s3 in [1, -1]:
|
|
337
|
+
vertices.append(np.array([s1, s2, s3]))
|
|
338
|
+
vertices = np.array(vertices)
|
|
339
|
+
|
|
340
|
+
# 6 face centers of the cube (4-fold axes pass through these)
|
|
341
|
+
face_centers = [
|
|
342
|
+
np.array([1, 0, 0]),
|
|
343
|
+
np.array([-1, 0, 0]),
|
|
344
|
+
np.array([0, 1, 0]),
|
|
345
|
+
np.array([0, -1, 0]),
|
|
346
|
+
np.array([0, 0, 1]),
|
|
347
|
+
np.array([0, 0, -1]),
|
|
348
|
+
]
|
|
349
|
+
|
|
350
|
+
# Find edges (pairs of vertices differing in exactly one coordinate)
|
|
351
|
+
edges = []
|
|
352
|
+
for i in range(len(vertices)):
|
|
353
|
+
for j in range(i + 1, len(vertices)):
|
|
354
|
+
diff = np.abs(vertices[i] - vertices[j])
|
|
355
|
+
if np.sum(diff > 0) == 1: # Differ in exactly one coordinate
|
|
356
|
+
edges.append((i, j))
|
|
357
|
+
|
|
358
|
+
# Helper to get unique axis (normalize direction to avoid duplicates)
|
|
359
|
+
def normalize_axis(v):
|
|
360
|
+
axis = v / np.linalg.norm(v)
|
|
361
|
+
for c in axis:
|
|
362
|
+
if abs(c) > 1e-10:
|
|
363
|
+
if c < 0:
|
|
364
|
+
axis = -axis
|
|
365
|
+
break
|
|
366
|
+
return tuple(np.round(axis, 10))
|
|
367
|
+
|
|
368
|
+
# Identity
|
|
369
|
+
frames.append((np.eye(3), np.array([0, 0, 0])))
|
|
370
|
+
|
|
371
|
+
# 4-fold axes (through opposite face centers) - 3 axes
|
|
372
|
+
# Each gives rotations at 90°, 180°, 270° (we skip 0° = identity)
|
|
373
|
+
fourfold_axes_set = set()
|
|
374
|
+
for fc in face_centers:
|
|
375
|
+
axis_tuple = normalize_axis(fc)
|
|
376
|
+
fourfold_axes_set.add(axis_tuple)
|
|
377
|
+
|
|
378
|
+
for axis_tuple in fourfold_axes_set:
|
|
379
|
+
axis = np.array(axis_tuple)
|
|
380
|
+
for k in [1, 2, 3]: # 90°, 180°, 270°
|
|
381
|
+
angle = np.pi * k / 2
|
|
382
|
+
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
383
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
384
|
+
|
|
385
|
+
# 3-fold axes (through opposite vertices) - 4 axes
|
|
386
|
+
# Each gives rotations at 120°, 240°
|
|
387
|
+
threefold_axes_set = set()
|
|
388
|
+
for v in vertices:
|
|
389
|
+
axis_tuple = normalize_axis(v)
|
|
390
|
+
threefold_axes_set.add(axis_tuple)
|
|
391
|
+
|
|
392
|
+
for axis_tuple in threefold_axes_set:
|
|
393
|
+
axis = np.array(axis_tuple)
|
|
394
|
+
for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
|
|
395
|
+
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
396
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
397
|
+
|
|
398
|
+
# 2-fold axes (through opposite edge midpoints) - 6 axes
|
|
399
|
+
# Each gives 1 rotation at 180°
|
|
400
|
+
twofold_axes_set = set()
|
|
401
|
+
for i, j in edges:
|
|
402
|
+
midpoint = (vertices[i] + vertices[j]) / 2
|
|
403
|
+
axis_tuple = normalize_axis(midpoint)
|
|
404
|
+
twofold_axes_set.add(axis_tuple)
|
|
405
|
+
|
|
406
|
+
for axis_tuple in twofold_axes_set:
|
|
407
|
+
axis = np.array(axis_tuple)
|
|
408
|
+
R = _rotation_matrix_from_axis_angle(axis, np.pi)
|
|
409
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
410
|
+
|
|
411
|
+
return frames
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def get_icosahedral_frames():
|
|
415
|
+
"""
|
|
416
|
+
Get icosahedral frames (I symmetry group, 60 elements).
|
|
417
|
+
The axes are computed from the geometry of a regular icosahedron with
|
|
418
|
+
vertices at (0, ±1, ±φ), (±1, ±φ, 0), (±φ, 0, ±1) where φ is the golden ratio.
|
|
419
|
+
Returns:
|
|
420
|
+
frames: list of rotation matrices
|
|
421
|
+
"""
|
|
422
|
+
|
|
423
|
+
frames = []
|
|
424
|
+
|
|
425
|
+
# Golden ratio
|
|
426
|
+
phi = (1 + np.sqrt(5)) / 2
|
|
427
|
+
|
|
428
|
+
# 12 vertices of the icosahedron
|
|
429
|
+
vertices = []
|
|
430
|
+
for s1 in [1, -1]:
|
|
431
|
+
for s2 in [1, -1]:
|
|
432
|
+
vertices.append(np.array([0, s1 * 1, s2 * phi]))
|
|
433
|
+
vertices.append(np.array([s1 * 1, s2 * phi, 0]))
|
|
434
|
+
vertices.append(np.array([s2 * phi, 0, s1 * 1]))
|
|
435
|
+
vertices = np.array(vertices)
|
|
436
|
+
|
|
437
|
+
# Find edges (pairs of vertices at distance 2)
|
|
438
|
+
edges = []
|
|
439
|
+
for i in range(len(vertices)):
|
|
440
|
+
for j in range(i + 1, len(vertices)):
|
|
441
|
+
dist_sq = np.sum((vertices[i] - vertices[j]) ** 2)
|
|
442
|
+
if np.isclose(dist_sq, 4.0):
|
|
443
|
+
edges.append((i, j))
|
|
444
|
+
|
|
445
|
+
# Find faces (triangles of mutually adjacent vertices)
|
|
446
|
+
edge_set = set(edges)
|
|
447
|
+
faces = []
|
|
448
|
+
for i in range(len(vertices)):
|
|
449
|
+
for j in range(i + 1, len(vertices)):
|
|
450
|
+
for k in range(j + 1, len(vertices)):
|
|
451
|
+
if (i, j) in edge_set and (j, k) in edge_set and (i, k) in edge_set:
|
|
452
|
+
faces.append((i, j, k))
|
|
453
|
+
|
|
454
|
+
# Helper to get unique axis (normalize direction to avoid duplicates)
|
|
455
|
+
def normalize_axis(v):
|
|
456
|
+
axis = v / np.linalg.norm(v)
|
|
457
|
+
# Make first significant component positive to avoid duplicate opposite axes
|
|
458
|
+
for c in axis:
|
|
459
|
+
if abs(c) > 1e-10:
|
|
460
|
+
if c < 0:
|
|
461
|
+
axis = -axis
|
|
462
|
+
break
|
|
463
|
+
return tuple(np.round(axis, 10))
|
|
464
|
+
|
|
465
|
+
# Identity
|
|
466
|
+
frames.append((np.eye(3), np.array([0, 0, 0])))
|
|
467
|
+
|
|
468
|
+
# 5-fold axes (through opposite vertices) - 6 axes, 4 rotations each = 24
|
|
469
|
+
fivefold_axes_set = set()
|
|
470
|
+
for v in vertices:
|
|
471
|
+
axis_tuple = normalize_axis(v)
|
|
472
|
+
fivefold_axes_set.add(axis_tuple)
|
|
473
|
+
|
|
474
|
+
for axis_tuple in fivefold_axes_set:
|
|
475
|
+
axis = np.array(axis_tuple)
|
|
476
|
+
for k in [1, 2, 3, 4]:
|
|
477
|
+
angle = 2 * np.pi * k / 5
|
|
478
|
+
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
479
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
480
|
+
|
|
481
|
+
# 3-fold axes (through opposite face centers) - 10 axes, 2 rotations each = 20
|
|
482
|
+
threefold_axes_set = set()
|
|
483
|
+
for i, j, k in faces:
|
|
484
|
+
center = (vertices[i] + vertices[j] + vertices[k]) / 3
|
|
485
|
+
axis_tuple = normalize_axis(center)
|
|
486
|
+
threefold_axes_set.add(axis_tuple)
|
|
487
|
+
|
|
488
|
+
for axis_tuple in threefold_axes_set:
|
|
489
|
+
axis = np.array(axis_tuple)
|
|
490
|
+
for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
|
|
491
|
+
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
492
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
493
|
+
|
|
494
|
+
# 2-fold axes (through opposite edge midpoints) - 15 axes, 1 rotation each = 15
|
|
495
|
+
twofold_axes_set = set()
|
|
496
|
+
for i, j in edges:
|
|
497
|
+
midpoint = (vertices[i] + vertices[j]) / 2
|
|
498
|
+
axis_tuple = normalize_axis(midpoint)
|
|
499
|
+
twofold_axes_set.add(axis_tuple)
|
|
500
|
+
|
|
501
|
+
for axis_tuple in twofold_axes_set:
|
|
502
|
+
axis = np.array(axis_tuple)
|
|
503
|
+
R = _rotation_matrix_from_axis_angle(axis, np.pi)
|
|
504
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
505
|
+
|
|
506
|
+
return frames
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def _rotation_matrix_from_axis_angle(axis, angle):
|
|
510
|
+
"""
|
|
511
|
+
Compute a rotation matrix from an axis and angle using Rodrigues' formula.
|
|
512
|
+
Arguments:
|
|
513
|
+
axis: unit vector of the rotation axis
|
|
514
|
+
angle: rotation angle in radians
|
|
515
|
+
Returns:
|
|
516
|
+
R: 3x3 rotation matrix
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
axis = axis / np.linalg.norm(axis)
|
|
520
|
+
K = np.array(
|
|
521
|
+
[
|
|
522
|
+
[0, -axis[2], axis[1]],
|
|
523
|
+
[axis[2], 0, -axis[0]],
|
|
524
|
+
[-axis[1], axis[0], 0],
|
|
525
|
+
]
|
|
526
|
+
)
|
|
527
|
+
R = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)
|
|
528
|
+
return R
|
|
529
|
+
|
|
530
|
+
|
|
283
531
|
def get_frames_from_file(file_path):
|
|
284
532
|
raise NotImplementedError("Input defined symmetry not implemented")
|
|
285
533
|
|
|
@@ -45,7 +45,7 @@ class SymmetryConfig(BaseModel):
|
|
|
45
45
|
)
|
|
46
46
|
id: Optional[str] = Field(
|
|
47
47
|
None,
|
|
48
|
-
description="Symmetry group ID. e.g. 'C3', 'D2'
|
|
48
|
+
description="Symmetry group ID. Supported types: Cyclic (C), Dihedral (D), Tetrahedral (T), Octahedral (O), Icosahedral (I). e.g. 'C3', 'D2', 'T', 'O', 'I'.",
|
|
49
49
|
)
|
|
50
50
|
is_unsym_motif: Optional[str] = Field(
|
|
51
51
|
None,
|
|
@@ -101,10 +101,10 @@ def make_symmetric_atom_array(
|
|
|
101
101
|
), "Source atom array must be provided for symmetric motifs"
|
|
102
102
|
# if symmetric motif is provided, get the frames from the src atom array.
|
|
103
103
|
frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
|
|
104
|
-
|
|
105
|
-
#
|
|
106
|
-
|
|
107
|
-
"
|
|
104
|
+
else:
|
|
105
|
+
# At this point, asym case would have been caught by the check_symmetry_config function.
|
|
106
|
+
ranked_logger.info(
|
|
107
|
+
"No motifs found in atom array. Generating unconditional symmetric proteins."
|
|
108
108
|
)
|
|
109
109
|
|
|
110
110
|
# Add symmetry annotations to the asu atom array
|
rfd3/model/inference_sampler.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import time
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
from typing import Any, Literal
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
from jaxtyping import Float
|
|
8
|
+
from rfd3.inference.symmetry.symmetry_utils import apply_symmetry_to_xyz_atomwise
|
|
9
|
+
from rfd3.model.cfg_utils import strip_X
|
|
7
10
|
|
|
8
11
|
from foundry.common import exists
|
|
12
|
+
from foundry.utils.alignment import weighted_rigid_align
|
|
9
13
|
from foundry.utils.ddp import RankedLogger
|
|
10
14
|
from foundry.utils.rotation_augmentation import (
|
|
11
15
|
rot_vec_mul,
|
|
@@ -110,7 +114,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
|
|
110
114
|
)
|
|
111
115
|
# Fallback to smallest available step
|
|
112
116
|
noise_schedule_original = self._construct_inference_noise_schedule(
|
|
113
|
-
device=
|
|
117
|
+
device=device
|
|
114
118
|
)
|
|
115
119
|
noise_schedule = noise_schedule_original[-1:] # Just use the final step
|
|
116
120
|
ranked_logger.info(
|
|
@@ -221,6 +225,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
|
|
221
225
|
# Handle chunked mode vs standard mode
|
|
222
226
|
if "chunked_pairwise_embedder" in initializer_outputs:
|
|
223
227
|
# Chunked mode: explicitly provide P_LL=None
|
|
228
|
+
tic = time.time()
|
|
224
229
|
chunked_embedder = initializer_outputs[
|
|
225
230
|
"chunked_pairwise_embedder"
|
|
226
231
|
] # Don't pop, just get
|
|
@@ -238,6 +243,8 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
|
|
238
243
|
initializer_outputs=other_outputs,
|
|
239
244
|
**other_outputs,
|
|
240
245
|
)
|
|
246
|
+
toc = time.time()
|
|
247
|
+
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
|
|
241
248
|
else:
|
|
242
249
|
# Standard mode: P_LL is included in initializer_outputs
|
|
243
250
|
outs = diffusion_module(
|
|
@@ -445,6 +452,7 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
|
|
445
452
|
# Handle chunked mode vs standard mode (same as default sampler)
|
|
446
453
|
if "chunked_pairwise_embedder" in initializer_outputs:
|
|
447
454
|
# Chunked mode: explicitly provide P_LL=None
|
|
455
|
+
tic = time.time()
|
|
448
456
|
chunked_embedder = initializer_outputs[
|
|
449
457
|
"chunked_pairwise_embedder"
|
|
450
458
|
] # Don't pop, just get
|
|
@@ -462,6 +470,8 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
|
|
462
470
|
initializer_outputs=other_outputs,
|
|
463
471
|
**other_outputs,
|
|
464
472
|
)
|
|
473
|
+
toc = time.time()
|
|
474
|
+
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
|
|
465
475
|
else:
|
|
466
476
|
# Standard mode: P_LL is included in initializer_outputs
|
|
467
477
|
outs = diffusion_module(
|
rfd3/model/layers/block_utils.py
CHANGED
|
@@ -118,14 +118,14 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
|
|
|
118
118
|
|
|
119
119
|
Parameters
|
|
120
120
|
----------
|
|
121
|
-
P_LK_indices : (
|
|
121
|
+
P_LK_indices : (B, L, k) LongTensor
|
|
122
122
|
Key indices | P_LK_indices[d, i, k] = global atom index for which atom i attends to.
|
|
123
|
-
P_LK : (
|
|
123
|
+
P_LK : (B, L, k, c) FloatTensor
|
|
124
124
|
Key features to scatter add into
|
|
125
125
|
|
|
126
|
-
P_LA_indices : (
|
|
126
|
+
P_LA_indices : (B, L, a) LongTensor
|
|
127
127
|
Additional feature indices to scatter into P_LK.
|
|
128
|
-
P_LA : (
|
|
128
|
+
P_LA : (B, L, a, c) FloatTensor
|
|
129
129
|
Features corresponding to P_LA.
|
|
130
130
|
|
|
131
131
|
Both index tensors contain indices representing D batch dim,
|
|
@@ -135,42 +135,42 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
|
|
|
135
135
|
|
|
136
136
|
"""
|
|
137
137
|
# Handle case when indices and P_LA don't have batch dimensions
|
|
138
|
-
|
|
138
|
+
B, L, k = P_LK_indices.shape
|
|
139
139
|
if P_LA_indices.ndim == 2:
|
|
140
|
-
P_LA_indices = P_LA_indices.unsqueeze(0).expand(
|
|
140
|
+
P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1)
|
|
141
141
|
if P_LA_src.ndim == 3:
|
|
142
|
-
P_LA_src = P_LA_src.unsqueeze(0).expand(
|
|
142
|
+
P_LA_src = P_LA_src.unsqueeze(0).expand(B, -1, -1)
|
|
143
143
|
assert (
|
|
144
144
|
P_LA_src.shape[-1] == P_LK_tgt.shape[-1]
|
|
145
145
|
), "Channel dims do not match, got: {} vs {}".format(
|
|
146
146
|
P_LA_src.shape[-1], P_LK_tgt.shape[-1]
|
|
147
147
|
)
|
|
148
148
|
|
|
149
|
-
matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (
|
|
149
|
+
matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (B, L, a, k)
|
|
150
150
|
if not torch.all(matches.sum(dim=(-1, -2)) >= 1):
|
|
151
151
|
raise ValueError("Found multiple scatter indices for some atoms")
|
|
152
152
|
elif not torch.all(matches.sum(dim=-1) <= 1):
|
|
153
153
|
raise ValueError("Did not find a scatter index for every atom")
|
|
154
|
-
k_indices = matches.long().argmax(dim=-1) # (
|
|
154
|
+
k_indices = matches.long().argmax(dim=-1) # (B, L, a)
|
|
155
155
|
scatter_indices = k_indices.unsqueeze(-1).expand(
|
|
156
156
|
-1, -1, -1, P_LK_tgt.shape[-1]
|
|
157
|
-
) # (
|
|
157
|
+
) # (B, L, a, c)
|
|
158
158
|
P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src)
|
|
159
159
|
return P_LK_tgt
|
|
160
160
|
|
|
161
161
|
|
|
162
162
|
def _batched_gather(values: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|
163
163
|
"""
|
|
164
|
-
values : (
|
|
165
|
-
idx : (
|
|
166
|
-
returns: (
|
|
164
|
+
values : (B, L, C)
|
|
165
|
+
idx : (B, L, k)
|
|
166
|
+
returns: (B, L, k, C)
|
|
167
167
|
"""
|
|
168
|
-
|
|
168
|
+
B, L, C = values.shape
|
|
169
169
|
k = idx.shape[-1]
|
|
170
170
|
|
|
171
|
-
# (
|
|
171
|
+
# (B, L, 1, C) → stride-0 along k → (B, L, k, C)
|
|
172
172
|
src = values.unsqueeze(2).expand(-1, -1, k, -1)
|
|
173
|
-
idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (
|
|
173
|
+
idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (B, L, k, C)
|
|
174
174
|
|
|
175
175
|
return torch.gather(src, 1, idx) # dim=1 is the L-axis
|
|
176
176
|
|
|
@@ -196,7 +196,7 @@ def create_attention_indices(
|
|
|
196
196
|
X_L = torch.randn(
|
|
197
197
|
(1, L, 3), device=device, dtype=torch.float
|
|
198
198
|
) # [L, 3] - random
|
|
199
|
-
D_LL = torch.cdist(X_L, X_L, p=2) # [
|
|
199
|
+
D_LL = torch.cdist(X_L, X_L, p=2) # [B, L, L] - pairwise atom distances
|
|
200
200
|
|
|
201
201
|
# Create attention indices using neighbour distances
|
|
202
202
|
base_mask = ~f["unindexing_pair_mask"][
|
|
@@ -231,7 +231,7 @@ def create_attention_indices(
|
|
|
231
231
|
k_max=k_actual,
|
|
232
232
|
chain_id=chain_ids,
|
|
233
233
|
base_mask=base_mask,
|
|
234
|
-
) # [
|
|
234
|
+
) # [B, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
|
|
235
235
|
|
|
236
236
|
return attn_indices
|
|
237
237
|
|
|
@@ -245,7 +245,7 @@ def get_sparse_attention_indices_with_inter_chain(
|
|
|
245
245
|
|
|
246
246
|
Args:
|
|
247
247
|
tok_idx: atom to token mapping
|
|
248
|
-
D_LL: pairwise distances [
|
|
248
|
+
D_LL: pairwise distances [B, L, L]
|
|
249
249
|
n_seq_neighbours: number of sequence neighbors
|
|
250
250
|
k_intra: number of intra-chain attention keys
|
|
251
251
|
k_inter: number of inter-chain attention keys
|
|
@@ -253,29 +253,29 @@ def get_sparse_attention_indices_with_inter_chain(
|
|
|
253
253
|
base_mask: base mask for valid pairs
|
|
254
254
|
|
|
255
255
|
Returns:
|
|
256
|
-
attn_indices: [
|
|
256
|
+
attn_indices: [B, L, k_total] where k_total = k_intra + k_inter
|
|
257
257
|
"""
|
|
258
|
-
|
|
258
|
+
B, L, _ = D_LL.shape
|
|
259
259
|
|
|
260
260
|
# Get regular intra-chain indices (limited to k_intra)
|
|
261
261
|
intra_indices = get_sparse_attention_indices(
|
|
262
262
|
tok_idx, D_LL, n_seq_neighbours, k_intra, chain_id, base_mask
|
|
263
|
-
) # [
|
|
263
|
+
) # [B, L, k_intra]
|
|
264
264
|
|
|
265
265
|
# Get inter-chain indices for clash avoidance
|
|
266
|
-
inter_indices = torch.zeros(
|
|
267
|
-
|
|
268
|
-
for
|
|
269
|
-
for
|
|
270
|
-
query_chain = chain_id[
|
|
266
|
+
inter_indices = torch.zeros(B, L, k_inter, dtype=torch.long, device=D_LL.device)
|
|
267
|
+
unique_chains = torch.unique(chain_id)
|
|
268
|
+
for b in range(B):
|
|
269
|
+
for c in unique_chains:
|
|
270
|
+
query_chain = chain_id[c]
|
|
271
271
|
|
|
272
272
|
# Find atoms from different chains
|
|
273
|
-
other_chain_mask = (chain_id != query_chain) & base_mask[
|
|
273
|
+
other_chain_mask = (chain_id != query_chain) & base_mask[c, :]
|
|
274
274
|
other_chain_atoms = torch.where(other_chain_mask)[0]
|
|
275
275
|
|
|
276
276
|
if len(other_chain_atoms) > 0:
|
|
277
277
|
# Get distances to other chains
|
|
278
|
-
distances_to_other = D_LL[
|
|
278
|
+
distances_to_other = D_LL[b, c, other_chain_atoms]
|
|
279
279
|
|
|
280
280
|
# Select k_inter closest atoms from other chains
|
|
281
281
|
n_select = min(k_inter, len(other_chain_atoms))
|
|
@@ -283,23 +283,23 @@ def get_sparse_attention_indices_with_inter_chain(
|
|
|
283
283
|
selected_atoms = other_chain_atoms[closest_idx]
|
|
284
284
|
|
|
285
285
|
# Fill inter-chain indices
|
|
286
|
-
inter_indices[
|
|
286
|
+
inter_indices[b, c, :n_select] = selected_atoms
|
|
287
287
|
# Pad with random atoms if needed
|
|
288
288
|
if n_select < k_inter:
|
|
289
289
|
padding = torch.randint(
|
|
290
290
|
0, L, (k_inter - n_select,), device=D_LL.device
|
|
291
291
|
)
|
|
292
|
-
inter_indices[
|
|
292
|
+
inter_indices[b, c, n_select:] = padding
|
|
293
293
|
else:
|
|
294
294
|
# No other chains found, fill with random indices
|
|
295
|
-
inter_indices[
|
|
295
|
+
inter_indices[b, c, :] = torch.randint(
|
|
296
296
|
0, L, (k_inter,), device=D_LL.device
|
|
297
297
|
)
|
|
298
298
|
|
|
299
299
|
# Combine intra and inter chain indices
|
|
300
300
|
combined_indices = torch.cat(
|
|
301
301
|
[intra_indices, inter_indices], dim=-1
|
|
302
|
-
) # [
|
|
302
|
+
) # [B, L, k_total]
|
|
303
303
|
|
|
304
304
|
return combined_indices
|
|
305
305
|
|