dayhoff-tools 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,434 @@
1
+ import csv
2
+ import gzip
3
+ import io
4
+ import os
5
+ import time
6
+ from typing import Iterator, List, Set
7
+
8
+ import h5py
9
+ import numpy as np
10
+ import pandas as pd
11
+ from Bio import SwissProt
12
+ from tqdm import tqdm
13
+
14
+
15
+ def parse_entries(file: Iterator[str]) -> Iterator[str]:
16
+ """
17
+ Generator function to parse SwissProt entries efficiently.
18
+ """
19
+ buffer = io.StringIO()
20
+ for line in file:
21
+ buffer.write(line)
22
+ if line.startswith("//"):
23
+ yield buffer.getvalue()
24
+ buffer = io.StringIO()
25
+ if buffer.getvalue():
26
+ yield buffer.getvalue()
27
+
28
+
29
+ def extract_cofactors(inpath: str, outpath: str, source_db: str):
30
+ """
31
+ Read a compressed .dat.gz file, extract the cofactors for each entry,
32
+ and write the results to a TSV file.
33
+
34
+ Args:
35
+ inpath (str): Path to the input .dat.gz file as obtained from UniProt
36
+ outpath (str): Path to the output TSV file
37
+ source_db (str): Name of the database to list as a source for this data
38
+ """
39
+ unparseable_entries = 0
40
+ total_entries = 0
41
+
42
+ with gzip.open(inpath, "rt") as file, open(outpath, "w", newline="") as tsv_file:
43
+ tsv_writer = csv.writer(tsv_file, delimiter="\t")
44
+ tsv_writer.writerow(
45
+ ["protein_id", "cofactor_names", "cofactor_chebi", "source_db"]
46
+ )
47
+
48
+ for entry in parse_entries(file):
49
+ total_entries += 1
50
+ if total_entries % 1_000_000 == 0:
51
+ print(f"Processed {total_entries:,} entries")
52
+
53
+ try:
54
+ record = SwissProt.read(io.StringIO(entry))
55
+ ac = record.accessions[0]
56
+ all_names = []
57
+ all_chebis = []
58
+
59
+ for comment in record.comments:
60
+ if "COFACTOR" in comment:
61
+ names = []
62
+ chebis = []
63
+ parts = comment.split(";")
64
+ for part in parts:
65
+ if "Name=" in part:
66
+ names.append(part.split("Name=")[1].strip())
67
+ if "Xref=ChEBI:CHEBI:" in part:
68
+ chebis.append(
69
+ part.split("Xref=ChEBI:CHEBI:")[1].strip()
70
+ )
71
+ if names and chebis:
72
+ all_names.append("|".join(names))
73
+ all_chebis.append("|".join(chebis))
74
+
75
+ tsv_writer.writerow(
76
+ [
77
+ ac,
78
+ ";".join(all_names) if all_names else None,
79
+ ";".join(all_chebis) if all_chebis else None,
80
+ source_db,
81
+ ]
82
+ )
83
+ except SwissProt.SwissProtParserError:
84
+ unparseable_entries += 1
85
+
86
+ print(f"Total entries processed: {total_entries:,}")
87
+ print(f"Number of entries that couldn't be parsed: {unparseable_entries:,}")
88
+
89
+
90
+ def one_hot_encode_cofactors(df: pd.DataFrame) -> pd.DataFrame:
91
+ """
92
+ One-hot encode the cofactors in the 'cofactor_chebi' column of a dataframe.
93
+
94
+ This function takes a dataframe with a 'cofactor_chebi' column and creates new columns
95
+ for each unique cofactor, with binary values indicating the presence or absence of
96
+ each cofactor for each protein.
97
+
98
+ Parameters:
99
+ -----------
100
+ df : pandas.DataFrame
101
+ The input dataframe containing the 'cofactor_chebi' column.
102
+
103
+ Returns:
104
+ --------
105
+ pandas.DataFrame
106
+ A new dataframe with additional columns for each unique cofactor.
107
+
108
+ Notes:
109
+ ------
110
+ - The function handles multiple cofactors separated by '|' or ';'.
111
+ - NaN values and empty strings in the 'cofactor_chebi' column are preserved.
112
+ - The function uses TQDM to show progress during processing.
113
+ - Print statements are included to indicate major transitions in the process.
114
+ - Duplicate cofactors in a single entry are handled correctly (only encoded once).
115
+ """
116
+ print("Starting one-hot encoding process...")
117
+
118
+ # Extract unique cofactors
119
+ print("Extracting unique cofactors...")
120
+ all_cofactors = set(
121
+ cofactor.strip()
122
+ for cofactors in df["cofactor_chebi"].dropna()
123
+ for cofactor in cofactors.replace("|", ";").split(";")
124
+ if cofactor.strip()
125
+ )
126
+ print(f"Found {len(all_cofactors)} unique cofactors.")
127
+
128
+ # Create a dictionary to map cofactors to column indices
129
+ cofactor_to_index = {cofactor: i for i, cofactor in enumerate(all_cofactors)}
130
+
131
+ # Initialize the result array
132
+ print("Initializing result array...")
133
+ result_array = np.zeros((len(df), len(all_cofactors)), dtype=np.int8)
134
+
135
+ # Perform one-hot encoding
136
+ print("Performing one-hot encoding...")
137
+ mask = df["cofactor_chebi"].notna()
138
+ for idx, cofactors in tqdm(
139
+ df.loc[mask, "cofactor_chebi"].items(), total=mask.sum(), desc="Processing rows"
140
+ ):
141
+ if cofactors.strip():
142
+ for cofactor in set(
143
+ cofactor.strip() for cofactor in cofactors.replace("|", ";").split(";")
144
+ ):
145
+ if cofactor:
146
+ result_array[idx, cofactor_to_index[cofactor]] = 1
147
+
148
+ # Create the result dataframe
149
+ print("Creating result dataframe...")
150
+ result_df = pd.DataFrame(result_array, columns=list(all_cofactors), index=df.index)
151
+
152
+ # Combine with original dataframe
153
+ print("Combining results...")
154
+ result_df = pd.concat([df, result_df], axis=1)
155
+
156
+ print("One-hot encoding process completed.")
157
+ return result_df
158
+
159
+
160
+ def merge_one_hot_encoded_datasets(
161
+ df1: pd.DataFrame, df2: pd.DataFrame
162
+ ) -> pd.DataFrame:
163
+ """
164
+ Merge two dataframes with one-hot encoded cofactors, preserving the encoding structure.
165
+
166
+ Args:
167
+ df1 (pd.DataFrame): The first input dataframe containing one-hot encoded cofactor columns.
168
+ df2 (pd.DataFrame): The second input dataframe containing one-hot encoded cofactor columns.
169
+
170
+ Returns:
171
+ pd.DataFrame: A new dataframe with all original columns and a complete set of one-hot encoded cofactor columns.
172
+
173
+ Raises:
174
+ ValueError: If any cofactor column in the input dataframes contains NaN values.
175
+ ValueError: If the input dataframes do not contain the expected original columns.
176
+
177
+ Notes:
178
+ This function assumes that the non-cofactor columns are:
179
+ ["protein_id", "cofactor_names", "cofactor_chebi", "source_db"]
180
+ """
181
+ expected_columns: List[str] = [
182
+ "protein_id",
183
+ "cofactor_names",
184
+ "cofactor_chebi",
185
+ "source_db",
186
+ ]
187
+
188
+ def check_original_columns(df: pd.DataFrame, name: str) -> None:
189
+ missing_columns = set(expected_columns) - set(df.columns)
190
+ if missing_columns:
191
+ raise ValueError(
192
+ f"{name} is missing the following expected columns: {', '.join(missing_columns)}"
193
+ )
194
+
195
+ check_original_columns(df1, "First dataframe")
196
+ check_original_columns(df2, "Second dataframe")
197
+
198
+ def get_cofactor_columns(df: pd.DataFrame) -> List[str]:
199
+ return [col for col in df.columns if col not in expected_columns]
200
+
201
+ cofactor_columns1 = get_cofactor_columns(df1)
202
+ cofactor_columns2 = get_cofactor_columns(df2)
203
+
204
+ # Verify no NaN values in cofactor columns
205
+ for df, cols, name in [
206
+ (df1, cofactor_columns1, "First dataframe"),
207
+ (df2, cofactor_columns2, "Second dataframe"),
208
+ ]:
209
+ if df[cols].isna().any().any():
210
+ raise ValueError(f"Input dataframe contains NaN values in cofactor columns")
211
+
212
+ all_cofactor_columns = sorted(set(cofactor_columns1 + cofactor_columns2))
213
+
214
+ def add_missing_columns(df: pd.DataFrame, all_columns: List[str]) -> pd.DataFrame:
215
+ for col in all_columns:
216
+ if col not in df.columns:
217
+ df[col] = 0
218
+ return df
219
+
220
+ df1_complete = add_missing_columns(df1, all_cofactor_columns)
221
+ df2_complete = add_missing_columns(df2, all_cofactor_columns)
222
+
223
+ column_order = expected_columns + all_cofactor_columns
224
+ df1_complete = df1_complete[column_order]
225
+ df2_complete = df2_complete[column_order]
226
+
227
+ merged_df = pd.concat([df1_complete, df2_complete], axis=0, ignore_index=True)
228
+
229
+ return merged_df
230
+
231
+
232
+ def filter_tsv_by_protein_ids(
233
+ input_file: str,
234
+ output_file: str,
235
+ protein_id_set: Set[str],
236
+ batch_size: int = 10000,
237
+ report_interval: float = 5.0,
238
+ ) -> None:
239
+ """
240
+ Filter a TSV file based on a set of protein IDs and write the results to a new file.
241
+
242
+ This function reads the input TSV file, filters rows based on the provided protein ID set,
243
+ and writes the filtered data to the output file. It also provides progress reporting
244
+ during the process.
245
+
246
+ Args:
247
+ input_file (str): Path to the input TSV file.
248
+ output_file (str): Path to the output TSV file.
249
+ protein_id_set (Set[str]): Set of protein IDs to filter by in the 'protein_id' column.
250
+ batch_size (int, optional): Number of rows to write in each batch. Defaults to 10000.
251
+ report_interval (float, optional): Interval in seconds for progress reporting. Defaults to 5.0.
252
+
253
+ Raises:
254
+ FileNotFoundError: If the input file is not found.
255
+ ValueError: If the 'protein_id' column is missing in the input file.
256
+ """
257
+ try:
258
+ start_time = time.time()
259
+ last_report_time = start_time
260
+
261
+ total_rows = sum(1 for _ in open(input_file, "r"))
262
+ processed_rows = 0
263
+ matched_rows = 0
264
+
265
+ with (
266
+ open(input_file, "r", newline="") as infile,
267
+ open(output_file, "w", newline="") as outfile,
268
+ ):
269
+ reader = csv.DictReader(infile, delimiter="\t")
270
+ if "protein_id" not in reader.fieldnames:
271
+ raise ValueError(
272
+ "Required column 'protein_id' is missing in the input file."
273
+ )
274
+
275
+ writer = csv.DictWriter(
276
+ outfile, fieldnames=reader.fieldnames, delimiter="\t"
277
+ )
278
+ writer.writeheader()
279
+
280
+ batch = []
281
+
282
+ for row in reader:
283
+ processed_rows += 1
284
+
285
+ if row["protein_id"] in protein_id_set:
286
+ batch.append(row)
287
+ matched_rows += 1
288
+
289
+ if len(batch) >= batch_size:
290
+ writer.writerows(batch)
291
+ batch = []
292
+
293
+ current_time = time.time()
294
+ if current_time - last_report_time >= report_interval:
295
+ progress = processed_rows / total_rows
296
+ elapsed_time = current_time - start_time
297
+ estimated_total_time = (
298
+ elapsed_time / progress if progress > 0 else 0
299
+ )
300
+ estimated_remaining_time = max(
301
+ estimated_total_time - elapsed_time, 0
302
+ )
303
+ print(
304
+ f"Progress: {progress:.2%} | Rows processed: {processed_rows}/{total_rows} | "
305
+ f"Matches: {matched_rows} | Est. time remaining: {estimated_remaining_time:.2f} seconds"
306
+ )
307
+ last_report_time = current_time
308
+
309
+ if batch:
310
+ writer.writerows(batch)
311
+
312
+ print(f"Processing complete. Output written to {output_file}")
313
+ print(f"Total rows processed: {processed_rows}")
314
+ print(f"Total matches found: {matched_rows}")
315
+ print(f"Total processing time: {time.time() - start_time:.2f} seconds")
316
+
317
+ except FileNotFoundError:
318
+ print(f"Error: Input file '{input_file}' not found.")
319
+ raise
320
+ except ValueError as e:
321
+ print(f"Error: {str(e)}")
322
+ raise
323
+
324
+
325
+ def concatenate_embeddings_with_ohe(
326
+ embedding_file: str,
327
+ ohe_file: str,
328
+ output_file: str,
329
+ non_ohe_columns: List[str] = [
330
+ "protein_id",
331
+ "cofactor_names",
332
+ "cofactor_chebi",
333
+ "source_db",
334
+ ],
335
+ chunk_size: int = 10000,
336
+ ) -> Set[str]:
337
+ """
338
+ Concatenate protein embeddings from a large H5 file with one-hot encoded (OHE) cofactor data from a TSV file.
339
+
340
+ This function efficiently processes large H5 files by reading and writing in chunks.
341
+ It handles missing proteins by filling their OHE data with zeros.
342
+ Progress updates include elapsed time and estimated time remaining.
343
+
344
+ Args:
345
+ embedding_file (str): Path to the input H5 file containing protein embeddings.
346
+ ohe_file (str): Path to the input TSV file containing OHE cofactor data.
347
+ output_file (str): Path to the output H5 file where concatenated data will be written.
348
+ non_ohe_columns (List[str]): List of column names in the TSV file that are not OHE data.
349
+ chunk_size (int): Number of proteins to process in each chunk.
350
+
351
+ Returns:
352
+ Set[str]: A set of protein IDs that were in the embedding file but not in the OHE file.
353
+
354
+ Raises:
355
+ FileNotFoundError: If either input file is not found.
356
+ ValueError: If the input H5 file is empty or in an unexpected format.
357
+ """
358
+ start_time = time.time()
359
+ print(
360
+ f"Starting optimized protein embedding concatenation process at {time.strftime('%Y-%m-%d %H:%M:%S')}"
361
+ )
362
+
363
+ # Check if input files exist
364
+ if not os.path.exists(embedding_file) or not os.path.exists(ohe_file):
365
+ raise FileNotFoundError("One or both input files not found.")
366
+
367
+ # Load protein IDs from H5 file
368
+ with h5py.File(embedding_file, "r") as h5_in:
369
+ protein_ids = [id.decode("utf-8") for id in h5_in["ids"]]
370
+
371
+ if not protein_ids:
372
+ raise ValueError("Input H5 file is empty or in unexpected format.")
373
+
374
+ print(f"Loaded {len(protein_ids)} protein IDs from H5 file.")
375
+
376
+ # Load and process OHE data
377
+ print("Loading and filtering OHE data...")
378
+ ohe_df = pd.read_csv(ohe_file, sep="\t")
379
+ ohe_columns = [col for col in ohe_df.columns if col not in non_ohe_columns]
380
+ ohe_df = ohe_df.set_index("protein_id")
381
+
382
+ num_ohe_columns = len(ohe_columns)
383
+
384
+ # Identify missing proteins
385
+ missing_proteins = set(protein_ids) - set(ohe_df.index)
386
+ print(f"Number of proteins missing from OHE data: {len(missing_proteins)}")
387
+
388
+ # Process data in chunks and write to output file
389
+ with h5py.File(embedding_file, "r") as h5_in, h5py.File(output_file, "w") as h5_out:
390
+ # Create datasets in output file
391
+ h5_out.create_dataset("ids", data=[id.encode("utf-8") for id in protein_ids])
392
+
393
+ embedding_size = h5_in["vectors"].shape[1]
394
+ h5_out.create_dataset(
395
+ "vectors",
396
+ shape=(len(protein_ids), embedding_size + num_ohe_columns),
397
+ dtype=np.float32,
398
+ )
399
+
400
+ total_chunks = (len(protein_ids) + chunk_size - 1) // chunk_size
401
+ for chunk_index in range(total_chunks):
402
+ chunk_start_time = time.time()
403
+ start_idx = chunk_index * chunk_size
404
+ end_idx = min((chunk_index + 1) * chunk_size, len(protein_ids))
405
+
406
+ chunk_proteins = protein_ids[start_idx:end_idx]
407
+ embedding_vectors = h5_in["vectors"][start_idx:end_idx]
408
+
409
+ # Create OHE vectors for the chunk, filling with zeros for missing proteins
410
+ ohe_vectors = np.zeros((len(chunk_proteins), num_ohe_columns))
411
+ for i, protein_id in enumerate(chunk_proteins):
412
+ if protein_id in ohe_df.index:
413
+ ohe_vectors[i] = ohe_df.loc[protein_id, ohe_columns].values
414
+
415
+ # Concatenate embedding vectors with OHE vectors
416
+ concatenated_vectors = np.hstack((embedding_vectors, ohe_vectors))
417
+ h5_out["vectors"][start_idx:end_idx] = concatenated_vectors
418
+
419
+ # Calculate and print progress information
420
+ elapsed_time = time.time() - start_time
421
+ estimated_total_time = elapsed_time * total_chunks / (chunk_index + 1)
422
+ estimated_remaining_time = estimated_total_time - elapsed_time
423
+
424
+ print(f"Processed chunk {chunk_index + 1}/{total_chunks}")
425
+ print(f" Elapsed time: {elapsed_time:.2f} seconds")
426
+ print(f" Estimated time remaining: {estimated_remaining_time:.2f} seconds")
427
+ print(f" Estimated total time: {estimated_total_time:.2f} seconds")
428
+
429
+ total_time = time.time() - start_time
430
+ print(
431
+ f"Concatenation complete. Written {len(protein_ids)} proteins to output file."
432
+ )
433
+ print(f"Total execution time: {total_time:.2f} seconds")
434
+ return missing_proteins