rc-foundry 0.1.6__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/inference_engines/checkpoint_registry.py +58 -11
- foundry/utils/alignment.py +10 -2
- foundry/utils/ddp.py +1 -1
- foundry/utils/logging.py +1 -1
- foundry/version.py +2 -2
- foundry_cli/download_checkpoints.py +66 -66
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/METADATA +30 -21
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/RECORD +31 -31
- rf3/cli.py +13 -4
- rf3/inference.py +3 -1
- rfd3/configs/datasets/train/pdb/af3_train_interface.yaml +1 -1
- rfd3/configs/inference_engine/rfdiffusion3.yaml +2 -2
- rfd3/configs/model/samplers/symmetry.yaml +1 -1
- rfd3/engine.py +28 -12
- rfd3/inference/datasets.py +1 -1
- rfd3/inference/input_parsing.py +32 -1
- rfd3/inference/legacy_input_parsing.py +17 -1
- rfd3/inference/parsing.py +1 -0
- rfd3/inference/symmetry/atom_array.py +78 -13
- rfd3/inference/symmetry/checks.py +62 -29
- rfd3/inference/symmetry/frames.py +256 -5
- rfd3/inference/symmetry/symmetry_utils.py +39 -61
- rfd3/model/inference_sampler.py +11 -1
- rfd3/model/layers/block_utils.py +33 -33
- rfd3/model/layers/chunked_pairwise.py +84 -82
- rfd3/run_inference.py +3 -1
- rfd3/transforms/symmetry.py +16 -7
- rfd3/utils/inference.py +21 -22
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -10,12 +10,13 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
|
|
|
10
10
|
Returns:
|
|
11
11
|
frames: list of rotation matrices
|
|
12
12
|
"""
|
|
13
|
+
from rfd3.inference.symmetry.symmetry_utils import SymmetryConfig
|
|
13
14
|
|
|
14
15
|
# Get frames from symmetry id
|
|
15
16
|
sym_conf = {}
|
|
16
|
-
if isinstance(symmetry_id,
|
|
17
|
+
if isinstance(symmetry_id, SymmetryConfig):
|
|
17
18
|
sym_conf = symmetry_id
|
|
18
|
-
symmetry_id = symmetry_id.
|
|
19
|
+
symmetry_id = symmetry_id.id
|
|
19
20
|
|
|
20
21
|
if symmetry_id.lower().startswith("c"):
|
|
21
22
|
order = int(symmetry_id[1:])
|
|
@@ -23,11 +24,17 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
|
|
|
23
24
|
elif symmetry_id.lower().startswith("d"):
|
|
24
25
|
order = int(symmetry_id[1:])
|
|
25
26
|
frames = get_dihedral_frames(order)
|
|
27
|
+
elif symmetry_id.lower() == "t":
|
|
28
|
+
frames = get_tetrahedral_frames()
|
|
29
|
+
elif symmetry_id.lower() == "o":
|
|
30
|
+
frames = get_octahedral_frames()
|
|
31
|
+
elif symmetry_id.lower() == "i":
|
|
32
|
+
frames = get_icosahedral_frames()
|
|
26
33
|
elif symmetry_id.lower() == "input_defined":
|
|
27
34
|
assert (
|
|
28
|
-
|
|
35
|
+
sym_conf.symmetry_file is not None
|
|
29
36
|
), "symmetry_file is required for input_defined symmetry"
|
|
30
|
-
frames = get_frames_from_file(sym_conf.
|
|
37
|
+
frames = get_frames_from_file(sym_conf.symmetry_file)
|
|
31
38
|
else:
|
|
32
39
|
raise ValueError(f"Symmetry id {symmetry_id} not supported")
|
|
33
40
|
|
|
@@ -120,7 +127,9 @@ def get_symmetry_frames_from_atom_array(src_atom_array, input_frames):
|
|
|
120
127
|
computed_frames = [(R, np.array([0, 0, 0])) for R in Rs]
|
|
121
128
|
|
|
122
129
|
# check that the computed frames match the input frames
|
|
123
|
-
check_input_frames_match_symmetry_frames(
|
|
130
|
+
check_input_frames_match_symmetry_frames(
|
|
131
|
+
computed_frames, input_frames, nids_by_entity
|
|
132
|
+
)
|
|
124
133
|
|
|
125
134
|
return computed_frames
|
|
126
135
|
|
|
@@ -277,6 +286,248 @@ def get_dihedral_frames(order):
|
|
|
277
286
|
return frames
|
|
278
287
|
|
|
279
288
|
|
|
289
|
+
def get_tetrahedral_frames():
|
|
290
|
+
"""
|
|
291
|
+
Get tetrahedral frames (T symmetry group, 12 elements).
|
|
292
|
+
Returns:
|
|
293
|
+
frames: list of rotation matrices
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
frames = []
|
|
297
|
+
|
|
298
|
+
# Identity
|
|
299
|
+
frames.append((np.eye(3), np.array([0, 0, 0])))
|
|
300
|
+
|
|
301
|
+
# 8 rotations by ±120° around body diagonals (±1, ±1, ±1)
|
|
302
|
+
diagonals = [
|
|
303
|
+
np.array([1, 1, 1]),
|
|
304
|
+
np.array([1, -1, -1]),
|
|
305
|
+
np.array([-1, 1, -1]),
|
|
306
|
+
np.array([-1, -1, 1]),
|
|
307
|
+
]
|
|
308
|
+
for d in diagonals:
|
|
309
|
+
axis = d / np.linalg.norm(d)
|
|
310
|
+
for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
|
|
311
|
+
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
312
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
313
|
+
|
|
314
|
+
# 3 rotations by 180° around coordinate axes
|
|
315
|
+
for axis in [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])]:
|
|
316
|
+
R = _rotation_matrix_from_axis_angle(axis, np.pi)
|
|
317
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
318
|
+
|
|
319
|
+
return frames
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def get_octahedral_frames():
|
|
323
|
+
"""
|
|
324
|
+
Get octahedral frames (O symmetry group, 24 elements).
|
|
325
|
+
The axes are computed from the geometry of a cube with vertices at (±1, ±1, ±1).
|
|
326
|
+
Returns:
|
|
327
|
+
frames: list of rotation matrices
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
frames = []
|
|
331
|
+
|
|
332
|
+
# 8 vertices of the cube
|
|
333
|
+
vertices = []
|
|
334
|
+
for s1 in [1, -1]:
|
|
335
|
+
for s2 in [1, -1]:
|
|
336
|
+
for s3 in [1, -1]:
|
|
337
|
+
vertices.append(np.array([s1, s2, s3]))
|
|
338
|
+
vertices = np.array(vertices)
|
|
339
|
+
|
|
340
|
+
# 6 face centers of the cube (4-fold axes pass through these)
|
|
341
|
+
face_centers = [
|
|
342
|
+
np.array([1, 0, 0]),
|
|
343
|
+
np.array([-1, 0, 0]),
|
|
344
|
+
np.array([0, 1, 0]),
|
|
345
|
+
np.array([0, -1, 0]),
|
|
346
|
+
np.array([0, 0, 1]),
|
|
347
|
+
np.array([0, 0, -1]),
|
|
348
|
+
]
|
|
349
|
+
|
|
350
|
+
# Find edges (pairs of vertices differing in exactly one coordinate)
|
|
351
|
+
edges = []
|
|
352
|
+
for i in range(len(vertices)):
|
|
353
|
+
for j in range(i + 1, len(vertices)):
|
|
354
|
+
diff = np.abs(vertices[i] - vertices[j])
|
|
355
|
+
if np.sum(diff > 0) == 1: # Differ in exactly one coordinate
|
|
356
|
+
edges.append((i, j))
|
|
357
|
+
|
|
358
|
+
# Helper to get unique axis (normalize direction to avoid duplicates)
|
|
359
|
+
def normalize_axis(v):
|
|
360
|
+
axis = v / np.linalg.norm(v)
|
|
361
|
+
for c in axis:
|
|
362
|
+
if abs(c) > 1e-10:
|
|
363
|
+
if c < 0:
|
|
364
|
+
axis = -axis
|
|
365
|
+
break
|
|
366
|
+
return tuple(np.round(axis, 10))
|
|
367
|
+
|
|
368
|
+
# Identity
|
|
369
|
+
frames.append((np.eye(3), np.array([0, 0, 0])))
|
|
370
|
+
|
|
371
|
+
# 4-fold axes (through opposite face centers) - 3 axes
|
|
372
|
+
# Each gives rotations at 90°, 180°, 270° (we skip 0° = identity)
|
|
373
|
+
fourfold_axes_set = set()
|
|
374
|
+
for fc in face_centers:
|
|
375
|
+
axis_tuple = normalize_axis(fc)
|
|
376
|
+
fourfold_axes_set.add(axis_tuple)
|
|
377
|
+
|
|
378
|
+
for axis_tuple in fourfold_axes_set:
|
|
379
|
+
axis = np.array(axis_tuple)
|
|
380
|
+
for k in [1, 2, 3]: # 90°, 180°, 270°
|
|
381
|
+
angle = np.pi * k / 2
|
|
382
|
+
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
383
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
384
|
+
|
|
385
|
+
# 3-fold axes (through opposite vertices) - 4 axes
|
|
386
|
+
# Each gives rotations at 120°, 240°
|
|
387
|
+
threefold_axes_set = set()
|
|
388
|
+
for v in vertices:
|
|
389
|
+
axis_tuple = normalize_axis(v)
|
|
390
|
+
threefold_axes_set.add(axis_tuple)
|
|
391
|
+
|
|
392
|
+
for axis_tuple in threefold_axes_set:
|
|
393
|
+
axis = np.array(axis_tuple)
|
|
394
|
+
for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
|
|
395
|
+
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
396
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
397
|
+
|
|
398
|
+
# 2-fold axes (through opposite edge midpoints) - 6 axes
|
|
399
|
+
# Each gives 1 rotation at 180°
|
|
400
|
+
twofold_axes_set = set()
|
|
401
|
+
for i, j in edges:
|
|
402
|
+
midpoint = (vertices[i] + vertices[j]) / 2
|
|
403
|
+
axis_tuple = normalize_axis(midpoint)
|
|
404
|
+
twofold_axes_set.add(axis_tuple)
|
|
405
|
+
|
|
406
|
+
for axis_tuple in twofold_axes_set:
|
|
407
|
+
axis = np.array(axis_tuple)
|
|
408
|
+
R = _rotation_matrix_from_axis_angle(axis, np.pi)
|
|
409
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
410
|
+
|
|
411
|
+
return frames
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def get_icosahedral_frames():
|
|
415
|
+
"""
|
|
416
|
+
Get icosahedral frames (I symmetry group, 60 elements).
|
|
417
|
+
The axes are computed from the geometry of a regular icosahedron with
|
|
418
|
+
vertices at (0, ±1, ±φ), (±1, ±φ, 0), (±φ, 0, ±1) where φ is the golden ratio.
|
|
419
|
+
Returns:
|
|
420
|
+
frames: list of rotation matrices
|
|
421
|
+
"""
|
|
422
|
+
|
|
423
|
+
frames = []
|
|
424
|
+
|
|
425
|
+
# Golden ratio
|
|
426
|
+
phi = (1 + np.sqrt(5)) / 2
|
|
427
|
+
|
|
428
|
+
# 12 vertices of the icosahedron
|
|
429
|
+
vertices = []
|
|
430
|
+
for s1 in [1, -1]:
|
|
431
|
+
for s2 in [1, -1]:
|
|
432
|
+
vertices.append(np.array([0, s1 * 1, s2 * phi]))
|
|
433
|
+
vertices.append(np.array([s1 * 1, s2 * phi, 0]))
|
|
434
|
+
vertices.append(np.array([s2 * phi, 0, s1 * 1]))
|
|
435
|
+
vertices = np.array(vertices)
|
|
436
|
+
|
|
437
|
+
# Find edges (pairs of vertices at distance 2)
|
|
438
|
+
edges = []
|
|
439
|
+
for i in range(len(vertices)):
|
|
440
|
+
for j in range(i + 1, len(vertices)):
|
|
441
|
+
dist_sq = np.sum((vertices[i] - vertices[j]) ** 2)
|
|
442
|
+
if np.isclose(dist_sq, 4.0):
|
|
443
|
+
edges.append((i, j))
|
|
444
|
+
|
|
445
|
+
# Find faces (triangles of mutually adjacent vertices)
|
|
446
|
+
edge_set = set(edges)
|
|
447
|
+
faces = []
|
|
448
|
+
for i in range(len(vertices)):
|
|
449
|
+
for j in range(i + 1, len(vertices)):
|
|
450
|
+
for k in range(j + 1, len(vertices)):
|
|
451
|
+
if (i, j) in edge_set and (j, k) in edge_set and (i, k) in edge_set:
|
|
452
|
+
faces.append((i, j, k))
|
|
453
|
+
|
|
454
|
+
# Helper to get unique axis (normalize direction to avoid duplicates)
|
|
455
|
+
def normalize_axis(v):
|
|
456
|
+
axis = v / np.linalg.norm(v)
|
|
457
|
+
# Make first significant component positive to avoid duplicate opposite axes
|
|
458
|
+
for c in axis:
|
|
459
|
+
if abs(c) > 1e-10:
|
|
460
|
+
if c < 0:
|
|
461
|
+
axis = -axis
|
|
462
|
+
break
|
|
463
|
+
return tuple(np.round(axis, 10))
|
|
464
|
+
|
|
465
|
+
# Identity
|
|
466
|
+
frames.append((np.eye(3), np.array([0, 0, 0])))
|
|
467
|
+
|
|
468
|
+
# 5-fold axes (through opposite vertices) - 6 axes, 4 rotations each = 24
|
|
469
|
+
fivefold_axes_set = set()
|
|
470
|
+
for v in vertices:
|
|
471
|
+
axis_tuple = normalize_axis(v)
|
|
472
|
+
fivefold_axes_set.add(axis_tuple)
|
|
473
|
+
|
|
474
|
+
for axis_tuple in fivefold_axes_set:
|
|
475
|
+
axis = np.array(axis_tuple)
|
|
476
|
+
for k in [1, 2, 3, 4]:
|
|
477
|
+
angle = 2 * np.pi * k / 5
|
|
478
|
+
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
479
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
480
|
+
|
|
481
|
+
# 3-fold axes (through opposite face centers) - 10 axes, 2 rotations each = 20
|
|
482
|
+
threefold_axes_set = set()
|
|
483
|
+
for i, j, k in faces:
|
|
484
|
+
center = (vertices[i] + vertices[j] + vertices[k]) / 3
|
|
485
|
+
axis_tuple = normalize_axis(center)
|
|
486
|
+
threefold_axes_set.add(axis_tuple)
|
|
487
|
+
|
|
488
|
+
for axis_tuple in threefold_axes_set:
|
|
489
|
+
axis = np.array(axis_tuple)
|
|
490
|
+
for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
|
|
491
|
+
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
492
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
493
|
+
|
|
494
|
+
# 2-fold axes (through opposite edge midpoints) - 15 axes, 1 rotation each = 15
|
|
495
|
+
twofold_axes_set = set()
|
|
496
|
+
for i, j in edges:
|
|
497
|
+
midpoint = (vertices[i] + vertices[j]) / 2
|
|
498
|
+
axis_tuple = normalize_axis(midpoint)
|
|
499
|
+
twofold_axes_set.add(axis_tuple)
|
|
500
|
+
|
|
501
|
+
for axis_tuple in twofold_axes_set:
|
|
502
|
+
axis = np.array(axis_tuple)
|
|
503
|
+
R = _rotation_matrix_from_axis_angle(axis, np.pi)
|
|
504
|
+
frames.append((R, np.array([0, 0, 0])))
|
|
505
|
+
|
|
506
|
+
return frames
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def _rotation_matrix_from_axis_angle(axis, angle):
|
|
510
|
+
"""
|
|
511
|
+
Compute a rotation matrix from an axis and angle using Rodrigues' formula.
|
|
512
|
+
Arguments:
|
|
513
|
+
axis: unit vector of the rotation axis
|
|
514
|
+
angle: rotation angle in radians
|
|
515
|
+
Returns:
|
|
516
|
+
R: 3x3 rotation matrix
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
axis = axis / np.linalg.norm(axis)
|
|
520
|
+
K = np.array(
|
|
521
|
+
[
|
|
522
|
+
[0, -axis[2], axis[1]],
|
|
523
|
+
[axis[2], 0, -axis[0]],
|
|
524
|
+
[-axis[1], axis[0], 0],
|
|
525
|
+
]
|
|
526
|
+
)
|
|
527
|
+
R = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)
|
|
528
|
+
return R
|
|
529
|
+
|
|
530
|
+
|
|
280
531
|
def get_frames_from_file(file_path):
|
|
281
532
|
raise NotImplementedError("Input defined symmetry not implemented")
|
|
282
533
|
|
|
@@ -39,18 +39,36 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
class SymmetryConfig(BaseModel):
|
|
42
|
-
# AM / HE TODO: feel free to flesh this out and add validation as needed
|
|
43
42
|
model_config = ConfigDict(
|
|
44
43
|
arbitrary_types_allowed=True,
|
|
45
44
|
extra="allow",
|
|
46
45
|
)
|
|
47
|
-
id: Optional[str] = Field(
|
|
48
|
-
|
|
49
|
-
|
|
46
|
+
id: Optional[str] = Field(
|
|
47
|
+
None,
|
|
48
|
+
description="Symmetry group ID. Supported types: Cyclic (C), Dihedral (D), Tetrahedral (T), Octahedral (O), Icosahedral (I). e.g. 'C3', 'D2', 'T', 'O', 'I'.",
|
|
49
|
+
)
|
|
50
|
+
is_unsym_motif: Optional[str] = Field(
|
|
51
|
+
None,
|
|
52
|
+
description="Comma separated list of contig/ligand names that should not be symmetrized such as DNA strands. \
|
|
53
|
+
e.g. 'HEM' or 'Y1-11,Z16-25'",
|
|
54
|
+
)
|
|
55
|
+
is_symmetric_motif: bool = Field(
|
|
56
|
+
True,
|
|
57
|
+
description="If True, the input motifs are expected to be already symmetric and won't be symmetrized. \
|
|
58
|
+
If False, the all input motifs are expected to be ASU and will be symmetrized.",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def convery_sym_conf_to_symmetry_config(sym_conf: dict):
|
|
63
|
+
return SymmetryConfig(**sym_conf)
|
|
50
64
|
|
|
51
65
|
|
|
52
66
|
def make_symmetric_atom_array(
|
|
53
|
-
asu_atom_array,
|
|
67
|
+
asu_atom_array,
|
|
68
|
+
sym_conf: SymmetryConfig | dict,
|
|
69
|
+
sm=None,
|
|
70
|
+
has_dist_cond=False,
|
|
71
|
+
src_atom_array=None,
|
|
54
72
|
):
|
|
55
73
|
"""
|
|
56
74
|
apply symmetry to an atom array.
|
|
@@ -58,41 +76,35 @@ def make_symmetric_atom_array(
|
|
|
58
76
|
asu_atom_array: atom array of the asymmetric unit
|
|
59
77
|
sym_conf: symmetry configuration (dict, "id" key is required)
|
|
60
78
|
sm: optional small molecule names (str, comma separated)
|
|
61
|
-
|
|
79
|
+
has_dist_cond: whether to add 2d entity annotations
|
|
62
80
|
Returns:
|
|
63
81
|
new_asu_atom_array: atom array with symmetry applied
|
|
64
82
|
"""
|
|
65
|
-
|
|
66
|
-
sym_conf
|
|
67
|
-
) # TODO: JB: remove this line to keep as symmetry config for cleaner syntax(?)
|
|
68
|
-
ranked_logger.info(f"Symmetry Configs: {sym_conf}")
|
|
83
|
+
if not isinstance(sym_conf, SymmetryConfig):
|
|
84
|
+
sym_conf = convery_sym_conf_to_symmetry_config(sym_conf)
|
|
69
85
|
|
|
70
|
-
# Making sure that the symmetry config is valid
|
|
71
86
|
check_symmetry_config(
|
|
72
|
-
asu_atom_array,
|
|
73
|
-
sym_conf,
|
|
74
|
-
sm,
|
|
75
|
-
has_dist_cond=has_2d,
|
|
76
|
-
src_atom_array=src_atom_array,
|
|
87
|
+
asu_atom_array, sym_conf, sm, has_dist_cond, src_atom_array=src_atom_array
|
|
77
88
|
)
|
|
78
89
|
# Adding utility annotations to the asu atom array
|
|
79
90
|
asu_atom_array = _add_util_annotations(asu_atom_array, sym_conf, sm)
|
|
80
91
|
|
|
81
|
-
if
|
|
92
|
+
if has_dist_cond: # NB: this will only work for asymmetric motifs at the moment - need to add functionality for symmetric motifs
|
|
82
93
|
asu_atom_array = add_2d_entity_annotations(asu_atom_array)
|
|
83
94
|
|
|
84
95
|
frames = get_symmetry_frames_from_symmetry_id(sym_conf)
|
|
85
96
|
|
|
86
97
|
# If the motif is symmetric, we get the frames instead from the source atom array.
|
|
87
|
-
if sym_conf.
|
|
98
|
+
if sym_conf.is_symmetric_motif:
|
|
88
99
|
assert (
|
|
89
100
|
src_atom_array is not None
|
|
90
101
|
), "Source atom array must be provided for symmetric motifs"
|
|
91
|
-
# if symmetric motif is provided, get the frames from the src atom array
|
|
102
|
+
# if symmetric motif is provided, get the frames from the src atom array.
|
|
92
103
|
frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
|
|
93
104
|
else:
|
|
94
|
-
|
|
95
|
-
|
|
105
|
+
# At this point, asym case would have been caught by the check_symmetry_config function.
|
|
106
|
+
ranked_logger.info(
|
|
107
|
+
"No motifs found in atom array. Generating unconditional symmetric proteins."
|
|
96
108
|
)
|
|
97
109
|
|
|
98
110
|
# Add symmetry annotations to the asu atom array
|
|
@@ -101,7 +113,7 @@ def make_symmetric_atom_array(
|
|
|
101
113
|
# Extracting all things at this moment that we will not want to symmetrize.
|
|
102
114
|
# This includes: 1) unsym motifs, 2) ligands
|
|
103
115
|
unsym_atom_arrays = []
|
|
104
|
-
if sym_conf.
|
|
116
|
+
if sym_conf.is_unsym_motif:
|
|
105
117
|
# unsym_motif_atom_array = get_unsym_motif(asu_atom_array, asu_atom_array._is_unsym_motif)
|
|
106
118
|
# Now remove the unsym motifs from the asu atom array
|
|
107
119
|
unsym_atom_arrays.append(asu_atom_array[asu_atom_array._is_unsym_motif])
|
|
@@ -128,7 +140,7 @@ def make_symmetric_atom_array(
|
|
|
128
140
|
symmetrized_atom_array = struc.concatenate(symmetry_unit_list)
|
|
129
141
|
|
|
130
142
|
# add 2D conditioning annotations
|
|
131
|
-
if
|
|
143
|
+
if has_dist_cond:
|
|
132
144
|
symmetrized_atom_array = reannotate_2d_conditions(symmetrized_atom_array)
|
|
133
145
|
|
|
134
146
|
# set all motifs to not have any symmetrization applied to them
|
|
@@ -183,7 +195,7 @@ def make_symmetric_atom_array_for_partial_diffusion(atom_array, sym_conf):
|
|
|
183
195
|
frames = get_symmetry_frames_from_symmetry_id(sym_conf)
|
|
184
196
|
|
|
185
197
|
# Add symmetry ID
|
|
186
|
-
symmetry_ids = np.full(n, sym_conf.
|
|
198
|
+
symmetry_ids = np.full(n, sym_conf.id, dtype="U6")
|
|
187
199
|
atom_array.set_annotation("symmetry_id", symmetry_ids)
|
|
188
200
|
|
|
189
201
|
# Initialize transform annotations (use same format as original system)
|
|
@@ -244,7 +256,7 @@ def _add_util_annotations(asu_atom_array, sym_conf, sm):
|
|
|
244
256
|
"""
|
|
245
257
|
n = asu_atom_array.shape[0]
|
|
246
258
|
is_motif = get_motif_features(asu_atom_array)["is_motif_atom"].astype(np.bool_)
|
|
247
|
-
is_sm = np.zeros(
|
|
259
|
+
is_sm = np.zeros(n, dtype=bool)
|
|
248
260
|
is_asu = np.ones(n, dtype=bool)
|
|
249
261
|
is_unsym_motif = np.zeros(n, dtype=bool)
|
|
250
262
|
|
|
@@ -257,8 +269,8 @@ def _add_util_annotations(asu_atom_array, sym_conf, sm):
|
|
|
257
269
|
)
|
|
258
270
|
|
|
259
271
|
# assign unsym motifs
|
|
260
|
-
if sym_conf.
|
|
261
|
-
unsym_motif_names = sym_conf
|
|
272
|
+
if sym_conf.is_unsym_motif:
|
|
273
|
+
unsym_motif_names = sym_conf.is_unsym_motif.split(",")
|
|
262
274
|
unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
|
|
263
275
|
is_unsym_motif = get_unsym_motif_mask(asu_atom_array, unsym_motif_names)
|
|
264
276
|
|
|
@@ -361,38 +373,4 @@ def apply_symmetry_to_xyz_atomwise(X_L, sym_feats, partial_diffusion=False):
|
|
|
361
373
|
"blc,cd->bld", asu_xyz, sym_transforms[target_id][0].to(asu_xyz.dtype)
|
|
362
374
|
) + sym_transforms[target_id][1].to(asu_xyz.dtype)
|
|
363
375
|
|
|
364
|
-
# Log inter-chain distances for debugging - use actual chain annotations
|
|
365
|
-
if sym_X_L.shape[1] > 100: # Only for large structures
|
|
366
|
-
# Use symmetry entity annotations to find different chains
|
|
367
|
-
sym_entity_id = sym_feats["sym_entity_id"]
|
|
368
|
-
unique_entities = torch.unique(sym_entity_id)
|
|
369
|
-
|
|
370
|
-
if len(unique_entities) >= 2:
|
|
371
|
-
# Get atoms from first two different entities
|
|
372
|
-
entity_0_mask = sym_entity_id == unique_entities[0]
|
|
373
|
-
entity_1_mask = sym_entity_id == unique_entities[1]
|
|
374
|
-
|
|
375
|
-
if entity_0_mask.sum() > 0 and entity_1_mask.sum() > 0:
|
|
376
|
-
entity_0_atoms = sym_X_L[0, entity_0_mask, :]
|
|
377
|
-
entity_1_atoms = sym_X_L[0, entity_1_mask, :]
|
|
378
|
-
|
|
379
|
-
# Sample subset to avoid memory issues
|
|
380
|
-
entity_0_sample = entity_0_atoms[: min(50, entity_0_atoms.shape[0]), :]
|
|
381
|
-
entity_1_sample = entity_1_atoms[: min(50, entity_1_atoms.shape[0]), :]
|
|
382
|
-
|
|
383
|
-
min_distance = (
|
|
384
|
-
torch.cdist(entity_0_sample, entity_1_sample).min().item()
|
|
385
|
-
)
|
|
386
|
-
ranked_logger.info(
|
|
387
|
-
f"Min inter-chain distance after symmetry: {min_distance:.2f} Å"
|
|
388
|
-
)
|
|
389
|
-
|
|
390
|
-
# Also log the centers of each entity
|
|
391
|
-
entity_0_center = entity_0_atoms.mean(dim=0)
|
|
392
|
-
entity_1_center = entity_1_atoms.mean(dim=0)
|
|
393
|
-
center_distance = torch.norm(entity_0_center - entity_1_center).item()
|
|
394
|
-
ranked_logger.info(
|
|
395
|
-
f"Distance between chain centers: {center_distance:.2f} Å"
|
|
396
|
-
)
|
|
397
|
-
|
|
398
376
|
return sym_X_L
|
rfd3/model/inference_sampler.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import time
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
from typing import Any, Literal
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
from jaxtyping import Float
|
|
8
|
+
from rfd3.inference.symmetry.symmetry_utils import apply_symmetry_to_xyz_atomwise
|
|
9
|
+
from rfd3.model.cfg_utils import strip_X
|
|
7
10
|
|
|
8
11
|
from foundry.common import exists
|
|
12
|
+
from foundry.utils.alignment import weighted_rigid_align
|
|
9
13
|
from foundry.utils.ddp import RankedLogger
|
|
10
14
|
from foundry.utils.rotation_augmentation import (
|
|
11
15
|
rot_vec_mul,
|
|
@@ -110,7 +114,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
|
|
110
114
|
)
|
|
111
115
|
# Fallback to smallest available step
|
|
112
116
|
noise_schedule_original = self._construct_inference_noise_schedule(
|
|
113
|
-
device=
|
|
117
|
+
device=device
|
|
114
118
|
)
|
|
115
119
|
noise_schedule = noise_schedule_original[-1:] # Just use the final step
|
|
116
120
|
ranked_logger.info(
|
|
@@ -221,6 +225,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
|
|
221
225
|
# Handle chunked mode vs standard mode
|
|
222
226
|
if "chunked_pairwise_embedder" in initializer_outputs:
|
|
223
227
|
# Chunked mode: explicitly provide P_LL=None
|
|
228
|
+
tic = time.time()
|
|
224
229
|
chunked_embedder = initializer_outputs[
|
|
225
230
|
"chunked_pairwise_embedder"
|
|
226
231
|
] # Don't pop, just get
|
|
@@ -238,6 +243,8 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
|
|
238
243
|
initializer_outputs=other_outputs,
|
|
239
244
|
**other_outputs,
|
|
240
245
|
)
|
|
246
|
+
toc = time.time()
|
|
247
|
+
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
|
|
241
248
|
else:
|
|
242
249
|
# Standard mode: P_LL is included in initializer_outputs
|
|
243
250
|
outs = diffusion_module(
|
|
@@ -445,6 +452,7 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
|
|
445
452
|
# Handle chunked mode vs standard mode (same as default sampler)
|
|
446
453
|
if "chunked_pairwise_embedder" in initializer_outputs:
|
|
447
454
|
# Chunked mode: explicitly provide P_LL=None
|
|
455
|
+
tic = time.time()
|
|
448
456
|
chunked_embedder = initializer_outputs[
|
|
449
457
|
"chunked_pairwise_embedder"
|
|
450
458
|
] # Don't pop, just get
|
|
@@ -462,6 +470,8 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
|
|
462
470
|
initializer_outputs=other_outputs,
|
|
463
471
|
**other_outputs,
|
|
464
472
|
)
|
|
473
|
+
toc = time.time()
|
|
474
|
+
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
|
|
465
475
|
else:
|
|
466
476
|
# Standard mode: P_LL is included in initializer_outputs
|
|
467
477
|
outs = diffusion_module(
|