dayhoff-tools 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,893 @@
1
+ import logging
2
+ import os
3
+ import time
4
+ from abc import ABC, abstractmethod
5
+ from typing import Dict, List, Literal, Optional, Tuple, cast
6
+
7
+ import h5py
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import torch.utils.data
12
+ from dayhoff_tools.deployment.processors import Processor
13
+ from dayhoff_tools.fasta import (
14
+ clean_noncanonical_fasta,
15
+ clean_noncanonical_fasta_to_dict,
16
+ )
17
+ from esm import FastaBatchedDataset, pretrained
18
+ from transformers import T5EncoderModel, T5Tokenizer
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ESMEmbedder(Processor):
24
+ """A processor that calculates ESM embeddings for a file of protein sequences.
25
+ Embeddings come from the last layer, and can be per-protein or per-residue."""
26
+
27
+ def __init__(
28
+ self,
29
+ # PyTorch model file OR name of pretrained model to download
30
+ # https://github.com/facebookresearch/esm?tab=readme-ov-file#available
31
+ model_name: Literal[
32
+ "esm2_t33_650M_UR50D", # ESM2 version of the size used in CLEAN
33
+ "esm2_t6_8M_UR50D", # Smallest
34
+ "esm1b_t33_650M_UR50S", # Same as CLEAN
35
+ "esm2_t36_3B_UR50D", # 2nd largest
36
+ "esm2_t48_15B_UR50D", # Largest
37
+ ],
38
+ # Whether to return per-protein or per-residue embeddings.
39
+ embedding_level: Literal["protein", "residue"],
40
+ # Maximum batch size
41
+ toks_per_batch: int = 4096,
42
+ # Truncate sequences longer than the given value
43
+ truncation_seq_length: int = 1022,
44
+ ):
45
+ super().__init__()
46
+ self.model_name = model_name
47
+ self.toks_per_batch = toks_per_batch
48
+ self.embedding_level = embedding_level
49
+ self.truncation_seq_length = truncation_seq_length
50
+
51
+ # Instance variables set by other methods below:
52
+ # self.model, self.alphabet, self.len_batches, self.data_loader, self.dataset_base_name
53
+
54
+ def _load_model(self):
55
+ """Download pre-trained model and load onto device"""
56
+ self.model, self.alphabet = pretrained.load_model_and_alphabet(self.model_name)
57
+ self.model.eval()
58
+ if torch.cuda.is_available():
59
+ self.model.cuda()
60
+ logger.info("Transferred model to GPU.")
61
+ else:
62
+ logger.info("GPU not available. Running model on CPU.")
63
+
64
+ def _load_dataset(self, fasta_file: str) -> None:
65
+ """Load FASTA file into batched dataset and dataloader"""
66
+ if not fasta_file.endswith(".fasta"):
67
+ raise ValueError("Input file must have .fasta extension.")
68
+
69
+ self.dataset_base_name = fasta_file.replace(".fasta", "")
70
+ clean_fasta_file = fasta_file.replace(".fasta", "_clean.fasta")
71
+ clean_noncanonical_fasta(input_path=fasta_file, output_path=clean_fasta_file)
72
+
73
+ dataset = FastaBatchedDataset.from_file(clean_fasta_file)
74
+ logger.info("Read %s and loaded %s sequences.", fasta_file, len(dataset))
75
+
76
+ batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1)
77
+ self.len_batches = len(batches)
78
+ self.data_loader = torch.utils.data.DataLoader(
79
+ dataset, # type: ignore
80
+ collate_fn=self.alphabet.get_batch_converter(self.truncation_seq_length),
81
+ batch_sampler=batches,
82
+ )
83
+ os.remove(clean_fasta_file)
84
+
85
+ def embed_fasta(self) -> str:
86
+ """Calculate embeddings from the FASTA file, return the path to the .h5 file of results
87
+ Write the H5 file with one dataset per protein (id plus embedding vector, where the vector
88
+ is 2D if it was calculated per-residue).
89
+ """
90
+ output_path = self.dataset_base_name + ".h5"
91
+ with h5py.File(output_path, "w") as h5_file, torch.no_grad():
92
+ start_time = time.time()
93
+ logger.info(
94
+ f"Calculating per-{self.embedding_level} embeddings. This dataset contains {self.len_batches} batches."
95
+ )
96
+ total_batches = self.len_batches
97
+
98
+ for batch_idx, (labels, sequences, toks) in enumerate(
99
+ self.data_loader, start=1
100
+ ):
101
+ if batch_idx % 10 == 0:
102
+ elapsed_time = time.time() - start_time
103
+ time_left = elapsed_time * (total_batches - batch_idx) / batch_idx
104
+ logger.info(
105
+ f"{self.dataset_base_name} | Batch {batch_idx}/{total_batches} | Elapsed time {elapsed_time / 60:.0f} min | Time left {time_left / 60:.0f} min, or {time_left / 3_600:.2f} hours"
106
+ )
107
+
108
+ if torch.cuda.is_available():
109
+ toks = toks.to(device="cuda", non_blocking=True)
110
+ out = self.model(
111
+ toks,
112
+ repr_layers=[
113
+ cast(int, self.model.num_layers)
114
+ ], # Get the last layer
115
+ return_contacts=False,
116
+ )
117
+ out["logits"].to(device="cpu")
118
+ representations = {
119
+ layer: t.to(device="cpu")
120
+ for layer, t in out["representations"].items()
121
+ }
122
+ for label_idx, label_full in enumerate(labels):
123
+ label = label_full.split()[
124
+ 0
125
+ ] # Shorten the label to only the first word
126
+ truncate_len = min(
127
+ self.truncation_seq_length, len(sequences[label_idx])
128
+ )
129
+
130
+ full_embeds = list(representations.items())[0][1]
131
+ if self.embedding_level == "protein":
132
+ protein_embeds = (
133
+ full_embeds[label_idx, 1 : truncate_len + 1].mean(0).clone()
134
+ )
135
+ h5_file.create_dataset(label, data=protein_embeds)
136
+ elif self.embedding_level == "residue":
137
+ residue_embeds = full_embeds[
138
+ label_idx, 1 : truncate_len + 1
139
+ ].clone()
140
+ h5_file.create_dataset(label, data=residue_embeds)
141
+ else:
142
+ raise NotImplementedError(
143
+ f"Embedding level {self.embedding_level} not implemented."
144
+ )
145
+
146
+ logger.info(f"Saved embeddings to {output_path}")
147
+ return output_path
148
+
149
+ def run(self, input_file: str, output_file: Optional[str] = None) -> str:
150
+ self._load_model()
151
+ self._load_dataset(input_file)
152
+ embedding_file = self.embed_fasta()
153
+
154
+ # If embedding per-protein, reformat the H5 file.
155
+ # By default, write to the same filename; otherwise to output_file.
156
+ reformatted_embedding_file = output_file or embedding_file
157
+ if self.embedding_level == "protein":
158
+ print("Reformatting H5 file...")
159
+ formatter = H5Reformatter()
160
+ formatter.run2(
161
+ input_file=embedding_file,
162
+ output_file=reformatted_embedding_file,
163
+ )
164
+
165
+ return reformatted_embedding_file
166
+
167
+
168
+ class H5Reformatter(Processor):
169
+ """A processor that reformats per-protein T5 embeddings
170
+ Old format (input): 1 dataset per protein, with the ID as key and the embedding as value.
171
+ New format (output): 2 datasets for the whole file, one of all protein IDs and one of all
172
+ the embeddings together."""
173
+
174
+ def __init__(self):
175
+ super().__init__()
176
+ # Set the device to GPU if available, otherwise CPU
177
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
178
+ logger.info("Device is: %s", self.device)
179
+
180
+ def embedding_file_to_df(self, file_name: str) -> pd.DataFrame:
181
+ with h5py.File(file_name, "r") as f:
182
+ gene_names = list(f.keys())
183
+ Xg = [f[key][()] for key in gene_names] # type:ignore
184
+ return pd.DataFrame(np.asmatrix(Xg), index=gene_names) # type:ignore
185
+
186
+ def write_df_to_h5(self, df: pd.DataFrame, filename: str, description: str) -> None:
187
+ """
188
+ Write a DataFrame to an HDF5 file, separating row IDs and vectors.
189
+
190
+ Parameters:
191
+ - df: pandas DataFrame, where the index contains the IDs and
192
+ the columns contain the vector components.
193
+ - filename: String, the path to the output HDF5 file.
194
+ """
195
+ df.index = df.index.astype(
196
+ str
197
+ ) # Ensure the index is of a string type for the row IDs
198
+ vectors = df.values
199
+ ids = df.index.to_numpy(dtype=str)
200
+
201
+ with h5py.File(filename, "w") as h5f:
202
+ h5f.create_dataset("vectors", data=vectors)
203
+ dt = h5py.special_dtype(
204
+ vlen=str
205
+ ) # Use variable-length strings to accommodate any ID size
206
+ h5f.create_dataset("ids", data=ids.astype("S"), dtype=dt)
207
+
208
+ # add the attributes
209
+ h5f.attrs["description"] = description
210
+ h5f.attrs["num_vecs"] = vectors.shape[0]
211
+ h5f.attrs["vec_dim"] = vectors.shape[1]
212
+
213
+ def run(self, input_file: str) -> str:
214
+ """Load an H5 file as a DataFrame, delete the file, and then export
215
+ the dataframe as an H5 file in the new format."""
216
+ df = self.embedding_file_to_df(input_file)
217
+ os.remove(input_file)
218
+ new_file_description = "Embeddings formatted for global ID and Vector tables."
219
+ self.write_df_to_h5(df, input_file, new_file_description)
220
+
221
+ return input_file
222
+
223
+ def run2(self, input_file: str, output_file: str):
224
+ """Load an H5 file as a DataFrame, delete the file, and then export
225
+ the dataframe as an H5 file in the new format."""
226
+ df = self.embedding_file_to_df(input_file)
227
+ new_file_description = "Embeddings formatted for global ID and Vector tables."
228
+ self.write_df_to_h5(df, output_file, new_file_description)
229
+
230
+
231
+ class Embedder(Processor):
232
+ """Base class for protein sequence embedders with optimized memory management.
233
+
234
+ This class provides the core functionality for embedding protein sequences using
235
+ transformer models, with built-in memory management and batch processing capabilities.
236
+ It handles sequences of different sizes appropriately, processing large sequences
237
+ individually and smaller sequences in batches for efficiency.
238
+
239
+ Memory Management Features:
240
+ - Periodic cleanup of GPU memory to prevent fragmentation
241
+ - Separate handling of large and small sequences to optimize memory usage
242
+ - Batch size limits based on total residues to prevent OOM errors
243
+ - Configurable cleanup frequency to balance performance and memory usage
244
+ - Empirically tested sequence length limits (5000-5500 residues depending on model)
245
+
246
+ Memory Fragmentation Prevention:
247
+ - Large sequences (>2500 residues) are processed individually to maintain contiguous memory blocks
248
+ - Small sequences are batched to efficiently use memory fragments
249
+ - Forced cleanup after processing large proteins
250
+ - Memory cleanup after every N sequences (configurable)
251
+ - Aggressive garbage collection settings for CUDA memory allocator
252
+
253
+ Memory Usage Patterns:
254
+ - Base memory: 2.4-4.8GB (model dependent)
255
+ - Peak memory: 12-15GB during large sequence processing
256
+ - Fragmentation ratio maintained above 0.92 for efficient memory use
257
+ - Maximum sequence length determined by model:
258
+ * T5: ~5000 residues
259
+ * ProstT5: ~5500 residues
260
+
261
+ Implementation Details:
262
+ - Uses PyTorch's CUDA memory allocator with optimized settings
263
+ - Configurable thresholds for large protein handling
264
+ - Automatic batch size adjustment based on sequence lengths
265
+ - Optional chunking for sequences exceeding maximum length
266
+ - Detailed memory statistics logging for monitoring
267
+
268
+ Note:
269
+ Memory limits are hardware-dependent. The above values are based on testing
270
+ with a 16GB GPU (such as NVIDIA T4). Adjust parameters based on available GPU memory.
271
+ """
272
+
273
+ def __init__(
274
+ self,
275
+ model,
276
+ tokenizer,
277
+ max_seq_length: int = 5000,
278
+ large_protein_threshold: int = 2500,
279
+ batch_residue_limit: int = 5000,
280
+ cleanup_frequency: int = 100,
281
+ skip_long_proteins: bool = False,
282
+ ):
283
+ """Initialize the Embedder with model and memory management parameters.
284
+
285
+ Args:
286
+ model: The transformer model to use for embeddings
287
+ tokenizer: The tokenizer matching the model
288
+ max_seq_length: Maximum sequence length before chunking or skipping.
289
+ Also used as chunk size when processing long sequences.
290
+ large_protein_threshold: Sequences longer than this are processed individually
291
+ batch_residue_limit: Maximum total residues allowed in a batch
292
+ cleanup_frequency: Number of sequences to process before performing memory cleanup
293
+ skip_long_proteins: If True, skip proteins longer than max_seq_length.
294
+ If False, process them in chunks of max_seq_length size.
295
+
296
+ Note:
297
+ The class automatically configures CUDA memory allocation if GPU is available.
298
+ """
299
+ # Configure memory allocator for CUDA
300
+ if torch.cuda.is_available():
301
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
302
+ "max_split_size_mb:128," # Smaller allocation chunks
303
+ "garbage_collection_threshold:0.8" # More aggressive GC
304
+ )
305
+
306
+ self.max_seq_length = max_seq_length
307
+ self.large_protein_threshold = large_protein_threshold
308
+ self.batch_residue_limit = batch_residue_limit
309
+ self.cleanup_frequency = cleanup_frequency
310
+ self.skip_long_proteins = skip_long_proteins
311
+ self.processed_count = 0
312
+
313
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
314
+ self.model = model
315
+ self.tokenizer = tokenizer
316
+
317
+ self.model.to(self.device)
318
+ self.model.eval()
319
+
320
+ def get_embeddings(self, seqs: Dict[str, str]) -> Dict[str, np.ndarray]:
321
+ """Process sequences and generate embeddings with memory management.
322
+
323
+ This method handles the core embedding logic, including:
324
+ - Handling sequences that exceed maximum length (skip or chunk)
325
+ - Splitting sequences into large and small batches
326
+ - Periodic memory cleanup
327
+ - Batch processing for efficiency
328
+
329
+ Args:
330
+ seqs: Dictionary mapping sequence IDs to their amino acid sequences
331
+
332
+ Returns:
333
+ Dictionary mapping sequence IDs to their embedding vectors
334
+
335
+ Note:
336
+ Long sequences are either skipped or processed in chunks based on skip_long_proteins
337
+ """
338
+ results: Dict[str, np.ndarray] = {}
339
+ try:
340
+ # Initialize progress tracking
341
+ total_sequences = len(seqs)
342
+ processed_sequences = 0
343
+ start_time = time.time()
344
+
345
+ logger.info(f"Starting embedding of {total_sequences} sequences")
346
+
347
+ # Handle sequences based on length
348
+ long_seqs = {
349
+ id: seq for id, seq in seqs.items() if len(seq) > self.max_seq_length
350
+ }
351
+ valid_seqs = {
352
+ id: seq for id, seq in seqs.items() if len(seq) <= self.max_seq_length
353
+ }
354
+
355
+ if long_seqs:
356
+ if self.skip_long_proteins:
357
+ logger.warning(
358
+ f"Skipping {len(long_seqs)} sequences exceeding max length {self.max_seq_length}: {', '.join(long_seqs.keys())}"
359
+ )
360
+ else:
361
+ logger.info(
362
+ f"Processing {len(long_seqs)} long sequences in chunks: {', '.join(long_seqs.keys())}"
363
+ )
364
+ for i, (seq_id, seq) in enumerate(long_seqs.items(), 1):
365
+ logger.info(
366
+ f"Embedding long sequence {i}/{len(long_seqs)}: {seq_id}"
367
+ )
368
+ results[seq_id] = self.embed_big_prot(seq_id, seq)
369
+ self.cleanup_memory()
370
+
371
+ # Update progress
372
+ processed_sequences += 1
373
+ elapsed_time = time.time() - start_time
374
+ remaining_sequences = total_sequences - processed_sequences
375
+ avg_time_per_seq = (
376
+ elapsed_time / processed_sequences
377
+ if processed_sequences > 0
378
+ else 0
379
+ )
380
+ estimated_time_left = avg_time_per_seq * remaining_sequences
381
+
382
+ logger.info(
383
+ f"Progress: {processed_sequences}/{total_sequences} sequences ({processed_sequences/total_sequences*100:.1f}%) | "
384
+ f"Elapsed: {elapsed_time/60:.1f} min | "
385
+ f"Est. remaining: {estimated_time_left/60:.1f} min"
386
+ )
387
+
388
+ # Split remaining sequences based on size
389
+ large_seqs = {
390
+ id: seq
391
+ for id, seq in valid_seqs.items()
392
+ if len(seq) > self.large_protein_threshold
393
+ }
394
+ small_seqs = {
395
+ id: seq
396
+ for id, seq in valid_seqs.items()
397
+ if len(seq) <= self.large_protein_threshold
398
+ }
399
+
400
+ logger.info(
401
+ f"Split into {len(large_seqs)} large and {len(small_seqs)} small sequences"
402
+ )
403
+
404
+ # Process large sequences individually
405
+ for i, (seq_id, seq) in enumerate(large_seqs.items(), 1):
406
+ logger.info(
407
+ f"Processing large sequence {i}/{len(large_seqs)}: {seq_id}"
408
+ )
409
+ batch = [(seq_id, seq, len(seq))]
410
+ results.update(self.embed_batch(batch))
411
+ self.cleanup_memory() # Cleanup after each large sequence
412
+
413
+ # Update progress
414
+ processed_sequences += 1
415
+ elapsed_time = time.time() - start_time
416
+ remaining_sequences = total_sequences - processed_sequences
417
+ avg_time_per_seq = (
418
+ elapsed_time / processed_sequences if processed_sequences > 0 else 0
419
+ )
420
+ estimated_time_left = avg_time_per_seq * remaining_sequences
421
+
422
+ logger.info(
423
+ f"Progress: {processed_sequences}/{total_sequences} sequences ({processed_sequences/total_sequences*100:.1f}%) | "
424
+ f"Elapsed: {elapsed_time/60:.1f} min | "
425
+ f"Est. remaining: {estimated_time_left/60:.1f} min"
426
+ )
427
+
428
+ # Process small sequences in batches
429
+ current_batch: List[Tuple[str, str, int]] = []
430
+ current_size = 0
431
+ small_batch_count = 0
432
+ total_small_batches = (
433
+ sum(len(seq) for seq in small_seqs.values())
434
+ + self.batch_residue_limit
435
+ - 1
436
+ ) // self.batch_residue_limit
437
+
438
+ # Sort sequences by length in descending order (reduces unnecessary padding --> speeds up embedding)
439
+ small_seqs_sorted = sorted(
440
+ small_seqs.items(), key=lambda x: len(x[1]), reverse=True
441
+ )
442
+
443
+ for seq_id, seq in small_seqs_sorted:
444
+ seq_len = len(seq)
445
+
446
+ if current_size + seq_len > self.batch_residue_limit:
447
+ if current_batch:
448
+ small_batch_count += 1
449
+ logger.info(
450
+ f"Processing small batch {small_batch_count}/{total_small_batches} with {len(current_batch)} sequences"
451
+ )
452
+ batch_results = self.embed_batch(current_batch)
453
+ results.update(batch_results)
454
+ self.cleanup_memory()
455
+
456
+ # Update progress
457
+ processed_sequences += len(current_batch)
458
+ elapsed_time = time.time() - start_time
459
+ remaining_sequences = total_sequences - processed_sequences
460
+ avg_time_per_seq = (
461
+ elapsed_time / processed_sequences
462
+ if processed_sequences > 0
463
+ else 0
464
+ )
465
+ estimated_time_left = avg_time_per_seq * remaining_sequences
466
+
467
+ logger.info(
468
+ f"Progress: {processed_sequences}/{total_sequences} sequences ({processed_sequences/total_sequences*100:.1f}%) | "
469
+ f"Elapsed: {elapsed_time/60:.1f} min | "
470
+ f"Est. remaining: {estimated_time_left/60:.1f} min"
471
+ )
472
+ current_batch = []
473
+ current_size = 0
474
+
475
+ current_batch.append((seq_id, seq, seq_len))
476
+ current_size += seq_len
477
+
478
+ # Process remaining batch
479
+ if current_batch:
480
+ small_batch_count += 1
481
+ logger.info(
482
+ f"Processing final small batch {small_batch_count}/{total_small_batches} with {len(current_batch)} sequences"
483
+ )
484
+ batch_results = self.embed_batch(current_batch)
485
+ results.update(batch_results)
486
+
487
+ # Update final progress
488
+ processed_sequences += len(current_batch)
489
+ elapsed_time = time.time() - start_time
490
+
491
+ logger.info(
492
+ f"Completed embedding {processed_sequences}/{total_sequences} sequences in {elapsed_time/60:.1f} minutes"
493
+ )
494
+
495
+ return results
496
+
497
+ finally:
498
+ self.cleanup_memory(deep=True)
499
+
500
+ def cleanup_memory(self, deep: bool = False):
501
+ """Perform memory cleanup operations.
502
+
503
+ Args:
504
+ deep: If True, performs aggressive cleanup including model transfer
505
+ and garbage collection. Takes longer but frees more memory.
506
+
507
+ Note:
508
+ Regular cleanup is performed based on cleanup_frequency.
509
+ Deep cleanup is more thorough but takes longer.
510
+ """
511
+ self.processed_count += 1
512
+
513
+ if deep or self.processed_count % self.cleanup_frequency == 0:
514
+ logger.info(
515
+ f"Performing memory cleanup after {self.processed_count} sequences"
516
+ )
517
+ if torch.cuda.is_available():
518
+ before_mem = torch.cuda.memory_allocated() / 1e9
519
+
520
+ torch.cuda.empty_cache()
521
+ if deep:
522
+ self.model = self.model.cpu()
523
+ torch.cuda.empty_cache()
524
+ self.model = self.model.to(self.device)
525
+
526
+ after_mem = torch.cuda.memory_allocated() / 1e9
527
+ logger.info(
528
+ f"Memory cleaned up: {before_mem:.2f}GB -> {after_mem:.2f}GB"
529
+ )
530
+
531
+ if deep:
532
+ import gc
533
+
534
+ gc.collect()
535
+
536
+ def run(self, input_file, output_file=None):
537
+ """
538
+ Run the embedding process on the input file.
539
+
540
+ Args:
541
+ input_file (str): Path to the input FASTA file.
542
+ output_file (str, optional): Path to the output H5 file. If not provided,
543
+ it will be generated from the input file name.
544
+
545
+ Returns:
546
+ str: Path to the output H5 file containing the embeddings.
547
+ """
548
+ logger.info(f"Loading sequences from {input_file}")
549
+ start_time = time.time()
550
+ sequences = clean_noncanonical_fasta_to_dict(input_file)
551
+ load_time = time.time() - start_time
552
+ logger.info(
553
+ f"Loaded {len(sequences)} sequences from {input_file} in {load_time:.2f} seconds"
554
+ )
555
+
556
+ logger.info(f"Starting embedding process for {len(sequences)} sequences")
557
+ embed_start_time = time.time()
558
+ embeddings = self.get_embeddings(sequences)
559
+ embed_time = time.time() - embed_start_time
560
+ logger.info(
561
+ f"Completed embedding {len(embeddings)} sequences in {embed_time/60:.2f} minutes"
562
+ )
563
+
564
+ if output_file is None:
565
+ output_file = input_file.replace(".fasta", ".h5")
566
+
567
+ logger.info(f"Saving embeddings to {output_file}")
568
+ save_start_time = time.time()
569
+ self.save_to_h5(output_file, embeddings)
570
+ save_time = time.time() - save_start_time
571
+ logger.info(
572
+ f"Saved {len(embeddings)} embeddings to {output_file} in {save_time:.2f} seconds"
573
+ )
574
+
575
+ total_time = time.time() - start_time
576
+ logger.info(f"Total processing time: {total_time/60:.2f} minutes")
577
+
578
+ return output_file
579
+
580
+ def save_to_h5(self, output_file: str, embeddings: Dict[str, np.ndarray]) -> None:
581
+ """
582
+ Save protein embeddings to an HDF5 file.
583
+
584
+ Args:
585
+ output_file (str): Path to save the embeddings.
586
+ embeddings (Dict[str, np.ndarray]): Dictionary of embeddings.
587
+
588
+ The method creates an H5 file with two datasets:
589
+ - 'ids': contains protein IDs as variable-length strings
590
+ - 'vectors': contains embedding vectors as float32 arrays
591
+ """
592
+ # Convert the embeddings dictionary to lists for ids and vectors
593
+ ids = list(embeddings.keys())
594
+ vectors = np.array(list(embeddings.values()), dtype=np.float32)
595
+
596
+ # Create the HDF5 file, with datasets for vectors and IDs
597
+ with h5py.File(output_file, "w") as h5f:
598
+ # Create the 'vectors' dataset
599
+ h5f.create_dataset("vectors", data=vectors)
600
+
601
+ # Create the 'ids' dataset with variable-length strings
602
+ dt = h5py.special_dtype(vlen=str)
603
+ h5f.create_dataset("ids", data=ids, dtype=dt)
604
+
605
+ # Add the attributes
606
+ h5f.attrs["num_vecs"] = len(embeddings)
607
+ h5f.attrs["vec_dim"] = vectors.shape[1] if vectors.size > 0 else 0
608
+
609
+ def embed_big_prot(self, seq_id: str, sequence: str) -> np.ndarray:
610
+ """Embed a large protein sequence by chunking it and averaging the embeddings.
611
+
612
+ Args:
613
+ seq_id: The identifier for the protein sequence
614
+ sequence: The protein sequence to embed
615
+
616
+ Returns:
617
+ np.ndarray: The averaged embedding for the entire sequence
618
+
619
+ Note:
620
+ This method processes the sequence in chunks of size max_seq_length
621
+ and averages the resulting embeddings.
622
+ """
623
+ if not isinstance(sequence, str):
624
+ raise TypeError("Sequence must be a string.")
625
+
626
+ if not sequence:
627
+ raise ValueError("Sequence cannot be empty.")
628
+
629
+ if self.max_seq_length <= 0:
630
+ raise ValueError("max_seq_length must be greater than 0.")
631
+
632
+ # Create chunks of the sequence using max_seq_length
633
+ chunks: List[Tuple[str, str, int]] = [
634
+ (
635
+ seq_id,
636
+ sequence[i : i + self.max_seq_length],
637
+ min(self.max_seq_length, len(sequence) - i),
638
+ )
639
+ for i in range(0, len(sequence), self.max_seq_length)
640
+ ]
641
+
642
+ logger.info(
643
+ f"Processing {seq_id} in {len(chunks)} chunks (total length: {len(sequence)})"
644
+ )
645
+
646
+ # Embed each chunk
647
+ chunk_embeddings = []
648
+ for i, chunk in enumerate(chunks, 1):
649
+ logger.info(
650
+ f"Processing chunk {i}/{len(chunks)} for {seq_id} (length: {chunk[2]})"
651
+ )
652
+ chunk_start_time = time.time()
653
+ result = self.embed_batch([chunk])
654
+ chunk_embeddings.append(result[seq_id])
655
+ chunk_time = time.time() - chunk_start_time
656
+ logger.info(
657
+ f"Processed chunk {i}/{len(chunks)} for {seq_id} in {chunk_time:.2f} seconds"
658
+ )
659
+
660
+ # Average the embeddings
661
+ average_embedding = np.mean(chunk_embeddings, axis=0)
662
+ logger.info(f"Completed processing {seq_id} (averaged {len(chunks)} chunks)")
663
+
664
+ return average_embedding
665
+
666
+ def embed_batch(self, batch: List[Tuple[str, str, int]]) -> Dict[str, np.ndarray]:
667
+ """
668
+ Generate embeddings for a batch of sequences.
669
+
670
+ Args:
671
+ batch: A list of tuples, each containing (sequence_id, sequence, sequence_length)
672
+
673
+ Returns:
674
+ A dictionary mapping sequence IDs to their embeddings as numpy arrays
675
+ """
676
+ if not batch:
677
+ raise ValueError(
678
+ "Cannot embed an empty batch. Please provide at least one sequence."
679
+ )
680
+
681
+ sequence_ids, sequences, sequence_lengths = zip(*batch)
682
+
683
+ # Prepare sequences for tokenization
684
+ tokenizer_input = self.prepare_tokenizer_input(sequences)
685
+
686
+ # Tokenize sequences
687
+ encoded_input = self.tokenizer.batch_encode_plus(
688
+ tokenizer_input,
689
+ add_special_tokens=True,
690
+ padding="longest",
691
+ return_tensors="pt",
692
+ )
693
+
694
+ # Move tensors to the appropriate device
695
+ input_ids = encoded_input["input_ids"].to(self.device)
696
+ attention_mask = encoded_input["attention_mask"].to(self.device)
697
+
698
+ # Generate embeddings
699
+ with torch.no_grad():
700
+ embedding_output = self.model(
701
+ input_ids, attention_mask=attention_mask
702
+ ).last_hidden_state
703
+
704
+ # Process embeddings for each sequence
705
+ embeddings = {}
706
+ for idx, (seq_id, seq_len) in enumerate(zip(sequence_ids, sequence_lengths)):
707
+ # Extract embedding for the sequence
708
+ seq_embedding = self.extract_sequence_embedding(
709
+ embedding_output[idx], seq_len
710
+ )
711
+
712
+ # Calculate mean embedding and convert to numpy array
713
+ mean_embedding = seq_embedding.mean(dim=0).detach().cpu().numpy().squeeze()
714
+
715
+ embeddings[seq_id] = mean_embedding
716
+
717
+ return embeddings
718
+
719
+ def prepare_tokenizer_input(self, sequences: List[str]) -> List[str]:
720
+ """Prepare sequences for tokenization."""
721
+ raise NotImplementedError
722
+
723
+ def extract_sequence_embedding(
724
+ self, embedding: torch.Tensor, seq_len: int
725
+ ) -> torch.Tensor:
726
+ """Extract the relevant part of the embedding for a sequence."""
727
+ raise NotImplementedError
728
+
729
+
730
+ class ProstT5Embedder(Embedder):
731
+ """Protein sequence embedder using the ProstT5 model.
732
+
733
+ This class implements protein sequence embedding using the ProstT5 model,
734
+ which is specifically trained for protein structure prediction tasks.
735
+ It includes memory-efficient processing and automatic precision selection
736
+ based on available hardware.
737
+
738
+ Memory management features are inherited from the base Embedder class:
739
+ - Periodic cleanup of GPU memory
740
+ - Separate handling of large and small sequences
741
+ - Batch size limits based on total residues
742
+ - Configurable cleanup frequency
743
+ """
744
+
745
+ def __init__(
746
+ self,
747
+ max_seq_length: int = 5000,
748
+ large_protein_threshold: int = 2500,
749
+ batch_residue_limit: int = 5000,
750
+ cleanup_frequency: int = 100,
751
+ skip_long_proteins: bool = False,
752
+ ):
753
+ """Initialize ProstT5Embedder with memory management parameters.
754
+
755
+ Args:
756
+ max_seq_length: Maximum sequence length before chunking or skipping.
757
+ Also used as chunk size when processing long sequences.
758
+ large_protein_threshold: Sequences longer than this are processed individually
759
+ batch_residue_limit: Maximum total residues in a batch
760
+ cleanup_frequency: Frequency of memory cleanup operations
761
+ skip_long_proteins: If True, skip proteins longer than max_seq_length
762
+
763
+ Note:
764
+ The model automatically selects half precision (float16) when running on GPU
765
+ and full precision (float32) when running on CPU.
766
+ """
767
+ tokenizer = T5Tokenizer.from_pretrained(
768
+ "Rostlab/ProstT5", do_lower_case=False, legacy=True
769
+ )
770
+ model = T5EncoderModel.from_pretrained("Rostlab/ProstT5")
771
+
772
+ super().__init__(
773
+ model,
774
+ tokenizer,
775
+ max_seq_length,
776
+ large_protein_threshold,
777
+ batch_residue_limit,
778
+ cleanup_frequency,
779
+ skip_long_proteins,
780
+ )
781
+
782
+ # Set precision based on device
783
+ self.model = (
784
+ self.model.half() if torch.cuda.is_available() else self.model.float()
785
+ )
786
+
787
+ def prepare_tokenizer_input(self, sequences: List[str]) -> List[str]:
788
+ """Prepare sequences for ProstT5 tokenization.
789
+
790
+ Args:
791
+ sequences: List of amino acid sequences
792
+
793
+ Returns:
794
+ List of sequences with ProstT5-specific formatting, including the
795
+ <AA2fold> prefix and space-separated residues.
796
+ """
797
+ return [f"<AA2fold> {' '.join(seq)}" for seq in sequences]
798
+
799
+ def extract_sequence_embedding(
800
+ self, embedding: torch.Tensor, seq_len: int
801
+ ) -> torch.Tensor:
802
+ """Extract relevant embeddings for a sequence.
803
+
804
+ Args:
805
+ embedding: Raw embedding tensor from the model
806
+ seq_len: Length of the original sequence
807
+
808
+ Returns:
809
+ Tensor containing only the relevant sequence embeddings,
810
+ excluding special tokens. For ProstT5, we skip the first token
811
+ (corresponding to <AA2fold>) and take the next seq_len tokens.
812
+ """
813
+ return embedding[1 : seq_len + 1]
814
+
815
+
816
+ class T5Embedder(Embedder):
817
+ """Protein sequence embedder using the T5 transformer model.
818
+
819
+ This class implements protein sequence embedding using the T5 model from Rostlab,
820
+ specifically designed for protein sequences. It includes memory-efficient processing
821
+ of both large and small sequences.
822
+
823
+ The model used is 'Rostlab/prot_t5_xl_half_uniref50-enc', which was trained on
824
+ UniRef50 sequences and provides state-of-the-art protein embeddings.
825
+
826
+ Memory management features are inherited from the base Embedder class:
827
+ - Periodic cleanup of GPU memory
828
+ - Separate handling of large and small sequences
829
+ - Batch size limits based on total residues
830
+ - Configurable cleanup frequency
831
+ """
832
+
833
+ def __init__(
834
+ self,
835
+ max_seq_length: int = 5000,
836
+ large_protein_threshold: int = 2500,
837
+ batch_residue_limit: int = 5000,
838
+ cleanup_frequency: int = 100,
839
+ skip_long_proteins: bool = False,
840
+ ):
841
+ """Initialize T5Embedder with memory management parameters.
842
+
843
+ Args:
844
+ max_seq_length: Maximum sequence length before chunking or skipping.
845
+ Also used as chunk size when processing long sequences.
846
+ large_protein_threshold: Sequences longer than this are processed individually
847
+ batch_residue_limit: Maximum total residues in a batch
848
+ cleanup_frequency: Frequency of memory cleanup operations
849
+ skip_long_proteins: If True, skip proteins longer than max_seq_length
850
+
851
+ Note:
852
+ The model automatically handles memory management and batch processing
853
+ based on sequence sizes and available resources.
854
+ """
855
+ tokenizer = T5Tokenizer.from_pretrained(
856
+ "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
857
+ )
858
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
859
+
860
+ super().__init__(
861
+ model,
862
+ tokenizer,
863
+ max_seq_length,
864
+ large_protein_threshold,
865
+ batch_residue_limit,
866
+ cleanup_frequency,
867
+ skip_long_proteins,
868
+ )
869
+
870
+ def prepare_tokenizer_input(self, sequences: List[str]) -> List[str]:
871
+ """Prepare sequences for T5 tokenization.
872
+
873
+ Args:
874
+ sequences: List of amino acid sequences
875
+
876
+ Returns:
877
+ List of space-separated sequences ready for tokenization
878
+ """
879
+ return [" ".join(seq) for seq in sequences]
880
+
881
+ def extract_sequence_embedding(
882
+ self, embedding: torch.Tensor, seq_len: int
883
+ ) -> torch.Tensor:
884
+ """Extract relevant embeddings for a sequence.
885
+
886
+ Args:
887
+ embedding: Raw embedding tensor from the model
888
+ seq_len: Length of the original sequence
889
+
890
+ Returns:
891
+ Tensor containing only the relevant sequence embeddings
892
+ """
893
+ return embedding[:seq_len]