rc-foundry 0.1.7__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. foundry/trainers/fabric.py +18 -2
  2. foundry/utils/components.py +3 -3
  3. foundry/utils/ddp.py +16 -13
  4. foundry/utils/logging.py +1 -1
  5. foundry/utils/xpu/__init__.py +27 -0
  6. foundry/utils/xpu/single_xpu_strategy.py +47 -0
  7. foundry/utils/xpu/xpu_accelerator.py +91 -0
  8. foundry/utils/xpu/xpu_precision.py +72 -0
  9. foundry/version.py +2 -2
  10. mpnn/inference_engines/mpnn.py +6 -2
  11. {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.10.dist-info}/METADATA +16 -2
  12. {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.10.dist-info}/RECORD +36 -30
  13. rf3/cli.py +13 -4
  14. rf3/configs/inference.yaml +5 -0
  15. rf3/configs/train.yaml +5 -0
  16. rf3/configs/trainer/xpu.yaml +6 -0
  17. rf3/configs/validate.yaml +5 -0
  18. rf3/inference.py +3 -1
  19. rfd3/configs/dev.yaml +1 -0
  20. rfd3/configs/inference.yaml +1 -0
  21. rfd3/configs/train.yaml +2 -1
  22. rfd3/configs/trainer/xpu.yaml +6 -0
  23. rfd3/configs/validate.yaml +1 -0
  24. rfd3/engine.py +25 -10
  25. rfd3/inference/datasets.py +1 -1
  26. rfd3/inference/input_parsing.py +35 -2
  27. rfd3/inference/symmetry/checks.py +12 -4
  28. rfd3/inference/symmetry/symmetry_utils.py +5 -6
  29. rfd3/model/inference_sampler.py +14 -2
  30. rfd3/model/layers/block_utils.py +33 -33
  31. rfd3/model/layers/chunked_pairwise.py +84 -82
  32. rfd3/transforms/hbonds_hbplus.py +52 -49
  33. rfd3/utils/inference.py +7 -30
  34. {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.10.dist-info}/WHEEL +0 -0
  35. {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.10.dist-info}/entry_points.txt +0 -0
  36. {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.10.dist-info}/licenses/LICENSE.md +0 -0
rf3/configs/validate.yaml CHANGED
@@ -2,6 +2,11 @@
2
2
  # ^ The "package" determines where the content of the config is placed in the output config
3
3
  # For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
4
4
 
5
+ hydra:
6
+ searchpath:
7
+ - pkg://rf3.configs
8
+ - pkg://configs
9
+
5
10
  # NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
6
11
  defaults:
7
12
  - callbacks: default
rf3/inference.py CHANGED
@@ -16,7 +16,9 @@ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
16
16
 
17
17
  load_dotenv(override=True)
18
18
 
19
- _config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
19
+ _config_path = os.path.join(
20
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs"
21
+ )
20
22
 
21
23
 
22
24
  @hydra.main(
rfd3/configs/dev.yaml CHANGED
@@ -2,6 +2,7 @@
2
2
  # Inference engine config for development
3
3
  hydra:
4
4
  searchpath:
5
+ - pkg://rfd3.configs
5
6
  - pkg://configs
6
7
 
7
8
  defaults:
@@ -2,6 +2,7 @@
2
2
 
3
3
  hydra:
4
4
  searchpath:
5
+ - pkg://rfd3.configs
5
6
  - pkg://configs
6
7
 
7
8
  defaults:
rfd3/configs/train.yaml CHANGED
@@ -1,6 +1,7 @@
1
1
  # @package _global_
2
2
  hydra:
3
3
  searchpath:
4
+ - pkg://rfd3.configs
4
5
  - pkg://configs
5
6
 
6
7
  defaults:
@@ -25,4 +26,4 @@ ckpt_path: null
25
26
 
26
27
  # Placeholders
27
28
  name: aa_design
28
- tags: [aa_design]
29
+ tags: [aa_design]
@@ -0,0 +1,6 @@
1
+ strategy: xpu_single
2
+
3
+ accelerator: xpu
4
+ devices_per_node: 1
5
+ num_nodes: 1
6
+
@@ -2,6 +2,7 @@
2
2
 
3
3
  hydra:
4
4
  searchpath:
5
+ - pkg://rfd3.configs
5
6
  - pkg://configs
6
7
 
7
8
  defaults:
rfd3/engine.py CHANGED
@@ -5,7 +5,7 @@ import time
5
5
  from dataclasses import dataclass, field
6
6
  from os import PathLike
7
7
  from pathlib import Path
8
- from typing import Any, Dict, List, Optional
8
+ from typing import Dict, List, Optional
9
9
 
10
10
  import torch
11
11
  import yaml
@@ -21,11 +21,13 @@ from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
21
21
  from rfd3.inference.datasets import (
22
22
  assemble_distributed_inference_loader_from_json,
23
23
  )
24
- from rfd3.inference.input_parsing import DesignInputSpecification
24
+ from rfd3.inference.input_parsing import (
25
+ DesignInputSpecification,
26
+ ensure_input_is_abspath,
27
+ )
25
28
  from rfd3.model.inference_sampler import SampleDiffusionConfig
26
29
  from rfd3.utils.inference import (
27
30
  ensure_inference_sampler_matches_design_spec,
28
- ensure_input_is_abspath,
29
31
  )
30
32
  from rfd3.utils.io import (
31
33
  CIF_LIKE_EXTENSIONS,
@@ -46,9 +48,8 @@ class RFD3InferenceConfig:
46
48
  diffusion_batch_size: int = 16
47
49
 
48
50
  # RFD3 specific
49
- skip_existing: bool = False
50
- json_keys_subset: Optional[List[str]] = None
51
51
  skip_existing: bool = True
52
+ json_keys_subset: Optional[List[str]] = None
52
53
  specification: Optional[dict] = field(default_factory=dict)
53
54
  inference_sampler: SampleDiffusionConfig | dict = field(default_factory=dict)
54
55
 
@@ -214,6 +215,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
214
215
  inputs=inputs,
215
216
  n_batches=n_batches,
216
217
  )
218
+ if len(design_specifications) == 0:
219
+ ranked_logger.info("No design specifications to run. Skipping.")
220
+ return None
217
221
  ensure_inference_sampler_matches_design_spec(
218
222
  design_specifications, self.inference_sampler_overrides
219
223
  )
@@ -376,28 +380,39 @@ class RFD3InferenceEngine(BaseInferenceEngine):
376
380
 
377
381
  def _multiply_specifications(
378
382
  self, inputs: Dict[str, dict | DesignInputSpecification], n_batches=None
379
- ) -> Dict[str, Dict[str, Any]]:
383
+ ) -> Dict[str, dict | DesignInputSpecification]:
380
384
  # Find existing example IDS in output directory
381
385
  if exists(self.out_dir):
382
- existing_example_ids = set(
386
+ existing_example_ids_ = set(
383
387
  extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
384
388
  for path in find_files_with_extension(self.out_dir, CIF_LIKE_EXTENSIONS)
385
389
  )
390
+ existing_example_ids = set(
391
+ [
392
+ "_model_".join(eid.split("_model_")[:-1])
393
+ for eid in existing_example_ids_
394
+ ]
395
+ )
386
396
  ranked_logger.info(
387
- f"Found {len(existing_example_ids)} existing example IDs in the output directory."
397
+ f"Found {len(existing_example_ids)} existing example IDs in the output directory ({len(existing_example_ids_)} total)."
388
398
  )
389
399
 
390
400
  # Based on inputs, construct the specifications to loop through
391
401
  design_specifications = {}
392
402
  for prefix, example_spec in inputs.items():
393
403
  # Record task name in the specification
394
- example_spec["extra"]["task_name"] = prefix
404
+ if isinstance(example_spec, DesignInputSpecification):
405
+ example_spec.extra = example_spec.extra or {}
406
+ example_spec.extra["task_name"] = prefix
407
+ else:
408
+ if "extra" not in example_spec:
409
+ example_spec["extra"] = {}
410
+ example_spec["extra"]["task_name"] = prefix
395
411
 
396
412
  # ... Create n_batches for example
397
413
  for batch_id in range((n_batches) if exists(n_batches) else 1):
398
414
  # ... Example ID
399
415
  example_id = f"{prefix}_{batch_id}" if exists(n_batches) else prefix
400
-
401
416
  if (
402
417
  self.skip_existing
403
418
  and exists(self.out_dir)
@@ -14,8 +14,8 @@ from atomworks.ml.transforms.base import Compose, Transform
14
14
  from omegaconf import DictConfig, OmegaConf
15
15
  from rfd3.inference.input_parsing import (
16
16
  DesignInputSpecification,
17
+ ensure_input_is_abspath,
17
18
  )
18
- from rfd3.utils.inference import ensure_input_is_abspath
19
19
  from torch.utils.data import (
20
20
  DataLoader,
21
21
  SequentialSampler,
@@ -5,6 +5,7 @@ import os
5
5
  import time
6
6
  import warnings
7
7
  from contextlib import contextmanager
8
+ from os import PathLike
8
9
  from typing import Any, Dict, List, Optional, Union
9
10
 
10
11
  import numpy as np
@@ -127,8 +128,10 @@ class DesignInputSpecification(BaseModel):
127
128
  # Motif selection from input file
128
129
  contig: Optional[InputSelection] = Field(None, description="Contig specification string (e.g. 'A1-10,B1-5')")
129
130
  unindex: Optional[InputSelection] = Field(None,
130
- description="Unindexed components string (components must not overlap with contig). "\
131
- "E.g. 'A15-20,B6-10' or dict. We recommend specifying")
131
+ description="Unindexed components selection. Components to fix in the generated structure without specifying sequence index. "\
132
+ "Components must not overlap with `contig` argument. "\
133
+ "E.g. 'A15-20,B6-10' or dict. We recommend specifying unindexed residues as a contig string, "\
134
+ "then using select_fixed_atoms will subset the atoms to the specified atoms")
132
135
  # Extra args:
133
136
  length: Optional[str] = Field(None, description="Length range as 'min-max' or int. Constrains length of contig if provided")
134
137
  ligand: Optional[str] = Field(None, description="Ligand name or index to include in design.")
@@ -1121,3 +1124,33 @@ def accumulate_components(
1121
1124
  if atom_array_accum.bonds is None:
1122
1125
  atom_array_accum.bonds = BondList(atom_array_accum.array_length())
1123
1126
  return atom_array_accum
1127
+
1128
+
1129
+ def ensure_input_is_abspath(args: Dict[str, Any], path: PathLike | None):
1130
+ """
1131
+ Ensures the input source is an absolute path if exists, if not it will convert
1132
+
1133
+ args:
1134
+ args: Inference specification for atom array
1135
+ path: None or file to which the input is relative to.
1136
+ """
1137
+ if isinstance(args, str):
1138
+ raise ValueError(
1139
+ "Expected args to be a dictionary, got a string: {}. If you are using an input JSON ensure it contains dictionaries of arguments".format(
1140
+ args
1141
+ )
1142
+ )
1143
+ if "input" not in args or not exists(args["input"]):
1144
+ return args
1145
+ input = str(args["input"])
1146
+ if not os.path.isabs(input):
1147
+ if path is None:
1148
+ raise ValueError(
1149
+ "Input path is relative, but no base path was provided to resolve it against."
1150
+ )
1151
+ input = os.path.abspath(os.path.join(os.path.dirname(str(path)), input))
1152
+ logger.info(
1153
+ f"Input source path is relative, converted to absolute path: {input}"
1154
+ )
1155
+ args["input"] = input
1156
+ return args
@@ -24,7 +24,16 @@ def check_symmetry_config(
24
24
  assert sym_conf.id, "symmetry_id is required. e.g. {'id': 'C2'}"
25
25
  # if unsym motif is provided, check that each motif name is in the atom array
26
26
 
27
+ is_motif_atom = get_motif_features(atom_array)["is_motif_atom"]
27
28
  is_unsym_motif = np.zeros(atom_array.shape[0], dtype=bool)
29
+
30
+ if not is_motif_atom.any():
31
+ sym_conf.is_symmetric_motif = None
32
+ ranked_logger.warning(
33
+ "No motifs found in atom array. Setting is_symmetric_motif to None."
34
+ )
35
+ return sym_conf
36
+
28
37
  if sym_conf.is_unsym_motif:
29
38
  assert (
30
39
  src_atom_array is not None
@@ -36,21 +45,20 @@ def check_symmetry_config(
36
45
  if (sm and n not in sm.split(",")) and (n not in atom_array.src_component):
37
46
  raise ValueError(f"Unsym motif {n} not found in atom_array")
38
47
 
39
- is_motif_token = get_motif_features(atom_array)["is_motif_token"]
40
48
  if (
41
- is_motif_token[~is_unsym_motif].any()
49
+ is_motif_atom[~is_unsym_motif].any()
42
50
  and not sym_conf.is_symmetric_motif
43
51
  and not has_dist_cond
44
52
  ):
45
53
  raise ValueError(
46
- "Asymmetric motif inputs should be distance constrained."
47
- "Use atomwise_fixed_dist to constrain the distance between the motif atoms."
54
+ "Asymmetric motif inputs are not supported yet. Please provide a symmetric motif."
48
55
  )
49
56
 
50
57
  if partial and not sym_conf.is_symmetric_motif:
51
58
  raise ValueError(
52
59
  "Partial diffusion with symmetry is only supported for symmetric inputs."
53
60
  )
61
+ return sym_conf
54
62
 
55
63
 
56
64
  def check_atom_array_is_symmetric(atom_array):
@@ -83,7 +83,7 @@ def make_symmetric_atom_array(
83
83
  if not isinstance(sym_conf, SymmetryConfig):
84
84
  sym_conf = convery_sym_conf_to_symmetry_config(sym_conf)
85
85
 
86
- check_symmetry_config(
86
+ sym_conf = check_symmetry_config(
87
87
  asu_atom_array, sym_conf, sm, has_dist_cond, src_atom_array=src_atom_array
88
88
  )
89
89
  # Adding utility annotations to the asu atom array
@@ -99,12 +99,11 @@ def make_symmetric_atom_array(
99
99
  assert (
100
100
  src_atom_array is not None
101
101
  ), "Source atom array must be provided for symmetric motifs"
102
- # if symmetric motif is provided, get the frames from the src atom array.
103
102
  frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
104
- elif (asu_atom_array._is_motif[~asu_atom_array._is_unsym_motif]).any():
105
- # if the motifs that's not unsym motifs are present.
106
- raise NotImplementedError(
107
- "Asymmetric motif inputs are not implemented yet. please symmetrize the motif."
103
+ else:
104
+ # At this point, asym case would have been caught by the check_symmetry_config function.
105
+ ranked_logger.info(
106
+ "No motifs found in atom array. Generating unconditional symmetric proteins."
108
107
  )
109
108
 
110
109
  # Add symmetry annotations to the asu atom array
@@ -1,11 +1,15 @@
1
1
  import inspect
2
+ import time
2
3
  from dataclasses import dataclass
3
4
  from typing import Any, Literal
4
5
 
5
6
  import torch
6
7
  from jaxtyping import Float
8
+ from rfd3.inference.symmetry.symmetry_utils import apply_symmetry_to_xyz_atomwise
9
+ from rfd3.model.cfg_utils import strip_X
7
10
 
8
11
  from foundry.common import exists
12
+ from foundry.utils.alignment import weighted_rigid_align
9
13
  from foundry.utils.ddp import RankedLogger
10
14
  from foundry.utils.rotation_augmentation import (
11
15
  rot_vec_mul,
@@ -110,14 +114,16 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
110
114
  )
111
115
  # Fallback to smallest available step
112
116
  noise_schedule_original = self._construct_inference_noise_schedule(
113
- device=coord_atom_lvl_to_be_noised.device
117
+ device=device
114
118
  )
115
119
  noise_schedule = noise_schedule_original[-1:] # Just use the final step
116
120
  ranked_logger.info(
117
121
  f"Using fallback: final step with t={noise_schedule[0].item():.6f}"
118
122
  )
123
+ else:
124
+ noise_schedule = t_hat
119
125
 
120
- return t_hat
126
+ return noise_schedule
121
127
 
122
128
  def _get_initial_structure(
123
129
  self,
@@ -221,6 +227,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
221
227
  # Handle chunked mode vs standard mode
222
228
  if "chunked_pairwise_embedder" in initializer_outputs:
223
229
  # Chunked mode: explicitly provide P_LL=None
230
+ tic = time.time()
224
231
  chunked_embedder = initializer_outputs[
225
232
  "chunked_pairwise_embedder"
226
233
  ] # Don't pop, just get
@@ -238,6 +245,8 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
238
245
  initializer_outputs=other_outputs,
239
246
  **other_outputs,
240
247
  )
248
+ toc = time.time()
249
+ ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
241
250
  else:
242
251
  # Standard mode: P_LL is included in initializer_outputs
243
252
  outs = diffusion_module(
@@ -445,6 +454,7 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
445
454
  # Handle chunked mode vs standard mode (same as default sampler)
446
455
  if "chunked_pairwise_embedder" in initializer_outputs:
447
456
  # Chunked mode: explicitly provide P_LL=None
457
+ tic = time.time()
448
458
  chunked_embedder = initializer_outputs[
449
459
  "chunked_pairwise_embedder"
450
460
  ] # Don't pop, just get
@@ -462,6 +472,8 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
462
472
  initializer_outputs=other_outputs,
463
473
  **other_outputs,
464
474
  )
475
+ toc = time.time()
476
+ ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
465
477
  else:
466
478
  # Standard mode: P_LL is included in initializer_outputs
467
479
  outs = diffusion_module(
@@ -118,14 +118,14 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
118
118
 
119
119
  Parameters
120
120
  ----------
121
- P_LK_indices : (D, L, k) LongTensor
121
+ P_LK_indices : (B, L, k) LongTensor
122
122
  Key indices | P_LK_indices[d, i, k] = global atom index for which atom i attends to.
123
- P_LK : (D, L, k, c) FloatTensor
123
+ P_LK : (B, L, k, c) FloatTensor
124
124
  Key features to scatter add into
125
125
 
126
- P_LA_indices : (D, L, a) LongTensor
126
+ P_LA_indices : (B, L, a) LongTensor
127
127
  Additional feature indices to scatter into P_LK.
128
- P_LA : (D, L, a, c) FloatTensor
128
+ P_LA : (B, L, a, c) FloatTensor
129
129
  Features corresponding to P_LA.
130
130
 
131
131
  Both index tensors contain indices representing D batch dim,
@@ -135,42 +135,42 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
135
135
 
136
136
  """
137
137
  # Handle case when indices and P_LA don't have batch dimensions
138
- D, L, k = P_LK_indices.shape
138
+ B, L, k = P_LK_indices.shape
139
139
  if P_LA_indices.ndim == 2:
140
- P_LA_indices = P_LA_indices.unsqueeze(0).expand(D, -1, -1)
140
+ P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1)
141
141
  if P_LA_src.ndim == 3:
142
- P_LA_src = P_LA_src.unsqueeze(0).expand(D, -1, -1)
142
+ P_LA_src = P_LA_src.unsqueeze(0).expand(B, -1, -1)
143
143
  assert (
144
144
  P_LA_src.shape[-1] == P_LK_tgt.shape[-1]
145
145
  ), "Channel dims do not match, got: {} vs {}".format(
146
146
  P_LA_src.shape[-1], P_LK_tgt.shape[-1]
147
147
  )
148
148
 
149
- matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (D, L, a, k)
149
+ matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (B, L, a, k)
150
150
  if not torch.all(matches.sum(dim=(-1, -2)) >= 1):
151
151
  raise ValueError("Found multiple scatter indices for some atoms")
152
152
  elif not torch.all(matches.sum(dim=-1) <= 1):
153
153
  raise ValueError("Did not find a scatter index for every atom")
154
- k_indices = matches.long().argmax(dim=-1) # (D, L, a)
154
+ k_indices = matches.long().argmax(dim=-1) # (B, L, a)
155
155
  scatter_indices = k_indices.unsqueeze(-1).expand(
156
156
  -1, -1, -1, P_LK_tgt.shape[-1]
157
- ) # (D, L, a, c)
157
+ ) # (B, L, a, c)
158
158
  P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src)
159
159
  return P_LK_tgt
160
160
 
161
161
 
162
162
  def _batched_gather(values: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
163
163
  """
164
- values : (D, L, C)
165
- idx : (D, L, k)
166
- returns: (D, L, k, C)
164
+ values : (B, L, C)
165
+ idx : (B, L, k)
166
+ returns: (B, L, k, C)
167
167
  """
168
- D, L, C = values.shape
168
+ B, L, C = values.shape
169
169
  k = idx.shape[-1]
170
170
 
171
- # (D, L, 1, C) → stride-0 along k → (D, L, k, C)
171
+ # (B, L, 1, C) → stride-0 along k → (B, L, k, C)
172
172
  src = values.unsqueeze(2).expand(-1, -1, k, -1)
173
- idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (D, L, k, C)
173
+ idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (B, L, k, C)
174
174
 
175
175
  return torch.gather(src, 1, idx) # dim=1 is the L-axis
176
176
 
@@ -196,7 +196,7 @@ def create_attention_indices(
196
196
  X_L = torch.randn(
197
197
  (1, L, 3), device=device, dtype=torch.float
198
198
  ) # [L, 3] - random
199
- D_LL = torch.cdist(X_L, X_L, p=2) # [D, L, L] - pairwise atom distances
199
+ D_LL = torch.cdist(X_L, X_L, p=2) # [B, L, L] - pairwise atom distances
200
200
 
201
201
  # Create attention indices using neighbour distances
202
202
  base_mask = ~f["unindexing_pair_mask"][
@@ -231,7 +231,7 @@ def create_attention_indices(
231
231
  k_max=k_actual,
232
232
  chain_id=chain_ids,
233
233
  base_mask=base_mask,
234
- ) # [D, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
234
+ ) # [B, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
235
235
 
236
236
  return attn_indices
237
237
 
@@ -245,7 +245,7 @@ def get_sparse_attention_indices_with_inter_chain(
245
245
 
246
246
  Args:
247
247
  tok_idx: atom to token mapping
248
- D_LL: pairwise distances [D, L, L]
248
+ D_LL: pairwise distances [B, L, L]
249
249
  n_seq_neighbours: number of sequence neighbors
250
250
  k_intra: number of intra-chain attention keys
251
251
  k_inter: number of inter-chain attention keys
@@ -253,29 +253,29 @@ def get_sparse_attention_indices_with_inter_chain(
253
253
  base_mask: base mask for valid pairs
254
254
 
255
255
  Returns:
256
- attn_indices: [D, L, k_total] where k_total = k_intra + k_inter
256
+ attn_indices: [B, L, k_total] where k_total = k_intra + k_inter
257
257
  """
258
- D, L, _ = D_LL.shape
258
+ B, L, _ = D_LL.shape
259
259
 
260
260
  # Get regular intra-chain indices (limited to k_intra)
261
261
  intra_indices = get_sparse_attention_indices(
262
262
  tok_idx, D_LL, n_seq_neighbours, k_intra, chain_id, base_mask
263
- ) # [D, L, k_intra]
263
+ ) # [B, L, k_intra]
264
264
 
265
265
  # Get inter-chain indices for clash avoidance
266
- inter_indices = torch.zeros(D, L, k_inter, dtype=torch.long, device=D_LL.device)
267
-
268
- for d in range(D):
269
- for l in range(L):
270
- query_chain = chain_id[l]
266
+ inter_indices = torch.zeros(B, L, k_inter, dtype=torch.long, device=D_LL.device)
267
+ unique_chains = torch.unique(chain_id)
268
+ for b in range(B):
269
+ for c in unique_chains:
270
+ query_chain = chain_id[c]
271
271
 
272
272
  # Find atoms from different chains
273
- other_chain_mask = (chain_id != query_chain) & base_mask[l, :]
273
+ other_chain_mask = (chain_id != query_chain) & base_mask[c, :]
274
274
  other_chain_atoms = torch.where(other_chain_mask)[0]
275
275
 
276
276
  if len(other_chain_atoms) > 0:
277
277
  # Get distances to other chains
278
- distances_to_other = D_LL[d, l, other_chain_atoms]
278
+ distances_to_other = D_LL[b, c, other_chain_atoms]
279
279
 
280
280
  # Select k_inter closest atoms from other chains
281
281
  n_select = min(k_inter, len(other_chain_atoms))
@@ -283,23 +283,23 @@ def get_sparse_attention_indices_with_inter_chain(
283
283
  selected_atoms = other_chain_atoms[closest_idx]
284
284
 
285
285
  # Fill inter-chain indices
286
- inter_indices[d, l, :n_select] = selected_atoms
286
+ inter_indices[b, c, :n_select] = selected_atoms
287
287
  # Pad with random atoms if needed
288
288
  if n_select < k_inter:
289
289
  padding = torch.randint(
290
290
  0, L, (k_inter - n_select,), device=D_LL.device
291
291
  )
292
- inter_indices[d, l, n_select:] = padding
292
+ inter_indices[b, c, n_select:] = padding
293
293
  else:
294
294
  # No other chains found, fill with random indices
295
- inter_indices[d, l, :] = torch.randint(
295
+ inter_indices[b, c, :] = torch.randint(
296
296
  0, L, (k_inter,), device=D_LL.device
297
297
  )
298
298
 
299
299
  # Combine intra and inter chain indices
300
300
  combined_indices = torch.cat(
301
301
  [intra_indices, inter_indices], dim=-1
302
- ) # [D, L, k_total]
302
+ ) # [B, L, k_total]
303
303
 
304
304
  return combined_indices
305
305