dayhoff-tools 1.1.40__py3-none-any.whl → 1.1.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -374,6 +374,11 @@ def create_or_update_job_definition(
374
374
  "timeout": {"attemptDurationSeconds": aws_config.get("timeout_seconds", 86400)},
375
375
  }
376
376
 
377
+ # Add tags if specified in config
378
+ if "tags" in aws_config:
379
+ job_definition_args["tags"] = aws_config["tags"]
380
+ print(f"Adding tags to job definition: {aws_config['tags']}")
381
+
377
382
  # Register new revision using the session client
378
383
  response = batch.register_job_definition(**job_definition_args)
379
384
 
@@ -472,6 +477,11 @@ def submit_aws_batch_job(
472
477
  print(f"Setting retry attempts to {retry_attempts}")
473
478
  job_submit_args["retryStrategy"] = {"attempts": retry_attempts}
474
479
 
480
+ # Add tags if specified in config
481
+ if "tags" in aws_config:
482
+ job_submit_args["tags"] = aws_config["tags"]
483
+ print(f"Adding tags to batch job: {aws_config['tags']}")
484
+
475
485
  # Submit the job using the session client
476
486
  response = batch.submit_job(**job_submit_args)
477
487
 
@@ -1,9 +1,5 @@
1
- import csv
2
1
  import json
3
2
  import logging
4
- import os
5
- import shlex
6
- import shutil
7
3
  import subprocess
8
4
  from abc import ABC, abstractmethod
9
5
  from pathlib import Path
@@ -295,108 +291,3 @@ class InterProScanProcessor(Processor):
295
291
  cleaned_input_file_path.unlink()
296
292
 
297
293
  return str(chunk_output_dir)
298
-
299
-
300
- class BoltzPredictor(Processor):
301
- """Processor for running Boltz docking predictions.
302
-
303
- This class wraps the Boltz docking tool to predict protein structures
304
- from sequence data.
305
- """
306
-
307
- def __init__(self, num_workers: int, boltz_options: str | None = None):
308
- """Initialize the BoltzPredictor.
309
-
310
- Args:
311
- num_workers: Number of worker threads to use as a default.
312
- This can be overridden if --num_workers is present
313
- in boltz_options.
314
- boltz_options: A string containing additional command-line options
315
- to pass to the Boltz predictor. Options should be
316
- space-separated (e.g., "--option1 value1 --option2").
317
- """
318
- self.num_workers = num_workers
319
- self.boltz_options = boltz_options
320
-
321
- def run(self, input_file: str) -> str:
322
- """Run Boltz prediction on the input file.
323
-
324
- Constructs the command using the input file, default number of workers,
325
- and any additional options provided via `boltz_options`. If `--num_workers`
326
- is specified in `boltz_options`, it overrides the default `num_workers`.
327
-
328
- Args:
329
- input_file: Path to the input file containing sequences
330
-
331
- Returns:
332
- Path to the output directory created by Boltz
333
-
334
- Raises:
335
- subprocess.CalledProcessError: If Boltz prediction fails
336
- """
337
- # Determine expected output directory name
338
- input_base = os.path.splitext(os.path.basename(input_file))[0]
339
- expected_output_dir = f"boltz_results_{input_base}"
340
- logger.info(f"Expected output directory: {expected_output_dir}")
341
-
342
- # Start building the command
343
- cmd = ["boltz", "predict", input_file]
344
-
345
- # Parse additional options if provided
346
- additional_args = []
347
- num_workers_in_opts = False
348
- if self.boltz_options:
349
- try:
350
- parsed_opts = shlex.split(self.boltz_options)
351
- additional_args.extend(parsed_opts)
352
- if "--num_workers" in parsed_opts:
353
- num_workers_in_opts = True
354
- logger.info(
355
- f"Using --num_workers from BOLTZ_OPTIONS: {self.boltz_options}"
356
- )
357
- except ValueError as e:
358
- logger.error(f"Error parsing BOLTZ_OPTIONS '{self.boltz_options}': {e}")
359
- # Decide if we should raise an error or proceed without options
360
- # For now, proceed without the additional options
361
- additional_args = [] # Clear potentially partially parsed args
362
-
363
- # Add num_workers if not specified in options
364
- if not num_workers_in_opts:
365
- logger.info(f"Using default num_workers: {self.num_workers}")
366
- cmd.extend(["--num_workers", str(self.num_workers)])
367
-
368
- # Add the parsed additional arguments
369
- cmd.extend(additional_args)
370
-
371
- # Log the final command
372
- # Use shlex.join for safer command logging, especially if paths/args have spaces
373
- try:
374
- safe_cmd_str = shlex.join(cmd)
375
- logger.info(f"Running command: {safe_cmd_str}")
376
- except AttributeError: # shlex.join is Python 3.8+
377
- logger.info(f"Running command: {' '.join(cmd)}")
378
-
379
- # Stream output in real-time
380
- process = subprocess.Popen(
381
- cmd,
382
- stdout=subprocess.PIPE,
383
- stderr=subprocess.STDOUT,
384
- text=True,
385
- bufsize=1,
386
- )
387
-
388
- stdout = process.stdout
389
- if stdout:
390
- for line in iter(stdout.readline, ""):
391
- logger.info(f"BOLTZ: {line.rstrip()}")
392
-
393
- # Wait for process to complete
394
- return_code = process.wait()
395
- if return_code != 0:
396
- logger.error(f"Boltz prediction failed with exit code {return_code}")
397
- raise subprocess.CalledProcessError(return_code, cmd)
398
-
399
- logger.info(
400
- f"Boltz prediction completed successfully. Output in {expected_output_dir}"
401
- )
402
- return expected_output_dir
@@ -0,0 +1,895 @@
1
+ import logging
2
+ import os
3
+ import time
4
+ from typing import Dict, List, Literal, Optional, Tuple, cast
5
+
6
+ import h5py
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torch.utils.data
11
+ from dayhoff_tools.deployment.processors import Processor
12
+ from dayhoff_tools.fasta import (
13
+ clean_noncanonical_fasta,
14
+ clean_noncanonical_fasta_to_dict,
15
+ )
16
+ from esm import FastaBatchedDataset, pretrained
17
+ from transformers import T5EncoderModel, T5Tokenizer
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class ESMEmbedder(Processor):
23
+ """A processor that calculates ESM embeddings for a file of protein sequences.
24
+ Embeddings come from the last layer, and can be per-protein or per-residue."""
25
+
26
+ def __init__(
27
+ self,
28
+ # PyTorch model file OR name of pretrained model to download
29
+ # https://github.com/facebookresearch/esm?tab=readme-ov-file#available
30
+ model_name: Literal[
31
+ "esm2_t33_650M_UR50D", # ESM2 version of the size used in CLEAN
32
+ "esm2_t6_8M_UR50D", # Smallest
33
+ "esm1b_t33_650M_UR50S", # Same as CLEAN
34
+ "esm2_t36_3B_UR50D", # 2nd largest
35
+ "esm2_t48_15B_UR50D", # Largest
36
+ ],
37
+ # Whether to return per-protein or per-residue embeddings.
38
+ embedding_level: Literal["protein", "residue"],
39
+ # Maximum batch size
40
+ toks_per_batch: int = 4096,
41
+ # Truncate sequences longer than the given value
42
+ truncation_seq_length: int = 1022,
43
+ ):
44
+ super().__init__()
45
+ self.model_name = model_name
46
+ self.toks_per_batch = toks_per_batch
47
+ self.embedding_level = embedding_level
48
+ self.truncation_seq_length = truncation_seq_length
49
+
50
+ # Instance variables set by other methods below:
51
+ # self.model, self.alphabet, self.len_batches, self.data_loader, self.dataset_base_name
52
+
53
+ def _load_model(self):
54
+ """Download pre-trained model and load onto device"""
55
+ self.model, self.alphabet = pretrained.load_model_and_alphabet(self.model_name)
56
+ self.model.eval()
57
+ if torch.cuda.is_available():
58
+ self.model.cuda()
59
+ logger.info("Transferred model to GPU.")
60
+ else:
61
+ logger.info("GPU not available. Running model on CPU.")
62
+
63
+ def _load_dataset(self, fasta_file: str) -> None:
64
+ """Load FASTA file into batched dataset and dataloader"""
65
+ if not fasta_file.endswith(".fasta"):
66
+ raise ValueError("Input file must have .fasta extension.")
67
+
68
+ self.dataset_base_name = fasta_file.replace(".fasta", "")
69
+ clean_fasta_file = fasta_file.replace(".fasta", "_clean.fasta")
70
+ clean_noncanonical_fasta(input_path=fasta_file, output_path=clean_fasta_file)
71
+
72
+ dataset = FastaBatchedDataset.from_file(clean_fasta_file)
73
+ logger.info("Read %s and loaded %s sequences.", fasta_file, len(dataset))
74
+
75
+ batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1)
76
+ self.len_batches = len(batches)
77
+ self.data_loader = torch.utils.data.DataLoader(
78
+ dataset, # type: ignore
79
+ collate_fn=self.alphabet.get_batch_converter(self.truncation_seq_length),
80
+ batch_sampler=batches,
81
+ )
82
+ os.remove(clean_fasta_file)
83
+
84
+ def embed_fasta(self) -> str:
85
+ """Calculate embeddings from the FASTA file, return the path to the .h5 file of results
86
+ Write the H5 file with one dataset per protein (id plus embedding vector, where the vector
87
+ is 2D if it was calculated per-residue).
88
+ """
89
+ output_path = self.dataset_base_name + ".h5"
90
+ with h5py.File(output_path, "w") as h5_file, torch.no_grad():
91
+ start_time = time.time()
92
+ logger.info(
93
+ f"Calculating per-{self.embedding_level} embeddings. This dataset contains {self.len_batches} batches."
94
+ )
95
+ total_batches = self.len_batches
96
+
97
+ for batch_idx, (labels, sequences, toks) in enumerate(
98
+ self.data_loader, start=1
99
+ ):
100
+ if batch_idx % 10 == 0:
101
+ elapsed_time = time.time() - start_time
102
+ time_left = elapsed_time * (total_batches - batch_idx) / batch_idx
103
+ logger.info(
104
+ f"{self.dataset_base_name} | Batch {batch_idx}/{total_batches} | Elapsed time {elapsed_time / 60:.0f} min | Time left {time_left / 60:.0f} min, or {time_left / 3_600:.2f} hours"
105
+ )
106
+
107
+ if torch.cuda.is_available():
108
+ toks = toks.to(device="cuda", non_blocking=True)
109
+ out = self.model(
110
+ toks,
111
+ repr_layers=[
112
+ cast(int, self.model.num_layers)
113
+ ], # Get the last layer
114
+ return_contacts=False,
115
+ )
116
+ out["logits"].to(device="cpu")
117
+ representations = {
118
+ layer: t.to(device="cpu")
119
+ for layer, t in out["representations"].items()
120
+ }
121
+ for label_idx, label_full in enumerate(labels):
122
+ label = label_full.split()[
123
+ 0
124
+ ] # Shorten the label to only the first word
125
+ truncate_len = min(
126
+ self.truncation_seq_length, len(sequences[label_idx])
127
+ )
128
+
129
+ full_embeds = list(representations.items())[0][1]
130
+ if self.embedding_level == "protein":
131
+ protein_embeds = (
132
+ full_embeds[label_idx, 1 : truncate_len + 1].mean(0).clone()
133
+ )
134
+ h5_file.create_dataset(label, data=protein_embeds)
135
+ elif self.embedding_level == "residue":
136
+ residue_embeds = full_embeds[
137
+ label_idx, 1 : truncate_len + 1
138
+ ].clone()
139
+ h5_file.create_dataset(label, data=residue_embeds)
140
+ else:
141
+ raise NotImplementedError(
142
+ f"Embedding level {self.embedding_level} not implemented."
143
+ )
144
+
145
+ logger.info(f"Saved embeddings to {output_path}")
146
+ return output_path
147
+
148
+ def run(self, input_file: str, output_file: Optional[str] = None) -> str:
149
+ self._load_model()
150
+ self._load_dataset(input_file)
151
+ embedding_file = self.embed_fasta()
152
+
153
+ # If embedding per-protein, reformat the H5 file.
154
+ # By default, write to the same filename; otherwise to output_file.
155
+ reformatted_embedding_file = output_file or embedding_file
156
+ if self.embedding_level == "protein":
157
+ print("Reformatting H5 file...")
158
+ formatter = H5Reformatter()
159
+ formatter.run2(
160
+ input_file=embedding_file,
161
+ output_file=reformatted_embedding_file,
162
+ )
163
+
164
+ return reformatted_embedding_file
165
+
166
+
167
+ class H5Reformatter(Processor):
168
+ """A processor that reformats per-protein T5 embeddings
169
+ Old format (input): 1 dataset per protein, with the ID as key and the embedding as value.
170
+ New format (output): 2 datasets for the whole file, one of all protein IDs and one of all
171
+ the embeddings together."""
172
+
173
+ def __init__(self):
174
+ super().__init__()
175
+ # Set the device to GPU if available, otherwise CPU
176
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
177
+ logger.info("Device is: %s", self.device)
178
+
179
+ def embedding_file_to_df(self, file_name: str) -> pd.DataFrame:
180
+ with h5py.File(file_name, "r") as f:
181
+ gene_names = list(f.keys())
182
+ Xg = [f[key][()] for key in gene_names] # type:ignore
183
+ return pd.DataFrame(np.asmatrix(Xg), index=gene_names) # type:ignore
184
+
185
+ def write_df_to_h5(self, df: pd.DataFrame, filename: str, description: str) -> None:
186
+ """
187
+ Write a DataFrame to an HDF5 file, separating row IDs and vectors.
188
+
189
+ Parameters:
190
+ - df: pandas DataFrame, where the index contains the IDs and
191
+ the columns contain the vector components.
192
+ - filename: String, the path to the output HDF5 file.
193
+ """
194
+ df.index = df.index.astype(
195
+ str
196
+ ) # Ensure the index is of a string type for the row IDs
197
+ vectors = df.values
198
+ ids = df.index.to_numpy(dtype=str)
199
+
200
+ with h5py.File(filename, "w") as h5f:
201
+ h5f.create_dataset("vectors", data=vectors)
202
+ dt = h5py.special_dtype(
203
+ vlen=str
204
+ ) # Use variable-length strings to accommodate any ID size
205
+ h5f.create_dataset("ids", data=ids.astype("S"), dtype=dt)
206
+
207
+ # add the attributes
208
+ h5f.attrs["description"] = description
209
+ h5f.attrs["num_vecs"] = vectors.shape[0]
210
+ h5f.attrs["vec_dim"] = vectors.shape[1]
211
+
212
+ def run(self, input_file: str) -> str:
213
+ """Load an H5 file as a DataFrame, delete the file, and then export
214
+ the dataframe as an H5 file in the new format."""
215
+ df = self.embedding_file_to_df(input_file)
216
+ os.remove(input_file)
217
+ new_file_description = "Embeddings formatted for global ID and Vector tables."
218
+ self.write_df_to_h5(df, input_file, new_file_description)
219
+
220
+ return input_file
221
+
222
+ def run2(self, input_file: str, output_file: str):
223
+ """Load an H5 file as a DataFrame, delete the file, and then export
224
+ the dataframe as an H5 file in the new format."""
225
+ df = self.embedding_file_to_df(input_file)
226
+ new_file_description = "Embeddings formatted for global ID and Vector tables."
227
+ self.write_df_to_h5(df, output_file, new_file_description)
228
+
229
+
230
+ class Embedder(Processor):
231
+ """Base class for protein sequence embedders with optimized memory management.
232
+
233
+ This class provides the core functionality for embedding protein sequences using
234
+ transformer models, with built-in memory management and batch processing capabilities.
235
+ It handles sequences of different sizes appropriately, processing large sequences
236
+ individually and smaller sequences in batches for efficiency.
237
+
238
+ Memory Management Features:
239
+ - Periodic cleanup of GPU memory to prevent fragmentation
240
+ - Separate handling of large and small sequences to optimize memory usage
241
+ - Batch size limits based on total residues to prevent OOM errors
242
+ - Configurable cleanup frequency to balance performance and memory usage
243
+ - Empirically tested sequence length limits (5000-5500 residues depending on model)
244
+
245
+ Memory Fragmentation Prevention:
246
+ - Large sequences (>2500 residues) are processed individually to maintain contiguous memory blocks
247
+ - Small sequences are batched to efficiently use memory fragments
248
+ - Forced cleanup after processing large proteins
249
+ - Memory cleanup after every N sequences (configurable)
250
+ - Aggressive garbage collection settings for CUDA memory allocator
251
+
252
+ Memory Usage Patterns:
253
+ - Base memory: 2.4-4.8GB (model dependent)
254
+ - Peak memory: 12-15GB during large sequence processing
255
+ - Fragmentation ratio maintained above 0.92 for efficient memory use
256
+ - Maximum sequence length determined by model:
257
+ * T5: ~5000 residues
258
+ * ProstT5: ~5500 residues
259
+
260
+ Implementation Details:
261
+ - Uses PyTorch's CUDA memory allocator with optimized settings
262
+ - Configurable thresholds for large protein handling
263
+ - Automatic batch size adjustment based on sequence lengths
264
+ - Optional chunking for sequences exceeding maximum length
265
+ - Detailed memory statistics logging for monitoring
266
+
267
+ Note:
268
+ Memory limits are hardware-dependent. The above values are based on testing
269
+ with a 16GB GPU (such as NVIDIA T4). Adjust parameters based on available GPU memory.
270
+ """
271
+
272
+ def __init__(
273
+ self,
274
+ model,
275
+ tokenizer,
276
+ max_seq_length: int = 5000,
277
+ large_protein_threshold: int = 2500,
278
+ batch_residue_limit: int = 5000,
279
+ cleanup_frequency: int = 100,
280
+ skip_long_proteins: bool = False,
281
+ ):
282
+ """Initialize the Embedder with model and memory management parameters.
283
+
284
+ Args:
285
+ model: The transformer model to use for embeddings
286
+ tokenizer: The tokenizer matching the model
287
+ max_seq_length: Maximum sequence length before chunking or skipping.
288
+ Also used as chunk size when processing long sequences.
289
+ large_protein_threshold: Sequences longer than this are processed individually
290
+ batch_residue_limit: Maximum total residues allowed in a batch
291
+ cleanup_frequency: Number of sequences to process before performing memory cleanup
292
+ skip_long_proteins: If True, skip proteins longer than max_seq_length.
293
+ If False, process them in chunks of max_seq_length size.
294
+
295
+ Note:
296
+ The class automatically configures CUDA memory allocation if GPU is available.
297
+ """
298
+ # Configure memory allocator for CUDA
299
+ if torch.cuda.is_available():
300
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
301
+ "max_split_size_mb:128," # Smaller allocation chunks
302
+ "garbage_collection_threshold:0.8" # More aggressive GC
303
+ )
304
+
305
+ self.max_seq_length = max_seq_length
306
+ self.large_protein_threshold = large_protein_threshold
307
+ self.batch_residue_limit = batch_residue_limit
308
+ self.cleanup_frequency = cleanup_frequency
309
+ self.skip_long_proteins = skip_long_proteins
310
+ self.processed_count = 0
311
+
312
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
313
+ self.model = model
314
+ self.tokenizer = tokenizer
315
+
316
+ self.model.to(self.device)
317
+ self.model.eval()
318
+
319
+ def get_embeddings(self, seqs: Dict[str, str]) -> Dict[str, np.ndarray]:
320
+ """Process sequences and generate embeddings with memory management.
321
+
322
+ This method handles the core embedding logic, including:
323
+ - Handling sequences that exceed maximum length (skip or chunk)
324
+ - Splitting sequences into large and small batches
325
+ - Periodic memory cleanup
326
+ - Batch processing for efficiency
327
+
328
+ Args:
329
+ seqs: Dictionary mapping sequence IDs to their amino acid sequences
330
+
331
+ Returns:
332
+ Dictionary mapping sequence IDs to their embedding vectors
333
+
334
+ Note:
335
+ Long sequences are either skipped or processed in chunks based on skip_long_proteins
336
+ """
337
+ results: Dict[str, np.ndarray] = {}
338
+ try:
339
+ # Initialize progress tracking
340
+ total_sequences = len(seqs)
341
+ processed_sequences = 0
342
+ start_time = time.time()
343
+
344
+ logger.info(f"Starting embedding of {total_sequences} sequences")
345
+
346
+ # Handle sequences based on length
347
+ long_seqs = {
348
+ id: seq for id, seq in seqs.items() if len(seq) > self.max_seq_length
349
+ }
350
+ valid_seqs = {
351
+ id: seq for id, seq in seqs.items() if len(seq) <= self.max_seq_length
352
+ }
353
+
354
+ if long_seqs:
355
+ if self.skip_long_proteins:
356
+ logger.warning(
357
+ f"Skipping {len(long_seqs)} sequences exceeding max length {self.max_seq_length}: {', '.join(long_seqs.keys())}"
358
+ )
359
+ else:
360
+ logger.info(
361
+ f"Processing {len(long_seqs)} long sequences in chunks: {', '.join(long_seqs.keys())}"
362
+ )
363
+ for i, (seq_id, seq) in enumerate(long_seqs.items(), 1):
364
+ logger.info(
365
+ f"Embedding long sequence {i}/{len(long_seqs)}: {seq_id}"
366
+ )
367
+ results[seq_id] = self.embed_big_prot(seq_id, seq)
368
+ self.cleanup_memory()
369
+
370
+ # Update progress
371
+ processed_sequences += 1
372
+ elapsed_time = time.time() - start_time
373
+ remaining_sequences = total_sequences - processed_sequences
374
+ avg_time_per_seq = (
375
+ elapsed_time / processed_sequences
376
+ if processed_sequences > 0
377
+ else 0
378
+ )
379
+ estimated_time_left = avg_time_per_seq * remaining_sequences
380
+
381
+ logger.info(
382
+ f"Progress: {processed_sequences}/{total_sequences} sequences ({processed_sequences/total_sequences*100:.1f}%) | "
383
+ f"Elapsed: {elapsed_time/60:.1f} min | "
384
+ f"Est. remaining: {estimated_time_left/60:.1f} min"
385
+ )
386
+
387
+ # Split remaining sequences based on size
388
+ large_seqs = {
389
+ id: seq
390
+ for id, seq in valid_seqs.items()
391
+ if len(seq) > self.large_protein_threshold
392
+ }
393
+ small_seqs = {
394
+ id: seq
395
+ for id, seq in valid_seqs.items()
396
+ if len(seq) <= self.large_protein_threshold
397
+ }
398
+
399
+ logger.info(
400
+ f"Split into {len(large_seqs)} large and {len(small_seqs)} small sequences"
401
+ )
402
+
403
+ # Process large sequences individually
404
+ for i, (seq_id, seq) in enumerate(large_seqs.items(), 1):
405
+ logger.info(
406
+ f"Processing large sequence {i}/{len(large_seqs)}: {seq_id}"
407
+ )
408
+ batch = [(seq_id, seq, len(seq))]
409
+ results.update(self.embed_batch(batch))
410
+ self.cleanup_memory() # Cleanup after each large sequence
411
+
412
+ # Update progress
413
+ processed_sequences += 1
414
+ elapsed_time = time.time() - start_time
415
+ remaining_sequences = total_sequences - processed_sequences
416
+ avg_time_per_seq = (
417
+ elapsed_time / processed_sequences if processed_sequences > 0 else 0
418
+ )
419
+ estimated_time_left = avg_time_per_seq * remaining_sequences
420
+
421
+ logger.info(
422
+ f"Progress: {processed_sequences}/{total_sequences} sequences ({processed_sequences/total_sequences*100:.1f}%) | "
423
+ f"Elapsed: {elapsed_time/60:.1f} min | "
424
+ f"Est. remaining: {estimated_time_left/60:.1f} min"
425
+ )
426
+
427
+ # Process small sequences in batches
428
+ current_batch: List[Tuple[str, str, int]] = []
429
+ current_size = 0
430
+ small_batch_count = 0
431
+ total_small_batches = (
432
+ sum(len(seq) for seq in small_seqs.values())
433
+ + self.batch_residue_limit
434
+ - 1
435
+ ) // self.batch_residue_limit
436
+
437
+ # Sort sequences by length in descending order (reduces unnecessary padding --> speeds up embedding)
438
+ small_seqs_sorted = sorted(
439
+ small_seqs.items(), key=lambda x: len(x[1]), reverse=True
440
+ )
441
+
442
+ for seq_id, seq in small_seqs_sorted:
443
+ seq_len = len(seq)
444
+
445
+ # Check if adding this sequence would exceed the limit
446
+ if current_batch and current_size + seq_len > self.batch_residue_limit:
447
+ # Process current batch before adding the new sequence
448
+ small_batch_count += 1
449
+ logger.info(
450
+ f"Processing small batch {small_batch_count}/{total_small_batches} with {len(current_batch)} sequences"
451
+ )
452
+ batch_results = self.embed_batch(current_batch)
453
+ results.update(batch_results)
454
+ self.cleanup_memory()
455
+
456
+ # Update progress
457
+ processed_sequences += len(current_batch)
458
+ elapsed_time = time.time() - start_time
459
+ remaining_sequences = total_sequences - processed_sequences
460
+ avg_time_per_seq = (
461
+ elapsed_time / processed_sequences
462
+ if processed_sequences > 0
463
+ else 0
464
+ )
465
+ estimated_time_left = avg_time_per_seq * remaining_sequences
466
+
467
+ logger.info(
468
+ f"Progress: {processed_sequences}/{total_sequences} sequences ({processed_sequences/total_sequences*100:.1f}%) | "
469
+ f"Elapsed: {elapsed_time/60:.1f} min | "
470
+ f"Est. remaining: {estimated_time_left/60:.1f} min"
471
+ )
472
+ # Start new batch
473
+ current_batch = []
474
+ current_size = 0
475
+
476
+ # Add the current sequence to the batch
477
+ current_batch.append((seq_id, seq, seq_len))
478
+ current_size += seq_len
479
+
480
+ # Process remaining batch
481
+ if current_batch:
482
+ small_batch_count += 1
483
+ logger.info(
484
+ f"Processing final small batch {small_batch_count}/{total_small_batches} with {len(current_batch)} sequences"
485
+ )
486
+ batch_results = self.embed_batch(current_batch)
487
+ results.update(batch_results)
488
+
489
+ # Update final progress
490
+ processed_sequences += len(current_batch)
491
+ elapsed_time = time.time() - start_time
492
+
493
+ logger.info(
494
+ f"Completed embedding {processed_sequences}/{total_sequences} sequences in {elapsed_time/60:.1f} minutes"
495
+ )
496
+
497
+ return results
498
+
499
+ finally:
500
+ self.cleanup_memory(deep=True)
501
+
502
+ def cleanup_memory(self, deep: bool = False):
503
+ """Perform memory cleanup operations.
504
+
505
+ Args:
506
+ deep: If True, performs aggressive cleanup including model transfer
507
+ and garbage collection. Takes longer but frees more memory.
508
+
509
+ Note:
510
+ Regular cleanup is performed based on cleanup_frequency.
511
+ Deep cleanup is more thorough but takes longer.
512
+ """
513
+ self.processed_count += 1
514
+
515
+ if deep or self.processed_count % self.cleanup_frequency == 0:
516
+ logger.info(
517
+ f"Performing memory cleanup after {self.processed_count} sequences"
518
+ )
519
+ if torch.cuda.is_available():
520
+ before_mem = torch.cuda.memory_allocated() / 1e9
521
+
522
+ torch.cuda.empty_cache()
523
+ if deep:
524
+ self.model = self.model.cpu()
525
+ torch.cuda.empty_cache()
526
+ self.model = self.model.to(self.device)
527
+
528
+ after_mem = torch.cuda.memory_allocated() / 1e9
529
+ logger.info(
530
+ f"Memory cleaned up: {before_mem:.2f}GB -> {after_mem:.2f}GB"
531
+ )
532
+
533
+ if deep:
534
+ import gc
535
+
536
+ gc.collect()
537
+
538
+ def run(self, input_file, output_file=None):
539
+ """
540
+ Run the embedding process on the input file.
541
+
542
+ Args:
543
+ input_file (str): Path to the input FASTA file.
544
+ output_file (str, optional): Path to the output H5 file. If not provided,
545
+ it will be generated from the input file name.
546
+
547
+ Returns:
548
+ str: Path to the output H5 file containing the embeddings.
549
+ """
550
+ logger.info(f"Loading sequences from {input_file}")
551
+ start_time = time.time()
552
+ sequences = clean_noncanonical_fasta_to_dict(input_file)
553
+ load_time = time.time() - start_time
554
+ logger.info(
555
+ f"Loaded {len(sequences)} sequences from {input_file} in {load_time:.2f} seconds"
556
+ )
557
+
558
+ logger.info(f"Starting embedding process for {len(sequences)} sequences")
559
+ embed_start_time = time.time()
560
+ embeddings = self.get_embeddings(sequences)
561
+ embed_time = time.time() - embed_start_time
562
+ logger.info(
563
+ f"Completed embedding {len(embeddings)} sequences in {embed_time/60:.2f} minutes"
564
+ )
565
+
566
+ if output_file is None:
567
+ output_file = input_file.replace(".fasta", ".h5")
568
+
569
+ logger.info(f"Saving embeddings to {output_file}")
570
+ save_start_time = time.time()
571
+ self.save_to_h5(output_file, embeddings)
572
+ save_time = time.time() - save_start_time
573
+ logger.info(
574
+ f"Saved {len(embeddings)} embeddings to {output_file} in {save_time:.2f} seconds"
575
+ )
576
+
577
+ total_time = time.time() - start_time
578
+ logger.info(f"Total processing time: {total_time/60:.2f} minutes")
579
+
580
+ return output_file
581
+
582
+ def save_to_h5(self, output_file: str, embeddings: Dict[str, np.ndarray]) -> None:
583
+ """
584
+ Save protein embeddings to an HDF5 file.
585
+
586
+ Args:
587
+ output_file (str): Path to save the embeddings.
588
+ embeddings (Dict[str, np.ndarray]): Dictionary of embeddings.
589
+
590
+ The method creates an H5 file with two datasets:
591
+ - 'ids': contains protein IDs as variable-length strings
592
+ - 'vectors': contains embedding vectors as float32 arrays
593
+ """
594
+ # Convert the embeddings dictionary to lists for ids and vectors
595
+ ids = list(embeddings.keys())
596
+ vectors = np.array(list(embeddings.values()), dtype=np.float32)
597
+
598
+ # Create the HDF5 file, with datasets for vectors and IDs
599
+ with h5py.File(output_file, "w") as h5f:
600
+ # Create the 'vectors' dataset
601
+ h5f.create_dataset("vectors", data=vectors)
602
+
603
+ # Create the 'ids' dataset with variable-length strings
604
+ dt = h5py.special_dtype(vlen=str)
605
+ h5f.create_dataset("ids", data=ids, dtype=dt)
606
+
607
+ # Add the attributes
608
+ h5f.attrs["num_vecs"] = len(embeddings)
609
+ h5f.attrs["vec_dim"] = vectors.shape[1] if vectors.size > 0 else 0
610
+
611
+ def embed_big_prot(self, seq_id: str, sequence: str) -> np.ndarray:
612
+ """Embed a large protein sequence by chunking it and averaging the embeddings.
613
+
614
+ Args:
615
+ seq_id: The identifier for the protein sequence
616
+ sequence: The protein sequence to embed
617
+
618
+ Returns:
619
+ np.ndarray: The averaged embedding for the entire sequence
620
+
621
+ Note:
622
+ This method processes the sequence in chunks of size max_seq_length
623
+ and averages the resulting embeddings.
624
+ """
625
+ if not isinstance(sequence, str):
626
+ raise TypeError("Sequence must be a string.")
627
+
628
+ if not sequence:
629
+ raise ValueError("Sequence cannot be empty.")
630
+
631
+ if self.max_seq_length <= 0:
632
+ raise ValueError("max_seq_length must be greater than 0.")
633
+
634
+ # Create chunks of the sequence using max_seq_length
635
+ chunks: List[Tuple[str, str, int]] = [
636
+ (
637
+ seq_id,
638
+ sequence[i : i + self.max_seq_length],
639
+ min(self.max_seq_length, len(sequence) - i),
640
+ )
641
+ for i in range(0, len(sequence), self.max_seq_length)
642
+ ]
643
+
644
+ logger.info(
645
+ f"Processing {seq_id} in {len(chunks)} chunks (total length: {len(sequence)})"
646
+ )
647
+
648
+ # Embed each chunk
649
+ chunk_embeddings = []
650
+ for i, chunk in enumerate(chunks, 1):
651
+ logger.info(
652
+ f"Processing chunk {i}/{len(chunks)} for {seq_id} (length: {chunk[2]})"
653
+ )
654
+ chunk_start_time = time.time()
655
+ result = self.embed_batch([chunk])
656
+ chunk_embeddings.append(result[seq_id])
657
+ chunk_time = time.time() - chunk_start_time
658
+ logger.info(
659
+ f"Processed chunk {i}/{len(chunks)} for {seq_id} in {chunk_time:.2f} seconds"
660
+ )
661
+
662
+ # Average the embeddings
663
+ average_embedding = np.mean(chunk_embeddings, axis=0)
664
+ logger.info(f"Completed processing {seq_id} (averaged {len(chunks)} chunks)")
665
+
666
+ return average_embedding
667
+
668
+ def embed_batch(self, batch: List[Tuple[str, str, int]]) -> Dict[str, np.ndarray]:
669
+ """
670
+ Generate embeddings for a batch of sequences.
671
+
672
+ Args:
673
+ batch: A list of tuples, each containing (sequence_id, sequence, sequence_length)
674
+
675
+ Returns:
676
+ A dictionary mapping sequence IDs to their embeddings as numpy arrays
677
+ """
678
+ if not batch:
679
+ raise ValueError(
680
+ "Cannot embed an empty batch. Please provide at least one sequence."
681
+ )
682
+
683
+ sequence_ids, sequences, sequence_lengths = zip(*batch)
684
+
685
+ # Prepare sequences for tokenization
686
+ tokenizer_input = self.prepare_tokenizer_input(sequences)
687
+
688
+ # Tokenize sequences
689
+ encoded_input = self.tokenizer.batch_encode_plus(
690
+ tokenizer_input,
691
+ add_special_tokens=True,
692
+ padding="longest",
693
+ return_tensors="pt",
694
+ )
695
+
696
+ # Move tensors to the appropriate device
697
+ input_ids = encoded_input["input_ids"].to(self.device)
698
+ attention_mask = encoded_input["attention_mask"].to(self.device)
699
+
700
+ # Generate embeddings
701
+ with torch.no_grad():
702
+ embedding_output = self.model(
703
+ input_ids, attention_mask=attention_mask
704
+ ).last_hidden_state
705
+
706
+ # Process embeddings for each sequence
707
+ embeddings = {}
708
+ for idx, (seq_id, seq_len) in enumerate(zip(sequence_ids, sequence_lengths)):
709
+ # Extract embedding for the sequence
710
+ seq_embedding = self.extract_sequence_embedding(
711
+ embedding_output[idx], seq_len
712
+ )
713
+
714
+ # Calculate mean embedding and convert to numpy array
715
+ mean_embedding = seq_embedding.mean(dim=0).detach().cpu().numpy().squeeze()
716
+
717
+ embeddings[seq_id] = mean_embedding
718
+
719
+ return embeddings
720
+
721
+ def prepare_tokenizer_input(self, sequences: List[str]) -> List[str]:
722
+ """Prepare sequences for tokenization."""
723
+ raise NotImplementedError
724
+
725
+ def extract_sequence_embedding(
726
+ self, embedding: torch.Tensor, seq_len: int
727
+ ) -> torch.Tensor:
728
+ """Extract the relevant part of the embedding for a sequence."""
729
+ raise NotImplementedError
730
+
731
+
732
+ class ProstT5Embedder(Embedder):
733
+ """Protein sequence embedder using the ProstT5 model.
734
+
735
+ This class implements protein sequence embedding using the ProstT5 model,
736
+ which is specifically trained for protein structure prediction tasks.
737
+ It includes memory-efficient processing and automatic precision selection
738
+ based on available hardware.
739
+
740
+ Memory management features are inherited from the base Embedder class:
741
+ - Periodic cleanup of GPU memory
742
+ - Separate handling of large and small sequences
743
+ - Batch size limits based on total residues
744
+ - Configurable cleanup frequency
745
+ """
746
+
747
+ def __init__(
748
+ self,
749
+ max_seq_length: int = 5000,
750
+ large_protein_threshold: int = 2500,
751
+ batch_residue_limit: int = 5000,
752
+ cleanup_frequency: int = 100,
753
+ skip_long_proteins: bool = False,
754
+ ):
755
+ """Initialize ProstT5Embedder with memory management parameters.
756
+
757
+ Args:
758
+ max_seq_length: Maximum sequence length before chunking or skipping.
759
+ Also used as chunk size when processing long sequences.
760
+ large_protein_threshold: Sequences longer than this are processed individually
761
+ batch_residue_limit: Maximum total residues in a batch
762
+ cleanup_frequency: Frequency of memory cleanup operations
763
+ skip_long_proteins: If True, skip proteins longer than max_seq_length
764
+
765
+ Note:
766
+ The model automatically selects half precision (float16) when running on GPU
767
+ and full precision (float32) when running on CPU.
768
+ """
769
+ tokenizer = T5Tokenizer.from_pretrained(
770
+ "Rostlab/ProstT5", do_lower_case=False, legacy=True
771
+ )
772
+ model = T5EncoderModel.from_pretrained("Rostlab/ProstT5")
773
+
774
+ super().__init__(
775
+ model,
776
+ tokenizer,
777
+ max_seq_length,
778
+ large_protein_threshold,
779
+ batch_residue_limit,
780
+ cleanup_frequency,
781
+ skip_long_proteins,
782
+ )
783
+
784
+ # Set precision based on device
785
+ self.model = (
786
+ self.model.half() if torch.cuda.is_available() else self.model.float()
787
+ )
788
+
789
+ def prepare_tokenizer_input(self, sequences: List[str]) -> List[str]:
790
+ """Prepare sequences for ProstT5 tokenization.
791
+
792
+ Args:
793
+ sequences: List of amino acid sequences
794
+
795
+ Returns:
796
+ List of sequences with ProstT5-specific formatting, including the
797
+ <AA2fold> prefix and space-separated residues.
798
+ """
799
+ return [f"<AA2fold> {' '.join(seq)}" for seq in sequences]
800
+
801
+ def extract_sequence_embedding(
802
+ self, embedding: torch.Tensor, seq_len: int
803
+ ) -> torch.Tensor:
804
+ """Extract relevant embeddings for a sequence.
805
+
806
+ Args:
807
+ embedding: Raw embedding tensor from the model
808
+ seq_len: Length of the original sequence
809
+
810
+ Returns:
811
+ Tensor containing only the relevant sequence embeddings,
812
+ excluding special tokens. For ProstT5, we skip the first token
813
+ (corresponding to <AA2fold>) and take the next seq_len tokens.
814
+ """
815
+ return embedding[1 : seq_len + 1]
816
+
817
+
818
+ class T5Embedder(Embedder):
819
+ """Protein sequence embedder using the T5 transformer model.
820
+
821
+ This class implements protein sequence embedding using the T5 model from Rostlab,
822
+ specifically designed for protein sequences. It includes memory-efficient processing
823
+ of both large and small sequences.
824
+
825
+ The model used is 'Rostlab/prot_t5_xl_half_uniref50-enc', which was trained on
826
+ UniRef50 sequences and provides state-of-the-art protein embeddings.
827
+
828
+ Memory management features are inherited from the base Embedder class:
829
+ - Periodic cleanup of GPU memory
830
+ - Separate handling of large and small sequences
831
+ - Batch size limits based on total residues
832
+ - Configurable cleanup frequency
833
+ """
834
+
835
+ def __init__(
836
+ self,
837
+ max_seq_length: int = 5000,
838
+ large_protein_threshold: int = 2500,
839
+ batch_residue_limit: int = 5000,
840
+ cleanup_frequency: int = 100,
841
+ skip_long_proteins: bool = False,
842
+ ):
843
+ """Initialize T5Embedder with memory management parameters.
844
+
845
+ Args:
846
+ max_seq_length: Maximum sequence length before chunking or skipping.
847
+ Also used as chunk size when processing long sequences.
848
+ large_protein_threshold: Sequences longer than this are processed individually
849
+ batch_residue_limit: Maximum total residues in a batch
850
+ cleanup_frequency: Frequency of memory cleanup operations
851
+ skip_long_proteins: If True, skip proteins longer than max_seq_length
852
+
853
+ Note:
854
+ The model automatically handles memory management and batch processing
855
+ based on sequence sizes and available resources.
856
+ """
857
+ tokenizer = T5Tokenizer.from_pretrained(
858
+ "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
859
+ )
860
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
861
+
862
+ super().__init__(
863
+ model,
864
+ tokenizer,
865
+ max_seq_length,
866
+ large_protein_threshold,
867
+ batch_residue_limit,
868
+ cleanup_frequency,
869
+ skip_long_proteins,
870
+ )
871
+
872
+ def prepare_tokenizer_input(self, sequences: List[str]) -> List[str]:
873
+ """Prepare sequences for T5 tokenization.
874
+
875
+ Args:
876
+ sequences: List of amino acid sequences
877
+
878
+ Returns:
879
+ List of space-separated sequences ready for tokenization
880
+ """
881
+ return [" ".join(seq) for seq in sequences]
882
+
883
+ def extract_sequence_embedding(
884
+ self, embedding: torch.Tensor, seq_len: int
885
+ ) -> torch.Tensor:
886
+ """Extract relevant embeddings for a sequence.
887
+
888
+ Args:
889
+ embedding: Raw embedding tensor from the model
890
+ seq_len: Length of the original sequence
891
+
892
+ Returns:
893
+ Tensor containing only the relevant sequence embeddings
894
+ """
895
+ return embedding[:seq_len]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: dayhoff-tools
3
- Version: 1.1.40
3
+ Version: 1.1.42
4
4
  Summary: Common tools for all the repos at Dayhoff Labs
5
5
  Author: Daniel Martin-Alarcon
6
6
  Author-email: dma@dayhofflabs.com
@@ -7,12 +7,13 @@ dayhoff_tools/cli/main.py,sha256=47EGb28ALaYFc7oAUGlY1D66AIDmc4RZiXxN-gPVrpQ,451
7
7
  dayhoff_tools/cli/swarm_commands.py,sha256=5EyKj8yietvT5lfoz8Zx0iQvVaNgc3SJX1z2zQR6o6M,5614
8
8
  dayhoff_tools/cli/utility_commands.py,sha256=ER4VrJt4hu904MwrcltUXjwBWT4uFrP-aPXjdXyT3F8,24685
9
9
  dayhoff_tools/deployment/base.py,sha256=8tXwsPYvRo-zV-aNhHw1c7Rji-KWg8S5xoCCznFnVVI,17412
10
- dayhoff_tools/deployment/deploy_aws.py,sha256=jQyQ0fbm2793jEHFO84lr5tNqiOpdBg6U0S5zCVJr1M,17884
10
+ dayhoff_tools/deployment/deploy_aws.py,sha256=GvZpE2YIFA5Dl9rkAljFjtUypmPDNbWgw8NicHYTP24,18265
11
11
  dayhoff_tools/deployment/deploy_gcp.py,sha256=xgaOVsUDmP6wSEMYNkm1yRNcVskfdz80qJtCulkBIAM,8860
12
12
  dayhoff_tools/deployment/deploy_utils.py,sha256=StFwbqnr2_FWiKVg3xnJF4kagTHzndqqDkpaIOaAn_4,26027
13
13
  dayhoff_tools/deployment/job_runner.py,sha256=hljvFpH2Bw96uYyUup5Ths72PZRL_X27KxlYzBMgguo,5086
14
- dayhoff_tools/deployment/processors.py,sha256=f4L52ekx_zYirl8C4WfavxtOioyD-c34TdTJVDoLpWs,16572
14
+ dayhoff_tools/deployment/processors.py,sha256=LM0CQbr4XCb3AtLbrcuDQm4tYPXsoNqgVJ4WQYDjzJc,12406
15
15
  dayhoff_tools/deployment/swarm.py,sha256=MGcS2_x4RNFtnVjWlU_SwNfhICz8NlGYr9cYBK4ZKDA,21688
16
+ dayhoff_tools/embedders.py,sha256=fRkyWjHo8OmbNUBY_FwrgfvyiLqpmrpI57UAb1Szn1Y,36609
16
17
  dayhoff_tools/fasta.py,sha256=_kA2Cpiy7JAGbBqLrjElkzbcUD_p-nO2d5Aj1LVmOvc,50509
17
18
  dayhoff_tools/file_ops.py,sha256=JlGowvr-CUJFidV-4g_JmhUTN9bsYuaxtqKmnKomm-Q,8506
18
19
  dayhoff_tools/h5.py,sha256=j1nxxaiHsMidVX_XwB33P1Pz9d7K8ZKiDZwJWQUUQSY,21158
@@ -25,7 +26,7 @@ dayhoff_tools/intake/uniprot.py,sha256=BZYJQF63OtPcBBnQ7_P9gulxzJtqyorgyuDiPeOJq
25
26
  dayhoff_tools/logs.py,sha256=DKdeP0k0kliRcilwvX0mUB2eipO5BdWUeHwh-VnsICs,838
26
27
  dayhoff_tools/sqlite.py,sha256=jV55ikF8VpTfeQqqlHSbY8OgfyfHj8zgHNpZjBLos_E,18672
27
28
  dayhoff_tools/warehouse.py,sha256=TqV8nex1AluNaL4JuXH5zuu9P7qmE89lSo6f_oViy6U,14965
28
- dayhoff_tools-1.1.40.dist-info/METADATA,sha256=BVFDhkYLynnht5-XLaT8k7HjOk-L2eBQVLljVvb7pbc,2843
29
- dayhoff_tools-1.1.40.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
30
- dayhoff_tools-1.1.40.dist-info/entry_points.txt,sha256=iAf4jteNqW3cJm6CO6czLxjW3vxYKsyGLZ8WGmxamSc,49
31
- dayhoff_tools-1.1.40.dist-info/RECORD,,
29
+ dayhoff_tools-1.1.42.dist-info/METADATA,sha256=tXnXIVVY3rpyNqyzl8cfzciGdW8ajHlJPpdliTYMD0M,2843
30
+ dayhoff_tools-1.1.42.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
31
+ dayhoff_tools-1.1.42.dist-info/entry_points.txt,sha256=iAf4jteNqW3cJm6CO6czLxjW3vxYKsyGLZ8WGmxamSc,49
32
+ dayhoff_tools-1.1.42.dist-info/RECORD,,