rcsb-embedding-model 0.0.10__tar.gz → 0.0.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rcsb-embedding-model might be problematic. Click here for more details.

Files changed (43) hide show
  1. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/PKG-INFO +2 -2
  2. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/README.md +1 -1
  3. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/pyproject.toml +1 -1
  4. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/cli/inference.py +31 -7
  5. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/dataset/esm_prot_from_structure.py +1 -1
  6. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py +3 -3
  7. rcsb_embedding_model-0.0.12/src/rcsb_embedding_model/dataset/residue_embedding_from_structure.py +66 -0
  8. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py +2 -2
  9. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/inference/chain_inference.py +15 -1
  10. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/types/api_types.py +5 -0
  11. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/utils/data.py +1 -0
  12. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/utils/structure_parser.py +4 -4
  13. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/test_inference.py +24 -1
  14. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/test_remote_inference.py +21 -1
  15. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/.gitignore +0 -0
  16. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/LICENSE.md +0 -0
  17. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/assets/embedding-model-architecture.png +0 -0
  18. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/examples/esm_embeddings.py +0 -0
  19. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/__init__.py +0 -0
  20. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/cli/args_utils.py +0 -0
  21. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/dataset/esm_prot_from_chain.py +0 -0
  22. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py +0 -0
  23. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/inference/assembly_inferece.py +0 -0
  24. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/inference/esm_inference.py +0 -0
  25. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/inference/structure_inference.py +0 -0
  26. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/model/layers.py +0 -0
  27. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/model/residue_embedding_aggregator.py +0 -0
  28. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/modules/chain_module.py +0 -0
  29. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/modules/esm_module.py +0 -0
  30. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/modules/structure_module.py +0 -0
  31. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/rcsb_structure_embedding.py +0 -0
  32. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/utils/model.py +0 -0
  33. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/utils/structure_provider.py +0 -0
  34. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/src/rcsb_embedding_model/writer/batch_writer.py +0 -0
  35. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/1acb.A.pt +0 -0
  36. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/1acb.B.pt +0 -0
  37. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/2uzi.A.pt +0 -0
  38. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/2uzi.B.pt +0 -0
  39. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/embeddings/2uzi.C.pt +0 -0
  40. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/esm-from-chain-inference.csv +0 -0
  41. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/pdb/1acb.cif +0 -0
  42. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/resources/pdb/2uzi.cif +0 -0
  43. {rcsb_embedding_model-0.0.10 → rcsb_embedding_model-0.0.12}/tests/test_embedding_model.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rcsb-embedding-model
3
- Version: 0.0.10
3
+ Version: 0.0.12
4
4
  Summary: Protein Embedding Model for Structure Search
5
5
  Project-URL: Homepage, https://github.com/rcsb/rcsb-embedding-model
6
6
  Project-URL: Issues, https://github.com/rcsb/rcsb-embedding-model/issues
@@ -18,7 +18,7 @@ Description-Content-Type: text/markdown
18
18
 
19
19
  # RCSB Embedding Model
20
20
 
21
- **Version** 0.0.10
21
+ **Version** 0.0.12
22
22
 
23
23
 
24
24
  ## Overview
@@ -1,6 +1,6 @@
1
1
  # RCSB Embedding Model
2
2
 
3
- **Version** 0.0.10
3
+ **Version** 0.0.12
4
4
 
5
5
 
6
6
  ## Overview
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rcsb-embedding-model"
3
- version = "0.0.10"
3
+ version = "0.0.12"
4
4
  authors = [
5
5
  { name="Joan Segura", email="joan.segura@rcsb.org" },
6
6
  ]
@@ -5,7 +5,7 @@ import typer
5
5
 
6
6
  from rcsb_embedding_model.cli.args_utils import arg_devices
7
7
  from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, SrcLocation, SrcProteinFrom, \
8
- StructureLocation, SrcAssemblyFrom
8
+ StructureLocation, SrcAssemblyFrom, SrcTensorFrom
9
9
 
10
10
  app = typer.Typer(
11
11
  add_completion=False
@@ -22,7 +22,7 @@ def residue_embedding(
22
22
  file_okay=True,
23
23
  dir_okay=False,
24
24
  resolve_path=True,
25
- help='CSV file 4 (or 3) columns: Structure Name | Structure File Path | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
25
+ help='CSV file 4 (or 3) columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
26
26
  )],
27
27
  output_path: Annotated[typer.FileText, typer.Option(
28
28
  exists=True,
@@ -86,7 +86,7 @@ def structure_embedding(
86
86
  file_okay=True,
87
87
  dir_okay=False,
88
88
  resolve_path=True,
89
- help='CSV file 4 (or 3) columns: Structure Name | Structure File Path | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
89
+ help='CSV file 4 (or 3) columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
90
90
  )],
91
91
  output_path: Annotated[typer.FileText, typer.Option(
92
92
  exists=True,
@@ -102,7 +102,7 @@ def structure_embedding(
102
102
  help='Use specific chains or all chains in a structure.'
103
103
  )] = SrcProteinFrom.chain,
104
104
  structure_location: Annotated[StructureLocation, typer.Option(
105
- help='Source input location.'
105
+ help='Structure file location.'
106
106
  )] = StructureLocation.local,
107
107
  structure_format: Annotated[StructureFormat, typer.Option(
108
108
  help='Structure file format.'
@@ -154,7 +154,7 @@ def chain_embedding(
154
154
  file_okay=True,
155
155
  dir_okay=False,
156
156
  resolve_path=True,
157
- help='CSV file 2 columns: Residue embedding torch tensor file | Output embedding name.'
157
+ help='Option 1 (src-from=file) - CSV file 2 columns: Residue Embedding Torch Tensor File | Output Embedding Name. Option 2 (src-from=structure) - CSV file 3 columns: Structure Name | Structure File Path or URL (switch structure-location) | Output Embedding Name.'
158
158
  )],
159
159
  output_path: Annotated[typer.FileText, typer.Option(
160
160
  exists=True,
@@ -163,6 +163,25 @@ def chain_embedding(
163
163
  resolve_path=True,
164
164
  help='Output path to store predictions. Embeddings are stored as csv files.'
165
165
  )],
166
+ res_embedding_location: Annotated[typer.FileText, typer.Option(
167
+ exists=True,
168
+ file_okay=False,
169
+ dir_okay=True,
170
+ resolve_path=True,
171
+ help='Path where residue level embeddings are located. This argument is required if src-from=structure.'
172
+ )] = None,
173
+ src_from: Annotated[SrcTensorFrom, typer.Option(
174
+ help='Use file names or all chains in a structure.'
175
+ )] = SrcTensorFrom.file,
176
+ structure_location: Annotated[StructureLocation, typer.Option(
177
+ help='Structure file location.'
178
+ )] = StructureLocation.local,
179
+ structure_format: Annotated[StructureFormat, typer.Option(
180
+ help='Structure file format.'
181
+ )] = StructureFormat.mmcif,
182
+ min_res_n: Annotated[int, typer.Option(
183
+ help='When using all chains in a structure, consider only chains with more than <min_res_n> residues.'
184
+ )] = 0,
166
185
  batch_size: Annotated[int, typer.Option(
167
186
  help='Number of samples processed together in one iteration.'
168
187
  )] = 1,
@@ -182,7 +201,12 @@ def chain_embedding(
182
201
  from rcsb_embedding_model.inference.chain_inference import predict
183
202
  predict(
184
203
  src_stream=src_file,
204
+ res_embedding_location=res_embedding_location,
185
205
  src_location=SrcLocation.local,
206
+ src_from=src_from,
207
+ structure_location=structure_location,
208
+ structure_format=structure_format,
209
+ min_res_n=min_res_n,
186
210
  batch_size=batch_size,
187
211
  num_workers=num_workers,
188
212
  num_nodes=num_nodes,
@@ -201,7 +225,7 @@ def assembly_embedding(
201
225
  file_okay=True,
202
226
  dir_okay=False,
203
227
  resolve_path=True,
204
- help='CSV file 4 columns: Structure Name | Structure File Path | Assembly Id | Output embedding name.'
228
+ help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Assembly Id | Output embedding name.'
205
229
  )],
206
230
  res_embedding_location: Annotated[typer.FileText, typer.Option(
207
231
  exists=True,
@@ -221,7 +245,7 @@ def assembly_embedding(
221
245
  help='Use specific assembly or all assemblies in a structure.'
222
246
  )] = SrcAssemblyFrom.assembly,
223
247
  structure_location: Annotated[StructureLocation, typer.Option(
224
- help='Source input location.'
248
+ help='Structure file location.'
225
249
  )] = StructureLocation.local,
226
250
  structure_format: Annotated[StructureFormat, typer.Option(
227
251
  help='Structure file format.'
@@ -43,7 +43,7 @@ class EsmProtFromStructure(EsmProtFromChain):
43
43
  for idx, row in (pd.DataFrame(
44
44
  src_stream,
45
45
  dtype=str,
46
- columns=self.COLUMNS
46
+ columns=EsmProtFromStructure.COLUMNS
47
47
  ) if self.src_location == SrcLocation.stream else pd.read_csv(
48
48
  src_stream,
49
49
  header=None,
@@ -50,7 +50,7 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
50
50
  for idx, row in (pd.DataFrame(
51
51
  src_stream,
52
52
  dtype=str,
53
- columns=self.COLUMNS
53
+ columns=ResidueAssemblyDatasetFromStructure.COLUMNS
54
54
  ) if self.src_location == SrcLocation.stream else pd.read_csv(
55
55
  src_stream,
56
56
  header=None,
@@ -60,9 +60,9 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
60
60
  )).iterrows():
61
61
  src_name = row[ResidueAssemblyDatasetFromStructure.STREAM_NAME_ATTR]
62
62
  src_structure = row[ResidueAssemblyDatasetFromStructure.STREAM_ATTR]
63
- src_structure = stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure
63
+ structure = stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure
64
64
  item_name = row[ResidueAssemblyDatasetFromStructure.ITEM_NAME_ATTR]
65
- for assembly_id in get_assemblies(src_structure=src_structure, structure_format=self.structure_format):
65
+ for assembly_id in get_assemblies(structure=structure, structure_format=self.structure_format):
66
66
  assemblies.append((src_name, src_structure, str(assembly_id), f"{item_name}.{assembly_id}"))
67
67
 
68
68
  return tuple(assemblies)
@@ -0,0 +1,66 @@
1
+ import os
2
+
3
+ import pandas as pd
4
+
5
+ from rcsb_embedding_model.dataset.residue_embedding_from_tensor_file import ResidueEmbeddingFromTensorFile
6
+ from rcsb_embedding_model.types.api_types import SrcLocation, StructureLocation, StructureFormat
7
+ from rcsb_embedding_model.utils.data import stringio_from_url
8
+ from rcsb_embedding_model.utils.structure_parser import get_protein_chains
9
+ from rcsb_embedding_model.utils.structure_provider import StructureProvider
10
+
11
+
12
+ class ResidueEmbeddingFromStructure(ResidueEmbeddingFromTensorFile):
13
+
14
+ STREAM_NAME_ATTR = 'stream_name'
15
+ STREAM_ATTR = 'stream'
16
+ ITEM_NAME_ATTR = 'item_name'
17
+
18
+ COLUMNS = [STREAM_NAME_ATTR, STREAM_ATTR, ITEM_NAME_ATTR]
19
+
20
+ def __init__(
21
+ self,
22
+ src_stream,
23
+ res_embedding_location,
24
+ src_location=SrcLocation.local,
25
+ structure_location=StructureLocation.local,
26
+ structure_format=StructureFormat.mmcif,
27
+ min_res_n=0,
28
+ structure_provider=StructureProvider()
29
+ ):
30
+ if not os.path.isdir(res_embedding_location):
31
+ raise FileNotFoundError(f"Folder {res_embedding_location} does not exist")
32
+ self.res_embedding_location = res_embedding_location
33
+ self.src_location = src_location
34
+ self.structure_location = structure_location
35
+ self.structure_format = structure_format
36
+ self.min_res_n = min_res_n
37
+ self.__structure_provider = structure_provider
38
+ super().__init__(
39
+ src_stream=self.__get_chains(src_stream),
40
+ src_location=SrcLocation.stream
41
+ )
42
+
43
+ def __get_chains(self, src_stream):
44
+ chains = []
45
+ for idx, row in (pd.DataFrame(
46
+ src_stream,
47
+ dtype=str,
48
+ columns=ResidueEmbeddingFromStructure.COLUMNS
49
+ ) if self.src_location == SrcLocation.stream else pd.read_csv(
50
+ src_stream,
51
+ header=None,
52
+ index_col=None,
53
+ dtype=str,
54
+ names=ResidueEmbeddingFromStructure.COLUMNS
55
+ )).iterrows():
56
+ src_name = row[ResidueEmbeddingFromStructure.STREAM_NAME_ATTR]
57
+ src_structure = row[ResidueEmbeddingFromStructure.STREAM_ATTR]
58
+ item_name = row[ResidueEmbeddingFromStructure.ITEM_NAME_ATTR]
59
+ structure = self.__structure_provider.get_structure(
60
+ src_name=src_name,
61
+ src_structure=stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure,
62
+ structure_format=self.structure_format
63
+ )
64
+ for ch in get_protein_chains(structure, self.min_res_n):
65
+ chains.append((os.path.join(self.res_embedding_location, f"{src_name}.{ch}.pt"), f"{item_name}.{ch}"))
66
+ return tuple(chains)
@@ -2,7 +2,7 @@ import pandas as pd
2
2
  import torch
3
3
  from torch.utils.data import Dataset
4
4
 
5
- from rcsb_embedding_model.types.api_types import StructureLocation, SrcLocation
5
+ from rcsb_embedding_model.types.api_types import SrcLocation
6
6
 
7
7
 
8
8
  class ResidueEmbeddingFromTensorFile(Dataset):
@@ -26,7 +26,7 @@ class ResidueEmbeddingFromTensorFile(Dataset):
26
26
  self.data = pd.DataFrame(
27
27
  src_stream,
28
28
  dtype=str,
29
- columns=self.COLUMNS
29
+ columns=ResidueEmbeddingFromTensorFile.COLUMNS
30
30
  ) if self.src_location == SrcLocation.stream else pd.read_csv(
31
31
  src_stream,
32
32
  header=None,
@@ -1,16 +1,23 @@
1
1
  from torch.utils.data import DataLoader
2
2
  from lightning import Trainer
3
3
 
4
+ from rcsb_embedding_model.dataset.residue_embedding_from_structure import ResidueEmbeddingFromStructure
4
5
  from rcsb_embedding_model.dataset.residue_embedding_from_tensor_file import ResidueEmbeddingFromTensorFile
5
6
  from rcsb_embedding_model.modules.chain_module import ChainModule
6
- from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath, FileOrStreamTuple, SrcLocation
7
+ from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath, FileOrStreamTuple, SrcLocation, \
8
+ SrcTensorFrom, StructureLocation, StructureFormat
7
9
  from rcsb_embedding_model.utils.data import collate_seq_embeddings
8
10
  from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter
9
11
 
10
12
 
11
13
  def predict(
12
14
  src_stream: FileOrStreamTuple,
15
+ res_embedding_location: OptionalPath = None,
13
16
  src_location: SrcLocation = SrcLocation.local,
17
+ src_from: SrcTensorFrom = SrcTensorFrom.file,
18
+ structure_location: StructureLocation = StructureLocation.local,
19
+ structure_format: StructureFormat = StructureFormat.mmcif,
20
+ min_res_n: int = 0,
14
21
  batch_size: int = 1,
15
22
  num_workers: int = 0,
16
23
  num_nodes: int = 1,
@@ -24,6 +31,13 @@ def predict(
24
31
  inference_set = ResidueEmbeddingFromTensorFile(
25
32
  src_stream=src_stream,
26
33
  src_location=src_location
34
+ ) if src_from == SrcTensorFrom.file else ResidueEmbeddingFromStructure(
35
+ src_stream=src_stream,
36
+ res_embedding_location=res_embedding_location,
37
+ src_location=src_location,
38
+ structure_location=structure_location,
39
+ structure_format=structure_format,
40
+ min_res_n=min_res_n
27
41
  )
28
42
 
29
43
  inference_dataloader = DataLoader(
@@ -49,3 +49,8 @@ class SrcProteinFrom(str, Enum):
49
49
  class SrcAssemblyFrom(str, Enum):
50
50
  assembly = "assembly"
51
51
  structure = "structure"
52
+
53
+
54
+ class SrcTensorFrom(str, Enum):
55
+ file = "file"
56
+ structure = "structure"
@@ -44,6 +44,7 @@ def stringio_from_url(url):
44
44
  print(f"Error fetching URL: {e}")
45
45
  return None
46
46
 
47
+
47
48
  def concatenate_tensors(file_list, max_residues, dim=0):
48
49
  """
49
50
  Concatenates a list of tensors stored in individual files along a specified dimension.
@@ -37,13 +37,13 @@ def get_protein_chains(structure, min_res_n=0):
37
37
  return tuple(chain_ids)
38
38
 
39
39
 
40
- def get_assemblies(src_structure, structure_format="mmcif"):
40
+ def get_assemblies(structure, structure_format="mmcif"):
41
41
  if structure_format == "pdb":
42
- return tuple(list_pdb_assemblies(PDBFile.read(src_structure)))
42
+ return tuple(list_pdb_assemblies(PDBFile.read(structure)))
43
43
  elif structure_format == "mmcif":
44
- return tuple(list_assemblies(CIFFile.read(src_structure)).keys())
44
+ return tuple(list_assemblies(CIFFile.read(structure)).keys())
45
45
  elif structure_format == "binarycif":
46
- return tuple(list_assemblies(BinaryCIFFile.read(src_structure)))
46
+ return tuple(list_assemblies(BinaryCIFFile.read(structure)))
47
47
  else:
48
48
  raise RuntimeError(f"Unknown file format {structure_format}")
49
49
 
@@ -2,7 +2,7 @@ import os
2
2
  import unittest
3
3
 
4
4
  from rcsb_embedding_model.types.api_types import StructureLocation, Accelerator, SrcProteinFrom, SrcLocation, \
5
- StructureFormat, SrcAssemblyFrom
5
+ StructureFormat, SrcAssemblyFrom, SrcTensorFrom
6
6
 
7
7
 
8
8
  class TestInference(unittest.TestCase):
@@ -72,6 +72,29 @@ class TestInference(unittest.TestCase):
72
72
  self.assertEqual(tuple(chain_embeddings[3][0][0].shape), (1536,))
73
73
  self.assertEqual(tuple(chain_embeddings[4][0][0].shape), (1536,))
74
74
 
75
+ def test_chain_inference_from_structure(self):
76
+ from rcsb_embedding_model.inference.chain_inference import predict
77
+ chain_embeddings = predict(
78
+ src_stream=[
79
+ ("1acb", f"{self.__test_path}/resources/pdb/1acb.cif", "1acb"),
80
+ ("2uzi", f"{self.__test_path}/resources/pdb/2uzi.cif", "2uzi"),
81
+ ],
82
+ res_embedding_location=f"{self.__test_path}/resources/embeddings",
83
+ src_location=SrcLocation.stream,
84
+ src_from=SrcTensorFrom.structure,
85
+ structure_location=StructureLocation.local,
86
+ structure_format=StructureFormat.mmcif,
87
+ min_res_n=0,
88
+ accelerator=Accelerator.cpu
89
+ )
90
+
91
+ self.assertEqual(len(chain_embeddings), 5)
92
+ self.assertEqual(tuple(chain_embeddings[0][0][0].shape), (1536,))
93
+ self.assertEqual(tuple(chain_embeddings[1][0][0].shape), (1536,))
94
+ self.assertEqual(tuple(chain_embeddings[2][0][0].shape), (1536,))
95
+ self.assertEqual(tuple(chain_embeddings[3][0][0].shape), (1536,))
96
+ self.assertEqual(tuple(chain_embeddings[4][0][0].shape), (1536,))
97
+
75
98
  def test_structure_inference_from_chain(self):
76
99
  from rcsb_embedding_model.inference.structure_inference import predict
77
100
 
@@ -1,11 +1,14 @@
1
+ import os.path
1
2
  import unittest
2
3
 
3
4
  from rcsb_embedding_model.types.api_types import SrcLocation, SrcProteinFrom, StructureLocation, StructureFormat, \
4
- Accelerator
5
+ Accelerator, SrcAssemblyFrom
5
6
 
6
7
 
7
8
  class TestRemoteInference(unittest.TestCase):
8
9
 
10
+ __test_path = os.path.dirname(__file__)
11
+
9
12
  def test_esm_inference_from_structure(self):
10
13
  from rcsb_embedding_model.inference.esm_inference import predict
11
14
 
@@ -26,3 +29,20 @@ class TestRemoteInference(unittest.TestCase):
26
29
  for idx, shape in enumerate(shapes):
27
30
  self.assertEqual(tuple(esm_embeddings[idx][0][0].shape), shape)
28
31
 
32
+ def test_assembly_inference_from_structure(self):
33
+ from rcsb_embedding_model.inference.assembly_inferece import predict
34
+
35
+ assembly_embeddings = predict(
36
+ src_stream=[
37
+ ("1acb", "https://files.rcsb.org/download/1acb.cif", "1acb"),
38
+ ("2uzi", "https://files.rcsb.org/download/2uzi.cif", "2uzi")
39
+ ],
40
+ res_embedding_location=f"{self.__test_path}/resources/embeddings",
41
+ src_location=SrcLocation.stream,
42
+ src_from=SrcAssemblyFrom.structure,
43
+ structure_location=StructureLocation.remote,
44
+ structure_format=StructureFormat.mmcif,
45
+ accelerator=Accelerator.cpu
46
+ )
47
+
48
+ self.assertEqual(len(assembly_embeddings), 2)