rcsb-embedding-model 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rcsb-embedding-model might be problematic. Click here for more details.

@@ -5,7 +5,7 @@ import typer
5
5
 
6
6
  from rcsb_embedding_model.cli.args_utils import arg_devices
7
7
  from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, SrcLocation, SrcProteinFrom, \
8
- StructureLocation, SrcAssemblyFrom, SrcTensorFrom
8
+ StructureLocation, SrcAssemblyFrom, SrcTensorFrom, OutFormat
9
9
 
10
10
  app = typer.Typer(
11
11
  add_completion=False
@@ -31,6 +31,12 @@ def residue_embedding(
31
31
  resolve_path=True,
32
32
  help='Output path to store predictions. Embeddings are stored as torch tensor files.'
33
33
  )],
34
+ output_format: Annotated[OutFormat, typer.Option(
35
+ help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
36
+ )] = OutFormat.separated,
37
+ output_name: Annotated[str, typer.Option(
38
+ help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
39
+ )] = 'inference',
34
40
  src_from: Annotated[SrcProteinFrom, typer.Option(
35
41
  help='Use specific chains or all chains in a structure.'
36
42
  )] = SrcProteinFrom.chain,
@@ -62,7 +68,7 @@ def residue_embedding(
62
68
  from rcsb_embedding_model.inference.esm_inference import predict
63
69
  predict(
64
70
  src_stream=src_file,
65
- src_location=SrcLocation.local,
71
+ src_location=SrcLocation.file,
66
72
  src_from=src_from,
67
73
  structure_location=structure_location,
68
74
  structure_format=structure_format,
@@ -72,6 +78,8 @@ def residue_embedding(
72
78
  num_nodes=num_nodes,
73
79
  accelerator=accelerator,
74
80
  devices=arg_devices(devices),
81
+ out_format=output_format,
82
+ out_name=output_name,
75
83
  out_path=output_path
76
84
  )
77
85
 
@@ -95,9 +103,9 @@ def structure_embedding(
95
103
  resolve_path=True,
96
104
  help='Output path to store predictions. Embeddings are stored as a single DataFrame file (see out-df-name).'
97
105
  )],
98
- out_df_name: Annotated[str, typer.Option(
99
- help='File name (without extension) for storing embeddings as a pandas DataFrame pickle (.pkl). The DataFrame contains 2 columns: Id | Embedding'
100
- )],
106
+ output_name: Annotated[str, typer.Option(
107
+ help='File name for storing embeddings as a single JSON file.'
108
+ )] = 'inference',
101
109
  src_from: Annotated[SrcProteinFrom, typer.Option(
102
110
  help='Use specific chains or all chains in a structure.'
103
111
  )] = SrcProteinFrom.chain,
@@ -129,7 +137,7 @@ def structure_embedding(
129
137
  from rcsb_embedding_model.inference.structure_inference import predict
130
138
  predict(
131
139
  src_stream=src_file,
132
- src_location=SrcLocation.local,
140
+ src_location=SrcLocation.file,
133
141
  src_from=src_from,
134
142
  structure_location=structure_location,
135
143
  structure_format=structure_format,
@@ -140,7 +148,7 @@ def structure_embedding(
140
148
  accelerator=accelerator,
141
149
  devices=arg_devices(devices),
142
150
  out_path=output_path,
143
- out_df_name=out_df_name
151
+ out_name=output_name
144
152
  )
145
153
 
146
154
 
@@ -163,6 +171,12 @@ def chain_embedding(
163
171
  resolve_path=True,
164
172
  help='Output path to store predictions. Embeddings are stored as csv files.'
165
173
  )],
174
+ output_format: Annotated[OutFormat, typer.Option(
175
+ help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
176
+ )] = OutFormat.separated,
177
+ output_name: Annotated[str, typer.Option(
178
+ help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
179
+ )] = 'inference',
166
180
  res_embedding_location: Annotated[typer.FileText, typer.Option(
167
181
  exists=True,
168
182
  file_okay=False,
@@ -202,7 +216,7 @@ def chain_embedding(
202
216
  predict(
203
217
  src_stream=src_file,
204
218
  res_embedding_location=res_embedding_location,
205
- src_location=SrcLocation.local,
219
+ src_location=SrcLocation.file,
206
220
  src_from=src_from,
207
221
  structure_location=structure_location,
208
222
  structure_format=structure_format,
@@ -212,7 +226,9 @@ def chain_embedding(
212
226
  num_nodes=num_nodes,
213
227
  accelerator=accelerator,
214
228
  devices=arg_devices(devices),
215
- out_path=output_path
229
+ out_path=output_path,
230
+ out_format=output_format,
231
+ out_name=output_name
216
232
  )
217
233
 
218
234
  @app.command(
@@ -241,6 +257,12 @@ def assembly_embedding(
241
257
  resolve_path=True,
242
258
  help='Output path to store predictions. Embeddings are stored as csv files.'
243
259
  )],
260
+ output_format: Annotated[OutFormat, typer.Option(
261
+ help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
262
+ )] = OutFormat.separated,
263
+ output_name: Annotated[str, typer.Option(
264
+ help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
265
+ )] = 'inference',
244
266
  src_from: Annotated[SrcAssemblyFrom, typer.Option(
245
267
  help='Use specific assembly or all assemblies in a structure.'
246
268
  )] = SrcAssemblyFrom.assembly,
@@ -276,7 +298,7 @@ def assembly_embedding(
276
298
  predict(
277
299
  src_stream=src_file,
278
300
  res_embedding_location=res_embedding_location,
279
- src_location=SrcLocation.local,
301
+ src_location=SrcLocation.file,
280
302
  src_from=src_from,
281
303
  structure_location=structure_location,
282
304
  structure_format=structure_format,
@@ -287,7 +309,113 @@ def assembly_embedding(
287
309
  num_nodes=num_nodes,
288
310
  accelerator=accelerator,
289
311
  devices=arg_devices(devices),
290
- out_path=output_path
312
+ out_path=output_path,
313
+ out_format=output_format,
314
+ out_name=output_name
315
+ )
316
+
317
+ @app.command(
318
+ name="complete-embedding",
319
+ help="Calculate chain and assembly embeddings from structural files. Predictions are stored as csv files."
320
+ )
321
+ def complete_embedding(
322
+ src_file: Annotated[typer.FileText, typer.Option(
323
+ exists=True,
324
+ file_okay=True,
325
+ dir_okay=False,
326
+ resolve_path=True,
327
+ help='CSV file 3 columns: Structure Name | Structure File Path or URL | Chain Id (asym_i for cif files. This field is required if src-from=chain) | Output Embedding Name.'
328
+ )],
329
+ output_path: Annotated[typer.FileText, typer.Option(
330
+ exists=True,
331
+ file_okay=False,
332
+ dir_okay=True,
333
+ resolve_path=True,
334
+ help='Output path to store predictions. Embeddings are stored as a single DataFrame file (see output_name).'
335
+ )],
336
+ esm_output_path: Annotated[typer.FileText, typer.Option(
337
+ exists=True,
338
+ file_okay=False,
339
+ dir_okay=True,
340
+ resolve_path=True,
341
+ help='Output path to store ESM predictions.'
342
+ )],
343
+ output_format: Annotated[OutFormat, typer.Option(
344
+ help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
345
+ )] = OutFormat.separated,
346
+ output_name: Annotated[str, typer.Option(
347
+ help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
348
+ )] = 'inference',
349
+ structure_location: Annotated[StructureLocation, typer.Option(
350
+ help='Structure file location.'
351
+ )] = StructureLocation.local,
352
+ structure_format: Annotated[StructureFormat, typer.Option(
353
+ help='Structure file format.'
354
+ )] = StructureFormat.mmcif,
355
+ min_res_n: Annotated[int, typer.Option(
356
+ help='When using all chains in a structure, consider only chains with more than <min_res_n> residues.'
357
+ )] = 0,
358
+ batch_size: Annotated[int, typer.Option(
359
+ help='Number of samples processed together in one iteration.'
360
+ )] = 1,
361
+ num_workers: Annotated[int, typer.Option(
362
+ help='Number of subprocesses to use for data loading.'
363
+ )] = 0,
364
+ num_nodes: Annotated[int, typer.Option(
365
+ help='Number of nodes to use for inference.'
366
+ )] = 1,
367
+ accelerator: Annotated[Accelerator, typer.Option(
368
+ help='Device used for inference.'
369
+ )] = Accelerator.auto,
370
+ devices: Annotated[List[str], typer.Option(
371
+ help='The devices to use. Can be set to a positive number or "auto". Repeat this argument to indicate multiple indices of devices. "auto" for automatic selection based on the chosen accelerator.'
372
+ )] = tuple(['auto'])
373
+ ):
374
+ residue_embedding(
375
+ src_file=src_file,
376
+ src_from=SrcProteinFrom.structure,
377
+ output_path=esm_output_path,
378
+ output_format=OutFormat.separated,
379
+ structure_location=structure_location,
380
+ structure_format=structure_format,
381
+ min_res_n=min_res_n,
382
+ batch_size=batch_size,
383
+ num_workers=num_workers,
384
+ num_nodes=num_nodes,
385
+ accelerator=accelerator,
386
+ devices=devices,
387
+ )
388
+ chain_embedding(
389
+ src_file=src_file,
390
+ src_from=SrcTensorFrom.structure,
391
+ output_path=output_path,
392
+ output_format=output_format,
393
+ output_name=f"{output_name}-chain",
394
+ res_embedding_location=esm_output_path,
395
+ structure_location=structure_location,
396
+ structure_format=structure_format,
397
+ min_res_n=min_res_n,
398
+ batch_size=batch_size,
399
+ num_workers=num_workers,
400
+ num_nodes=num_nodes,
401
+ accelerator=accelerator,
402
+ devices=devices
403
+ )
404
+ assembly_embedding(
405
+ src_file=src_file,
406
+ src_from=SrcAssemblyFrom.structure,
407
+ output_path=output_path,
408
+ output_format=output_format,
409
+ output_name=f"{output_name}-assembly",
410
+ res_embedding_location=esm_output_path,
411
+ structure_location=structure_location,
412
+ structure_format=structure_format,
413
+ min_res_n=min_res_n,
414
+ batch_size=batch_size,
415
+ num_workers=num_workers,
416
+ num_nodes=num_nodes,
417
+ accelerator=accelerator,
418
+ devices=devices
291
419
  )
292
420
 
293
421
 
@@ -27,7 +27,7 @@ class EsmProtFromChain(Dataset):
27
27
  def __init__(
28
28
  self,
29
29
  src_stream,
30
- src_location=SrcLocation.local,
30
+ src_location=SrcLocation.file,
31
31
  structure_location=StructureLocation.local,
32
32
  structure_format=StructureFormat.mmcif,
33
33
  structure_provider=StructureProvider()
@@ -19,7 +19,7 @@ class EsmProtFromStructure(EsmProtFromChain):
19
19
  def __init__(
20
20
  self,
21
21
  src_stream,
22
- src_location=SrcLocation.local,
22
+ src_location=SrcLocation.file,
23
23
  structure_location=StructureLocation.local,
24
24
  structure_format=StructureFormat.mmcif,
25
25
  min_res_n=0,
@@ -21,7 +21,7 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
21
21
  self,
22
22
  src_stream,
23
23
  res_embedding_location,
24
- src_location=SrcLocation.local,
24
+ src_location=SrcLocation.file,
25
25
  structure_location=StructureLocation.local,
26
26
  structure_format=StructureFormat.mmcif,
27
27
  min_res_n=0,
@@ -63,6 +63,6 @@ class ResidueAssemblyDatasetFromStructure(ResidueAssemblyEmbeddingFromTensorFile
63
63
  structure = stringio_from_url(src_structure) if self.structure_location == StructureLocation.remote else src_structure
64
64
  item_name = row[ResidueAssemblyDatasetFromStructure.ITEM_NAME_ATTR]
65
65
  for assembly_id in get_assemblies(structure=structure, structure_format=self.structure_format):
66
- assemblies.append((src_name, src_structure, str(assembly_id), f"{item_name}.{assembly_id}"))
66
+ assemblies.append((src_name, src_structure, str(assembly_id), f"{item_name}-{assembly_id}"))
67
67
 
68
68
  return tuple(assemblies)
@@ -22,7 +22,7 @@ class ResidueAssemblyEmbeddingFromTensorFile(Dataset):
22
22
  self,
23
23
  src_stream,
24
24
  res_embedding_location,
25
- src_location=SrcLocation.local,
25
+ src_location=SrcLocation.file,
26
26
  structure_location=StructureLocation.local,
27
27
  structure_format=StructureFormat.mmcif,
28
28
  min_res_n=0,
@@ -79,7 +79,7 @@ if __name__ == "__main__":
79
79
  dataset = ResidueAssemblyEmbeddingFromTensorFile(
80
80
  src_stream="/Users/joan/tmp/assembly-test.csv",
81
81
  res_embedding_location="/Users/joan/tmp",
82
- src_location=SrcLocation.local,
82
+ src_location=SrcLocation.file,
83
83
  structure_location=StructureLocation.local,
84
84
  structure_format=StructureFormat.mmcif
85
85
  )
@@ -21,7 +21,7 @@ class ResidueEmbeddingFromStructure(ResidueEmbeddingFromTensorFile):
21
21
  self,
22
22
  src_stream,
23
23
  res_embedding_location,
24
- src_location=SrcLocation.local,
24
+ src_location=SrcLocation.file,
25
25
  structure_location=StructureLocation.local,
26
26
  structure_format=StructureFormat.mmcif,
27
27
  min_res_n=0,
@@ -15,7 +15,7 @@ class ResidueEmbeddingFromTensorFile(Dataset):
15
15
  def __init__(
16
16
  self,
17
17
  src_stream,
18
- src_location=SrcLocation.local
18
+ src_location=SrcLocation.file
19
19
  ):
20
20
  super().__init__()
21
21
  self.src_location = src_location
@@ -2,14 +2,15 @@ import sys
2
2
 
3
3
  from rcsb_embedding_model.dataset.resdiue_assembly_embedding_from_structure import ResidueAssemblyDatasetFromStructure
4
4
  from rcsb_embedding_model.dataset.residue_assembly_embedding_from_tensor_file import ResidueAssemblyEmbeddingFromTensorFile
5
- from rcsb_embedding_model.types.api_types import FileOrStreamTuple, SrcLocation, Accelerator, Devices, OptionalPath, EmbeddingPath, StructureLocation, StructureFormat, SrcAssemblyFrom
5
+ from rcsb_embedding_model.types.api_types import FileOrStreamTuple, SrcLocation, Accelerator, Devices, OptionalPath, \
6
+ EmbeddingPath, StructureLocation, StructureFormat, SrcAssemblyFrom, OutFormat
6
7
  from rcsb_embedding_model.inference.chain_inference import predict as chain_predict
7
8
 
8
9
 
9
10
  def predict(
10
11
  src_stream: FileOrStreamTuple,
11
12
  res_embedding_location: EmbeddingPath,
12
- src_location: SrcLocation = SrcLocation.local,
13
+ src_location: SrcLocation = SrcLocation.file,
13
14
  src_from: SrcAssemblyFrom = SrcAssemblyFrom.assembly,
14
15
  structure_location: StructureLocation = StructureLocation.local,
15
16
  structure_format: StructureFormat = StructureFormat.mmcif,
@@ -20,6 +21,8 @@ def predict(
20
21
  num_nodes: int = 1,
21
22
  accelerator: Accelerator = Accelerator.auto,
22
23
  devices: Devices = 'auto',
24
+ out_format: OutFormat = OutFormat.separated,
25
+ out_name: str = 'inference',
23
26
  out_path: OptionalPath = None
24
27
  ):
25
28
  inference_set = ResidueAssemblyEmbeddingFromTensorFile(
@@ -48,6 +51,8 @@ def predict(
48
51
  num_nodes=num_nodes,
49
52
  accelerator=accelerator,
50
53
  devices=devices,
54
+ out_format=out_format,
55
+ out_name=out_name,
51
56
  out_path=out_path,
52
57
  inference_set=inference_set
53
58
  )
@@ -5,15 +5,15 @@ from rcsb_embedding_model.dataset.residue_embedding_from_structure import Residu
5
5
  from rcsb_embedding_model.dataset.residue_embedding_from_tensor_file import ResidueEmbeddingFromTensorFile
6
6
  from rcsb_embedding_model.modules.chain_module import ChainModule
7
7
  from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath, FileOrStreamTuple, SrcLocation, \
8
- SrcTensorFrom, StructureLocation, StructureFormat
8
+ SrcTensorFrom, StructureLocation, StructureFormat, OutFormat
9
9
  from rcsb_embedding_model.utils.data import collate_seq_embeddings
10
- from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter
10
+ from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter, JsonStorage
11
11
 
12
12
 
13
13
  def predict(
14
14
  src_stream: FileOrStreamTuple,
15
15
  res_embedding_location: OptionalPath = None,
16
- src_location: SrcLocation = SrcLocation.local,
16
+ src_location: SrcLocation = SrcLocation.file,
17
17
  src_from: SrcTensorFrom = SrcTensorFrom.file,
18
18
  structure_location: StructureLocation = StructureLocation.local,
19
19
  structure_format: StructureFormat = StructureFormat.mmcif,
@@ -23,6 +23,8 @@ def predict(
23
23
  num_nodes: int = 1,
24
24
  accelerator: Accelerator = Accelerator.auto,
25
25
  devices: Devices = 'auto',
26
+ out_format: OutFormat = OutFormat.separated,
27
+ out_name: str = 'inference',
26
28
  out_path: OptionalPath = None,
27
29
  inference_set=None
28
30
  ):
@@ -51,13 +53,13 @@ def predict(
51
53
  )
52
54
 
53
55
  module = ChainModule()
54
-
55
- inference_writer = CsvBatchWriter(out_path) if out_path is not None else None
56
+ inference_writer = (JsonStorage(out_path, out_name) if out_format == OutFormat.grouped else CsvBatchWriter(out_path)) if out_path is not None else None
56
57
  trainer = Trainer(
57
58
  callbacks=[inference_writer] if inference_writer is not None else None,
58
59
  num_nodes=num_nodes,
59
60
  accelerator=accelerator,
60
- devices=devices
61
+ devices=devices,
62
+ logger=False
61
63
  )
62
64
 
63
65
  prediction = trainer.predict(
@@ -4,13 +4,14 @@ from lightning import Trainer
4
4
  from rcsb_embedding_model.dataset.esm_prot_from_structure import EsmProtFromStructure
5
5
  from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
6
6
  from rcsb_embedding_model.modules.esm_module import EsmModule
7
- from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, SrcProteinFrom, FileOrStreamTuple, SrcLocation
8
- from rcsb_embedding_model.writer.batch_writer import TensorBatchWriter
7
+ from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, \
8
+ SrcProteinFrom, FileOrStreamTuple, SrcLocation, OutFormat
9
+ from rcsb_embedding_model.writer.batch_writer import TensorBatchWriter, JsonStorage
9
10
 
10
11
 
11
12
  def predict(
12
13
  src_stream: FileOrStreamTuple,
13
- src_location: SrcLocation = SrcLocation.local,
14
+ src_location: SrcLocation = SrcLocation.file,
14
15
  src_from: SrcProteinFrom = SrcProteinFrom.chain,
15
16
  structure_location: StructureLocation = StructureLocation.local,
16
17
  structure_format: StructureFormat = StructureFormat.mmcif,
@@ -20,6 +21,8 @@ def predict(
20
21
  num_nodes: int = 1,
21
22
  accelerator: Accelerator = Accelerator.auto,
22
23
  devices: Devices = 'auto',
24
+ out_format: OutFormat = OutFormat.separated,
25
+ out_name: str = 'inference',
23
26
  out_path: OptionalPath = None
24
27
  ):
25
28
 
@@ -44,12 +47,13 @@ def predict(
44
47
  )
45
48
 
46
49
  module = EsmModule()
47
- inference_writer = TensorBatchWriter(out_path) if out_path is not None else None
50
+ inference_writer = (JsonStorage(out_path, out_name) if out_format == OutFormat.grouped else TensorBatchWriter(out_path)) if out_path is not None else None
48
51
  trainer = Trainer(
49
52
  callbacks=[inference_writer] if inference_writer is not None else None,
50
53
  num_nodes=num_nodes,
51
54
  accelerator=accelerator,
52
- devices=devices
55
+ devices=devices,
56
+ logger=False
53
57
  )
54
58
 
55
59
  prediction = trainer.predict(
@@ -4,13 +4,14 @@ from lightning import Trainer
4
4
  from rcsb_embedding_model.dataset.esm_prot_from_structure import EsmProtFromStructure
5
5
  from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
6
6
  from rcsb_embedding_model.modules.structure_module import StructureModule
7
- from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, SrcProteinFrom, FileOrStreamTuple, SrcLocation
8
- from rcsb_embedding_model.writer.batch_writer import DataFrameStorage
7
+ from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, \
8
+ SrcProteinFrom, FileOrStreamTuple, SrcLocation
9
+ from rcsb_embedding_model.writer.batch_writer import JsonStorage
9
10
 
10
11
 
11
12
  def predict(
12
13
  src_stream: FileOrStreamTuple,
13
- src_location: SrcLocation = SrcLocation.local,
14
+ src_location: SrcLocation = SrcLocation.file,
14
15
  src_from: SrcProteinFrom = SrcProteinFrom.chain,
15
16
  structure_location: StructureLocation = StructureLocation.local,
16
17
  structure_format: StructureFormat = StructureFormat.mmcif,
@@ -20,8 +21,8 @@ def predict(
20
21
  num_nodes: int = 1,
21
22
  accelerator: Accelerator = Accelerator.auto,
22
23
  devices: Devices = 'auto',
23
- out_path: OptionalPath = None,
24
- out_df_name: str = None
24
+ out_name: str = 'inference',
25
+ out_path: OptionalPath = None
25
26
  ):
26
27
 
27
28
  inference_set = EsmProtFromChain(
@@ -45,12 +46,13 @@ def predict(
45
46
  )
46
47
 
47
48
  module = StructureModule()
48
- inference_writer = DataFrameStorage(out_path, out_df_name) if out_path is not None and out_df_name is not None else None
49
+ inference_writer = JsonStorage(out_path, out_name) if out_path is not None and out_name is not None else None
49
50
  trainer = Trainer(
50
51
  callbacks=[inference_writer] if inference_writer is not None else None,
51
52
  num_nodes=num_nodes,
52
53
  accelerator=accelerator,
53
- devices=devices
54
+ devices=devices,
55
+ logger=False
54
56
  )
55
57
 
56
58
  prediction = trainer.predict(
@@ -32,7 +32,7 @@ class Accelerator(str, Enum):
32
32
 
33
33
 
34
34
  class SrcLocation(str, Enum):
35
- local = "local"
35
+ file = "file"
36
36
  stream = "stream"
37
37
 
38
38
 
@@ -54,3 +54,7 @@ class SrcAssemblyFrom(str, Enum):
54
54
  class SrcTensorFrom(str, Enum):
55
55
  file = "file"
56
56
  structure = "structure"
57
+
58
+ class OutFormat(str, Enum):
59
+ separated = "separated"
60
+ grouped = "grouped"
@@ -111,3 +111,21 @@ class DataFrameStorage(CoreBatchWriter, ABC):
111
111
  f"{self.out_path}/{self.df_id}.pkl.gz",
112
112
  compression='gzip'
113
113
  )
114
+
115
+
116
+ class JsonStorage(DataFrameStorage, ABC):
117
+ def __init__(
118
+ self,
119
+ output_path,
120
+ df_id,
121
+ postfix="pkl",
122
+ write_interval="batch"
123
+ ):
124
+ super().__init__(output_path, df_id, postfix, write_interval)
125
+
126
+ def on_predict_end(self, trainer, pl_module):
127
+ self.embedding.to_json(
128
+ f"{self.out_path}/{self.df_id}.json.gz",
129
+ orient='records',
130
+ compression='gzip'
131
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rcsb-embedding-model
3
- Version: 0.0.15
3
+ Version: 0.0.17
4
4
  Summary: Protein Embedding Model for Structure Search
5
5
  Project-URL: Homepage, https://github.com/rcsb/rcsb-embedding-model
6
6
  Project-URL: Issues, https://github.com/rcsb/rcsb-embedding-model/issues
@@ -18,7 +18,7 @@ Description-Content-Type: text/markdown
18
18
 
19
19
  # RCSB Embedding Model
20
20
 
21
- **Version** 0.0.15
21
+ **Version** 0.0.17
22
22
 
23
23
 
24
24
  ## Overview
@@ -1,30 +1,30 @@
1
1
  rcsb_embedding_model/__init__.py,sha256=r3gLdeBIXkQEQA_K6QcRPO-TtYuAQSutk6pXRUE_nas,120
2
2
  rcsb_embedding_model/rcsb_structure_embedding.py,sha256=dKp9hXQO0JAnO4SEfjJ_mG_jHu3UxAPguv6jkOjp-BI,4487
3
3
  rcsb_embedding_model/cli/args_utils.py,sha256=7nP2q8pL5dWK_U7opxtWmoFcYVwasky6elHk-dASFaI,165
4
- rcsb_embedding_model/cli/inference.py,sha256=0DZHw4QeAi2f6xdfoEPzYb_gQhCWc_IPA1QgnckcUIg,12916
5
- rcsb_embedding_model/dataset/esm_prot_from_chain.py,sha256=dBD2N0Y-GoN6p3z2yLnOvv6JGn-skAxwgbOYhXKDngc,3487
6
- rcsb_embedding_model/dataset/esm_prot_from_structure.py,sha256=9IvurGr7PGjfAABoGoMlG08zn6mC6iVAjgExGSrDVdQ,2552
7
- rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py,sha256=10NUHnjTE5xSXPFVTfeuL8MpOhqk-f3ZIG7EbWR49B4,2867
8
- rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py,sha256=KXiohnPjjfZEFbPZQ46HGE8eEYWrVX8bfbTz4zPlo7o,3451
9
- rcsb_embedding_model/dataset/residue_embedding_from_structure.py,sha256=9MfgKvFAxYr9RU8kwvHnEZBH35gukx8hRPeoBXfyNXo,2797
10
- rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py,sha256=mDCqJrpnu2GXmp75zOPTH8ogL3GWDqc3iEH62JuyHVs,1275
11
- rcsb_embedding_model/inference/assembly_inferece.py,sha256=MPssN5bsOqOU-LGwa6AKX99cv5LD43Mnbaqhuuww1Tw,2165
12
- rcsb_embedding_model/inference/chain_inference.py,sha256=N92Wfu-UNkhmSlQ0153BA1idECj1NgEcl35Zis9Q2js,2492
13
- rcsb_embedding_model/inference/esm_inference.py,sha256=oVN4r9_6V8TS0pYoNn7GR92Xo0Zn7eBsnt_OfDSaH6g,2126
14
- rcsb_embedding_model/inference/structure_inference.py,sha256=QIUEo8eEc-kTSYKGdlX2rxT74huw4ZAw6U8Px9kYajE,2216
4
+ rcsb_embedding_model/cli/inference.py,sha256=XmnRwygWYQkPqeJi4I1H2jjo24IxXzt_EihdYZ7LLqA,18696
5
+ rcsb_embedding_model/dataset/esm_prot_from_chain.py,sha256=u6vu_2CN6vaYAk6kpvHAOgHuEHjXJl3fukMk-tDr_6E,3486
6
+ rcsb_embedding_model/dataset/esm_prot_from_structure.py,sha256=TeITPdi1uc3qLQ-Pgn807oH6eM0LYv-67RE50ZT4dLI,2551
7
+ rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py,sha256=worRiNqOJRjyr693TaillsS65bdTdGOoHfwyT9yE1O4,2866
8
+ rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py,sha256=JG4rrhziIUtdTmbuTbMbEYHrvlda4m5VWvdJXe_Sv3c,3449
9
+ rcsb_embedding_model/dataset/residue_embedding_from_structure.py,sha256=dxfUNcVmdl8LrtQf1UJQ4E79e7R9LRsL0fjsq2GJQRk,2796
10
+ rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py,sha256=ehHQuLI2TrE5l4_4n6p3e30i17O1pXW92KOCn7bGtcg,1274
11
+ rcsb_embedding_model/inference/assembly_inferece.py,sha256=8fPJjEXy1WsM5XB5U7KfdO5-Du6nEsawsaAjmWoXA9I,2329
12
+ rcsb_embedding_model/inference/chain_inference.py,sha256=zTV_glkoErSYjVy0xfDRNtT8bVS0NGBnaNSUqp-CnoY,2700
13
+ rcsb_embedding_model/inference/esm_inference.py,sha256=3ny9vvHDSI7jpybDfMVXour52qiZ_av-2SL6h2yygEI,2341
14
+ rcsb_embedding_model/inference/structure_inference.py,sha256=lqbDBPSea8IoNyQXl83OcfXgLq4hmbD1DNvAwjetiPc,2231
15
15
  rcsb_embedding_model/model/layers.py,sha256=lhKaWC4gTS_T5lHOP0mgnnP8nKTPEOm4MrjhESA4hE8,743
16
16
  rcsb_embedding_model/model/residue_embedding_aggregator.py,sha256=k3UW63Ax8DtjCMdD3O5xNxtyAu28l2n3-Ab6nS0atm0,1967
17
17
  rcsb_embedding_model/modules/chain_module.py,sha256=sDSPXJmWuU2C3lt1NorlbUVWZvRSLzumPdFQk01h3VI,403
18
18
  rcsb_embedding_model/modules/esm_module.py,sha256=CTHGOATXiarqZsBsZ8oxGJBj20A73186Slpr0EzMJsE,770
19
19
  rcsb_embedding_model/modules/structure_module.py,sha256=dEtDNdWo1j2sSDa0JiOHQfEfQzIWqSLEKpvOX0GrXZ4,1048
20
- rcsb_embedding_model/types/api_types.py,sha256=JSHd5Rq7dm6uWNzy1UZnLkWKxfjsKB7gRRTCSqS4r7c,1156
20
+ rcsb_embedding_model/types/api_types.py,sha256=SCwALwvEb0KRKaoWKbuN7JyfOH-1whsI0Z4ki41dht8,1235
21
21
  rcsb_embedding_model/utils/data.py,sha256=ODz6GG6IAhgAlLh3tcIP6-JVHX8Bb_-E745Lvc_oR84,2934
22
22
  rcsb_embedding_model/utils/model.py,sha256=rpZa-gfm3cEtbBd7UXMHrZv3x6f0AC8TJT3gtrSxr5I,852
23
23
  rcsb_embedding_model/utils/structure_parser.py,sha256=IWMQ8brlEMe6_ND-DBESOli8vlqHxladTssjbM9RSKw,2751
24
24
  rcsb_embedding_model/utils/structure_provider.py,sha256=eWtxjkPpmRfmil_DKR1J6miaXR3lQ28DF5O0qrqSgGA,786
25
- rcsb_embedding_model/writer/batch_writer.py,sha256=ekgzFZyoKpcnZ3IDP9hfOWBpuHxUQ31P35ViDAi-Edw,2843
26
- rcsb_embedding_model-0.0.15.dist-info/METADATA,sha256=jkv3kF6L-VO34k8C6K-PCsU1BxTPxHglvXdsrGjjrSU,5368
27
- rcsb_embedding_model-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
28
- rcsb_embedding_model-0.0.15.dist-info/entry_points.txt,sha256=MK11jTIEmaV-x4CkPX5IymDaVs7Ky_f2xxU8BJVZ_9Q,69
29
- rcsb_embedding_model-0.0.15.dist-info/licenses/LICENSE.md,sha256=oUaHiKgfBkChth_Sm67WemEvatO1U0Go8LHjaskXY0w,1522
30
- rcsb_embedding_model-0.0.15.dist-info/RECORD,,
25
+ rcsb_embedding_model/writer/batch_writer.py,sha256=rTFNasB0Xp4-XCNTXKeEWZxSrb7lvZytoRldJUWn9Jg,3312
26
+ rcsb_embedding_model-0.0.17.dist-info/METADATA,sha256=ZIV-WqJGsSmZd79Ks455CTxSgN2J2JSD7vqr3XPx3nE,5368
27
+ rcsb_embedding_model-0.0.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
28
+ rcsb_embedding_model-0.0.17.dist-info/entry_points.txt,sha256=MK11jTIEmaV-x4CkPX5IymDaVs7Ky_f2xxU8BJVZ_9Q,69
29
+ rcsb_embedding_model-0.0.17.dist-info/licenses/LICENSE.md,sha256=oUaHiKgfBkChth_Sm67WemEvatO1U0Go8LHjaskXY0w,1522
30
+ rcsb_embedding_model-0.0.17.dist-info/RECORD,,