dayhoff-tools 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dayhoff_tools/__init__.py +0 -0
- dayhoff_tools/chemistry/standardizer.py +297 -0
- dayhoff_tools/chemistry/utils.py +63 -0
- dayhoff_tools/cli/__init__.py +0 -0
- dayhoff_tools/cli/main.py +90 -0
- dayhoff_tools/cli/swarm_commands.py +156 -0
- dayhoff_tools/cli/utility_commands.py +244 -0
- dayhoff_tools/deployment/base.py +434 -0
- dayhoff_tools/deployment/deploy_aws.py +458 -0
- dayhoff_tools/deployment/deploy_gcp.py +176 -0
- dayhoff_tools/deployment/deploy_utils.py +781 -0
- dayhoff_tools/deployment/job_runner.py +153 -0
- dayhoff_tools/deployment/processors.py +125 -0
- dayhoff_tools/deployment/swarm.py +591 -0
- dayhoff_tools/embedders.py +893 -0
- dayhoff_tools/fasta.py +1082 -0
- dayhoff_tools/file_ops.py +261 -0
- dayhoff_tools/gcp.py +85 -0
- dayhoff_tools/h5.py +542 -0
- dayhoff_tools/kegg.py +37 -0
- dayhoff_tools/logs.py +27 -0
- dayhoff_tools/mmseqs.py +164 -0
- dayhoff_tools/sqlite.py +516 -0
- dayhoff_tools/structure.py +751 -0
- dayhoff_tools/uniprot.py +434 -0
- dayhoff_tools/warehouse.py +418 -0
- dayhoff_tools-1.0.0.dist-info/METADATA +122 -0
- dayhoff_tools-1.0.0.dist-info/RECORD +30 -0
- dayhoff_tools-1.0.0.dist-info/WHEEL +4 -0
- dayhoff_tools-1.0.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,893 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import time
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import Dict, List, Literal, Optional, Tuple, cast
|
6
|
+
|
7
|
+
import h5py
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
import torch
|
11
|
+
import torch.utils.data
|
12
|
+
from dayhoff_tools.deployment.processors import Processor
|
13
|
+
from dayhoff_tools.fasta import (
|
14
|
+
clean_noncanonical_fasta,
|
15
|
+
clean_noncanonical_fasta_to_dict,
|
16
|
+
)
|
17
|
+
from esm import FastaBatchedDataset, pretrained
|
18
|
+
from transformers import T5EncoderModel, T5Tokenizer
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class ESMEmbedder(Processor):
|
24
|
+
"""A processor that calculates ESM embeddings for a file of protein sequences.
|
25
|
+
Embeddings come from the last layer, and can be per-protein or per-residue."""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
# PyTorch model file OR name of pretrained model to download
|
30
|
+
# https://github.com/facebookresearch/esm?tab=readme-ov-file#available
|
31
|
+
model_name: Literal[
|
32
|
+
"esm2_t33_650M_UR50D", # ESM2 version of the size used in CLEAN
|
33
|
+
"esm2_t6_8M_UR50D", # Smallest
|
34
|
+
"esm1b_t33_650M_UR50S", # Same as CLEAN
|
35
|
+
"esm2_t36_3B_UR50D", # 2nd largest
|
36
|
+
"esm2_t48_15B_UR50D", # Largest
|
37
|
+
],
|
38
|
+
# Whether to return per-protein or per-residue embeddings.
|
39
|
+
embedding_level: Literal["protein", "residue"],
|
40
|
+
# Maximum batch size
|
41
|
+
toks_per_batch: int = 4096,
|
42
|
+
# Truncate sequences longer than the given value
|
43
|
+
truncation_seq_length: int = 1022,
|
44
|
+
):
|
45
|
+
super().__init__()
|
46
|
+
self.model_name = model_name
|
47
|
+
self.toks_per_batch = toks_per_batch
|
48
|
+
self.embedding_level = embedding_level
|
49
|
+
self.truncation_seq_length = truncation_seq_length
|
50
|
+
|
51
|
+
# Instance variables set by other methods below:
|
52
|
+
# self.model, self.alphabet, self.len_batches, self.data_loader, self.dataset_base_name
|
53
|
+
|
54
|
+
def _load_model(self):
|
55
|
+
"""Download pre-trained model and load onto device"""
|
56
|
+
self.model, self.alphabet = pretrained.load_model_and_alphabet(self.model_name)
|
57
|
+
self.model.eval()
|
58
|
+
if torch.cuda.is_available():
|
59
|
+
self.model.cuda()
|
60
|
+
logger.info("Transferred model to GPU.")
|
61
|
+
else:
|
62
|
+
logger.info("GPU not available. Running model on CPU.")
|
63
|
+
|
64
|
+
def _load_dataset(self, fasta_file: str) -> None:
|
65
|
+
"""Load FASTA file into batched dataset and dataloader"""
|
66
|
+
if not fasta_file.endswith(".fasta"):
|
67
|
+
raise ValueError("Input file must have .fasta extension.")
|
68
|
+
|
69
|
+
self.dataset_base_name = fasta_file.replace(".fasta", "")
|
70
|
+
clean_fasta_file = fasta_file.replace(".fasta", "_clean.fasta")
|
71
|
+
clean_noncanonical_fasta(input_path=fasta_file, output_path=clean_fasta_file)
|
72
|
+
|
73
|
+
dataset = FastaBatchedDataset.from_file(clean_fasta_file)
|
74
|
+
logger.info("Read %s and loaded %s sequences.", fasta_file, len(dataset))
|
75
|
+
|
76
|
+
batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1)
|
77
|
+
self.len_batches = len(batches)
|
78
|
+
self.data_loader = torch.utils.data.DataLoader(
|
79
|
+
dataset, # type: ignore
|
80
|
+
collate_fn=self.alphabet.get_batch_converter(self.truncation_seq_length),
|
81
|
+
batch_sampler=batches,
|
82
|
+
)
|
83
|
+
os.remove(clean_fasta_file)
|
84
|
+
|
85
|
+
def embed_fasta(self) -> str:
|
86
|
+
"""Calculate embeddings from the FASTA file, return the path to the .h5 file of results
|
87
|
+
Write the H5 file with one dataset per protein (id plus embedding vector, where the vector
|
88
|
+
is 2D if it was calculated per-residue).
|
89
|
+
"""
|
90
|
+
output_path = self.dataset_base_name + ".h5"
|
91
|
+
with h5py.File(output_path, "w") as h5_file, torch.no_grad():
|
92
|
+
start_time = time.time()
|
93
|
+
logger.info(
|
94
|
+
f"Calculating per-{self.embedding_level} embeddings. This dataset contains {self.len_batches} batches."
|
95
|
+
)
|
96
|
+
total_batches = self.len_batches
|
97
|
+
|
98
|
+
for batch_idx, (labels, sequences, toks) in enumerate(
|
99
|
+
self.data_loader, start=1
|
100
|
+
):
|
101
|
+
if batch_idx % 10 == 0:
|
102
|
+
elapsed_time = time.time() - start_time
|
103
|
+
time_left = elapsed_time * (total_batches - batch_idx) / batch_idx
|
104
|
+
logger.info(
|
105
|
+
f"{self.dataset_base_name} | Batch {batch_idx}/{total_batches} | Elapsed time {elapsed_time / 60:.0f} min | Time left {time_left / 60:.0f} min, or {time_left / 3_600:.2f} hours"
|
106
|
+
)
|
107
|
+
|
108
|
+
if torch.cuda.is_available():
|
109
|
+
toks = toks.to(device="cuda", non_blocking=True)
|
110
|
+
out = self.model(
|
111
|
+
toks,
|
112
|
+
repr_layers=[
|
113
|
+
cast(int, self.model.num_layers)
|
114
|
+
], # Get the last layer
|
115
|
+
return_contacts=False,
|
116
|
+
)
|
117
|
+
out["logits"].to(device="cpu")
|
118
|
+
representations = {
|
119
|
+
layer: t.to(device="cpu")
|
120
|
+
for layer, t in out["representations"].items()
|
121
|
+
}
|
122
|
+
for label_idx, label_full in enumerate(labels):
|
123
|
+
label = label_full.split()[
|
124
|
+
0
|
125
|
+
] # Shorten the label to only the first word
|
126
|
+
truncate_len = min(
|
127
|
+
self.truncation_seq_length, len(sequences[label_idx])
|
128
|
+
)
|
129
|
+
|
130
|
+
full_embeds = list(representations.items())[0][1]
|
131
|
+
if self.embedding_level == "protein":
|
132
|
+
protein_embeds = (
|
133
|
+
full_embeds[label_idx, 1 : truncate_len + 1].mean(0).clone()
|
134
|
+
)
|
135
|
+
h5_file.create_dataset(label, data=protein_embeds)
|
136
|
+
elif self.embedding_level == "residue":
|
137
|
+
residue_embeds = full_embeds[
|
138
|
+
label_idx, 1 : truncate_len + 1
|
139
|
+
].clone()
|
140
|
+
h5_file.create_dataset(label, data=residue_embeds)
|
141
|
+
else:
|
142
|
+
raise NotImplementedError(
|
143
|
+
f"Embedding level {self.embedding_level} not implemented."
|
144
|
+
)
|
145
|
+
|
146
|
+
logger.info(f"Saved embeddings to {output_path}")
|
147
|
+
return output_path
|
148
|
+
|
149
|
+
def run(self, input_file: str, output_file: Optional[str] = None) -> str:
|
150
|
+
self._load_model()
|
151
|
+
self._load_dataset(input_file)
|
152
|
+
embedding_file = self.embed_fasta()
|
153
|
+
|
154
|
+
# If embedding per-protein, reformat the H5 file.
|
155
|
+
# By default, write to the same filename; otherwise to output_file.
|
156
|
+
reformatted_embedding_file = output_file or embedding_file
|
157
|
+
if self.embedding_level == "protein":
|
158
|
+
print("Reformatting H5 file...")
|
159
|
+
formatter = H5Reformatter()
|
160
|
+
formatter.run2(
|
161
|
+
input_file=embedding_file,
|
162
|
+
output_file=reformatted_embedding_file,
|
163
|
+
)
|
164
|
+
|
165
|
+
return reformatted_embedding_file
|
166
|
+
|
167
|
+
|
168
|
+
class H5Reformatter(Processor):
|
169
|
+
"""A processor that reformats per-protein T5 embeddings
|
170
|
+
Old format (input): 1 dataset per protein, with the ID as key and the embedding as value.
|
171
|
+
New format (output): 2 datasets for the whole file, one of all protein IDs and one of all
|
172
|
+
the embeddings together."""
|
173
|
+
|
174
|
+
def __init__(self):
|
175
|
+
super().__init__()
|
176
|
+
# Set the device to GPU if available, otherwise CPU
|
177
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
178
|
+
logger.info("Device is: %s", self.device)
|
179
|
+
|
180
|
+
def embedding_file_to_df(self, file_name: str) -> pd.DataFrame:
|
181
|
+
with h5py.File(file_name, "r") as f:
|
182
|
+
gene_names = list(f.keys())
|
183
|
+
Xg = [f[key][()] for key in gene_names] # type:ignore
|
184
|
+
return pd.DataFrame(np.asmatrix(Xg), index=gene_names) # type:ignore
|
185
|
+
|
186
|
+
def write_df_to_h5(self, df: pd.DataFrame, filename: str, description: str) -> None:
|
187
|
+
"""
|
188
|
+
Write a DataFrame to an HDF5 file, separating row IDs and vectors.
|
189
|
+
|
190
|
+
Parameters:
|
191
|
+
- df: pandas DataFrame, where the index contains the IDs and
|
192
|
+
the columns contain the vector components.
|
193
|
+
- filename: String, the path to the output HDF5 file.
|
194
|
+
"""
|
195
|
+
df.index = df.index.astype(
|
196
|
+
str
|
197
|
+
) # Ensure the index is of a string type for the row IDs
|
198
|
+
vectors = df.values
|
199
|
+
ids = df.index.to_numpy(dtype=str)
|
200
|
+
|
201
|
+
with h5py.File(filename, "w") as h5f:
|
202
|
+
h5f.create_dataset("vectors", data=vectors)
|
203
|
+
dt = h5py.special_dtype(
|
204
|
+
vlen=str
|
205
|
+
) # Use variable-length strings to accommodate any ID size
|
206
|
+
h5f.create_dataset("ids", data=ids.astype("S"), dtype=dt)
|
207
|
+
|
208
|
+
# add the attributes
|
209
|
+
h5f.attrs["description"] = description
|
210
|
+
h5f.attrs["num_vecs"] = vectors.shape[0]
|
211
|
+
h5f.attrs["vec_dim"] = vectors.shape[1]
|
212
|
+
|
213
|
+
def run(self, input_file: str) -> str:
|
214
|
+
"""Load an H5 file as a DataFrame, delete the file, and then export
|
215
|
+
the dataframe as an H5 file in the new format."""
|
216
|
+
df = self.embedding_file_to_df(input_file)
|
217
|
+
os.remove(input_file)
|
218
|
+
new_file_description = "Embeddings formatted for global ID and Vector tables."
|
219
|
+
self.write_df_to_h5(df, input_file, new_file_description)
|
220
|
+
|
221
|
+
return input_file
|
222
|
+
|
223
|
+
def run2(self, input_file: str, output_file: str):
|
224
|
+
"""Load an H5 file as a DataFrame, delete the file, and then export
|
225
|
+
the dataframe as an H5 file in the new format."""
|
226
|
+
df = self.embedding_file_to_df(input_file)
|
227
|
+
new_file_description = "Embeddings formatted for global ID and Vector tables."
|
228
|
+
self.write_df_to_h5(df, output_file, new_file_description)
|
229
|
+
|
230
|
+
|
231
|
+
class Embedder(Processor):
|
232
|
+
"""Base class for protein sequence embedders with optimized memory management.
|
233
|
+
|
234
|
+
This class provides the core functionality for embedding protein sequences using
|
235
|
+
transformer models, with built-in memory management and batch processing capabilities.
|
236
|
+
It handles sequences of different sizes appropriately, processing large sequences
|
237
|
+
individually and smaller sequences in batches for efficiency.
|
238
|
+
|
239
|
+
Memory Management Features:
|
240
|
+
- Periodic cleanup of GPU memory to prevent fragmentation
|
241
|
+
- Separate handling of large and small sequences to optimize memory usage
|
242
|
+
- Batch size limits based on total residues to prevent OOM errors
|
243
|
+
- Configurable cleanup frequency to balance performance and memory usage
|
244
|
+
- Empirically tested sequence length limits (5000-5500 residues depending on model)
|
245
|
+
|
246
|
+
Memory Fragmentation Prevention:
|
247
|
+
- Large sequences (>2500 residues) are processed individually to maintain contiguous memory blocks
|
248
|
+
- Small sequences are batched to efficiently use memory fragments
|
249
|
+
- Forced cleanup after processing large proteins
|
250
|
+
- Memory cleanup after every N sequences (configurable)
|
251
|
+
- Aggressive garbage collection settings for CUDA memory allocator
|
252
|
+
|
253
|
+
Memory Usage Patterns:
|
254
|
+
- Base memory: 2.4-4.8GB (model dependent)
|
255
|
+
- Peak memory: 12-15GB during large sequence processing
|
256
|
+
- Fragmentation ratio maintained above 0.92 for efficient memory use
|
257
|
+
- Maximum sequence length determined by model:
|
258
|
+
* T5: ~5000 residues
|
259
|
+
* ProstT5: ~5500 residues
|
260
|
+
|
261
|
+
Implementation Details:
|
262
|
+
- Uses PyTorch's CUDA memory allocator with optimized settings
|
263
|
+
- Configurable thresholds for large protein handling
|
264
|
+
- Automatic batch size adjustment based on sequence lengths
|
265
|
+
- Optional chunking for sequences exceeding maximum length
|
266
|
+
- Detailed memory statistics logging for monitoring
|
267
|
+
|
268
|
+
Note:
|
269
|
+
Memory limits are hardware-dependent. The above values are based on testing
|
270
|
+
with a 16GB GPU (such as NVIDIA T4). Adjust parameters based on available GPU memory.
|
271
|
+
"""
|
272
|
+
|
273
|
+
def __init__(
|
274
|
+
self,
|
275
|
+
model,
|
276
|
+
tokenizer,
|
277
|
+
max_seq_length: int = 5000,
|
278
|
+
large_protein_threshold: int = 2500,
|
279
|
+
batch_residue_limit: int = 5000,
|
280
|
+
cleanup_frequency: int = 100,
|
281
|
+
skip_long_proteins: bool = False,
|
282
|
+
):
|
283
|
+
"""Initialize the Embedder with model and memory management parameters.
|
284
|
+
|
285
|
+
Args:
|
286
|
+
model: The transformer model to use for embeddings
|
287
|
+
tokenizer: The tokenizer matching the model
|
288
|
+
max_seq_length: Maximum sequence length before chunking or skipping.
|
289
|
+
Also used as chunk size when processing long sequences.
|
290
|
+
large_protein_threshold: Sequences longer than this are processed individually
|
291
|
+
batch_residue_limit: Maximum total residues allowed in a batch
|
292
|
+
cleanup_frequency: Number of sequences to process before performing memory cleanup
|
293
|
+
skip_long_proteins: If True, skip proteins longer than max_seq_length.
|
294
|
+
If False, process them in chunks of max_seq_length size.
|
295
|
+
|
296
|
+
Note:
|
297
|
+
The class automatically configures CUDA memory allocation if GPU is available.
|
298
|
+
"""
|
299
|
+
# Configure memory allocator for CUDA
|
300
|
+
if torch.cuda.is_available():
|
301
|
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
|
302
|
+
"max_split_size_mb:128," # Smaller allocation chunks
|
303
|
+
"garbage_collection_threshold:0.8" # More aggressive GC
|
304
|
+
)
|
305
|
+
|
306
|
+
self.max_seq_length = max_seq_length
|
307
|
+
self.large_protein_threshold = large_protein_threshold
|
308
|
+
self.batch_residue_limit = batch_residue_limit
|
309
|
+
self.cleanup_frequency = cleanup_frequency
|
310
|
+
self.skip_long_proteins = skip_long_proteins
|
311
|
+
self.processed_count = 0
|
312
|
+
|
313
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
314
|
+
self.model = model
|
315
|
+
self.tokenizer = tokenizer
|
316
|
+
|
317
|
+
self.model.to(self.device)
|
318
|
+
self.model.eval()
|
319
|
+
|
320
|
+
def get_embeddings(self, seqs: Dict[str, str]) -> Dict[str, np.ndarray]:
|
321
|
+
"""Process sequences and generate embeddings with memory management.
|
322
|
+
|
323
|
+
This method handles the core embedding logic, including:
|
324
|
+
- Handling sequences that exceed maximum length (skip or chunk)
|
325
|
+
- Splitting sequences into large and small batches
|
326
|
+
- Periodic memory cleanup
|
327
|
+
- Batch processing for efficiency
|
328
|
+
|
329
|
+
Args:
|
330
|
+
seqs: Dictionary mapping sequence IDs to their amino acid sequences
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
Dictionary mapping sequence IDs to their embedding vectors
|
334
|
+
|
335
|
+
Note:
|
336
|
+
Long sequences are either skipped or processed in chunks based on skip_long_proteins
|
337
|
+
"""
|
338
|
+
results: Dict[str, np.ndarray] = {}
|
339
|
+
try:
|
340
|
+
# Initialize progress tracking
|
341
|
+
total_sequences = len(seqs)
|
342
|
+
processed_sequences = 0
|
343
|
+
start_time = time.time()
|
344
|
+
|
345
|
+
logger.info(f"Starting embedding of {total_sequences} sequences")
|
346
|
+
|
347
|
+
# Handle sequences based on length
|
348
|
+
long_seqs = {
|
349
|
+
id: seq for id, seq in seqs.items() if len(seq) > self.max_seq_length
|
350
|
+
}
|
351
|
+
valid_seqs = {
|
352
|
+
id: seq for id, seq in seqs.items() if len(seq) <= self.max_seq_length
|
353
|
+
}
|
354
|
+
|
355
|
+
if long_seqs:
|
356
|
+
if self.skip_long_proteins:
|
357
|
+
logger.warning(
|
358
|
+
f"Skipping {len(long_seqs)} sequences exceeding max length {self.max_seq_length}: {', '.join(long_seqs.keys())}"
|
359
|
+
)
|
360
|
+
else:
|
361
|
+
logger.info(
|
362
|
+
f"Processing {len(long_seqs)} long sequences in chunks: {', '.join(long_seqs.keys())}"
|
363
|
+
)
|
364
|
+
for i, (seq_id, seq) in enumerate(long_seqs.items(), 1):
|
365
|
+
logger.info(
|
366
|
+
f"Embedding long sequence {i}/{len(long_seqs)}: {seq_id}"
|
367
|
+
)
|
368
|
+
results[seq_id] = self.embed_big_prot(seq_id, seq)
|
369
|
+
self.cleanup_memory()
|
370
|
+
|
371
|
+
# Update progress
|
372
|
+
processed_sequences += 1
|
373
|
+
elapsed_time = time.time() - start_time
|
374
|
+
remaining_sequences = total_sequences - processed_sequences
|
375
|
+
avg_time_per_seq = (
|
376
|
+
elapsed_time / processed_sequences
|
377
|
+
if processed_sequences > 0
|
378
|
+
else 0
|
379
|
+
)
|
380
|
+
estimated_time_left = avg_time_per_seq * remaining_sequences
|
381
|
+
|
382
|
+
logger.info(
|
383
|
+
f"Progress: {processed_sequences}/{total_sequences} sequences ({processed_sequences/total_sequences*100:.1f}%) | "
|
384
|
+
f"Elapsed: {elapsed_time/60:.1f} min | "
|
385
|
+
f"Est. remaining: {estimated_time_left/60:.1f} min"
|
386
|
+
)
|
387
|
+
|
388
|
+
# Split remaining sequences based on size
|
389
|
+
large_seqs = {
|
390
|
+
id: seq
|
391
|
+
for id, seq in valid_seqs.items()
|
392
|
+
if len(seq) > self.large_protein_threshold
|
393
|
+
}
|
394
|
+
small_seqs = {
|
395
|
+
id: seq
|
396
|
+
for id, seq in valid_seqs.items()
|
397
|
+
if len(seq) <= self.large_protein_threshold
|
398
|
+
}
|
399
|
+
|
400
|
+
logger.info(
|
401
|
+
f"Split into {len(large_seqs)} large and {len(small_seqs)} small sequences"
|
402
|
+
)
|
403
|
+
|
404
|
+
# Process large sequences individually
|
405
|
+
for i, (seq_id, seq) in enumerate(large_seqs.items(), 1):
|
406
|
+
logger.info(
|
407
|
+
f"Processing large sequence {i}/{len(large_seqs)}: {seq_id}"
|
408
|
+
)
|
409
|
+
batch = [(seq_id, seq, len(seq))]
|
410
|
+
results.update(self.embed_batch(batch))
|
411
|
+
self.cleanup_memory() # Cleanup after each large sequence
|
412
|
+
|
413
|
+
# Update progress
|
414
|
+
processed_sequences += 1
|
415
|
+
elapsed_time = time.time() - start_time
|
416
|
+
remaining_sequences = total_sequences - processed_sequences
|
417
|
+
avg_time_per_seq = (
|
418
|
+
elapsed_time / processed_sequences if processed_sequences > 0 else 0
|
419
|
+
)
|
420
|
+
estimated_time_left = avg_time_per_seq * remaining_sequences
|
421
|
+
|
422
|
+
logger.info(
|
423
|
+
f"Progress: {processed_sequences}/{total_sequences} sequences ({processed_sequences/total_sequences*100:.1f}%) | "
|
424
|
+
f"Elapsed: {elapsed_time/60:.1f} min | "
|
425
|
+
f"Est. remaining: {estimated_time_left/60:.1f} min"
|
426
|
+
)
|
427
|
+
|
428
|
+
# Process small sequences in batches
|
429
|
+
current_batch: List[Tuple[str, str, int]] = []
|
430
|
+
current_size = 0
|
431
|
+
small_batch_count = 0
|
432
|
+
total_small_batches = (
|
433
|
+
sum(len(seq) for seq in small_seqs.values())
|
434
|
+
+ self.batch_residue_limit
|
435
|
+
- 1
|
436
|
+
) // self.batch_residue_limit
|
437
|
+
|
438
|
+
# Sort sequences by length in descending order (reduces unnecessary padding --> speeds up embedding)
|
439
|
+
small_seqs_sorted = sorted(
|
440
|
+
small_seqs.items(), key=lambda x: len(x[1]), reverse=True
|
441
|
+
)
|
442
|
+
|
443
|
+
for seq_id, seq in small_seqs_sorted:
|
444
|
+
seq_len = len(seq)
|
445
|
+
|
446
|
+
if current_size + seq_len > self.batch_residue_limit:
|
447
|
+
if current_batch:
|
448
|
+
small_batch_count += 1
|
449
|
+
logger.info(
|
450
|
+
f"Processing small batch {small_batch_count}/{total_small_batches} with {len(current_batch)} sequences"
|
451
|
+
)
|
452
|
+
batch_results = self.embed_batch(current_batch)
|
453
|
+
results.update(batch_results)
|
454
|
+
self.cleanup_memory()
|
455
|
+
|
456
|
+
# Update progress
|
457
|
+
processed_sequences += len(current_batch)
|
458
|
+
elapsed_time = time.time() - start_time
|
459
|
+
remaining_sequences = total_sequences - processed_sequences
|
460
|
+
avg_time_per_seq = (
|
461
|
+
elapsed_time / processed_sequences
|
462
|
+
if processed_sequences > 0
|
463
|
+
else 0
|
464
|
+
)
|
465
|
+
estimated_time_left = avg_time_per_seq * remaining_sequences
|
466
|
+
|
467
|
+
logger.info(
|
468
|
+
f"Progress: {processed_sequences}/{total_sequences} sequences ({processed_sequences/total_sequences*100:.1f}%) | "
|
469
|
+
f"Elapsed: {elapsed_time/60:.1f} min | "
|
470
|
+
f"Est. remaining: {estimated_time_left/60:.1f} min"
|
471
|
+
)
|
472
|
+
current_batch = []
|
473
|
+
current_size = 0
|
474
|
+
|
475
|
+
current_batch.append((seq_id, seq, seq_len))
|
476
|
+
current_size += seq_len
|
477
|
+
|
478
|
+
# Process remaining batch
|
479
|
+
if current_batch:
|
480
|
+
small_batch_count += 1
|
481
|
+
logger.info(
|
482
|
+
f"Processing final small batch {small_batch_count}/{total_small_batches} with {len(current_batch)} sequences"
|
483
|
+
)
|
484
|
+
batch_results = self.embed_batch(current_batch)
|
485
|
+
results.update(batch_results)
|
486
|
+
|
487
|
+
# Update final progress
|
488
|
+
processed_sequences += len(current_batch)
|
489
|
+
elapsed_time = time.time() - start_time
|
490
|
+
|
491
|
+
logger.info(
|
492
|
+
f"Completed embedding {processed_sequences}/{total_sequences} sequences in {elapsed_time/60:.1f} minutes"
|
493
|
+
)
|
494
|
+
|
495
|
+
return results
|
496
|
+
|
497
|
+
finally:
|
498
|
+
self.cleanup_memory(deep=True)
|
499
|
+
|
500
|
+
def cleanup_memory(self, deep: bool = False):
|
501
|
+
"""Perform memory cleanup operations.
|
502
|
+
|
503
|
+
Args:
|
504
|
+
deep: If True, performs aggressive cleanup including model transfer
|
505
|
+
and garbage collection. Takes longer but frees more memory.
|
506
|
+
|
507
|
+
Note:
|
508
|
+
Regular cleanup is performed based on cleanup_frequency.
|
509
|
+
Deep cleanup is more thorough but takes longer.
|
510
|
+
"""
|
511
|
+
self.processed_count += 1
|
512
|
+
|
513
|
+
if deep or self.processed_count % self.cleanup_frequency == 0:
|
514
|
+
logger.info(
|
515
|
+
f"Performing memory cleanup after {self.processed_count} sequences"
|
516
|
+
)
|
517
|
+
if torch.cuda.is_available():
|
518
|
+
before_mem = torch.cuda.memory_allocated() / 1e9
|
519
|
+
|
520
|
+
torch.cuda.empty_cache()
|
521
|
+
if deep:
|
522
|
+
self.model = self.model.cpu()
|
523
|
+
torch.cuda.empty_cache()
|
524
|
+
self.model = self.model.to(self.device)
|
525
|
+
|
526
|
+
after_mem = torch.cuda.memory_allocated() / 1e9
|
527
|
+
logger.info(
|
528
|
+
f"Memory cleaned up: {before_mem:.2f}GB -> {after_mem:.2f}GB"
|
529
|
+
)
|
530
|
+
|
531
|
+
if deep:
|
532
|
+
import gc
|
533
|
+
|
534
|
+
gc.collect()
|
535
|
+
|
536
|
+
def run(self, input_file, output_file=None):
|
537
|
+
"""
|
538
|
+
Run the embedding process on the input file.
|
539
|
+
|
540
|
+
Args:
|
541
|
+
input_file (str): Path to the input FASTA file.
|
542
|
+
output_file (str, optional): Path to the output H5 file. If not provided,
|
543
|
+
it will be generated from the input file name.
|
544
|
+
|
545
|
+
Returns:
|
546
|
+
str: Path to the output H5 file containing the embeddings.
|
547
|
+
"""
|
548
|
+
logger.info(f"Loading sequences from {input_file}")
|
549
|
+
start_time = time.time()
|
550
|
+
sequences = clean_noncanonical_fasta_to_dict(input_file)
|
551
|
+
load_time = time.time() - start_time
|
552
|
+
logger.info(
|
553
|
+
f"Loaded {len(sequences)} sequences from {input_file} in {load_time:.2f} seconds"
|
554
|
+
)
|
555
|
+
|
556
|
+
logger.info(f"Starting embedding process for {len(sequences)} sequences")
|
557
|
+
embed_start_time = time.time()
|
558
|
+
embeddings = self.get_embeddings(sequences)
|
559
|
+
embed_time = time.time() - embed_start_time
|
560
|
+
logger.info(
|
561
|
+
f"Completed embedding {len(embeddings)} sequences in {embed_time/60:.2f} minutes"
|
562
|
+
)
|
563
|
+
|
564
|
+
if output_file is None:
|
565
|
+
output_file = input_file.replace(".fasta", ".h5")
|
566
|
+
|
567
|
+
logger.info(f"Saving embeddings to {output_file}")
|
568
|
+
save_start_time = time.time()
|
569
|
+
self.save_to_h5(output_file, embeddings)
|
570
|
+
save_time = time.time() - save_start_time
|
571
|
+
logger.info(
|
572
|
+
f"Saved {len(embeddings)} embeddings to {output_file} in {save_time:.2f} seconds"
|
573
|
+
)
|
574
|
+
|
575
|
+
total_time = time.time() - start_time
|
576
|
+
logger.info(f"Total processing time: {total_time/60:.2f} minutes")
|
577
|
+
|
578
|
+
return output_file
|
579
|
+
|
580
|
+
def save_to_h5(self, output_file: str, embeddings: Dict[str, np.ndarray]) -> None:
|
581
|
+
"""
|
582
|
+
Save protein embeddings to an HDF5 file.
|
583
|
+
|
584
|
+
Args:
|
585
|
+
output_file (str): Path to save the embeddings.
|
586
|
+
embeddings (Dict[str, np.ndarray]): Dictionary of embeddings.
|
587
|
+
|
588
|
+
The method creates an H5 file with two datasets:
|
589
|
+
- 'ids': contains protein IDs as variable-length strings
|
590
|
+
- 'vectors': contains embedding vectors as float32 arrays
|
591
|
+
"""
|
592
|
+
# Convert the embeddings dictionary to lists for ids and vectors
|
593
|
+
ids = list(embeddings.keys())
|
594
|
+
vectors = np.array(list(embeddings.values()), dtype=np.float32)
|
595
|
+
|
596
|
+
# Create the HDF5 file, with datasets for vectors and IDs
|
597
|
+
with h5py.File(output_file, "w") as h5f:
|
598
|
+
# Create the 'vectors' dataset
|
599
|
+
h5f.create_dataset("vectors", data=vectors)
|
600
|
+
|
601
|
+
# Create the 'ids' dataset with variable-length strings
|
602
|
+
dt = h5py.special_dtype(vlen=str)
|
603
|
+
h5f.create_dataset("ids", data=ids, dtype=dt)
|
604
|
+
|
605
|
+
# Add the attributes
|
606
|
+
h5f.attrs["num_vecs"] = len(embeddings)
|
607
|
+
h5f.attrs["vec_dim"] = vectors.shape[1] if vectors.size > 0 else 0
|
608
|
+
|
609
|
+
def embed_big_prot(self, seq_id: str, sequence: str) -> np.ndarray:
|
610
|
+
"""Embed a large protein sequence by chunking it and averaging the embeddings.
|
611
|
+
|
612
|
+
Args:
|
613
|
+
seq_id: The identifier for the protein sequence
|
614
|
+
sequence: The protein sequence to embed
|
615
|
+
|
616
|
+
Returns:
|
617
|
+
np.ndarray: The averaged embedding for the entire sequence
|
618
|
+
|
619
|
+
Note:
|
620
|
+
This method processes the sequence in chunks of size max_seq_length
|
621
|
+
and averages the resulting embeddings.
|
622
|
+
"""
|
623
|
+
if not isinstance(sequence, str):
|
624
|
+
raise TypeError("Sequence must be a string.")
|
625
|
+
|
626
|
+
if not sequence:
|
627
|
+
raise ValueError("Sequence cannot be empty.")
|
628
|
+
|
629
|
+
if self.max_seq_length <= 0:
|
630
|
+
raise ValueError("max_seq_length must be greater than 0.")
|
631
|
+
|
632
|
+
# Create chunks of the sequence using max_seq_length
|
633
|
+
chunks: List[Tuple[str, str, int]] = [
|
634
|
+
(
|
635
|
+
seq_id,
|
636
|
+
sequence[i : i + self.max_seq_length],
|
637
|
+
min(self.max_seq_length, len(sequence) - i),
|
638
|
+
)
|
639
|
+
for i in range(0, len(sequence), self.max_seq_length)
|
640
|
+
]
|
641
|
+
|
642
|
+
logger.info(
|
643
|
+
f"Processing {seq_id} in {len(chunks)} chunks (total length: {len(sequence)})"
|
644
|
+
)
|
645
|
+
|
646
|
+
# Embed each chunk
|
647
|
+
chunk_embeddings = []
|
648
|
+
for i, chunk in enumerate(chunks, 1):
|
649
|
+
logger.info(
|
650
|
+
f"Processing chunk {i}/{len(chunks)} for {seq_id} (length: {chunk[2]})"
|
651
|
+
)
|
652
|
+
chunk_start_time = time.time()
|
653
|
+
result = self.embed_batch([chunk])
|
654
|
+
chunk_embeddings.append(result[seq_id])
|
655
|
+
chunk_time = time.time() - chunk_start_time
|
656
|
+
logger.info(
|
657
|
+
f"Processed chunk {i}/{len(chunks)} for {seq_id} in {chunk_time:.2f} seconds"
|
658
|
+
)
|
659
|
+
|
660
|
+
# Average the embeddings
|
661
|
+
average_embedding = np.mean(chunk_embeddings, axis=0)
|
662
|
+
logger.info(f"Completed processing {seq_id} (averaged {len(chunks)} chunks)")
|
663
|
+
|
664
|
+
return average_embedding
|
665
|
+
|
666
|
+
def embed_batch(self, batch: List[Tuple[str, str, int]]) -> Dict[str, np.ndarray]:
|
667
|
+
"""
|
668
|
+
Generate embeddings for a batch of sequences.
|
669
|
+
|
670
|
+
Args:
|
671
|
+
batch: A list of tuples, each containing (sequence_id, sequence, sequence_length)
|
672
|
+
|
673
|
+
Returns:
|
674
|
+
A dictionary mapping sequence IDs to their embeddings as numpy arrays
|
675
|
+
"""
|
676
|
+
if not batch:
|
677
|
+
raise ValueError(
|
678
|
+
"Cannot embed an empty batch. Please provide at least one sequence."
|
679
|
+
)
|
680
|
+
|
681
|
+
sequence_ids, sequences, sequence_lengths = zip(*batch)
|
682
|
+
|
683
|
+
# Prepare sequences for tokenization
|
684
|
+
tokenizer_input = self.prepare_tokenizer_input(sequences)
|
685
|
+
|
686
|
+
# Tokenize sequences
|
687
|
+
encoded_input = self.tokenizer.batch_encode_plus(
|
688
|
+
tokenizer_input,
|
689
|
+
add_special_tokens=True,
|
690
|
+
padding="longest",
|
691
|
+
return_tensors="pt",
|
692
|
+
)
|
693
|
+
|
694
|
+
# Move tensors to the appropriate device
|
695
|
+
input_ids = encoded_input["input_ids"].to(self.device)
|
696
|
+
attention_mask = encoded_input["attention_mask"].to(self.device)
|
697
|
+
|
698
|
+
# Generate embeddings
|
699
|
+
with torch.no_grad():
|
700
|
+
embedding_output = self.model(
|
701
|
+
input_ids, attention_mask=attention_mask
|
702
|
+
).last_hidden_state
|
703
|
+
|
704
|
+
# Process embeddings for each sequence
|
705
|
+
embeddings = {}
|
706
|
+
for idx, (seq_id, seq_len) in enumerate(zip(sequence_ids, sequence_lengths)):
|
707
|
+
# Extract embedding for the sequence
|
708
|
+
seq_embedding = self.extract_sequence_embedding(
|
709
|
+
embedding_output[idx], seq_len
|
710
|
+
)
|
711
|
+
|
712
|
+
# Calculate mean embedding and convert to numpy array
|
713
|
+
mean_embedding = seq_embedding.mean(dim=0).detach().cpu().numpy().squeeze()
|
714
|
+
|
715
|
+
embeddings[seq_id] = mean_embedding
|
716
|
+
|
717
|
+
return embeddings
|
718
|
+
|
719
|
+
def prepare_tokenizer_input(self, sequences: List[str]) -> List[str]:
|
720
|
+
"""Prepare sequences for tokenization."""
|
721
|
+
raise NotImplementedError
|
722
|
+
|
723
|
+
def extract_sequence_embedding(
|
724
|
+
self, embedding: torch.Tensor, seq_len: int
|
725
|
+
) -> torch.Tensor:
|
726
|
+
"""Extract the relevant part of the embedding for a sequence."""
|
727
|
+
raise NotImplementedError
|
728
|
+
|
729
|
+
|
730
|
+
class ProstT5Embedder(Embedder):
|
731
|
+
"""Protein sequence embedder using the ProstT5 model.
|
732
|
+
|
733
|
+
This class implements protein sequence embedding using the ProstT5 model,
|
734
|
+
which is specifically trained for protein structure prediction tasks.
|
735
|
+
It includes memory-efficient processing and automatic precision selection
|
736
|
+
based on available hardware.
|
737
|
+
|
738
|
+
Memory management features are inherited from the base Embedder class:
|
739
|
+
- Periodic cleanup of GPU memory
|
740
|
+
- Separate handling of large and small sequences
|
741
|
+
- Batch size limits based on total residues
|
742
|
+
- Configurable cleanup frequency
|
743
|
+
"""
|
744
|
+
|
745
|
+
def __init__(
|
746
|
+
self,
|
747
|
+
max_seq_length: int = 5000,
|
748
|
+
large_protein_threshold: int = 2500,
|
749
|
+
batch_residue_limit: int = 5000,
|
750
|
+
cleanup_frequency: int = 100,
|
751
|
+
skip_long_proteins: bool = False,
|
752
|
+
):
|
753
|
+
"""Initialize ProstT5Embedder with memory management parameters.
|
754
|
+
|
755
|
+
Args:
|
756
|
+
max_seq_length: Maximum sequence length before chunking or skipping.
|
757
|
+
Also used as chunk size when processing long sequences.
|
758
|
+
large_protein_threshold: Sequences longer than this are processed individually
|
759
|
+
batch_residue_limit: Maximum total residues in a batch
|
760
|
+
cleanup_frequency: Frequency of memory cleanup operations
|
761
|
+
skip_long_proteins: If True, skip proteins longer than max_seq_length
|
762
|
+
|
763
|
+
Note:
|
764
|
+
The model automatically selects half precision (float16) when running on GPU
|
765
|
+
and full precision (float32) when running on CPU.
|
766
|
+
"""
|
767
|
+
tokenizer = T5Tokenizer.from_pretrained(
|
768
|
+
"Rostlab/ProstT5", do_lower_case=False, legacy=True
|
769
|
+
)
|
770
|
+
model = T5EncoderModel.from_pretrained("Rostlab/ProstT5")
|
771
|
+
|
772
|
+
super().__init__(
|
773
|
+
model,
|
774
|
+
tokenizer,
|
775
|
+
max_seq_length,
|
776
|
+
large_protein_threshold,
|
777
|
+
batch_residue_limit,
|
778
|
+
cleanup_frequency,
|
779
|
+
skip_long_proteins,
|
780
|
+
)
|
781
|
+
|
782
|
+
# Set precision based on device
|
783
|
+
self.model = (
|
784
|
+
self.model.half() if torch.cuda.is_available() else self.model.float()
|
785
|
+
)
|
786
|
+
|
787
|
+
def prepare_tokenizer_input(self, sequences: List[str]) -> List[str]:
|
788
|
+
"""Prepare sequences for ProstT5 tokenization.
|
789
|
+
|
790
|
+
Args:
|
791
|
+
sequences: List of amino acid sequences
|
792
|
+
|
793
|
+
Returns:
|
794
|
+
List of sequences with ProstT5-specific formatting, including the
|
795
|
+
<AA2fold> prefix and space-separated residues.
|
796
|
+
"""
|
797
|
+
return [f"<AA2fold> {' '.join(seq)}" for seq in sequences]
|
798
|
+
|
799
|
+
def extract_sequence_embedding(
|
800
|
+
self, embedding: torch.Tensor, seq_len: int
|
801
|
+
) -> torch.Tensor:
|
802
|
+
"""Extract relevant embeddings for a sequence.
|
803
|
+
|
804
|
+
Args:
|
805
|
+
embedding: Raw embedding tensor from the model
|
806
|
+
seq_len: Length of the original sequence
|
807
|
+
|
808
|
+
Returns:
|
809
|
+
Tensor containing only the relevant sequence embeddings,
|
810
|
+
excluding special tokens. For ProstT5, we skip the first token
|
811
|
+
(corresponding to <AA2fold>) and take the next seq_len tokens.
|
812
|
+
"""
|
813
|
+
return embedding[1 : seq_len + 1]
|
814
|
+
|
815
|
+
|
816
|
+
class T5Embedder(Embedder):
|
817
|
+
"""Protein sequence embedder using the T5 transformer model.
|
818
|
+
|
819
|
+
This class implements protein sequence embedding using the T5 model from Rostlab,
|
820
|
+
specifically designed for protein sequences. It includes memory-efficient processing
|
821
|
+
of both large and small sequences.
|
822
|
+
|
823
|
+
The model used is 'Rostlab/prot_t5_xl_half_uniref50-enc', which was trained on
|
824
|
+
UniRef50 sequences and provides state-of-the-art protein embeddings.
|
825
|
+
|
826
|
+
Memory management features are inherited from the base Embedder class:
|
827
|
+
- Periodic cleanup of GPU memory
|
828
|
+
- Separate handling of large and small sequences
|
829
|
+
- Batch size limits based on total residues
|
830
|
+
- Configurable cleanup frequency
|
831
|
+
"""
|
832
|
+
|
833
|
+
def __init__(
|
834
|
+
self,
|
835
|
+
max_seq_length: int = 5000,
|
836
|
+
large_protein_threshold: int = 2500,
|
837
|
+
batch_residue_limit: int = 5000,
|
838
|
+
cleanup_frequency: int = 100,
|
839
|
+
skip_long_proteins: bool = False,
|
840
|
+
):
|
841
|
+
"""Initialize T5Embedder with memory management parameters.
|
842
|
+
|
843
|
+
Args:
|
844
|
+
max_seq_length: Maximum sequence length before chunking or skipping.
|
845
|
+
Also used as chunk size when processing long sequences.
|
846
|
+
large_protein_threshold: Sequences longer than this are processed individually
|
847
|
+
batch_residue_limit: Maximum total residues in a batch
|
848
|
+
cleanup_frequency: Frequency of memory cleanup operations
|
849
|
+
skip_long_proteins: If True, skip proteins longer than max_seq_length
|
850
|
+
|
851
|
+
Note:
|
852
|
+
The model automatically handles memory management and batch processing
|
853
|
+
based on sequence sizes and available resources.
|
854
|
+
"""
|
855
|
+
tokenizer = T5Tokenizer.from_pretrained(
|
856
|
+
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
|
857
|
+
)
|
858
|
+
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
|
859
|
+
|
860
|
+
super().__init__(
|
861
|
+
model,
|
862
|
+
tokenizer,
|
863
|
+
max_seq_length,
|
864
|
+
large_protein_threshold,
|
865
|
+
batch_residue_limit,
|
866
|
+
cleanup_frequency,
|
867
|
+
skip_long_proteins,
|
868
|
+
)
|
869
|
+
|
870
|
+
def prepare_tokenizer_input(self, sequences: List[str]) -> List[str]:
|
871
|
+
"""Prepare sequences for T5 tokenization.
|
872
|
+
|
873
|
+
Args:
|
874
|
+
sequences: List of amino acid sequences
|
875
|
+
|
876
|
+
Returns:
|
877
|
+
List of space-separated sequences ready for tokenization
|
878
|
+
"""
|
879
|
+
return [" ".join(seq) for seq in sequences]
|
880
|
+
|
881
|
+
def extract_sequence_embedding(
|
882
|
+
self, embedding: torch.Tensor, seq_len: int
|
883
|
+
) -> torch.Tensor:
|
884
|
+
"""Extract relevant embeddings for a sequence.
|
885
|
+
|
886
|
+
Args:
|
887
|
+
embedding: Raw embedding tensor from the model
|
888
|
+
seq_len: Length of the original sequence
|
889
|
+
|
890
|
+
Returns:
|
891
|
+
Tensor containing only the relevant sequence embeddings
|
892
|
+
"""
|
893
|
+
return embedding[:seq_len]
|