rc-foundry 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,12 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
24
24
  elif symmetry_id.lower().startswith("d"):
25
25
  order = int(symmetry_id[1:])
26
26
  frames = get_dihedral_frames(order)
27
+ elif symmetry_id.lower() == "t":
28
+ frames = get_tetrahedral_frames()
29
+ elif symmetry_id.lower() == "o":
30
+ frames = get_octahedral_frames()
31
+ elif symmetry_id.lower() == "i":
32
+ frames = get_icosahedral_frames()
27
33
  elif symmetry_id.lower() == "input_defined":
28
34
  assert (
29
35
  sym_conf.symmetry_file is not None
@@ -280,6 +286,248 @@ def get_dihedral_frames(order):
280
286
  return frames
281
287
 
282
288
 
289
+ def get_tetrahedral_frames():
290
+ """
291
+ Get tetrahedral frames (T symmetry group, 12 elements).
292
+ Returns:
293
+ frames: list of rotation matrices
294
+ """
295
+
296
+ frames = []
297
+
298
+ # Identity
299
+ frames.append((np.eye(3), np.array([0, 0, 0])))
300
+
301
+ # 8 rotations by ±120° around body diagonals (±1, ±1, ±1)
302
+ diagonals = [
303
+ np.array([1, 1, 1]),
304
+ np.array([1, -1, -1]),
305
+ np.array([-1, 1, -1]),
306
+ np.array([-1, -1, 1]),
307
+ ]
308
+ for d in diagonals:
309
+ axis = d / np.linalg.norm(d)
310
+ for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
311
+ R = _rotation_matrix_from_axis_angle(axis, angle)
312
+ frames.append((R, np.array([0, 0, 0])))
313
+
314
+ # 3 rotations by 180° around coordinate axes
315
+ for axis in [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])]:
316
+ R = _rotation_matrix_from_axis_angle(axis, np.pi)
317
+ frames.append((R, np.array([0, 0, 0])))
318
+
319
+ return frames
320
+
321
+
322
+ def get_octahedral_frames():
323
+ """
324
+ Get octahedral frames (O symmetry group, 24 elements).
325
+ The axes are computed from the geometry of a cube with vertices at (±1, ±1, ±1).
326
+ Returns:
327
+ frames: list of rotation matrices
328
+ """
329
+
330
+ frames = []
331
+
332
+ # 8 vertices of the cube
333
+ vertices = []
334
+ for s1 in [1, -1]:
335
+ for s2 in [1, -1]:
336
+ for s3 in [1, -1]:
337
+ vertices.append(np.array([s1, s2, s3]))
338
+ vertices = np.array(vertices)
339
+
340
+ # 6 face centers of the cube (4-fold axes pass through these)
341
+ face_centers = [
342
+ np.array([1, 0, 0]),
343
+ np.array([-1, 0, 0]),
344
+ np.array([0, 1, 0]),
345
+ np.array([0, -1, 0]),
346
+ np.array([0, 0, 1]),
347
+ np.array([0, 0, -1]),
348
+ ]
349
+
350
+ # Find edges (pairs of vertices differing in exactly one coordinate)
351
+ edges = []
352
+ for i in range(len(vertices)):
353
+ for j in range(i + 1, len(vertices)):
354
+ diff = np.abs(vertices[i] - vertices[j])
355
+ if np.sum(diff > 0) == 1: # Differ in exactly one coordinate
356
+ edges.append((i, j))
357
+
358
+ # Helper to get unique axis (normalize direction to avoid duplicates)
359
+ def normalize_axis(v):
360
+ axis = v / np.linalg.norm(v)
361
+ for c in axis:
362
+ if abs(c) > 1e-10:
363
+ if c < 0:
364
+ axis = -axis
365
+ break
366
+ return tuple(np.round(axis, 10))
367
+
368
+ # Identity
369
+ frames.append((np.eye(3), np.array([0, 0, 0])))
370
+
371
+ # 4-fold axes (through opposite face centers) - 3 axes
372
+ # Each gives rotations at 90°, 180°, 270° (we skip 0° = identity)
373
+ fourfold_axes_set = set()
374
+ for fc in face_centers:
375
+ axis_tuple = normalize_axis(fc)
376
+ fourfold_axes_set.add(axis_tuple)
377
+
378
+ for axis_tuple in fourfold_axes_set:
379
+ axis = np.array(axis_tuple)
380
+ for k in [1, 2, 3]: # 90°, 180°, 270°
381
+ angle = np.pi * k / 2
382
+ R = _rotation_matrix_from_axis_angle(axis, angle)
383
+ frames.append((R, np.array([0, 0, 0])))
384
+
385
+ # 3-fold axes (through opposite vertices) - 4 axes
386
+ # Each gives rotations at 120°, 240°
387
+ threefold_axes_set = set()
388
+ for v in vertices:
389
+ axis_tuple = normalize_axis(v)
390
+ threefold_axes_set.add(axis_tuple)
391
+
392
+ for axis_tuple in threefold_axes_set:
393
+ axis = np.array(axis_tuple)
394
+ for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
395
+ R = _rotation_matrix_from_axis_angle(axis, angle)
396
+ frames.append((R, np.array([0, 0, 0])))
397
+
398
+ # 2-fold axes (through opposite edge midpoints) - 6 axes
399
+ # Each gives 1 rotation at 180°
400
+ twofold_axes_set = set()
401
+ for i, j in edges:
402
+ midpoint = (vertices[i] + vertices[j]) / 2
403
+ axis_tuple = normalize_axis(midpoint)
404
+ twofold_axes_set.add(axis_tuple)
405
+
406
+ for axis_tuple in twofold_axes_set:
407
+ axis = np.array(axis_tuple)
408
+ R = _rotation_matrix_from_axis_angle(axis, np.pi)
409
+ frames.append((R, np.array([0, 0, 0])))
410
+
411
+ return frames
412
+
413
+
414
+ def get_icosahedral_frames():
415
+ """
416
+ Get icosahedral frames (I symmetry group, 60 elements).
417
+ The axes are computed from the geometry of a regular icosahedron with
418
+ vertices at (0, ±1, ±φ), (±1, ±φ, 0), (±φ, 0, ±1) where φ is the golden ratio.
419
+ Returns:
420
+ frames: list of rotation matrices
421
+ """
422
+
423
+ frames = []
424
+
425
+ # Golden ratio
426
+ phi = (1 + np.sqrt(5)) / 2
427
+
428
+ # 12 vertices of the icosahedron
429
+ vertices = []
430
+ for s1 in [1, -1]:
431
+ for s2 in [1, -1]:
432
+ vertices.append(np.array([0, s1 * 1, s2 * phi]))
433
+ vertices.append(np.array([s1 * 1, s2 * phi, 0]))
434
+ vertices.append(np.array([s2 * phi, 0, s1 * 1]))
435
+ vertices = np.array(vertices)
436
+
437
+ # Find edges (pairs of vertices at distance 2)
438
+ edges = []
439
+ for i in range(len(vertices)):
440
+ for j in range(i + 1, len(vertices)):
441
+ dist_sq = np.sum((vertices[i] - vertices[j]) ** 2)
442
+ if np.isclose(dist_sq, 4.0):
443
+ edges.append((i, j))
444
+
445
+ # Find faces (triangles of mutually adjacent vertices)
446
+ edge_set = set(edges)
447
+ faces = []
448
+ for i in range(len(vertices)):
449
+ for j in range(i + 1, len(vertices)):
450
+ for k in range(j + 1, len(vertices)):
451
+ if (i, j) in edge_set and (j, k) in edge_set and (i, k) in edge_set:
452
+ faces.append((i, j, k))
453
+
454
+ # Helper to get unique axis (normalize direction to avoid duplicates)
455
+ def normalize_axis(v):
456
+ axis = v / np.linalg.norm(v)
457
+ # Make first significant component positive to avoid duplicate opposite axes
458
+ for c in axis:
459
+ if abs(c) > 1e-10:
460
+ if c < 0:
461
+ axis = -axis
462
+ break
463
+ return tuple(np.round(axis, 10))
464
+
465
+ # Identity
466
+ frames.append((np.eye(3), np.array([0, 0, 0])))
467
+
468
+ # 5-fold axes (through opposite vertices) - 6 axes, 4 rotations each = 24
469
+ fivefold_axes_set = set()
470
+ for v in vertices:
471
+ axis_tuple = normalize_axis(v)
472
+ fivefold_axes_set.add(axis_tuple)
473
+
474
+ for axis_tuple in fivefold_axes_set:
475
+ axis = np.array(axis_tuple)
476
+ for k in [1, 2, 3, 4]:
477
+ angle = 2 * np.pi * k / 5
478
+ R = _rotation_matrix_from_axis_angle(axis, angle)
479
+ frames.append((R, np.array([0, 0, 0])))
480
+
481
+ # 3-fold axes (through opposite face centers) - 10 axes, 2 rotations each = 20
482
+ threefold_axes_set = set()
483
+ for i, j, k in faces:
484
+ center = (vertices[i] + vertices[j] + vertices[k]) / 3
485
+ axis_tuple = normalize_axis(center)
486
+ threefold_axes_set.add(axis_tuple)
487
+
488
+ for axis_tuple in threefold_axes_set:
489
+ axis = np.array(axis_tuple)
490
+ for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
491
+ R = _rotation_matrix_from_axis_angle(axis, angle)
492
+ frames.append((R, np.array([0, 0, 0])))
493
+
494
+ # 2-fold axes (through opposite edge midpoints) - 15 axes, 1 rotation each = 15
495
+ twofold_axes_set = set()
496
+ for i, j in edges:
497
+ midpoint = (vertices[i] + vertices[j]) / 2
498
+ axis_tuple = normalize_axis(midpoint)
499
+ twofold_axes_set.add(axis_tuple)
500
+
501
+ for axis_tuple in twofold_axes_set:
502
+ axis = np.array(axis_tuple)
503
+ R = _rotation_matrix_from_axis_angle(axis, np.pi)
504
+ frames.append((R, np.array([0, 0, 0])))
505
+
506
+ return frames
507
+
508
+
509
+ def _rotation_matrix_from_axis_angle(axis, angle):
510
+ """
511
+ Compute a rotation matrix from an axis and angle using Rodrigues' formula.
512
+ Arguments:
513
+ axis: unit vector of the rotation axis
514
+ angle: rotation angle in radians
515
+ Returns:
516
+ R: 3x3 rotation matrix
517
+ """
518
+
519
+ axis = axis / np.linalg.norm(axis)
520
+ K = np.array(
521
+ [
522
+ [0, -axis[2], axis[1]],
523
+ [axis[2], 0, -axis[0]],
524
+ [-axis[1], axis[0], 0],
525
+ ]
526
+ )
527
+ R = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)
528
+ return R
529
+
530
+
283
531
  def get_frames_from_file(file_path):
284
532
  raise NotImplementedError("Input defined symmetry not implemented")
285
533
 
@@ -45,7 +45,7 @@ class SymmetryConfig(BaseModel):
45
45
  )
46
46
  id: Optional[str] = Field(
47
47
  None,
48
- description="Symmetry group ID. e.g. 'C3', 'D2'. Only C and D symmetry types are supported currently.",
48
+ description="Symmetry group ID. Supported types: Cyclic (C), Dihedral (D), Tetrahedral (T), Octahedral (O), Icosahedral (I). e.g. 'C3', 'D2', 'T', 'O', 'I'.",
49
49
  )
50
50
  is_unsym_motif: Optional[str] = Field(
51
51
  None,
@@ -101,10 +101,10 @@ def make_symmetric_atom_array(
101
101
  ), "Source atom array must be provided for symmetric motifs"
102
102
  # if symmetric motif is provided, get the frames from the src atom array.
103
103
  frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
104
- elif (asu_atom_array._is_motif[~asu_atom_array._is_unsym_motif]).any():
105
- # if the motifs that's not unsym motifs are present.
106
- raise NotImplementedError(
107
- "Asymmetric motif inputs are not implemented yet. please symmetrize the motif."
104
+ else:
105
+ # At this point, asym case would have been caught by the check_symmetry_config function.
106
+ ranked_logger.info(
107
+ "No motifs found in atom array. Generating unconditional symmetric proteins."
108
108
  )
109
109
 
110
110
  # Add symmetry annotations to the asu atom array
@@ -1,11 +1,15 @@
1
1
  import inspect
2
+ import time
2
3
  from dataclasses import dataclass
3
4
  from typing import Any, Literal
4
5
 
5
6
  import torch
6
7
  from jaxtyping import Float
8
+ from rfd3.inference.symmetry.symmetry_utils import apply_symmetry_to_xyz_atomwise
9
+ from rfd3.model.cfg_utils import strip_X
7
10
 
8
11
  from foundry.common import exists
12
+ from foundry.utils.alignment import weighted_rigid_align
9
13
  from foundry.utils.ddp import RankedLogger
10
14
  from foundry.utils.rotation_augmentation import (
11
15
  rot_vec_mul,
@@ -110,7 +114,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
110
114
  )
111
115
  # Fallback to smallest available step
112
116
  noise_schedule_original = self._construct_inference_noise_schedule(
113
- device=coord_atom_lvl_to_be_noised.device
117
+ device=device
114
118
  )
115
119
  noise_schedule = noise_schedule_original[-1:] # Just use the final step
116
120
  ranked_logger.info(
@@ -221,6 +225,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
221
225
  # Handle chunked mode vs standard mode
222
226
  if "chunked_pairwise_embedder" in initializer_outputs:
223
227
  # Chunked mode: explicitly provide P_LL=None
228
+ tic = time.time()
224
229
  chunked_embedder = initializer_outputs[
225
230
  "chunked_pairwise_embedder"
226
231
  ] # Don't pop, just get
@@ -238,6 +243,8 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
238
243
  initializer_outputs=other_outputs,
239
244
  **other_outputs,
240
245
  )
246
+ toc = time.time()
247
+ ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
241
248
  else:
242
249
  # Standard mode: P_LL is included in initializer_outputs
243
250
  outs = diffusion_module(
@@ -445,6 +452,7 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
445
452
  # Handle chunked mode vs standard mode (same as default sampler)
446
453
  if "chunked_pairwise_embedder" in initializer_outputs:
447
454
  # Chunked mode: explicitly provide P_LL=None
455
+ tic = time.time()
448
456
  chunked_embedder = initializer_outputs[
449
457
  "chunked_pairwise_embedder"
450
458
  ] # Don't pop, just get
@@ -462,6 +470,8 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
462
470
  initializer_outputs=other_outputs,
463
471
  **other_outputs,
464
472
  )
473
+ toc = time.time()
474
+ ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
465
475
  else:
466
476
  # Standard mode: P_LL is included in initializer_outputs
467
477
  outs = diffusion_module(
@@ -118,14 +118,14 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
118
118
 
119
119
  Parameters
120
120
  ----------
121
- P_LK_indices : (D, L, k) LongTensor
121
+ P_LK_indices : (B, L, k) LongTensor
122
122
  Key indices | P_LK_indices[d, i, k] = global atom index for which atom i attends to.
123
- P_LK : (D, L, k, c) FloatTensor
123
+ P_LK : (B, L, k, c) FloatTensor
124
124
  Key features to scatter add into
125
125
 
126
- P_LA_indices : (D, L, a) LongTensor
126
+ P_LA_indices : (B, L, a) LongTensor
127
127
  Additional feature indices to scatter into P_LK.
128
- P_LA : (D, L, a, c) FloatTensor
128
+ P_LA : (B, L, a, c) FloatTensor
129
129
  Features corresponding to P_LA.
130
130
 
131
131
  Both index tensors contain indices representing D batch dim,
@@ -135,42 +135,42 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
135
135
 
136
136
  """
137
137
  # Handle case when indices and P_LA don't have batch dimensions
138
- D, L, k = P_LK_indices.shape
138
+ B, L, k = P_LK_indices.shape
139
139
  if P_LA_indices.ndim == 2:
140
- P_LA_indices = P_LA_indices.unsqueeze(0).expand(D, -1, -1)
140
+ P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1)
141
141
  if P_LA_src.ndim == 3:
142
- P_LA_src = P_LA_src.unsqueeze(0).expand(D, -1, -1)
142
+ P_LA_src = P_LA_src.unsqueeze(0).expand(B, -1, -1)
143
143
  assert (
144
144
  P_LA_src.shape[-1] == P_LK_tgt.shape[-1]
145
145
  ), "Channel dims do not match, got: {} vs {}".format(
146
146
  P_LA_src.shape[-1], P_LK_tgt.shape[-1]
147
147
  )
148
148
 
149
- matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (D, L, a, k)
149
+ matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (B, L, a, k)
150
150
  if not torch.all(matches.sum(dim=(-1, -2)) >= 1):
151
151
  raise ValueError("Found multiple scatter indices for some atoms")
152
152
  elif not torch.all(matches.sum(dim=-1) <= 1):
153
153
  raise ValueError("Did not find a scatter index for every atom")
154
- k_indices = matches.long().argmax(dim=-1) # (D, L, a)
154
+ k_indices = matches.long().argmax(dim=-1) # (B, L, a)
155
155
  scatter_indices = k_indices.unsqueeze(-1).expand(
156
156
  -1, -1, -1, P_LK_tgt.shape[-1]
157
- ) # (D, L, a, c)
157
+ ) # (B, L, a, c)
158
158
  P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src)
159
159
  return P_LK_tgt
160
160
 
161
161
 
162
162
  def _batched_gather(values: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
163
163
  """
164
- values : (D, L, C)
165
- idx : (D, L, k)
166
- returns: (D, L, k, C)
164
+ values : (B, L, C)
165
+ idx : (B, L, k)
166
+ returns: (B, L, k, C)
167
167
  """
168
- D, L, C = values.shape
168
+ B, L, C = values.shape
169
169
  k = idx.shape[-1]
170
170
 
171
- # (D, L, 1, C) → stride-0 along k → (D, L, k, C)
171
+ # (B, L, 1, C) → stride-0 along k → (B, L, k, C)
172
172
  src = values.unsqueeze(2).expand(-1, -1, k, -1)
173
- idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (D, L, k, C)
173
+ idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (B, L, k, C)
174
174
 
175
175
  return torch.gather(src, 1, idx) # dim=1 is the L-axis
176
176
 
@@ -196,7 +196,7 @@ def create_attention_indices(
196
196
  X_L = torch.randn(
197
197
  (1, L, 3), device=device, dtype=torch.float
198
198
  ) # [L, 3] - random
199
- D_LL = torch.cdist(X_L, X_L, p=2) # [D, L, L] - pairwise atom distances
199
+ D_LL = torch.cdist(X_L, X_L, p=2) # [B, L, L] - pairwise atom distances
200
200
 
201
201
  # Create attention indices using neighbour distances
202
202
  base_mask = ~f["unindexing_pair_mask"][
@@ -231,7 +231,7 @@ def create_attention_indices(
231
231
  k_max=k_actual,
232
232
  chain_id=chain_ids,
233
233
  base_mask=base_mask,
234
- ) # [D, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
234
+ ) # [B, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
235
235
 
236
236
  return attn_indices
237
237
 
@@ -245,7 +245,7 @@ def get_sparse_attention_indices_with_inter_chain(
245
245
 
246
246
  Args:
247
247
  tok_idx: atom to token mapping
248
- D_LL: pairwise distances [D, L, L]
248
+ D_LL: pairwise distances [B, L, L]
249
249
  n_seq_neighbours: number of sequence neighbors
250
250
  k_intra: number of intra-chain attention keys
251
251
  k_inter: number of inter-chain attention keys
@@ -253,29 +253,29 @@ def get_sparse_attention_indices_with_inter_chain(
253
253
  base_mask: base mask for valid pairs
254
254
 
255
255
  Returns:
256
- attn_indices: [D, L, k_total] where k_total = k_intra + k_inter
256
+ attn_indices: [B, L, k_total] where k_total = k_intra + k_inter
257
257
  """
258
- D, L, _ = D_LL.shape
258
+ B, L, _ = D_LL.shape
259
259
 
260
260
  # Get regular intra-chain indices (limited to k_intra)
261
261
  intra_indices = get_sparse_attention_indices(
262
262
  tok_idx, D_LL, n_seq_neighbours, k_intra, chain_id, base_mask
263
- ) # [D, L, k_intra]
263
+ ) # [B, L, k_intra]
264
264
 
265
265
  # Get inter-chain indices for clash avoidance
266
- inter_indices = torch.zeros(D, L, k_inter, dtype=torch.long, device=D_LL.device)
267
-
268
- for d in range(D):
269
- for l in range(L):
270
- query_chain = chain_id[l]
266
+ inter_indices = torch.zeros(B, L, k_inter, dtype=torch.long, device=D_LL.device)
267
+ unique_chains = torch.unique(chain_id)
268
+ for b in range(B):
269
+ for c in unique_chains:
270
+ query_chain = chain_id[c]
271
271
 
272
272
  # Find atoms from different chains
273
- other_chain_mask = (chain_id != query_chain) & base_mask[l, :]
273
+ other_chain_mask = (chain_id != query_chain) & base_mask[c, :]
274
274
  other_chain_atoms = torch.where(other_chain_mask)[0]
275
275
 
276
276
  if len(other_chain_atoms) > 0:
277
277
  # Get distances to other chains
278
- distances_to_other = D_LL[d, l, other_chain_atoms]
278
+ distances_to_other = D_LL[b, c, other_chain_atoms]
279
279
 
280
280
  # Select k_inter closest atoms from other chains
281
281
  n_select = min(k_inter, len(other_chain_atoms))
@@ -283,23 +283,23 @@ def get_sparse_attention_indices_with_inter_chain(
283
283
  selected_atoms = other_chain_atoms[closest_idx]
284
284
 
285
285
  # Fill inter-chain indices
286
- inter_indices[d, l, :n_select] = selected_atoms
286
+ inter_indices[b, c, :n_select] = selected_atoms
287
287
  # Pad with random atoms if needed
288
288
  if n_select < k_inter:
289
289
  padding = torch.randint(
290
290
  0, L, (k_inter - n_select,), device=D_LL.device
291
291
  )
292
- inter_indices[d, l, n_select:] = padding
292
+ inter_indices[b, c, n_select:] = padding
293
293
  else:
294
294
  # No other chains found, fill with random indices
295
- inter_indices[d, l, :] = torch.randint(
295
+ inter_indices[b, c, :] = torch.randint(
296
296
  0, L, (k_inter,), device=D_LL.device
297
297
  )
298
298
 
299
299
  # Combine intra and inter chain indices
300
300
  combined_indices = torch.cat(
301
301
  [intra_indices, inter_indices], dim=-1
302
- ) # [D, L, k_total]
302
+ ) # [B, L, k_total]
303
303
 
304
304
  return combined_indices
305
305