rcsb-embedding-model 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rcsb-embedding-model might be problematic. Click here for more details.
- rcsb_embedding_model/cli/args_utils.py +0 -2
- rcsb_embedding_model/cli/inference.py +164 -42
- rcsb_embedding_model/dataset/esm_prot_from_chain.py +102 -0
- rcsb_embedding_model/dataset/esm_prot_from_structure.py +63 -0
- rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py +68 -0
- rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py +94 -0
- rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py +43 -0
- rcsb_embedding_model/inference/assembly_inferece.py +53 -0
- rcsb_embedding_model/inference/chain_inference.py +12 -8
- rcsb_embedding_model/inference/esm_inference.py +18 -8
- rcsb_embedding_model/inference/structure_inference.py +61 -0
- rcsb_embedding_model/modules/structure_module.py +27 -0
- rcsb_embedding_model/rcsb_structure_embedding.py +7 -8
- rcsb_embedding_model/types/api_types.py +27 -5
- rcsb_embedding_model/utils/data.py +30 -0
- rcsb_embedding_model/utils/structure_parser.py +43 -13
- rcsb_embedding_model/utils/structure_provider.py +27 -0
- rcsb_embedding_model-0.0.8.dist-info/METADATA +129 -0
- rcsb_embedding_model-0.0.8.dist-info/RECORD +29 -0
- rcsb_embedding_model/dataset/esm_prot_from_csv.py +0 -91
- rcsb_embedding_model/dataset/residue_embedding_from_csv.py +0 -32
- rcsb_embedding_model-0.0.6.dist-info/METADATA +0 -117
- rcsb_embedding_model-0.0.6.dist-info/RECORD +0 -22
- {rcsb_embedding_model-0.0.6.dist-info → rcsb_embedding_model-0.0.8.dist-info}/WHEEL +0 -0
- {rcsb_embedding_model-0.0.6.dist-info → rcsb_embedding_model-0.0.8.dist-info}/entry_points.txt +0 -0
- {rcsb_embedding_model-0.0.6.dist-info → rcsb_embedding_model-0.0.8.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
3
|
+
from rcsb_embedding_model.dataset.resdiue_assembly_embedding_from_structure import ResidueAssemblyDatasetFromStructure
|
|
4
|
+
from rcsb_embedding_model.dataset.residue_assembly_embedding_from_tensor_file import ResidueAssemblyEmbeddingFromTensorFile
|
|
5
|
+
from rcsb_embedding_model.types.api_types import FileOrStreamTuple, SrcLocation, Accelerator, Devices, OptionalPath, EmbeddingPath, StructureLocation, StructureFormat, SrcAssemblyFrom
|
|
6
|
+
from rcsb_embedding_model.inference.chain_inference import predict as chain_predict
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def predict(
|
|
10
|
+
src_stream: FileOrStreamTuple,
|
|
11
|
+
res_embedding_location: EmbeddingPath,
|
|
12
|
+
src_location: SrcLocation = SrcLocation.local,
|
|
13
|
+
src_from: SrcAssemblyFrom = SrcAssemblyFrom.assembly,
|
|
14
|
+
structure_location: StructureLocation = StructureLocation.local,
|
|
15
|
+
structure_format: StructureFormat = StructureFormat.mmcif,
|
|
16
|
+
min_res_n: int = 0,
|
|
17
|
+
max_res_n: int = sys.maxsize,
|
|
18
|
+
batch_size: int = 1,
|
|
19
|
+
num_workers: int = 0,
|
|
20
|
+
num_nodes: int = 1,
|
|
21
|
+
accelerator: Accelerator = Accelerator.auto,
|
|
22
|
+
devices: Devices = 'auto',
|
|
23
|
+
out_path: OptionalPath = None
|
|
24
|
+
):
|
|
25
|
+
inference_set = ResidueAssemblyEmbeddingFromTensorFile(
|
|
26
|
+
src_stream=src_stream,
|
|
27
|
+
res_embedding_location=res_embedding_location,
|
|
28
|
+
src_location=src_location,
|
|
29
|
+
structure_location=structure_location,
|
|
30
|
+
structure_format=structure_format,
|
|
31
|
+
min_res_n=min_res_n,
|
|
32
|
+
max_res_n=max_res_n
|
|
33
|
+
) if src_from == SrcAssemblyFrom.assembly else ResidueAssemblyDatasetFromStructure(
|
|
34
|
+
src_stream=src_stream,
|
|
35
|
+
res_embedding_location=res_embedding_location,
|
|
36
|
+
src_location=src_location,
|
|
37
|
+
structure_location=structure_location,
|
|
38
|
+
structure_format=structure_format,
|
|
39
|
+
min_res_n=min_res_n,
|
|
40
|
+
max_res_n=max_res_n
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
return chain_predict(
|
|
44
|
+
src_stream=src_stream,
|
|
45
|
+
src_location=src_location,
|
|
46
|
+
batch_size=batch_size,
|
|
47
|
+
num_workers=num_workers,
|
|
48
|
+
num_nodes=num_nodes,
|
|
49
|
+
accelerator=accelerator,
|
|
50
|
+
devices=devices,
|
|
51
|
+
out_path=out_path,
|
|
52
|
+
inference_set=inference_set
|
|
53
|
+
)
|
|
@@ -1,26 +1,30 @@
|
|
|
1
1
|
from torch.utils.data import DataLoader
|
|
2
2
|
from lightning import Trainer
|
|
3
|
-
from typer import FileText
|
|
4
3
|
|
|
5
|
-
from rcsb_embedding_model.dataset.
|
|
4
|
+
from rcsb_embedding_model.dataset.residue_embedding_from_tensor_file import ResidueEmbeddingFromTensorFile
|
|
6
5
|
from rcsb_embedding_model.modules.chain_module import ChainModule
|
|
7
|
-
from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath
|
|
6
|
+
from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath, FileOrStreamTuple, SrcLocation
|
|
8
7
|
from rcsb_embedding_model.utils.data import collate_seq_embeddings
|
|
9
8
|
from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
def predict(
|
|
13
|
-
|
|
12
|
+
src_stream: FileOrStreamTuple,
|
|
13
|
+
src_location: SrcLocation = SrcLocation.local,
|
|
14
14
|
batch_size: int = 1,
|
|
15
15
|
num_workers: int = 0,
|
|
16
16
|
num_nodes: int = 1,
|
|
17
17
|
accelerator: Accelerator = Accelerator.auto,
|
|
18
18
|
devices: Devices = 'auto',
|
|
19
|
-
out_path: OptionalPath = None
|
|
19
|
+
out_path: OptionalPath = None,
|
|
20
|
+
inference_set=None
|
|
20
21
|
):
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
22
|
+
|
|
23
|
+
if inference_set is None:
|
|
24
|
+
inference_set = ResidueEmbeddingFromTensorFile(
|
|
25
|
+
src_stream=src_stream,
|
|
26
|
+
src_location=src_location
|
|
27
|
+
)
|
|
24
28
|
|
|
25
29
|
inference_dataloader = DataLoader(
|
|
26
30
|
dataset=inference_set,
|
|
@@ -1,17 +1,20 @@
|
|
|
1
1
|
from torch.utils.data import DataLoader
|
|
2
2
|
from lightning import Trainer
|
|
3
|
-
from typer import FileText
|
|
4
3
|
|
|
5
|
-
from rcsb_embedding_model.dataset.
|
|
4
|
+
from rcsb_embedding_model.dataset.esm_prot_from_structure import EsmProtFromStructure
|
|
5
|
+
from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
|
|
6
6
|
from rcsb_embedding_model.modules.esm_module import EsmModule
|
|
7
|
-
from rcsb_embedding_model.types.api_types import
|
|
7
|
+
from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, SrcProteinFrom, FileOrStreamTuple, SrcLocation
|
|
8
8
|
from rcsb_embedding_model.writer.batch_writer import TensorBatchWriter
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def predict(
|
|
12
|
-
|
|
12
|
+
src_stream: FileOrStreamTuple,
|
|
13
13
|
src_location: SrcLocation = SrcLocation.local,
|
|
14
|
-
|
|
14
|
+
src_from: SrcProteinFrom = SrcProteinFrom.chain,
|
|
15
|
+
structure_location: StructureLocation = StructureLocation.local,
|
|
16
|
+
structure_format: StructureFormat = StructureFormat.mmcif,
|
|
17
|
+
min_res_n: int = 0,
|
|
15
18
|
batch_size: int = 1,
|
|
16
19
|
num_workers: int = 0,
|
|
17
20
|
num_nodes: int = 1,
|
|
@@ -20,10 +23,17 @@ def predict(
|
|
|
20
23
|
out_path: OptionalPath = None
|
|
21
24
|
):
|
|
22
25
|
|
|
23
|
-
inference_set =
|
|
24
|
-
|
|
26
|
+
inference_set = EsmProtFromChain(
|
|
27
|
+
src_stream=src_stream,
|
|
25
28
|
src_location=src_location,
|
|
26
|
-
|
|
29
|
+
structure_location=structure_location,
|
|
30
|
+
structure_format=structure_format
|
|
31
|
+
) if src_from == SrcProteinFrom.chain else EsmProtFromStructure(
|
|
32
|
+
src_stream=src_stream,
|
|
33
|
+
src_location=src_location,
|
|
34
|
+
structure_location=structure_location,
|
|
35
|
+
structure_format=structure_format,
|
|
36
|
+
min_res_n=min_res_n
|
|
27
37
|
)
|
|
28
38
|
|
|
29
39
|
inference_dataloader = DataLoader(
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from torch.utils.data import DataLoader
|
|
2
|
+
from lightning import Trainer
|
|
3
|
+
|
|
4
|
+
from rcsb_embedding_model.dataset.esm_prot_from_structure import EsmProtFromStructure
|
|
5
|
+
from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
|
|
6
|
+
from rcsb_embedding_model.modules.structure_module import StructureModule
|
|
7
|
+
from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, SrcProteinFrom, FileOrStreamTuple, SrcLocation
|
|
8
|
+
from rcsb_embedding_model.writer.batch_writer import DataFrameStorage
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def predict(
|
|
12
|
+
src_stream: FileOrStreamTuple,
|
|
13
|
+
src_location: SrcLocation = SrcLocation.local,
|
|
14
|
+
src_from: SrcProteinFrom = SrcProteinFrom.chain,
|
|
15
|
+
structure_location: StructureLocation = StructureLocation.local,
|
|
16
|
+
structure_format: StructureFormat = StructureFormat.mmcif,
|
|
17
|
+
min_res_n: int = 0,
|
|
18
|
+
batch_size: int = 1,
|
|
19
|
+
num_workers: int = 0,
|
|
20
|
+
num_nodes: int = 1,
|
|
21
|
+
accelerator: Accelerator = Accelerator.auto,
|
|
22
|
+
devices: Devices = 'auto',
|
|
23
|
+
out_path: OptionalPath = None,
|
|
24
|
+
out_df_name: str = None
|
|
25
|
+
):
|
|
26
|
+
|
|
27
|
+
inference_set = EsmProtFromChain(
|
|
28
|
+
src_stream=src_stream,
|
|
29
|
+
src_location=src_location,
|
|
30
|
+
structure_location=structure_location,
|
|
31
|
+
structure_format=structure_format
|
|
32
|
+
) if src_from == SrcProteinFrom.chain else EsmProtFromStructure(
|
|
33
|
+
src_stream=src_stream,
|
|
34
|
+
src_location=src_location,
|
|
35
|
+
structure_location=structure_location,
|
|
36
|
+
structure_format=structure_format,
|
|
37
|
+
min_res_n=min_res_n
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
inference_dataloader = DataLoader(
|
|
41
|
+
dataset=inference_set,
|
|
42
|
+
batch_size=batch_size,
|
|
43
|
+
num_workers=num_workers,
|
|
44
|
+
collate_fn=lambda _: _
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
module = StructureModule()
|
|
48
|
+
inference_writer = DataFrameStorage(out_path, out_df_name) if out_path is not None and out_df_name is not None else None
|
|
49
|
+
trainer = Trainer(
|
|
50
|
+
callbacks=[inference_writer] if inference_writer is not None else None,
|
|
51
|
+
num_nodes=num_nodes,
|
|
52
|
+
accelerator=accelerator,
|
|
53
|
+
devices=devices
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
prediction = trainer.predict(
|
|
57
|
+
module,
|
|
58
|
+
inference_dataloader
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return prediction
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from esm.sdk.api import SamplingConfig
|
|
2
|
+
from lightning import LightningModule
|
|
3
|
+
|
|
4
|
+
from rcsb_embedding_model.utils.data import collate_seq_embeddings
|
|
5
|
+
from rcsb_embedding_model.utils.model import get_residue_model, get_aggregator_model
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class StructureModule(LightningModule):
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self
|
|
12
|
+
):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.esm3 = get_residue_model(self.device)
|
|
15
|
+
self.aggregator = get_aggregator_model(device=self.device)
|
|
16
|
+
|
|
17
|
+
def predict_step(self, prot_batch, batch_idx):
|
|
18
|
+
prot_embeddings = []
|
|
19
|
+
prot_names = []
|
|
20
|
+
for esm_prot, name in prot_batch:
|
|
21
|
+
embeddings = self.esm3.forward_and_sample(
|
|
22
|
+
self.esm3.encode(esm_prot), SamplingConfig(return_per_residue_embeddings=True)
|
|
23
|
+
).per_residue_embedding
|
|
24
|
+
prot_embeddings.append(embeddings)
|
|
25
|
+
prot_names.append(name)
|
|
26
|
+
res_batch_embedding, res_batch_mask = collate_seq_embeddings(prot_embeddings)
|
|
27
|
+
return self.aggregator(res_batch_embedding, res_batch_mask), tuple(prot_names)
|
|
@@ -2,9 +2,8 @@ import torch
|
|
|
2
2
|
from biotite.structure import get_residues, chain_iter, filter_amino_acids
|
|
3
3
|
from esm.sdk.api import ESMProtein, SamplingConfig
|
|
4
4
|
from esm.utils.structure.protein_chain import ProteinChain
|
|
5
|
-
from huggingface_hub import hf_hub_download
|
|
6
5
|
|
|
7
|
-
from rcsb_embedding_model.types.api_types import StreamSrc,
|
|
6
|
+
from rcsb_embedding_model.types.api_types import StreamSrc, StructureFormat
|
|
8
7
|
from rcsb_embedding_model.utils.model import get_aggregator_model, get_residue_model
|
|
9
8
|
from rcsb_embedding_model.utils.structure_parser import get_structure_from_src
|
|
10
9
|
|
|
@@ -42,23 +41,23 @@ class RcsbStructureEmbedding:
|
|
|
42
41
|
|
|
43
42
|
def structure_embedding(
|
|
44
43
|
self,
|
|
45
|
-
|
|
46
|
-
|
|
44
|
+
src_structure: StreamSrc,
|
|
45
|
+
structure_format: StructureFormat = StructureFormat.mmcif,
|
|
47
46
|
chain_id: str = None,
|
|
48
47
|
assembly_id: str = None
|
|
49
48
|
):
|
|
50
|
-
res_embedding = self.residue_embedding(
|
|
49
|
+
res_embedding = self.residue_embedding(src_structure, structure_format, chain_id, assembly_id)
|
|
51
50
|
return self.aggregator_embedding(res_embedding)
|
|
52
51
|
|
|
53
52
|
def residue_embedding(
|
|
54
53
|
self,
|
|
55
|
-
|
|
56
|
-
|
|
54
|
+
src_structure: StreamSrc,
|
|
55
|
+
structure_format: StructureFormat = StructureFormat.mmcif,
|
|
57
56
|
chain_id: str = None,
|
|
58
57
|
assembly_id: str = None
|
|
59
58
|
):
|
|
60
59
|
self.__check_residue_embedding()
|
|
61
|
-
structure = get_structure_from_src(
|
|
60
|
+
structure = get_structure_from_src(src_structure, structure_format, chain_id, assembly_id)
|
|
62
61
|
embedding_ch = []
|
|
63
62
|
for atom_ch in chain_iter(structure):
|
|
64
63
|
atom_res = atom_ch[filter_amino_acids(atom_ch)]
|
|
@@ -1,16 +1,23 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from os import PathLike
|
|
3
2
|
from typing import NewType, Union, IO, Tuple, List, Optional
|
|
4
3
|
|
|
5
|
-
|
|
6
|
-
|
|
4
|
+
from typer import FileText
|
|
5
|
+
|
|
6
|
+
StreamSrc = NewType('StreamSrc', Union[FileText, IO])
|
|
7
|
+
StreamTuple = NewType('StreamTuple', Union[
|
|
8
|
+
Tuple[str, StreamSrc, str, str],
|
|
9
|
+
Tuple[str, StreamSrc, str],
|
|
10
|
+
Tuple[str, str]
|
|
11
|
+
])
|
|
12
|
+
FileOrStreamTuple = NewType('FileOrStreamTuple', Union[FileText, StreamTuple])
|
|
7
13
|
|
|
8
14
|
Devices = NewType('Devices', Union[int, List[int], "auto"])
|
|
9
15
|
|
|
10
|
-
|
|
16
|
+
EmbeddingPath = Union[str, FileText]
|
|
17
|
+
OptionalPath = NewType('OptionalPath', Optional[FileText])
|
|
11
18
|
|
|
12
19
|
|
|
13
|
-
class
|
|
20
|
+
class StructureFormat(str, Enum):
|
|
14
21
|
pdb = "pdb"
|
|
15
22
|
mmcif = "mmcif"
|
|
16
23
|
bciff = "binarycif"
|
|
@@ -25,5 +32,20 @@ class Accelerator(str, Enum):
|
|
|
25
32
|
|
|
26
33
|
|
|
27
34
|
class SrcLocation(str, Enum):
|
|
35
|
+
local = "local"
|
|
36
|
+
stream = "stream"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class StructureLocation(str, Enum):
|
|
28
40
|
local = "local"
|
|
29
41
|
remote = "remote"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class SrcProteinFrom(str, Enum):
|
|
45
|
+
chain = "chain"
|
|
46
|
+
structure = "structure"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SrcAssemblyFrom(str, Enum):
|
|
50
|
+
assembly = "assembly"
|
|
51
|
+
structure = "structure"
|
|
@@ -44,4 +44,34 @@ def stringio_from_url(url):
|
|
|
44
44
|
print(f"Error fetching URL: {e}")
|
|
45
45
|
return None
|
|
46
46
|
|
|
47
|
+
def concatenate_tensors(file_list, max_residues, dim=0):
|
|
48
|
+
"""
|
|
49
|
+
Concatenates a list of tensors stored in individual files along a specified dimension.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
file_list (list of str): List of file paths to tensor files.
|
|
53
|
+
max_residues (int): Maximum number of residues allowed in the assembly
|
|
54
|
+
dim (int): The dimension along which to concatenate the tensors. Default is 0.
|
|
47
55
|
|
|
56
|
+
Returns:
|
|
57
|
+
torch.Tensor: The concatenated tensor.
|
|
58
|
+
"""
|
|
59
|
+
tensors = []
|
|
60
|
+
total_residues = 0
|
|
61
|
+
for file in file_list:
|
|
62
|
+
try:
|
|
63
|
+
tensor = torch.load(
|
|
64
|
+
file,
|
|
65
|
+
map_location=torch.device('cpu')
|
|
66
|
+
)
|
|
67
|
+
total_residues += tensor.shape[0]
|
|
68
|
+
tensors.append(tensor)
|
|
69
|
+
except Exception as e:
|
|
70
|
+
continue
|
|
71
|
+
if total_residues > max_residues:
|
|
72
|
+
break
|
|
73
|
+
if tensors and len(tensors) > 0:
|
|
74
|
+
tensor_cat = torch.cat(tensors, dim=dim)
|
|
75
|
+
return tensor_cat
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError("No valid tensors were loaded to concatenate.")
|
|
@@ -1,32 +1,62 @@
|
|
|
1
|
-
|
|
2
|
-
from biotite.structure.io.pdb import PDBFile, get_structure as get_pdb_structure, get_assembly as get_pdb_assembly
|
|
3
|
-
from biotite.structure.io.pdbx import CIFFile, get_structure, get_assembly, BinaryCIFFile
|
|
1
|
+
from biotite.structure import filter_amino_acids, chain_iter, get_chains, get_residues, AtomArray
|
|
2
|
+
from biotite.structure.io.pdb import PDBFile, get_structure as get_pdb_structure, get_assembly as get_pdb_assembly, list_assemblies as list_pdb_assemblies
|
|
3
|
+
from biotite.structure.io.pdbx import CIFFile, get_structure, get_assembly, BinaryCIFFile, list_assemblies
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def get_structure_from_src(
|
|
7
|
-
|
|
8
|
-
|
|
7
|
+
src_structure,
|
|
8
|
+
structure_format="mmcif",
|
|
9
9
|
chain_id=None,
|
|
10
10
|
assembly_id=None
|
|
11
11
|
):
|
|
12
|
-
if
|
|
13
|
-
pdb_file = PDBFile.read(
|
|
12
|
+
if structure_format == "pdb":
|
|
13
|
+
pdb_file = PDBFile.read(src_structure)
|
|
14
14
|
structure = __get_pdb_structure(pdb_file, assembly_id)
|
|
15
|
-
elif
|
|
16
|
-
cif_file = CIFFile.read(
|
|
15
|
+
elif structure_format == "mmcif":
|
|
16
|
+
cif_file = CIFFile.read(src_structure)
|
|
17
17
|
structure = __get_structure(cif_file, assembly_id)
|
|
18
|
-
elif
|
|
19
|
-
cif_file = BinaryCIFFile.read(
|
|
18
|
+
elif structure_format == "binarycif":
|
|
19
|
+
cif_file = BinaryCIFFile.read(src_structure)
|
|
20
20
|
structure = __get_structure(cif_file, assembly_id)
|
|
21
21
|
else:
|
|
22
|
-
raise RuntimeError(f"Unknown file format {
|
|
22
|
+
raise RuntimeError(f"Unknown file format {structure_format}")
|
|
23
23
|
|
|
24
24
|
if chain_id is not None:
|
|
25
|
-
|
|
25
|
+
return structure[structure.chain_id == chain_id]
|
|
26
26
|
|
|
27
27
|
return structure
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
def get_protein_chains(structure, min_res_n=0):
|
|
31
|
+
chain_ids = []
|
|
32
|
+
for atom_ch in chain_iter(structure):
|
|
33
|
+
atom_res = atom_ch[filter_amino_acids(atom_ch)]
|
|
34
|
+
if len(atom_res) > 0 and len(get_residues(atom_res)) > min_res_n:
|
|
35
|
+
chain_ids.append(str(get_chains(atom_res)[0]))
|
|
36
|
+
return tuple(chain_ids)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_assemblies(src_structure, structure_format="mmcif"):
|
|
40
|
+
if structure_format == "pdb":
|
|
41
|
+
return tuple(list_pdb_assemblies(PDBFile.read(src_structure)))
|
|
42
|
+
elif structure_format == "mmcif":
|
|
43
|
+
return tuple(list_assemblies(CIFFile.read(src_structure)).keys())
|
|
44
|
+
elif structure_format == "binarycif":
|
|
45
|
+
return tuple(list_assemblies(BinaryCIFFile.read(src_structure)))
|
|
46
|
+
else:
|
|
47
|
+
raise RuntimeError(f"Unknown file format {structure_format}")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def rename_atom_ch(atom_ch, ch="A"):
|
|
51
|
+
renamed_atom_ch = AtomArray(len(atom_ch))
|
|
52
|
+
n = 0
|
|
53
|
+
for atom in atom_ch:
|
|
54
|
+
atom.chain_id = ch
|
|
55
|
+
renamed_atom_ch[n] = atom
|
|
56
|
+
n += 1
|
|
57
|
+
return renamed_atom_ch
|
|
58
|
+
|
|
59
|
+
|
|
30
60
|
def __get_pdb_structure(pdb_file, assembly_id=None):
|
|
31
61
|
return get_pdb_structure(
|
|
32
62
|
pdb_file,
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from rcsb_embedding_model.utils.structure_parser import get_structure_from_src
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class StructureProvider:
|
|
5
|
+
|
|
6
|
+
def __init__(self):
|
|
7
|
+
self.__src_name = None
|
|
8
|
+
self.__structure = None
|
|
9
|
+
|
|
10
|
+
def get_structure(
|
|
11
|
+
self,
|
|
12
|
+
src_name,
|
|
13
|
+
src_structure,
|
|
14
|
+
structure_format="mmcif",
|
|
15
|
+
chain_id=None,
|
|
16
|
+
assembly_id=None
|
|
17
|
+
):
|
|
18
|
+
if src_name != self.__src_name:
|
|
19
|
+
self.__src_name = src_name
|
|
20
|
+
self.__structure = get_structure_from_src(
|
|
21
|
+
src_structure=src_structure,
|
|
22
|
+
structure_format=structure_format,
|
|
23
|
+
assembly_id=assembly_id
|
|
24
|
+
)
|
|
25
|
+
if chain_id is not None:
|
|
26
|
+
return self.__structure[self.__structure.chain_id == chain_id]
|
|
27
|
+
return self.__structure
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: rcsb-embedding-model
|
|
3
|
+
Version: 0.0.8
|
|
4
|
+
Summary: Protein Embedding Model for Structure Search
|
|
5
|
+
Project-URL: Homepage, https://github.com/rcsb/rcsb-embedding-model
|
|
6
|
+
Project-URL: Issues, https://github.com/rcsb/rcsb-embedding-model/issues
|
|
7
|
+
Author-email: Joan Segura <joan.segura@rcsb.org>
|
|
8
|
+
License-Expression: BSD-3-Clause
|
|
9
|
+
License-File: LICENSE.md
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Requires-Python: >=3.10
|
|
13
|
+
Requires-Dist: esm>=3.2.0
|
|
14
|
+
Requires-Dist: lightning>=2.5.0
|
|
15
|
+
Requires-Dist: torch>=2.2.0
|
|
16
|
+
Requires-Dist: typer>=0.15.0
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
|
|
19
|
+
# RCSB Embedding Model
|
|
20
|
+
|
|
21
|
+
**Version** 0.0.8
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
## Overview
|
|
25
|
+
|
|
26
|
+
RCSB Embedding Model is a neural network architecture designed to encode macromolecular 3D structures into fixed-length vector embeddings for efficient large-scale structure similarity search.
|
|
27
|
+
|
|
28
|
+
Preprint: [Multi-scale structural similarity embedding search across entire proteomes](https://www.biorxiv.org/content/10.1101/2025.02.28.640875v1).
|
|
29
|
+
|
|
30
|
+
A web-based implementation using this model for structure similarity search is available at [rcsb-embedding-search](http://embedding-search.rcsb.org).
|
|
31
|
+
|
|
32
|
+
If you are interested in training the model with a new dataset, visit the [rcsb-embedding-search repository](https://github.com/bioinsilico/rcsb-embedding-search), which provides scripts and documentation for training.
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
## Features
|
|
36
|
+
|
|
37
|
+
- **Residue-level embeddings** computed using the ESM3 protein language model
|
|
38
|
+
- **Structure-level embeddings** aggregated via a transformer-based aggregator network
|
|
39
|
+
- **Command-line interface** implemented with Typer for high-throughput inference workflows
|
|
40
|
+
- **Python API** for interactive embedding computation and integration into analysis pipelines
|
|
41
|
+
- **High-performance inference** leveraging PyTorch Lightning, with multi-node and multi-GPU support
|
|
42
|
+
|
|
43
|
+
---
|
|
44
|
+
|
|
45
|
+
## Installation
|
|
46
|
+
|
|
47
|
+
pip install rcsb-embedding-model
|
|
48
|
+
|
|
49
|
+
**Requirements:**
|
|
50
|
+
|
|
51
|
+
- Python ≥ 3.10
|
|
52
|
+
- ESM ≥ 3.2.0
|
|
53
|
+
- PyTorch ≥ 2.2.0
|
|
54
|
+
- Lightning ≥ 2.5.0
|
|
55
|
+
- Typer ≥ 0.15.0
|
|
56
|
+
|
|
57
|
+
---
|
|
58
|
+
|
|
59
|
+
## Quick Start
|
|
60
|
+
|
|
61
|
+
### CLI
|
|
62
|
+
|
|
63
|
+
# 1. Compute residue embeddings: Calculate residue level embeddings of protein structures using ESM3. Predictions are stored as torch tensor files.
|
|
64
|
+
inference residue-embedding --src-file data/structures.csv --output-path results/residue_embeddings --structure-format mmcif --batch-size 8 --devices auto
|
|
65
|
+
|
|
66
|
+
# 2. Compute structure embeddings: Calculate single-chain protein embeddings from structural files. Predictions are stored in a single pandas DataFrame file.
|
|
67
|
+
inference structure-embedding --src-file data/structures.csv --output-path results/residue_embeddings --out-df-name df-res-embeddings --batch-size 4 --devices 0 --devives 1
|
|
68
|
+
|
|
69
|
+
# 3. Compute chain embeddings: Calculate single-chain protein embeddings from residue level embeddings stored as torch tensor files. Predictions a re stored as csv files.
|
|
70
|
+
inference chain-embedding --src-file data/structures.csv --output-path results/chain_embeddings --batch-size 4
|
|
71
|
+
|
|
72
|
+
# 4. Compute assembly embeddings: Calculate assembly embeddings from residue level embeddings stored as torch tensor files. Predictions are stored as csv files.
|
|
73
|
+
inference assembly-embedding --src-file data/structures.csv --res-embedding-location results/residue_embeddings --output-path results/assembly_embeddings
|
|
74
|
+
|
|
75
|
+
### Python API
|
|
76
|
+
|
|
77
|
+
from rcsb_embedding_model import RcsbStructureEmbedding
|
|
78
|
+
|
|
79
|
+
model = RcsbStructureEmbedding()
|
|
80
|
+
|
|
81
|
+
# Compute per-residue embeddings
|
|
82
|
+
res_emb = model.residue_embedding(
|
|
83
|
+
src_structure="examples/1abc.cif",
|
|
84
|
+
src_format="mmcif",
|
|
85
|
+
chain_id="A"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Aggregate to structure-level embedding
|
|
89
|
+
struct_emb = model.aggregator_embedding(res_emb)
|
|
90
|
+
|
|
91
|
+
See the examples and tests directories for more use cases.
|
|
92
|
+
|
|
93
|
+
---
|
|
94
|
+
|
|
95
|
+
## Model Architecture
|
|
96
|
+
|
|
97
|
+
The embedding model is trained to predict structural similarity by approximating TM-scores using cosine distances between embeddings. It consists of two main components:
|
|
98
|
+
|
|
99
|
+
- **Protein Language Model (PLM)**: Computes residue-level embeddings from a given 3D structure.
|
|
100
|
+
- **Residue Embedding Aggregator**: A transformer-based neural network that aggregates these residue-level embeddings into a single vector.
|
|
101
|
+
|
|
102
|
+

|
|
103
|
+
|
|
104
|
+
### **Protein Language Model (PLM)**
|
|
105
|
+
Residue-wise embeddings of protein structures are computed using the [ESM3](https://www.evolutionaryscale.ai/) generative protein language model.
|
|
106
|
+
|
|
107
|
+
### **Residue Embedding Aggregator**
|
|
108
|
+
The aggregation component consists of six transformer encoder layers, each with a 3,072-neuron feedforward layer and ReLU activations. After processing through these layers, a summation pooling operation is applied, followed by 12 fully connected residual layers that refine the embeddings into a single 1,536-dimensional vector.
|
|
109
|
+
|
|
110
|
+
---
|
|
111
|
+
|
|
112
|
+
## Development
|
|
113
|
+
|
|
114
|
+
git clone https://github.com/rcsb/rcsb-embedding-model.git
|
|
115
|
+
cd rcsb-embedding-model
|
|
116
|
+
pip install -e .
|
|
117
|
+
pytest
|
|
118
|
+
|
|
119
|
+
---
|
|
120
|
+
|
|
121
|
+
## Citation
|
|
122
|
+
|
|
123
|
+
Segura, J., Bittrich, S., et al. (2024). *Multi-scale structural similarity embedding search across entire proteomes*. bioRxiv. (Preprint: https://www.biorxiv.org/content/10.1101/2025.02.28.640875v1)
|
|
124
|
+
|
|
125
|
+
---
|
|
126
|
+
|
|
127
|
+
## License
|
|
128
|
+
|
|
129
|
+
This project is licensed under the BSD 3-Clause License. See [LICENSE.md](LICENSE.md) for details.
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
rcsb_embedding_model/__init__.py,sha256=r3gLdeBIXkQEQA_K6QcRPO-TtYuAQSutk6pXRUE_nas,120
|
|
2
|
+
rcsb_embedding_model/rcsb_structure_embedding.py,sha256=dKp9hXQO0JAnO4SEfjJ_mG_jHu3UxAPguv6jkOjp-BI,4487
|
|
3
|
+
rcsb_embedding_model/cli/args_utils.py,sha256=7nP2q8pL5dWK_U7opxtWmoFcYVwasky6elHk-dASFaI,165
|
|
4
|
+
rcsb_embedding_model/cli/inference.py,sha256=KPZLqznbxZE_CBCGigUGg7yOfGsi8ID4aWMTExniRj4,11464
|
|
5
|
+
rcsb_embedding_model/dataset/esm_prot_from_chain.py,sha256=dBD2N0Y-GoN6p3z2yLnOvv6JGn-skAxwgbOYhXKDngc,3487
|
|
6
|
+
rcsb_embedding_model/dataset/esm_prot_from_structure.py,sha256=kOqgHfHjiym5InaAgpgMmHBgCAPEqW88PCoHHQy0ROI,2490
|
|
7
|
+
rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py,sha256=d8C7HRJBZWuOKhPQpihv1koT4aIvyt5QN2yndC2ABuE,2842
|
|
8
|
+
rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py,sha256=KXiohnPjjfZEFbPZQ46HGE8eEYWrVX8bfbTz4zPlo7o,3451
|
|
9
|
+
rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py,sha256=cOxT--Spkel10JJCeGlqgLXN5vNCZzPdfSxDgUSrdPI,1268
|
|
10
|
+
rcsb_embedding_model/inference/assembly_inferece.py,sha256=MPssN5bsOqOU-LGwa6AKX99cv5LD43Mnbaqhuuww1Tw,2165
|
|
11
|
+
rcsb_embedding_model/inference/chain_inference.py,sha256=R9gi0MZ_HaM3v9c433W_5w4suse4nJmy4SgUTHJVZLg,1713
|
|
12
|
+
rcsb_embedding_model/inference/esm_inference.py,sha256=oVN4r9_6V8TS0pYoNn7GR92Xo0Zn7eBsnt_OfDSaH6g,2126
|
|
13
|
+
rcsb_embedding_model/inference/structure_inference.py,sha256=QIUEo8eEc-kTSYKGdlX2rxT74huw4ZAw6U8Px9kYajE,2216
|
|
14
|
+
rcsb_embedding_model/model/layers.py,sha256=lhKaWC4gTS_T5lHOP0mgnnP8nKTPEOm4MrjhESA4hE8,743
|
|
15
|
+
rcsb_embedding_model/model/residue_embedding_aggregator.py,sha256=k3UW63Ax8DtjCMdD3O5xNxtyAu28l2n3-Ab6nS0atm0,1967
|
|
16
|
+
rcsb_embedding_model/modules/chain_module.py,sha256=sDSPXJmWuU2C3lt1NorlbUVWZvRSLzumPdFQk01h3VI,403
|
|
17
|
+
rcsb_embedding_model/modules/esm_module.py,sha256=CTHGOATXiarqZsBsZ8oxGJBj20A73186Slpr0EzMJsE,770
|
|
18
|
+
rcsb_embedding_model/modules/structure_module.py,sha256=dEtDNdWo1j2sSDa0JiOHQfEfQzIWqSLEKpvOX0GrXZ4,1048
|
|
19
|
+
rcsb_embedding_model/types/api_types.py,sha256=3sPh33yb3Ya9r3O5vuiTfhb1WyFuhQWCQmewSbqEyG0,1076
|
|
20
|
+
rcsb_embedding_model/utils/data.py,sha256=x6ca_bVdBXEAp9ugCi1rVEQ-G5nGTFKpzDKqZKpkFBE,2933
|
|
21
|
+
rcsb_embedding_model/utils/model.py,sha256=rpZa-gfm3cEtbBd7UXMHrZv3x6f0AC8TJT3gtrSxr5I,852
|
|
22
|
+
rcsb_embedding_model/utils/structure_parser.py,sha256=jat4SCtPHYMZ6JJR-T7lPQoMbT_E8CwYSGDNSZjG86U,2697
|
|
23
|
+
rcsb_embedding_model/utils/structure_provider.py,sha256=eWtxjkPpmRfmil_DKR1J6miaXR3lQ28DF5O0qrqSgGA,786
|
|
24
|
+
rcsb_embedding_model/writer/batch_writer.py,sha256=ekgzFZyoKpcnZ3IDP9hfOWBpuHxUQ31P35ViDAi-Edw,2843
|
|
25
|
+
rcsb_embedding_model-0.0.8.dist-info/METADATA,sha256=XvNb99X9GWdMEdz-A_o-ngTxlfiWrr8KMvjKh_rk3x0,5366
|
|
26
|
+
rcsb_embedding_model-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
27
|
+
rcsb_embedding_model-0.0.8.dist-info/entry_points.txt,sha256=MK11jTIEmaV-x4CkPX5IymDaVs7Ky_f2xxU8BJVZ_9Q,69
|
|
28
|
+
rcsb_embedding_model-0.0.8.dist-info/licenses/LICENSE.md,sha256=oUaHiKgfBkChth_Sm67WemEvatO1U0Go8LHjaskXY0w,1522
|
|
29
|
+
rcsb_embedding_model-0.0.8.dist-info/RECORD,,
|