rcsb-embedding-model 0.0.17__py3-none-any.whl → 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rcsb-embedding-model might be problematic. Click here for more details.

@@ -6,6 +6,7 @@ import typer
6
6
  from rcsb_embedding_model.cli.args_utils import arg_devices
7
7
  from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, SrcLocation, SrcProteinFrom, \
8
8
  StructureLocation, SrcAssemblyFrom, SrcTensorFrom, OutFormat
9
+ from rcsb_embedding_model.utils.data import adapt_csv_to_embedding_chain_stream
9
10
 
10
11
  app = typer.Typer(
11
12
  add_completion=False
@@ -22,7 +23,7 @@ def residue_embedding(
22
23
  file_okay=True,
23
24
  dir_okay=False,
24
25
  resolve_path=True,
25
- help='CSV file 4 (or 3) columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
26
+ help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files) | Output Embedding Name.'
26
27
  )],
27
28
  output_path: Annotated[typer.FileText, typer.Option(
28
29
  exists=True,
@@ -37,9 +38,6 @@ def residue_embedding(
37
38
  output_name: Annotated[str, typer.Option(
38
39
  help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
39
40
  )] = 'inference',
40
- src_from: Annotated[SrcProteinFrom, typer.Option(
41
- help='Use specific chains or all chains in a structure.'
42
- )] = SrcProteinFrom.chain,
43
41
  structure_location: Annotated[StructureLocation, typer.Option(
44
42
  help='Structure file location.'
45
43
  )] = StructureLocation.local,
@@ -69,7 +67,7 @@ def residue_embedding(
69
67
  predict(
70
68
  src_stream=src_file,
71
69
  src_location=SrcLocation.file,
72
- src_from=src_from,
70
+ src_from=SrcProteinFrom.chain,
73
71
  structure_location=structure_location,
74
72
  structure_format=structure_format,
75
73
  min_res_n=min_res_n,
@@ -94,7 +92,7 @@ def structure_embedding(
94
92
  file_okay=True,
95
93
  dir_okay=False,
96
94
  resolve_path=True,
97
- help='CSV file 4 (or 3) columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
95
+ help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files) | Output Embedding Name.'
98
96
  )],
99
97
  output_path: Annotated[typer.FileText, typer.Option(
100
98
  exists=True,
@@ -106,9 +104,6 @@ def structure_embedding(
106
104
  output_name: Annotated[str, typer.Option(
107
105
  help='File name for storing embeddings as a single JSON file.'
108
106
  )] = 'inference',
109
- src_from: Annotated[SrcProteinFrom, typer.Option(
110
- help='Use specific chains or all chains in a structure.'
111
- )] = SrcProteinFrom.chain,
112
107
  structure_location: Annotated[StructureLocation, typer.Option(
113
108
  help='Structure file location.'
114
109
  )] = StructureLocation.local,
@@ -138,7 +133,7 @@ def structure_embedding(
138
133
  predict(
139
134
  src_stream=src_file,
140
135
  src_location=SrcLocation.file,
141
- src_from=src_from,
136
+ src_from=SrcProteinFrom.chain,
142
137
  structure_location=structure_location,
143
138
  structure_format=structure_format,
144
139
  min_res_n=min_res_n,
@@ -162,7 +157,7 @@ def chain_embedding(
162
157
  file_okay=True,
163
158
  dir_okay=False,
164
159
  resolve_path=True,
165
- help='Option 1 (src-from=file) - CSV file 2 columns: Residue Embedding Torch Tensor File | Output Embedding Name. Option 2 (src-from=structure) - CSV file 3 columns: Structure Name | Structure File Path or URL (switch structure-location) | Output Embedding Name.'
160
+ help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files) | Output Embedding Name.'
166
161
  )],
167
162
  output_path: Annotated[typer.FileText, typer.Option(
168
163
  exists=True,
@@ -171,22 +166,19 @@ def chain_embedding(
171
166
  resolve_path=True,
172
167
  help='Output path to store predictions. Embeddings are stored as csv files.'
173
168
  )],
169
+ res_embedding_location: Annotated[typer.FileText, typer.Option(
170
+ exists=True,
171
+ file_okay=False,
172
+ dir_okay=True,
173
+ resolve_path=True,
174
+ help='Path where residue level embeddings are located.'
175
+ )],
174
176
  output_format: Annotated[OutFormat, typer.Option(
175
177
  help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
176
178
  )] = OutFormat.separated,
177
179
  output_name: Annotated[str, typer.Option(
178
180
  help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
179
181
  )] = 'inference',
180
- res_embedding_location: Annotated[typer.FileText, typer.Option(
181
- exists=True,
182
- file_okay=False,
183
- dir_okay=True,
184
- resolve_path=True,
185
- help='Path where residue level embeddings are located. This argument is required if src-from=structure.'
186
- )] = None,
187
- src_from: Annotated[SrcTensorFrom, typer.Option(
188
- help='Use file names or all chains in a structure.'
189
- )] = SrcTensorFrom.file,
190
182
  structure_location: Annotated[StructureLocation, typer.Option(
191
183
  help='Structure file location.'
192
184
  )] = StructureLocation.local,
@@ -214,10 +206,10 @@ def chain_embedding(
214
206
  ):
215
207
  from rcsb_embedding_model.inference.chain_inference import predict
216
208
  predict(
217
- src_stream=src_file,
209
+ src_stream=adapt_csv_to_embedding_chain_stream(src_file, res_embedding_location),
218
210
  res_embedding_location=res_embedding_location,
219
- src_location=SrcLocation.file,
220
- src_from=src_from,
211
+ src_location=SrcLocation.stream,
212
+ src_from=SrcTensorFrom.file,
221
213
  structure_location=structure_location,
222
214
  structure_format=structure_format,
223
215
  min_res_n=min_res_n,
@@ -263,9 +255,6 @@ def assembly_embedding(
263
255
  output_name: Annotated[str, typer.Option(
264
256
  help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
265
257
  )] = 'inference',
266
- src_from: Annotated[SrcAssemblyFrom, typer.Option(
267
- help='Use specific assembly or all assemblies in a structure.'
268
- )] = SrcAssemblyFrom.assembly,
269
258
  structure_location: Annotated[StructureLocation, typer.Option(
270
259
  help='Structure file location.'
271
260
  )] = StructureLocation.local,
@@ -299,7 +288,7 @@ def assembly_embedding(
299
288
  src_stream=src_file,
300
289
  res_embedding_location=res_embedding_location,
301
290
  src_location=SrcLocation.file,
302
- src_from=src_from,
291
+ src_from=SrcAssemblyFrom.assembly,
303
292
  structure_location=structure_location,
304
293
  structure_format=structure_format,
305
294
  min_res_n=min_res_n,
@@ -319,12 +308,19 @@ def assembly_embedding(
319
308
  help="Calculate chain and assembly embeddings from structural files. Predictions are stored as csv files."
320
309
  )
321
310
  def complete_embedding(
322
- src_file: Annotated[typer.FileText, typer.Option(
311
+ src_chain_file: Annotated[typer.FileText, typer.Option(
323
312
  exists=True,
324
313
  file_okay=True,
325
314
  dir_okay=False,
326
315
  resolve_path=True,
327
- help='CSV file 3 columns: Structure Name | Structure File Path or URL | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
316
+ help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files) | Output Embedding Name.'
317
+ )],
318
+ src_assembly_file: Annotated[typer.FileText, typer.Option(
319
+ exists=True,
320
+ file_okay=True,
321
+ dir_okay=False,
322
+ resolve_path=True,
323
+ help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Assembly Id | Output embedding name.'
328
324
  )],
329
325
  output_path: Annotated[typer.FileText, typer.Option(
330
326
  exists=True,
@@ -333,7 +329,7 @@ def complete_embedding(
333
329
  resolve_path=True,
334
330
  help='Output path to store predictions. Embeddings are stored as a single DataFrame file (see output_name).'
335
331
  )],
336
- esm_output_path: Annotated[typer.FileText, typer.Option(
332
+ res_embedding_location: Annotated[typer.FileText, typer.Option(
337
333
  exists=True,
338
334
  file_okay=False,
339
335
  dir_okay=True,
@@ -372,9 +368,8 @@ def complete_embedding(
372
368
  )] = tuple(['auto'])
373
369
  ):
374
370
  residue_embedding(
375
- src_file=src_file,
376
- src_from=SrcProteinFrom.structure,
377
- output_path=esm_output_path,
371
+ src_file=src_chain_file,
372
+ output_path=res_embedding_location,
378
373
  output_format=OutFormat.separated,
379
374
  structure_location=structure_location,
380
375
  structure_format=structure_format,
@@ -386,12 +381,11 @@ def complete_embedding(
386
381
  devices=devices,
387
382
  )
388
383
  chain_embedding(
389
- src_file=src_file,
390
- src_from=SrcTensorFrom.structure,
384
+ src_file=src_chain_file,
391
385
  output_path=output_path,
392
386
  output_format=output_format,
393
387
  output_name=f"{output_name}-chain",
394
- res_embedding_location=esm_output_path,
388
+ res_embedding_location=res_embedding_location,
395
389
  structure_location=structure_location,
396
390
  structure_format=structure_format,
397
391
  min_res_n=min_res_n,
@@ -402,12 +396,11 @@ def complete_embedding(
402
396
  devices=devices
403
397
  )
404
398
  assembly_embedding(
405
- src_file=src_file,
406
- src_from=SrcAssemblyFrom.structure,
399
+ src_file=src_assembly_file,
407
400
  output_path=output_path,
408
401
  output_format=output_format,
409
402
  output_name=f"{output_name}-assembly",
410
- res_embedding_location=esm_output_path,
403
+ res_embedding_location=res_embedding_location,
411
404
  structure_location=structure_location,
412
405
  structure_format=structure_format,
413
406
  min_res_n=min_res_n,
@@ -70,6 +70,7 @@ class EsmProtFromChain(Dataset):
70
70
  for atom_ch in chain_iter(structure):
71
71
  protein_chain = ProteinChain.from_atomarray(rename_atom_ch(atom_ch))
72
72
  return ESMProtein.from_protein_chain(protein_chain), item_name
73
+ return None
73
74
 
74
75
 
75
76
  if __name__ == '__main__':
@@ -1,4 +1,5 @@
1
1
  from esm.sdk.api import SamplingConfig
2
+ from esm.sdk import batch_executor
2
3
  from lightning import LightningModule
3
4
 
4
5
  from rcsb_embedding_model.utils.model import get_residue_model
@@ -14,11 +15,13 @@ class EsmModule(LightningModule):
14
15
 
15
16
  def predict_step(self, prot_batch, batch_idx):
16
17
  prot_embeddings = []
17
- prot_names = []
18
- for esm_prot, name in prot_batch:
19
- embeddings = self.esm3.forward_and_sample(
18
+ def __batch_embedding(esm_prot):
19
+ return self.esm3.forward_and_sample(
20
20
  self.esm3.encode(esm_prot), SamplingConfig(return_per_residue_embeddings=True)
21
21
  ).per_residue_embedding
22
- prot_embeddings.append(embeddings)
23
- prot_names.append(name)
24
- return tuple(prot_embeddings), tuple(prot_names)
22
+ with batch_executor() as executor:
23
+ prot_embeddings = executor.execute_batch(
24
+ user_func=__batch_embedding,
25
+ esm_prot=[esm_prot for esm_prot, name in prot_batch]
26
+ )
27
+ return tuple(prot_embeddings), tuple([name for esm_prot, name in prot_batch])
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from io import StringIO
2
3
 
3
4
  import requests
@@ -76,3 +77,9 @@ def concatenate_tensors(file_list, max_residues, dim=0):
76
77
  return tensor_cat
77
78
  else:
78
79
  raise ValueError("No valid tensors were loaded to concatenate.")
80
+
81
+ def adapt_csv_to_embedding_chain_stream(src_file, res_embedding_location):
82
+ def __parse_row(row):
83
+ r = row.split(",")
84
+ return os.path.join(res_embedding_location, f"{r[0]}.{r[2]}.pt"), f"{r[0]}.{r[2]}"
85
+ return tuple([__parse_row(r) for r in open(src_file)])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rcsb-embedding-model
3
- Version: 0.0.17
3
+ Version: 0.0.18
4
4
  Summary: Protein Embedding Model for Structure Search
5
5
  Project-URL: Homepage, https://github.com/rcsb/rcsb-embedding-model
6
6
  Project-URL: Issues, https://github.com/rcsb/rcsb-embedding-model/issues
@@ -12,13 +12,12 @@ Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10
13
13
  Requires-Dist: esm>=3.2.0
14
14
  Requires-Dist: lightning>=2.5.0
15
- Requires-Dist: torch>=2.2.0
16
15
  Requires-Dist: typer>=0.15.0
17
16
  Description-Content-Type: text/markdown
18
17
 
19
18
  # RCSB Embedding Model
20
19
 
21
- **Version** 0.0.17
20
+ **Version** 0.0.18
22
21
 
23
22
 
24
23
  ## Overview
@@ -48,11 +47,10 @@ If you are interested in training the model with a new dataset, visit the [rcsb-
48
47
 
49
48
  **Requirements:**
50
49
 
51
- - Python ≥ 3.10
52
- - ESM 3.2.0
53
- - PyTorch ≥ 2.2.0
54
- - Lightning2.5.0
55
- - Typer ≥ 0.15.0
50
+ - Python ≥ 3.10
51
+ - ESM == 3.1.1
52
+ - Lightning ≥ 2.5.0
53
+ - Typer0.15.0
56
54
 
57
55
  ---
58
56
 
@@ -1,8 +1,8 @@
1
1
  rcsb_embedding_model/__init__.py,sha256=r3gLdeBIXkQEQA_K6QcRPO-TtYuAQSutk6pXRUE_nas,120
2
2
  rcsb_embedding_model/rcsb_structure_embedding.py,sha256=dKp9hXQO0JAnO4SEfjJ_mG_jHu3UxAPguv6jkOjp-BI,4487
3
3
  rcsb_embedding_model/cli/args_utils.py,sha256=7nP2q8pL5dWK_U7opxtWmoFcYVwasky6elHk-dASFaI,165
4
- rcsb_embedding_model/cli/inference.py,sha256=XmnRwygWYQkPqeJi4I1H2jjo24IxXzt_EihdYZ7LLqA,18696
5
- rcsb_embedding_model/dataset/esm_prot_from_chain.py,sha256=u6vu_2CN6vaYAk6kpvHAOgHuEHjXJl3fukMk-tDr_6E,3486
4
+ rcsb_embedding_model/cli/inference.py,sha256=PE36a1d6nfhNsuqCCJbos2JpZE0oCJmIf2mNw7Nz8GI,18231
5
+ rcsb_embedding_model/dataset/esm_prot_from_chain.py,sha256=3hWo2nWunFZNTfYCTiPvVoJlkWQbRmvlehFw-6B4z6A,3506
6
6
  rcsb_embedding_model/dataset/esm_prot_from_structure.py,sha256=TeITPdi1uc3qLQ-Pgn807oH6eM0LYv-67RE50ZT4dLI,2551
7
7
  rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py,sha256=worRiNqOJRjyr693TaillsS65bdTdGOoHfwyT9yE1O4,2866
8
8
  rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py,sha256=JG4rrhziIUtdTmbuTbMbEYHrvlda4m5VWvdJXe_Sv3c,3449
@@ -15,16 +15,16 @@ rcsb_embedding_model/inference/structure_inference.py,sha256=lqbDBPSea8IoNyQXl83
15
15
  rcsb_embedding_model/model/layers.py,sha256=lhKaWC4gTS_T5lHOP0mgnnP8nKTPEOm4MrjhESA4hE8,743
16
16
  rcsb_embedding_model/model/residue_embedding_aggregator.py,sha256=k3UW63Ax8DtjCMdD3O5xNxtyAu28l2n3-Ab6nS0atm0,1967
17
17
  rcsb_embedding_model/modules/chain_module.py,sha256=sDSPXJmWuU2C3lt1NorlbUVWZvRSLzumPdFQk01h3VI,403
18
- rcsb_embedding_model/modules/esm_module.py,sha256=CTHGOATXiarqZsBsZ8oxGJBj20A73186Slpr0EzMJsE,770
18
+ rcsb_embedding_model/modules/esm_module.py,sha256=4IQgrNQlGThxl0PhobVzyp7N3FcyAbvek_KxJozGImQ,945
19
19
  rcsb_embedding_model/modules/structure_module.py,sha256=dEtDNdWo1j2sSDa0JiOHQfEfQzIWqSLEKpvOX0GrXZ4,1048
20
20
  rcsb_embedding_model/types/api_types.py,sha256=SCwALwvEb0KRKaoWKbuN7JyfOH-1whsI0Z4ki41dht8,1235
21
- rcsb_embedding_model/utils/data.py,sha256=ODz6GG6IAhgAlLh3tcIP6-JVHX8Bb_-E745Lvc_oR84,2934
21
+ rcsb_embedding_model/utils/data.py,sha256=FVb6tzoX4SrJf3Fr6UFbxZJQsUr9xp5RbkK6nqXhcuQ,3222
22
22
  rcsb_embedding_model/utils/model.py,sha256=rpZa-gfm3cEtbBd7UXMHrZv3x6f0AC8TJT3gtrSxr5I,852
23
23
  rcsb_embedding_model/utils/structure_parser.py,sha256=IWMQ8brlEMe6_ND-DBESOli8vlqHxladTssjbM9RSKw,2751
24
24
  rcsb_embedding_model/utils/structure_provider.py,sha256=eWtxjkPpmRfmil_DKR1J6miaXR3lQ28DF5O0qrqSgGA,786
25
25
  rcsb_embedding_model/writer/batch_writer.py,sha256=rTFNasB0Xp4-XCNTXKeEWZxSrb7lvZytoRldJUWn9Jg,3312
26
- rcsb_embedding_model-0.0.17.dist-info/METADATA,sha256=ZIV-WqJGsSmZd79Ks455CTxSgN2J2JSD7vqr3XPx3nE,5368
27
- rcsb_embedding_model-0.0.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
28
- rcsb_embedding_model-0.0.17.dist-info/entry_points.txt,sha256=MK11jTIEmaV-x4CkPX5IymDaVs7Ky_f2xxU8BJVZ_9Q,69
29
- rcsb_embedding_model-0.0.17.dist-info/licenses/LICENSE.md,sha256=oUaHiKgfBkChth_Sm67WemEvatO1U0Go8LHjaskXY0w,1522
30
- rcsb_embedding_model-0.0.17.dist-info/RECORD,,
26
+ rcsb_embedding_model-0.0.18.dist-info/METADATA,sha256=PzSnwGeAeUbYxhpRBgEiZZdj2bGdLrT8QAy0uB_BxNQ,5310
27
+ rcsb_embedding_model-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
28
+ rcsb_embedding_model-0.0.18.dist-info/entry_points.txt,sha256=MK11jTIEmaV-x4CkPX5IymDaVs7Ky_f2xxU8BJVZ_9Q,69
29
+ rcsb_embedding_model-0.0.18.dist-info/licenses/LICENSE.md,sha256=oUaHiKgfBkChth_Sm67WemEvatO1U0Go8LHjaskXY0w,1522
30
+ rcsb_embedding_model-0.0.18.dist-info/RECORD,,