rc-foundry 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@
2
2
 
3
3
  hydra:
4
4
  searchpath:
5
+ - pkg://rfd3.configs
5
6
  - pkg://configs
6
7
 
7
8
  defaults:
rfd3/engine.py CHANGED
@@ -5,7 +5,7 @@ import time
5
5
  from dataclasses import dataclass, field
6
6
  from os import PathLike
7
7
  from pathlib import Path
8
- from typing import Any, Dict, List, Optional
8
+ from typing import Dict, List, Optional
9
9
 
10
10
  import torch
11
11
  import yaml
@@ -48,9 +48,8 @@ class RFD3InferenceConfig:
48
48
  diffusion_batch_size: int = 16
49
49
 
50
50
  # RFD3 specific
51
- skip_existing: bool = False
52
- json_keys_subset: Optional[List[str]] = None
53
51
  skip_existing: bool = True
52
+ json_keys_subset: Optional[List[str]] = None
54
53
  specification: Optional[dict] = field(default_factory=dict)
55
54
  inference_sampler: SampleDiffusionConfig | dict = field(default_factory=dict)
56
55
 
@@ -216,6 +215,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
216
215
  inputs=inputs,
217
216
  n_batches=n_batches,
218
217
  )
218
+ if len(design_specifications) == 0:
219
+ ranked_logger.info("No design specifications to run. Skipping.")
220
+ return None
219
221
  ensure_inference_sampler_matches_design_spec(
220
222
  design_specifications, self.inference_sampler_overrides
221
223
  )
@@ -378,15 +380,21 @@ class RFD3InferenceEngine(BaseInferenceEngine):
378
380
 
379
381
  def _multiply_specifications(
380
382
  self, inputs: Dict[str, dict | DesignInputSpecification], n_batches=None
381
- ) -> Dict[str, Dict[str, Any]]:
383
+ ) -> Dict[str, dict | DesignInputSpecification]:
382
384
  # Find existing example IDS in output directory
383
385
  if exists(self.out_dir):
384
- existing_example_ids = set(
386
+ existing_example_ids_ = set(
385
387
  extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
386
388
  for path in find_files_with_extension(self.out_dir, CIF_LIKE_EXTENSIONS)
387
389
  )
390
+ existing_example_ids = set(
391
+ [
392
+ "_model_".join(eid.split("_model_")[:-1])
393
+ for eid in existing_example_ids_
394
+ ]
395
+ )
388
396
  ranked_logger.info(
389
- f"Found {len(existing_example_ids)} existing example IDs in the output directory."
397
+ f"Found {len(existing_example_ids)} existing example IDs in the output directory ({len(existing_example_ids_)} total)."
390
398
  )
391
399
 
392
400
  # Based on inputs, construct the specifications to loop through
@@ -405,7 +413,6 @@ class RFD3InferenceEngine(BaseInferenceEngine):
405
413
  for batch_id in range((n_batches) if exists(n_batches) else 1):
406
414
  # ... Example ID
407
415
  example_id = f"{prefix}_{batch_id}" if exists(n_batches) else prefix
408
-
409
416
  if (
410
417
  self.skip_existing
411
418
  and exists(self.out_dir)
@@ -128,8 +128,10 @@ class DesignInputSpecification(BaseModel):
128
128
  # Motif selection from input file
129
129
  contig: Optional[InputSelection] = Field(None, description="Contig specification string (e.g. 'A1-10,B1-5')")
130
130
  unindex: Optional[InputSelection] = Field(None,
131
- description="Unindexed components string (components must not overlap with contig). "\
132
- "E.g. 'A15-20,B6-10' or dict. We recommend specifying")
131
+ description="Unindexed components selection. Components to fix in the generated structure without specifying sequence index. "\
132
+ "Components must not overlap with `contig` argument. "\
133
+ "E.g. 'A15-20,B6-10' or dict. We recommend specifying unindexed residues as a contig string, "\
134
+ "then using select_fixed_atoms will subset the atoms to the specified atoms")
133
135
  # Extra args:
134
136
  length: Optional[str] = Field(None, description="Length range as 'min-max' or int. Constrains length of contig if provided")
135
137
  ligand: Optional[str] = Field(None, description="Ligand name or index to include in design.")
@@ -1,5 +1,3 @@
1
- import string
2
-
3
1
  import numpy as np
4
2
  from rfd3.inference.symmetry.frames import (
5
3
  decompose_symmetry_frame,
@@ -9,68 +7,6 @@ from rfd3.inference.symmetry.frames import (
9
7
  FIXED_TRANSFORM_ID = -1
10
8
  FIXED_ENTITY_ID = -1
11
9
 
12
- # Alphabet for chain ID generation (uppercase letters only, per wwPDB convention)
13
- _CHAIN_ALPHABET = string.ascii_uppercase
14
-
15
-
16
- def index_to_chain_id(index: int) -> str:
17
- """
18
- Convert a zero-based index to a chain ID following wwPDB convention.
19
-
20
- The naming follows the wwPDB-assigned chain ID system:
21
- - 0-25: A-Z (single letter)
22
- - 26-701: AA-ZZ (double letter)
23
- - 702-18277: AAA-ZZZ (triple letter)
24
- - And so on...
25
-
26
- This is similar to Excel column naming (A, B, ..., Z, AA, AB, ...).
27
-
28
- Arguments:
29
- index: zero-based index (0 -> 'A', 25 -> 'Z', 26 -> 'AA', etc.)
30
- Returns:
31
- chain_id: string chain identifier
32
- """
33
- if index < 0:
34
- raise ValueError(f"Chain index must be non-negative, got {index}")
35
-
36
- result = ""
37
- remaining = index
38
-
39
- # Convert to bijective base-26 (like Excel columns)
40
- while True:
41
- result = _CHAIN_ALPHABET[remaining % 26] + result
42
- remaining = remaining // 26 - 1
43
- if remaining < 0:
44
- break
45
-
46
- return result
47
-
48
-
49
- def chain_id_to_index(chain_id: str) -> int:
50
- """
51
- Convert a chain ID back to a zero-based index.
52
-
53
- Inverse of index_to_chain_id.
54
-
55
- Arguments:
56
- chain_id: string chain identifier (e.g., 'A', 'Z', 'AA', 'AB')
57
- Returns:
58
- index: zero-based index
59
- """
60
- if not chain_id or not all(c in _CHAIN_ALPHABET for c in chain_id):
61
- raise ValueError(f"Invalid chain ID: {chain_id}")
62
-
63
- # Offset for all shorter chain IDs (26 + 26^2 + ... + 26^(len-1))
64
- offset = sum(26**k for k in range(1, len(chain_id)))
65
-
66
- # Value within the current length group (standard base-26)
67
- value = 0
68
- for char in chain_id:
69
- value = value * 26 + _CHAIN_ALPHABET.index(char)
70
-
71
- return offset + value
72
-
73
-
74
10
  ########################################################
75
11
  # Symmetry annotations
76
12
  ########################################################
@@ -311,13 +247,11 @@ def reset_chain_ids(atom_array, start_id):
311
247
  Reset the chain ids and pn_unit_iids of an atom array to start from the given id.
312
248
  Arguments:
313
249
  atom_array: atom array with chain_ids and pn_unit_iids annotated
314
- start_id: starting chain ID (e.g., 'A')
315
250
  """
316
251
  chain_ids = np.unique(atom_array.chain_id)
317
- start_index = chain_id_to_index(start_id)
318
- for i, old_id in enumerate(chain_ids):
319
- new_id = index_to_chain_id(start_index + i)
320
- atom_array.chain_id[atom_array.chain_id == old_id] = new_id
252
+ new_chain_range = range(ord(start_id), ord(start_id) + len(chain_ids))
253
+ for new_id, old_id in zip(new_chain_range, chain_ids):
254
+ atom_array.chain_id[atom_array.chain_id == old_id] = chr(new_id)
321
255
  atom_array.pn_unit_iid = atom_array.chain_id
322
256
  return atom_array
323
257
 
@@ -325,18 +259,15 @@ def reset_chain_ids(atom_array, start_id):
325
259
  def reannotate_chain_ids(atom_array, offset, multiplier=0):
326
260
  """
327
261
  Reannotate the chain ids and pn_unit_iids of an atom array.
328
-
329
- Uses wwPDB-style chain IDs (A-Z, AA-ZZ, AAA-ZZZ, ...) to support
330
- any number of chains.
331
-
332
262
  Arguments:
333
263
  atom_array: protein atom array with chain_ids and pn_unit_iids annotated
334
- offset: offset to add to the chain ids (typically num_chains in ASU)
335
- multiplier: multiplier for the offset (typically transform index)
264
+ offset: offset to add to the chain ids
265
+ multiplier: multiplier to add to the chain ids
336
266
  """
337
- chain_ids_indices = np.array([chain_id_to_index(c) for c in atom_array.chain_id])
338
- new_indices = chain_ids_indices + offset * multiplier
339
- chain_ids = np.array([index_to_chain_id(idx) for idx in new_indices], dtype="U4")
267
+ chain_ids_int = (
268
+ np.array([ord(c) for c in atom_array.chain_id]) + offset * multiplier
269
+ )
270
+ chain_ids = np.array([chr(id) for id in chain_ids_int], dtype=str)
340
271
  atom_array.chain_id = chain_ids
341
272
  atom_array.pn_unit_iid = chain_ids
342
273
  return atom_array
@@ -24,12 +24,6 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
24
24
  elif symmetry_id.lower().startswith("d"):
25
25
  order = int(symmetry_id[1:])
26
26
  frames = get_dihedral_frames(order)
27
- elif symmetry_id.lower() == "t":
28
- frames = get_tetrahedral_frames()
29
- elif symmetry_id.lower() == "o":
30
- frames = get_octahedral_frames()
31
- elif symmetry_id.lower() == "i":
32
- frames = get_icosahedral_frames()
33
27
  elif symmetry_id.lower() == "input_defined":
34
28
  assert (
35
29
  sym_conf.symmetry_file is not None
@@ -286,248 +280,6 @@ def get_dihedral_frames(order):
286
280
  return frames
287
281
 
288
282
 
289
- def get_tetrahedral_frames():
290
- """
291
- Get tetrahedral frames (T symmetry group, 12 elements).
292
- Returns:
293
- frames: list of rotation matrices
294
- """
295
-
296
- frames = []
297
-
298
- # Identity
299
- frames.append((np.eye(3), np.array([0, 0, 0])))
300
-
301
- # 8 rotations by ±120° around body diagonals (±1, ±1, ±1)
302
- diagonals = [
303
- np.array([1, 1, 1]),
304
- np.array([1, -1, -1]),
305
- np.array([-1, 1, -1]),
306
- np.array([-1, -1, 1]),
307
- ]
308
- for d in diagonals:
309
- axis = d / np.linalg.norm(d)
310
- for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
311
- R = _rotation_matrix_from_axis_angle(axis, angle)
312
- frames.append((R, np.array([0, 0, 0])))
313
-
314
- # 3 rotations by 180° around coordinate axes
315
- for axis in [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])]:
316
- R = _rotation_matrix_from_axis_angle(axis, np.pi)
317
- frames.append((R, np.array([0, 0, 0])))
318
-
319
- return frames
320
-
321
-
322
- def get_octahedral_frames():
323
- """
324
- Get octahedral frames (O symmetry group, 24 elements).
325
- The axes are computed from the geometry of a cube with vertices at (±1, ±1, ±1).
326
- Returns:
327
- frames: list of rotation matrices
328
- """
329
-
330
- frames = []
331
-
332
- # 8 vertices of the cube
333
- vertices = []
334
- for s1 in [1, -1]:
335
- for s2 in [1, -1]:
336
- for s3 in [1, -1]:
337
- vertices.append(np.array([s1, s2, s3]))
338
- vertices = np.array(vertices)
339
-
340
- # 6 face centers of the cube (4-fold axes pass through these)
341
- face_centers = [
342
- np.array([1, 0, 0]),
343
- np.array([-1, 0, 0]),
344
- np.array([0, 1, 0]),
345
- np.array([0, -1, 0]),
346
- np.array([0, 0, 1]),
347
- np.array([0, 0, -1]),
348
- ]
349
-
350
- # Find edges (pairs of vertices differing in exactly one coordinate)
351
- edges = []
352
- for i in range(len(vertices)):
353
- for j in range(i + 1, len(vertices)):
354
- diff = np.abs(vertices[i] - vertices[j])
355
- if np.sum(diff > 0) == 1: # Differ in exactly one coordinate
356
- edges.append((i, j))
357
-
358
- # Helper to get unique axis (normalize direction to avoid duplicates)
359
- def normalize_axis(v):
360
- axis = v / np.linalg.norm(v)
361
- for c in axis:
362
- if abs(c) > 1e-10:
363
- if c < 0:
364
- axis = -axis
365
- break
366
- return tuple(np.round(axis, 10))
367
-
368
- # Identity
369
- frames.append((np.eye(3), np.array([0, 0, 0])))
370
-
371
- # 4-fold axes (through opposite face centers) - 3 axes
372
- # Each gives rotations at 90°, 180°, 270° (we skip 0° = identity)
373
- fourfold_axes_set = set()
374
- for fc in face_centers:
375
- axis_tuple = normalize_axis(fc)
376
- fourfold_axes_set.add(axis_tuple)
377
-
378
- for axis_tuple in fourfold_axes_set:
379
- axis = np.array(axis_tuple)
380
- for k in [1, 2, 3]: # 90°, 180°, 270°
381
- angle = np.pi * k / 2
382
- R = _rotation_matrix_from_axis_angle(axis, angle)
383
- frames.append((R, np.array([0, 0, 0])))
384
-
385
- # 3-fold axes (through opposite vertices) - 4 axes
386
- # Each gives rotations at 120°, 240°
387
- threefold_axes_set = set()
388
- for v in vertices:
389
- axis_tuple = normalize_axis(v)
390
- threefold_axes_set.add(axis_tuple)
391
-
392
- for axis_tuple in threefold_axes_set:
393
- axis = np.array(axis_tuple)
394
- for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
395
- R = _rotation_matrix_from_axis_angle(axis, angle)
396
- frames.append((R, np.array([0, 0, 0])))
397
-
398
- # 2-fold axes (through opposite edge midpoints) - 6 axes
399
- # Each gives 1 rotation at 180°
400
- twofold_axes_set = set()
401
- for i, j in edges:
402
- midpoint = (vertices[i] + vertices[j]) / 2
403
- axis_tuple = normalize_axis(midpoint)
404
- twofold_axes_set.add(axis_tuple)
405
-
406
- for axis_tuple in twofold_axes_set:
407
- axis = np.array(axis_tuple)
408
- R = _rotation_matrix_from_axis_angle(axis, np.pi)
409
- frames.append((R, np.array([0, 0, 0])))
410
-
411
- return frames
412
-
413
-
414
- def get_icosahedral_frames():
415
- """
416
- Get icosahedral frames (I symmetry group, 60 elements).
417
- The axes are computed from the geometry of a regular icosahedron with
418
- vertices at (0, ±1, ±φ), (±1, ±φ, 0), (±φ, 0, ±1) where φ is the golden ratio.
419
- Returns:
420
- frames: list of rotation matrices
421
- """
422
-
423
- frames = []
424
-
425
- # Golden ratio
426
- phi = (1 + np.sqrt(5)) / 2
427
-
428
- # 12 vertices of the icosahedron
429
- vertices = []
430
- for s1 in [1, -1]:
431
- for s2 in [1, -1]:
432
- vertices.append(np.array([0, s1 * 1, s2 * phi]))
433
- vertices.append(np.array([s1 * 1, s2 * phi, 0]))
434
- vertices.append(np.array([s2 * phi, 0, s1 * 1]))
435
- vertices = np.array(vertices)
436
-
437
- # Find edges (pairs of vertices at distance 2)
438
- edges = []
439
- for i in range(len(vertices)):
440
- for j in range(i + 1, len(vertices)):
441
- dist_sq = np.sum((vertices[i] - vertices[j]) ** 2)
442
- if np.isclose(dist_sq, 4.0):
443
- edges.append((i, j))
444
-
445
- # Find faces (triangles of mutually adjacent vertices)
446
- edge_set = set(edges)
447
- faces = []
448
- for i in range(len(vertices)):
449
- for j in range(i + 1, len(vertices)):
450
- for k in range(j + 1, len(vertices)):
451
- if (i, j) in edge_set and (j, k) in edge_set and (i, k) in edge_set:
452
- faces.append((i, j, k))
453
-
454
- # Helper to get unique axis (normalize direction to avoid duplicates)
455
- def normalize_axis(v):
456
- axis = v / np.linalg.norm(v)
457
- # Make first significant component positive to avoid duplicate opposite axes
458
- for c in axis:
459
- if abs(c) > 1e-10:
460
- if c < 0:
461
- axis = -axis
462
- break
463
- return tuple(np.round(axis, 10))
464
-
465
- # Identity
466
- frames.append((np.eye(3), np.array([0, 0, 0])))
467
-
468
- # 5-fold axes (through opposite vertices) - 6 axes, 4 rotations each = 24
469
- fivefold_axes_set = set()
470
- for v in vertices:
471
- axis_tuple = normalize_axis(v)
472
- fivefold_axes_set.add(axis_tuple)
473
-
474
- for axis_tuple in fivefold_axes_set:
475
- axis = np.array(axis_tuple)
476
- for k in [1, 2, 3, 4]:
477
- angle = 2 * np.pi * k / 5
478
- R = _rotation_matrix_from_axis_angle(axis, angle)
479
- frames.append((R, np.array([0, 0, 0])))
480
-
481
- # 3-fold axes (through opposite face centers) - 10 axes, 2 rotations each = 20
482
- threefold_axes_set = set()
483
- for i, j, k in faces:
484
- center = (vertices[i] + vertices[j] + vertices[k]) / 3
485
- axis_tuple = normalize_axis(center)
486
- threefold_axes_set.add(axis_tuple)
487
-
488
- for axis_tuple in threefold_axes_set:
489
- axis = np.array(axis_tuple)
490
- for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
491
- R = _rotation_matrix_from_axis_angle(axis, angle)
492
- frames.append((R, np.array([0, 0, 0])))
493
-
494
- # 2-fold axes (through opposite edge midpoints) - 15 axes, 1 rotation each = 15
495
- twofold_axes_set = set()
496
- for i, j in edges:
497
- midpoint = (vertices[i] + vertices[j]) / 2
498
- axis_tuple = normalize_axis(midpoint)
499
- twofold_axes_set.add(axis_tuple)
500
-
501
- for axis_tuple in twofold_axes_set:
502
- axis = np.array(axis_tuple)
503
- R = _rotation_matrix_from_axis_angle(axis, np.pi)
504
- frames.append((R, np.array([0, 0, 0])))
505
-
506
- return frames
507
-
508
-
509
- def _rotation_matrix_from_axis_angle(axis, angle):
510
- """
511
- Compute a rotation matrix from an axis and angle using Rodrigues' formula.
512
- Arguments:
513
- axis: unit vector of the rotation axis
514
- angle: rotation angle in radians
515
- Returns:
516
- R: 3x3 rotation matrix
517
- """
518
-
519
- axis = axis / np.linalg.norm(axis)
520
- K = np.array(
521
- [
522
- [0, -axis[2], axis[1]],
523
- [axis[2], 0, -axis[0]],
524
- [-axis[1], axis[0], 0],
525
- ]
526
- )
527
- R = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)
528
- return R
529
-
530
-
531
283
  def get_frames_from_file(file_path):
532
284
  raise NotImplementedError("Input defined symmetry not implemented")
533
285
 
@@ -45,7 +45,7 @@ class SymmetryConfig(BaseModel):
45
45
  )
46
46
  id: Optional[str] = Field(
47
47
  None,
48
- description="Symmetry group ID. Supported types: Cyclic (C), Dihedral (D), Tetrahedral (T), Octahedral (O), Icosahedral (I). e.g. 'C3', 'D2', 'T', 'O', 'I'.",
48
+ description="Symmetry group ID. e.g. 'C3', 'D2'. Only C and D symmetry types are supported currently.",
49
49
  )
50
50
  is_unsym_motif: Optional[str] = Field(
51
51
  None,
@@ -83,7 +83,7 @@ def make_symmetric_atom_array(
83
83
  if not isinstance(sym_conf, SymmetryConfig):
84
84
  sym_conf = convery_sym_conf_to_symmetry_config(sym_conf)
85
85
 
86
- check_symmetry_config(
86
+ sym_conf = check_symmetry_config(
87
87
  asu_atom_array, sym_conf, sm, has_dist_cond, src_atom_array=src_atom_array
88
88
  )
89
89
  # Adding utility annotations to the asu atom array
@@ -99,7 +99,6 @@ def make_symmetric_atom_array(
99
99
  assert (
100
100
  src_atom_array is not None
101
101
  ), "Source atom array must be provided for symmetric motifs"
102
- # if symmetric motif is provided, get the frames from the src atom array.
103
102
  frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
104
103
  else:
105
104
  # At this point, asym case would have been caught by the check_symmetry_config function.
@@ -120,8 +120,10 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
120
120
  ranked_logger.info(
121
121
  f"Using fallback: final step with t={noise_schedule[0].item():.6f}"
122
122
  )
123
+ else:
124
+ noise_schedule = t_hat
123
125
 
124
- return t_hat
126
+ return noise_schedule
125
127
 
126
128
  def _get_initial_structure(
127
129
  self,
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import string
3
3
  import subprocess
4
+ import tempfile
4
5
  from datetime import datetime
5
6
  from typing import Any, Tuple
6
7
 
@@ -66,10 +67,6 @@ def calculate_hbonds(
66
67
  cutoff_HA_dist: float = 3,
67
68
  cutoff_DA_distance: float = 3.5,
68
69
  ) -> Tuple[np.ndarray, np.ndarray, AtomArray]:
69
- dtstr = datetime.now().strftime("%Y%m%d%H%M%S")
70
- pdb_path = f"{dtstr}_{np.random.randint(10000)}.pdb"
71
- atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path)
72
-
73
70
  hbplus_exe = os.environ.get("HBPLUS_PATH")
74
71
 
75
72
  if hbplus_exe is None or hbplus_exe == "":
@@ -78,49 +75,57 @@ def calculate_hbonds(
78
75
  "Please set it to the path of the hbplus executable in order to calculate hydrogen bonds."
79
76
  )
80
77
 
81
- subprocess.call(
82
- [
83
- hbplus_exe,
84
- "-h",
85
- str(cutoff_HA_dist),
86
- "-d",
87
- str(cutoff_DA_distance),
88
- pdb_path,
89
- pdb_path,
90
- ],
91
- stdout=subprocess.DEVNULL,
92
- stderr=subprocess.DEVNULL,
93
- )
94
-
95
- HB = open(pdb_path.replace("pdb", "hb2"), "r").readlines()
96
- hbonds = []
97
- for i in range(8, len(HB)):
98
- d_chain = HB[i][0]
99
- d_resi = str(int(HB[i][1:5].strip()))
100
- d_resn = HB[i][6:9].strip()
101
- d_ins = HB[i][5].replace("-", " ")
102
- d_atom = HB[i][9:13].strip()
103
- a_chain = HB[i][14]
104
- a_resi = str(int(HB[i][15:19].strip()))
105
- a_ins = HB[i][19].replace("-", " ")
106
- a_resn = HB[i][20:23].strip()
107
- a_atom = HB[i][23:27].strip()
108
- dist = float(HB[i][27:32].strip())
109
-
110
- items = {
111
- "d_chain": chain_map[d_chain],
112
- "d_resi": d_resi,
113
- "d_resn": d_resn,
114
- "d_ins": d_ins,
115
- "d_atom": d_atom,
116
- "a_chain": chain_map[a_chain],
117
- "a_resi": a_resi,
118
- "a_resn": a_resn,
119
- "a_ins": a_ins,
120
- "a_atom": a_atom,
121
- "dist": dist,
122
- }
123
- hbonds.append(items)
78
+ with tempfile.TemporaryDirectory() as tmpdir:
79
+ dtstr = datetime.now().strftime("%Y%m%d%H%M%S")
80
+ pdb_filename = f"{dtstr}_{np.random.randint(10000)}.pdb"
81
+ pdb_path = os.path.join(tmpdir, pdb_filename)
82
+ atom_array, _, chain_map = save_atomarray_to_pdb(atom_array, pdb_path)
83
+
84
+ subprocess.call(
85
+ [
86
+ hbplus_exe,
87
+ "-h",
88
+ str(cutoff_HA_dist),
89
+ "-d",
90
+ str(cutoff_DA_distance),
91
+ pdb_path,
92
+ pdb_path,
93
+ ],
94
+ stdout=subprocess.DEVNULL,
95
+ stderr=subprocess.DEVNULL,
96
+ )
97
+
98
+ hb2_path = pdb_path.replace(".pdb", ".hb2")
99
+ with open(hb2_path, "r") as hb_file:
100
+ HB = hb_file.readlines()
101
+ hbonds = []
102
+ for i in range(8, len(HB)):
103
+ d_chain = HB[i][0]
104
+ d_resi = str(int(HB[i][1:5].strip()))
105
+ d_resn = HB[i][6:9].strip()
106
+ d_ins = HB[i][5].replace("-", " ")
107
+ d_atom = HB[i][9:13].strip()
108
+ a_chain = HB[i][14]
109
+ a_resi = str(int(HB[i][15:19].strip()))
110
+ a_ins = HB[i][19].replace("-", " ")
111
+ a_resn = HB[i][20:23].strip()
112
+ a_atom = HB[i][23:27].strip()
113
+ dist = float(HB[i][27:32].strip())
114
+
115
+ items = {
116
+ "d_chain": chain_map[d_chain],
117
+ "d_resi": d_resi,
118
+ "d_resn": d_resn,
119
+ "d_ins": d_ins,
120
+ "d_atom": d_atom,
121
+ "a_chain": chain_map[a_chain],
122
+ "a_resi": a_resi,
123
+ "a_resn": a_resn,
124
+ "a_ins": a_ins,
125
+ "a_atom": a_atom,
126
+ "dist": dist,
127
+ }
128
+ hbonds.append(items)
124
129
 
125
130
  donor_array = np.zeros(len(atom_array))
126
131
  acceptor_array = np.zeros(len(atom_array))
@@ -162,8 +167,6 @@ def calculate_hbonds(
162
167
  donor_array[donor_mask] = 1
163
168
  acceptor_array[acceptor_mask] = 1
164
169
 
165
- os.remove(pdb_path)
166
- os.remove(pdb_path.replace("pdb", "hb2"))
167
170
  atom_array.set_annotation("active_donor", donor_array)
168
171
  atom_array.set_annotation("active_acceptor", acceptor_array)
169
172
 
@@ -60,22 +60,13 @@ class AddSymmetryFeats(Transform):
60
60
  )
61
61
  TIDs = torch.from_numpy(atom_array.get_annotation("sym_transform_id"))
62
62
 
63
- # Get unique transforms by TID (more robust than unique_consecutive on each array)
64
- unique_TIDs, inverse_indices = torch.unique(TIDs, return_inverse=True)
65
-
66
- # Get the first occurrence of each unique TID
67
- first_occurrence = torch.zeros(len(unique_TIDs), dtype=torch.long)
68
- for i in range(len(TIDs)):
69
- tid_idx = inverse_indices[i]
70
- if first_occurrence[tid_idx] == 0 or i < first_occurrence[tid_idx]:
71
- first_occurrence[tid_idx] = i
72
-
73
- # Extract Ori, X, Y for each unique transform
74
- Oris = Oris[first_occurrence]
75
- Xs = Xs[first_occurrence]
76
- Ys = Ys[first_occurrence]
77
- TIDs = unique_TIDs
78
-
63
+ Oris = torch.unique_consecutive(Oris, dim=0)
64
+ Xs = torch.unique_consecutive(Xs, dim=0)
65
+ Ys = torch.unique_consecutive(Ys, dim=0)
66
+ TIDs = torch.unique_consecutive(TIDs, dim=0)
67
+ # the case in which there is only rotation (no translation), Ori = [0,0,0]
68
+ if len(Oris) == 1 and (Oris == 0).all():
69
+ Oris = Oris.repeat(len(Xs), 1)
79
70
  Rs, Ts = framecoords_to_RTs(Oris, Xs, Ys)
80
71
 
81
72
  for R, T, transform_id in zip(Rs, Ts, TIDs):