rcsb-embedding-model 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rcsb-embedding-model might be problematic. Click here for more details.
- rcsb_embedding_model/cli/inference.py +139 -11
- rcsb_embedding_model/dataset/esm_prot_from_chain.py +1 -1
- rcsb_embedding_model/dataset/esm_prot_from_structure.py +1 -1
- rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py +2 -2
- rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py +2 -2
- rcsb_embedding_model/dataset/residue_embedding_from_structure.py +1 -1
- rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py +1 -1
- rcsb_embedding_model/inference/assembly_inferece.py +7 -2
- rcsb_embedding_model/inference/chain_inference.py +8 -6
- rcsb_embedding_model/inference/esm_inference.py +9 -5
- rcsb_embedding_model/inference/structure_inference.py +9 -7
- rcsb_embedding_model/types/api_types.py +5 -1
- rcsb_embedding_model/writer/batch_writer.py +18 -0
- {rcsb_embedding_model-0.0.15.dist-info → rcsb_embedding_model-0.0.17.dist-info}/METADATA +2 -2
- {rcsb_embedding_model-0.0.15.dist-info → rcsb_embedding_model-0.0.17.dist-info}/RECORD +18 -18
- {rcsb_embedding_model-0.0.15.dist-info → rcsb_embedding_model-0.0.17.dist-info}/WHEEL +0 -0
- {rcsb_embedding_model-0.0.15.dist-info → rcsb_embedding_model-0.0.17.dist-info}/entry_points.txt +0 -0
- {rcsb_embedding_model-0.0.15.dist-info → rcsb_embedding_model-0.0.17.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -5,7 +5,7 @@ import typer
|
|
|
5
5
|
|
|
6
6
|
from rcsb_embedding_model.cli.args_utils import arg_devices
|
|
7
7
|
from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, SrcLocation, SrcProteinFrom, \
|
|
8
|
-
StructureLocation, SrcAssemblyFrom, SrcTensorFrom
|
|
8
|
+
StructureLocation, SrcAssemblyFrom, SrcTensorFrom, OutFormat
|
|
9
9
|
|
|
10
10
|
app = typer.Typer(
|
|
11
11
|
add_completion=False
|
|
@@ -31,6 +31,12 @@ def residue_embedding(
|
|
|
31
31
|
resolve_path=True,
|
|
32
32
|
help='Output path to store predictions. Embeddings are stored as torch tensor files.'
|
|
33
33
|
)],
|
|
34
|
+
output_format: Annotated[OutFormat, typer.Option(
|
|
35
|
+
help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
|
|
36
|
+
)] = OutFormat.separated,
|
|
37
|
+
output_name: Annotated[str, typer.Option(
|
|
38
|
+
help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
|
|
39
|
+
)] = 'inference',
|
|
34
40
|
src_from: Annotated[SrcProteinFrom, typer.Option(
|
|
35
41
|
help='Use specific chains or all chains in a structure.'
|
|
36
42
|
)] = SrcProteinFrom.chain,
|
|
@@ -62,7 +68,7 @@ def residue_embedding(
|
|
|
62
68
|
from rcsb_embedding_model.inference.esm_inference import predict
|
|
63
69
|
predict(
|
|
64
70
|
src_stream=src_file,
|
|
65
|
-
src_location=SrcLocation.
|
|
71
|
+
src_location=SrcLocation.file,
|
|
66
72
|
src_from=src_from,
|
|
67
73
|
structure_location=structure_location,
|
|
68
74
|
structure_format=structure_format,
|
|
@@ -72,6 +78,8 @@ def residue_embedding(
|
|
|
72
78
|
num_nodes=num_nodes,
|
|
73
79
|
accelerator=accelerator,
|
|
74
80
|
devices=arg_devices(devices),
|
|
81
|
+
out_format=output_format,
|
|
82
|
+
out_name=output_name,
|
|
75
83
|
out_path=output_path
|
|
76
84
|
)
|
|
77
85
|
|
|
@@ -95,9 +103,9 @@ def structure_embedding(
|
|
|
95
103
|
resolve_path=True,
|
|
96
104
|
help='Output path to store predictions. Embeddings are stored as a single DataFrame file (see out-df-name).'
|
|
97
105
|
)],
|
|
98
|
-
|
|
99
|
-
help='File name
|
|
100
|
-
)],
|
|
106
|
+
output_name: Annotated[str, typer.Option(
|
|
107
|
+
help='File name for storing embeddings as a single JSON file.'
|
|
108
|
+
)] = 'inference',
|
|
101
109
|
src_from: Annotated[SrcProteinFrom, typer.Option(
|
|
102
110
|
help='Use specific chains or all chains in a structure.'
|
|
103
111
|
)] = SrcProteinFrom.chain,
|
|
@@ -129,7 +137,7 @@ def structure_embedding(
|
|
|
129
137
|
from rcsb_embedding_model.inference.structure_inference import predict
|
|
130
138
|
predict(
|
|
131
139
|
src_stream=src_file,
|
|
132
|
-
src_location=SrcLocation.
|
|
140
|
+
src_location=SrcLocation.file,
|
|
133
141
|
src_from=src_from,
|
|
134
142
|
structure_location=structure_location,
|
|
135
143
|
structure_format=structure_format,
|
|
@@ -140,7 +148,7 @@ def structure_embedding(
|
|
|
140
148
|
accelerator=accelerator,
|
|
141
149
|
devices=arg_devices(devices),
|
|
142
150
|
out_path=output_path,
|
|
143
|
-
|
|
151
|
+
out_name=output_name
|
|
144
152
|
)
|
|
145
153
|
|
|
146
154
|
|
|
@@ -163,6 +171,12 @@ def chain_embedding(
|
|
|
163
171
|
resolve_path=True,
|
|
164
172
|
help='Output path to store predictions. Embeddings are stored as csv files.'
|
|
165
173
|
)],
|
|
174
|
+
output_format: Annotated[OutFormat, typer.Option(
|
|
175
|
+
help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
|
|
176
|
+
)] = OutFormat.separated,
|
|
177
|
+
output_name: Annotated[str, typer.Option(
|
|
178
|
+
help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
|
|
179
|
+
)] = 'inference',
|
|
166
180
|
res_embedding_location: Annotated[typer.FileText, typer.Option(
|
|
167
181
|
exists=True,
|
|
168
182
|
file_okay=False,
|
|
@@ -202,7 +216,7 @@ def chain_embedding(
|
|
|
202
216
|
predict(
|
|
203
217
|
src_stream=src_file,
|
|
204
218
|
res_embedding_location=res_embedding_location,
|
|
205
|
-
src_location=SrcLocation.
|
|
219
|
+
src_location=SrcLocation.file,
|
|
206
220
|
src_from=src_from,
|
|
207
221
|
structure_location=structure_location,
|
|
208
222
|
structure_format=structure_format,
|
|
@@ -212,7 +226,9 @@ def chain_embedding(
|
|
|
212
226
|
num_nodes=num_nodes,
|
|
213
227
|
accelerator=accelerator,
|
|
214
228
|
devices=arg_devices(devices),
|
|
215
|
-
out_path=output_path
|
|
229
|
+
out_path=output_path,
|
|
230
|
+
out_format=output_format,
|
|
231
|
+
out_name=output_name
|
|
216
232
|
)
|
|
217
233
|
|
|
218
234
|
@app.command(
|
|
@@ -241,6 +257,12 @@ def assembly_embedding(
|
|
|
241
257
|
resolve_path=True,
|
|
242
258
|
help='Output path to store predictions. Embeddings are stored as csv files.'
|
|
243
259
|
)],
|
|
260
|
+
output_format: Annotated[OutFormat, typer.Option(
|
|
261
|
+
help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
|
|
262
|
+
)] = OutFormat.separated,
|
|
263
|
+
output_name: Annotated[str, typer.Option(
|
|
264
|
+
help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
|
|
265
|
+
)] = 'inference',
|
|
244
266
|
src_from: Annotated[SrcAssemblyFrom, typer.Option(
|
|
245
267
|
help='Use specific assembly or all assemblies in a structure.'
|
|
246
268
|
)] = SrcAssemblyFrom.assembly,
|
|
@@ -276,7 +298,7 @@ def assembly_embedding(
|
|
|
276
298
|
predict(
|
|
277
299
|
src_stream=src_file,
|
|
278
300
|
res_embedding_location=res_embedding_location,
|
|
279
|
-
src_location=SrcLocation.
|
|
301
|
+
src_location=SrcLocation.file,
|
|
280
302
|
src_from=src_from,
|
|
281
303
|
structure_location=structure_location,
|
|
282
304
|
structure_format=structure_format,
|
|
@@ -287,7 +309,113 @@ def assembly_embedding(
|
|
|
287
309
|
num_nodes=num_nodes,
|
|
288
310
|
accelerator=accelerator,
|
|
289
311
|
devices=arg_devices(devices),
|
|
290
|
-
out_path=output_path
|
|
312
|
+
out_path=output_path,
|
|
313
|
+
out_format=output_format,
|
|
314
|
+
out_name=output_name
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
@app.command(
|
|
318
|
+
name="complete-embedding",
|
|
319
|
+
help="Calculate chain and assembly embeddings from structural files. Predictions are stored as csv files."
|
|
320
|
+
)
|
|
321
|
+
def complete_embedding(
|
|
322
|
+
src_file: Annotated[typer.FileText, typer.Option(
|
|
323
|
+
exists=True,
|
|
324
|
+
file_okay=True,
|
|
325
|
+
dir_okay=False,
|
|
326
|
+
resolve_path=True,
|
|
327
|
+
help='CSV file 3 columns: Structure Name | Structure File Path or URL | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
|
|
328
|
+
)],
|
|
329
|
+
output_path: Annotated[typer.FileText, typer.Option(
|
|
330
|
+
exists=True,
|
|
331
|
+
file_okay=False,
|
|
332
|
+
dir_okay=True,
|
|
333
|
+
resolve_path=True,
|
|
334
|
+
help='Output path to store predictions. Embeddings are stored as a single DataFrame file (see output_name).'
|
|
335
|
+
)],
|
|
336
|
+
esm_output_path: Annotated[typer.FileText, typer.Option(
|
|
337
|
+
exists=True,
|
|
338
|
+
file_okay=False,
|
|
339
|
+
dir_okay=True,
|
|
340
|
+
resolve_path=True,
|
|
341
|
+
help='Output path to store ESM predictions.'
|
|
342
|
+
)],
|
|
343
|
+
output_format: Annotated[OutFormat, typer.Option(
|
|
344
|
+
help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
|
|
345
|
+
)] = OutFormat.separated,
|
|
346
|
+
output_name: Annotated[str, typer.Option(
|
|
347
|
+
help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
|
|
348
|
+
)] = 'inference',
|
|
349
|
+
structure_location: Annotated[StructureLocation, typer.Option(
|
|
350
|
+
help='Structure file location.'
|
|
351
|
+
)] = StructureLocation.local,
|
|
352
|
+
structure_format: Annotated[StructureFormat, typer.Option(
|
|
353
|
+
help='Structure file format.'
|
|
354
|
+
)] = StructureFormat.mmcif,
|
|
355
|
+
min_res_n: Annotated[int, typer.Option(
|
|
356
|
+
help='When using all chains in a structure, consider only chains with more than <min_res_n> residues.'
|
|
357
|
+
)] = 0,
|
|
358
|
+
batch_size: Annotated[int, typer.Option(
|
|
359
|
+
help='Number of samples processed together in one iteration.'
|
|
360
|
+
)] = 1,
|
|
361
|
+
num_workers: Annotated[int, typer.Option(
|
|
362
|
+
help='Number of subprocesses to use for data loading.'
|
|
363
|
+
)] = 0,
|
|
364
|
+
num_nodes: Annotated[int, typer.Option(
|
|
365
|
+
help='Number of nodes to use for inference.'
|
|
366
|
+
)] = 1,
|
|
367
|
+
accelerator: Annotated[Accelerator, typer.Option(
|
|
368
|
+
help='Device used for inference.'
|
|
369
|
+
)] = Accelerator.auto,
|
|
370
|
+
devices: Annotated[List[str], typer.Option(
|
|
371
|
+
help='The devices to use. Can be set to a positive number or "auto". Repeat this argument to indicate multiple indices of devices. "auto" for automatic selection based on the chosen accelerator.'
|
|
372
|
+
)] = tuple(['auto'])
|
|
373
|
+
):
|
|
374
|
+
residue_embedding(
|
|
375
|
+
src_file=src_file,
|
|
376
|
+
src_from=SrcProteinFrom.structure,
|
|
377
|
+
output_path=esm_output_path,
|
|
378
|
+
output_format=OutFormat.separated,
|
|
379
|
+
structure_location=structure_location,
|
|
380
|
+
structure_format=structure_format,
|
|
381
|
+
min_res_n=min_res_n,
|
|
382
|
+
batch_size=batch_size,
|
|
383
|
+
num_workers=num_workers,
|
|
384
|
+
num_nodes=num_nodes,
|
|
385
|
+
accelerator=accelerator,
|
|
386
|
+
devices=devices,
|
|
387
|
+
)
|
|
388
|
+
chain_embedding(
|
|
389
|
+
src_file=src_file,
|
|
390
|
+
src_from=SrcTensorFrom.structure,
|
|
391
|
+
output_path=output_path,
|
|
392
|
+
output_format=output_format,
|
|
393
|
+
output_name=f"{output_name}-chain",
|
|
394
|
+
res_embedding_location=esm_output_path,
|
|
395
|
+
structure_location=structure_location,
|
|
396
|
+
structure_format=structure_format,
|
|
397
|
+
min_res_n=min_res_n,
|
|
398
|
+
batch_size=batch_size,
|
|
399
|
+
num_workers=num_workers,
|
|
400
|
+
num_nodes=num_nodes,
|
|
401
|
+
accelerator=accelerator,
|
|
402
|
+
devices=devices
|
|
403
|
+
)
|
|
404
|
+
assembly_embedding(
|
|
405
|
+
src_file=src_file,
|
|
406
|
+
src_from=SrcAssemblyFrom.structure,
|
|
407
|
+
output_path=output_path,
|
|
408
|
+
output_format=output_format,
|
|
409
|
+
output_name=f"{output_name}-assembly",
|
|
410
|
+
res_embedding_location=esm_output_path,
|
|
411
|
+
structure_location=structure_location,
|
|
412
|
+
structure_format=structure_format,
|
|
413
|
+
min_res_n=min_res_n,
|
|
414
|
+
batch_size=batch_size,
|
|
415
|
+
num_workers=num_workers,
|
|
416
|
+
num_nodes=num_nodes,
|
|
417
|
+
accelerator=accelerator,
|
|
418
|
+
devices=devices
|
|
291
419
|
)
|
|
292
420
|
|
|
293
421
|
|
|
@@ -27,7 +27,7 @@ class EsmProtFromChain(Dataset):
|
|
|
27
27
|
def __init__(
|
|
28
28
|
self,
|
|
29
29
|
src_stream,
|
|
30
|
-
src_location=SrcLocation.
|
|
30
|
+
src_location=SrcLocation.file,
|
|
31
31
|
structure_location=StructureLocation.local,
|
|
32
32
|
structure_format=StructureFormat.mmcif,
|
|
33
33
|
structure_provider=StructureProvider()
|
|
@@ -19,7 +19,7 @@ class EsmProtFromStructure(EsmProtFromChain):
|
|
|
19
19
|
def __init__(
|
|
20
20
|
self,
|
|
21
21
|
src_stream,
|
|
22
|
-
src_location=SrcLocation.
|
|
22
|
+
src_location=SrcLocation.file,
|
|
23
23
|
structure_location=StructureLocation.local,
|
|
24
24
|
structure_format=StructureFormat.mmcif,
|
|
25
25
|
min_res_n=0,
|
|
@@ -21,7 +21,7 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
|
|
|
21
21
|
self,
|
|
22
22
|
src_stream,
|
|
23
23
|
res_embedding_location,
|
|
24
|
-
src_location=SrcLocation.
|
|
24
|
+
src_location=SrcLocation.file,
|
|
25
25
|
structure_location=StructureLocation.local,
|
|
26
26
|
structure_format=StructureFormat.mmcif,
|
|
27
27
|
min_res_n=0,
|
|
@@ -63,6 +63,6 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
|
|
|
63
63
|
structure = stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure
|
|
64
64
|
item_name = row[ResidueAssemblyDatasetFromStructure.ITEM_NAME_ATTR]
|
|
65
65
|
for assembly_id in get_assemblies(structure=structure, structure_format=self.structure_format):
|
|
66
|
-
assemblies.append((src_name, src_structure, str(assembly_id), f"{item_name}
|
|
66
|
+
assemblies.append((src_name, src_structure, str(assembly_id), f"{item_name}-{assembly_id}"))
|
|
67
67
|
|
|
68
68
|
return tuple(assemblies)
|
|
@@ -22,7 +22,7 @@ class ResidueAssemblyEmbeddingFromTensorFile(Dataset):
|
|
|
22
22
|
self,
|
|
23
23
|
src_stream,
|
|
24
24
|
res_embedding_location,
|
|
25
|
-
src_location=SrcLocation.
|
|
25
|
+
src_location=SrcLocation.file,
|
|
26
26
|
structure_location=StructureLocation.local,
|
|
27
27
|
structure_format=StructureFormat.mmcif,
|
|
28
28
|
min_res_n=0,
|
|
@@ -79,7 +79,7 @@ if __name__ == "__main__":
|
|
|
79
79
|
dataset = ResidueAssemblyEmbeddingFromTensorFile(
|
|
80
80
|
src_stream="/Users/joan/tmp/assembly-test.csv",
|
|
81
81
|
res_embedding_location="/Users/joan/tmp",
|
|
82
|
-
src_location=SrcLocation.
|
|
82
|
+
src_location=SrcLocation.file,
|
|
83
83
|
structure_location=StructureLocation.local,
|
|
84
84
|
structure_format=StructureFormat.mmcif
|
|
85
85
|
)
|
|
@@ -21,7 +21,7 @@ class ResidueEmbeddingFromStructure(ResidueEmbeddingFromTensorFile):
|
|
|
21
21
|
self,
|
|
22
22
|
src_stream,
|
|
23
23
|
res_embedding_location,
|
|
24
|
-
src_location=SrcLocation.
|
|
24
|
+
src_location=SrcLocation.file,
|
|
25
25
|
structure_location=StructureLocation.local,
|
|
26
26
|
structure_format=StructureFormat.mmcif,
|
|
27
27
|
min_res_n=0,
|
|
@@ -2,14 +2,15 @@ import sys
|
|
|
2
2
|
|
|
3
3
|
from rcsb_embedding_model.dataset.resdiue_assembly_embedding_from_structure import ResidueAssemblyDatasetFromStructure
|
|
4
4
|
from rcsb_embedding_model.dataset.residue_assembly_embedding_from_tensor_file import ResidueAssemblyEmbeddingFromTensorFile
|
|
5
|
-
from rcsb_embedding_model.types.api_types import FileOrStreamTuple, SrcLocation, Accelerator, Devices, OptionalPath,
|
|
5
|
+
from rcsb_embedding_model.types.api_types import FileOrStreamTuple, SrcLocation, Accelerator, Devices, OptionalPath, \
|
|
6
|
+
EmbeddingPath, StructureLocation, StructureFormat, SrcAssemblyFrom, OutFormat
|
|
6
7
|
from rcsb_embedding_model.inference.chain_inference import predict as chain_predict
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
def predict(
|
|
10
11
|
src_stream: FileOrStreamTuple,
|
|
11
12
|
res_embedding_location: EmbeddingPath,
|
|
12
|
-
src_location: SrcLocation = SrcLocation.
|
|
13
|
+
src_location: SrcLocation = SrcLocation.file,
|
|
13
14
|
src_from: SrcAssemblyFrom = SrcAssemblyFrom.assembly,
|
|
14
15
|
structure_location: StructureLocation = StructureLocation.local,
|
|
15
16
|
structure_format: StructureFormat = StructureFormat.mmcif,
|
|
@@ -20,6 +21,8 @@ def predict(
|
|
|
20
21
|
num_nodes: int = 1,
|
|
21
22
|
accelerator: Accelerator = Accelerator.auto,
|
|
22
23
|
devices: Devices = 'auto',
|
|
24
|
+
out_format: OutFormat = OutFormat.separated,
|
|
25
|
+
out_name: str = 'inference',
|
|
23
26
|
out_path: OptionalPath = None
|
|
24
27
|
):
|
|
25
28
|
inference_set = ResidueAssemblyEmbeddingFromTensorFile(
|
|
@@ -48,6 +51,8 @@ def predict(
|
|
|
48
51
|
num_nodes=num_nodes,
|
|
49
52
|
accelerator=accelerator,
|
|
50
53
|
devices=devices,
|
|
54
|
+
out_format=out_format,
|
|
55
|
+
out_name=out_name,
|
|
51
56
|
out_path=out_path,
|
|
52
57
|
inference_set=inference_set
|
|
53
58
|
)
|
|
@@ -5,15 +5,15 @@ from rcsb_embedding_model.dataset.residue_embedding_from_structure import Residu
|
|
|
5
5
|
from rcsb_embedding_model.dataset.residue_embedding_from_tensor_file import ResidueEmbeddingFromTensorFile
|
|
6
6
|
from rcsb_embedding_model.modules.chain_module import ChainModule
|
|
7
7
|
from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath, FileOrStreamTuple, SrcLocation, \
|
|
8
|
-
SrcTensorFrom, StructureLocation, StructureFormat
|
|
8
|
+
SrcTensorFrom, StructureLocation, StructureFormat, OutFormat
|
|
9
9
|
from rcsb_embedding_model.utils.data import collate_seq_embeddings
|
|
10
|
-
from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter
|
|
10
|
+
from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter, JsonStorage
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def predict(
|
|
14
14
|
src_stream: FileOrStreamTuple,
|
|
15
15
|
res_embedding_location: OptionalPath = None,
|
|
16
|
-
src_location: SrcLocation = SrcLocation.
|
|
16
|
+
src_location: SrcLocation = SrcLocation.file,
|
|
17
17
|
src_from: SrcTensorFrom = SrcTensorFrom.file,
|
|
18
18
|
structure_location: StructureLocation = StructureLocation.local,
|
|
19
19
|
structure_format: StructureFormat = StructureFormat.mmcif,
|
|
@@ -23,6 +23,8 @@ def predict(
|
|
|
23
23
|
num_nodes: int = 1,
|
|
24
24
|
accelerator: Accelerator = Accelerator.auto,
|
|
25
25
|
devices: Devices = 'auto',
|
|
26
|
+
out_format: OutFormat = OutFormat.separated,
|
|
27
|
+
out_name: str = 'inference',
|
|
26
28
|
out_path: OptionalPath = None,
|
|
27
29
|
inference_set=None
|
|
28
30
|
):
|
|
@@ -51,13 +53,13 @@ def predict(
|
|
|
51
53
|
)
|
|
52
54
|
|
|
53
55
|
module = ChainModule()
|
|
54
|
-
|
|
55
|
-
inference_writer = CsvBatchWriter(out_path) if out_path is not None else None
|
|
56
|
+
inference_writer = (JsonStorage(out_path, out_name) if out_format == OutFormat.grouped else CsvBatchWriter(out_path)) if out_path is not None else None
|
|
56
57
|
trainer = Trainer(
|
|
57
58
|
callbacks=[inference_writer] if inference_writer is not None else None,
|
|
58
59
|
num_nodes=num_nodes,
|
|
59
60
|
accelerator=accelerator,
|
|
60
|
-
devices=devices
|
|
61
|
+
devices=devices,
|
|
62
|
+
logger=False
|
|
61
63
|
)
|
|
62
64
|
|
|
63
65
|
prediction = trainer.predict(
|
|
@@ -4,13 +4,14 @@ from lightning import Trainer
|
|
|
4
4
|
from rcsb_embedding_model.dataset.esm_prot_from_structure import EsmProtFromStructure
|
|
5
5
|
from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
|
|
6
6
|
from rcsb_embedding_model.modules.esm_module import EsmModule
|
|
7
|
-
from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation,
|
|
8
|
-
|
|
7
|
+
from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, \
|
|
8
|
+
SrcProteinFrom, FileOrStreamTuple, SrcLocation, OutFormat
|
|
9
|
+
from rcsb_embedding_model.writer.batch_writer import TensorBatchWriter, JsonStorage
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def predict(
|
|
12
13
|
src_stream: FileOrStreamTuple,
|
|
13
|
-
src_location: SrcLocation = SrcLocation.
|
|
14
|
+
src_location: SrcLocation = SrcLocation.file,
|
|
14
15
|
src_from: SrcProteinFrom = SrcProteinFrom.chain,
|
|
15
16
|
structure_location: StructureLocation = StructureLocation.local,
|
|
16
17
|
structure_format: StructureFormat = StructureFormat.mmcif,
|
|
@@ -20,6 +21,8 @@ def predict(
|
|
|
20
21
|
num_nodes: int = 1,
|
|
21
22
|
accelerator: Accelerator = Accelerator.auto,
|
|
22
23
|
devices: Devices = 'auto',
|
|
24
|
+
out_format: OutFormat = OutFormat.separated,
|
|
25
|
+
out_name: str = 'inference',
|
|
23
26
|
out_path: OptionalPath = None
|
|
24
27
|
):
|
|
25
28
|
|
|
@@ -44,12 +47,13 @@ def predict(
|
|
|
44
47
|
)
|
|
45
48
|
|
|
46
49
|
module = EsmModule()
|
|
47
|
-
inference_writer = TensorBatchWriter(out_path) if out_path is not None else None
|
|
50
|
+
inference_writer = (JsonStorage(out_path, out_name) if out_format == OutFormat.grouped else TensorBatchWriter(out_path)) if out_path is not None else None
|
|
48
51
|
trainer = Trainer(
|
|
49
52
|
callbacks=[inference_writer] if inference_writer is not None else None,
|
|
50
53
|
num_nodes=num_nodes,
|
|
51
54
|
accelerator=accelerator,
|
|
52
|
-
devices=devices
|
|
55
|
+
devices=devices,
|
|
56
|
+
logger=False
|
|
53
57
|
)
|
|
54
58
|
|
|
55
59
|
prediction = trainer.predict(
|
|
@@ -4,13 +4,14 @@ from lightning import Trainer
|
|
|
4
4
|
from rcsb_embedding_model.dataset.esm_prot_from_structure import EsmProtFromStructure
|
|
5
5
|
from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
|
|
6
6
|
from rcsb_embedding_model.modules.structure_module import StructureModule
|
|
7
|
-
from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation,
|
|
8
|
-
|
|
7
|
+
from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, \
|
|
8
|
+
SrcProteinFrom, FileOrStreamTuple, SrcLocation
|
|
9
|
+
from rcsb_embedding_model.writer.batch_writer import JsonStorage
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def predict(
|
|
12
13
|
src_stream: FileOrStreamTuple,
|
|
13
|
-
src_location: SrcLocation = SrcLocation.
|
|
14
|
+
src_location: SrcLocation = SrcLocation.file,
|
|
14
15
|
src_from: SrcProteinFrom = SrcProteinFrom.chain,
|
|
15
16
|
structure_location: StructureLocation = StructureLocation.local,
|
|
16
17
|
structure_format: StructureFormat = StructureFormat.mmcif,
|
|
@@ -20,8 +21,8 @@ def predict(
|
|
|
20
21
|
num_nodes: int = 1,
|
|
21
22
|
accelerator: Accelerator = Accelerator.auto,
|
|
22
23
|
devices: Devices = 'auto',
|
|
23
|
-
|
|
24
|
-
|
|
24
|
+
out_name: str = 'inference',
|
|
25
|
+
out_path: OptionalPath = None
|
|
25
26
|
):
|
|
26
27
|
|
|
27
28
|
inference_set = EsmProtFromChain(
|
|
@@ -45,12 +46,13 @@ def predict(
|
|
|
45
46
|
)
|
|
46
47
|
|
|
47
48
|
module = StructureModule()
|
|
48
|
-
inference_writer =
|
|
49
|
+
inference_writer = JsonStorage(out_path, out_name) if out_path is not None and out_name is not None else None
|
|
49
50
|
trainer = Trainer(
|
|
50
51
|
callbacks=[inference_writer] if inference_writer is not None else None,
|
|
51
52
|
num_nodes=num_nodes,
|
|
52
53
|
accelerator=accelerator,
|
|
53
|
-
devices=devices
|
|
54
|
+
devices=devices,
|
|
55
|
+
logger=False
|
|
54
56
|
)
|
|
55
57
|
|
|
56
58
|
prediction = trainer.predict(
|
|
@@ -32,7 +32,7 @@ class Accelerator(str, Enum):
|
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
class SrcLocation(str, Enum):
|
|
35
|
-
|
|
35
|
+
file = "file"
|
|
36
36
|
stream = "stream"
|
|
37
37
|
|
|
38
38
|
|
|
@@ -54,3 +54,7 @@ class SrcAssemblyFrom(str, Enum):
|
|
|
54
54
|
class SrcTensorFrom(str, Enum):
|
|
55
55
|
file = "file"
|
|
56
56
|
structure = "structure"
|
|
57
|
+
|
|
58
|
+
class OutFormat(str, Enum):
|
|
59
|
+
separated = "separated"
|
|
60
|
+
grouped = "grouped"
|
|
@@ -111,3 +111,21 @@ class DataFrameStorage(CoreBatchWriter, ABC):
|
|
|
111
111
|
f"{self.out_path}/{self.df_id}.pkl.gz",
|
|
112
112
|
compression='gzip'
|
|
113
113
|
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class JsonStorage(DataFrameStorage, ABC):
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
output_path,
|
|
120
|
+
df_id,
|
|
121
|
+
postfix="pkl",
|
|
122
|
+
write_interval="batch"
|
|
123
|
+
):
|
|
124
|
+
super().__init__(output_path, df_id, postfix, write_interval)
|
|
125
|
+
|
|
126
|
+
def on_predict_end(self, trainer, pl_module):
|
|
127
|
+
self.embedding.to_json(
|
|
128
|
+
f"{self.out_path}/{self.df_id}.json.gz",
|
|
129
|
+
orient='records',
|
|
130
|
+
compression='gzip'
|
|
131
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rcsb-embedding-model
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.17
|
|
4
4
|
Summary: Protein Embedding Model for Structure Search
|
|
5
5
|
Project-URL: Homepage, https://github.com/rcsb/rcsb-embedding-model
|
|
6
6
|
Project-URL: Issues, https://github.com/rcsb/rcsb-embedding-model/issues
|
|
@@ -18,7 +18,7 @@ Description-Content-Type: text/markdown
|
|
|
18
18
|
|
|
19
19
|
# RCSB Embedding Model
|
|
20
20
|
|
|
21
|
-
**Version** 0.0.
|
|
21
|
+
**Version** 0.0.17
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
## Overview
|
|
@@ -1,30 +1,30 @@
|
|
|
1
1
|
rcsb_embedding_model/__init__.py,sha256=r3gLdeBIXkQEQA_K6QcRPO-TtYuAQSutk6pXRUE_nas,120
|
|
2
2
|
rcsb_embedding_model/rcsb_structure_embedding.py,sha256=dKp9hXQO0JAnO4SEfjJ_mG_jHu3UxAPguv6jkOjp-BI,4487
|
|
3
3
|
rcsb_embedding_model/cli/args_utils.py,sha256=7nP2q8pL5dWK_U7opxtWmoFcYVwasky6elHk-dASFaI,165
|
|
4
|
-
rcsb_embedding_model/cli/inference.py,sha256=
|
|
5
|
-
rcsb_embedding_model/dataset/esm_prot_from_chain.py,sha256=
|
|
6
|
-
rcsb_embedding_model/dataset/esm_prot_from_structure.py,sha256=
|
|
7
|
-
rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py,sha256=
|
|
8
|
-
rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py,sha256=
|
|
9
|
-
rcsb_embedding_model/dataset/residue_embedding_from_structure.py,sha256=
|
|
10
|
-
rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py,sha256=
|
|
11
|
-
rcsb_embedding_model/inference/assembly_inferece.py,sha256=
|
|
12
|
-
rcsb_embedding_model/inference/chain_inference.py,sha256=
|
|
13
|
-
rcsb_embedding_model/inference/esm_inference.py,sha256=
|
|
14
|
-
rcsb_embedding_model/inference/structure_inference.py,sha256=
|
|
4
|
+
rcsb_embedding_model/cli/inference.py,sha256=XmnRwygWYQkPqeJi4I1H2jjo24IxXzt_EihdYZ7LLqA,18696
|
|
5
|
+
rcsb_embedding_model/dataset/esm_prot_from_chain.py,sha256=u6vu_2CN6vaYAk6kpvHAOgHuEHjXJl3fukMk-tDr_6E,3486
|
|
6
|
+
rcsb_embedding_model/dataset/esm_prot_from_structure.py,sha256=TeITPdi1uc3qLQ-Pgn807oH6eM0LYv-67RE50ZT4dLI,2551
|
|
7
|
+
rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py,sha256=worRiNqOJRjyr693TaillsS65bdTdGOoHfwyT9yE1O4,2866
|
|
8
|
+
rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py,sha256=JG4rrhziIUtdTmbuTbMbEYHrvlda4m5VWvdJXe_Sv3c,3449
|
|
9
|
+
rcsb_embedding_model/dataset/residue_embedding_from_structure.py,sha256=dxfUNcVmdl8LrtQf1UJQ4E79e7R9LRsL0fjsq2GJQRk,2796
|
|
10
|
+
rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py,sha256=ehHQuLI2TrE5l4_4n6p3e30i17O1pXW92KOCn7bGtcg,1274
|
|
11
|
+
rcsb_embedding_model/inference/assembly_inferece.py,sha256=8fPJjEXy1WsM5XB5U7KfdO5-Du6nEsawsaAjmWoXA9I,2329
|
|
12
|
+
rcsb_embedding_model/inference/chain_inference.py,sha256=zTV_glkoErSYjVy0xfDRNtT8bVS0NGBnaNSUqp-CnoY,2700
|
|
13
|
+
rcsb_embedding_model/inference/esm_inference.py,sha256=3ny9vvHDSI7jpybDfMVXour52qiZ_av-2SL6h2yygEI,2341
|
|
14
|
+
rcsb_embedding_model/inference/structure_inference.py,sha256=lqbDBPSea8IoNyQXl83OcfXgLq4hmbD1DNvAwjetiPc,2231
|
|
15
15
|
rcsb_embedding_model/model/layers.py,sha256=lhKaWC4gTS_T5lHOP0mgnnP8nKTPEOm4MrjhESA4hE8,743
|
|
16
16
|
rcsb_embedding_model/model/residue_embedding_aggregator.py,sha256=k3UW63Ax8DtjCMdD3O5xNxtyAu28l2n3-Ab6nS0atm0,1967
|
|
17
17
|
rcsb_embedding_model/modules/chain_module.py,sha256=sDSPXJmWuU2C3lt1NorlbUVWZvRSLzumPdFQk01h3VI,403
|
|
18
18
|
rcsb_embedding_model/modules/esm_module.py,sha256=CTHGOATXiarqZsBsZ8oxGJBj20A73186Slpr0EzMJsE,770
|
|
19
19
|
rcsb_embedding_model/modules/structure_module.py,sha256=dEtDNdWo1j2sSDa0JiOHQfEfQzIWqSLEKpvOX0GrXZ4,1048
|
|
20
|
-
rcsb_embedding_model/types/api_types.py,sha256=
|
|
20
|
+
rcsb_embedding_model/types/api_types.py,sha256=SCwALwvEb0KRKaoWKbuN7JyfOH-1whsI0Z4ki41dht8,1235
|
|
21
21
|
rcsb_embedding_model/utils/data.py,sha256=ODz6GG6IAhgAlLh3tcIP6-JVHX8Bb_-E745Lvc_oR84,2934
|
|
22
22
|
rcsb_embedding_model/utils/model.py,sha256=rpZa-gfm3cEtbBd7UXMHrZv3x6f0AC8TJT3gtrSxr5I,852
|
|
23
23
|
rcsb_embedding_model/utils/structure_parser.py,sha256=IWMQ8brlEMe6_ND-DBESOli8vlqHxladTssjbM9RSKw,2751
|
|
24
24
|
rcsb_embedding_model/utils/structure_provider.py,sha256=eWtxjkPpmRfmil_DKR1J6miaXR3lQ28DF5O0qrqSgGA,786
|
|
25
|
-
rcsb_embedding_model/writer/batch_writer.py,sha256=
|
|
26
|
-
rcsb_embedding_model-0.0.
|
|
27
|
-
rcsb_embedding_model-0.0.
|
|
28
|
-
rcsb_embedding_model-0.0.
|
|
29
|
-
rcsb_embedding_model-0.0.
|
|
30
|
-
rcsb_embedding_model-0.0.
|
|
25
|
+
rcsb_embedding_model/writer/batch_writer.py,sha256=rTFNasB0Xp4-XCNTXKeEWZxSrb7lvZytoRldJUWn9Jg,3312
|
|
26
|
+
rcsb_embedding_model-0.0.17.dist-info/METADATA,sha256=ZIV-WqJGsSmZd79Ks455CTxSgN2J2JSD7vqr3XPx3nE,5368
|
|
27
|
+
rcsb_embedding_model-0.0.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
28
|
+
rcsb_embedding_model-0.0.17.dist-info/entry_points.txt,sha256=MK11jTIEmaV-x4CkPX5IymDaVs7Ky_f2xxU8BJVZ_9Q,69
|
|
29
|
+
rcsb_embedding_model-0.0.17.dist-info/licenses/LICENSE.md,sha256=oUaHiKgfBkChth_Sm67WemEvatO1U0Go8LHjaskXY0w,1522
|
|
30
|
+
rcsb_embedding_model-0.0.17.dist-info/RECORD,,
|
|
File without changes
|
{rcsb_embedding_model-0.0.15.dist-info → rcsb_embedding_model-0.0.17.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
{rcsb_embedding_model-0.0.15.dist-info → rcsb_embedding_model-0.0.17.dist-info}/licenses/LICENSE.md
RENAMED
|
File without changes
|