rc-foundry 0.1.9__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/trainers/fabric.py +18 -2
- foundry/utils/components.py +3 -3
- foundry/utils/ddp.py +15 -12
- foundry/utils/xpu/__init__.py +27 -0
- foundry/utils/xpu/single_xpu_strategy.py +47 -0
- foundry/utils/xpu/xpu_accelerator.py +91 -0
- foundry/utils/xpu/xpu_precision.py +72 -0
- foundry/version.py +2 -2
- mpnn/inference_engines/mpnn.py +6 -2
- {rc_foundry-0.1.9.dist-info → rc_foundry-0.1.10.dist-info}/METADATA +11 -1
- {rc_foundry-0.1.9.dist-info → rc_foundry-0.1.10.dist-info}/RECORD +32 -26
- rf3/configs/inference.yaml +5 -0
- rf3/configs/train.yaml +5 -0
- rf3/configs/trainer/xpu.yaml +6 -0
- rf3/configs/validate.yaml +5 -0
- rfd3/configs/dev.yaml +1 -0
- rfd3/configs/inference.yaml +1 -0
- rfd3/configs/train.yaml +2 -1
- rfd3/configs/trainer/xpu.yaml +6 -0
- rfd3/configs/validate.yaml +1 -0
- rfd3/engine.py +14 -7
- rfd3/inference/input_parsing.py +4 -2
- rfd3/inference/symmetry/atom_array.py +9 -78
- rfd3/inference/symmetry/frames.py +0 -248
- rfd3/inference/symmetry/symmetry_utils.py +2 -3
- rfd3/model/inference_sampler.py +3 -1
- rfd3/transforms/hbonds_hbplus.py +52 -49
- rfd3/transforms/symmetry.py +7 -16
- rfd3/utils/inference.py +7 -6
- {rc_foundry-0.1.9.dist-info → rc_foundry-0.1.10.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.9.dist-info → rc_foundry-0.1.10.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.9.dist-info → rc_foundry-0.1.10.dist-info}/licenses/LICENSE.md +0 -0
rfd3/engine.py
CHANGED
|
@@ -5,7 +5,7 @@ import time
|
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from os import PathLike
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Dict, List, Optional
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
11
|
import yaml
|
|
@@ -48,9 +48,8 @@ class RFD3InferenceConfig:
|
|
|
48
48
|
diffusion_batch_size: int = 16
|
|
49
49
|
|
|
50
50
|
# RFD3 specific
|
|
51
|
-
skip_existing: bool = False
|
|
52
|
-
json_keys_subset: Optional[List[str]] = None
|
|
53
51
|
skip_existing: bool = True
|
|
52
|
+
json_keys_subset: Optional[List[str]] = None
|
|
54
53
|
specification: Optional[dict] = field(default_factory=dict)
|
|
55
54
|
inference_sampler: SampleDiffusionConfig | dict = field(default_factory=dict)
|
|
56
55
|
|
|
@@ -216,6 +215,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
216
215
|
inputs=inputs,
|
|
217
216
|
n_batches=n_batches,
|
|
218
217
|
)
|
|
218
|
+
if len(design_specifications) == 0:
|
|
219
|
+
ranked_logger.info("No design specifications to run. Skipping.")
|
|
220
|
+
return None
|
|
219
221
|
ensure_inference_sampler_matches_design_spec(
|
|
220
222
|
design_specifications, self.inference_sampler_overrides
|
|
221
223
|
)
|
|
@@ -378,15 +380,21 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
378
380
|
|
|
379
381
|
def _multiply_specifications(
|
|
380
382
|
self, inputs: Dict[str, dict | DesignInputSpecification], n_batches=None
|
|
381
|
-
) -> Dict[str,
|
|
383
|
+
) -> Dict[str, dict | DesignInputSpecification]:
|
|
382
384
|
# Find existing example IDS in output directory
|
|
383
385
|
if exists(self.out_dir):
|
|
384
|
-
|
|
386
|
+
existing_example_ids_ = set(
|
|
385
387
|
extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
|
|
386
388
|
for path in find_files_with_extension(self.out_dir, CIF_LIKE_EXTENSIONS)
|
|
387
389
|
)
|
|
390
|
+
existing_example_ids = set(
|
|
391
|
+
[
|
|
392
|
+
"_model_".join(eid.split("_model_")[:-1])
|
|
393
|
+
for eid in existing_example_ids_
|
|
394
|
+
]
|
|
395
|
+
)
|
|
388
396
|
ranked_logger.info(
|
|
389
|
-
f"Found {len(existing_example_ids)} existing example IDs in the output directory."
|
|
397
|
+
f"Found {len(existing_example_ids)} existing example IDs in the output directory ({len(existing_example_ids_)} total)."
|
|
390
398
|
)
|
|
391
399
|
|
|
392
400
|
# Based on inputs, construct the specifications to loop through
|
|
@@ -405,7 +413,6 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
405
413
|
for batch_id in range((n_batches) if exists(n_batches) else 1):
|
|
406
414
|
# ... Example ID
|
|
407
415
|
example_id = f"{prefix}_{batch_id}" if exists(n_batches) else prefix
|
|
408
|
-
|
|
409
416
|
if (
|
|
410
417
|
self.skip_existing
|
|
411
418
|
and exists(self.out_dir)
|
rfd3/inference/input_parsing.py
CHANGED
|
@@ -128,8 +128,10 @@ class DesignInputSpecification(BaseModel):
|
|
|
128
128
|
# Motif selection from input file
|
|
129
129
|
contig: Optional[InputSelection] = Field(None, description="Contig specification string (e.g. 'A1-10,B1-5')")
|
|
130
130
|
unindex: Optional[InputSelection] = Field(None,
|
|
131
|
-
description="Unindexed components
|
|
132
|
-
"
|
|
131
|
+
description="Unindexed components selection. Components to fix in the generated structure without specifying sequence index. "\
|
|
132
|
+
"Components must not overlap with `contig` argument. "\
|
|
133
|
+
"E.g. 'A15-20,B6-10' or dict. We recommend specifying unindexed residues as a contig string, "\
|
|
134
|
+
"then using select_fixed_atoms will subset the atoms to the specified atoms")
|
|
133
135
|
# Extra args:
|
|
134
136
|
length: Optional[str] = Field(None, description="Length range as 'min-max' or int. Constrains length of contig if provided")
|
|
135
137
|
ligand: Optional[str] = Field(None, description="Ligand name or index to include in design.")
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
import string
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
2
|
from rfd3.inference.symmetry.frames import (
|
|
5
3
|
decompose_symmetry_frame,
|
|
@@ -9,68 +7,6 @@ from rfd3.inference.symmetry.frames import (
|
|
|
9
7
|
FIXED_TRANSFORM_ID = -1
|
|
10
8
|
FIXED_ENTITY_ID = -1
|
|
11
9
|
|
|
12
|
-
# Alphabet for chain ID generation (uppercase letters only, per wwPDB convention)
|
|
13
|
-
_CHAIN_ALPHABET = string.ascii_uppercase
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def index_to_chain_id(index: int) -> str:
|
|
17
|
-
"""
|
|
18
|
-
Convert a zero-based index to a chain ID following wwPDB convention.
|
|
19
|
-
|
|
20
|
-
The naming follows the wwPDB-assigned chain ID system:
|
|
21
|
-
- 0-25: A-Z (single letter)
|
|
22
|
-
- 26-701: AA-ZZ (double letter)
|
|
23
|
-
- 702-18277: AAA-ZZZ (triple letter)
|
|
24
|
-
- And so on...
|
|
25
|
-
|
|
26
|
-
This is similar to Excel column naming (A, B, ..., Z, AA, AB, ...).
|
|
27
|
-
|
|
28
|
-
Arguments:
|
|
29
|
-
index: zero-based index (0 -> 'A', 25 -> 'Z', 26 -> 'AA', etc.)
|
|
30
|
-
Returns:
|
|
31
|
-
chain_id: string chain identifier
|
|
32
|
-
"""
|
|
33
|
-
if index < 0:
|
|
34
|
-
raise ValueError(f"Chain index must be non-negative, got {index}")
|
|
35
|
-
|
|
36
|
-
result = ""
|
|
37
|
-
remaining = index
|
|
38
|
-
|
|
39
|
-
# Convert to bijective base-26 (like Excel columns)
|
|
40
|
-
while True:
|
|
41
|
-
result = _CHAIN_ALPHABET[remaining % 26] + result
|
|
42
|
-
remaining = remaining // 26 - 1
|
|
43
|
-
if remaining < 0:
|
|
44
|
-
break
|
|
45
|
-
|
|
46
|
-
return result
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def chain_id_to_index(chain_id: str) -> int:
|
|
50
|
-
"""
|
|
51
|
-
Convert a chain ID back to a zero-based index.
|
|
52
|
-
|
|
53
|
-
Inverse of index_to_chain_id.
|
|
54
|
-
|
|
55
|
-
Arguments:
|
|
56
|
-
chain_id: string chain identifier (e.g., 'A', 'Z', 'AA', 'AB')
|
|
57
|
-
Returns:
|
|
58
|
-
index: zero-based index
|
|
59
|
-
"""
|
|
60
|
-
if not chain_id or not all(c in _CHAIN_ALPHABET for c in chain_id):
|
|
61
|
-
raise ValueError(f"Invalid chain ID: {chain_id}")
|
|
62
|
-
|
|
63
|
-
# Offset for all shorter chain IDs (26 + 26^2 + ... + 26^(len-1))
|
|
64
|
-
offset = sum(26**k for k in range(1, len(chain_id)))
|
|
65
|
-
|
|
66
|
-
# Value within the current length group (standard base-26)
|
|
67
|
-
value = 0
|
|
68
|
-
for char in chain_id:
|
|
69
|
-
value = value * 26 + _CHAIN_ALPHABET.index(char)
|
|
70
|
-
|
|
71
|
-
return offset + value
|
|
72
|
-
|
|
73
|
-
|
|
74
10
|
########################################################
|
|
75
11
|
# Symmetry annotations
|
|
76
12
|
########################################################
|
|
@@ -311,13 +247,11 @@ def reset_chain_ids(atom_array, start_id):
|
|
|
311
247
|
Reset the chain ids and pn_unit_iids of an atom array to start from the given id.
|
|
312
248
|
Arguments:
|
|
313
249
|
atom_array: atom array with chain_ids and pn_unit_iids annotated
|
|
314
|
-
start_id: starting chain ID (e.g., 'A')
|
|
315
250
|
"""
|
|
316
251
|
chain_ids = np.unique(atom_array.chain_id)
|
|
317
|
-
|
|
318
|
-
for
|
|
319
|
-
|
|
320
|
-
atom_array.chain_id[atom_array.chain_id == old_id] = new_id
|
|
252
|
+
new_chain_range = range(ord(start_id), ord(start_id) + len(chain_ids))
|
|
253
|
+
for new_id, old_id in zip(new_chain_range, chain_ids):
|
|
254
|
+
atom_array.chain_id[atom_array.chain_id == old_id] = chr(new_id)
|
|
321
255
|
atom_array.pn_unit_iid = atom_array.chain_id
|
|
322
256
|
return atom_array
|
|
323
257
|
|
|
@@ -325,18 +259,15 @@ def reset_chain_ids(atom_array, start_id):
|
|
|
325
259
|
def reannotate_chain_ids(atom_array, offset, multiplier=0):
|
|
326
260
|
"""
|
|
327
261
|
Reannotate the chain ids and pn_unit_iids of an atom array.
|
|
328
|
-
|
|
329
|
-
Uses wwPDB-style chain IDs (A-Z, AA-ZZ, AAA-ZZZ, ...) to support
|
|
330
|
-
any number of chains.
|
|
331
|
-
|
|
332
262
|
Arguments:
|
|
333
263
|
atom_array: protein atom array with chain_ids and pn_unit_iids annotated
|
|
334
|
-
offset: offset to add to the chain ids
|
|
335
|
-
multiplier: multiplier
|
|
264
|
+
offset: offset to add to the chain ids
|
|
265
|
+
multiplier: multiplier to add to the chain ids
|
|
336
266
|
"""
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
267
|
+
chain_ids_int = (
|
|
268
|
+
np.array([ord(c) for c in atom_array.chain_id]) + offset * multiplier
|
|
269
|
+
)
|
|
270
|
+
chain_ids = np.array([chr(id) for id in chain_ids_int], dtype=str)
|
|
340
271
|
atom_array.chain_id = chain_ids
|
|
341
272
|
atom_array.pn_unit_iid = chain_ids
|
|
342
273
|
return atom_array
|
|
@@ -24,12 +24,6 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
|
|
|
24
24
|
elif symmetry_id.lower().startswith("d"):
|
|
25
25
|
order = int(symmetry_id[1:])
|
|
26
26
|
frames = get_dihedral_frames(order)
|
|
27
|
-
elif symmetry_id.lower() == "t":
|
|
28
|
-
frames = get_tetrahedral_frames()
|
|
29
|
-
elif symmetry_id.lower() == "o":
|
|
30
|
-
frames = get_octahedral_frames()
|
|
31
|
-
elif symmetry_id.lower() == "i":
|
|
32
|
-
frames = get_icosahedral_frames()
|
|
33
27
|
elif symmetry_id.lower() == "input_defined":
|
|
34
28
|
assert (
|
|
35
29
|
sym_conf.symmetry_file is not None
|
|
@@ -286,248 +280,6 @@ def get_dihedral_frames(order):
|
|
|
286
280
|
return frames
|
|
287
281
|
|
|
288
282
|
|
|
289
|
-
def get_tetrahedral_frames():
|
|
290
|
-
"""
|
|
291
|
-
Get tetrahedral frames (T symmetry group, 12 elements).
|
|
292
|
-
Returns:
|
|
293
|
-
frames: list of rotation matrices
|
|
294
|
-
"""
|
|
295
|
-
|
|
296
|
-
frames = []
|
|
297
|
-
|
|
298
|
-
# Identity
|
|
299
|
-
frames.append((np.eye(3), np.array([0, 0, 0])))
|
|
300
|
-
|
|
301
|
-
# 8 rotations by ±120° around body diagonals (±1, ±1, ±1)
|
|
302
|
-
diagonals = [
|
|
303
|
-
np.array([1, 1, 1]),
|
|
304
|
-
np.array([1, -1, -1]),
|
|
305
|
-
np.array([-1, 1, -1]),
|
|
306
|
-
np.array([-1, -1, 1]),
|
|
307
|
-
]
|
|
308
|
-
for d in diagonals:
|
|
309
|
-
axis = d / np.linalg.norm(d)
|
|
310
|
-
for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
|
|
311
|
-
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
312
|
-
frames.append((R, np.array([0, 0, 0])))
|
|
313
|
-
|
|
314
|
-
# 3 rotations by 180° around coordinate axes
|
|
315
|
-
for axis in [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])]:
|
|
316
|
-
R = _rotation_matrix_from_axis_angle(axis, np.pi)
|
|
317
|
-
frames.append((R, np.array([0, 0, 0])))
|
|
318
|
-
|
|
319
|
-
return frames
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
def get_octahedral_frames():
|
|
323
|
-
"""
|
|
324
|
-
Get octahedral frames (O symmetry group, 24 elements).
|
|
325
|
-
The axes are computed from the geometry of a cube with vertices at (±1, ±1, ±1).
|
|
326
|
-
Returns:
|
|
327
|
-
frames: list of rotation matrices
|
|
328
|
-
"""
|
|
329
|
-
|
|
330
|
-
frames = []
|
|
331
|
-
|
|
332
|
-
# 8 vertices of the cube
|
|
333
|
-
vertices = []
|
|
334
|
-
for s1 in [1, -1]:
|
|
335
|
-
for s2 in [1, -1]:
|
|
336
|
-
for s3 in [1, -1]:
|
|
337
|
-
vertices.append(np.array([s1, s2, s3]))
|
|
338
|
-
vertices = np.array(vertices)
|
|
339
|
-
|
|
340
|
-
# 6 face centers of the cube (4-fold axes pass through these)
|
|
341
|
-
face_centers = [
|
|
342
|
-
np.array([1, 0, 0]),
|
|
343
|
-
np.array([-1, 0, 0]),
|
|
344
|
-
np.array([0, 1, 0]),
|
|
345
|
-
np.array([0, -1, 0]),
|
|
346
|
-
np.array([0, 0, 1]),
|
|
347
|
-
np.array([0, 0, -1]),
|
|
348
|
-
]
|
|
349
|
-
|
|
350
|
-
# Find edges (pairs of vertices differing in exactly one coordinate)
|
|
351
|
-
edges = []
|
|
352
|
-
for i in range(len(vertices)):
|
|
353
|
-
for j in range(i + 1, len(vertices)):
|
|
354
|
-
diff = np.abs(vertices[i] - vertices[j])
|
|
355
|
-
if np.sum(diff > 0) == 1: # Differ in exactly one coordinate
|
|
356
|
-
edges.append((i, j))
|
|
357
|
-
|
|
358
|
-
# Helper to get unique axis (normalize direction to avoid duplicates)
|
|
359
|
-
def normalize_axis(v):
|
|
360
|
-
axis = v / np.linalg.norm(v)
|
|
361
|
-
for c in axis:
|
|
362
|
-
if abs(c) > 1e-10:
|
|
363
|
-
if c < 0:
|
|
364
|
-
axis = -axis
|
|
365
|
-
break
|
|
366
|
-
return tuple(np.round(axis, 10))
|
|
367
|
-
|
|
368
|
-
# Identity
|
|
369
|
-
frames.append((np.eye(3), np.array([0, 0, 0])))
|
|
370
|
-
|
|
371
|
-
# 4-fold axes (through opposite face centers) - 3 axes
|
|
372
|
-
# Each gives rotations at 90°, 180°, 270° (we skip 0° = identity)
|
|
373
|
-
fourfold_axes_set = set()
|
|
374
|
-
for fc in face_centers:
|
|
375
|
-
axis_tuple = normalize_axis(fc)
|
|
376
|
-
fourfold_axes_set.add(axis_tuple)
|
|
377
|
-
|
|
378
|
-
for axis_tuple in fourfold_axes_set:
|
|
379
|
-
axis = np.array(axis_tuple)
|
|
380
|
-
for k in [1, 2, 3]: # 90°, 180°, 270°
|
|
381
|
-
angle = np.pi * k / 2
|
|
382
|
-
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
383
|
-
frames.append((R, np.array([0, 0, 0])))
|
|
384
|
-
|
|
385
|
-
# 3-fold axes (through opposite vertices) - 4 axes
|
|
386
|
-
# Each gives rotations at 120°, 240°
|
|
387
|
-
threefold_axes_set = set()
|
|
388
|
-
for v in vertices:
|
|
389
|
-
axis_tuple = normalize_axis(v)
|
|
390
|
-
threefold_axes_set.add(axis_tuple)
|
|
391
|
-
|
|
392
|
-
for axis_tuple in threefold_axes_set:
|
|
393
|
-
axis = np.array(axis_tuple)
|
|
394
|
-
for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
|
|
395
|
-
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
396
|
-
frames.append((R, np.array([0, 0, 0])))
|
|
397
|
-
|
|
398
|
-
# 2-fold axes (through opposite edge midpoints) - 6 axes
|
|
399
|
-
# Each gives 1 rotation at 180°
|
|
400
|
-
twofold_axes_set = set()
|
|
401
|
-
for i, j in edges:
|
|
402
|
-
midpoint = (vertices[i] + vertices[j]) / 2
|
|
403
|
-
axis_tuple = normalize_axis(midpoint)
|
|
404
|
-
twofold_axes_set.add(axis_tuple)
|
|
405
|
-
|
|
406
|
-
for axis_tuple in twofold_axes_set:
|
|
407
|
-
axis = np.array(axis_tuple)
|
|
408
|
-
R = _rotation_matrix_from_axis_angle(axis, np.pi)
|
|
409
|
-
frames.append((R, np.array([0, 0, 0])))
|
|
410
|
-
|
|
411
|
-
return frames
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
def get_icosahedral_frames():
|
|
415
|
-
"""
|
|
416
|
-
Get icosahedral frames (I symmetry group, 60 elements).
|
|
417
|
-
The axes are computed from the geometry of a regular icosahedron with
|
|
418
|
-
vertices at (0, ±1, ±φ), (±1, ±φ, 0), (±φ, 0, ±1) where φ is the golden ratio.
|
|
419
|
-
Returns:
|
|
420
|
-
frames: list of rotation matrices
|
|
421
|
-
"""
|
|
422
|
-
|
|
423
|
-
frames = []
|
|
424
|
-
|
|
425
|
-
# Golden ratio
|
|
426
|
-
phi = (1 + np.sqrt(5)) / 2
|
|
427
|
-
|
|
428
|
-
# 12 vertices of the icosahedron
|
|
429
|
-
vertices = []
|
|
430
|
-
for s1 in [1, -1]:
|
|
431
|
-
for s2 in [1, -1]:
|
|
432
|
-
vertices.append(np.array([0, s1 * 1, s2 * phi]))
|
|
433
|
-
vertices.append(np.array([s1 * 1, s2 * phi, 0]))
|
|
434
|
-
vertices.append(np.array([s2 * phi, 0, s1 * 1]))
|
|
435
|
-
vertices = np.array(vertices)
|
|
436
|
-
|
|
437
|
-
# Find edges (pairs of vertices at distance 2)
|
|
438
|
-
edges = []
|
|
439
|
-
for i in range(len(vertices)):
|
|
440
|
-
for j in range(i + 1, len(vertices)):
|
|
441
|
-
dist_sq = np.sum((vertices[i] - vertices[j]) ** 2)
|
|
442
|
-
if np.isclose(dist_sq, 4.0):
|
|
443
|
-
edges.append((i, j))
|
|
444
|
-
|
|
445
|
-
# Find faces (triangles of mutually adjacent vertices)
|
|
446
|
-
edge_set = set(edges)
|
|
447
|
-
faces = []
|
|
448
|
-
for i in range(len(vertices)):
|
|
449
|
-
for j in range(i + 1, len(vertices)):
|
|
450
|
-
for k in range(j + 1, len(vertices)):
|
|
451
|
-
if (i, j) in edge_set and (j, k) in edge_set and (i, k) in edge_set:
|
|
452
|
-
faces.append((i, j, k))
|
|
453
|
-
|
|
454
|
-
# Helper to get unique axis (normalize direction to avoid duplicates)
|
|
455
|
-
def normalize_axis(v):
|
|
456
|
-
axis = v / np.linalg.norm(v)
|
|
457
|
-
# Make first significant component positive to avoid duplicate opposite axes
|
|
458
|
-
for c in axis:
|
|
459
|
-
if abs(c) > 1e-10:
|
|
460
|
-
if c < 0:
|
|
461
|
-
axis = -axis
|
|
462
|
-
break
|
|
463
|
-
return tuple(np.round(axis, 10))
|
|
464
|
-
|
|
465
|
-
# Identity
|
|
466
|
-
frames.append((np.eye(3), np.array([0, 0, 0])))
|
|
467
|
-
|
|
468
|
-
# 5-fold axes (through opposite vertices) - 6 axes, 4 rotations each = 24
|
|
469
|
-
fivefold_axes_set = set()
|
|
470
|
-
for v in vertices:
|
|
471
|
-
axis_tuple = normalize_axis(v)
|
|
472
|
-
fivefold_axes_set.add(axis_tuple)
|
|
473
|
-
|
|
474
|
-
for axis_tuple in fivefold_axes_set:
|
|
475
|
-
axis = np.array(axis_tuple)
|
|
476
|
-
for k in [1, 2, 3, 4]:
|
|
477
|
-
angle = 2 * np.pi * k / 5
|
|
478
|
-
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
479
|
-
frames.append((R, np.array([0, 0, 0])))
|
|
480
|
-
|
|
481
|
-
# 3-fold axes (through opposite face centers) - 10 axes, 2 rotations each = 20
|
|
482
|
-
threefold_axes_set = set()
|
|
483
|
-
for i, j, k in faces:
|
|
484
|
-
center = (vertices[i] + vertices[j] + vertices[k]) / 3
|
|
485
|
-
axis_tuple = normalize_axis(center)
|
|
486
|
-
threefold_axes_set.add(axis_tuple)
|
|
487
|
-
|
|
488
|
-
for axis_tuple in threefold_axes_set:
|
|
489
|
-
axis = np.array(axis_tuple)
|
|
490
|
-
for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
|
|
491
|
-
R = _rotation_matrix_from_axis_angle(axis, angle)
|
|
492
|
-
frames.append((R, np.array([0, 0, 0])))
|
|
493
|
-
|
|
494
|
-
# 2-fold axes (through opposite edge midpoints) - 15 axes, 1 rotation each = 15
|
|
495
|
-
twofold_axes_set = set()
|
|
496
|
-
for i, j in edges:
|
|
497
|
-
midpoint = (vertices[i] + vertices[j]) / 2
|
|
498
|
-
axis_tuple = normalize_axis(midpoint)
|
|
499
|
-
twofold_axes_set.add(axis_tuple)
|
|
500
|
-
|
|
501
|
-
for axis_tuple in twofold_axes_set:
|
|
502
|
-
axis = np.array(axis_tuple)
|
|
503
|
-
R = _rotation_matrix_from_axis_angle(axis, np.pi)
|
|
504
|
-
frames.append((R, np.array([0, 0, 0])))
|
|
505
|
-
|
|
506
|
-
return frames
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
def _rotation_matrix_from_axis_angle(axis, angle):
|
|
510
|
-
"""
|
|
511
|
-
Compute a rotation matrix from an axis and angle using Rodrigues' formula.
|
|
512
|
-
Arguments:
|
|
513
|
-
axis: unit vector of the rotation axis
|
|
514
|
-
angle: rotation angle in radians
|
|
515
|
-
Returns:
|
|
516
|
-
R: 3x3 rotation matrix
|
|
517
|
-
"""
|
|
518
|
-
|
|
519
|
-
axis = axis / np.linalg.norm(axis)
|
|
520
|
-
K = np.array(
|
|
521
|
-
[
|
|
522
|
-
[0, -axis[2], axis[1]],
|
|
523
|
-
[axis[2], 0, -axis[0]],
|
|
524
|
-
[-axis[1], axis[0], 0],
|
|
525
|
-
]
|
|
526
|
-
)
|
|
527
|
-
R = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)
|
|
528
|
-
return R
|
|
529
|
-
|
|
530
|
-
|
|
531
283
|
def get_frames_from_file(file_path):
|
|
532
284
|
raise NotImplementedError("Input defined symmetry not implemented")
|
|
533
285
|
|
|
@@ -45,7 +45,7 @@ class SymmetryConfig(BaseModel):
|
|
|
45
45
|
)
|
|
46
46
|
id: Optional[str] = Field(
|
|
47
47
|
None,
|
|
48
|
-
description="Symmetry group ID.
|
|
48
|
+
description="Symmetry group ID. e.g. 'C3', 'D2'. Only C and D symmetry types are supported currently.",
|
|
49
49
|
)
|
|
50
50
|
is_unsym_motif: Optional[str] = Field(
|
|
51
51
|
None,
|
|
@@ -83,7 +83,7 @@ def make_symmetric_atom_array(
|
|
|
83
83
|
if not isinstance(sym_conf, SymmetryConfig):
|
|
84
84
|
sym_conf = convery_sym_conf_to_symmetry_config(sym_conf)
|
|
85
85
|
|
|
86
|
-
check_symmetry_config(
|
|
86
|
+
sym_conf = check_symmetry_config(
|
|
87
87
|
asu_atom_array, sym_conf, sm, has_dist_cond, src_atom_array=src_atom_array
|
|
88
88
|
)
|
|
89
89
|
# Adding utility annotations to the asu atom array
|
|
@@ -99,7 +99,6 @@ def make_symmetric_atom_array(
|
|
|
99
99
|
assert (
|
|
100
100
|
src_atom_array is not None
|
|
101
101
|
), "Source atom array must be provided for symmetric motifs"
|
|
102
|
-
# if symmetric motif is provided, get the frames from the src atom array.
|
|
103
102
|
frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
|
|
104
103
|
else:
|
|
105
104
|
# At this point, asym case would have been caught by the check_symmetry_config function.
|
rfd3/model/inference_sampler.py
CHANGED
|
@@ -120,8 +120,10 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
|
|
120
120
|
ranked_logger.info(
|
|
121
121
|
f"Using fallback: final step with t={noise_schedule[0].item():.6f}"
|
|
122
122
|
)
|
|
123
|
+
else:
|
|
124
|
+
noise_schedule = t_hat
|
|
123
125
|
|
|
124
|
-
return
|
|
126
|
+
return noise_schedule
|
|
125
127
|
|
|
126
128
|
def _get_initial_structure(
|
|
127
129
|
self,
|
rfd3/transforms/hbonds_hbplus.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import string
|
|
3
3
|
import subprocess
|
|
4
|
+
import tempfile
|
|
4
5
|
from datetime import datetime
|
|
5
6
|
from typing import Any, Tuple
|
|
6
7
|
|
|
@@ -66,10 +67,6 @@ def calculate_hbonds(
|
|
|
66
67
|
cutoff_HA_dist: float = 3,
|
|
67
68
|
cutoff_DA_distance: float = 3.5,
|
|
68
69
|
) -> Tuple[np.ndarray, np.ndarray, AtomArray]:
|
|
69
|
-
dtstr = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
70
|
-
pdb_path = f"{dtstr}_{np.random.randint(10000)}.pdb"
|
|
71
|
-
atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path)
|
|
72
|
-
|
|
73
70
|
hbplus_exe = os.environ.get("HBPLUS_PATH")
|
|
74
71
|
|
|
75
72
|
if hbplus_exe is None or hbplus_exe == "":
|
|
@@ -78,49 +75,57 @@ def calculate_hbonds(
|
|
|
78
75
|
"Please set it to the path of the hbplus executable in order to calculate hydrogen bonds."
|
|
79
76
|
)
|
|
80
77
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
"
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
78
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
79
|
+
dtstr = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
80
|
+
pdb_filename = f"{dtstr}_{np.random.randint(10000)}.pdb"
|
|
81
|
+
pdb_path = os.path.join(tmpdir, pdb_filename)
|
|
82
|
+
atom_array, _, chain_map = save_atomarray_to_pdb(atom_array, pdb_path)
|
|
83
|
+
|
|
84
|
+
subprocess.call(
|
|
85
|
+
[
|
|
86
|
+
hbplus_exe,
|
|
87
|
+
"-h",
|
|
88
|
+
str(cutoff_HA_dist),
|
|
89
|
+
"-d",
|
|
90
|
+
str(cutoff_DA_distance),
|
|
91
|
+
pdb_path,
|
|
92
|
+
pdb_path,
|
|
93
|
+
],
|
|
94
|
+
stdout=subprocess.DEVNULL,
|
|
95
|
+
stderr=subprocess.DEVNULL,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
hb2_path = pdb_path.replace(".pdb", ".hb2")
|
|
99
|
+
with open(hb2_path, "r") as hb_file:
|
|
100
|
+
HB = hb_file.readlines()
|
|
101
|
+
hbonds = []
|
|
102
|
+
for i in range(8, len(HB)):
|
|
103
|
+
d_chain = HB[i][0]
|
|
104
|
+
d_resi = str(int(HB[i][1:5].strip()))
|
|
105
|
+
d_resn = HB[i][6:9].strip()
|
|
106
|
+
d_ins = HB[i][5].replace("-", " ")
|
|
107
|
+
d_atom = HB[i][9:13].strip()
|
|
108
|
+
a_chain = HB[i][14]
|
|
109
|
+
a_resi = str(int(HB[i][15:19].strip()))
|
|
110
|
+
a_ins = HB[i][19].replace("-", " ")
|
|
111
|
+
a_resn = HB[i][20:23].strip()
|
|
112
|
+
a_atom = HB[i][23:27].strip()
|
|
113
|
+
dist = float(HB[i][27:32].strip())
|
|
114
|
+
|
|
115
|
+
items = {
|
|
116
|
+
"d_chain": chain_map[d_chain],
|
|
117
|
+
"d_resi": d_resi,
|
|
118
|
+
"d_resn": d_resn,
|
|
119
|
+
"d_ins": d_ins,
|
|
120
|
+
"d_atom": d_atom,
|
|
121
|
+
"a_chain": chain_map[a_chain],
|
|
122
|
+
"a_resi": a_resi,
|
|
123
|
+
"a_resn": a_resn,
|
|
124
|
+
"a_ins": a_ins,
|
|
125
|
+
"a_atom": a_atom,
|
|
126
|
+
"dist": dist,
|
|
127
|
+
}
|
|
128
|
+
hbonds.append(items)
|
|
124
129
|
|
|
125
130
|
donor_array = np.zeros(len(atom_array))
|
|
126
131
|
acceptor_array = np.zeros(len(atom_array))
|
|
@@ -162,8 +167,6 @@ def calculate_hbonds(
|
|
|
162
167
|
donor_array[donor_mask] = 1
|
|
163
168
|
acceptor_array[acceptor_mask] = 1
|
|
164
169
|
|
|
165
|
-
os.remove(pdb_path)
|
|
166
|
-
os.remove(pdb_path.replace("pdb", "hb2"))
|
|
167
170
|
atom_array.set_annotation("active_donor", donor_array)
|
|
168
171
|
atom_array.set_annotation("active_acceptor", acceptor_array)
|
|
169
172
|
|
rfd3/transforms/symmetry.py
CHANGED
|
@@ -60,22 +60,13 @@ class AddSymmetryFeats(Transform):
|
|
|
60
60
|
)
|
|
61
61
|
TIDs = torch.from_numpy(atom_array.get_annotation("sym_transform_id"))
|
|
62
62
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
if first_occurrence[tid_idx] == 0 or i < first_occurrence[tid_idx]:
|
|
71
|
-
first_occurrence[tid_idx] = i
|
|
72
|
-
|
|
73
|
-
# Extract Ori, X, Y for each unique transform
|
|
74
|
-
Oris = Oris[first_occurrence]
|
|
75
|
-
Xs = Xs[first_occurrence]
|
|
76
|
-
Ys = Ys[first_occurrence]
|
|
77
|
-
TIDs = unique_TIDs
|
|
78
|
-
|
|
63
|
+
Oris = torch.unique_consecutive(Oris, dim=0)
|
|
64
|
+
Xs = torch.unique_consecutive(Xs, dim=0)
|
|
65
|
+
Ys = torch.unique_consecutive(Ys, dim=0)
|
|
66
|
+
TIDs = torch.unique_consecutive(TIDs, dim=0)
|
|
67
|
+
# the case in which there is only rotation (no translation), Ori = [0,0,0]
|
|
68
|
+
if len(Oris) == 1 and (Oris == 0).all():
|
|
69
|
+
Oris = Oris.repeat(len(Xs), 1)
|
|
79
70
|
Rs, Ts = framecoords_to_RTs(Oris, Xs, Ys)
|
|
80
71
|
|
|
81
72
|
for R, T, transform_id in zip(Rs, Ts, TIDs):
|
rfd3/utils/inference.py
CHANGED
|
@@ -373,12 +373,13 @@ def ensure_inference_sampler_matches_design_spec(
|
|
|
373
373
|
design_spec: Design specification dictionary
|
|
374
374
|
inference_sampler: Inference sampler dictionary
|
|
375
375
|
"""
|
|
376
|
-
has_symmetry_specification = [
|
|
377
|
-
|
|
378
|
-
if
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
376
|
+
has_symmetry_specification = []
|
|
377
|
+
for item in design_spec.values():
|
|
378
|
+
if hasattr(item, "symmetry"):
|
|
379
|
+
has_symmetry = item.symmetry is not None
|
|
380
|
+
else:
|
|
381
|
+
has_symmetry = "symmetry" in item and item.get("symmetry") is not None
|
|
382
|
+
has_symmetry_specification.append(has_symmetry)
|
|
382
383
|
if any(has_symmetry_specification):
|
|
383
384
|
if (
|
|
384
385
|
inference_sampler is None
|
|
File without changes
|
|
File without changes
|