dayhoff-tools 1.1.38__py3-none-any.whl → 1.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -400,425 +400,3 @@ class BoltzPredictor(Processor):
400
400
  f"Boltz prediction completed successfully. Output in {expected_output_dir}"
401
401
  )
402
402
  return expected_output_dir
403
-
404
-
405
- class MMSeqsProfileProcessor(Processor):
406
- """Processor for running MMseqs2 profile searches.
407
-
408
- This class wraps the MMseqs2 workflow to perform a profile-based search
409
- against a target database using a query FASTA.
410
- """
411
-
412
- def __init__(
413
- self,
414
- query_fasta_path_in_image: str,
415
- num_threads: int = 8,
416
- mmseqs_args: dict | None = None,
417
- ):
418
- """Initialize the MMSeqsProfileProcessor.
419
-
420
- Args:
421
- query_fasta_path_in_image: Path to the query FASTA file. This path is expected
422
- to be accessible within the execution environment (e.g.,
423
- packaged in a Docker image).
424
- num_threads: Number of threads to use for MMseqs2 commands.
425
- mmseqs_args: A dictionary of additional MMseqs2 parameters.
426
- Expected keys: "memory_limit_gb", "evalue", "sensitivity",
427
- "max_seqs_search", "min_seq_id_cluster", "max_seqs_profile_msa".
428
- Defaults are used if not provided.
429
- """
430
- if not Path(query_fasta_path_in_image).is_file():
431
- raise FileNotFoundError(
432
- f"Query FASTA file not found at: {query_fasta_path_in_image}"
433
- )
434
- self.query_fasta_path = query_fasta_path_in_image
435
- self.num_threads = str(num_threads) # MMseqs2 expects string for threads
436
-
437
- default_mmseqs_args = {
438
- "memory_limit_gb": "25",
439
- "evalue": "10",
440
- "sensitivity": "7.5",
441
- "max_seqs_search": "300",
442
- "min_seq_id_cluster": "0.8",
443
- "max_seqs_profile_msa": "1000",
444
- }
445
- if mmseqs_args:
446
- self.mmseqs_args = {**default_mmseqs_args, **mmseqs_args}
447
- else:
448
- self.mmseqs_args = default_mmseqs_args
449
-
450
- # Log dayhoff-tools version
451
- from dayhoff_tools import __version__
452
-
453
- logger.info(f"dayhoff-tools version: {__version__}")
454
- logger.info(
455
- f"MMSeqsProfileProcessor initialized with query: {self.query_fasta_path}"
456
- )
457
- logger.info(f"MMSeqs args: {self.mmseqs_args}")
458
- logger.info(f"Num threads: {self.num_threads}")
459
-
460
- def _run_mmseqs_command(
461
- self, command_parts: list[str], step_description: str, work_dir: Path
462
- ):
463
- """Runs an MMseqs2 command and logs its execution.
464
-
465
- Args:
466
- command_parts: A list of strings representing the command and its arguments.
467
- step_description: A human-readable description of the MMseqs2 step.
468
- work_dir: The working directory for the command.
469
-
470
- Raises:
471
- subprocess.CalledProcessError: If the MMseqs2 command returns a non-zero exit code.
472
- """
473
- full_command = " ".join(command_parts)
474
- logger.info(f"Running MMseqs2 step in {work_dir}: {step_description}")
475
- logger.info(f"Command: {full_command}")
476
- try:
477
- process = subprocess.run(
478
- command_parts,
479
- check=True,
480
- stdout=subprocess.PIPE,
481
- stderr=subprocess.PIPE,
482
- text=True,
483
- cwd=work_dir, # Run command in the specified working directory
484
- )
485
- if process.stdout:
486
- logger.info(f"MMseqs2 stdout: {process.stdout.strip()}")
487
- if process.stderr: # MMseqs often outputs informational messages to stderr
488
- logger.info(f"MMseqs2 stderr: {process.stderr.strip()}")
489
- logger.info(f"MMseqs2 step '{step_description}' completed successfully.")
490
- except subprocess.CalledProcessError as e:
491
- logger.error(f"MMseqs2 step '{step_description}' failed in {work_dir}.")
492
- if e.stdout:
493
- logger.error(f"MMseqs2 stdout: {e.stdout.strip()}")
494
- if e.stderr:
495
- logger.error(f"MMseqs2 stderr: {e.stderr.strip()}")
496
- raise
497
-
498
- def run(self, input_file: str) -> str:
499
- """Run MMseqs2 profile search.
500
-
501
- The input_file is the target FASTA. The query FASTA is provided
502
- during initialization.
503
- The method creates an output directory (e.g., {target_stem})
504
- which contains the result files, now named meaningfully using the target stem
505
- (e.g., {target_stem}_results.m8 and {target_stem}_hits.fasta).
506
-
507
- Args:
508
- input_file: Path to the input target FASTA file.
509
-
510
- Returns:
511
- Path to the output directory (e.g., {target_stem}) containing
512
- the meaningfully named result files.
513
-
514
- Raises:
515
- subprocess.CalledProcessError: If any MMseqs2 command fails.
516
- FileNotFoundError: If the input_file is not found.
517
- """
518
- if not Path(input_file).is_file():
519
- raise FileNotFoundError(f"Input target FASTA file not found: {input_file}")
520
-
521
- input_file_path = Path(input_file).resolve() # Ensure absolute path
522
- target_fasta_filename = input_file_path.name
523
- target_fasta_stem = input_file_path.stem # Get stem for naming
524
-
525
- # Create a unique base directory for this run's outputs and temp files
526
- # This directory will be returned and subsequently uploaded by the Operator
527
- run_base_dir_name = f"{target_fasta_stem}" # Use stem as the dir name
528
- run_base_dir = Path(run_base_dir_name).resolve()
529
- run_base_dir.mkdir(parents=True, exist_ok=True)
530
- logger.info(f"Created run base directory: {run_base_dir}")
531
-
532
- # Define local paths within the run_base_dir
533
- local_target_file = run_base_dir / target_fasta_filename
534
- # Copy the target file into the run directory to keep inputs and outputs together
535
- shutil.copy(input_file_path, local_target_file)
536
- logger.info(f"Copied target file {input_file_path} to {local_target_file}")
537
-
538
- # Query file is already specified by self.query_fasta_path (path in image)
539
- local_query_file = Path(self.query_fasta_path).resolve()
540
-
541
- # Temporary directory for MMseqs2 intermediate files, created inside run_base_dir
542
- mmseqs_temp_dir = run_base_dir / "mmseqs_tmp"
543
- mmseqs_temp_dir.mkdir(parents=True, exist_ok=True)
544
- logger.info(f"Created MMseqs2 temporary directory: {mmseqs_temp_dir}")
545
-
546
- # Define INTERMEDIATE output file paths within mmseqs_temp_dir
547
- intermediate_results_m8_file = mmseqs_temp_dir / "results.m8"
548
- intermediate_results_as_csv_file = mmseqs_temp_dir / "results_as.csv"
549
-
550
- # Define FINAL output file paths within run_base_dir, using target stem
551
- final_results_csv_file = run_base_dir / f"{target_fasta_stem}.csv"
552
- final_hits_txt_file = run_base_dir / f"{target_fasta_stem}.txt"
553
-
554
- # --- MMseqs2 Workflow Paths (intermediate files in mmseqs_temp_dir) ---
555
- query_db = mmseqs_temp_dir / "queryDB"
556
- target_db = mmseqs_temp_dir / "targetDB"
557
- # Ensure local_target_file is used for creating targetDB
558
- target_db_input_file = local_target_file
559
-
560
- query_db_cluster = mmseqs_temp_dir / "queryDB_cluster"
561
- query_db_rep = mmseqs_temp_dir / "queryDB_rep"
562
- aln_db = mmseqs_temp_dir / "alnDB"
563
- profile_db = mmseqs_temp_dir / "profileDB"
564
- result_db = mmseqs_temp_dir / "resultDB"
565
-
566
- try:
567
- # 1. Create query database
568
- self._run_mmseqs_command(
569
- ["mmseqs", "createdb", str(local_query_file), str(query_db)],
570
- "Create query DB",
571
- run_base_dir, # Working directory for the command
572
- )
573
-
574
- # 2. Create target database
575
- self._run_mmseqs_command(
576
- ["mmseqs", "createdb", str(target_db_input_file), str(target_db)],
577
- "Create target DB",
578
- run_base_dir,
579
- )
580
-
581
- # 3. Cluster query sequences
582
- self._run_mmseqs_command(
583
- [
584
- "mmseqs",
585
- "cluster",
586
- str(query_db),
587
- str(query_db_cluster),
588
- str(
589
- mmseqs_temp_dir / "tmp_cluster"
590
- ), # MMseqs needs a temp dir for cluster
591
- "--min-seq-id",
592
- self.mmseqs_args["min_seq_id_cluster"],
593
- "--threads",
594
- self.num_threads,
595
- ],
596
- "Cluster query sequences",
597
- run_base_dir,
598
- )
599
-
600
- # 4. Create representative set from query clusters
601
- self._run_mmseqs_command(
602
- [
603
- "mmseqs",
604
- "createsubdb",
605
- str(query_db_cluster),
606
- str(query_db),
607
- str(query_db_rep),
608
- ],
609
- "Create representative query set",
610
- run_base_dir,
611
- )
612
-
613
- # 5. Create MSA for profile generation
614
- self._run_mmseqs_command(
615
- [
616
- "mmseqs",
617
- "search",
618
- str(query_db_rep),
619
- str(query_db), # Search representative against full query DB
620
- str(aln_db),
621
- str(mmseqs_temp_dir / "tmp_search_msa"), # Temp for this search
622
- "--max-seqs",
623
- self.mmseqs_args["max_seqs_profile_msa"],
624
- "--threads",
625
- self.num_threads,
626
- ],
627
- "Create MSA for profile",
628
- run_base_dir,
629
- )
630
-
631
- # 6. Create profile database
632
- self._run_mmseqs_command(
633
- [
634
- "mmseqs",
635
- "result2profile",
636
- str(query_db_rep), # Use query_db_rep as input for profile
637
- str(query_db), # Full query DB as second arg
638
- str(aln_db),
639
- str(profile_db),
640
- "--threads", # Added threads option
641
- self.num_threads,
642
- ],
643
- "Create profile DB",
644
- run_base_dir,
645
- )
646
-
647
- # 7. Perform profile search
648
- self._run_mmseqs_command(
649
- [
650
- "mmseqs",
651
- "search",
652
- str(profile_db),
653
- str(target_db),
654
- str(result_db),
655
- str(mmseqs_temp_dir / "tmp_search_profile"), # Temp for this search
656
- "--split-memory-limit",
657
- f"{self.mmseqs_args['memory_limit_gb']}G",
658
- "-e",
659
- self.mmseqs_args["evalue"],
660
- "--max-seqs",
661
- self.mmseqs_args["max_seqs_search"],
662
- "--threads",
663
- self.num_threads,
664
- "-s",
665
- self.mmseqs_args["sensitivity"],
666
- ],
667
- "Perform profile search",
668
- run_base_dir,
669
- )
670
-
671
- # 8. Convert results to tabular format (M8) -> to intermediate file
672
- self._run_mmseqs_command(
673
- [
674
- "mmseqs",
675
- "convertalis",
676
- str(profile_db), # Query DB used for search (profileDB)
677
- str(target_db),
678
- str(result_db),
679
- str(intermediate_results_m8_file), # Output M8 file to temp dir
680
- "--threads",
681
- self.num_threads,
682
- ],
683
- "Convert results to M8",
684
- run_base_dir,
685
- )
686
-
687
- # 8.5 Convert M8 to CSV with headers
688
- logger.info(
689
- f"Converting M8 results to CSV: {intermediate_results_m8_file} -> {intermediate_results_as_csv_file}"
690
- )
691
- csv_headers = [
692
- "query_id",
693
- "target_id",
694
- "percent_identity",
695
- "alignment_length",
696
- "mismatches",
697
- "gap_openings",
698
- "query_start",
699
- "query_end",
700
- "target_start",
701
- "target_end",
702
- "e_value",
703
- "bit_score",
704
- ]
705
- try:
706
- if not intermediate_results_m8_file.exists():
707
- logger.warning(
708
- f"M8 results file {intermediate_results_m8_file} not found. CSV will be empty."
709
- )
710
- # Create an empty CSV with headers if M8 is missing
711
- with open(
712
- intermediate_results_as_csv_file, "w", newline=""
713
- ) as csvfile:
714
- writer = csv.writer(csvfile)
715
- writer.writerow(csv_headers)
716
- else:
717
- with (
718
- open(intermediate_results_m8_file, "r") as m8file,
719
- open(
720
- intermediate_results_as_csv_file, "w", newline=""
721
- ) as csvfile,
722
- ):
723
- writer = csv.writer(csvfile)
724
- writer.writerow(csv_headers)
725
- for line in m8file:
726
- writer.writerow(line.strip().split("\t"))
727
- except Exception as e:
728
- logger.error(f"Error converting M8 to CSV: {e}", exc_info=True)
729
- # Ensure an empty csv is created on error to prevent downstream issues
730
- if not intermediate_results_as_csv_file.exists():
731
- with open(
732
- intermediate_results_as_csv_file, "w", newline=""
733
- ) as csvfile:
734
- writer = csv.writer(csvfile)
735
- writer.writerow(csv_headers) # write headers even on error
736
-
737
- # 9. Extract hit sequence IDs from M8 results for the TXT file
738
- hit_sequence_ids = set()
739
- logger.info(
740
- f"Extracting hit IDs from {intermediate_results_m8_file} for TXT output."
741
- )
742
- try:
743
- if intermediate_results_m8_file.exists():
744
- with open(intermediate_results_m8_file, "r") as m8_file:
745
- for line in m8_file:
746
- if line.strip(): # Check if line is not empty
747
- columns = line.strip().split("\t")
748
- if len(columns) >= 2:
749
- hit_sequence_ids.add(
750
- columns[1]
751
- ) # Add target_accession
752
- logger.info(
753
- f"Found {len(hit_sequence_ids)} unique hit IDs in M8 file."
754
- )
755
- else:
756
- logger.warning(
757
- f"Intermediate M8 file {intermediate_results_m8_file} not found. Hit TXT file will be empty."
758
- )
759
- except Exception as e:
760
- logger.error(
761
- f"Error reading M8 file {intermediate_results_m8_file} for hit ID extraction: {e}",
762
- exc_info=True,
763
- )
764
- # Proceed even if M8 reading fails, TXT will be empty
765
-
766
- # 10. Write the set of hit sequence IDs to the final .txt file
767
- logger.info(
768
- f"Writing {len(hit_sequence_ids)} hit sequence IDs to {final_hits_txt_file}"
769
- )
770
- try:
771
- with open(final_hits_txt_file, "w") as txt_out:
772
- # Sort IDs for consistent output
773
- for seq_id in sorted(list(hit_sequence_ids)):
774
- txt_out.write(f"{seq_id}\n")
775
- logger.info(f"Successfully wrote hit IDs to {final_hits_txt_file}")
776
- except Exception as e:
777
- logger.error(
778
- f"Failed to write hit IDs to {final_hits_txt_file}: {e}",
779
- exc_info=True,
780
- )
781
- # Ensure the file exists even if writing fails
782
- if not final_hits_txt_file.exists():
783
- final_hits_txt_file.touch()
784
-
785
- logger.info(
786
- f"PROCESSOR: MMseqs2 workflow and FASTA/TXT generation completed successfully. Intermediate outputs in {mmseqs_temp_dir}"
787
- )
788
-
789
- # Move and rename final output files from mmseqs_temp_dir to run_base_dir
790
- if intermediate_results_as_csv_file.exists():
791
- shutil.move(
792
- str(intermediate_results_as_csv_file), str(final_results_csv_file)
793
- )
794
- logger.info(
795
- f"Moved and renamed M8 results to CSV: {final_results_csv_file}"
796
- )
797
- else:
798
- logger.warning(
799
- f"Intermediate CSV file {intermediate_results_as_csv_file} not found. Creating empty target CSV file."
800
- )
801
- final_results_csv_file.touch() # Create empty file in run_base_dir if not found
802
-
803
- logger.info(
804
- f"MMSeqsProfileProcessor run completed for {input_file}. Output CSV: {final_results_csv_file}"
805
- )
806
-
807
- except Exception as e:
808
- logger.error(
809
- f"MMSeqsProfileProcessor failed for {input_file}: {e}", exc_info=True
810
- )
811
- raise
812
- finally:
813
- # --- Cleanup --- #
814
- logger.info(f"Cleaning up temporary directory: {mmseqs_temp_dir}")
815
- if mmseqs_temp_dir.exists():
816
- shutil.rmtree(mmseqs_temp_dir)
817
- if local_target_file.exists() and local_target_file != Path(input_file):
818
- logger.info(
819
- f"Cleaning up local copy of target file: {local_target_file}"
820
- )
821
- local_target_file.unlink()
822
- logger.info("MMSeqsProfileProcessor cleanup finished.")
823
-
824
- return str(run_base_dir) # Return the path to the directory containing outputs
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: dayhoff-tools
3
- Version: 1.1.38
3
+ Version: 1.1.40
4
4
  Summary: Common tools for all the repos at Dayhoff Labs
5
5
  Author: Daniel Martin-Alarcon
6
6
  Author-email: dma@dayhofflabs.com
@@ -11,9 +11,8 @@ dayhoff_tools/deployment/deploy_aws.py,sha256=jQyQ0fbm2793jEHFO84lr5tNqiOpdBg6U0
11
11
  dayhoff_tools/deployment/deploy_gcp.py,sha256=xgaOVsUDmP6wSEMYNkm1yRNcVskfdz80qJtCulkBIAM,8860
12
12
  dayhoff_tools/deployment/deploy_utils.py,sha256=StFwbqnr2_FWiKVg3xnJF4kagTHzndqqDkpaIOaAn_4,26027
13
13
  dayhoff_tools/deployment/job_runner.py,sha256=hljvFpH2Bw96uYyUup5Ths72PZRL_X27KxlYzBMgguo,5086
14
- dayhoff_tools/deployment/processors.py,sha256=A7zvF47TfCkuLTCvaqZmk1M9ZgZcv6CAoXZCV6rEXuE,34599
14
+ dayhoff_tools/deployment/processors.py,sha256=f4L52ekx_zYirl8C4WfavxtOioyD-c34TdTJVDoLpWs,16572
15
15
  dayhoff_tools/deployment/swarm.py,sha256=MGcS2_x4RNFtnVjWlU_SwNfhICz8NlGYr9cYBK4ZKDA,21688
16
- dayhoff_tools/embedders.py,sha256=svP_ksm3FdyVZ8i8R9R5uoGu2qI_hVQ_eztG0drXkN8,36477
17
16
  dayhoff_tools/fasta.py,sha256=_kA2Cpiy7JAGbBqLrjElkzbcUD_p-nO2d5Aj1LVmOvc,50509
18
17
  dayhoff_tools/file_ops.py,sha256=JlGowvr-CUJFidV-4g_JmhUTN9bsYuaxtqKmnKomm-Q,8506
19
18
  dayhoff_tools/h5.py,sha256=j1nxxaiHsMidVX_XwB33P1Pz9d7K8ZKiDZwJWQUUQSY,21158
@@ -26,7 +25,7 @@ dayhoff_tools/intake/uniprot.py,sha256=BZYJQF63OtPcBBnQ7_P9gulxzJtqyorgyuDiPeOJq
26
25
  dayhoff_tools/logs.py,sha256=DKdeP0k0kliRcilwvX0mUB2eipO5BdWUeHwh-VnsICs,838
27
26
  dayhoff_tools/sqlite.py,sha256=jV55ikF8VpTfeQqqlHSbY8OgfyfHj8zgHNpZjBLos_E,18672
28
27
  dayhoff_tools/warehouse.py,sha256=TqV8nex1AluNaL4JuXH5zuu9P7qmE89lSo6f_oViy6U,14965
29
- dayhoff_tools-1.1.38.dist-info/METADATA,sha256=nDSK0SHTOMdieTxWDLScNArXB4g5TLAocONnt4xD89k,2843
30
- dayhoff_tools-1.1.38.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
31
- dayhoff_tools-1.1.38.dist-info/entry_points.txt,sha256=iAf4jteNqW3cJm6CO6czLxjW3vxYKsyGLZ8WGmxamSc,49
32
- dayhoff_tools-1.1.38.dist-info/RECORD,,
28
+ dayhoff_tools-1.1.40.dist-info/METADATA,sha256=BVFDhkYLynnht5-XLaT8k7HjOk-L2eBQVLljVvb7pbc,2843
29
+ dayhoff_tools-1.1.40.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
30
+ dayhoff_tools-1.1.40.dist-info/entry_points.txt,sha256=iAf4jteNqW3cJm6CO6czLxjW3vxYKsyGLZ8WGmxamSc,49
31
+ dayhoff_tools-1.1.40.dist-info/RECORD,,
@@ -1,892 +0,0 @@
1
- import logging
2
- import os
3
- import time
4
- from typing import Dict, List, Literal, Optional, Tuple, cast
5
-
6
- import h5py
7
- import numpy as np
8
- import pandas as pd
9
- import torch
10
- import torch.utils.data
11
- from dayhoff_tools.deployment.processors import Processor
12
- from dayhoff_tools.fasta import (
13
- clean_noncanonical_fasta,
14
- clean_noncanonical_fasta_to_dict,
15
- )
16
- from esm import FastaBatchedDataset, pretrained
17
- from transformers import T5EncoderModel, T5Tokenizer
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- class ESMEmbedder(Processor):
23
- """A processor that calculates ESM embeddings for a file of protein sequences.
24
- Embeddings come from the last layer, and can be per-protein or per-residue."""
25
-
26
- def __init__(
27
- self,
28
- # PyTorch model file OR name of pretrained model to download
29
- # https://github.com/facebookresearch/esm?tab=readme-ov-file#available
30
- model_name: Literal[
31
- "esm2_t33_650M_UR50D", # ESM2 version of the size used in CLEAN
32
- "esm2_t6_8M_UR50D", # Smallest
33
- "esm1b_t33_650M_UR50S", # Same as CLEAN
34
- "esm2_t36_3B_UR50D", # 2nd largest
35
- "esm2_t48_15B_UR50D", # Largest
36
- ],
37
- # Whether to return per-protein or per-residue embeddings.
38
- embedding_level: Literal["protein", "residue"],
39
- # Maximum batch size
40
- toks_per_batch: int = 4096,
41
- # Truncate sequences longer than the given value
42
- truncation_seq_length: int = 1022,
43
- ):
44
- super().__init__()
45
- self.model_name = model_name
46
- self.toks_per_batch = toks_per_batch
47
- self.embedding_level = embedding_level
48
- self.truncation_seq_length = truncation_seq_length
49
-
50
- # Instance variables set by other methods below:
51
- # self.model, self.alphabet, self.len_batches, self.data_loader, self.dataset_base_name
52
-
53
- def _load_model(self):
54
- """Download pre-trained model and load onto device"""
55
- self.model, self.alphabet = pretrained.load_model_and_alphabet(self.model_name)
56
- self.model.eval()
57
- if torch.cuda.is_available():
58
- self.model.cuda()
59
- logger.info("Transferred model to GPU.")
60
- else:
61
- logger.info("GPU not available. Running model on CPU.")
62
-
63
- def _load_dataset(self, fasta_file: str) -> None:
64
- """Load FASTA file into batched dataset and dataloader"""
65
- if not fasta_file.endswith(".fasta"):
66
- raise ValueError("Input file must have .fasta extension.")
67
-
68
- self.dataset_base_name = fasta_file.replace(".fasta", "")
69
- clean_fasta_file = fasta_file.replace(".fasta", "_clean.fasta")
70
- clean_noncanonical_fasta(input_path=fasta_file, output_path=clean_fasta_file)
71
-
72
- dataset = FastaBatchedDataset.from_file(clean_fasta_file)
73
- logger.info("Read %s and loaded %s sequences.", fasta_file, len(dataset))
74
-
75
- batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1)
76
- self.len_batches = len(batches)
77
- self.data_loader = torch.utils.data.DataLoader(
78
- dataset, # type: ignore
79
- collate_fn=self.alphabet.get_batch_converter(self.truncation_seq_length),
80
- batch_sampler=batches,
81
- )
82
- os.remove(clean_fasta_file)
83
-
84
- def embed_fasta(self) -> str:
85
- """Calculate embeddings from the FASTA file, return the path to the .h5 file of results
86
- Write the H5 file with one dataset per protein (id plus embedding vector, where the vector
87
- is 2D if it was calculated per-residue).
88
- """
89
- output_path = self.dataset_base_name + ".h5"
90
- with h5py.File(output_path, "w") as h5_file, torch.no_grad():
91
- start_time = time.time()
92
- logger.info(
93
- f"Calculating per-{self.embedding_level} embeddings. This dataset contains {self.len_batches} batches."
94
- )
95
- total_batches = self.len_batches
96
-
97
- for batch_idx, (labels, sequences, toks) in enumerate(
98
- self.data_loader, start=1
99
- ):
100
- if batch_idx % 10 == 0:
101
- elapsed_time = time.time() - start_time
102
- time_left = elapsed_time * (total_batches - batch_idx) / batch_idx
103
- logger.info(
104
- f"{self.dataset_base_name} | Batch {batch_idx}/{total_batches} | Elapsed time {elapsed_time / 60:.0f} min | Time left {time_left / 60:.0f} min, or {time_left / 3_600:.2f} hours"
105
- )
106
-
107
- if torch.cuda.is_available():
108
- toks = toks.to(device="cuda", non_blocking=True)
109
- out = self.model(
110
- toks,
111
- repr_layers=[
112
- cast(int, self.model.num_layers)
113
- ], # Get the last layer
114
- return_contacts=False,
115
- )
116
- out["logits"].to(device="cpu")
117
- representations = {
118
- layer: t.to(device="cpu")
119
- for layer, t in out["representations"].items()
120
- }
121
- for label_idx, label_full in enumerate(labels):
122
- label = label_full.split()[
123
- 0
124
- ] # Shorten the label to only the first word
125
- truncate_len = min(
126
- self.truncation_seq_length, len(sequences[label_idx])
127
- )
128
-
129
- full_embeds = list(representations.items())[0][1]
130
- if self.embedding_level == "protein":
131
- protein_embeds = (
132
- full_embeds[label_idx, 1 : truncate_len + 1].mean(0).clone()
133
- )
134
- h5_file.create_dataset(label, data=protein_embeds)
135
- elif self.embedding_level == "residue":
136
- residue_embeds = full_embeds[
137
- label_idx, 1 : truncate_len + 1
138
- ].clone()
139
- h5_file.create_dataset(label, data=residue_embeds)
140
- else:
141
- raise NotImplementedError(
142
- f"Embedding level {self.embedding_level} not implemented."
143
- )
144
-
145
- logger.info(f"Saved embeddings to {output_path}")
146
- return output_path
147
-
148
- def run(self, input_file: str, output_file: Optional[str] = None) -> str:
149
- self._load_model()
150
- self._load_dataset(input_file)
151
- embedding_file = self.embed_fasta()
152
-
153
- # If embedding per-protein, reformat the H5 file.
154
- # By default, write to the same filename; otherwise to output_file.
155
- reformatted_embedding_file = output_file or embedding_file
156
- if self.embedding_level == "protein":
157
- print("Reformatting H5 file...")
158
- formatter = H5Reformatter()
159
- formatter.run2(
160
- input_file=embedding_file,
161
- output_file=reformatted_embedding_file,
162
- )
163
-
164
- return reformatted_embedding_file
165
-
166
-
167
- class H5Reformatter(Processor):
168
- """A processor that reformats per-protein T5 embeddings
169
- Old format (input): 1 dataset per protein, with the ID as key and the embedding as value.
170
- New format (output): 2 datasets for the whole file, one of all protein IDs and one of all
171
- the embeddings together."""
172
-
173
- def __init__(self):
174
- super().__init__()
175
- # Set the device to GPU if available, otherwise CPU
176
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
177
- logger.info("Device is: %s", self.device)
178
-
179
- def embedding_file_to_df(self, file_name: str) -> pd.DataFrame:
180
- with h5py.File(file_name, "r") as f:
181
- gene_names = list(f.keys())
182
- Xg = [f[key][()] for key in gene_names] # type:ignore
183
- return pd.DataFrame(np.asmatrix(Xg), index=gene_names) # type:ignore
184
-
185
- def write_df_to_h5(self, df: pd.DataFrame, filename: str, description: str) -> None:
186
- """
187
- Write a DataFrame to an HDF5 file, separating row IDs and vectors.
188
-
189
- Parameters:
190
- - df: pandas DataFrame, where the index contains the IDs and
191
- the columns contain the vector components.
192
- - filename: String, the path to the output HDF5 file.
193
- """
194
- df.index = df.index.astype(
195
- str
196
- ) # Ensure the index is of a string type for the row IDs
197
- vectors = df.values
198
- ids = df.index.to_numpy(dtype=str)
199
-
200
- with h5py.File(filename, "w") as h5f:
201
- h5f.create_dataset("vectors", data=vectors)
202
- dt = h5py.special_dtype(
203
- vlen=str
204
- ) # Use variable-length strings to accommodate any ID size
205
- h5f.create_dataset("ids", data=ids.astype("S"), dtype=dt)
206
-
207
- # add the attributes
208
- h5f.attrs["description"] = description
209
- h5f.attrs["num_vecs"] = vectors.shape[0]
210
- h5f.attrs["vec_dim"] = vectors.shape[1]
211
-
212
- def run(self, input_file: str) -> str:
213
- """Load an H5 file as a DataFrame, delete the file, and then export
214
- the dataframe as an H5 file in the new format."""
215
- df = self.embedding_file_to_df(input_file)
216
- os.remove(input_file)
217
- new_file_description = "Embeddings formatted for global ID and Vector tables."
218
- self.write_df_to_h5(df, input_file, new_file_description)
219
-
220
- return input_file
221
-
222
- def run2(self, input_file: str, output_file: str):
223
- """Load an H5 file as a DataFrame, delete the file, and then export
224
- the dataframe as an H5 file in the new format."""
225
- df = self.embedding_file_to_df(input_file)
226
- new_file_description = "Embeddings formatted for global ID and Vector tables."
227
- self.write_df_to_h5(df, output_file, new_file_description)
228
-
229
-
230
- class Embedder(Processor):
231
- """Base class for protein sequence embedders with optimized memory management.
232
-
233
- This class provides the core functionality for embedding protein sequences using
234
- transformer models, with built-in memory management and batch processing capabilities.
235
- It handles sequences of different sizes appropriately, processing large sequences
236
- individually and smaller sequences in batches for efficiency.
237
-
238
- Memory Management Features:
239
- - Periodic cleanup of GPU memory to prevent fragmentation
240
- - Separate handling of large and small sequences to optimize memory usage
241
- - Batch size limits based on total residues to prevent OOM errors
242
- - Configurable cleanup frequency to balance performance and memory usage
243
- - Empirically tested sequence length limits (5000-5500 residues depending on model)
244
-
245
- Memory Fragmentation Prevention:
246
- - Large sequences (>2500 residues) are processed individually to maintain contiguous memory blocks
247
- - Small sequences are batched to efficiently use memory fragments
248
- - Forced cleanup after processing large proteins
249
- - Memory cleanup after every N sequences (configurable)
250
- - Aggressive garbage collection settings for CUDA memory allocator
251
-
252
- Memory Usage Patterns:
253
- - Base memory: 2.4-4.8GB (model dependent)
254
- - Peak memory: 12-15GB during large sequence processing
255
- - Fragmentation ratio maintained above 0.92 for efficient memory use
256
- - Maximum sequence length determined by model:
257
- * T5: ~5000 residues
258
- * ProstT5: ~5500 residues
259
-
260
- Implementation Details:
261
- - Uses PyTorch's CUDA memory allocator with optimized settings
262
- - Configurable thresholds for large protein handling
263
- - Automatic batch size adjustment based on sequence lengths
264
- - Optional chunking for sequences exceeding maximum length
265
- - Detailed memory statistics logging for monitoring
266
-
267
- Note:
268
- Memory limits are hardware-dependent. The above values are based on testing
269
- with a 16GB GPU (such as NVIDIA T4). Adjust parameters based on available GPU memory.
270
- """
271
-
272
- def __init__(
273
- self,
274
- model,
275
- tokenizer,
276
- max_seq_length: int = 5000,
277
- large_protein_threshold: int = 2500,
278
- batch_residue_limit: int = 5000,
279
- cleanup_frequency: int = 100,
280
- skip_long_proteins: bool = False,
281
- ):
282
- """Initialize the Embedder with model and memory management parameters.
283
-
284
- Args:
285
- model: The transformer model to use for embeddings
286
- tokenizer: The tokenizer matching the model
287
- max_seq_length: Maximum sequence length before chunking or skipping.
288
- Also used as chunk size when processing long sequences.
289
- large_protein_threshold: Sequences longer than this are processed individually
290
- batch_residue_limit: Maximum total residues allowed in a batch
291
- cleanup_frequency: Number of sequences to process before performing memory cleanup
292
- skip_long_proteins: If True, skip proteins longer than max_seq_length.
293
- If False, process them in chunks of max_seq_length size.
294
-
295
- Note:
296
- The class automatically configures CUDA memory allocation if GPU is available.
297
- """
298
- # Configure memory allocator for CUDA
299
- if torch.cuda.is_available():
300
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
301
- "max_split_size_mb:128," # Smaller allocation chunks
302
- "garbage_collection_threshold:0.8" # More aggressive GC
303
- )
304
-
305
- self.max_seq_length = max_seq_length
306
- self.large_protein_threshold = large_protein_threshold
307
- self.batch_residue_limit = batch_residue_limit
308
- self.cleanup_frequency = cleanup_frequency
309
- self.skip_long_proteins = skip_long_proteins
310
- self.processed_count = 0
311
-
312
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
313
- self.model = model
314
- self.tokenizer = tokenizer
315
-
316
- self.model.to(self.device)
317
- self.model.eval()
318
-
319
- def get_embeddings(self, seqs: Dict[str, str]) -> Dict[str, np.ndarray]:
320
- """Process sequences and generate embeddings with memory management.
321
-
322
- This method handles the core embedding logic, including:
323
- - Handling sequences that exceed maximum length (skip or chunk)
324
- - Splitting sequences into large and small batches
325
- - Periodic memory cleanup
326
- - Batch processing for efficiency
327
-
328
- Args:
329
- seqs: Dictionary mapping sequence IDs to their amino acid sequences
330
-
331
- Returns:
332
- Dictionary mapping sequence IDs to their embedding vectors
333
-
334
- Note:
335
- Long sequences are either skipped or processed in chunks based on skip_long_proteins
336
- """
337
- results: Dict[str, np.ndarray] = {}
338
- try:
339
- # Initialize progress tracking
340
- total_sequences = len(seqs)
341
- processed_sequences = 0
342
- start_time = time.time()
343
-
344
- logger.info(f"Starting embedding of {total_sequences} sequences")
345
-
346
- # Handle sequences based on length
347
- long_seqs = {
348
- id: seq for id, seq in seqs.items() if len(seq) > self.max_seq_length
349
- }
350
- valid_seqs = {
351
- id: seq for id, seq in seqs.items() if len(seq) <= self.max_seq_length
352
- }
353
-
354
- if long_seqs:
355
- if self.skip_long_proteins:
356
- logger.warning(
357
- f"Skipping {len(long_seqs)} sequences exceeding max length {self.max_seq_length}: {', '.join(long_seqs.keys())}"
358
- )
359
- else:
360
- logger.info(
361
- f"Processing {len(long_seqs)} long sequences in chunks: {', '.join(long_seqs.keys())}"
362
- )
363
- for i, (seq_id, seq) in enumerate(long_seqs.items(), 1):
364
- logger.info(
365
- f"Embedding long sequence {i}/{len(long_seqs)}: {seq_id}"
366
- )
367
- results[seq_id] = self.embed_big_prot(seq_id, seq)
368
- self.cleanup_memory()
369
-
370
- # Update progress
371
- processed_sequences += 1
372
- elapsed_time = time.time() - start_time
373
- remaining_sequences = total_sequences - processed_sequences
374
- avg_time_per_seq = (
375
- elapsed_time / processed_sequences
376
- if processed_sequences > 0
377
- else 0
378
- )
379
- estimated_time_left = avg_time_per_seq * remaining_sequences
380
-
381
- logger.info(
382
- f"Progress: {processed_sequences}/{total_sequences} sequences ({processed_sequences/total_sequences*100:.1f}%) | "
383
- f"Elapsed: {elapsed_time/60:.1f} min | "
384
- f"Est. remaining: {estimated_time_left/60:.1f} min"
385
- )
386
-
387
- # Split remaining sequences based on size
388
- large_seqs = {
389
- id: seq
390
- for id, seq in valid_seqs.items()
391
- if len(seq) > self.large_protein_threshold
392
- }
393
- small_seqs = {
394
- id: seq
395
- for id, seq in valid_seqs.items()
396
- if len(seq) <= self.large_protein_threshold
397
- }
398
-
399
- logger.info(
400
- f"Split into {len(large_seqs)} large and {len(small_seqs)} small sequences"
401
- )
402
-
403
- # Process large sequences individually
404
- for i, (seq_id, seq) in enumerate(large_seqs.items(), 1):
405
- logger.info(
406
- f"Processing large sequence {i}/{len(large_seqs)}: {seq_id}"
407
- )
408
- batch = [(seq_id, seq, len(seq))]
409
- results.update(self.embed_batch(batch))
410
- self.cleanup_memory() # Cleanup after each large sequence
411
-
412
- # Update progress
413
- processed_sequences += 1
414
- elapsed_time = time.time() - start_time
415
- remaining_sequences = total_sequences - processed_sequences
416
- avg_time_per_seq = (
417
- elapsed_time / processed_sequences if processed_sequences > 0 else 0
418
- )
419
- estimated_time_left = avg_time_per_seq * remaining_sequences
420
-
421
- logger.info(
422
- f"Progress: {processed_sequences}/{total_sequences} sequences ({processed_sequences/total_sequences*100:.1f}%) | "
423
- f"Elapsed: {elapsed_time/60:.1f} min | "
424
- f"Est. remaining: {estimated_time_left/60:.1f} min"
425
- )
426
-
427
- # Process small sequences in batches
428
- current_batch: List[Tuple[str, str, int]] = []
429
- current_size = 0
430
- small_batch_count = 0
431
- total_small_batches = (
432
- sum(len(seq) for seq in small_seqs.values())
433
- + self.batch_residue_limit
434
- - 1
435
- ) // self.batch_residue_limit
436
-
437
- # Sort sequences by length in descending order (reduces unnecessary padding --> speeds up embedding)
438
- small_seqs_sorted = sorted(
439
- small_seqs.items(), key=lambda x: len(x[1]), reverse=True
440
- )
441
-
442
- for seq_id, seq in small_seqs_sorted:
443
- seq_len = len(seq)
444
-
445
- if current_size + seq_len > self.batch_residue_limit:
446
- if current_batch:
447
- small_batch_count += 1
448
- logger.info(
449
- f"Processing small batch {small_batch_count}/{total_small_batches} with {len(current_batch)} sequences"
450
- )
451
- batch_results = self.embed_batch(current_batch)
452
- results.update(batch_results)
453
- self.cleanup_memory()
454
-
455
- # Update progress
456
- processed_sequences += len(current_batch)
457
- elapsed_time = time.time() - start_time
458
- remaining_sequences = total_sequences - processed_sequences
459
- avg_time_per_seq = (
460
- elapsed_time / processed_sequences
461
- if processed_sequences > 0
462
- else 0
463
- )
464
- estimated_time_left = avg_time_per_seq * remaining_sequences
465
-
466
- logger.info(
467
- f"Progress: {processed_sequences}/{total_sequences} sequences ({processed_sequences/total_sequences*100:.1f}%) | "
468
- f"Elapsed: {elapsed_time/60:.1f} min | "
469
- f"Est. remaining: {estimated_time_left/60:.1f} min"
470
- )
471
- current_batch = []
472
- current_size = 0
473
-
474
- current_batch.append((seq_id, seq, seq_len))
475
- current_size += seq_len
476
-
477
- # Process remaining batch
478
- if current_batch:
479
- small_batch_count += 1
480
- logger.info(
481
- f"Processing final small batch {small_batch_count}/{total_small_batches} with {len(current_batch)} sequences"
482
- )
483
- batch_results = self.embed_batch(current_batch)
484
- results.update(batch_results)
485
-
486
- # Update final progress
487
- processed_sequences += len(current_batch)
488
- elapsed_time = time.time() - start_time
489
-
490
- logger.info(
491
- f"Completed embedding {processed_sequences}/{total_sequences} sequences in {elapsed_time/60:.1f} minutes"
492
- )
493
-
494
- return results
495
-
496
- finally:
497
- self.cleanup_memory(deep=True)
498
-
499
- def cleanup_memory(self, deep: bool = False):
500
- """Perform memory cleanup operations.
501
-
502
- Args:
503
- deep: If True, performs aggressive cleanup including model transfer
504
- and garbage collection. Takes longer but frees more memory.
505
-
506
- Note:
507
- Regular cleanup is performed based on cleanup_frequency.
508
- Deep cleanup is more thorough but takes longer.
509
- """
510
- self.processed_count += 1
511
-
512
- if deep or self.processed_count % self.cleanup_frequency == 0:
513
- logger.info(
514
- f"Performing memory cleanup after {self.processed_count} sequences"
515
- )
516
- if torch.cuda.is_available():
517
- before_mem = torch.cuda.memory_allocated() / 1e9
518
-
519
- torch.cuda.empty_cache()
520
- if deep:
521
- self.model = self.model.cpu()
522
- torch.cuda.empty_cache()
523
- self.model = self.model.to(self.device)
524
-
525
- after_mem = torch.cuda.memory_allocated() / 1e9
526
- logger.info(
527
- f"Memory cleaned up: {before_mem:.2f}GB -> {after_mem:.2f}GB"
528
- )
529
-
530
- if deep:
531
- import gc
532
-
533
- gc.collect()
534
-
535
- def run(self, input_file, output_file=None):
536
- """
537
- Run the embedding process on the input file.
538
-
539
- Args:
540
- input_file (str): Path to the input FASTA file.
541
- output_file (str, optional): Path to the output H5 file. If not provided,
542
- it will be generated from the input file name.
543
-
544
- Returns:
545
- str: Path to the output H5 file containing the embeddings.
546
- """
547
- logger.info(f"Loading sequences from {input_file}")
548
- start_time = time.time()
549
- sequences = clean_noncanonical_fasta_to_dict(input_file)
550
- load_time = time.time() - start_time
551
- logger.info(
552
- f"Loaded {len(sequences)} sequences from {input_file} in {load_time:.2f} seconds"
553
- )
554
-
555
- logger.info(f"Starting embedding process for {len(sequences)} sequences")
556
- embed_start_time = time.time()
557
- embeddings = self.get_embeddings(sequences)
558
- embed_time = time.time() - embed_start_time
559
- logger.info(
560
- f"Completed embedding {len(embeddings)} sequences in {embed_time/60:.2f} minutes"
561
- )
562
-
563
- if output_file is None:
564
- output_file = input_file.replace(".fasta", ".h5")
565
-
566
- logger.info(f"Saving embeddings to {output_file}")
567
- save_start_time = time.time()
568
- self.save_to_h5(output_file, embeddings)
569
- save_time = time.time() - save_start_time
570
- logger.info(
571
- f"Saved {len(embeddings)} embeddings to {output_file} in {save_time:.2f} seconds"
572
- )
573
-
574
- total_time = time.time() - start_time
575
- logger.info(f"Total processing time: {total_time/60:.2f} minutes")
576
-
577
- return output_file
578
-
579
- def save_to_h5(self, output_file: str, embeddings: Dict[str, np.ndarray]) -> None:
580
- """
581
- Save protein embeddings to an HDF5 file.
582
-
583
- Args:
584
- output_file (str): Path to save the embeddings.
585
- embeddings (Dict[str, np.ndarray]): Dictionary of embeddings.
586
-
587
- The method creates an H5 file with two datasets:
588
- - 'ids': contains protein IDs as variable-length strings
589
- - 'vectors': contains embedding vectors as float32 arrays
590
- """
591
- # Convert the embeddings dictionary to lists for ids and vectors
592
- ids = list(embeddings.keys())
593
- vectors = np.array(list(embeddings.values()), dtype=np.float32)
594
-
595
- # Create the HDF5 file, with datasets for vectors and IDs
596
- with h5py.File(output_file, "w") as h5f:
597
- # Create the 'vectors' dataset
598
- h5f.create_dataset("vectors", data=vectors)
599
-
600
- # Create the 'ids' dataset with variable-length strings
601
- dt = h5py.special_dtype(vlen=str)
602
- h5f.create_dataset("ids", data=ids, dtype=dt)
603
-
604
- # Add the attributes
605
- h5f.attrs["num_vecs"] = len(embeddings)
606
- h5f.attrs["vec_dim"] = vectors.shape[1] if vectors.size > 0 else 0
607
-
608
- def embed_big_prot(self, seq_id: str, sequence: str) -> np.ndarray:
609
- """Embed a large protein sequence by chunking it and averaging the embeddings.
610
-
611
- Args:
612
- seq_id: The identifier for the protein sequence
613
- sequence: The protein sequence to embed
614
-
615
- Returns:
616
- np.ndarray: The averaged embedding for the entire sequence
617
-
618
- Note:
619
- This method processes the sequence in chunks of size max_seq_length
620
- and averages the resulting embeddings.
621
- """
622
- if not isinstance(sequence, str):
623
- raise TypeError("Sequence must be a string.")
624
-
625
- if not sequence:
626
- raise ValueError("Sequence cannot be empty.")
627
-
628
- if self.max_seq_length <= 0:
629
- raise ValueError("max_seq_length must be greater than 0.")
630
-
631
- # Create chunks of the sequence using max_seq_length
632
- chunks: List[Tuple[str, str, int]] = [
633
- (
634
- seq_id,
635
- sequence[i : i + self.max_seq_length],
636
- min(self.max_seq_length, len(sequence) - i),
637
- )
638
- for i in range(0, len(sequence), self.max_seq_length)
639
- ]
640
-
641
- logger.info(
642
- f"Processing {seq_id} in {len(chunks)} chunks (total length: {len(sequence)})"
643
- )
644
-
645
- # Embed each chunk
646
- chunk_embeddings = []
647
- for i, chunk in enumerate(chunks, 1):
648
- logger.info(
649
- f"Processing chunk {i}/{len(chunks)} for {seq_id} (length: {chunk[2]})"
650
- )
651
- chunk_start_time = time.time()
652
- result = self.embed_batch([chunk])
653
- chunk_embeddings.append(result[seq_id])
654
- chunk_time = time.time() - chunk_start_time
655
- logger.info(
656
- f"Processed chunk {i}/{len(chunks)} for {seq_id} in {chunk_time:.2f} seconds"
657
- )
658
-
659
- # Average the embeddings
660
- average_embedding = np.mean(chunk_embeddings, axis=0)
661
- logger.info(f"Completed processing {seq_id} (averaged {len(chunks)} chunks)")
662
-
663
- return average_embedding
664
-
665
- def embed_batch(self, batch: List[Tuple[str, str, int]]) -> Dict[str, np.ndarray]:
666
- """
667
- Generate embeddings for a batch of sequences.
668
-
669
- Args:
670
- batch: A list of tuples, each containing (sequence_id, sequence, sequence_length)
671
-
672
- Returns:
673
- A dictionary mapping sequence IDs to their embeddings as numpy arrays
674
- """
675
- if not batch:
676
- raise ValueError(
677
- "Cannot embed an empty batch. Please provide at least one sequence."
678
- )
679
-
680
- sequence_ids, sequences, sequence_lengths = zip(*batch)
681
-
682
- # Prepare sequences for tokenization
683
- tokenizer_input = self.prepare_tokenizer_input(sequences)
684
-
685
- # Tokenize sequences
686
- encoded_input = self.tokenizer.batch_encode_plus(
687
- tokenizer_input,
688
- add_special_tokens=True,
689
- padding="longest",
690
- return_tensors="pt",
691
- )
692
-
693
- # Move tensors to the appropriate device
694
- input_ids = encoded_input["input_ids"].to(self.device)
695
- attention_mask = encoded_input["attention_mask"].to(self.device)
696
-
697
- # Generate embeddings
698
- with torch.no_grad():
699
- embedding_output = self.model(
700
- input_ids, attention_mask=attention_mask
701
- ).last_hidden_state
702
-
703
- # Process embeddings for each sequence
704
- embeddings = {}
705
- for idx, (seq_id, seq_len) in enumerate(zip(sequence_ids, sequence_lengths)):
706
- # Extract embedding for the sequence
707
- seq_embedding = self.extract_sequence_embedding(
708
- embedding_output[idx], seq_len
709
- )
710
-
711
- # Calculate mean embedding and convert to numpy array
712
- mean_embedding = seq_embedding.mean(dim=0).detach().cpu().numpy().squeeze()
713
-
714
- embeddings[seq_id] = mean_embedding
715
-
716
- return embeddings
717
-
718
- def prepare_tokenizer_input(self, sequences: List[str]) -> List[str]:
719
- """Prepare sequences for tokenization."""
720
- raise NotImplementedError
721
-
722
- def extract_sequence_embedding(
723
- self, embedding: torch.Tensor, seq_len: int
724
- ) -> torch.Tensor:
725
- """Extract the relevant part of the embedding for a sequence."""
726
- raise NotImplementedError
727
-
728
-
729
- class ProstT5Embedder(Embedder):
730
- """Protein sequence embedder using the ProstT5 model.
731
-
732
- This class implements protein sequence embedding using the ProstT5 model,
733
- which is specifically trained for protein structure prediction tasks.
734
- It includes memory-efficient processing and automatic precision selection
735
- based on available hardware.
736
-
737
- Memory management features are inherited from the base Embedder class:
738
- - Periodic cleanup of GPU memory
739
- - Separate handling of large and small sequences
740
- - Batch size limits based on total residues
741
- - Configurable cleanup frequency
742
- """
743
-
744
- def __init__(
745
- self,
746
- max_seq_length: int = 5000,
747
- large_protein_threshold: int = 2500,
748
- batch_residue_limit: int = 5000,
749
- cleanup_frequency: int = 100,
750
- skip_long_proteins: bool = False,
751
- ):
752
- """Initialize ProstT5Embedder with memory management parameters.
753
-
754
- Args:
755
- max_seq_length: Maximum sequence length before chunking or skipping.
756
- Also used as chunk size when processing long sequences.
757
- large_protein_threshold: Sequences longer than this are processed individually
758
- batch_residue_limit: Maximum total residues in a batch
759
- cleanup_frequency: Frequency of memory cleanup operations
760
- skip_long_proteins: If True, skip proteins longer than max_seq_length
761
-
762
- Note:
763
- The model automatically selects half precision (float16) when running on GPU
764
- and full precision (float32) when running on CPU.
765
- """
766
- tokenizer = T5Tokenizer.from_pretrained(
767
- "Rostlab/ProstT5", do_lower_case=False, legacy=True
768
- )
769
- model = T5EncoderModel.from_pretrained("Rostlab/ProstT5")
770
-
771
- super().__init__(
772
- model,
773
- tokenizer,
774
- max_seq_length,
775
- large_protein_threshold,
776
- batch_residue_limit,
777
- cleanup_frequency,
778
- skip_long_proteins,
779
- )
780
-
781
- # Set precision based on device
782
- self.model = (
783
- self.model.half() if torch.cuda.is_available() else self.model.float()
784
- )
785
-
786
- def prepare_tokenizer_input(self, sequences: List[str]) -> List[str]:
787
- """Prepare sequences for ProstT5 tokenization.
788
-
789
- Args:
790
- sequences: List of amino acid sequences
791
-
792
- Returns:
793
- List of sequences with ProstT5-specific formatting, including the
794
- <AA2fold> prefix and space-separated residues.
795
- """
796
- return [f"<AA2fold> {' '.join(seq)}" for seq in sequences]
797
-
798
- def extract_sequence_embedding(
799
- self, embedding: torch.Tensor, seq_len: int
800
- ) -> torch.Tensor:
801
- """Extract relevant embeddings for a sequence.
802
-
803
- Args:
804
- embedding: Raw embedding tensor from the model
805
- seq_len: Length of the original sequence
806
-
807
- Returns:
808
- Tensor containing only the relevant sequence embeddings,
809
- excluding special tokens. For ProstT5, we skip the first token
810
- (corresponding to <AA2fold>) and take the next seq_len tokens.
811
- """
812
- return embedding[1 : seq_len + 1]
813
-
814
-
815
- class T5Embedder(Embedder):
816
- """Protein sequence embedder using the T5 transformer model.
817
-
818
- This class implements protein sequence embedding using the T5 model from Rostlab,
819
- specifically designed for protein sequences. It includes memory-efficient processing
820
- of both large and small sequences.
821
-
822
- The model used is 'Rostlab/prot_t5_xl_half_uniref50-enc', which was trained on
823
- UniRef50 sequences and provides state-of-the-art protein embeddings.
824
-
825
- Memory management features are inherited from the base Embedder class:
826
- - Periodic cleanup of GPU memory
827
- - Separate handling of large and small sequences
828
- - Batch size limits based on total residues
829
- - Configurable cleanup frequency
830
- """
831
-
832
- def __init__(
833
- self,
834
- max_seq_length: int = 5000,
835
- large_protein_threshold: int = 2500,
836
- batch_residue_limit: int = 5000,
837
- cleanup_frequency: int = 100,
838
- skip_long_proteins: bool = False,
839
- ):
840
- """Initialize T5Embedder with memory management parameters.
841
-
842
- Args:
843
- max_seq_length: Maximum sequence length before chunking or skipping.
844
- Also used as chunk size when processing long sequences.
845
- large_protein_threshold: Sequences longer than this are processed individually
846
- batch_residue_limit: Maximum total residues in a batch
847
- cleanup_frequency: Frequency of memory cleanup operations
848
- skip_long_proteins: If True, skip proteins longer than max_seq_length
849
-
850
- Note:
851
- The model automatically handles memory management and batch processing
852
- based on sequence sizes and available resources.
853
- """
854
- tokenizer = T5Tokenizer.from_pretrained(
855
- "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
856
- )
857
- model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
858
-
859
- super().__init__(
860
- model,
861
- tokenizer,
862
- max_seq_length,
863
- large_protein_threshold,
864
- batch_residue_limit,
865
- cleanup_frequency,
866
- skip_long_proteins,
867
- )
868
-
869
- def prepare_tokenizer_input(self, sequences: List[str]) -> List[str]:
870
- """Prepare sequences for T5 tokenization.
871
-
872
- Args:
873
- sequences: List of amino acid sequences
874
-
875
- Returns:
876
- List of space-separated sequences ready for tokenization
877
- """
878
- return [" ".join(seq) for seq in sequences]
879
-
880
- def extract_sequence_embedding(
881
- self, embedding: torch.Tensor, seq_len: int
882
- ) -> torch.Tensor:
883
- """Extract relevant embeddings for a sequence.
884
-
885
- Args:
886
- embedding: Raw embedding tensor from the model
887
- seq_len: Length of the original sequence
888
-
889
- Returns:
890
- Tensor containing only the relevant sequence embeddings
891
- """
892
- return embedding[:seq_len]