rcsb-embedding-model 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rcsb-embedding-model might be problematic. Click here for more details.

Files changed (26) hide show
  1. rcsb_embedding_model/cli/args_utils.py +0 -2
  2. rcsb_embedding_model/cli/inference.py +164 -42
  3. rcsb_embedding_model/dataset/esm_prot_from_chain.py +102 -0
  4. rcsb_embedding_model/dataset/esm_prot_from_structure.py +63 -0
  5. rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py +68 -0
  6. rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py +94 -0
  7. rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py +43 -0
  8. rcsb_embedding_model/inference/assembly_inferece.py +53 -0
  9. rcsb_embedding_model/inference/chain_inference.py +12 -8
  10. rcsb_embedding_model/inference/esm_inference.py +18 -8
  11. rcsb_embedding_model/inference/structure_inference.py +61 -0
  12. rcsb_embedding_model/modules/structure_module.py +27 -0
  13. rcsb_embedding_model/rcsb_structure_embedding.py +7 -8
  14. rcsb_embedding_model/types/api_types.py +27 -5
  15. rcsb_embedding_model/utils/data.py +30 -0
  16. rcsb_embedding_model/utils/structure_parser.py +43 -13
  17. rcsb_embedding_model/utils/structure_provider.py +27 -0
  18. rcsb_embedding_model-0.0.8.dist-info/METADATA +129 -0
  19. rcsb_embedding_model-0.0.8.dist-info/RECORD +29 -0
  20. rcsb_embedding_model/dataset/esm_prot_from_csv.py +0 -91
  21. rcsb_embedding_model/dataset/residue_embedding_from_csv.py +0 -32
  22. rcsb_embedding_model-0.0.6.dist-info/METADATA +0 -117
  23. rcsb_embedding_model-0.0.6.dist-info/RECORD +0 -22
  24. {rcsb_embedding_model-0.0.6.dist-info → rcsb_embedding_model-0.0.8.dist-info}/WHEEL +0 -0
  25. {rcsb_embedding_model-0.0.6.dist-info → rcsb_embedding_model-0.0.8.dist-info}/entry_points.txt +0 -0
  26. {rcsb_embedding_model-0.0.6.dist-info → rcsb_embedding_model-0.0.8.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,53 @@
1
+ import sys
2
+
3
+ from rcsb_embedding_model.dataset.resdiue_assembly_embedding_from_structure import ResidueAssemblyDatasetFromStructure
4
+ from rcsb_embedding_model.dataset.residue_assembly_embedding_from_tensor_file import ResidueAssemblyEmbeddingFromTensorFile
5
+ from rcsb_embedding_model.types.api_types import FileOrStreamTuple, SrcLocation, Accelerator, Devices, OptionalPath, EmbeddingPath, StructureLocation, StructureFormat, SrcAssemblyFrom
6
+ from rcsb_embedding_model.inference.chain_inference import predict as chain_predict
7
+
8
+
9
+ def predict(
10
+ src_stream: FileOrStreamTuple,
11
+ res_embedding_location: EmbeddingPath,
12
+ src_location: SrcLocation = SrcLocation.local,
13
+ src_from: SrcAssemblyFrom = SrcAssemblyFrom.assembly,
14
+ structure_location: StructureLocation = StructureLocation.local,
15
+ structure_format: StructureFormat = StructureFormat.mmcif,
16
+ min_res_n: int = 0,
17
+ max_res_n: int = sys.maxsize,
18
+ batch_size: int = 1,
19
+ num_workers: int = 0,
20
+ num_nodes: int = 1,
21
+ accelerator: Accelerator = Accelerator.auto,
22
+ devices: Devices = 'auto',
23
+ out_path: OptionalPath = None
24
+ ):
25
+ inference_set = ResidueAssemblyEmbeddingFromTensorFile(
26
+ src_stream=src_stream,
27
+ res_embedding_location=res_embedding_location,
28
+ src_location=src_location,
29
+ structure_location=structure_location,
30
+ structure_format=structure_format,
31
+ min_res_n=min_res_n,
32
+ max_res_n=max_res_n
33
+ ) if src_from == SrcAssemblyFrom.assembly else ResidueAssemblyDatasetFromStructure(
34
+ src_stream=src_stream,
35
+ res_embedding_location=res_embedding_location,
36
+ src_location=src_location,
37
+ structure_location=structure_location,
38
+ structure_format=structure_format,
39
+ min_res_n=min_res_n,
40
+ max_res_n=max_res_n
41
+ )
42
+
43
+ return chain_predict(
44
+ src_stream=src_stream,
45
+ src_location=src_location,
46
+ batch_size=batch_size,
47
+ num_workers=num_workers,
48
+ num_nodes=num_nodes,
49
+ accelerator=accelerator,
50
+ devices=devices,
51
+ out_path=out_path,
52
+ inference_set=inference_set
53
+ )
@@ -1,26 +1,30 @@
1
1
  from torch.utils.data import DataLoader
2
2
  from lightning import Trainer
3
- from typer import FileText
4
3
 
5
- from rcsb_embedding_model.dataset.residue_embedding_from_csv import ResidueEmbeddingFromCSV
4
+ from rcsb_embedding_model.dataset.residue_embedding_from_tensor_file import ResidueEmbeddingFromTensorFile
6
5
  from rcsb_embedding_model.modules.chain_module import ChainModule
7
- from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath
6
+ from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath, FileOrStreamTuple, SrcLocation
8
7
  from rcsb_embedding_model.utils.data import collate_seq_embeddings
9
8
  from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter
10
9
 
11
10
 
12
11
  def predict(
13
- csv_file: FileText,
12
+ src_stream: FileOrStreamTuple,
13
+ src_location: SrcLocation = SrcLocation.local,
14
14
  batch_size: int = 1,
15
15
  num_workers: int = 0,
16
16
  num_nodes: int = 1,
17
17
  accelerator: Accelerator = Accelerator.auto,
18
18
  devices: Devices = 'auto',
19
- out_path: OptionalPath = None
19
+ out_path: OptionalPath = None,
20
+ inference_set=None
20
21
  ):
21
- inference_set = ResidueEmbeddingFromCSV(
22
- csv_file=csv_file
23
- )
22
+
23
+ if inference_set is None:
24
+ inference_set = ResidueEmbeddingFromTensorFile(
25
+ src_stream=src_stream,
26
+ src_location=src_location
27
+ )
24
28
 
25
29
  inference_dataloader = DataLoader(
26
30
  dataset=inference_set,
@@ -1,17 +1,20 @@
1
1
  from torch.utils.data import DataLoader
2
2
  from lightning import Trainer
3
- from typer import FileText
4
3
 
5
- from rcsb_embedding_model.dataset.esm_prot_from_csv import EsmProtFromCsv
4
+ from rcsb_embedding_model.dataset.esm_prot_from_structure import EsmProtFromStructure
5
+ from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
6
6
  from rcsb_embedding_model.modules.esm_module import EsmModule
7
- from rcsb_embedding_model.types.api_types import SrcFormat, Accelerator, Devices, OptionalPath, SrcLocation
7
+ from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, SrcProteinFrom, FileOrStreamTuple, SrcLocation
8
8
  from rcsb_embedding_model.writer.batch_writer import TensorBatchWriter
9
9
 
10
10
 
11
11
  def predict(
12
- csv_file: FileText,
12
+ src_stream: FileOrStreamTuple,
13
13
  src_location: SrcLocation = SrcLocation.local,
14
- src_format: SrcFormat = SrcFormat.mmcif,
14
+ src_from: SrcProteinFrom = SrcProteinFrom.chain,
15
+ structure_location: StructureLocation = StructureLocation.local,
16
+ structure_format: StructureFormat = StructureFormat.mmcif,
17
+ min_res_n: int = 0,
15
18
  batch_size: int = 1,
16
19
  num_workers: int = 0,
17
20
  num_nodes: int = 1,
@@ -20,10 +23,17 @@ def predict(
20
23
  out_path: OptionalPath = None
21
24
  ):
22
25
 
23
- inference_set = EsmProtFromCsv(
24
- csv_file=csv_file,
26
+ inference_set = EsmProtFromChain(
27
+ src_stream=src_stream,
25
28
  src_location=src_location,
26
- src_format=src_format
29
+ structure_location=structure_location,
30
+ structure_format=structure_format
31
+ ) if src_from == SrcProteinFrom.chain else EsmProtFromStructure(
32
+ src_stream=src_stream,
33
+ src_location=src_location,
34
+ structure_location=structure_location,
35
+ structure_format=structure_format,
36
+ min_res_n=min_res_n
27
37
  )
28
38
 
29
39
  inference_dataloader = DataLoader(
@@ -0,0 +1,61 @@
1
+ from torch.utils.data import DataLoader
2
+ from lightning import Trainer
3
+
4
+ from rcsb_embedding_model.dataset.esm_prot_from_structure import EsmProtFromStructure
5
+ from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
6
+ from rcsb_embedding_model.modules.structure_module import StructureModule
7
+ from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, SrcProteinFrom, FileOrStreamTuple, SrcLocation
8
+ from rcsb_embedding_model.writer.batch_writer import DataFrameStorage
9
+
10
+
11
+ def predict(
12
+ src_stream: FileOrStreamTuple,
13
+ src_location: SrcLocation = SrcLocation.local,
14
+ src_from: SrcProteinFrom = SrcProteinFrom.chain,
15
+ structure_location: StructureLocation = StructureLocation.local,
16
+ structure_format: StructureFormat = StructureFormat.mmcif,
17
+ min_res_n: int = 0,
18
+ batch_size: int = 1,
19
+ num_workers: int = 0,
20
+ num_nodes: int = 1,
21
+ accelerator: Accelerator = Accelerator.auto,
22
+ devices: Devices = 'auto',
23
+ out_path: OptionalPath = None,
24
+ out_df_name: str = None
25
+ ):
26
+
27
+ inference_set = EsmProtFromChain(
28
+ src_stream=src_stream,
29
+ src_location=src_location,
30
+ structure_location=structure_location,
31
+ structure_format=structure_format
32
+ ) if src_from == SrcProteinFrom.chain else EsmProtFromStructure(
33
+ src_stream=src_stream,
34
+ src_location=src_location,
35
+ structure_location=structure_location,
36
+ structure_format=structure_format,
37
+ min_res_n=min_res_n
38
+ )
39
+
40
+ inference_dataloader = DataLoader(
41
+ dataset=inference_set,
42
+ batch_size=batch_size,
43
+ num_workers=num_workers,
44
+ collate_fn=lambda _: _
45
+ )
46
+
47
+ module = StructureModule()
48
+ inference_writer = DataFrameStorage(out_path, out_df_name) if out_path is not None and out_df_name is not None else None
49
+ trainer = Trainer(
50
+ callbacks=[inference_writer] if inference_writer is not None else None,
51
+ num_nodes=num_nodes,
52
+ accelerator=accelerator,
53
+ devices=devices
54
+ )
55
+
56
+ prediction = trainer.predict(
57
+ module,
58
+ inference_dataloader
59
+ )
60
+
61
+ return prediction
@@ -0,0 +1,27 @@
1
+ from esm.sdk.api import SamplingConfig
2
+ from lightning import LightningModule
3
+
4
+ from rcsb_embedding_model.utils.data import collate_seq_embeddings
5
+ from rcsb_embedding_model.utils.model import get_residue_model, get_aggregator_model
6
+
7
+
8
+ class StructureModule(LightningModule):
9
+
10
+ def __init__(
11
+ self
12
+ ):
13
+ super().__init__()
14
+ self.esm3 = get_residue_model(self.device)
15
+ self.aggregator = get_aggregator_model(device=self.device)
16
+
17
+ def predict_step(self, prot_batch, batch_idx):
18
+ prot_embeddings = []
19
+ prot_names = []
20
+ for esm_prot, name in prot_batch:
21
+ embeddings = self.esm3.forward_and_sample(
22
+ self.esm3.encode(esm_prot), SamplingConfig(return_per_residue_embeddings=True)
23
+ ).per_residue_embedding
24
+ prot_embeddings.append(embeddings)
25
+ prot_names.append(name)
26
+ res_batch_embedding, res_batch_mask = collate_seq_embeddings(prot_embeddings)
27
+ return self.aggregator(res_batch_embedding, res_batch_mask), tuple(prot_names)
@@ -2,9 +2,8 @@ import torch
2
2
  from biotite.structure import get_residues, chain_iter, filter_amino_acids
3
3
  from esm.sdk.api import ESMProtein, SamplingConfig
4
4
  from esm.utils.structure.protein_chain import ProteinChain
5
- from huggingface_hub import hf_hub_download
6
5
 
7
- from rcsb_embedding_model.types.api_types import StreamSrc, SrcFormat
6
+ from rcsb_embedding_model.types.api_types import StreamSrc, StructureFormat
8
7
  from rcsb_embedding_model.utils.model import get_aggregator_model, get_residue_model
9
8
  from rcsb_embedding_model.utils.structure_parser import get_structure_from_src
10
9
 
@@ -42,23 +41,23 @@ class RcsbStructureEmbedding:
42
41
 
43
42
  def structure_embedding(
44
43
  self,
45
- structure_src: StreamSrc,
46
- src_format: SrcFormat = SrcFormat.mmcif,
44
+ src_structure: StreamSrc,
45
+ structure_format: StructureFormat = StructureFormat.mmcif,
47
46
  chain_id: str = None,
48
47
  assembly_id: str = None
49
48
  ):
50
- res_embedding = self.residue_embedding(structure_src, src_format, chain_id, assembly_id)
49
+ res_embedding = self.residue_embedding(src_structure, structure_format, chain_id, assembly_id)
51
50
  return self.aggregator_embedding(res_embedding)
52
51
 
53
52
  def residue_embedding(
54
53
  self,
55
- structure_src: StreamSrc,
56
- src_format: SrcFormat = SrcFormat.mmcif,
54
+ src_structure: StreamSrc,
55
+ structure_format: StructureFormat = StructureFormat.mmcif,
57
56
  chain_id: str = None,
58
57
  assembly_id: str = None
59
58
  ):
60
59
  self.__check_residue_embedding()
61
- structure = get_structure_from_src(structure_src, src_format, chain_id, assembly_id)
60
+ structure = get_structure_from_src(src_structure, structure_format, chain_id, assembly_id)
62
61
  embedding_ch = []
63
62
  for atom_ch in chain_iter(structure):
64
63
  atom_res = atom_ch[filter_amino_acids(atom_ch)]
@@ -1,16 +1,23 @@
1
1
  from enum import Enum
2
- from os import PathLike
3
2
  from typing import NewType, Union, IO, Tuple, List, Optional
4
3
 
5
- StreamSrc = NewType('StreamSrc', Union[PathLike, IO])
6
- StreamTuple = NewType('StreamTuple', Tuple[StreamSrc, str, str])
4
+ from typer import FileText
5
+
6
+ StreamSrc = NewType('StreamSrc', Union[FileText, IO])
7
+ StreamTuple = NewType('StreamTuple', Union[
8
+ Tuple[str, StreamSrc, str, str],
9
+ Tuple[str, StreamSrc, str],
10
+ Tuple[str, str]
11
+ ])
12
+ FileOrStreamTuple = NewType('FileOrStreamTuple', Union[FileText, StreamTuple])
7
13
 
8
14
  Devices = NewType('Devices', Union[int, List[int], "auto"])
9
15
 
10
- OptionalPath = NewType('OptionalPath', Optional[PathLike])
16
+ EmbeddingPath = Union[str, FileText]
17
+ OptionalPath = NewType('OptionalPath', Optional[FileText])
11
18
 
12
19
 
13
- class SrcFormat(str, Enum):
20
+ class StructureFormat(str, Enum):
14
21
  pdb = "pdb"
15
22
  mmcif = "mmcif"
16
23
  bciff = "binarycif"
@@ -25,5 +32,20 @@ class Accelerator(str, Enum):
25
32
 
26
33
 
27
34
  class SrcLocation(str, Enum):
35
+ local = "local"
36
+ stream = "stream"
37
+
38
+
39
+ class StructureLocation(str, Enum):
28
40
  local = "local"
29
41
  remote = "remote"
42
+
43
+
44
+ class SrcProteinFrom(str, Enum):
45
+ chain = "chain"
46
+ structure = "structure"
47
+
48
+
49
+ class SrcAssemblyFrom(str, Enum):
50
+ assembly = "assembly"
51
+ structure = "structure"
@@ -44,4 +44,34 @@ def stringio_from_url(url):
44
44
  print(f"Error fetching URL: {e}")
45
45
  return None
46
46
 
47
+ def concatenate_tensors(file_list, max_residues, dim=0):
48
+ """
49
+ Concatenates a list of tensors stored in individual files along a specified dimension.
50
+
51
+ Args:
52
+ file_list (list of str): List of file paths to tensor files.
53
+ max_residues (int): Maximum number of residues allowed in the assembly
54
+ dim (int): The dimension along which to concatenate the tensors. Default is 0.
47
55
 
56
+ Returns:
57
+ torch.Tensor: The concatenated tensor.
58
+ """
59
+ tensors = []
60
+ total_residues = 0
61
+ for file in file_list:
62
+ try:
63
+ tensor = torch.load(
64
+ file,
65
+ map_location=torch.device('cpu')
66
+ )
67
+ total_residues += tensor.shape[0]
68
+ tensors.append(tensor)
69
+ except Exception as e:
70
+ continue
71
+ if total_residues > max_residues:
72
+ break
73
+ if tensors and len(tensors) > 0:
74
+ tensor_cat = torch.cat(tensors, dim=dim)
75
+ return tensor_cat
76
+ else:
77
+ raise ValueError("No valid tensors were loaded to concatenate.")
@@ -1,32 +1,62 @@
1
-
2
- from biotite.structure.io.pdb import PDBFile, get_structure as get_pdb_structure, get_assembly as get_pdb_assembly
3
- from biotite.structure.io.pdbx import CIFFile, get_structure, get_assembly, BinaryCIFFile
1
+ from biotite.structure import filter_amino_acids, chain_iter, get_chains, get_residues, AtomArray
2
+ from biotite.structure.io.pdb import PDBFile, get_structure as get_pdb_structure, get_assembly as get_pdb_assembly, list_assemblies as list_pdb_assemblies
3
+ from biotite.structure.io.pdbx import CIFFile, get_structure, get_assembly, BinaryCIFFile, list_assemblies
4
4
 
5
5
 
6
6
  def get_structure_from_src(
7
- structure_src,
8
- src_format="mmcif",
7
+ src_structure,
8
+ structure_format="mmcif",
9
9
  chain_id=None,
10
10
  assembly_id=None
11
11
  ):
12
- if src_format == "pdb":
13
- pdb_file = PDBFile.read(structure_src)
12
+ if structure_format == "pdb":
13
+ pdb_file = PDBFile.read(src_structure)
14
14
  structure = __get_pdb_structure(pdb_file, assembly_id)
15
- elif src_format == "mmcif":
16
- cif_file = CIFFile.read(structure_src)
15
+ elif structure_format == "mmcif":
16
+ cif_file = CIFFile.read(src_structure)
17
17
  structure = __get_structure(cif_file, assembly_id)
18
- elif src_format == "binarycif":
19
- cif_file = BinaryCIFFile.read(structure_src)
18
+ elif structure_format == "binarycif":
19
+ cif_file = BinaryCIFFile.read(src_structure)
20
20
  structure = __get_structure(cif_file, assembly_id)
21
21
  else:
22
- raise RuntimeError(f"Unknown file format {src_format}")
22
+ raise RuntimeError(f"Unknown file format {structure_format}")
23
23
 
24
24
  if chain_id is not None:
25
- structure = structure[structure.chain_id == chain_id]
25
+ return structure[structure.chain_id == chain_id]
26
26
 
27
27
  return structure
28
28
 
29
29
 
30
+ def get_protein_chains(structure, min_res_n=0):
31
+ chain_ids = []
32
+ for atom_ch in chain_iter(structure):
33
+ atom_res = atom_ch[filter_amino_acids(atom_ch)]
34
+ if len(atom_res) > 0 and len(get_residues(atom_res)) > min_res_n:
35
+ chain_ids.append(str(get_chains(atom_res)[0]))
36
+ return tuple(chain_ids)
37
+
38
+
39
+ def get_assemblies(src_structure, structure_format="mmcif"):
40
+ if structure_format == "pdb":
41
+ return tuple(list_pdb_assemblies(PDBFile.read(src_structure)))
42
+ elif structure_format == "mmcif":
43
+ return tuple(list_assemblies(CIFFile.read(src_structure)).keys())
44
+ elif structure_format == "binarycif":
45
+ return tuple(list_assemblies(BinaryCIFFile.read(src_structure)))
46
+ else:
47
+ raise RuntimeError(f"Unknown file format {structure_format}")
48
+
49
+
50
+ def rename_atom_ch(atom_ch, ch="A"):
51
+ renamed_atom_ch = AtomArray(len(atom_ch))
52
+ n = 0
53
+ for atom in atom_ch:
54
+ atom.chain_id = ch
55
+ renamed_atom_ch[n] = atom
56
+ n += 1
57
+ return renamed_atom_ch
58
+
59
+
30
60
  def __get_pdb_structure(pdb_file, assembly_id=None):
31
61
  return get_pdb_structure(
32
62
  pdb_file,
@@ -0,0 +1,27 @@
1
+ from rcsb_embedding_model.utils.structure_parser import get_structure_from_src
2
+
3
+
4
+ class StructureProvider:
5
+
6
+ def __init__(self):
7
+ self.__src_name = None
8
+ self.__structure = None
9
+
10
+ def get_structure(
11
+ self,
12
+ src_name,
13
+ src_structure,
14
+ structure_format="mmcif",
15
+ chain_id=None,
16
+ assembly_id=None
17
+ ):
18
+ if src_name != self.__src_name:
19
+ self.__src_name = src_name
20
+ self.__structure = get_structure_from_src(
21
+ src_structure=src_structure,
22
+ structure_format=structure_format,
23
+ assembly_id=assembly_id
24
+ )
25
+ if chain_id is not None:
26
+ return self.__structure[self.__structure.chain_id == chain_id]
27
+ return self.__structure
@@ -0,0 +1,129 @@
1
+ Metadata-Version: 2.4
2
+ Name: rcsb-embedding-model
3
+ Version: 0.0.8
4
+ Summary: Protein Embedding Model for Structure Search
5
+ Project-URL: Homepage, https://github.com/rcsb/rcsb-embedding-model
6
+ Project-URL: Issues, https://github.com/rcsb/rcsb-embedding-model/issues
7
+ Author-email: Joan Segura <joan.segura@rcsb.org>
8
+ License-Expression: BSD-3-Clause
9
+ License-File: LICENSE.md
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Programming Language :: Python :: 3
12
+ Requires-Python: >=3.10
13
+ Requires-Dist: esm>=3.2.0
14
+ Requires-Dist: lightning>=2.5.0
15
+ Requires-Dist: torch>=2.2.0
16
+ Requires-Dist: typer>=0.15.0
17
+ Description-Content-Type: text/markdown
18
+
19
+ # RCSB Embedding Model
20
+
21
+ **Version** 0.0.8
22
+
23
+
24
+ ## Overview
25
+
26
+ RCSB Embedding Model is a neural network architecture designed to encode macromolecular 3D structures into fixed-length vector embeddings for efficient large-scale structure similarity search.
27
+
28
+ Preprint: [Multi-scale structural similarity embedding search across entire proteomes](https://www.biorxiv.org/content/10.1101/2025.02.28.640875v1).
29
+
30
+ A web-based implementation using this model for structure similarity search is available at [rcsb-embedding-search](http://embedding-search.rcsb.org).
31
+
32
+ If you are interested in training the model with a new dataset, visit the [rcsb-embedding-search repository](https://github.com/bioinsilico/rcsb-embedding-search), which provides scripts and documentation for training.
33
+
34
+
35
+ ## Features
36
+
37
+ - **Residue-level embeddings** computed using the ESM3 protein language model
38
+ - **Structure-level embeddings** aggregated via a transformer-based aggregator network
39
+ - **Command-line interface** implemented with Typer for high-throughput inference workflows
40
+ - **Python API** for interactive embedding computation and integration into analysis pipelines
41
+ - **High-performance inference** leveraging PyTorch Lightning, with multi-node and multi-GPU support
42
+
43
+ ---
44
+
45
+ ## Installation
46
+
47
+ pip install rcsb-embedding-model
48
+
49
+ **Requirements:**
50
+
51
+ - Python ≥ 3.10
52
+ - ESM ≥ 3.2.0
53
+ - PyTorch ≥ 2.2.0
54
+ - Lightning ≥ 2.5.0
55
+ - Typer ≥ 0.15.0
56
+
57
+ ---
58
+
59
+ ## Quick Start
60
+
61
+ ### CLI
62
+
63
+ # 1. Compute residue embeddings: Calculate residue level embeddings of protein structures using ESM3. Predictions are stored as torch tensor files.
64
+ inference residue-embedding --src-file data/structures.csv --output-path results/residue_embeddings --structure-format mmcif --batch-size 8 --devices auto
65
+
66
+ # 2. Compute structure embeddings: Calculate single-chain protein embeddings from structural files. Predictions are stored in a single pandas DataFrame file.
67
+ inference structure-embedding --src-file data/structures.csv --output-path results/residue_embeddings --out-df-name df-res-embeddings --batch-size 4 --devices 0 --devives 1
68
+
69
+ # 3. Compute chain embeddings: Calculate single-chain protein embeddings from residue level embeddings stored as torch tensor files. Predictions a re stored as csv files.
70
+ inference chain-embedding --src-file data/structures.csv --output-path results/chain_embeddings --batch-size 4
71
+
72
+ # 4. Compute assembly embeddings: Calculate assembly embeddings from residue level embeddings stored as torch tensor files. Predictions are stored as csv files.
73
+ inference assembly-embedding --src-file data/structures.csv --res-embedding-location results/residue_embeddings --output-path results/assembly_embeddings
74
+
75
+ ### Python API
76
+
77
+ from rcsb_embedding_model import RcsbStructureEmbedding
78
+
79
+ model = RcsbStructureEmbedding()
80
+
81
+ # Compute per-residue embeddings
82
+ res_emb = model.residue_embedding(
83
+ src_structure="examples/1abc.cif",
84
+ src_format="mmcif",
85
+ chain_id="A"
86
+ )
87
+
88
+ # Aggregate to structure-level embedding
89
+ struct_emb = model.aggregator_embedding(res_emb)
90
+
91
+ See the examples and tests directories for more use cases.
92
+
93
+ ---
94
+
95
+ ## Model Architecture
96
+
97
+ The embedding model is trained to predict structural similarity by approximating TM-scores using cosine distances between embeddings. It consists of two main components:
98
+
99
+ - **Protein Language Model (PLM)**: Computes residue-level embeddings from a given 3D structure.
100
+ - **Residue Embedding Aggregator**: A transformer-based neural network that aggregates these residue-level embeddings into a single vector.
101
+
102
+ ![Embedding model architecture](assets/embedding-model-architecture.png)
103
+
104
+ ### **Protein Language Model (PLM)**
105
+ Residue-wise embeddings of protein structures are computed using the [ESM3](https://www.evolutionaryscale.ai/) generative protein language model.
106
+
107
+ ### **Residue Embedding Aggregator**
108
+ The aggregation component consists of six transformer encoder layers, each with a 3,072-neuron feedforward layer and ReLU activations. After processing through these layers, a summation pooling operation is applied, followed by 12 fully connected residual layers that refine the embeddings into a single 1,536-dimensional vector.
109
+
110
+ ---
111
+
112
+ ## Development
113
+
114
+ git clone https://github.com/rcsb/rcsb-embedding-model.git
115
+ cd rcsb-embedding-model
116
+ pip install -e .
117
+ pytest
118
+
119
+ ---
120
+
121
+ ## Citation
122
+
123
+ Segura, J., Bittrich, S., et al. (2024). *Multi-scale structural similarity embedding search across entire proteomes*. bioRxiv. (Preprint: https://www.biorxiv.org/content/10.1101/2025.02.28.640875v1)
124
+
125
+ ---
126
+
127
+ ## License
128
+
129
+ This project is licensed under the BSD 3-Clause License. See [LICENSE.md](LICENSE.md) for details.
@@ -0,0 +1,29 @@
1
+ rcsb_embedding_model/__init__.py,sha256=r3gLdeBIXkQEQA_K6QcRPO-TtYuAQSutk6pXRUE_nas,120
2
+ rcsb_embedding_model/rcsb_structure_embedding.py,sha256=dKp9hXQO0JAnO4SEfjJ_mG_jHu3UxAPguv6jkOjp-BI,4487
3
+ rcsb_embedding_model/cli/args_utils.py,sha256=7nP2q8pL5dWK_U7opxtWmoFcYVwasky6elHk-dASFaI,165
4
+ rcsb_embedding_model/cli/inference.py,sha256=KPZLqznbxZE_CBCGigUGg7yOfGsi8ID4aWMTExniRj4,11464
5
+ rcsb_embedding_model/dataset/esm_prot_from_chain.py,sha256=dBD2N0Y-GoN6p3z2yLnOvv6JGn-skAxwgbOYhXKDngc,3487
6
+ rcsb_embedding_model/dataset/esm_prot_from_structure.py,sha256=kOqgHfHjiym5InaAgpgMmHBgCAPEqW88PCoHHQy0ROI,2490
7
+ rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py,sha256=d8C7HRJBZWuOKhPQpihv1koT4aIvyt5QN2yndC2ABuE,2842
8
+ rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py,sha256=KXiohnPjjfZEFbPZQ46HGE8eEYWrVX8bfbTz4zPlo7o,3451
9
+ rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py,sha256=cOxT--Spkel10JJCeGlqgLXN5vNCZzPdfSxDgUSrdPI,1268
10
+ rcsb_embedding_model/inference/assembly_inferece.py,sha256=MPssN5bsOqOU-LGwa6AKX99cv5LD43Mnbaqhuuww1Tw,2165
11
+ rcsb_embedding_model/inference/chain_inference.py,sha256=R9gi0MZ_HaM3v9c433W_5w4suse4nJmy4SgUTHJVZLg,1713
12
+ rcsb_embedding_model/inference/esm_inference.py,sha256=oVN4r9_6V8TS0pYoNn7GR92Xo0Zn7eBsnt_OfDSaH6g,2126
13
+ rcsb_embedding_model/inference/structure_inference.py,sha256=QIUEo8eEc-kTSYKGdlX2rxT74huw4ZAw6U8Px9kYajE,2216
14
+ rcsb_embedding_model/model/layers.py,sha256=lhKaWC4gTS_T5lHOP0mgnnP8nKTPEOm4MrjhESA4hE8,743
15
+ rcsb_embedding_model/model/residue_embedding_aggregator.py,sha256=k3UW63Ax8DtjCMdD3O5xNxtyAu28l2n3-Ab6nS0atm0,1967
16
+ rcsb_embedding_model/modules/chain_module.py,sha256=sDSPXJmWuU2C3lt1NorlbUVWZvRSLzumPdFQk01h3VI,403
17
+ rcsb_embedding_model/modules/esm_module.py,sha256=CTHGOATXiarqZsBsZ8oxGJBj20A73186Slpr0EzMJsE,770
18
+ rcsb_embedding_model/modules/structure_module.py,sha256=dEtDNdWo1j2sSDa0JiOHQfEfQzIWqSLEKpvOX0GrXZ4,1048
19
+ rcsb_embedding_model/types/api_types.py,sha256=3sPh33yb3Ya9r3O5vuiTfhb1WyFuhQWCQmewSbqEyG0,1076
20
+ rcsb_embedding_model/utils/data.py,sha256=x6ca_bVdBXEAp9ugCi1rVEQ-G5nGTFKpzDKqZKpkFBE,2933
21
+ rcsb_embedding_model/utils/model.py,sha256=rpZa-gfm3cEtbBd7UXMHrZv3x6f0AC8TJT3gtrSxr5I,852
22
+ rcsb_embedding_model/utils/structure_parser.py,sha256=jat4SCtPHYMZ6JJR-T7lPQoMbT_E8CwYSGDNSZjG86U,2697
23
+ rcsb_embedding_model/utils/structure_provider.py,sha256=eWtxjkPpmRfmil_DKR1J6miaXR3lQ28DF5O0qrqSgGA,786
24
+ rcsb_embedding_model/writer/batch_writer.py,sha256=ekgzFZyoKpcnZ3IDP9hfOWBpuHxUQ31P35ViDAi-Edw,2843
25
+ rcsb_embedding_model-0.0.8.dist-info/METADATA,sha256=XvNb99X9GWdMEdz-A_o-ngTxlfiWrr8KMvjKh_rk3x0,5366
26
+ rcsb_embedding_model-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
27
+ rcsb_embedding_model-0.0.8.dist-info/entry_points.txt,sha256=MK11jTIEmaV-x4CkPX5IymDaVs7Ky_f2xxU8BJVZ_9Q,69
28
+ rcsb_embedding_model-0.0.8.dist-info/licenses/LICENSE.md,sha256=oUaHiKgfBkChth_Sm67WemEvatO1U0Go8LHjaskXY0w,1522
29
+ rcsb_embedding_model-0.0.8.dist-info/RECORD,,