rcsb-embedding-model 0.0.10__tar.gz → 0.0.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rcsb-embedding-model might be problematic. Click here for more details.
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/PKG-INFO +2 -2
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/README.md +1 -1
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/pyproject.toml +1 -1
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/cli/inference.py +31 -7
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/dataset/esm_prot_from_structure.py +1 -1
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py +3 -3
- rcsb_embedding_model-0.0.12/src/rcsb_embedding_model/dataset/residue_embedding_from_structure.py +66 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py +2 -2
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/inference/chain_inference.py +15 -1
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/types/api_types.py +5 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/utils/data.py +1 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/utils/structure_parser.py +4 -4
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/test_inference.py +24 -1
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/test_remote_inference.py +21 -1
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/.gitignore +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/LICENSE.md +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/assets/embedding-model-architecture.png +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/examples/esm_embeddings.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/__init__.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/cli/args_utils.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/dataset/esm_prot_from_chain.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/inference/assembly_inferece.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/inference/esm_inference.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/inference/structure_inference.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/model/layers.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/model/residue_embedding_aggregator.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/modules/chain_module.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/modules/esm_module.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/modules/structure_module.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/rcsb_structure_embedding.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/utils/model.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/utils/structure_provider.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/writer/batch_writer.py +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/1acb.A.pt +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/1acb.B.pt +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/2uzi.A.pt +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/2uzi.B.pt +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/2uzi.C.pt +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/esm-from-chain-inference.csv +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/pdb/1acb.cif +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/pdb/2uzi.cif +0 -0
- {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/test_embedding_model.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rcsb-embedding-model
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.12
|
|
4
4
|
Summary: Protein Embedding Model for Structure Search
|
|
5
5
|
Project-URL: Homepage, https://github.com/rcsb/rcsb-embedding-model
|
|
6
6
|
Project-URL: Issues, https://github.com/rcsb/rcsb-embedding-model/issues
|
|
@@ -18,7 +18,7 @@ Description-Content-Type: text/markdown
|
|
|
18
18
|
|
|
19
19
|
# RCSB Embedding Model
|
|
20
20
|
|
|
21
|
-
**Version** 0.0.
|
|
21
|
+
**Version** 0.0.12
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
## Overview
|
|
@@ -5,7 +5,7 @@ import typer
|
|
|
5
5
|
|
|
6
6
|
from rcsb_embedding_model.cli.args_utils import arg_devices
|
|
7
7
|
from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, SrcLocation, SrcProteinFrom, \
|
|
8
|
-
StructureLocation, SrcAssemblyFrom
|
|
8
|
+
StructureLocation, SrcAssemblyFrom, SrcTensorFrom
|
|
9
9
|
|
|
10
10
|
app = typer.Typer(
|
|
11
11
|
add_completion=False
|
|
@@ -22,7 +22,7 @@ def residue_embedding(
|
|
|
22
22
|
file_okay=True,
|
|
23
23
|
dir_okay=False,
|
|
24
24
|
resolve_path=True,
|
|
25
|
-
help='CSV file 4 (or 3) columns: Structure Name | Structure File Path | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
|
|
25
|
+
help='CSV file 4 (or 3) columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
|
|
26
26
|
)],
|
|
27
27
|
output_path: Annotated[typer.FileText, typer.Option(
|
|
28
28
|
exists=True,
|
|
@@ -86,7 +86,7 @@ def structure_embedding(
|
|
|
86
86
|
file_okay=True,
|
|
87
87
|
dir_okay=False,
|
|
88
88
|
resolve_path=True,
|
|
89
|
-
help='CSV file 4 (or 3) columns: Structure Name | Structure File Path | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
|
|
89
|
+
help='CSV file 4 (or 3) columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
|
|
90
90
|
)],
|
|
91
91
|
output_path: Annotated[typer.FileText, typer.Option(
|
|
92
92
|
exists=True,
|
|
@@ -102,7 +102,7 @@ def structure_embedding(
|
|
|
102
102
|
help='Use specific chains or all chains in a structure.'
|
|
103
103
|
)] = SrcProteinFrom.chain,
|
|
104
104
|
structure_location: Annotated[StructureLocation, typer.Option(
|
|
105
|
-
help='
|
|
105
|
+
help='Structure file location.'
|
|
106
106
|
)] = StructureLocation.local,
|
|
107
107
|
structure_format: Annotated[StructureFormat, typer.Option(
|
|
108
108
|
help='Structure file format.'
|
|
@@ -154,7 +154,7 @@ def chain_embedding(
|
|
|
154
154
|
file_okay=True,
|
|
155
155
|
dir_okay=False,
|
|
156
156
|
resolve_path=True,
|
|
157
|
-
help='CSV file 2 columns: Residue
|
|
157
|
+
help='Option 1 (src-from=file) - CSV file 2 columns: Residue Embedding Torch Tensor File | Output Embedding Name. Option 2 (src-from=structure) - CSV file 3 columns: Structure Name | Structure File Path or URL (switch structure-location) | Output Embedding Name.'
|
|
158
158
|
)],
|
|
159
159
|
output_path: Annotated[typer.FileText, typer.Option(
|
|
160
160
|
exists=True,
|
|
@@ -163,6 +163,25 @@ def chain_embedding(
|
|
|
163
163
|
resolve_path=True,
|
|
164
164
|
help='Output path to store predictions. Embeddings are stored as csv files.'
|
|
165
165
|
)],
|
|
166
|
+
res_embedding_location: Annotated[typer.FileText, typer.Option(
|
|
167
|
+
exists=True,
|
|
168
|
+
file_okay=False,
|
|
169
|
+
dir_okay=True,
|
|
170
|
+
resolve_path=True,
|
|
171
|
+
help='Path where residue level embeddings are located. This argument is required if src-from=structure.'
|
|
172
|
+
)] = None,
|
|
173
|
+
src_from: Annotated[SrcTensorFrom, typer.Option(
|
|
174
|
+
help='Use file names or all chains in a structure.'
|
|
175
|
+
)] = SrcTensorFrom.file,
|
|
176
|
+
structure_location: Annotated[StructureLocation, typer.Option(
|
|
177
|
+
help='Structure file location.'
|
|
178
|
+
)] = StructureLocation.local,
|
|
179
|
+
structure_format: Annotated[StructureFormat, typer.Option(
|
|
180
|
+
help='Structure file format.'
|
|
181
|
+
)] = StructureFormat.mmcif,
|
|
182
|
+
min_res_n: Annotated[int, typer.Option(
|
|
183
|
+
help='When using all chains in a structure, consider only chains with more than <min_res_n> residues.'
|
|
184
|
+
)] = 0,
|
|
166
185
|
batch_size: Annotated[int, typer.Option(
|
|
167
186
|
help='Number of samples processed together in one iteration.'
|
|
168
187
|
)] = 1,
|
|
@@ -182,7 +201,12 @@ def chain_embedding(
|
|
|
182
201
|
from rcsb_embedding_model.inference.chain_inference import predict
|
|
183
202
|
predict(
|
|
184
203
|
src_stream=src_file,
|
|
204
|
+
res_embedding_location=res_embedding_location,
|
|
185
205
|
src_location=SrcLocation.local,
|
|
206
|
+
src_from=src_from,
|
|
207
|
+
structure_location=structure_location,
|
|
208
|
+
structure_format=structure_format,
|
|
209
|
+
min_res_n=min_res_n,
|
|
186
210
|
batch_size=batch_size,
|
|
187
211
|
num_workers=num_workers,
|
|
188
212
|
num_nodes=num_nodes,
|
|
@@ -201,7 +225,7 @@ def assembly_embedding(
|
|
|
201
225
|
file_okay=True,
|
|
202
226
|
dir_okay=False,
|
|
203
227
|
resolve_path=True,
|
|
204
|
-
help='CSV file 4 columns: Structure Name | Structure File Path | Assembly Id | Output embedding name.'
|
|
228
|
+
help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Assembly Id | Output embedding name.'
|
|
205
229
|
)],
|
|
206
230
|
res_embedding_location: Annotated[typer.FileText, typer.Option(
|
|
207
231
|
exists=True,
|
|
@@ -221,7 +245,7 @@ def assembly_embedding(
|
|
|
221
245
|
help='Use specific assembly or all assemblies in a structure.'
|
|
222
246
|
)] = SrcAssemblyFrom.assembly,
|
|
223
247
|
structure_location: Annotated[StructureLocation, typer.Option(
|
|
224
|
-
help='
|
|
248
|
+
help='Structure file location.'
|
|
225
249
|
)] = StructureLocation.local,
|
|
226
250
|
structure_format: Annotated[StructureFormat, typer.Option(
|
|
227
251
|
help='Structure file format.'
|
|
@@ -43,7 +43,7 @@ class EsmProtFromStructure(EsmProtFromChain):
|
|
|
43
43
|
for idx, row in (pd.DataFrame(
|
|
44
44
|
src_stream,
|
|
45
45
|
dtype=str,
|
|
46
|
-
columns=
|
|
46
|
+
columns=EsmProtFromStructure.COLUMNS
|
|
47
47
|
) if self.src_location == SrcLocation.stream else pd.read_csv(
|
|
48
48
|
src_stream,
|
|
49
49
|
header=None,
|
|
@@ -50,7 +50,7 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
|
|
|
50
50
|
for idx, row in (pd.DataFrame(
|
|
51
51
|
src_stream,
|
|
52
52
|
dtype=str,
|
|
53
|
-
columns=
|
|
53
|
+
columns=ResidueAssemblyDatasetFromStructure.COLUMNS
|
|
54
54
|
) if self.src_location == SrcLocation.stream else pd.read_csv(
|
|
55
55
|
src_stream,
|
|
56
56
|
header=None,
|
|
@@ -60,9 +60,9 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
|
|
|
60
60
|
)).iterrows():
|
|
61
61
|
src_name = row[ResidueAssemblyDatasetFromStructure.STREAM_NAME_ATTR]
|
|
62
62
|
src_structure = row[ResidueAssemblyDatasetFromStructure.STREAM_ATTR]
|
|
63
|
-
|
|
63
|
+
structure = stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure
|
|
64
64
|
item_name = row[ResidueAssemblyDatasetFromStructure.ITEM_NAME_ATTR]
|
|
65
|
-
for assembly_id in get_assemblies(
|
|
65
|
+
for assembly_id in get_assemblies(structure=structure, structure_format=self.structure_format):
|
|
66
66
|
assemblies.append((src_name, src_structure, str(assembly_id), f"{item_name}.{assembly_id}"))
|
|
67
67
|
|
|
68
68
|
return tuple(assemblies)
|
rcsb_embedding_model-0.0.12/src/rcsb_embedding_model/dataset/residue_embedding_from_structure.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
from rcsb_embedding_model.dataset.residue_embedding_from_tensor_file import ResidueEmbeddingFromTensorFile
|
|
6
|
+
from rcsb_embedding_model.types.api_types import SrcLocation, StructureLocation, StructureFormat
|
|
7
|
+
from rcsb_embedding_model.utils.data import stringio_from_url
|
|
8
|
+
from rcsb_embedding_model.utils.structure_parser import get_protein_chains
|
|
9
|
+
from rcsb_embedding_model.utils.structure_provider import StructureProvider
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ResidueEmbeddingFromStructure(ResidueEmbeddingFromTensorFile):
|
|
13
|
+
|
|
14
|
+
STREAM_NAME_ATTR = 'stream_name'
|
|
15
|
+
STREAM_ATTR = 'stream'
|
|
16
|
+
ITEM_NAME_ATTR = 'item_name'
|
|
17
|
+
|
|
18
|
+
COLUMNS = [STREAM_NAME_ATTR, STREAM_ATTR, ITEM_NAME_ATTR]
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
src_stream,
|
|
23
|
+
res_embedding_location,
|
|
24
|
+
src_location=SrcLocation.local,
|
|
25
|
+
structure_location=StructureLocation.local,
|
|
26
|
+
structure_format=StructureFormat.mmcif,
|
|
27
|
+
min_res_n=0,
|
|
28
|
+
structure_provider=StructureProvider()
|
|
29
|
+
):
|
|
30
|
+
if not os.path.isdir(res_embedding_location):
|
|
31
|
+
raise FileNotFoundError(f"Folder {res_embedding_location} does not exist")
|
|
32
|
+
self.res_embedding_location = res_embedding_location
|
|
33
|
+
self.src_location = src_location
|
|
34
|
+
self.structure_location = structure_location
|
|
35
|
+
self.structure_format = structure_format
|
|
36
|
+
self.min_res_n = min_res_n
|
|
37
|
+
self.__structure_provider = structure_provider
|
|
38
|
+
super().__init__(
|
|
39
|
+
src_stream=self.__get_chains(src_stream),
|
|
40
|
+
src_location=SrcLocation.stream
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def __get_chains(self, src_stream):
|
|
44
|
+
chains = []
|
|
45
|
+
for idx, row in (pd.DataFrame(
|
|
46
|
+
src_stream,
|
|
47
|
+
dtype=str,
|
|
48
|
+
columns=ResidueEmbeddingFromStructure.COLUMNS
|
|
49
|
+
) if self.src_location == SrcLocation.stream else pd.read_csv(
|
|
50
|
+
src_stream,
|
|
51
|
+
header=None,
|
|
52
|
+
index_col=None,
|
|
53
|
+
dtype=str,
|
|
54
|
+
names=ResidueEmbeddingFromStructure.COLUMNS
|
|
55
|
+
)).iterrows():
|
|
56
|
+
src_name = row[ResidueEmbeddingFromStructure.STREAM_NAME_ATTR]
|
|
57
|
+
src_structure = row[ResidueEmbeddingFromStructure.STREAM_ATTR]
|
|
58
|
+
item_name = row[ResidueEmbeddingFromStructure.ITEM_NAME_ATTR]
|
|
59
|
+
structure = self.__structure_provider.get_structure(
|
|
60
|
+
src_name=src_name,
|
|
61
|
+
src_structure=stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure,
|
|
62
|
+
structure_format=self.structure_format
|
|
63
|
+
)
|
|
64
|
+
for ch in get_protein_chains(structure, self.min_res_n):
|
|
65
|
+
chains.append((os.path.join(self.res_embedding_location, f"{src_name}.{ch}.pt"), f"{item_name}.{ch}"))
|
|
66
|
+
return tuple(chains)
|
|
@@ -2,7 +2,7 @@ import pandas as pd
|
|
|
2
2
|
import torch
|
|
3
3
|
from torch.utils.data import Dataset
|
|
4
4
|
|
|
5
|
-
from rcsb_embedding_model.types.api_types import
|
|
5
|
+
from rcsb_embedding_model.types.api_types import SrcLocation
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class ResidueEmbeddingFromTensorFile(Dataset):
|
|
@@ -26,7 +26,7 @@ class ResidueEmbeddingFromTensorFile(Dataset):
|
|
|
26
26
|
self.data = pd.DataFrame(
|
|
27
27
|
src_stream,
|
|
28
28
|
dtype=str,
|
|
29
|
-
columns=
|
|
29
|
+
columns=ResidueEmbeddingFromTensorFile.COLUMNS
|
|
30
30
|
) if self.src_location == SrcLocation.stream else pd.read_csv(
|
|
31
31
|
src_stream,
|
|
32
32
|
header=None,
|
|
@@ -1,16 +1,23 @@
|
|
|
1
1
|
from torch.utils.data import DataLoader
|
|
2
2
|
from lightning import Trainer
|
|
3
3
|
|
|
4
|
+
from rcsb_embedding_model.dataset.residue_embedding_from_structure import ResidueEmbeddingFromStructure
|
|
4
5
|
from rcsb_embedding_model.dataset.residue_embedding_from_tensor_file import ResidueEmbeddingFromTensorFile
|
|
5
6
|
from rcsb_embedding_model.modules.chain_module import ChainModule
|
|
6
|
-
from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath, FileOrStreamTuple, SrcLocation
|
|
7
|
+
from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath, FileOrStreamTuple, SrcLocation, \
|
|
8
|
+
SrcTensorFrom, StructureLocation, StructureFormat
|
|
7
9
|
from rcsb_embedding_model.utils.data import collate_seq_embeddings
|
|
8
10
|
from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
def predict(
|
|
12
14
|
src_stream: FileOrStreamTuple,
|
|
15
|
+
res_embedding_location: OptionalPath = None,
|
|
13
16
|
src_location: SrcLocation = SrcLocation.local,
|
|
17
|
+
src_from: SrcTensorFrom = SrcTensorFrom.file,
|
|
18
|
+
structure_location: StructureLocation = StructureLocation.local,
|
|
19
|
+
structure_format: StructureFormat = StructureFormat.mmcif,
|
|
20
|
+
min_res_n: int = 0,
|
|
14
21
|
batch_size: int = 1,
|
|
15
22
|
num_workers: int = 0,
|
|
16
23
|
num_nodes: int = 1,
|
|
@@ -24,6 +31,13 @@ def predict(
|
|
|
24
31
|
inference_set = ResidueEmbeddingFromTensorFile(
|
|
25
32
|
src_stream=src_stream,
|
|
26
33
|
src_location=src_location
|
|
34
|
+
) if src_from == SrcTensorFrom.file else ResidueEmbeddingFromStructure(
|
|
35
|
+
src_stream=src_stream,
|
|
36
|
+
res_embedding_location=res_embedding_location,
|
|
37
|
+
src_location=src_location,
|
|
38
|
+
structure_location=structure_location,
|
|
39
|
+
structure_format=structure_format,
|
|
40
|
+
min_res_n=min_res_n
|
|
27
41
|
)
|
|
28
42
|
|
|
29
43
|
inference_dataloader = DataLoader(
|
|
@@ -37,13 +37,13 @@ def get_protein_chains(structure, min_res_n=0):
|
|
|
37
37
|
return tuple(chain_ids)
|
|
38
38
|
|
|
39
39
|
|
|
40
|
-
def get_assemblies(
|
|
40
|
+
def get_assemblies(structure, structure_format="mmcif"):
|
|
41
41
|
if structure_format == "pdb":
|
|
42
|
-
return tuple(list_pdb_assemblies(PDBFile.read(
|
|
42
|
+
return tuple(list_pdb_assemblies(PDBFile.read(structure)))
|
|
43
43
|
elif structure_format == "mmcif":
|
|
44
|
-
return tuple(list_assemblies(CIFFile.read(
|
|
44
|
+
return tuple(list_assemblies(CIFFile.read(structure)).keys())
|
|
45
45
|
elif structure_format == "binarycif":
|
|
46
|
-
return tuple(list_assemblies(BinaryCIFFile.read(
|
|
46
|
+
return tuple(list_assemblies(BinaryCIFFile.read(structure)))
|
|
47
47
|
else:
|
|
48
48
|
raise RuntimeError(f"Unknown file format {structure_format}")
|
|
49
49
|
|
|
@@ -2,7 +2,7 @@ import os
|
|
|
2
2
|
import unittest
|
|
3
3
|
|
|
4
4
|
from rcsb_embedding_model.types.api_types import StructureLocation, Accelerator, SrcProteinFrom, SrcLocation, \
|
|
5
|
-
StructureFormat, SrcAssemblyFrom
|
|
5
|
+
StructureFormat, SrcAssemblyFrom, SrcTensorFrom
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class TestInference(unittest.TestCase):
|
|
@@ -72,6 +72,29 @@ class TestInference(unittest.TestCase):
|
|
|
72
72
|
self.assertEqual(tuple(chain_embeddings[3][0][0].shape), (1536,))
|
|
73
73
|
self.assertEqual(tuple(chain_embeddings[4][0][0].shape), (1536,))
|
|
74
74
|
|
|
75
|
+
def test_chain_inference_from_structure(self):
|
|
76
|
+
from rcsb_embedding_model.inference.chain_inference import predict
|
|
77
|
+
chain_embeddings = predict(
|
|
78
|
+
src_stream=[
|
|
79
|
+
("1acb", f"{self.__test_path}/resources/pdb/1acb.cif", "1acb"),
|
|
80
|
+
("2uzi", f"{self.__test_path}/resources/pdb/2uzi.cif", "2uzi"),
|
|
81
|
+
],
|
|
82
|
+
res_embedding_location=f"{self.__test_path}/resources/embeddings",
|
|
83
|
+
src_location=SrcLocation.stream,
|
|
84
|
+
src_from=SrcTensorFrom.structure,
|
|
85
|
+
structure_location=StructureLocation.local,
|
|
86
|
+
structure_format=StructureFormat.mmcif,
|
|
87
|
+
min_res_n=0,
|
|
88
|
+
accelerator=Accelerator.cpu
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
self.assertEqual(len(chain_embeddings), 5)
|
|
92
|
+
self.assertEqual(tuple(chain_embeddings[0][0][0].shape), (1536,))
|
|
93
|
+
self.assertEqual(tuple(chain_embeddings[1][0][0].shape), (1536,))
|
|
94
|
+
self.assertEqual(tuple(chain_embeddings[2][0][0].shape), (1536,))
|
|
95
|
+
self.assertEqual(tuple(chain_embeddings[3][0][0].shape), (1536,))
|
|
96
|
+
self.assertEqual(tuple(chain_embeddings[4][0][0].shape), (1536,))
|
|
97
|
+
|
|
75
98
|
def test_structure_inference_from_chain(self):
|
|
76
99
|
from rcsb_embedding_model.inference.structure_inference import predict
|
|
77
100
|
|
|
@@ -1,11 +1,14 @@
|
|
|
1
|
+
import os.path
|
|
1
2
|
import unittest
|
|
2
3
|
|
|
3
4
|
from rcsb_embedding_model.types.api_types import SrcLocation, SrcProteinFrom, StructureLocation, StructureFormat, \
|
|
4
|
-
Accelerator
|
|
5
|
+
Accelerator, SrcAssemblyFrom
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class TestRemoteInference(unittest.TestCase):
|
|
8
9
|
|
|
10
|
+
__test_path = os.path.dirname(__file__)
|
|
11
|
+
|
|
9
12
|
def test_esm_inference_from_structure(self):
|
|
10
13
|
from rcsb_embedding_model.inference.esm_inference import predict
|
|
11
14
|
|
|
@@ -26,3 +29,20 @@ class TestRemoteInference(unittest.TestCase):
|
|
|
26
29
|
for idx, shape in enumerate(shapes):
|
|
27
30
|
self.assertEqual(tuple(esm_embeddings[idx][0][0].shape), shape)
|
|
28
31
|
|
|
32
|
+
def test_assembly_inference_from_structure(self):
|
|
33
|
+
from rcsb_embedding_model.inference.assembly_inferece import predict
|
|
34
|
+
|
|
35
|
+
assembly_embeddings = predict(
|
|
36
|
+
src_stream=[
|
|
37
|
+
("1acb", "https://files.rcsb.org/download/1acb.cif", "1acb"),
|
|
38
|
+
("2uzi", "https://files.rcsb.org/download/2uzi.cif", "2uzi")
|
|
39
|
+
],
|
|
40
|
+
res_embedding_location=f"{self.__test_path}/resources/embeddings",
|
|
41
|
+
src_location=SrcLocation.stream,
|
|
42
|
+
src_from=SrcAssemblyFrom.structure,
|
|
43
|
+
structure_location=StructureLocation.remote,
|
|
44
|
+
structure_format=StructureFormat.mmcif,
|
|
45
|
+
accelerator=Accelerator.cpu
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
self.assertEqual(len(assembly_embeddings), 2)
|
|
File without changes
|
|
File without changes
|
{rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/assets/embedding-model-architecture.png
RENAMED
|
File without changes
|
|
File without changes
|
{rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/model/layers.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/utils/model.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/1acb.A.pt
RENAMED
|
File without changes
|
{rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/1acb.B.pt
RENAMED
|
File without changes
|
{rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/2uzi.A.pt
RENAMED
|
File without changes
|
{rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/2uzi.B.pt
RENAMED
|
File without changes
|
{rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/2uzi.C.pt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|