rcsb-embedding-model 0.0.20__tar.gz → 0.0.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rcsb-embedding-model might be problematic. Click here for more details.

Files changed (46) hide show
  1. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/PKG-INFO +3 -3
  2. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/README.md +2 -2
  3. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/pyproject.toml +1 -1
  4. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/cli/inference.py +47 -22
  5. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/inference/chain_inference.py +6 -1
  6. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/inference/esm_inference.py +6 -1
  7. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/inference/structure_inference.py +8 -1
  8. rcsb_embedding_model-0.0.22/src/rcsb_embedding_model/modules/chain_module.py +19 -0
  9. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/modules/esm_module.py +7 -4
  10. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/modules/structure_module.py +9 -4
  11. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/utils/model.py +2 -0
  12. rcsb_embedding_model-0.0.20/src/rcsb_embedding_model/modules/chain_module.py +0 -16
  13. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/.github/workflows/_workflow-docker.yaml +0 -0
  14. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/.github/workflows/publish.yaml +0 -0
  15. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/.gitignore +0 -0
  16. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/Dockerfile +0 -0
  17. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/LICENSE.md +0 -0
  18. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/assets/embedding-model-architecture.png +0 -0
  19. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/examples/esm_embeddings.py +0 -0
  20. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/__init__.py +0 -0
  21. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/cli/args_utils.py +0 -0
  22. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/dataset/esm_prot_from_chain.py +0 -0
  23. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/dataset/esm_prot_from_structure.py +0 -0
  24. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/dataset/resdiue_assembly_embedding_from_structure.py +0 -0
  25. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/dataset/residue_assembly_embedding_from_tensor_file.py +0 -0
  26. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/dataset/residue_embedding_from_structure.py +0 -0
  27. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/dataset/residue_embedding_from_tensor_file.py +0 -0
  28. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/inference/assembly_inferece.py +0 -0
  29. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/model/layers.py +0 -0
  30. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/model/residue_embedding_aggregator.py +0 -0
  31. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/rcsb_structure_embedding.py +0 -0
  32. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/types/api_types.py +0 -0
  33. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/utils/data.py +0 -0
  34. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/utils/structure_parser.py +0 -0
  35. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/utils/structure_provider.py +0 -0
  36. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/src/rcsb_embedding_model/writer/batch_writer.py +0 -0
  37. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/tests/resources/embeddings/1acb.A.pt +0 -0
  38. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/tests/resources/embeddings/1acb.B.pt +0 -0
  39. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/tests/resources/embeddings/2uzi.A.pt +0 -0
  40. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/tests/resources/embeddings/2uzi.B.pt +0 -0
  41. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/tests/resources/embeddings/2uzi.C.pt +0 -0
  42. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/tests/resources/pdb/1acb.cif +0 -0
  43. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/tests/resources/pdb/2uzi.cif +0 -0
  44. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/tests/test_embedding_model.py +0 -0
  45. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/tests/test_inference.py +0 -0
  46. {rcsb_embedding_model-0.0.20 → rcsb_embedding_model-0.0.22}/tests/test_remote_inference.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rcsb-embedding-model
3
- Version: 0.0.20
3
+ Version: 0.0.22
4
4
  Summary: Protein Embedding Model for Structure Search
5
5
  Project-URL: Homepage, https://github.com/rcsb/rcsb-embedding-model
6
6
  Project-URL: Issues, https://github.com/rcsb/rcsb-embedding-model/issues
@@ -17,7 +17,7 @@ Description-Content-Type: text/markdown
17
17
 
18
18
  # RCSB Embedding Model
19
19
 
20
- **Version** 0.0.20
20
+ **Version** 0.0.22
21
21
 
22
22
 
23
23
  ## Overview
@@ -48,7 +48,7 @@ If you are interested in training the model with a new dataset, visit the [rcsb-
48
48
  **Requirements:**
49
49
 
50
50
  - Python ≥ 3.10
51
- - ESM == 3.1.1
51
+ - ESM >= 3.2.0
52
52
  - Lightning ≥ 2.5.0
53
53
  - Typer ≥ 0.15.0
54
54
 
@@ -1,6 +1,6 @@
1
1
  # RCSB Embedding Model
2
2
 
3
- **Version** 0.0.20
3
+ **Version** 0.0.22
4
4
 
5
5
 
6
6
  ## Overview
@@ -31,7 +31,7 @@ If you are interested in training the model with a new dataset, visit the [rcsb-
31
31
  **Requirements:**
32
32
 
33
33
  - Python ≥ 3.10
34
- - ESM == 3.1.1
34
+ - ESM >= 3.2.0
35
35
  - Lightning ≥ 2.5.0
36
36
  - Typer ≥ 0.15.0
37
37
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rcsb-embedding-model"
3
- version = "0.0.20"
3
+ version = "0.0.22"
4
4
  authors = [
5
5
  { name="Joan Segura", email="joan.segura@rcsb.org" },
6
6
  ]
@@ -8,6 +8,9 @@ from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, S
8
8
  StructureLocation, SrcAssemblyFrom, SrcTensorFrom, OutFormat
9
9
  from rcsb_embedding_model.utils.data import adapt_csv_to_embedding_chain_stream
10
10
 
11
+ import os
12
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
13
+
11
14
  app = typer.Typer(
12
15
  add_completion=False
13
16
  )
@@ -322,26 +325,36 @@ def complete_embedding(
322
325
  resolve_path=True,
323
326
  help='CSV file 4 columns: Structure Name | Structure File Path or URL (switch structure-location) | Assembly Id | Output embedding name.'
324
327
  )],
325
- output_path: Annotated[typer.FileText, typer.Option(
328
+ output_res_path: Annotated[typer.FileText, typer.Option(
326
329
  exists=True,
327
330
  file_okay=False,
328
331
  dir_okay=True,
329
332
  resolve_path=True,
330
- help='Output path to store predictions. Embeddings are stored as a single DataFrame file (see output_name).'
333
+ help='Output path to store residue embeddings. Residue embeddings are stored in separated files'
331
334
  )],
332
- res_embedding_location: Annotated[typer.FileText, typer.Option(
335
+ output_chain_path: Annotated[typer.FileText, typer.Option(
336
+ exists=True,
337
+ file_okay=False,
338
+ dir_okay=True,
339
+ resolve_path=True,
340
+ help='Output path to store chain embeddings. Embeddings are stored as a single JSON file (see output_chain_name).'
341
+ )],
342
+ output_assembly_path: Annotated[typer.FileText, typer.Option(
333
343
  exists=True,
334
344
  file_okay=False,
335
345
  dir_okay=True,
336
346
  resolve_path=True,
337
- help='Output path to store ESM predictions.'
347
+ help='Output path to store assembly embeddings. Embeddings are stored as a single JSON file (see output_assembly_name).'
338
348
  )],
339
349
  output_format: Annotated[OutFormat, typer.Option(
340
350
  help='Format of the output. Options: separated (predictions are stored in single files) or grouped (predictions are stored in a single JSON file).'
341
351
  )] = OutFormat.separated,
342
- output_name: Annotated[str, typer.Option(
343
- help='File name for storing embeddings as a single JSON file. Used when output-format=grouped.'
344
- )] = 'inference',
352
+ output_chain_name: Annotated[str, typer.Option(
353
+ help='File name for storing chain embeddings as a single JSON file. Used when output-format=grouped.'
354
+ )] = 'chain-inference',
355
+ output_assembly_name: Annotated[str, typer.Option(
356
+ help='File name for storing chain embeddings as a single JSON file. Used when output-format=grouped.'
357
+ )] = 'chain-inference',
345
358
  structure_location: Annotated[StructureLocation, typer.Option(
346
359
  help='Structure file location.'
347
360
  )] = StructureLocation.local,
@@ -351,10 +364,22 @@ def complete_embedding(
351
364
  min_res_n: Annotated[int, typer.Option(
352
365
  help='When using all chains in a structure, consider only chains with more than <min_res_n> residues.'
353
366
  )] = 0,
354
- batch_size: Annotated[int, typer.Option(
367
+ batch_size_res: Annotated[int, typer.Option(
355
368
  help='Number of samples processed together in one iteration.'
356
369
  )] = 1,
357
- num_workers: Annotated[int, typer.Option(
370
+ num_workers_res: Annotated[int, typer.Option(
371
+ help='Number of subprocesses to use for data loading.'
372
+ )] = 0,
373
+ batch_size_chain: Annotated[int, typer.Option(
374
+ help='Number of samples processed together in one iteration.'
375
+ )] = 1,
376
+ num_workers_chain: Annotated[int, typer.Option(
377
+ help='Number of subprocesses to use for data loading.'
378
+ )] = 0,
379
+ batch_size_assembly: Annotated[int, typer.Option(
380
+ help='Number of samples processed together in one iteration.'
381
+ )] = 1,
382
+ num_workers_assembly: Annotated[int, typer.Option(
358
383
  help='Number of subprocesses to use for data loading.'
359
384
  )] = 0,
360
385
  num_nodes: Annotated[int, typer.Option(
@@ -369,43 +394,43 @@ def complete_embedding(
369
394
  ):
370
395
  residue_embedding(
371
396
  src_file=src_chain_file,
372
- output_path=res_embedding_location,
397
+ output_path=output_res_path,
373
398
  output_format=OutFormat.separated,
374
399
  structure_location=structure_location,
375
400
  structure_format=structure_format,
376
401
  min_res_n=min_res_n,
377
- batch_size=batch_size,
378
- num_workers=num_workers,
402
+ batch_size=batch_size_res,
403
+ num_workers=num_workers_res,
379
404
  num_nodes=num_nodes,
380
405
  accelerator=accelerator,
381
406
  devices=devices,
382
407
  )
383
408
  chain_embedding(
384
409
  src_file=src_chain_file,
385
- output_path=output_path,
410
+ output_path=output_chain_path,
386
411
  output_format=output_format,
387
- output_name=f"{output_name}-chain",
388
- res_embedding_location=res_embedding_location,
412
+ output_name=output_chain_name,
413
+ res_embedding_location=output_res_path,
389
414
  structure_location=structure_location,
390
415
  structure_format=structure_format,
391
416
  min_res_n=min_res_n,
392
- batch_size=batch_size,
393
- num_workers=num_workers,
417
+ batch_size=batch_size_chain,
418
+ num_workers=num_workers_chain,
394
419
  num_nodes=num_nodes,
395
420
  accelerator=accelerator,
396
421
  devices=devices
397
422
  )
398
423
  assembly_embedding(
399
424
  src_file=src_assembly_file,
400
- output_path=output_path,
425
+ output_path=output_assembly_path,
401
426
  output_format=output_format,
402
- output_name=f"{output_name}-assembly",
403
- res_embedding_location=res_embedding_location,
427
+ output_name=output_assembly_name,
428
+ res_embedding_location=output_res_path,
404
429
  structure_location=structure_location,
405
430
  structure_format=structure_format,
406
431
  min_res_n=min_res_n,
407
- batch_size=batch_size,
408
- num_workers=num_workers,
432
+ batch_size=batch_size_assembly,
433
+ num_workers=num_workers_assembly,
409
434
  num_nodes=num_nodes,
410
435
  accelerator=accelerator,
411
436
  devices=devices
@@ -7,6 +7,7 @@ from rcsb_embedding_model.modules.chain_module import ChainModule
7
7
  from rcsb_embedding_model.types.api_types import Accelerator, Devices, OptionalPath, FileOrStreamTuple, SrcLocation, \
8
8
  SrcTensorFrom, StructureLocation, StructureFormat, OutFormat
9
9
  from rcsb_embedding_model.utils.data import collate_seq_embeddings
10
+ from rcsb_embedding_model.utils.model import get_aggregator_model
10
11
  from rcsb_embedding_model.writer.batch_writer import CsvBatchWriter, JsonStorage
11
12
 
12
13
 
@@ -52,13 +53,17 @@ def predict(
52
53
  )
53
54
  )
54
55
 
55
- module = ChainModule()
56
+ aggregator_model = get_aggregator_model()
57
+ module = ChainModule(
58
+ model=aggregator_model
59
+ )
56
60
  inference_writer = (JsonStorage(out_path, out_name) if out_format == OutFormat.grouped else CsvBatchWriter(out_path)) if out_path is not None else None
57
61
  trainer = Trainer(
58
62
  callbacks=[inference_writer] if inference_writer is not None else None,
59
63
  num_nodes=num_nodes,
60
64
  accelerator=accelerator,
61
65
  devices=devices,
66
+ strategy="ddp",
62
67
  logger=False
63
68
  )
64
69
 
@@ -6,6 +6,7 @@ from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
6
6
  from rcsb_embedding_model.modules.esm_module import EsmModule
7
7
  from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, \
8
8
  SrcProteinFrom, FileOrStreamTuple, SrcLocation, OutFormat
9
+ from rcsb_embedding_model.utils.model import get_residue_model
9
10
  from rcsb_embedding_model.writer.batch_writer import TensorBatchWriter, JsonStorage
10
11
 
11
12
 
@@ -46,13 +47,17 @@ def predict(
46
47
  collate_fn=lambda _: _
47
48
  )
48
49
 
49
- module = EsmModule()
50
+ esm_model = get_residue_model()
51
+ module = EsmModule(
52
+ model=esm_model
53
+ )
50
54
  inference_writer = (JsonStorage(out_path, out_name) if out_format == OutFormat.grouped else TensorBatchWriter(out_path)) if out_path is not None else None
51
55
  trainer = Trainer(
52
56
  callbacks=[inference_writer] if inference_writer is not None else None,
53
57
  num_nodes=num_nodes,
54
58
  accelerator=accelerator,
55
59
  devices=devices,
60
+ strategy="ddp",
56
61
  logger=False
57
62
  )
58
63
 
@@ -6,6 +6,7 @@ from rcsb_embedding_model.dataset.esm_prot_from_chain import EsmProtFromChain
6
6
  from rcsb_embedding_model.modules.structure_module import StructureModule
7
7
  from rcsb_embedding_model.types.api_types import StructureFormat, Accelerator, Devices, OptionalPath, StructureLocation, \
8
8
  SrcProteinFrom, FileOrStreamTuple, SrcLocation
9
+ from rcsb_embedding_model.utils.model import get_residue_model, get_aggregator_model
9
10
  from rcsb_embedding_model.writer.batch_writer import JsonStorage
10
11
 
11
12
 
@@ -45,13 +46,19 @@ def predict(
45
46
  collate_fn=lambda _: _
46
47
  )
47
48
 
48
- module = StructureModule()
49
+ res_model = get_residue_model()
50
+ aggregator_model = get_aggregator_model()
51
+ module = StructureModule(
52
+ res_model=res_model,
53
+ aggregator_model=aggregator_model
54
+ )
49
55
  inference_writer = JsonStorage(out_path, out_name) if out_path is not None and out_name is not None else None
50
56
  trainer = Trainer(
51
57
  callbacks=[inference_writer] if inference_writer is not None else None,
52
58
  num_nodes=num_nodes,
53
59
  accelerator=accelerator,
54
60
  devices=devices,
61
+ strategy="ddp",
55
62
  logger=False
56
63
  )
57
64
 
@@ -0,0 +1,19 @@
1
+ import logging
2
+
3
+ from lightning import LightningModule
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ class ChainModule(LightningModule):
8
+
9
+ def __init__(
10
+ self,
11
+ model
12
+ ):
13
+ super().__init__()
14
+ logger.info(f"Using device: {self.device}")
15
+ self.aggregator = model
16
+
17
+ def predict_step(self, batch, batch_idx):
18
+ (x, x_mask), dom_id = batch
19
+ return self.aggregator(x, x_mask), dom_id
@@ -1,16 +1,19 @@
1
+ import logging
2
+
1
3
  from esm.sdk.api import SamplingConfig
2
4
  from lightning import LightningModule
3
5
 
4
- from rcsb_embedding_model.utils.model import get_residue_model
5
-
6
+ logger = logging.getLogger(__name__)
6
7
 
7
8
  class EsmModule(LightningModule):
8
9
 
9
10
  def __init__(
10
- self
11
+ self,
12
+ model
11
13
  ):
12
14
  super().__init__()
13
- self.esm3 = get_residue_model(self.device)
15
+ logger.info(f"Using device: {self.device}")
16
+ self.esm3 = model
14
17
 
15
18
  def predict_step(self, prot_batch, batch_idx):
16
19
  return tuple([self.__compute_embeddings(esm_prot) for esm_prot, name in prot_batch]), tuple([name for esm_prot, name in prot_batch])
@@ -1,18 +1,23 @@
1
+ import logging
2
+
1
3
  from esm.sdk.api import SamplingConfig
2
4
  from lightning import LightningModule
3
5
 
4
6
  from rcsb_embedding_model.utils.data import collate_seq_embeddings
5
- from rcsb_embedding_model.utils.model import get_residue_model, get_aggregator_model
6
7
 
8
+ logger = logging.getLogger(__name__)
7
9
 
8
10
  class StructureModule(LightningModule):
9
11
 
10
12
  def __init__(
11
- self
13
+ self,
14
+ res_model,
15
+ aggregator_model
12
16
  ):
13
17
  super().__init__()
14
- self.esm3 = get_residue_model(self.device)
15
- self.aggregator = get_aggregator_model(device=self.device)
18
+ logger.info(f"Using device: {self.device}")
19
+ self.esm3 = res_model
20
+ self.aggregator = aggregator_model
16
21
 
17
22
  def predict_step(self, prot_batch, batch_idx):
18
23
  prot_embeddings = []
@@ -16,6 +16,8 @@ def get_aggregator_model(device=None):
16
16
  filename=FILE_NAME,
17
17
  revision=REVISION
18
18
  )
19
+ if device is None:
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
21
  weights = torch.load(model_path, weights_only=True, map_location=device)
20
22
  aggregator_model = ResidueEmbeddingAggregator()
21
23
  aggregator_model.load_state_dict(weights)
@@ -1,16 +0,0 @@
1
- from lightning import LightningModule
2
-
3
- from rcsb_embedding_model.utils.model import get_aggregator_model
4
-
5
-
6
- class ChainModule(LightningModule):
7
-
8
- def __init__(
9
- self
10
- ):
11
- super().__init__()
12
- self.model = get_aggregator_model(device=self.device)
13
-
14
- def predict_step(self, batch, batch_idx):
15
- (x, x_mask), dom_id = batch
16
- return self.model(x, x_mask), dom_id