rc-foundry 0.1.6__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,12 +10,13 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
10
10
  Returns:
11
11
  frames: list of rotation matrices
12
12
  """
13
+ from rfd3.inference.symmetry.symmetry_utils import SymmetryConfig
13
14
 
14
15
  # Get frames from symmetry id
15
16
  sym_conf = {}
16
- if isinstance(symmetry_id, dict):
17
+ if isinstance(symmetry_id, SymmetryConfig):
17
18
  sym_conf = symmetry_id
18
- symmetry_id = symmetry_id.get("id")
19
+ symmetry_id = symmetry_id.id
19
20
 
20
21
  if symmetry_id.lower().startswith("c"):
21
22
  order = int(symmetry_id[1:])
@@ -23,11 +24,17 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
23
24
  elif symmetry_id.lower().startswith("d"):
24
25
  order = int(symmetry_id[1:])
25
26
  frames = get_dihedral_frames(order)
27
+ elif symmetry_id.lower() == "t":
28
+ frames = get_tetrahedral_frames()
29
+ elif symmetry_id.lower() == "o":
30
+ frames = get_octahedral_frames()
31
+ elif symmetry_id.lower() == "i":
32
+ frames = get_icosahedral_frames()
26
33
  elif symmetry_id.lower() == "input_defined":
27
34
  assert (
28
- "symmetry_file" in sym_conf
35
+ sym_conf.symmetry_file is not None
29
36
  ), "symmetry_file is required for input_defined symmetry"
30
- frames = get_frames_from_file(sym_conf.get("symmetry_file"))
37
+ frames = get_frames_from_file(sym_conf.symmetry_file)
31
38
  else:
32
39
  raise ValueError(f"Symmetry id {symmetry_id} not supported")
33
40
 
@@ -120,7 +127,9 @@ def get_symmetry_frames_from_atom_array(src_atom_array, input_frames):
120
127
  computed_frames = [(R, np.array([0, 0, 0])) for R in Rs]
121
128
 
122
129
  # check that the computed frames match the input frames
123
- check_input_frames_match_symmetry_frames(computed_frames, input_frames)
130
+ check_input_frames_match_symmetry_frames(
131
+ computed_frames, input_frames, nids_by_entity
132
+ )
124
133
 
125
134
  return computed_frames
126
135
 
@@ -277,6 +286,248 @@ def get_dihedral_frames(order):
277
286
  return frames
278
287
 
279
288
 
289
+ def get_tetrahedral_frames():
290
+ """
291
+ Get tetrahedral frames (T symmetry group, 12 elements).
292
+ Returns:
293
+ frames: list of rotation matrices
294
+ """
295
+
296
+ frames = []
297
+
298
+ # Identity
299
+ frames.append((np.eye(3), np.array([0, 0, 0])))
300
+
301
+ # 8 rotations by ±120° around body diagonals (±1, ±1, ±1)
302
+ diagonals = [
303
+ np.array([1, 1, 1]),
304
+ np.array([1, -1, -1]),
305
+ np.array([-1, 1, -1]),
306
+ np.array([-1, -1, 1]),
307
+ ]
308
+ for d in diagonals:
309
+ axis = d / np.linalg.norm(d)
310
+ for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
311
+ R = _rotation_matrix_from_axis_angle(axis, angle)
312
+ frames.append((R, np.array([0, 0, 0])))
313
+
314
+ # 3 rotations by 180° around coordinate axes
315
+ for axis in [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])]:
316
+ R = _rotation_matrix_from_axis_angle(axis, np.pi)
317
+ frames.append((R, np.array([0, 0, 0])))
318
+
319
+ return frames
320
+
321
+
322
+ def get_octahedral_frames():
323
+ """
324
+ Get octahedral frames (O symmetry group, 24 elements).
325
+ The axes are computed from the geometry of a cube with vertices at (±1, ±1, ±1).
326
+ Returns:
327
+ frames: list of rotation matrices
328
+ """
329
+
330
+ frames = []
331
+
332
+ # 8 vertices of the cube
333
+ vertices = []
334
+ for s1 in [1, -1]:
335
+ for s2 in [1, -1]:
336
+ for s3 in [1, -1]:
337
+ vertices.append(np.array([s1, s2, s3]))
338
+ vertices = np.array(vertices)
339
+
340
+ # 6 face centers of the cube (4-fold axes pass through these)
341
+ face_centers = [
342
+ np.array([1, 0, 0]),
343
+ np.array([-1, 0, 0]),
344
+ np.array([0, 1, 0]),
345
+ np.array([0, -1, 0]),
346
+ np.array([0, 0, 1]),
347
+ np.array([0, 0, -1]),
348
+ ]
349
+
350
+ # Find edges (pairs of vertices differing in exactly one coordinate)
351
+ edges = []
352
+ for i in range(len(vertices)):
353
+ for j in range(i + 1, len(vertices)):
354
+ diff = np.abs(vertices[i] - vertices[j])
355
+ if np.sum(diff > 0) == 1: # Differ in exactly one coordinate
356
+ edges.append((i, j))
357
+
358
+ # Helper to get unique axis (normalize direction to avoid duplicates)
359
+ def normalize_axis(v):
360
+ axis = v / np.linalg.norm(v)
361
+ for c in axis:
362
+ if abs(c) > 1e-10:
363
+ if c < 0:
364
+ axis = -axis
365
+ break
366
+ return tuple(np.round(axis, 10))
367
+
368
+ # Identity
369
+ frames.append((np.eye(3), np.array([0, 0, 0])))
370
+
371
+ # 4-fold axes (through opposite face centers) - 3 axes
372
+ # Each gives rotations at 90°, 180°, 270° (we skip 0° = identity)
373
+ fourfold_axes_set = set()
374
+ for fc in face_centers:
375
+ axis_tuple = normalize_axis(fc)
376
+ fourfold_axes_set.add(axis_tuple)
377
+
378
+ for axis_tuple in fourfold_axes_set:
379
+ axis = np.array(axis_tuple)
380
+ for k in [1, 2, 3]: # 90°, 180°, 270°
381
+ angle = np.pi * k / 2
382
+ R = _rotation_matrix_from_axis_angle(axis, angle)
383
+ frames.append((R, np.array([0, 0, 0])))
384
+
385
+ # 3-fold axes (through opposite vertices) - 4 axes
386
+ # Each gives rotations at 120°, 240°
387
+ threefold_axes_set = set()
388
+ for v in vertices:
389
+ axis_tuple = normalize_axis(v)
390
+ threefold_axes_set.add(axis_tuple)
391
+
392
+ for axis_tuple in threefold_axes_set:
393
+ axis = np.array(axis_tuple)
394
+ for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
395
+ R = _rotation_matrix_from_axis_angle(axis, angle)
396
+ frames.append((R, np.array([0, 0, 0])))
397
+
398
+ # 2-fold axes (through opposite edge midpoints) - 6 axes
399
+ # Each gives 1 rotation at 180°
400
+ twofold_axes_set = set()
401
+ for i, j in edges:
402
+ midpoint = (vertices[i] + vertices[j]) / 2
403
+ axis_tuple = normalize_axis(midpoint)
404
+ twofold_axes_set.add(axis_tuple)
405
+
406
+ for axis_tuple in twofold_axes_set:
407
+ axis = np.array(axis_tuple)
408
+ R = _rotation_matrix_from_axis_angle(axis, np.pi)
409
+ frames.append((R, np.array([0, 0, 0])))
410
+
411
+ return frames
412
+
413
+
414
+ def get_icosahedral_frames():
415
+ """
416
+ Get icosahedral frames (I symmetry group, 60 elements).
417
+ The axes are computed from the geometry of a regular icosahedron with
418
+ vertices at (0, ±1, ±φ), (±1, ±φ, 0), (±φ, 0, ±1) where φ is the golden ratio.
419
+ Returns:
420
+ frames: list of rotation matrices
421
+ """
422
+
423
+ frames = []
424
+
425
+ # Golden ratio
426
+ phi = (1 + np.sqrt(5)) / 2
427
+
428
+ # 12 vertices of the icosahedron
429
+ vertices = []
430
+ for s1 in [1, -1]:
431
+ for s2 in [1, -1]:
432
+ vertices.append(np.array([0, s1 * 1, s2 * phi]))
433
+ vertices.append(np.array([s1 * 1, s2 * phi, 0]))
434
+ vertices.append(np.array([s2 * phi, 0, s1 * 1]))
435
+ vertices = np.array(vertices)
436
+
437
+ # Find edges (pairs of vertices at distance 2)
438
+ edges = []
439
+ for i in range(len(vertices)):
440
+ for j in range(i + 1, len(vertices)):
441
+ dist_sq = np.sum((vertices[i] - vertices[j]) ** 2)
442
+ if np.isclose(dist_sq, 4.0):
443
+ edges.append((i, j))
444
+
445
+ # Find faces (triangles of mutually adjacent vertices)
446
+ edge_set = set(edges)
447
+ faces = []
448
+ for i in range(len(vertices)):
449
+ for j in range(i + 1, len(vertices)):
450
+ for k in range(j + 1, len(vertices)):
451
+ if (i, j) in edge_set and (j, k) in edge_set and (i, k) in edge_set:
452
+ faces.append((i, j, k))
453
+
454
+ # Helper to get unique axis (normalize direction to avoid duplicates)
455
+ def normalize_axis(v):
456
+ axis = v / np.linalg.norm(v)
457
+ # Make first significant component positive to avoid duplicate opposite axes
458
+ for c in axis:
459
+ if abs(c) > 1e-10:
460
+ if c < 0:
461
+ axis = -axis
462
+ break
463
+ return tuple(np.round(axis, 10))
464
+
465
+ # Identity
466
+ frames.append((np.eye(3), np.array([0, 0, 0])))
467
+
468
+ # 5-fold axes (through opposite vertices) - 6 axes, 4 rotations each = 24
469
+ fivefold_axes_set = set()
470
+ for v in vertices:
471
+ axis_tuple = normalize_axis(v)
472
+ fivefold_axes_set.add(axis_tuple)
473
+
474
+ for axis_tuple in fivefold_axes_set:
475
+ axis = np.array(axis_tuple)
476
+ for k in [1, 2, 3, 4]:
477
+ angle = 2 * np.pi * k / 5
478
+ R = _rotation_matrix_from_axis_angle(axis, angle)
479
+ frames.append((R, np.array([0, 0, 0])))
480
+
481
+ # 3-fold axes (through opposite face centers) - 10 axes, 2 rotations each = 20
482
+ threefold_axes_set = set()
483
+ for i, j, k in faces:
484
+ center = (vertices[i] + vertices[j] + vertices[k]) / 3
485
+ axis_tuple = normalize_axis(center)
486
+ threefold_axes_set.add(axis_tuple)
487
+
488
+ for axis_tuple in threefold_axes_set:
489
+ axis = np.array(axis_tuple)
490
+ for angle in [2 * np.pi / 3, 4 * np.pi / 3]:
491
+ R = _rotation_matrix_from_axis_angle(axis, angle)
492
+ frames.append((R, np.array([0, 0, 0])))
493
+
494
+ # 2-fold axes (through opposite edge midpoints) - 15 axes, 1 rotation each = 15
495
+ twofold_axes_set = set()
496
+ for i, j in edges:
497
+ midpoint = (vertices[i] + vertices[j]) / 2
498
+ axis_tuple = normalize_axis(midpoint)
499
+ twofold_axes_set.add(axis_tuple)
500
+
501
+ for axis_tuple in twofold_axes_set:
502
+ axis = np.array(axis_tuple)
503
+ R = _rotation_matrix_from_axis_angle(axis, np.pi)
504
+ frames.append((R, np.array([0, 0, 0])))
505
+
506
+ return frames
507
+
508
+
509
+ def _rotation_matrix_from_axis_angle(axis, angle):
510
+ """
511
+ Compute a rotation matrix from an axis and angle using Rodrigues' formula.
512
+ Arguments:
513
+ axis: unit vector of the rotation axis
514
+ angle: rotation angle in radians
515
+ Returns:
516
+ R: 3x3 rotation matrix
517
+ """
518
+
519
+ axis = axis / np.linalg.norm(axis)
520
+ K = np.array(
521
+ [
522
+ [0, -axis[2], axis[1]],
523
+ [axis[2], 0, -axis[0]],
524
+ [-axis[1], axis[0], 0],
525
+ ]
526
+ )
527
+ R = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)
528
+ return R
529
+
530
+
280
531
  def get_frames_from_file(file_path):
281
532
  raise NotImplementedError("Input defined symmetry not implemented")
282
533
 
@@ -39,18 +39,36 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
39
39
 
40
40
 
41
41
  class SymmetryConfig(BaseModel):
42
- # AM / HE TODO: feel free to flesh this out and add validation as needed
43
42
  model_config = ConfigDict(
44
43
  arbitrary_types_allowed=True,
45
44
  extra="allow",
46
45
  )
47
- id: Optional[str] = Field(None)
48
- # is_unsym_motif: Optional[np.ndarray[bool]] = Field(...)
49
- # is_symmetric_motif: bool = Field(...)
46
+ id: Optional[str] = Field(
47
+ None,
48
+ description="Symmetry group ID. Supported types: Cyclic (C), Dihedral (D), Tetrahedral (T), Octahedral (O), Icosahedral (I). e.g. 'C3', 'D2', 'T', 'O', 'I'.",
49
+ )
50
+ is_unsym_motif: Optional[str] = Field(
51
+ None,
52
+ description="Comma separated list of contig/ligand names that should not be symmetrized such as DNA strands. \
53
+ e.g. 'HEM' or 'Y1-11,Z16-25'",
54
+ )
55
+ is_symmetric_motif: bool = Field(
56
+ True,
57
+ description="If True, the input motifs are expected to be already symmetric and won't be symmetrized. \
58
+ If False, the all input motifs are expected to be ASU and will be symmetrized.",
59
+ )
60
+
61
+
62
+ def convery_sym_conf_to_symmetry_config(sym_conf: dict):
63
+ return SymmetryConfig(**sym_conf)
50
64
 
51
65
 
52
66
  def make_symmetric_atom_array(
53
- asu_atom_array, sym_conf: SymmetryConfig, sm=None, has_2d=False, src_atom_array=None
67
+ asu_atom_array,
68
+ sym_conf: SymmetryConfig | dict,
69
+ sm=None,
70
+ has_dist_cond=False,
71
+ src_atom_array=None,
54
72
  ):
55
73
  """
56
74
  apply symmetry to an atom array.
@@ -58,41 +76,35 @@ def make_symmetric_atom_array(
58
76
  asu_atom_array: atom array of the asymmetric unit
59
77
  sym_conf: symmetry configuration (dict, "id" key is required)
60
78
  sm: optional small molecule names (str, comma separated)
61
- has_2d: whether to add 2d entity annotations
79
+ has_dist_cond: whether to add 2d entity annotations
62
80
  Returns:
63
81
  new_asu_atom_array: atom array with symmetry applied
64
82
  """
65
- sym_conf = (
66
- sym_conf.model_dump()
67
- ) # TODO: JB: remove this line to keep as symmetry config for cleaner syntax(?)
68
- ranked_logger.info(f"Symmetry Configs: {sym_conf}")
83
+ if not isinstance(sym_conf, SymmetryConfig):
84
+ sym_conf = convery_sym_conf_to_symmetry_config(sym_conf)
69
85
 
70
- # Making sure that the symmetry config is valid
71
86
  check_symmetry_config(
72
- asu_atom_array,
73
- sym_conf,
74
- sm,
75
- has_dist_cond=has_2d,
76
- src_atom_array=src_atom_array,
87
+ asu_atom_array, sym_conf, sm, has_dist_cond, src_atom_array=src_atom_array
77
88
  )
78
89
  # Adding utility annotations to the asu atom array
79
90
  asu_atom_array = _add_util_annotations(asu_atom_array, sym_conf, sm)
80
91
 
81
- if has_2d: # NB: this will only work for asymmetric motifs at the moment - need to add functionality for symmetric motifs
92
+ if has_dist_cond: # NB: this will only work for asymmetric motifs at the moment - need to add functionality for symmetric motifs
82
93
  asu_atom_array = add_2d_entity_annotations(asu_atom_array)
83
94
 
84
95
  frames = get_symmetry_frames_from_symmetry_id(sym_conf)
85
96
 
86
97
  # If the motif is symmetric, we get the frames instead from the source atom array.
87
- if sym_conf.get("is_symmetric_motif"):
98
+ if sym_conf.is_symmetric_motif:
88
99
  assert (
89
100
  src_atom_array is not None
90
101
  ), "Source atom array must be provided for symmetric motifs"
91
- # if symmetric motif is provided, get the frames from the src atom array
102
+ # if symmetric motif is provided, get the frames from the src atom array.
92
103
  frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
93
104
  else:
94
- raise NotImplementedError(
95
- "Asymmetric motif inputs are not implemented yet. please symmetrize the motif."
105
+ # At this point, asym case would have been caught by the check_symmetry_config function.
106
+ ranked_logger.info(
107
+ "No motifs found in atom array. Generating unconditional symmetric proteins."
96
108
  )
97
109
 
98
110
  # Add symmetry annotations to the asu atom array
@@ -101,7 +113,7 @@ def make_symmetric_atom_array(
101
113
  # Extracting all things at this moment that we will not want to symmetrize.
102
114
  # This includes: 1) unsym motifs, 2) ligands
103
115
  unsym_atom_arrays = []
104
- if sym_conf.get("is_unsym_motif"):
116
+ if sym_conf.is_unsym_motif:
105
117
  # unsym_motif_atom_array = get_unsym_motif(asu_atom_array, asu_atom_array._is_unsym_motif)
106
118
  # Now remove the unsym motifs from the asu atom array
107
119
  unsym_atom_arrays.append(asu_atom_array[asu_atom_array._is_unsym_motif])
@@ -128,7 +140,7 @@ def make_symmetric_atom_array(
128
140
  symmetrized_atom_array = struc.concatenate(symmetry_unit_list)
129
141
 
130
142
  # add 2D conditioning annotations
131
- if has_2d:
143
+ if has_dist_cond:
132
144
  symmetrized_atom_array = reannotate_2d_conditions(symmetrized_atom_array)
133
145
 
134
146
  # set all motifs to not have any symmetrization applied to them
@@ -183,7 +195,7 @@ def make_symmetric_atom_array_for_partial_diffusion(atom_array, sym_conf):
183
195
  frames = get_symmetry_frames_from_symmetry_id(sym_conf)
184
196
 
185
197
  # Add symmetry ID
186
- symmetry_ids = np.full(n, sym_conf.get("id"), dtype="U6")
198
+ symmetry_ids = np.full(n, sym_conf.id, dtype="U6")
187
199
  atom_array.set_annotation("symmetry_id", symmetry_ids)
188
200
 
189
201
  # Initialize transform annotations (use same format as original system)
@@ -244,7 +256,7 @@ def _add_util_annotations(asu_atom_array, sym_conf, sm):
244
256
  """
245
257
  n = asu_atom_array.shape[0]
246
258
  is_motif = get_motif_features(asu_atom_array)["is_motif_atom"].astype(np.bool_)
247
- is_sm = np.zeros(asu_atom_array.shape[0], dtype=bool)
259
+ is_sm = np.zeros(n, dtype=bool)
248
260
  is_asu = np.ones(n, dtype=bool)
249
261
  is_unsym_motif = np.zeros(n, dtype=bool)
250
262
 
@@ -257,8 +269,8 @@ def _add_util_annotations(asu_atom_array, sym_conf, sm):
257
269
  )
258
270
 
259
271
  # assign unsym motifs
260
- if sym_conf.get("is_unsym_motif"):
261
- unsym_motif_names = sym_conf["is_unsym_motif"].split(",")
272
+ if sym_conf.is_unsym_motif:
273
+ unsym_motif_names = sym_conf.is_unsym_motif.split(",")
262
274
  unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
263
275
  is_unsym_motif = get_unsym_motif_mask(asu_atom_array, unsym_motif_names)
264
276
 
@@ -361,38 +373,4 @@ def apply_symmetry_to_xyz_atomwise(X_L, sym_feats, partial_diffusion=False):
361
373
  "blc,cd->bld", asu_xyz, sym_transforms[target_id][0].to(asu_xyz.dtype)
362
374
  ) + sym_transforms[target_id][1].to(asu_xyz.dtype)
363
375
 
364
- # Log inter-chain distances for debugging - use actual chain annotations
365
- if sym_X_L.shape[1] > 100: # Only for large structures
366
- # Use symmetry entity annotations to find different chains
367
- sym_entity_id = sym_feats["sym_entity_id"]
368
- unique_entities = torch.unique(sym_entity_id)
369
-
370
- if len(unique_entities) >= 2:
371
- # Get atoms from first two different entities
372
- entity_0_mask = sym_entity_id == unique_entities[0]
373
- entity_1_mask = sym_entity_id == unique_entities[1]
374
-
375
- if entity_0_mask.sum() > 0 and entity_1_mask.sum() > 0:
376
- entity_0_atoms = sym_X_L[0, entity_0_mask, :]
377
- entity_1_atoms = sym_X_L[0, entity_1_mask, :]
378
-
379
- # Sample subset to avoid memory issues
380
- entity_0_sample = entity_0_atoms[: min(50, entity_0_atoms.shape[0]), :]
381
- entity_1_sample = entity_1_atoms[: min(50, entity_1_atoms.shape[0]), :]
382
-
383
- min_distance = (
384
- torch.cdist(entity_0_sample, entity_1_sample).min().item()
385
- )
386
- ranked_logger.info(
387
- f"Min inter-chain distance after symmetry: {min_distance:.2f} Å"
388
- )
389
-
390
- # Also log the centers of each entity
391
- entity_0_center = entity_0_atoms.mean(dim=0)
392
- entity_1_center = entity_1_atoms.mean(dim=0)
393
- center_distance = torch.norm(entity_0_center - entity_1_center).item()
394
- ranked_logger.info(
395
- f"Distance between chain centers: {center_distance:.2f} Å"
396
- )
397
-
398
376
  return sym_X_L
@@ -1,11 +1,15 @@
1
1
  import inspect
2
+ import time
2
3
  from dataclasses import dataclass
3
4
  from typing import Any, Literal
4
5
 
5
6
  import torch
6
7
  from jaxtyping import Float
8
+ from rfd3.inference.symmetry.symmetry_utils import apply_symmetry_to_xyz_atomwise
9
+ from rfd3.model.cfg_utils import strip_X
7
10
 
8
11
  from foundry.common import exists
12
+ from foundry.utils.alignment import weighted_rigid_align
9
13
  from foundry.utils.ddp import RankedLogger
10
14
  from foundry.utils.rotation_augmentation import (
11
15
  rot_vec_mul,
@@ -110,7 +114,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
110
114
  )
111
115
  # Fallback to smallest available step
112
116
  noise_schedule_original = self._construct_inference_noise_schedule(
113
- device=coord_atom_lvl_to_be_noised.device
117
+ device=device
114
118
  )
115
119
  noise_schedule = noise_schedule_original[-1:] # Just use the final step
116
120
  ranked_logger.info(
@@ -221,6 +225,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
221
225
  # Handle chunked mode vs standard mode
222
226
  if "chunked_pairwise_embedder" in initializer_outputs:
223
227
  # Chunked mode: explicitly provide P_LL=None
228
+ tic = time.time()
224
229
  chunked_embedder = initializer_outputs[
225
230
  "chunked_pairwise_embedder"
226
231
  ] # Don't pop, just get
@@ -238,6 +243,8 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
238
243
  initializer_outputs=other_outputs,
239
244
  **other_outputs,
240
245
  )
246
+ toc = time.time()
247
+ ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
241
248
  else:
242
249
  # Standard mode: P_LL is included in initializer_outputs
243
250
  outs = diffusion_module(
@@ -445,6 +452,7 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
445
452
  # Handle chunked mode vs standard mode (same as default sampler)
446
453
  if "chunked_pairwise_embedder" in initializer_outputs:
447
454
  # Chunked mode: explicitly provide P_LL=None
455
+ tic = time.time()
448
456
  chunked_embedder = initializer_outputs[
449
457
  "chunked_pairwise_embedder"
450
458
  ] # Don't pop, just get
@@ -462,6 +470,8 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
462
470
  initializer_outputs=other_outputs,
463
471
  **other_outputs,
464
472
  )
473
+ toc = time.time()
474
+ ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
465
475
  else:
466
476
  # Standard mode: P_LL is included in initializer_outputs
467
477
  outs = diffusion_module(