rcsb-embedding-model 0.0.16__tar.gz → 0.0.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rcsb-embedding-model might be problematic. Click here for more details.

Files changed (45) hide show
  1. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/PKG-INFO +6 -8
  2. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/README.md +5 -6
  3. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/pyproject.toml +1 -2
  4. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/cli/inference.py +154 -33
  5. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/dataset/esm_prot_from_chain.py +2 -1
  6. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/dataset/esm_prot_from_structure.py +1 -1
  7. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py +2 -2
  8. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py +2 -2
  9. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/dataset/residue_embedding_from_structure.py +1 -1
  10. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py +1 -1
  11. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/inference/assembly_inferece.py +7 -2
  12. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/inference/chain_inference.py +8 -6
  13. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/inference/esm_inference.py +9 -5
  14. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/inference/structure_inference.py +9 -7
  15. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/modules/esm_module.py +9 -6
  16. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/types/api_types.py +5 -1
  17. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/utils/data.py +7 -0
  18. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/writer/batch_writer.py +18 -0
  19. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/.github/workflows/_workflow-docker.yaml +0 -0
  20. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/.github/workflows/publish.yaml +0 -0
  21. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/.gitignore +0 -0
  22. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/Dockerfile +0 -0
  23. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/LICENSE.md +0 -0
  24. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/assets/embedding-model-architecture.png +0 -0
  25. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/examples/esm_embeddings.py +0 -0
  26. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/__init__.py +0 -0
  27. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/cli/args_utils.py +0 -0
  28. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/model/layers.py +0 -0
  29. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/model/residue_embedding_aggregator.py +0 -0
  30. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/modules/chain_module.py +0 -0
  31. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/modules/structure_module.py +0 -0
  32. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/rcsb_structure_embedding.py +0 -0
  33. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/utils/model.py +0 -0
  34. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/utils/structure_parser.py +0 -0
  35. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/src/rcsb_embedding_model/utils/structure_provider.py +0 -0
  36. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/tests/resources/embeddings/1acb.A.pt +0 -0
  37. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/tests/resources/embeddings/1acb.B.pt +0 -0
  38. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/tests/resources/embeddings/2uzi.A.pt +0 -0
  39. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/tests/resources/embeddings/2uzi.B.pt +0 -0
  40. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/tests/resources/embeddings/2uzi.C.pt +0 -0
  41. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/tests/resources/pdb/1acb.cif +0 -0
  42. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/tests/resources/pdb/2uzi.cif +0 -0
  43. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/tests/test_embedding_model.py +0 -0
  44. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/tests/test_inference.py +0 -0
  45. {rcsb_embedding_model-0.0.16 → rcsb_embedding_model-0.0.18}/tests/test_remote_inference.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rcsb-embedding-model
3
- Version: 0.0.16
3
+ Version: 0.0.18
4
4
  Summary: Protein Embedding Model for Structure Search
5
5
  Project-URL: Homepage, https://github.com/rcsb/rcsb-embedding-model
6
6
  Project-URL: Issues, https://github.com/rcsb/rcsb-embedding-model/issues
@@ -12,13 +12,12 @@ Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10
13
13
  Requires-Dist: esm>=3.2.0
14
14
  Requires-Dist: lightning>=2.5.0
15
- Requires-Dist: torch>=2.2.0
16
15
  Requires-Dist: typer>=0.15.0
17
16
  Description-Content-Type: text/markdown
18
17
 
19
18
  # RCSB Embedding Model
20
19
 
21
- **Version** 0.0.16
20
+ **Version** 0.0.18
22
21
 
23
22
 
24
23
  ## Overview
@@ -48,11 +47,10 @@ If you are interested in training the model with a new dataset, visit the [rcsb-
48
47
 
49
48
  **Requirements:**
50
49
 
51
- - Python ≥ 3.10
52
- - ESM 3.2.0
53
- - PyTorch ≥ 2.2.0
54
- - Lightning2.5.0
55
- - Typer ≥ 0.15.0
50
+ - Python ≥ 3.10
51
+ - ESM == 3.1.1
52
+ - Lightning ≥ 2.5.0
53
+ - Typer0.15.0
56
54
 
57
55
  ---
58
56
 
@@ -1,6 +1,6 @@
1
1
  # RCSB Embedding Model
2
2
 
3
- **Version** 0.0.16
3
+ **Version** 0.0.18
4
4
 
5
5
 
6
6
  ## Overview
@@ -30,11 +30,10 @@ If you are interested in training the model with a new dataset, visit the [rcsb-
30
30
 
31
31
  **Requirements:**
32
32
 
33
- - Python ≥ 3.10
34
- - ESM 3.2.0
35
- - PyTorch ≥ 2.2.0
36
- - Lightning2.5.0
37
- - Typer ≥ 0.15.0
33
+ - Python ≥ 3.10
34
+ - ESM == 3.1.1
35
+ - Lightning ≥ 2.5.0
36
+ - Typer0.15.0
38
37
 
39
38
  ---
40
39
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rcsb-embedding-model"
3
- version = "0.0.16"
3
+ version = "0.0.18"
4
4
  authors = [
5
5
  { name="Joan Segura", email="joan.segura@rcsb.org" },
6
6
  ]
@@ -15,7 +15,6 @@ license = "BSD-3-Clause"
15
15
  license-files = ["LICEN[CS]E*"]
16
16
  dependencies=[
17
17
  "esm >= 3.2.0",
18
- "torch >= 2.2.0",
19
18
  "lightning >= 2.5.0",
20
19
  "typer >= 0.15.0"
21
20
  ]
@@ -5,7 +5,8 @@ import typer
5
5
 
6
6
  from rcsb_embedding_model.cli.args_utils import arg_devices
7
7
  from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, SrcLocation, SrcProteinFrom, \
8
- StructureLocation, SrcAssemblyFrom, SrcTensorFrom
8
+ StructureLocation, SrcAssemblyFrom, SrcTensorFrom, OutFormat
9
+ from rcsb_embedding_model.utils.data import adapt_csv_to_embedding_chain_stream
9
10
 
10
11
  app = typer.Typer(
11
12
  add_completion=False
@@ -22,7 +23,7 @@ def residue_embedding(
22
23
  file_okay=True,
23
24
  dir_okay=False,
24
25
  resolve_path=True,
25
- help='CSV file 4 (or 3) columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
26
+ help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files) | Output Embedding Name.'
26
27
  )],
27
28
  output_path: Annotated[typer.FileText, typer.Option(
28
29
  exists=True,
@@ -31,9 +32,12 @@ def residue_embedding(
31
32
  resolve_path=True,
32
33
  help='Output path to store predictions. Embeddings are stored as torch tensor files.'
33
34
  )],
34
- src_from: Annotated[SrcProteinFrom, typer.Option(
35
- help='Use specific chains or all chains in a structure.'
36
- )] = SrcProteinFrom.chain,
35
+ output_format: Annotated[OutFormat, typer.Option(
36
+ help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
37
+ )] = OutFormat.separated,
38
+ output_name: Annotated[str, typer.Option(
39
+ help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
40
+ )] = 'inference',
37
41
  structure_location: Annotated[StructureLocation, typer.Option(
38
42
  help='Structure file location.'
39
43
  )] = StructureLocation.local,
@@ -62,8 +66,8 @@ def residue_embedding(
62
66
  from rcsb_embedding_model.inference.esm_inference import predict
63
67
  predict(
64
68
  src_stream=src_file,
65
- src_location=SrcLocation.local,
66
- src_from=src_from,
69
+ src_location=SrcLocation.file,
70
+ src_from=SrcProteinFrom.chain,
67
71
  structure_location=structure_location,
68
72
  structure_format=structure_format,
69
73
  min_res_n=min_res_n,
@@ -72,6 +76,8 @@ def residue_embedding(
72
76
  num_nodes=num_nodes,
73
77
  accelerator=accelerator,
74
78
  devices=arg_devices(devices),
79
+ out_format=output_format,
80
+ out_name=output_name,
75
81
  out_path=output_path
76
82
  )
77
83
 
@@ -86,7 +92,7 @@ def structure_embedding(
86
92
  file_okay=True,
87
93
  dir_okay=False,
88
94
  resolve_path=True,
89
- help='CSV file 4 (or 3) columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
95
+ help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files) | Output Embedding Name.'
90
96
  )],
91
97
  output_path: Annotated[typer.FileText, typer.Option(
92
98
  exists=True,
@@ -95,12 +101,9 @@ def structure_embedding(
95
101
  resolve_path=True,
96
102
  help='Output path to store predictions. Embeddings are stored as a single DataFrame file (see out-df-name).'
97
103
  )],
98
- out_df_name: Annotated[str, typer.Option(
99
- help='File name (without extension) for storing embeddings as a pandas DataFrame pickle (.pkl). The DataFrame contains 2 columns: Id | Embedding'
100
- )],
101
- src_from: Annotated[SrcProteinFrom, typer.Option(
102
- help='Use specific chains or all chains in a structure.'
103
- )] = SrcProteinFrom.chain,
104
+ output_name: Annotated[str, typer.Option(
105
+ help='File name for storing embeddings as a single JSON file.'
106
+ )] = 'inference',
104
107
  structure_location: Annotated[StructureLocation, typer.Option(
105
108
  help='Structure file location.'
106
109
  )] = StructureLocation.local,
@@ -129,8 +132,8 @@ def structure_embedding(
129
132
  from rcsb_embedding_model.inference.structure_inference import predict
130
133
  predict(
131
134
  src_stream=src_file,
132
- src_location=SrcLocation.local,
133
- src_from=src_from,
135
+ src_location=SrcLocation.file,
136
+ src_from=SrcProteinFrom.chain,
134
137
  structure_location=structure_location,
135
138
  structure_format=structure_format,
136
139
  min_res_n=min_res_n,
@@ -140,7 +143,7 @@ def structure_embedding(
140
143
  accelerator=accelerator,
141
144
  devices=arg_devices(devices),
142
145
  out_path=output_path,
143
- out_df_name=out_df_name
146
+ out_name=output_name
144
147
  )
145
148
 
146
149
 
@@ -154,7 +157,7 @@ def chain_embedding(
154
157
  file_okay=True,
155
158
  dir_okay=False,
156
159
  resolve_path=True,
157
- help='Option 1 (src-from=file) - CSV file 2 columns: Residue Embedding Torch Tensor File | Output Embedding Name. Option 2 (src-from=structure) - CSV file 3 columns: Structure Name | Structure File Path or URL (switch structure-location) | Output Embedding Name.'
160
+ help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files) | Output Embedding Name.'
158
161
  )],
159
162
  output_path: Annotated[typer.FileText, typer.Option(
160
163
  exists=True,
@@ -168,11 +171,14 @@ def chain_embedding(
168
171
  file_okay=False,
169
172
  dir_okay=True,
170
173
  resolve_path=True,
171
- help='Path where residue level embeddings are located. This argument is required if src-from=structure.'
172
- )] = None,
173
- src_from: Annotated[SrcTensorFrom, typer.Option(
174
- help='Use file names or all chains in a structure.'
175
- )] = SrcTensorFrom.file,
174
+ help='Path where residue level embeddings are located.'
175
+ )],
176
+ output_format: Annotated[OutFormat, typer.Option(
177
+ help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
178
+ )] = OutFormat.separated,
179
+ output_name: Annotated[str, typer.Option(
180
+ help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
181
+ )] = 'inference',
176
182
  structure_location: Annotated[StructureLocation, typer.Option(
177
183
  help='Structure file location.'
178
184
  )] = StructureLocation.local,
@@ -200,10 +206,10 @@ def chain_embedding(
200
206
  ):
201
207
  from rcsb_embedding_model.inference.chain_inference import predict
202
208
  predict(
203
- src_stream=src_file,
209
+ src_stream=adapt_csv_to_embedding_chain_stream(src_file, res_embedding_location),
204
210
  res_embedding_location=res_embedding_location,
205
- src_location=SrcLocation.local,
206
- src_from=src_from,
211
+ src_location=SrcLocation.stream,
212
+ src_from=SrcTensorFrom.file,
207
213
  structure_location=structure_location,
208
214
  structure_format=structure_format,
209
215
  min_res_n=min_res_n,
@@ -212,7 +218,9 @@ def chain_embedding(
212
218
  num_nodes=num_nodes,
213
219
  accelerator=accelerator,
214
220
  devices=arg_devices(devices),
215
- out_path=output_path
221
+ out_path=output_path,
222
+ out_format=output_format,
223
+ out_name=output_name
216
224
  )
217
225
 
218
226
  @app.command(
@@ -241,9 +249,12 @@ def assembly_embedding(
241
249
  resolve_path=True,
242
250
  help='Output path to store predictions. Embeddings are stored as csv files.'
243
251
  )],
244
- src_from: Annotated[SrcAssemblyFrom, typer.Option(
245
- help='Use specific assembly or all assemblies in a structure.'
246
- )] = SrcAssemblyFrom.assembly,
252
+ output_format: Annotated[OutFormat, typer.Option(
253
+ help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
254
+ )] = OutFormat.separated,
255
+ output_name: Annotated[str, typer.Option(
256
+ help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
257
+ )] = 'inference',
247
258
  structure_location: Annotated[StructureLocation, typer.Option(
248
259
  help='Structure file location.'
249
260
  )] = StructureLocation.local,
@@ -276,8 +287,8 @@ def assembly_embedding(
276
287
  predict(
277
288
  src_stream=src_file,
278
289
  res_embedding_location=res_embedding_location,
279
- src_location=SrcLocation.local,
280
- src_from=src_from,
290
+ src_location=SrcLocation.file,
291
+ src_from=SrcAssemblyFrom.assembly,
281
292
  structure_location=structure_location,
282
293
  structure_format=structure_format,
283
294
  min_res_n=min_res_n,
@@ -287,7 +298,117 @@ def assembly_embedding(
287
298
  num_nodes=num_nodes,
288
299
  accelerator=accelerator,
289
300
  devices=arg_devices(devices),
290
- out_path=output_path
301
+ out_path=output_path,
302
+ out_format=output_format,
303
+ out_name=output_name
304
+ )
305
+
306
+ @app.command(
307
+ name="complete-embedding",
308
+ help="Calculate chain and assembly embeddings from structural files. Predictions are stored as csv files."
309
+ )
310
+ def complete_embedding(
311
+ src_chain_file: Annotated[typer.FileText, typer.Option(
312
+ exists=True,
313
+ file_okay=True,
314
+ dir_okay=False,
315
+ resolve_path=True,
316
+ help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Chain Id (asym_i for cif files) | Output Embedding Name.'
317
+ )],
318
+ src_assembly_file: Annotated[typer.FileText, typer.Option(
319
+ exists=True,
320
+ file_okay=True,
321
+ dir_okay=False,
322
+ resolve_path=True,
323
+ help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Assembly Id | Output embedding name.'
324
+ )],
325
+ output_path: Annotated[typer.FileText, typer.Option(
326
+ exists=True,
327
+ file_okay=False,
328
+ dir_okay=True,
329
+ resolve_path=True,
330
+ help='Output path to store predictions. Embeddings are stored as a single DataFrame file (see output_name).'
331
+ )],
332
+ res_embedding_location: Annotated[typer.FileText, typer.Option(
333
+ exists=True,
334
+ file_okay=False,
335
+ dir_okay=True,
336
+ resolve_path=True,
337
+ help='Output path to store ESM predictions.'
338
+ )],
339
+ output_format: Annotated[OutFormat, typer.Option(
340
+ help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
341
+ )] = OutFormat.separated,
342
+ output_name: Annotated[str, typer.Option(
343
+ help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
344
+ )] = 'inference',
345
+ structure_location: Annotated[StructureLocation, typer.Option(
346
+ help='Structure file location.'
347
+ )] = StructureLocation.local,
348
+ structure_format: Annotated[StructureFormat, typer.Option(
349
+ help='Structure file format.'
350
+ )] = StructureFormat.mmcif,
351
+ min_res_n: Annotated[int, typer.Option(
352
+ help='When using all chains in a structure, consider only chains with more than <min_res_n> residues.'
353
+ )] = 0,
354
+ batch_size: Annotated[int, typer.Option(
355
+ help='Number of samples processed together in one iteration.'
356
+ )] = 1,
357
+ num_workers: Annotated[int, typer.Option(
358
+ help='Number of subprocesses to use for data loading.'
359
+ )] = 0,
360
+ num_nodes: Annotated[int, typer.Option(
361
+ help='Number of nodes to use for inference.'
362
+ )] = 1,
363
+ accelerator: Annotated[Accelerator, typer.Option(
364
+ help='Device used for inference.'
365
+ )] = Accelerator.auto,
366
+ devices: Annotated[List[str], typer.Option(
367
+ help='The devices to use. Can be set to a positive number or "auto". Repeat this argument to indicate multiple indices of devices. "auto" for automatic selection based on the chosen accelerator.'
368
+ )] = tuple(['auto'])
369
+ ):
370
+ residue_embedding(
371
+ src_file=src_chain_file,
372
+ output_path=res_embedding_location,
373
+ output_format=OutFormat.separated,
374
+ structure_location=structure_location,
375
+ structure_format=structure_format,
376
+ min_res_n=min_res_n,
377
+ batch_size=batch_size,
378
+ num_workers=num_workers,
379
+ num_nodes=num_nodes,
380
+ accelerator=accelerator,
381
+ devices=devices,
382
+ )
383
+ chain_embedding(
384
+ src_file=src_chain_file,
385
+ output_path=output_path,
386
+ output_format=output_format,
387
+ output_name=f"{output_name}-chain",
388
+ res_embedding_location=res_embedding_location,
389
+ structure_location=structure_location,
390
+ structure_format=structure_format,
391
+ min_res_n=min_res_n,
392
+ batch_size=batch_size,
393
+ num_workers=num_workers,
394
+ num_nodes=num_nodes,
395
+ accelerator=accelerator,
396
+ devices=devices
397
+ )
398
+ assembly_embedding(
399
+ src_file=src_assembly_file,
400
+ output_path=output_path,
401
+ output_format=output_format,
402
+ output_name=f"{output_name}-assembly",
403
+ res_embedding_location=res_embedding_location,
404
+ structure_location=structure_location,
405
+ structure_format=structure_format,
406
+ min_res_n=min_res_n,
407
+ batch_size=batch_size,
408
+ num_workers=num_workers,
409
+ num_nodes=num_nodes,
410
+ accelerator=accelerator,
411
+ devices=devices
291
412
  )
292
413
 
293
414
 
@@ -27,7 +27,7 @@ class EsmProtFromChain(Dataset):
27
27
  def __init__(
28
28
  self,
29
29
  src_stream,
30
- src_location=SrcLocation.local,
30
+ src_location=SrcLocation.file,
31
31
  structure_location=StructureLocation.local,
32
32
  structure_format=StructureFormat.mmcif,
33
33
  structure_provider=StructureProvider()
@@ -70,6 +70,7 @@ class EsmProtFromChain(Dataset):
70
70
  for atom_ch in chain_iter(structure):
71
71
  protein_chain = ProteinChain.from_atomarray(rename_atom_ch(atom_ch))
72
72
  return ESMProtein.from_protein_chain(protein_chain), item_name
73
+ return None
73
74
 
74
75
 
75
76
  if __name__ == '__main__':
@@ -19,7 +19,7 @@ class EsmProtFromStructure(EsmProtFromChain):
19
19
  def __init__(
20
20
  self,
21
21
  src_stream,
22
- src_location=SrcLocation.local,
22
+ src_location=SrcLocation.file,
23
23
  structure_location=StructureLocation.local,
24
24
  structure_format=StructureFormat.mmcif,
25
25
  min_res_n=0,
@@ -21,7 +21,7 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
21
21
  self,
22
22
  src_stream,
23
23
  res_embedding_location,
24
- src_location=SrcLocation.local,
24
+ src_location=SrcLocation.file,
25
25
  structure_location=StructureLocation.local,
26
26
  structure_format=StructureFormat.mmcif,
27
27
  min_res_n=0,
@@ -63,6 +63,6 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
63
63
  structure = stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure
64
64
  item_name = row[ResidueAssemblyDatasetFromStructure.ITEM_NAME_ATTR]
65
65
  for assembly_id in get_assemblies(structure=structure, structure_format=self.structure_format):
66
- assemblies.append((src_name, src_structure, str(assembly_id), f"{item_name}.{assembly_id}"))
66
+ assemblies.append((src_name, src_structure, str(assembly_id), f"{item_name}-{assembly_id}"))
67
67
 
68
68
  return tuple(assemblies)
@@ -22,7 +22,7 @@ class ResidueAssemblyEmbeddingFromTensorFile(Dataset):
22
22
  self,
23
23
  src_stream,
24
24
  res_embedding_location,
25
- src_location=SrcLocation.local,
25
+ src_location=SrcLocation.file,
26
26
  structure_location=StructureLocation.local,
27
27
  structure_format=StructureFormat.mmcif,
28
28
  min_res_n=0,
@@ -79,7 +79,7 @@ if __name__ == "__main__":
79
79
  dataset = ResidueAssemblyEmbeddingFromTensorFile(
80
80
  src_stream="/Users/joan/tmp/assembly-test.csv",
81
81
  res_embedding_location="/Users/joan/tmp",
82
- src_location=SrcLocation.local,
82
+ src_location=SrcLocation.file,
83
83
  structure_location=StructureLocation.local,
84
84
  structure_format=StructureFormat.mmcif
85
85
  )
@@ -21,7 +21,7 @@ class ResidueEmbeddingFromStructure(ResidueEmbeddingFromTensorFile):
21
21
  self,
22
22
  src_stream,
23
23
  res_embedding_location,
24
- src_location=SrcLocation.local,
24
+ src_location=SrcLocation.file,
25
25
  structure_location=StructureLocation.local,
26
26
  structure_format=StructureFormat.mmcif,
27
27
  min_res_n=0,
@@ -15,7 +15,7 @@ class ResidueEmbeddingFromTensorFile(Dataset):
15
15
  def __init__(
16
16
  self,
17
17
  src_stream,
18
- src_location=SrcLocation.local
18
+ src_location=SrcLocation.file
19
19
  ):
20
20
  super().__init__()
21
21
  self.src_location = src_location
@@ -2,14 +2,15 @@ import sys
2
2
 
3
3
  from rcsb_embedding_model.dataset.resdiue_assembly_embedding_from_structure import ResidueAssemblyDatasetFromStructure
4
4
  from rcsb_embedding_model.dataset.residue_assembly_embedding_from_tensor_file import ResidueAssemblyEmbeddingFromTensorFile
5
- from rcsb_embedding_model.types.api_types import FileOrStreamTuple, SrcLocation, Accelerator, Devices, OptionalPath, EmbeddingPath, StructureLocation, StructureFormat, SrcAssemblyFrom
5
+ from rcsb_embedding_model.types.api_types import FileOrStreamTuple, SrcLocation, Accelerator, Devices, OptionalPath, \
6
+ EmbeddingPath, StructureLocation, StructureFormat, SrcAssemblyFrom, OutFormat
6
7
  from rcsb_embedding_model.inference.chain_inference import predict as chain_predict
7
8
 
8
9
 
9
10
  def predict(
10
11
  src_stream: FileOrStreamTuple,
11
12
  res_embedding_location: EmbeddingPath,
12
- src_location: SrcLocation = SrcLocation.local,
13
+ src_location: SrcLocation = SrcLocation.file,
13
14
  src_from: SrcAssemblyFrom = SrcAssemblyFrom.assembly,
14
15
  structure_location: StructureLocation = StructureLocation.local,
15
16
  structure_format: StructureFormat = StructureFormat.mmcif,
@@ -20,6 +21,8 @@ def predict(
20
21
  num_nodes: int = 1,
21
22
  accelerator: Accelerator = Accelerator.auto,
22
23
  devices: Devices = 'auto',
24
+ out_format: OutFormat = OutFormat.separated,
25
+ out_name: str = 'inference',
23
26
  out_path: OptionalPath = None
24
27
  ):
25
28
  inference_set = ResidueAssemblyEmbeddingFromTensorFile(
@@ -48,6 +51,8 @@ def predict(
48
51
  num_nodes=num_nodes,
49
52
  accelerator=accelerator,
50
53
  devices=devices,
54
+ out_format=out_format,
55
+ out_name=out_name,
51
56
  out_path=out_path,
52
57
  inference_set=inference_set
53
58
  )
@@ -5,15 +5,15 @@ from rcsb_embedding_model.dataset.residue_embedding_from_structure import Residu
5
5
  from rcsb_embedding_model.dataset.residue_embedding_from_tensor_file import ResidueEmbeddingFromTensorFile
6
6
  from rcsb_embedding_model.modules.chain_module import ChainModule
7
7
  from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath, FileOrStreamTuple, SrcLocation, \
8
- SrcTensorFrom, StructureLocation, StructureFormat
8
+ SrcTensorFrom, StructureLocation, StructureFormat, OutFormat
9
9
  from rcsb_embedding_model.utils.data import collate_seq_embeddings
10
- from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter
10
+ from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter, JsonStorage
11
11
 
12
12
 
13
13
  def predict(
14
14
  src_stream: FileOrStreamTuple,
15
15
  res_embedding_location: OptionalPath = None,
16
- src_location: SrcLocation = SrcLocation.local,
16
+ src_location: SrcLocation = SrcLocation.file,
17
17
  src_from: SrcTensorFrom = SrcTensorFrom.file,
18
18
  structure_location: StructureLocation = StructureLocation.local,
19
19
  structure_format: StructureFormat = StructureFormat.mmcif,
@@ -23,6 +23,8 @@ def predict(
23
23
  num_nodes: int = 1,
24
24
  accelerator: Accelerator = Accelerator.auto,
25
25
  devices: Devices = 'auto',
26
+ out_format: OutFormat = OutFormat.separated,
27
+ out_name: str = 'inference',
26
28
  out_path: OptionalPath = None,
27
29
  inference_set=None
28
30
  ):
@@ -51,13 +53,13 @@ def predict(
51
53
  )
52
54
 
53
55
  module = ChainModule()
54
-
55
- inference_writer = CsvBatchWriter(out_path) if out_path is not None else None
56
+ inference_writer = (JsonStorage(out_path, out_name) if out_format == OutFormat.grouped else CsvBatchWriter(out_path)) if out_path is not None else None
56
57
  trainer = Trainer(
57
58
  callbacks=[inference_writer] if inference_writer is not None else None,
58
59
  num_nodes=num_nodes,
59
60
  accelerator=accelerator,
60
- devices=devices
61
+ devices=devices,
62
+ logger=False
61
63
  )
62
64
 
63
65
  prediction = trainer.predict(
@@ -4,13 +4,14 @@ from lightning import Trainer
4
4
  from rcsb_embedding_model.dataset.esm_prot_from_structure import EsmProtFromStructure
5
5
  from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
6
6
  from rcsb_embedding_model.modules.esm_module import EsmModule
7
- from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, SrcProteinFrom, FileOrStreamTuple, SrcLocation
8
- from rcsb_embedding_model.writer.batch_writer import TensorBatchWriter
7
+ from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, \
8
+ SrcProteinFrom, FileOrStreamTuple, SrcLocation, OutFormat
9
+ from rcsb_embedding_model.writer.batch_writer import TensorBatchWriter, JsonStorage
9
10
 
10
11
 
11
12
  def predict(
12
13
  src_stream: FileOrStreamTuple,
13
- src_location: SrcLocation = SrcLocation.local,
14
+ src_location: SrcLocation = SrcLocation.file,
14
15
  src_from: SrcProteinFrom = SrcProteinFrom.chain,
15
16
  structure_location: StructureLocation = StructureLocation.local,
16
17
  structure_format: StructureFormat = StructureFormat.mmcif,
@@ -20,6 +21,8 @@ def predict(
20
21
  num_nodes: int = 1,
21
22
  accelerator: Accelerator = Accelerator.auto,
22
23
  devices: Devices = 'auto',
24
+ out_format: OutFormat = OutFormat.separated,
25
+ out_name: str = 'inference',
23
26
  out_path: OptionalPath = None
24
27
  ):
25
28
 
@@ -44,12 +47,13 @@ def predict(
44
47
  )
45
48
 
46
49
  module = EsmModule()
47
- inference_writer = TensorBatchWriter(out_path) if out_path is not None else None
50
+ inference_writer = (JsonStorage(out_path, out_name) if out_format == OutFormat.grouped else TensorBatchWriter(out_path)) if out_path is not None else None
48
51
  trainer = Trainer(
49
52
  callbacks=[inference_writer] if inference_writer is not None else None,
50
53
  num_nodes=num_nodes,
51
54
  accelerator=accelerator,
52
- devices=devices
55
+ devices=devices,
56
+ logger=False
53
57
  )
54
58
 
55
59
  prediction = trainer.predict(
@@ -4,13 +4,14 @@ from lightning import Trainer
4
4
  from rcsb_embedding_model.dataset.esm_prot_from_structure import EsmProtFromStructure
5
5
  from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
6
6
  from rcsb_embedding_model.modules.structure_module import StructureModule
7
- from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, SrcProteinFrom, FileOrStreamTuple, SrcLocation
8
- from rcsb_embedding_model.writer.batch_writer import DataFrameStorage
7
+ from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, \
8
+ SrcProteinFrom, FileOrStreamTuple, SrcLocation
9
+ from rcsb_embedding_model.writer.batch_writer import JsonStorage
9
10
 
10
11
 
11
12
  def predict(
12
13
  src_stream: FileOrStreamTuple,
13
- src_location: SrcLocation = SrcLocation.local,
14
+ src_location: SrcLocation = SrcLocation.file,
14
15
  src_from: SrcProteinFrom = SrcProteinFrom.chain,
15
16
  structure_location: StructureLocation = StructureLocation.local,
16
17
  structure_format: StructureFormat = StructureFormat.mmcif,
@@ -20,8 +21,8 @@ def predict(
20
21
  num_nodes: int = 1,
21
22
  accelerator: Accelerator = Accelerator.auto,
22
23
  devices: Devices = 'auto',
23
- out_path: OptionalPath = None,
24
- out_df_name: str = None
24
+ out_name: str = 'inference',
25
+ out_path: OptionalPath = None
25
26
  ):
26
27
 
27
28
  inference_set = EsmProtFromChain(
@@ -45,12 +46,13 @@ def predict(
45
46
  )
46
47
 
47
48
  module = StructureModule()
48
- inference_writer = DataFrameStorage(out_path, out_df_name) if out_path is not None and out_df_name is not None else None
49
+ inference_writer = JsonStorage(out_path, out_name) if out_path is not None and out_name is not None else None
49
50
  trainer = Trainer(
50
51
  callbacks=[inference_writer] if inference_writer is not None else None,
51
52
  num_nodes=num_nodes,
52
53
  accelerator=accelerator,
53
- devices=devices
54
+ devices=devices,
55
+ logger=False
54
56
  )
55
57
 
56
58
  prediction = trainer.predict(
@@ -1,4 +1,5 @@
1
1
  from esm.sdk.api import SamplingConfig
2
+ from esm.sdk import batch_executor
2
3
  from lightning import LightningModule
3
4
 
4
5
  from rcsb_embedding_model.utils.model import get_residue_model
@@ -14,11 +15,13 @@ class EsmModule(LightningModule):
14
15
 
15
16
  def predict_step(self, prot_batch, batch_idx):
16
17
  prot_embeddings = []
17
- prot_names = []
18
- for esm_prot, name in prot_batch:
19
- embeddings = self.esm3.forward_and_sample(
18
+ def __batch_embedding(esm_prot):
19
+ return self.esm3.forward_and_sample(
20
20
  self.esm3.encode(esm_prot), SamplingConfig(return_per_residue_embeddings=True)
21
21
  ).per_residue_embedding
22
- prot_embeddings.append(embeddings)
23
- prot_names.append(name)
24
- return tuple(prot_embeddings), tuple(prot_names)
22
+ with batch_executor() as executor:
23
+ prot_embeddings = executor.execute_batch(
24
+ user_func=__batch_embedding,
25
+ esm_prot=[esm_prot for esm_prot, name in prot_batch]
26
+ )
27
+ return tuple(prot_embeddings), tuple([name for esm_prot, name in prot_batch])
@@ -32,7 +32,7 @@ class Accelerator(str, Enum):
32
32
 
33
33
 
34
34
  class SrcLocation(str, Enum):
35
- local = "local"
35
+ file = "file"
36
36
  stream = "stream"
37
37
 
38
38
 
@@ -54,3 +54,7 @@ class SrcAssemblyFrom(str, Enum):
54
54
  class SrcTensorFrom(str, Enum):
55
55
  file = "file"
56
56
  structure = "structure"
57
+
58
+ class OutFormat(str, Enum):
59
+ separated = "separated"
60
+ grouped = "grouped"
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from io import StringIO
2
3
 
3
4
  import requests
@@ -76,3 +77,9 @@ def concatenate_tensors(file_list, max_residues, dim=0):
76
77
  return tensor_cat
77
78
  else:
78
79
  raise ValueError("No valid tensors were loaded to concatenate.")
80
+
81
+ def adapt_csv_to_embedding_chain_stream(src_file, res_embedding_location):
82
+ def __parse_row(row):
83
+ r = row.split(",")
84
+ return os.path.join(res_embedding_location, f"{r[0]}.{r[2]}.pt"), f"{r[0]}.{r[2]}"
85
+ return tuple([__parse_row(r) for r in open(src_file)])
@@ -111,3 +111,21 @@ class DataFrameStorage(CoreBatchWriter, ABC):
111
111
  f"{self.out_path}/{self.df_id}.pkl.gz",
112
112
  compression='gzip'
113
113
  )
114
+
115
+
116
+ class JsonStorage(DataFrameStorage, ABC):
117
+ def __init__(
118
+ self,
119
+ output_path,
120
+ df_id,
121
+ postfix="pkl",
122
+ write_interval="batch"
123
+ ):
124
+ super().__init__(output_path, df_id, postfix, write_interval)
125
+
126
+ def on_predict_end(self, trainer, pl_module):
127
+ self.embedding.to_json(
128
+ f"{self.out_path}/{self.df_id}.json.gz",
129
+ orient='records',
130
+ compression='gzip'
131
+ )