rc-foundry 0.1.7__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/trainers/fabric.py +18 -2
- foundry/utils/components.py +3 -3
- foundry/utils/ddp.py +16 -13
- foundry/utils/logging.py +1 -1
- foundry/utils/xpu/__init__.py +27 -0
- foundry/utils/xpu/single_xpu_strategy.py +47 -0
- foundry/utils/xpu/xpu_accelerator.py +91 -0
- foundry/utils/xpu/xpu_precision.py +72 -0
- foundry/version.py +2 -2
- mpnn/inference_engines/mpnn.py +6 -2
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.10.dist-info}/METADATA +16 -2
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.10.dist-info}/RECORD +36 -30
- rf3/cli.py +13 -4
- rf3/configs/inference.yaml +5 -0
- rf3/configs/train.yaml +5 -0
- rf3/configs/trainer/xpu.yaml +6 -0
- rf3/configs/validate.yaml +5 -0
- rf3/inference.py +3 -1
- rfd3/configs/dev.yaml +1 -0
- rfd3/configs/inference.yaml +1 -0
- rfd3/configs/train.yaml +2 -1
- rfd3/configs/trainer/xpu.yaml +6 -0
- rfd3/configs/validate.yaml +1 -0
- rfd3/engine.py +25 -10
- rfd3/inference/datasets.py +1 -1
- rfd3/inference/input_parsing.py +35 -2
- rfd3/inference/symmetry/checks.py +12 -4
- rfd3/inference/symmetry/symmetry_utils.py +5 -6
- rfd3/model/inference_sampler.py +14 -2
- rfd3/model/layers/block_utils.py +33 -33
- rfd3/model/layers/chunked_pairwise.py +84 -82
- rfd3/transforms/hbonds_hbplus.py +52 -49
- rfd3/utils/inference.py +7 -30
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.10.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.10.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.10.dist-info}/licenses/LICENSE.md +0 -0
rf3/configs/validate.yaml
CHANGED
|
@@ -2,6 +2,11 @@
|
|
|
2
2
|
# ^ The "package" determines where the content of the config is placed in the output config
|
|
3
3
|
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
|
|
4
4
|
|
|
5
|
+
hydra:
|
|
6
|
+
searchpath:
|
|
7
|
+
- pkg://rf3.configs
|
|
8
|
+
- pkg://configs
|
|
9
|
+
|
|
5
10
|
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
|
|
6
11
|
defaults:
|
|
7
12
|
- callbacks: default
|
rf3/inference.py
CHANGED
|
@@ -16,7 +16,9 @@ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
16
16
|
|
|
17
17
|
load_dotenv(override=True)
|
|
18
18
|
|
|
19
|
-
_config_path = os.path.join(
|
|
19
|
+
_config_path = os.path.join(
|
|
20
|
+
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs"
|
|
21
|
+
)
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
@hydra.main(
|
rfd3/configs/dev.yaml
CHANGED
rfd3/configs/inference.yaml
CHANGED
rfd3/configs/train.yaml
CHANGED
rfd3/configs/validate.yaml
CHANGED
rfd3/engine.py
CHANGED
|
@@ -5,7 +5,7 @@ import time
|
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from os import PathLike
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Dict, List, Optional
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
11
|
import yaml
|
|
@@ -21,11 +21,13 @@ from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
|
|
|
21
21
|
from rfd3.inference.datasets import (
|
|
22
22
|
assemble_distributed_inference_loader_from_json,
|
|
23
23
|
)
|
|
24
|
-
from rfd3.inference.input_parsing import
|
|
24
|
+
from rfd3.inference.input_parsing import (
|
|
25
|
+
DesignInputSpecification,
|
|
26
|
+
ensure_input_is_abspath,
|
|
27
|
+
)
|
|
25
28
|
from rfd3.model.inference_sampler import SampleDiffusionConfig
|
|
26
29
|
from rfd3.utils.inference import (
|
|
27
30
|
ensure_inference_sampler_matches_design_spec,
|
|
28
|
-
ensure_input_is_abspath,
|
|
29
31
|
)
|
|
30
32
|
from rfd3.utils.io import (
|
|
31
33
|
CIF_LIKE_EXTENSIONS,
|
|
@@ -46,9 +48,8 @@ class RFD3InferenceConfig:
|
|
|
46
48
|
diffusion_batch_size: int = 16
|
|
47
49
|
|
|
48
50
|
# RFD3 specific
|
|
49
|
-
skip_existing: bool = False
|
|
50
|
-
json_keys_subset: Optional[List[str]] = None
|
|
51
51
|
skip_existing: bool = True
|
|
52
|
+
json_keys_subset: Optional[List[str]] = None
|
|
52
53
|
specification: Optional[dict] = field(default_factory=dict)
|
|
53
54
|
inference_sampler: SampleDiffusionConfig | dict = field(default_factory=dict)
|
|
54
55
|
|
|
@@ -214,6 +215,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
214
215
|
inputs=inputs,
|
|
215
216
|
n_batches=n_batches,
|
|
216
217
|
)
|
|
218
|
+
if len(design_specifications) == 0:
|
|
219
|
+
ranked_logger.info("No design specifications to run. Skipping.")
|
|
220
|
+
return None
|
|
217
221
|
ensure_inference_sampler_matches_design_spec(
|
|
218
222
|
design_specifications, self.inference_sampler_overrides
|
|
219
223
|
)
|
|
@@ -376,28 +380,39 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
376
380
|
|
|
377
381
|
def _multiply_specifications(
|
|
378
382
|
self, inputs: Dict[str, dict | DesignInputSpecification], n_batches=None
|
|
379
|
-
) -> Dict[str,
|
|
383
|
+
) -> Dict[str, dict | DesignInputSpecification]:
|
|
380
384
|
# Find existing example IDS in output directory
|
|
381
385
|
if exists(self.out_dir):
|
|
382
|
-
|
|
386
|
+
existing_example_ids_ = set(
|
|
383
387
|
extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
|
|
384
388
|
for path in find_files_with_extension(self.out_dir, CIF_LIKE_EXTENSIONS)
|
|
385
389
|
)
|
|
390
|
+
existing_example_ids = set(
|
|
391
|
+
[
|
|
392
|
+
"_model_".join(eid.split("_model_")[:-1])
|
|
393
|
+
for eid in existing_example_ids_
|
|
394
|
+
]
|
|
395
|
+
)
|
|
386
396
|
ranked_logger.info(
|
|
387
|
-
f"Found {len(existing_example_ids)} existing example IDs in the output directory."
|
|
397
|
+
f"Found {len(existing_example_ids)} existing example IDs in the output directory ({len(existing_example_ids_)} total)."
|
|
388
398
|
)
|
|
389
399
|
|
|
390
400
|
# Based on inputs, construct the specifications to loop through
|
|
391
401
|
design_specifications = {}
|
|
392
402
|
for prefix, example_spec in inputs.items():
|
|
393
403
|
# Record task name in the specification
|
|
394
|
-
example_spec
|
|
404
|
+
if isinstance(example_spec, DesignInputSpecification):
|
|
405
|
+
example_spec.extra = example_spec.extra or {}
|
|
406
|
+
example_spec.extra["task_name"] = prefix
|
|
407
|
+
else:
|
|
408
|
+
if "extra" not in example_spec:
|
|
409
|
+
example_spec["extra"] = {}
|
|
410
|
+
example_spec["extra"]["task_name"] = prefix
|
|
395
411
|
|
|
396
412
|
# ... Create n_batches for example
|
|
397
413
|
for batch_id in range((n_batches) if exists(n_batches) else 1):
|
|
398
414
|
# ... Example ID
|
|
399
415
|
example_id = f"{prefix}_{batch_id}" if exists(n_batches) else prefix
|
|
400
|
-
|
|
401
416
|
if (
|
|
402
417
|
self.skip_existing
|
|
403
418
|
and exists(self.out_dir)
|
rfd3/inference/datasets.py
CHANGED
|
@@ -14,8 +14,8 @@ from atomworks.ml.transforms.base import Compose, Transform
|
|
|
14
14
|
from omegaconf import DictConfig, OmegaConf
|
|
15
15
|
from rfd3.inference.input_parsing import (
|
|
16
16
|
DesignInputSpecification,
|
|
17
|
+
ensure_input_is_abspath,
|
|
17
18
|
)
|
|
18
|
-
from rfd3.utils.inference import ensure_input_is_abspath
|
|
19
19
|
from torch.utils.data import (
|
|
20
20
|
DataLoader,
|
|
21
21
|
SequentialSampler,
|
rfd3/inference/input_parsing.py
CHANGED
|
@@ -5,6 +5,7 @@ import os
|
|
|
5
5
|
import time
|
|
6
6
|
import warnings
|
|
7
7
|
from contextlib import contextmanager
|
|
8
|
+
from os import PathLike
|
|
8
9
|
from typing import Any, Dict, List, Optional, Union
|
|
9
10
|
|
|
10
11
|
import numpy as np
|
|
@@ -127,8 +128,10 @@ class DesignInputSpecification(BaseModel):
|
|
|
127
128
|
# Motif selection from input file
|
|
128
129
|
contig: Optional[InputSelection] = Field(None, description="Contig specification string (e.g. 'A1-10,B1-5')")
|
|
129
130
|
unindex: Optional[InputSelection] = Field(None,
|
|
130
|
-
description="Unindexed components
|
|
131
|
-
"
|
|
131
|
+
description="Unindexed components selection. Components to fix in the generated structure without specifying sequence index. "\
|
|
132
|
+
"Components must not overlap with `contig` argument. "\
|
|
133
|
+
"E.g. 'A15-20,B6-10' or dict. We recommend specifying unindexed residues as a contig string, "\
|
|
134
|
+
"then using select_fixed_atoms will subset the atoms to the specified atoms")
|
|
132
135
|
# Extra args:
|
|
133
136
|
length: Optional[str] = Field(None, description="Length range as 'min-max' or int. Constrains length of contig if provided")
|
|
134
137
|
ligand: Optional[str] = Field(None, description="Ligand name or index to include in design.")
|
|
@@ -1121,3 +1124,33 @@ def accumulate_components(
|
|
|
1121
1124
|
if atom_array_accum.bonds is None:
|
|
1122
1125
|
atom_array_accum.bonds = BondList(atom_array_accum.array_length())
|
|
1123
1126
|
return atom_array_accum
|
|
1127
|
+
|
|
1128
|
+
|
|
1129
|
+
def ensure_input_is_abspath(args: Dict[str, Any], path: PathLike | None):
|
|
1130
|
+
"""
|
|
1131
|
+
Ensures the input source is an absolute path if exists, if not it will convert
|
|
1132
|
+
|
|
1133
|
+
args:
|
|
1134
|
+
args: Inference specification for atom array
|
|
1135
|
+
path: None or file to which the input is relative to.
|
|
1136
|
+
"""
|
|
1137
|
+
if isinstance(args, str):
|
|
1138
|
+
raise ValueError(
|
|
1139
|
+
"Expected args to be a dictionary, got a string: {}. If you are using an input JSON ensure it contains dictionaries of arguments".format(
|
|
1140
|
+
args
|
|
1141
|
+
)
|
|
1142
|
+
)
|
|
1143
|
+
if "input" not in args or not exists(args["input"]):
|
|
1144
|
+
return args
|
|
1145
|
+
input = str(args["input"])
|
|
1146
|
+
if not os.path.isabs(input):
|
|
1147
|
+
if path is None:
|
|
1148
|
+
raise ValueError(
|
|
1149
|
+
"Input path is relative, but no base path was provided to resolve it against."
|
|
1150
|
+
)
|
|
1151
|
+
input = os.path.abspath(os.path.join(os.path.dirname(str(path)), input))
|
|
1152
|
+
logger.info(
|
|
1153
|
+
f"Input source path is relative, converted to absolute path: {input}"
|
|
1154
|
+
)
|
|
1155
|
+
args["input"] = input
|
|
1156
|
+
return args
|
|
@@ -24,7 +24,16 @@ def check_symmetry_config(
|
|
|
24
24
|
assert sym_conf.id, "symmetry_id is required. e.g. {'id': 'C2'}"
|
|
25
25
|
# if unsym motif is provided, check that each motif name is in the atom array
|
|
26
26
|
|
|
27
|
+
is_motif_atom = get_motif_features(atom_array)["is_motif_atom"]
|
|
27
28
|
is_unsym_motif = np.zeros(atom_array.shape[0], dtype=bool)
|
|
29
|
+
|
|
30
|
+
if not is_motif_atom.any():
|
|
31
|
+
sym_conf.is_symmetric_motif = None
|
|
32
|
+
ranked_logger.warning(
|
|
33
|
+
"No motifs found in atom array. Setting is_symmetric_motif to None."
|
|
34
|
+
)
|
|
35
|
+
return sym_conf
|
|
36
|
+
|
|
28
37
|
if sym_conf.is_unsym_motif:
|
|
29
38
|
assert (
|
|
30
39
|
src_atom_array is not None
|
|
@@ -36,21 +45,20 @@ def check_symmetry_config(
|
|
|
36
45
|
if (sm and n not in sm.split(",")) and (n not in atom_array.src_component):
|
|
37
46
|
raise ValueError(f"Unsym motif {n} not found in atom_array")
|
|
38
47
|
|
|
39
|
-
is_motif_token = get_motif_features(atom_array)["is_motif_token"]
|
|
40
48
|
if (
|
|
41
|
-
|
|
49
|
+
is_motif_atom[~is_unsym_motif].any()
|
|
42
50
|
and not sym_conf.is_symmetric_motif
|
|
43
51
|
and not has_dist_cond
|
|
44
52
|
):
|
|
45
53
|
raise ValueError(
|
|
46
|
-
"Asymmetric motif inputs
|
|
47
|
-
"Use atomwise_fixed_dist to constrain the distance between the motif atoms."
|
|
54
|
+
"Asymmetric motif inputs are not supported yet. Please provide a symmetric motif."
|
|
48
55
|
)
|
|
49
56
|
|
|
50
57
|
if partial and not sym_conf.is_symmetric_motif:
|
|
51
58
|
raise ValueError(
|
|
52
59
|
"Partial diffusion with symmetry is only supported for symmetric inputs."
|
|
53
60
|
)
|
|
61
|
+
return sym_conf
|
|
54
62
|
|
|
55
63
|
|
|
56
64
|
def check_atom_array_is_symmetric(atom_array):
|
|
@@ -83,7 +83,7 @@ def make_symmetric_atom_array(
|
|
|
83
83
|
if not isinstance(sym_conf, SymmetryConfig):
|
|
84
84
|
sym_conf = convery_sym_conf_to_symmetry_config(sym_conf)
|
|
85
85
|
|
|
86
|
-
check_symmetry_config(
|
|
86
|
+
sym_conf = check_symmetry_config(
|
|
87
87
|
asu_atom_array, sym_conf, sm, has_dist_cond, src_atom_array=src_atom_array
|
|
88
88
|
)
|
|
89
89
|
# Adding utility annotations to the asu atom array
|
|
@@ -99,12 +99,11 @@ def make_symmetric_atom_array(
|
|
|
99
99
|
assert (
|
|
100
100
|
src_atom_array is not None
|
|
101
101
|
), "Source atom array must be provided for symmetric motifs"
|
|
102
|
-
# if symmetric motif is provided, get the frames from the src atom array.
|
|
103
102
|
frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
|
|
104
|
-
|
|
105
|
-
#
|
|
106
|
-
|
|
107
|
-
"
|
|
103
|
+
else:
|
|
104
|
+
# At this point, asym case would have been caught by the check_symmetry_config function.
|
|
105
|
+
ranked_logger.info(
|
|
106
|
+
"No motifs found in atom array. Generating unconditional symmetric proteins."
|
|
108
107
|
)
|
|
109
108
|
|
|
110
109
|
# Add symmetry annotations to the asu atom array
|
rfd3/model/inference_sampler.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import time
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
from typing import Any, Literal
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
from jaxtyping import Float
|
|
8
|
+
from rfd3.inference.symmetry.symmetry_utils import apply_symmetry_to_xyz_atomwise
|
|
9
|
+
from rfd3.model.cfg_utils import strip_X
|
|
7
10
|
|
|
8
11
|
from foundry.common import exists
|
|
12
|
+
from foundry.utils.alignment import weighted_rigid_align
|
|
9
13
|
from foundry.utils.ddp import RankedLogger
|
|
10
14
|
from foundry.utils.rotation_augmentation import (
|
|
11
15
|
rot_vec_mul,
|
|
@@ -110,14 +114,16 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
|
|
110
114
|
)
|
|
111
115
|
# Fallback to smallest available step
|
|
112
116
|
noise_schedule_original = self._construct_inference_noise_schedule(
|
|
113
|
-
device=
|
|
117
|
+
device=device
|
|
114
118
|
)
|
|
115
119
|
noise_schedule = noise_schedule_original[-1:] # Just use the final step
|
|
116
120
|
ranked_logger.info(
|
|
117
121
|
f"Using fallback: final step with t={noise_schedule[0].item():.6f}"
|
|
118
122
|
)
|
|
123
|
+
else:
|
|
124
|
+
noise_schedule = t_hat
|
|
119
125
|
|
|
120
|
-
return
|
|
126
|
+
return noise_schedule
|
|
121
127
|
|
|
122
128
|
def _get_initial_structure(
|
|
123
129
|
self,
|
|
@@ -221,6 +227,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
|
|
221
227
|
# Handle chunked mode vs standard mode
|
|
222
228
|
if "chunked_pairwise_embedder" in initializer_outputs:
|
|
223
229
|
# Chunked mode: explicitly provide P_LL=None
|
|
230
|
+
tic = time.time()
|
|
224
231
|
chunked_embedder = initializer_outputs[
|
|
225
232
|
"chunked_pairwise_embedder"
|
|
226
233
|
] # Don't pop, just get
|
|
@@ -238,6 +245,8 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
|
|
238
245
|
initializer_outputs=other_outputs,
|
|
239
246
|
**other_outputs,
|
|
240
247
|
)
|
|
248
|
+
toc = time.time()
|
|
249
|
+
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
|
|
241
250
|
else:
|
|
242
251
|
# Standard mode: P_LL is included in initializer_outputs
|
|
243
252
|
outs = diffusion_module(
|
|
@@ -445,6 +454,7 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
|
|
445
454
|
# Handle chunked mode vs standard mode (same as default sampler)
|
|
446
455
|
if "chunked_pairwise_embedder" in initializer_outputs:
|
|
447
456
|
# Chunked mode: explicitly provide P_LL=None
|
|
457
|
+
tic = time.time()
|
|
448
458
|
chunked_embedder = initializer_outputs[
|
|
449
459
|
"chunked_pairwise_embedder"
|
|
450
460
|
] # Don't pop, just get
|
|
@@ -462,6 +472,8 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
|
|
462
472
|
initializer_outputs=other_outputs,
|
|
463
473
|
**other_outputs,
|
|
464
474
|
)
|
|
475
|
+
toc = time.time()
|
|
476
|
+
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
|
|
465
477
|
else:
|
|
466
478
|
# Standard mode: P_LL is included in initializer_outputs
|
|
467
479
|
outs = diffusion_module(
|
rfd3/model/layers/block_utils.py
CHANGED
|
@@ -118,14 +118,14 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
|
|
|
118
118
|
|
|
119
119
|
Parameters
|
|
120
120
|
----------
|
|
121
|
-
P_LK_indices : (
|
|
121
|
+
P_LK_indices : (B, L, k) LongTensor
|
|
122
122
|
Key indices | P_LK_indices[d, i, k] = global atom index for which atom i attends to.
|
|
123
|
-
P_LK : (
|
|
123
|
+
P_LK : (B, L, k, c) FloatTensor
|
|
124
124
|
Key features to scatter add into
|
|
125
125
|
|
|
126
|
-
P_LA_indices : (
|
|
126
|
+
P_LA_indices : (B, L, a) LongTensor
|
|
127
127
|
Additional feature indices to scatter into P_LK.
|
|
128
|
-
P_LA : (
|
|
128
|
+
P_LA : (B, L, a, c) FloatTensor
|
|
129
129
|
Features corresponding to P_LA.
|
|
130
130
|
|
|
131
131
|
Both index tensors contain indices representing D batch dim,
|
|
@@ -135,42 +135,42 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
|
|
|
135
135
|
|
|
136
136
|
"""
|
|
137
137
|
# Handle case when indices and P_LA don't have batch dimensions
|
|
138
|
-
|
|
138
|
+
B, L, k = P_LK_indices.shape
|
|
139
139
|
if P_LA_indices.ndim == 2:
|
|
140
|
-
P_LA_indices = P_LA_indices.unsqueeze(0).expand(
|
|
140
|
+
P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1)
|
|
141
141
|
if P_LA_src.ndim == 3:
|
|
142
|
-
P_LA_src = P_LA_src.unsqueeze(0).expand(
|
|
142
|
+
P_LA_src = P_LA_src.unsqueeze(0).expand(B, -1, -1)
|
|
143
143
|
assert (
|
|
144
144
|
P_LA_src.shape[-1] == P_LK_tgt.shape[-1]
|
|
145
145
|
), "Channel dims do not match, got: {} vs {}".format(
|
|
146
146
|
P_LA_src.shape[-1], P_LK_tgt.shape[-1]
|
|
147
147
|
)
|
|
148
148
|
|
|
149
|
-
matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (
|
|
149
|
+
matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (B, L, a, k)
|
|
150
150
|
if not torch.all(matches.sum(dim=(-1, -2)) >= 1):
|
|
151
151
|
raise ValueError("Found multiple scatter indices for some atoms")
|
|
152
152
|
elif not torch.all(matches.sum(dim=-1) <= 1):
|
|
153
153
|
raise ValueError("Did not find a scatter index for every atom")
|
|
154
|
-
k_indices = matches.long().argmax(dim=-1) # (
|
|
154
|
+
k_indices = matches.long().argmax(dim=-1) # (B, L, a)
|
|
155
155
|
scatter_indices = k_indices.unsqueeze(-1).expand(
|
|
156
156
|
-1, -1, -1, P_LK_tgt.shape[-1]
|
|
157
|
-
) # (
|
|
157
|
+
) # (B, L, a, c)
|
|
158
158
|
P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src)
|
|
159
159
|
return P_LK_tgt
|
|
160
160
|
|
|
161
161
|
|
|
162
162
|
def _batched_gather(values: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|
163
163
|
"""
|
|
164
|
-
values : (
|
|
165
|
-
idx : (
|
|
166
|
-
returns: (
|
|
164
|
+
values : (B, L, C)
|
|
165
|
+
idx : (B, L, k)
|
|
166
|
+
returns: (B, L, k, C)
|
|
167
167
|
"""
|
|
168
|
-
|
|
168
|
+
B, L, C = values.shape
|
|
169
169
|
k = idx.shape[-1]
|
|
170
170
|
|
|
171
|
-
# (
|
|
171
|
+
# (B, L, 1, C) → stride-0 along k → (B, L, k, C)
|
|
172
172
|
src = values.unsqueeze(2).expand(-1, -1, k, -1)
|
|
173
|
-
idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (
|
|
173
|
+
idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (B, L, k, C)
|
|
174
174
|
|
|
175
175
|
return torch.gather(src, 1, idx) # dim=1 is the L-axis
|
|
176
176
|
|
|
@@ -196,7 +196,7 @@ def create_attention_indices(
|
|
|
196
196
|
X_L = torch.randn(
|
|
197
197
|
(1, L, 3), device=device, dtype=torch.float
|
|
198
198
|
) # [L, 3] - random
|
|
199
|
-
D_LL = torch.cdist(X_L, X_L, p=2) # [
|
|
199
|
+
D_LL = torch.cdist(X_L, X_L, p=2) # [B, L, L] - pairwise atom distances
|
|
200
200
|
|
|
201
201
|
# Create attention indices using neighbour distances
|
|
202
202
|
base_mask = ~f["unindexing_pair_mask"][
|
|
@@ -231,7 +231,7 @@ def create_attention_indices(
|
|
|
231
231
|
k_max=k_actual,
|
|
232
232
|
chain_id=chain_ids,
|
|
233
233
|
base_mask=base_mask,
|
|
234
|
-
) # [
|
|
234
|
+
) # [B, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
|
|
235
235
|
|
|
236
236
|
return attn_indices
|
|
237
237
|
|
|
@@ -245,7 +245,7 @@ def get_sparse_attention_indices_with_inter_chain(
|
|
|
245
245
|
|
|
246
246
|
Args:
|
|
247
247
|
tok_idx: atom to token mapping
|
|
248
|
-
D_LL: pairwise distances [
|
|
248
|
+
D_LL: pairwise distances [B, L, L]
|
|
249
249
|
n_seq_neighbours: number of sequence neighbors
|
|
250
250
|
k_intra: number of intra-chain attention keys
|
|
251
251
|
k_inter: number of inter-chain attention keys
|
|
@@ -253,29 +253,29 @@ def get_sparse_attention_indices_with_inter_chain(
|
|
|
253
253
|
base_mask: base mask for valid pairs
|
|
254
254
|
|
|
255
255
|
Returns:
|
|
256
|
-
attn_indices: [
|
|
256
|
+
attn_indices: [B, L, k_total] where k_total = k_intra + k_inter
|
|
257
257
|
"""
|
|
258
|
-
|
|
258
|
+
B, L, _ = D_LL.shape
|
|
259
259
|
|
|
260
260
|
# Get regular intra-chain indices (limited to k_intra)
|
|
261
261
|
intra_indices = get_sparse_attention_indices(
|
|
262
262
|
tok_idx, D_LL, n_seq_neighbours, k_intra, chain_id, base_mask
|
|
263
|
-
) # [
|
|
263
|
+
) # [B, L, k_intra]
|
|
264
264
|
|
|
265
265
|
# Get inter-chain indices for clash avoidance
|
|
266
|
-
inter_indices = torch.zeros(
|
|
267
|
-
|
|
268
|
-
for
|
|
269
|
-
for
|
|
270
|
-
query_chain = chain_id[
|
|
266
|
+
inter_indices = torch.zeros(B, L, k_inter, dtype=torch.long, device=D_LL.device)
|
|
267
|
+
unique_chains = torch.unique(chain_id)
|
|
268
|
+
for b in range(B):
|
|
269
|
+
for c in unique_chains:
|
|
270
|
+
query_chain = chain_id[c]
|
|
271
271
|
|
|
272
272
|
# Find atoms from different chains
|
|
273
|
-
other_chain_mask = (chain_id != query_chain) & base_mask[
|
|
273
|
+
other_chain_mask = (chain_id != query_chain) & base_mask[c, :]
|
|
274
274
|
other_chain_atoms = torch.where(other_chain_mask)[0]
|
|
275
275
|
|
|
276
276
|
if len(other_chain_atoms) > 0:
|
|
277
277
|
# Get distances to other chains
|
|
278
|
-
distances_to_other = D_LL[
|
|
278
|
+
distances_to_other = D_LL[b, c, other_chain_atoms]
|
|
279
279
|
|
|
280
280
|
# Select k_inter closest atoms from other chains
|
|
281
281
|
n_select = min(k_inter, len(other_chain_atoms))
|
|
@@ -283,23 +283,23 @@ def get_sparse_attention_indices_with_inter_chain(
|
|
|
283
283
|
selected_atoms = other_chain_atoms[closest_idx]
|
|
284
284
|
|
|
285
285
|
# Fill inter-chain indices
|
|
286
|
-
inter_indices[
|
|
286
|
+
inter_indices[b, c, :n_select] = selected_atoms
|
|
287
287
|
# Pad with random atoms if needed
|
|
288
288
|
if n_select < k_inter:
|
|
289
289
|
padding = torch.randint(
|
|
290
290
|
0, L, (k_inter - n_select,), device=D_LL.device
|
|
291
291
|
)
|
|
292
|
-
inter_indices[
|
|
292
|
+
inter_indices[b, c, n_select:] = padding
|
|
293
293
|
else:
|
|
294
294
|
# No other chains found, fill with random indices
|
|
295
|
-
inter_indices[
|
|
295
|
+
inter_indices[b, c, :] = torch.randint(
|
|
296
296
|
0, L, (k_inter,), device=D_LL.device
|
|
297
297
|
)
|
|
298
298
|
|
|
299
299
|
# Combine intra and inter chain indices
|
|
300
300
|
combined_indices = torch.cat(
|
|
301
301
|
[intra_indices, inter_indices], dim=-1
|
|
302
|
-
) # [
|
|
302
|
+
) # [B, L, k_total]
|
|
303
303
|
|
|
304
304
|
return combined_indices
|
|
305
305
|
|