rcsb-embedding-model 0.0.21__tar.gz → 0.0.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rcsb-embedding-model might be problematic. Click here for more details.

Files changed (46) hide show
  1. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/PKG-INFO +2 -2
  2. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/README.md +1 -1
  3. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/pyproject.toml +1 -1
  4. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/cli/inference.py +44 -22
  5. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/inference/chain_inference.py +6 -1
  6. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/inference/esm_inference.py +6 -1
  7. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/inference/structure_inference.py +8 -1
  8. rcsb_embedding_model-0.0.22/src/rcsb_embedding_model/modules/chain_module.py +19 -0
  9. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/modules/esm_module.py +7 -4
  10. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/modules/structure_module.py +9 -4
  11. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/utils/model.py +2 -0
  12. rcsb_embedding_model-0.0.21/src/rcsb_embedding_model/modules/chain_module.py +0 -16
  13. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/.github/workflows/_workflow-docker.yaml +0 -0
  14. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/.github/workflows/publish.yaml +0 -0
  15. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/.gitignore +0 -0
  16. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/Dockerfile +0 -0
  17. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/LICENSE.md +0 -0
  18. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/assets/embedding-model-architecture.png +0 -0
  19. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/examples/esm_embeddings.py +0 -0
  20. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/__init__.py +0 -0
  21. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/cli/args_utils.py +0 -0
  22. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/dataset/esm_prot_from_chain.py +0 -0
  23. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/dataset/esm_prot_from_structure.py +0 -0
  24. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py +0 -0
  25. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py +0 -0
  26. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/dataset/residue_embedding_from_structure.py +0 -0
  27. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py +0 -0
  28. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/inference/assembly_inferece.py +0 -0
  29. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/model/layers.py +0 -0
  30. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/model/residue_embedding_aggregator.py +0 -0
  31. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/rcsb_structure_embedding.py +0 -0
  32. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/types/api_types.py +0 -0
  33. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/utils/data.py +0 -0
  34. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/utils/structure_parser.py +0 -0
  35. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/utils/structure_provider.py +0 -0
  36. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/writer/batch_writer.py +0 -0
  37. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/tests/resources/embeddings/1acb.A.pt +0 -0
  38. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/tests/resources/embeddings/1acb.B.pt +0 -0
  39. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/tests/resources/embeddings/2uzi.A.pt +0 -0
  40. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/tests/resources/embeddings/2uzi.B.pt +0 -0
  41. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/tests/resources/embeddings/2uzi.C.pt +0 -0
  42. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/tests/resources/pdb/1acb.cif +0 -0
  43. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/tests/resources/pdb/2uzi.cif +0 -0
  44. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/tests/test_embedding_model.py +0 -0
  45. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/tests/test_inference.py +0 -0
  46. {rcsb_embedding_model-0.0.21 → rcsb_embedding_model-0.0.22}/tests/test_remote_inference.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rcsb-embedding-model
3
- Version: 0.0.21
3
+ Version: 0.0.22
4
4
  Summary: Protein Embedding Model for Structure Search
5
5
  Project-URL: Homepage, https://github.com/rcsb/rcsb-embedding-model
6
6
  Project-URL: Issues, https://github.com/rcsb/rcsb-embedding-model/issues
@@ -17,7 +17,7 @@ Description-Content-Type: text/markdown
17
17
 
18
18
  # RCSB Embedding Model
19
19
 
20
- **Version** 0.0.21
20
+ **Version** 0.0.22
21
21
 
22
22
 
23
23
  ## Overview
@@ -1,6 +1,6 @@
1
1
  # RCSB Embedding Model
2
2
 
3
- **Version** 0.0.21
3
+ **Version** 0.0.22
4
4
 
5
5
 
6
6
  ## Overview
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rcsb-embedding-model"
3
- version = "0.0.21"
3
+ version = "0.0.22"
4
4
  authors = [
5
5
  { name="Joan Segura", email="joan.segura@rcsb.org" },
6
6
  ]
@@ -325,26 +325,36 @@ def complete_embedding(
325
325
  resolve_path=True,
326
326
  help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Assembly Id | Output embedding name.'
327
327
  )],
328
- output_path: Annotated[typer.FileText, typer.Option(
328
+ output_res_path: Annotated[typer.FileText, typer.Option(
329
329
  exists=True,
330
330
  file_okay=False,
331
331
  dir_okay=True,
332
332
  resolve_path=True,
333
- help='Output path to store predictions. Embeddings are stored as a single DataFrame file (see output_name).'
333
+ help='Output path to store residue embeddings. Residue embeddings are stored in separated files'
334
334
  )],
335
- res_embedding_location: Annotated[typer.FileText, typer.Option(
335
+ output_chain_path: Annotated[typer.FileText, typer.Option(
336
+ exists=True,
337
+ file_okay=False,
338
+ dir_okay=True,
339
+ resolve_path=True,
340
+ help='Output path to store chain embeddings. Embeddings are stored as a single JSON file (see output_chain_name).'
341
+ )],
342
+ output_assembly_path: Annotated[typer.FileText, typer.Option(
336
343
  exists=True,
337
344
  file_okay=False,
338
345
  dir_okay=True,
339
346
  resolve_path=True,
340
- help='Output path to store ESM predictions.'
347
+ help='Output path to store assembly embeddings. Embeddings are stored as a single JSON file (see output_assembly_name).'
341
348
  )],
342
349
  output_format: Annotated[OutFormat, typer.Option(
343
350
  help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
344
351
  )] = OutFormat.separated,
345
- output_name: Annotated[str, typer.Option(
346
- help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
347
- )] = 'inference',
352
+ output_chain_name: Annotated[str, typer.Option(
353
+ help='File name for storing chain embeddings as a single JSON file. Used when output-format=grouped.'
354
+ )] = 'chain-inference',
355
+ output_assembly_name: Annotated[str, typer.Option(
356
+ help='File name for storing chain embeddings as a single JSON file. Used when output-format=grouped.'
357
+ )] = 'chain-inference',
348
358
  structure_location: Annotated[StructureLocation, typer.Option(
349
359
  help='Structure file location.'
350
360
  )] = StructureLocation.local,
@@ -354,10 +364,22 @@ def complete_embedding(
354
364
  min_res_n: Annotated[int, typer.Option(
355
365
  help='When using all chains in a structure, consider only chains with more than <min_res_n> residues.'
356
366
  )] = 0,
357
- batch_size: Annotated[int, typer.Option(
367
+ batch_size_res: Annotated[int, typer.Option(
358
368
  help='Number of samples processed together in one iteration.'
359
369
  )] = 1,
360
- num_workers: Annotated[int, typer.Option(
370
+ num_workers_res: Annotated[int, typer.Option(
371
+ help='Number of subprocesses to use for data loading.'
372
+ )] = 0,
373
+ batch_size_chain: Annotated[int, typer.Option(
374
+ help='Number of samples processed together in one iteration.'
375
+ )] = 1,
376
+ num_workers_chain: Annotated[int, typer.Option(
377
+ help='Number of subprocesses to use for data loading.'
378
+ )] = 0,
379
+ batch_size_assembly: Annotated[int, typer.Option(
380
+ help='Number of samples processed together in one iteration.'
381
+ )] = 1,
382
+ num_workers_assembly: Annotated[int, typer.Option(
361
383
  help='Number of subprocesses to use for data loading.'
362
384
  )] = 0,
363
385
  num_nodes: Annotated[int, typer.Option(
@@ -372,43 +394,43 @@ def complete_embedding(
372
394
  ):
373
395
  residue_embedding(
374
396
  src_file=src_chain_file,
375
- output_path=res_embedding_location,
397
+ output_path=output_res_path,
376
398
  output_format=OutFormat.separated,
377
399
  structure_location=structure_location,
378
400
  structure_format=structure_format,
379
401
  min_res_n=min_res_n,
380
- batch_size=batch_size,
381
- num_workers=num_workers,
402
+ batch_size=batch_size_res,
403
+ num_workers=num_workers_res,
382
404
  num_nodes=num_nodes,
383
405
  accelerator=accelerator,
384
406
  devices=devices,
385
407
  )
386
408
  chain_embedding(
387
409
  src_file=src_chain_file,
388
- output_path=output_path,
410
+ output_path=output_chain_path,
389
411
  output_format=output_format,
390
- output_name=f"{output_name}-chain",
391
- res_embedding_location=res_embedding_location,
412
+ output_name=output_chain_name,
413
+ res_embedding_location=output_res_path,
392
414
  structure_location=structure_location,
393
415
  structure_format=structure_format,
394
416
  min_res_n=min_res_n,
395
- batch_size=batch_size,
396
- num_workers=num_workers,
417
+ batch_size=batch_size_chain,
418
+ num_workers=num_workers_chain,
397
419
  num_nodes=num_nodes,
398
420
  accelerator=accelerator,
399
421
  devices=devices
400
422
  )
401
423
  assembly_embedding(
402
424
  src_file=src_assembly_file,
403
- output_path=output_path,
425
+ output_path=output_assembly_path,
404
426
  output_format=output_format,
405
- output_name=f"{output_name}-assembly",
406
- res_embedding_location=res_embedding_location,
427
+ output_name=output_assembly_name,
428
+ res_embedding_location=output_res_path,
407
429
  structure_location=structure_location,
408
430
  structure_format=structure_format,
409
431
  min_res_n=min_res_n,
410
- batch_size=batch_size,
411
- num_workers=num_workers,
432
+ batch_size=batch_size_assembly,
433
+ num_workers=num_workers_assembly,
412
434
  num_nodes=num_nodes,
413
435
  accelerator=accelerator,
414
436
  devices=devices
@@ -7,6 +7,7 @@ from rcsb_embedding_model.modules.chain_module import ChainModule
7
7
  from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath, FileOrStreamTuple, SrcLocation, \
8
8
  SrcTensorFrom, StructureLocation, StructureFormat, OutFormat
9
9
  from rcsb_embedding_model.utils.data import collate_seq_embeddings
10
+ from rcsb_embedding_model.utils.model import get_aggregator_model
10
11
  from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter, JsonStorage
11
12
 
12
13
 
@@ -52,13 +53,17 @@ def predict(
52
53
  )
53
54
  )
54
55
 
55
- module = ChainModule()
56
+ aggregator_model = get_aggregator_model()
57
+ module = ChainModule(
58
+ model=aggregator_model
59
+ )
56
60
  inference_writer = (JsonStorage(out_path, out_name) if out_format == OutFormat.grouped else CsvBatchWriter(out_path)) if out_path is not None else None
57
61
  trainer = Trainer(
58
62
  callbacks=[inference_writer] if inference_writer is not None else None,
59
63
  num_nodes=num_nodes,
60
64
  accelerator=accelerator,
61
65
  devices=devices,
66
+ strategy="ddp",
62
67
  logger=False
63
68
  )
64
69
 
@@ -6,6 +6,7 @@ from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
6
6
  from rcsb_embedding_model.modules.esm_module import EsmModule
7
7
  from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, \
8
8
  SrcProteinFrom, FileOrStreamTuple, SrcLocation, OutFormat
9
+ from rcsb_embedding_model.utils.model import get_residue_model
9
10
  from rcsb_embedding_model.writer.batch_writer import TensorBatchWriter, JsonStorage
10
11
 
11
12
 
@@ -46,13 +47,17 @@ def predict(
46
47
  collate_fn=lambda _: _
47
48
  )
48
49
 
49
- module = EsmModule()
50
+ esm_model = get_residue_model()
51
+ module = EsmModule(
52
+ model=esm_model
53
+ )
50
54
  inference_writer = (JsonStorage(out_path, out_name) if out_format == OutFormat.grouped else TensorBatchWriter(out_path)) if out_path is not None else None
51
55
  trainer = Trainer(
52
56
  callbacks=[inference_writer] if inference_writer is not None else None,
53
57
  num_nodes=num_nodes,
54
58
  accelerator=accelerator,
55
59
  devices=devices,
60
+ strategy="ddp",
56
61
  logger=False
57
62
  )
58
63
 
@@ -6,6 +6,7 @@ from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
6
6
  from rcsb_embedding_model.modules.structure_module import StructureModule
7
7
  from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, \
8
8
  SrcProteinFrom, FileOrStreamTuple, SrcLocation
9
+ from rcsb_embedding_model.utils.model import get_residue_model, get_aggregator_model
9
10
  from rcsb_embedding_model.writer.batch_writer import JsonStorage
10
11
 
11
12
 
@@ -45,13 +46,19 @@ def predict(
45
46
  collate_fn=lambda _: _
46
47
  )
47
48
 
48
- module = StructureModule()
49
+ res_model = get_residue_model()
50
+ aggregator_model = get_aggregator_model()
51
+ module = StructureModule(
52
+ res_model=res_model,
53
+ aggregator_model=aggregator_model
54
+ )
49
55
  inference_writer = JsonStorage(out_path, out_name) if out_path is not None and out_name is not None else None
50
56
  trainer = Trainer(
51
57
  callbacks=[inference_writer] if inference_writer is not None else None,
52
58
  num_nodes=num_nodes,
53
59
  accelerator=accelerator,
54
60
  devices=devices,
61
+ strategy="ddp",
55
62
  logger=False
56
63
  )
57
64
 
@@ -0,0 +1,19 @@
1
+ import logging
2
+
3
+ from lightning import LightningModule
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ class ChainModule(LightningModule):
8
+
9
+ def __init__(
10
+ self,
11
+ model
12
+ ):
13
+ super().__init__()
14
+ logger.info(f"Using device: {self.device}")
15
+ self.aggregator = model
16
+
17
+ def predict_step(self, batch, batch_idx):
18
+ (x, x_mask), dom_id = batch
19
+ return self.aggregator(x, x_mask), dom_id
@@ -1,16 +1,19 @@
1
+ import logging
2
+
1
3
  from esm.sdk.api import SamplingConfig
2
4
  from lightning import LightningModule
3
5
 
4
- from rcsb_embedding_model.utils.model import get_residue_model
5
-
6
+ logger = logging.getLogger(__name__)
6
7
 
7
8
  class EsmModule(LightningModule):
8
9
 
9
10
  def __init__(
10
- self
11
+ self,
12
+ model
11
13
  ):
12
14
  super().__init__()
13
- self.esm3 = get_residue_model(self.device)
15
+ logger.info(f"Using device: {self.device}")
16
+ self.esm3 = model
14
17
 
15
18
  def predict_step(self, prot_batch, batch_idx):
16
19
  return tuple([self.__compute_embeddings(esm_prot) for esm_prot, name in prot_batch]), tuple([name for esm_prot, name in prot_batch])
@@ -1,18 +1,23 @@
1
+ import logging
2
+
1
3
  from esm.sdk.api import SamplingConfig
2
4
  from lightning import LightningModule
3
5
 
4
6
  from rcsb_embedding_model.utils.data import collate_seq_embeddings
5
- from rcsb_embedding_model.utils.model import get_residue_model, get_aggregator_model
6
7
 
8
+ logger = logging.getLogger(__name__)
7
9
 
8
10
  class StructureModule(LightningModule):
9
11
 
10
12
  def __init__(
11
- self
13
+ self,
14
+ res_model,
15
+ aggregator_model
12
16
  ):
13
17
  super().__init__()
14
- self.esm3 = get_residue_model(self.device)
15
- self.aggregator = get_aggregator_model(device=self.device)
18
+ logger.info(f"Using device: {self.device}")
19
+ self.esm3 = res_model
20
+ self.aggregator = aggregator_model
16
21
 
17
22
  def predict_step(self, prot_batch, batch_idx):
18
23
  prot_embeddings = []
@@ -16,6 +16,8 @@ def get_aggregator_model(device=None):
16
16
  filename=FILE_NAME,
17
17
  revision=REVISION
18
18
  )
19
+ if device is None:
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
21
  weights = torch.load(model_path, weights_only=True, map_location=device)
20
22
  aggregator_model = ResidueEmbeddingAggregator()
21
23
  aggregator_model.load_state_dict(weights)
@@ -1,16 +0,0 @@
1
- from lightning import LightningModule
2
-
3
- from rcsb_embedding_model.utils.model import get_aggregator_model
4
-
5
-
6
- class ChainModule(LightningModule):
7
-
8
- def __init__(
9
- self
10
- ):
11
- super().__init__()
12
- self.model = get_aggregator_model(device=self.device)
13
-
14
- def predict_step(self, batch, batch_idx):
15
- (x, x_mask), dom_id = batch
16
- return self.model(x, x_mask), dom_id