rcsb-embedding-model 0.0.35__py3-none-any.whl → 0.0.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rcsb-embedding-model might be problematic. Click here for more details.

@@ -6,7 +6,7 @@ import typer
6
6
  from rcsb_embedding_model import __version__
7
7
  from rcsb_embedding_model.cli.args_utils import arg_devices
8
8
  from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, SrcLocation, SrcProteinFrom, \
9
- StructureLocation, SrcAssemblyFrom, SrcTensorFrom, OutFormat
9
+ SrcAssemblyFrom, SrcTensorFrom, OutFormat
10
10
  from rcsb_embedding_model.utils.data import adapt_csv_to_embedding_chain_stream
11
11
 
12
12
  import os
@@ -42,9 +42,6 @@ def residue_embedding(
42
42
  output_name: Annotated[str, typer.Option(
43
43
  help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
44
44
  )] = 'inference',
45
- structure_location: Annotated[StructureLocation, typer.Option(
46
- help='Structure file location.'
47
- )] = StructureLocation.local,
48
45
  structure_format: Annotated[StructureFormat, typer.Option(
49
46
  help='Structure file format.'
50
47
  )] = StructureFormat.mmcif,
@@ -72,7 +69,6 @@ def residue_embedding(
72
69
  src_stream=src_file,
73
70
  src_location=SrcLocation.file,
74
71
  src_from=SrcProteinFrom.chain,
75
- structure_location=structure_location,
76
72
  structure_format=structure_format,
77
73
  min_res_n=min_res_n,
78
74
  batch_size=batch_size,
@@ -108,9 +104,6 @@ def structure_embedding(
108
104
  output_name: Annotated[str, typer.Option(
109
105
  help='File name for storing embeddings as a single JSON file.'
110
106
  )] = 'inference',
111
- structure_location: Annotated[StructureLocation, typer.Option(
112
- help='Structure file location.'
113
- )] = StructureLocation.local,
114
107
  structure_format: Annotated[StructureFormat, typer.Option(
115
108
  help='Structure file format.'
116
109
  )] = StructureFormat.mmcif,
@@ -138,7 +131,6 @@ def structure_embedding(
138
131
  src_stream=src_file,
139
132
  src_location=SrcLocation.file,
140
133
  src_from=SrcProteinFrom.chain,
141
- structure_location=structure_location,
142
134
  structure_format=structure_format,
143
135
  min_res_n=min_res_n,
144
136
  batch_size=batch_size,
@@ -183,9 +175,6 @@ def chain_embedding(
183
175
  output_name: Annotated[str, typer.Option(
184
176
  help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
185
177
  )] = 'inference',
186
- structure_location: Annotated[StructureLocation, typer.Option(
187
- help='Structure file location.'
188
- )] = StructureLocation.local,
189
178
  structure_format: Annotated[StructureFormat, typer.Option(
190
179
  help='Structure file format.'
191
180
  )] = StructureFormat.mmcif,
@@ -214,7 +203,6 @@ def chain_embedding(
214
203
  res_embedding_location=res_embedding_location,
215
204
  src_location=SrcLocation.stream,
216
205
  src_from=SrcTensorFrom.file,
217
- structure_location=structure_location,
218
206
  structure_format=structure_format,
219
207
  min_res_n=min_res_n,
220
208
  batch_size=batch_size,
@@ -259,9 +247,6 @@ def assembly_embedding(
259
247
  output_name: Annotated[str, typer.Option(
260
248
  help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
261
249
  )] = 'inference',
262
- structure_location: Annotated[StructureLocation, typer.Option(
263
- help='Structure file location.'
264
- )] = StructureLocation.local,
265
250
  structure_format: Annotated[StructureFormat, typer.Option(
266
251
  help='Structure file format.'
267
252
  )] = StructureFormat.mmcif,
@@ -293,7 +278,6 @@ def assembly_embedding(
293
278
  res_embedding_location=res_embedding_location,
294
279
  src_location=SrcLocation.file,
295
280
  src_from=SrcAssemblyFrom.assembly,
296
- structure_location=structure_location,
297
281
  structure_format=structure_format,
298
282
  min_res_n=min_res_n,
299
283
  max_res_n=max_res_n,
@@ -356,9 +340,6 @@ def complete_embedding(
356
340
  output_assembly_name: Annotated[str, typer.Option(
357
341
  help='File name for storing chain embeddings as a single JSON file. Used when output-format=grouped.'
358
342
  )] = 'chain-inference',
359
- structure_location: Annotated[StructureLocation, typer.Option(
360
- help='Structure file location.'
361
- )] = StructureLocation.local,
362
343
  structure_format: Annotated[StructureFormat, typer.Option(
363
344
  help='Structure file format.'
364
345
  )] = StructureFormat.mmcif,
@@ -397,7 +378,6 @@ def complete_embedding(
397
378
  src_file=src_chain_file,
398
379
  output_path=output_res_path,
399
380
  output_format=OutFormat.separated,
400
- structure_location=structure_location,
401
381
  structure_format=structure_format,
402
382
  min_res_n=min_res_n,
403
383
  batch_size=batch_size_res,
@@ -412,7 +392,6 @@ def complete_embedding(
412
392
  output_format=output_format,
413
393
  output_name=output_chain_name,
414
394
  res_embedding_location=output_res_path,
415
- structure_location=structure_location,
416
395
  structure_format=structure_format,
417
396
  min_res_n=min_res_n,
418
397
  batch_size=batch_size_chain,
@@ -427,7 +406,6 @@ def complete_embedding(
427
406
  output_format=output_format,
428
407
  output_name=output_assembly_name,
429
408
  res_embedding_location=output_res_path,
430
- structure_location=structure_location,
431
409
  structure_format=structure_format,
432
410
  min_res_n=min_res_n,
433
411
  batch_size=batch_size_assembly,
@@ -9,6 +9,7 @@ from esm.utils.structure.protein_chain import ProteinChain
9
9
  from torch.utils.data import Dataset, DataLoader
10
10
  import pandas as pd
11
11
 
12
+ from rcsb_embedding_model.dataset.untils import get_structure_location
12
13
  from rcsb_embedding_model.types.api_types import StructureFormat, StructureLocation, SrcLocation
13
14
  from rcsb_embedding_model.utils.data import stringio_from_url
14
15
  from rcsb_embedding_model.utils.structure_parser import rename_atom_attr,filter_residues
@@ -28,14 +29,12 @@ class EsmProtFromChain(Dataset):
28
29
  self,
29
30
  src_stream,
30
31
  src_location=SrcLocation.file,
31
- structure_location=StructureLocation.local,
32
32
  structure_format=StructureFormat.mmcif,
33
33
  structure_provider=StructureProvider()
34
34
  ):
35
35
  super().__init__()
36
36
  self.__structure_provider = structure_provider
37
37
  self.src_location = src_location
38
- self.structure_location = structure_location
39
38
  self.structure_format = structure_format
40
39
  self.data = pd.DataFrame()
41
40
  self.__load_stream(src_stream)
@@ -65,7 +64,7 @@ class EsmProtFromChain(Dataset):
65
64
  item_name = self.data.iloc[idx][EsmProtFromChain.ITEM_NAME_ATTR]
66
65
  structure = self.__structure_provider.get_structure(
67
66
  src_name=src_name,
68
- src_structure=stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure,
67
+ src_structure=stringio_from_url(src_structure) if get_structure_location(src_structure) == StructureLocation.remote else src_structure,
69
68
  structure_format=self.structure_format,
70
69
  chain_id=chain_id
71
70
  )
@@ -96,7 +95,6 @@ if __name__ == '__main__':
96
95
  dataset = EsmProtFromChain(
97
96
  src_stream=args.file_list,
98
97
  src_location=SrcLocation.file,
99
- structure_location=StructureLocation.remote,
100
98
  structure_format=StructureFormat.bciff,
101
99
  )
102
100
 
@@ -2,6 +2,7 @@
2
2
  import pandas as pd
3
3
 
4
4
  from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
5
+ from rcsb_embedding_model.dataset.untils import get_structure_location
5
6
  from rcsb_embedding_model.types.api_types import StructureLocation, StructureFormat, SrcLocation
6
7
  from rcsb_embedding_model.utils.data import stringio_from_url
7
8
  from rcsb_embedding_model.utils.structure_parser import get_protein_chains
@@ -20,20 +21,17 @@ class EsmProtFromStructure(EsmProtFromChain):
20
21
  self,
21
22
  src_stream,
22
23
  src_location=SrcLocation.file,
23
- structure_location=StructureLocation.local,
24
24
  structure_format=StructureFormat.mmcif,
25
25
  min_res_n=0,
26
26
  structure_provider=StructureProvider()
27
27
  ):
28
28
  self.min_res_n = min_res_n
29
29
  self.src_location = src_location
30
- self.structure_location = structure_location
31
30
  self.structure_format = structure_format
32
31
  self.__structure_provider = structure_provider
33
32
  super().__init__(
34
33
  src_stream=self.__get_chains(src_stream),
35
34
  src_location=SrcLocation.stream,
36
- structure_location=structure_location,
37
35
  structure_format=structure_format,
38
36
  structure_provider=structure_provider
39
37
  )
@@ -58,7 +56,7 @@ class EsmProtFromStructure(EsmProtFromChain):
58
56
  item_name = row[EsmProtFromStructure.ITEM_NAME_ATTR]
59
57
  structure = self.__structure_provider.get_structure(
60
58
  src_name=src_name,
61
- src_structure=stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure,
59
+ src_structure=stringio_from_url(src_structure) if get_structure_location(src_structure) == StructureLocation.remote else src_structure,
62
60
  structure_format=self.structure_format
63
61
  )
64
62
  for ch in get_protein_chains(structure, self.min_res_n):
@@ -3,6 +3,7 @@ import sys
3
3
  import pandas as pd
4
4
 
5
5
  from rcsb_embedding_model.dataset.residue_assembly_embedding_from_tensor_file import ResidueAssemblyEmbeddingFromTensorFile
6
+ from rcsb_embedding_model.dataset.untils import get_structure_location
6
7
  from rcsb_embedding_model.types.api_types import SrcLocation, StructureLocation, StructureFormat
7
8
  from rcsb_embedding_model.utils.data import stringio_from_url
8
9
  from rcsb_embedding_model.utils.structure_parser import get_assemblies
@@ -22,14 +23,12 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
22
23
  src_stream,
23
24
  res_embedding_location,
24
25
  src_location=SrcLocation.file,
25
- structure_location=StructureLocation.local,
26
26
  structure_format=StructureFormat.mmcif,
27
27
  min_res_n=0,
28
28
  max_res_n=sys.maxsize,
29
29
  structure_provider=StructureProvider()
30
30
  ):
31
31
  self.src_location = src_location
32
- self.structure_location = structure_location
33
32
  self.structure_format = structure_format
34
33
  self.min_res_n = min_res_n
35
34
  self.max_res_n = max_res_n
@@ -37,7 +36,6 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
37
36
  src_stream=self.__get_assemblies(src_stream),
38
37
  res_embedding_location=res_embedding_location,
39
38
  src_location=SrcLocation.stream,
40
- structure_location=structure_location,
41
39
  structure_format=structure_format,
42
40
  min_res_n=min_res_n,
43
41
  max_res_n=max_res_n,
@@ -61,7 +59,7 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
61
59
  for idx, row in data.iterrows():
62
60
  src_name = row[ResidueAssemblyDatasetFromStructure.STREAM_NAME_ATTR]
63
61
  src_structure = row[ResidueAssemblyDatasetFromStructure.STREAM_ATTR]
64
- structure = stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure
62
+ structure = stringio_from_url(src_structure) if get_structure_location(src_structure) == StructureLocation.remote else src_structure
65
63
  item_name = row[ResidueAssemblyDatasetFromStructure.ITEM_NAME_ATTR]
66
64
  for assembly_id in get_assemblies(structure=structure, structure_format=self.structure_format):
67
65
  assemblies.append((src_name, src_structure, str(assembly_id), f"{item_name}-{assembly_id}"))
@@ -4,6 +4,7 @@ import sys
4
4
  import pandas as pd
5
5
  from torch.utils.data import Dataset, DataLoader
6
6
 
7
+ from rcsb_embedding_model.dataset.untils import get_structure_location
7
8
  from rcsb_embedding_model.types.api_types import StructureLocation, StructureFormat, SrcLocation
8
9
  from rcsb_embedding_model.utils.data import stringio_from_url, concatenate_tensors
9
10
  from rcsb_embedding_model.utils.structure_parser import get_protein_chains
@@ -24,7 +25,6 @@ class ResidueAssemblyEmbeddingFromTensorFile(Dataset):
24
25
  src_stream,
25
26
  res_embedding_location,
26
27
  src_location=SrcLocation.file,
27
- structure_location=StructureLocation.local,
28
28
  structure_format=StructureFormat.mmcif,
29
29
  min_res_n=0,
30
30
  max_res_n=sys.maxsize,
@@ -33,7 +33,6 @@ class ResidueAssemblyEmbeddingFromTensorFile(Dataset):
33
33
  super().__init__()
34
34
  self.res_embedding_location = res_embedding_location
35
35
  self.src_location = src_location
36
- self.structure_location = structure_location
37
36
  self.structure_format = structure_format
38
37
  self.min_res_n = min_res_n
39
38
  self.max_res_n = max_res_n
@@ -65,7 +64,7 @@ class ResidueAssemblyEmbeddingFromTensorFile(Dataset):
65
64
  item_name = self.data.iloc[idx][ResidueAssemblyEmbeddingFromTensorFile.ITEM_NAME_ATTR]
66
65
  structure = self.__structure_provider.get_structure(
67
66
  src_name=src_name,
68
- src_structure=stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure,
67
+ src_structure=stringio_from_url(src_structure) if get_structure_location(src_structure) == StructureLocation.remote else src_structure,
69
68
  structure_format=self.structure_format,
70
69
  assembly_id=assembly_id
71
70
  )
@@ -86,7 +85,6 @@ if __name__ == "__main__":
86
85
  src_stream=args.file_list,
87
86
  res_embedding_location=args.res_embeddings_path,
88
87
  src_location=SrcLocation.file,
89
- structure_location=StructureLocation.remote,
90
88
  structure_format=StructureFormat.bciff
91
89
  )
92
90
 
@@ -3,6 +3,7 @@ import os
3
3
  import pandas as pd
4
4
 
5
5
  from rcsb_embedding_model.dataset.residue_embedding_from_tensor_file import ResidueEmbeddingFromTensorFile
6
+ from rcsb_embedding_model.dataset.untils import get_structure_location
6
7
  from rcsb_embedding_model.types.api_types import SrcLocation, StructureLocation, StructureFormat
7
8
  from rcsb_embedding_model.utils.data import stringio_from_url
8
9
  from rcsb_embedding_model.utils.structure_parser import get_protein_chains
@@ -22,7 +23,6 @@ class ResidueEmbeddingFromStructure(ResidueEmbeddingFromTensorFile):
22
23
  src_stream,
23
24
  res_embedding_location,
24
25
  src_location=SrcLocation.file,
25
- structure_location=StructureLocation.local,
26
26
  structure_format=StructureFormat.mmcif,
27
27
  min_res_n=0,
28
28
  structure_provider=StructureProvider()
@@ -31,7 +31,6 @@ class ResidueEmbeddingFromStructure(ResidueEmbeddingFromTensorFile):
31
31
  raise FileNotFoundError(f"Folder {res_embedding_location} does not exist")
32
32
  self.res_embedding_location = res_embedding_location
33
33
  self.src_location = src_location
34
- self.structure_location = structure_location
35
34
  self.structure_format = structure_format
36
35
  self.min_res_n = min_res_n
37
36
  self.__structure_provider = structure_provider
@@ -60,7 +59,7 @@ class ResidueEmbeddingFromStructure(ResidueEmbeddingFromTensorFile):
60
59
  item_name = row[ResidueEmbeddingFromStructure.ITEM_NAME_ATTR]
61
60
  structure = self.__structure_provider.get_structure(
62
61
  src_name=src_name,
63
- src_structure=stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure,
62
+ src_structure=stringio_from_url(src_structure) if get_structure_location(src_structure) == StructureLocation.remote else src_structure,
64
63
  structure_format=self.structure_format
65
64
  )
66
65
  for ch in get_protein_chains(structure, self.min_res_n):
@@ -0,0 +1,4 @@
1
+
2
+ from rcsb_embedding_model.dataset.untils.utils import get_structure_location
3
+
4
+ __all__ = ["get_structure_location"]
@@ -0,0 +1,17 @@
1
+ import os
2
+ from urllib.parse import urlparse
3
+ from rcsb_embedding_model.types.api_types import StructureLocation
4
+
5
+
6
+ def get_structure_location(s: str) -> str:
7
+ # First, attempt to parse as URL
8
+ parsed = urlparse(s)
9
+ if parsed.scheme.lower() in {'http', 'https', 'ftp'} and parsed.netloc:
10
+ return StructureLocation.remote
11
+
12
+ # Next, test for an existing file or directory
13
+ if os.path.exists(s):
14
+ return StructureLocation.local
15
+
16
+ # Neither URL nor existing file
17
+ raise ValueError(f"Structure file source is neither a recognized URL nor file: {s!r}")
@@ -3,7 +3,7 @@ import sys
3
3
  from rcsb_embedding_model.dataset.resdiue_assembly_embedding_from_structure import ResidueAssemblyDatasetFromStructure
4
4
  from rcsb_embedding_model.dataset.residue_assembly_embedding_from_tensor_file import ResidueAssemblyEmbeddingFromTensorFile
5
5
  from rcsb_embedding_model.types.api_types import FileOrStreamTuple, SrcLocation, Accelerator, Devices, OptionalPath, \
6
- EmbeddingPath, StructureLocation, StructureFormat, SrcAssemblyFrom, OutFormat
6
+ EmbeddingPath, StructureFormat, SrcAssemblyFrom, OutFormat
7
7
  from rcsb_embedding_model.inference.chain_inference import predict as chain_predict
8
8
 
9
9
 
@@ -12,7 +12,6 @@ def predict(
12
12
  res_embedding_location: EmbeddingPath,
13
13
  src_location: SrcLocation = SrcLocation.file,
14
14
  src_from: SrcAssemblyFrom = SrcAssemblyFrom.assembly,
15
- structure_location: StructureLocation = StructureLocation.local,
16
15
  structure_format: StructureFormat = StructureFormat.mmcif,
17
16
  min_res_n: int = 0,
18
17
  max_res_n: int = sys.maxsize,
@@ -29,7 +28,6 @@ def predict(
29
28
  src_stream=src_stream,
30
29
  res_embedding_location=res_embedding_location,
31
30
  src_location=src_location,
32
- structure_location=structure_location,
33
31
  structure_format=structure_format,
34
32
  min_res_n=min_res_n,
35
33
  max_res_n=max_res_n
@@ -37,7 +35,6 @@ def predict(
37
35
  src_stream=src_stream,
38
36
  res_embedding_location=res_embedding_location,
39
37
  src_location=src_location,
40
- structure_location=structure_location,
41
38
  structure_format=structure_format,
42
39
  min_res_n=min_res_n,
43
40
  max_res_n=max_res_n
@@ -6,7 +6,7 @@ from rcsb_embedding_model.dataset.residue_embedding_from_structure import Residu
6
6
  from rcsb_embedding_model.dataset.residue_embedding_from_tensor_file import ResidueEmbeddingFromTensorFile
7
7
  from rcsb_embedding_model.modules.chain_module import ChainModule
8
8
  from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath, FileOrStreamTuple, SrcLocation, \
9
- SrcTensorFrom, StructureLocation, StructureFormat, OutFormat
9
+ SrcTensorFrom, StructureFormat, OutFormat
10
10
  from rcsb_embedding_model.utils.data import collate_seq_embeddings
11
11
  from rcsb_embedding_model.utils.model import get_aggregator_model
12
12
  from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter, JsonStorage
@@ -17,7 +17,6 @@ def predict(
17
17
  res_embedding_location: OptionalPath = None,
18
18
  src_location: SrcLocation = SrcLocation.file,
19
19
  src_from: SrcTensorFrom = SrcTensorFrom.file,
20
- structure_location: StructureLocation = StructureLocation.local,
21
20
  structure_format: StructureFormat = StructureFormat.mmcif,
22
21
  min_res_n: int = 0,
23
22
  batch_size: int = 1,
@@ -39,7 +38,6 @@ def predict(
39
38
  src_stream=src_stream,
40
39
  res_embedding_location=res_embedding_location,
41
40
  src_location=src_location,
42
- structure_location=structure_location,
43
41
  structure_format=structure_format,
44
42
  min_res_n=min_res_n
45
43
  )
@@ -5,7 +5,7 @@ from lightning import Trainer
5
5
  from rcsb_embedding_model.dataset.esm_prot_from_structure import EsmProtFromStructure
6
6
  from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
7
7
  from rcsb_embedding_model.modules.esm_module import EsmModule
8
- from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, \
8
+ from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, \
9
9
  SrcProteinFrom, FileOrStreamTuple, SrcLocation, OutFormat
10
10
  from rcsb_embedding_model.utils.model import get_residue_model
11
11
  from rcsb_embedding_model.writer.batch_writer import TensorBatchWriter, JsonStorage
@@ -15,7 +15,6 @@ def predict(
15
15
  src_stream: FileOrStreamTuple,
16
16
  src_location: SrcLocation = SrcLocation.file,
17
17
  src_from: SrcProteinFrom = SrcProteinFrom.chain,
18
- structure_location: StructureLocation = StructureLocation.local,
19
18
  structure_format: StructureFormat = StructureFormat.mmcif,
20
19
  min_res_n: int = 0,
21
20
  batch_size: int = 1,
@@ -31,12 +30,10 @@ def predict(
31
30
  inference_set = EsmProtFromChain(
32
31
  src_stream=src_stream,
33
32
  src_location=src_location,
34
- structure_location=structure_location,
35
33
  structure_format=structure_format
36
34
  ) if src_from == SrcProteinFrom.chain else EsmProtFromStructure(
37
35
  src_stream=src_stream,
38
36
  src_location=src_location,
39
- structure_location=structure_location,
40
37
  structure_format=structure_format,
41
38
  min_res_n=min_res_n
42
39
  )
@@ -5,7 +5,7 @@ from lightning import Trainer
5
5
  from rcsb_embedding_model.dataset.esm_prot_from_structure import EsmProtFromStructure
6
6
  from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
7
7
  from rcsb_embedding_model.modules.structure_module import StructureModule
8
- from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, \
8
+ from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, \
9
9
  SrcProteinFrom, FileOrStreamTuple, SrcLocation
10
10
  from rcsb_embedding_model.utils.model import get_residue_model, get_aggregator_model
11
11
  from rcsb_embedding_model.writer.batch_writer import JsonStorage
@@ -15,7 +15,6 @@ def predict(
15
15
  src_stream: FileOrStreamTuple,
16
16
  src_location: SrcLocation = SrcLocation.file,
17
17
  src_from: SrcProteinFrom = SrcProteinFrom.chain,
18
- structure_location: StructureLocation = StructureLocation.local,
19
18
  structure_format: StructureFormat = StructureFormat.mmcif,
20
19
  min_res_n: int = 0,
21
20
  batch_size: int = 1,
@@ -30,12 +29,10 @@ def predict(
30
29
  inference_set = EsmProtFromChain(
31
30
  src_stream=src_stream,
32
31
  src_location=src_location,
33
- structure_location=structure_location,
34
32
  structure_format=structure_format
35
33
  ) if src_from == SrcProteinFrom.chain else EsmProtFromStructure(
36
34
  src_stream=src_stream,
37
35
  src_location=src_location,
38
- structure_location=structure_location,
39
36
  structure_format=structure_format,
40
37
  min_res_n=min_res_n
41
38
  )
@@ -95,4 +95,4 @@ def adapt_csv_to_embedding_chain_stream(src_file, res_embedding_location):
95
95
  def __parse_row(row):
96
96
  r = row.split(",")
97
97
  return os.path.join(res_embedding_location, f"{r[0]}.{r[2]}.pt"), f"{r[0]}.{r[2]}"
98
- return tuple([__parse_row(r.strip()) for r in open(src_file)])
98
+ return tuple([__parse_row(r.strip()) for r in open(src_file) if len(r.split(",")) > 2])
@@ -0,0 +1,65 @@
1
+ from pathlib import Path
2
+ import torch
3
+
4
+ from esm.models.esm3 import ESM3
5
+ from esm.models.vqvae import StructureTokenEncoder
6
+ from esm.tokenization import TokenizerCollection, EsmSequenceTokenizer, StructureTokenizer, SecondaryStructureTokenizer, \
7
+ SASADiscretizingTokenizer, InterProQuantizedTokenizer, ResidueAnnotationsTokenizer
8
+
9
+ from huggingface_hub import snapshot_download
10
+
11
+ def data_root():
12
+ path = Path(snapshot_download(repo_id="rcsb/rcsb-esm"))
13
+ return path
14
+
15
+
16
+ def structure_encoder(device: torch.device | str = "cpu"):
17
+ with torch.device(device):
18
+ model = StructureTokenEncoder(
19
+ d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
20
+ ).eval()
21
+ state_dict = torch.load(
22
+ data_root() / "data/weights/esm3_structure_encoder_v0.pth", map_location=device
23
+ )
24
+ model.load_state_dict(state_dict)
25
+ return model
26
+
27
+
28
+
29
+ def get_model_tokenizers():
30
+
31
+ class CustomAnnotationsTokenizer(ResidueAnnotationsTokenizer):
32
+ def __init__(self, csv_path: str | None = None, max_annotations: int = 16):
33
+ from esm.utils.constants import esm3 as C
34
+ super().__init__("none", max_annotations)
35
+ if csv_path is None:
36
+ csv_path = str(data_root() / C.RESID_CSV)
37
+ self.csv_path = csv_path
38
+
39
+ return TokenizerCollection(
40
+ sequence=EsmSequenceTokenizer(),
41
+ structure=StructureTokenizer(),
42
+ secondary_structure=SecondaryStructureTokenizer(kind="ss8"),
43
+ sasa=SASADiscretizingTokenizer(),
44
+ function=InterProQuantizedTokenizer(),
45
+ residue_annotations=CustomAnnotationsTokenizer(),
46
+ )
47
+
48
+
49
+ def esm_open(device: torch.device | str = "cpu"):
50
+ with torch.device(device):
51
+ model = ESM3(
52
+ d_model=1536,
53
+ n_heads=24,
54
+ v_heads=256,
55
+ n_layers=48,
56
+ structure_encoder_fn=structure_encoder,
57
+ structure_decoder_fn=lambda x: x,
58
+ function_decoder_fn=lambda x: x,
59
+ tokenizers=get_model_tokenizers(),
60
+ ).eval()
61
+ state_dict = torch.load(
62
+ data_root() / "data/weights/esm3_sm_open_v1.pth", map_location=device
63
+ )
64
+ model.load_state_dict(state_dict)
65
+ return model
@@ -1,9 +1,8 @@
1
1
  import torch
2
- from esm.models.esm3 import ESM3
3
- from esm.utils.constants.models import ESM3_OPEN_SMALL
4
- from huggingface_hub import hf_hub_download
5
2
 
3
+ from huggingface_hub import hf_hub_download
6
4
  from rcsb_embedding_model.model.residue_embedding_aggregator import ResidueEmbeddingAggregator
5
+ from rcsb_embedding_model.utils.esm.loaders import esm_open
7
6
 
8
7
  REPO_ID = "rcsb/rcsb-embedding-model"
9
8
  FILE_NAME = "rcsb-embedding-model.pt"
@@ -25,7 +24,5 @@ def get_aggregator_model(device=None):
25
24
 
26
25
 
27
26
  def get_residue_model(device=None):
28
- return ESM3.from_pretrained(
29
- ESM3_OPEN_SMALL,
30
- device
31
- )
27
+ return esm_open(device)
28
+
@@ -1,16 +1,23 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rcsb-embedding-model
3
- Version: 0.0.35
3
+ Version: 0.0.37
4
4
  Summary: Protein Embedding Model for Structure Search
5
5
  Project-URL: Homepage, https://github.com/rcsb/rcsb-embedding-model
6
6
  Project-URL: Issues, https://github.com/rcsb/rcsb-embedding-model/issues
7
7
  Author-email: Joan Segura <joan.segura@rcsb.org>
8
- License-Expression: BSD-3-Clause
8
+ License: # Cambrian Non-Commercial License Agreement
9
+
10
+ This project is licensed under the EvolutionaryScale Cambrian Non-Commercial License Agreement.
11
+ See: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
9
12
  License-File: LICENSE.md
10
13
  Classifier: Operating System :: OS Independent
11
14
  Classifier: Programming Language :: Python :: 3
12
- Requires-Python: >=3.10
15
+ Requires-Python: >=3.11
16
+ Requires-Dist: biotite>=1.5.0
13
17
  Requires-Dist: esm>=3.2.0
18
+ Requires-Dist: hf-xet>=1.1.10
19
+ Requires-Dist: httpx>=0.28.1
20
+ Requires-Dist: huggingface-hub>=0.30.2
14
21
  Requires-Dist: importlib-metadata>=8.7.0
15
22
  Requires-Dist: lightning>=2.5.0
16
23
  Requires-Dist: typer>=0.15.0
@@ -18,7 +25,7 @@ Description-Content-Type: text/markdown
18
25
 
19
26
  # RCSB Embedding Model
20
27
 
21
- **Version** 0.0.26
28
+ **Version** 0.0.37
22
29
 
23
30
 
24
31
  ## Overview
@@ -125,4 +132,5 @@ Segura, J., Bittrich, S., et al. (2024). *Multi-scale structural similarity embe
125
132
 
126
133
  ## License
127
134
 
128
- This project is licensed under the BSD 3-Clause License. See [LICENSE.md](LICENSE.md) for details.
135
+ This project uses the EvolutionaryScale ESM-3 model and is distributed under the
136
+ [Cambrian Non-Commercial License Agreement](https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement).
@@ -0,0 +1,33 @@
1
+ rcsb_embedding_model/__init__.py,sha256=7YfYO-V-u__19eAZfQ3t5Gf2qrhd_gwQB8rHO0J0puw,306
2
+ rcsb_embedding_model/rcsb_structure_embedding.py,sha256=dKp9hXQO0JAnO4SEfjJ_mG_jHu3UxAPguv6jkOjp-BI,4487
3
+ rcsb_embedding_model/cli/args_utils.py,sha256=7nP2q8pL5dWK_U7opxtWmoFcYVwasky6elHk-dASFaI,165
4
+ rcsb_embedding_model/cli/inference.py,sha256=cXYaais4A3rVAkiucMdJxrYVxezKti8hL3DogBU0_2c,18788
5
+ rcsb_embedding_model/dataset/esm_prot_from_chain.py,sha256=_DYWLDEc492nhUdFRAQjwh0romF9iMwydFNi43-r0TY,4345
6
+ rcsb_embedding_model/dataset/esm_prot_from_structure.py,sha256=VU9BxNUApZ-pus_vmFGEU4eplcCH0fO7KBdic6X_NOM,2546
7
+ rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py,sha256=9iO7ZUcxl0TIBiwNieqjZFfnM7-7V3pl5abYiLzIY0I,2794
8
+ rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py,sha256=6bMjb0hfNbrTOqstnUVHbegw0xeUo7s6INnRsvP7V3I,3663
9
+ rcsb_embedding_model/dataset/residue_embedding_from_structure.py,sha256=tFHiXqGceZjAoYfVkeXG3sa2mz0gd5XBfm9EpJswcWI,2830
10
+ rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py,sha256=4OPaw55yGKHjY2iPpCnemcfwfmTZ4j5VrGQ2oIMQw6A,1343
11
+ rcsb_embedding_model/dataset/untils/__init__.py,sha256=O3WOukwvaKJvHUTALD3eYNHRacJo8o5BW7-ZulLZ65g,116
12
+ rcsb_embedding_model/dataset/untils/utils.py,sha256=SPiQ9aO2WLictO4R2JiNlo2ChhlANNMeIhbN0kq11kQ,578
13
+ rcsb_embedding_model/inference/assembly_inferece.py,sha256=b-mAfOJOO-s6gilOedZpaM90OTbhm_RQVqh2zKFG4dQ,2143
14
+ rcsb_embedding_model/inference/chain_inference.py,sha256=0HkV4EnLwg4ttQhf-xwOuSksZwEYDEChnHU4_A0xUXM,2782
15
+ rcsb_embedding_model/inference/esm_inference.py,sha256=nmHJYfSGjEqRPgb3l9s5fqtlyzdbAsiPz-OxHXBTgcI,2360
16
+ rcsb_embedding_model/inference/structure_inference.py,sha256=b44mY7VcCbjbtB35Mi9EhZoM18yyMaF579MKmzwB564,2405
17
+ rcsb_embedding_model/model/layers.py,sha256=lhKaWC4gTS_T5lHOP0mgnnP8nKTPEOm4MrjhESA4hE8,743
18
+ rcsb_embedding_model/model/residue_embedding_aggregator.py,sha256=k3UW63Ax8DtjCMdD3O5xNxtyAu28l2n3-Ab6nS0atm0,1967
19
+ rcsb_embedding_model/modules/chain_module.py,sha256=KsZw2uagO4rpAKWv6ivqEMxIEzgtfQFliHV_vX8kqtc,435
20
+ rcsb_embedding_model/modules/esm_module.py,sha256=otJRbCb319nCCob_4E1W_UClhkex9eDqcCyzWQO-vIs,740
21
+ rcsb_embedding_model/modules/structure_module.py,sha256=4js02XzKvhc_G26ELsGhJ9SCi_wlvtVolObxfWt3BhE,1077
22
+ rcsb_embedding_model/types/api_types.py,sha256=SCwALwvEb0KRKaoWKbuN7JyfOH-1whsI0Z4ki41dht8,1235
23
+ rcsb_embedding_model/utils/data.py,sha256=p7sbskLPBFtpZ-XM18wFY5Kei02Xso4wTWYTqHxJvVw,3841
24
+ rcsb_embedding_model/utils/model.py,sha256=Xi6bSUsB2-IsQS9610gXnbAEvYlK2V7eJC-cDE-JBTA,875
25
+ rcsb_embedding_model/utils/structure_parser.py,sha256=fSIbq_a_aEigCWY_1dUcW9d9Law0ZDOcZAxJlZL0Rt8,3377
26
+ rcsb_embedding_model/utils/structure_provider.py,sha256=eWtxjkPpmRfmil_DKR1J6miaXR3lQ28DF5O0qrqSgGA,786
27
+ rcsb_embedding_model/utils/esm/loaders.py,sha256=V7CADr7RReoztYmBQb2tjA8RBQIwFEjxBcocKAB_ea4,2221
28
+ rcsb_embedding_model/writer/batch_writer.py,sha256=rTFNasB0Xp4-XCNTXKeEWZxSrb7lvZytoRldJUWn9Jg,3312
29
+ rcsb_embedding_model-0.0.37.dist-info/METADATA,sha256=s_as4M_J_P6Pkcca8eWsRSx7FHBAV1Z-PMI_rnhFZ0A,5820
30
+ rcsb_embedding_model-0.0.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
31
+ rcsb_embedding_model-0.0.37.dist-info/entry_points.txt,sha256=MK11jTIEmaV-x4CkPX5IymDaVs7Ky_f2xxU8BJVZ_9Q,69
32
+ rcsb_embedding_model-0.0.37.dist-info/licenses/LICENSE.md,sha256=XyzxQe9PLJQlOmOOrqwmBaAfo0PAenOQ5NsgnApuVH4,230
33
+ rcsb_embedding_model-0.0.37.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ # Cambrian Non-Commercial License Agreement
2
+
3
+ This project is licensed under the EvolutionaryScale Cambrian Non-Commercial License Agreement.
4
+ See: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
@@ -1,30 +0,0 @@
1
- rcsb_embedding_model/__init__.py,sha256=7YfYO-V-u__19eAZfQ3t5Gf2qrhd_gwQB8rHO0J0puw,306
2
- rcsb_embedding_model/rcsb_structure_embedding.py,sha256=dKp9hXQO0JAnO4SEfjJ_mG_jHu3UxAPguv6jkOjp-BI,4487
3
- rcsb_embedding_model/cli/args_utils.py,sha256=7nP2q8pL5dWK_U7opxtWmoFcYVwasky6elHk-dASFaI,165
4
- rcsb_embedding_model/cli/inference.py,sha256=67_Tr3LWeA3T4KS5mkjq6tw77Ypy0R8IwMxEG2FwVqQ,19901
5
- rcsb_embedding_model/dataset/esm_prot_from_chain.py,sha256=sLbBapgchxciq4RgwHkw9yoNokGlOv2Z5PSaiWV5G64,4418
6
- rcsb_embedding_model/dataset/esm_prot_from_structure.py,sha256=3HzXCCc-UqmZNbJaeXHyUsSIZZxMc2erbxAPGIxSmfE,2621
7
- rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py,sha256=69h1VkrIXesHZi1cG3BOMMytSDeRzcBBP0_Z3Xz3dM8,2869
8
- rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py,sha256=Hd9oH-IVgY6d7Dxy5VfiwHvSaK-Wwhk6ccUBgOwl0TU,3740
9
- rcsb_embedding_model/dataset/residue_embedding_from_structure.py,sha256=1jmeEcCK41cAi2ZnqQkd667NWCAIGS3k6jGDF-WxtTk,2854
10
- rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py,sha256=4OPaw55yGKHjY2iPpCnemcfwfmTZ4j5VrGQ2oIMQw6A,1343
11
- rcsb_embedding_model/inference/assembly_inferece.py,sha256=8fPJjEXy1WsM5XB5U7KfdO5-Du6nEsawsaAjmWoXA9I,2329
12
- rcsb_embedding_model/inference/chain_inference.py,sha256=6f5wVzjtRtHU3BPMTe5k3nH_Nl440Am8BL8h1vmK1jI,2925
13
- rcsb_embedding_model/inference/esm_inference.py,sha256=rn6H43D8BYzMZbMu7UPsLYg2dgERmmpci5weNItrG5Q,2546
14
- rcsb_embedding_model/inference/structure_inference.py,sha256=0wqCW5wee_UQ8WJo9KG6SBHmosdNRzoJYEm7rMn4veA,2591
15
- rcsb_embedding_model/model/layers.py,sha256=lhKaWC4gTS_T5lHOP0mgnnP8nKTPEOm4MrjhESA4hE8,743
16
- rcsb_embedding_model/model/residue_embedding_aggregator.py,sha256=k3UW63Ax8DtjCMdD3O5xNxtyAu28l2n3-Ab6nS0atm0,1967
17
- rcsb_embedding_model/modules/chain_module.py,sha256=KsZw2uagO4rpAKWv6ivqEMxIEzgtfQFliHV_vX8kqtc,435
18
- rcsb_embedding_model/modules/esm_module.py,sha256=otJRbCb319nCCob_4E1W_UClhkex9eDqcCyzWQO-vIs,740
19
- rcsb_embedding_model/modules/structure_module.py,sha256=4js02XzKvhc_G26ELsGhJ9SCi_wlvtVolObxfWt3BhE,1077
20
- rcsb_embedding_model/types/api_types.py,sha256=SCwALwvEb0KRKaoWKbuN7JyfOH-1whsI0Z4ki41dht8,1235
21
- rcsb_embedding_model/utils/data.py,sha256=ThrcYycIizsV_Ycn6PPxF12JRr1m2K-v8TsIaVqx10A,3816
22
- rcsb_embedding_model/utils/model.py,sha256=xr3p02ohOgJ5UInwdIupN68Oq4yvNFhxobZRacS1adg,953
23
- rcsb_embedding_model/utils/structure_parser.py,sha256=fSIbq_a_aEigCWY_1dUcW9d9Law0ZDOcZAxJlZL0Rt8,3377
24
- rcsb_embedding_model/utils/structure_provider.py,sha256=eWtxjkPpmRfmil_DKR1J6miaXR3lQ28DF5O0qrqSgGA,786
25
- rcsb_embedding_model/writer/batch_writer.py,sha256=rTFNasB0Xp4-XCNTXKeEWZxSrb7lvZytoRldJUWn9Jg,3312
26
- rcsb_embedding_model-0.0.35.dist-info/METADATA,sha256=h5uREe5bIKpY4o-ZUzXF9tObRb_eewm9GJ44vgijdig,5351
27
- rcsb_embedding_model-0.0.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
28
- rcsb_embedding_model-0.0.35.dist-info/entry_points.txt,sha256=MK11jTIEmaV-x4CkPX5IymDaVs7Ky_f2xxU8BJVZ_9Q,69
29
- rcsb_embedding_model-0.0.35.dist-info/licenses/LICENSE.md,sha256=oUaHiKgfBkChth_Sm67WemEvatO1U0Go8LHjaskXY0w,1522
30
- rcsb_embedding_model-0.0.35.dist-info/RECORD,,
@@ -1,28 +0,0 @@
1
- BSD 3-Clause License
2
-
3
- Copyright (c) 2024, RCSB Protein Data Bank, UC San Diego
4
-
5
- Redistribution and use in source and binary forms, with or without
6
- modification, are permitted provided that the following conditions are met:
7
-
8
- 1. Redistributions of source code must retain the above copyright notice, this
9
- list of conditions and the following disclaimer.
10
-
11
- 2. Redistributions in binary form must reproduce the above copyright notice,
12
- this list of conditions and the following disclaimer in the documentation
13
- and/or other materials provided with the distribution.
14
-
15
- 3. Neither the name of the copyright holder nor the names of its
16
- contributors may be used to endorse or promote products derived from
17
- this software without specific prior written permission.
18
-
19
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.