dayhoff-tools 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dayhoff_tools/__init__.py +0 -0
- dayhoff_tools/chemistry/standardizer.py +297 -0
- dayhoff_tools/chemistry/utils.py +63 -0
- dayhoff_tools/cli/__init__.py +0 -0
- dayhoff_tools/cli/main.py +90 -0
- dayhoff_tools/cli/swarm_commands.py +156 -0
- dayhoff_tools/cli/utility_commands.py +244 -0
- dayhoff_tools/deployment/base.py +434 -0
- dayhoff_tools/deployment/deploy_aws.py +458 -0
- dayhoff_tools/deployment/deploy_gcp.py +176 -0
- dayhoff_tools/deployment/deploy_utils.py +781 -0
- dayhoff_tools/deployment/job_runner.py +153 -0
- dayhoff_tools/deployment/processors.py +125 -0
- dayhoff_tools/deployment/swarm.py +591 -0
- dayhoff_tools/embedders.py +893 -0
- dayhoff_tools/fasta.py +1082 -0
- dayhoff_tools/file_ops.py +261 -0
- dayhoff_tools/gcp.py +85 -0
- dayhoff_tools/h5.py +542 -0
- dayhoff_tools/kegg.py +37 -0
- dayhoff_tools/logs.py +27 -0
- dayhoff_tools/mmseqs.py +164 -0
- dayhoff_tools/sqlite.py +516 -0
- dayhoff_tools/structure.py +751 -0
- dayhoff_tools/uniprot.py +434 -0
- dayhoff_tools/warehouse.py +418 -0
- dayhoff_tools-1.0.0.dist-info/METADATA +122 -0
- dayhoff_tools-1.0.0.dist-info/RECORD +30 -0
- dayhoff_tools-1.0.0.dist-info/WHEEL +4 -0
- dayhoff_tools-1.0.0.dist-info/entry_points.txt +3 -0
dayhoff_tools/h5.py
ADDED
@@ -0,0 +1,542 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
import time
|
4
|
+
from collections import Counter
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Dict, List, Set, Union
|
7
|
+
|
8
|
+
import h5py
|
9
|
+
import numpy as np
|
10
|
+
from tqdm import tqdm
|
11
|
+
|
12
|
+
|
13
|
+
def combine_h5_files(
|
14
|
+
input_files: Union[str, List[str]], output_file: str, chunk_size: int = 10000
|
15
|
+
) -> None:
|
16
|
+
"""Combine several .h5 embedding files into one efficiently, with detailed progress indication.
|
17
|
+
Assumes they have two datasets: `ids` and `vectors`, within which order is important.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
input_files (Union[str, List[str]]): Either a path to the folder containing .h5 files,
|
21
|
+
or a list of paths to individual .h5 files.
|
22
|
+
output_file (str): The path to the output .h5 file.
|
23
|
+
chunk_size (int): Number of rows to process at a time. Default is 10000.
|
24
|
+
|
25
|
+
Raises:
|
26
|
+
FileExistsError: If the output file already exists.
|
27
|
+
"""
|
28
|
+
if os.path.exists(output_file):
|
29
|
+
raise FileExistsError(
|
30
|
+
f"Output file '{output_file}' already exists. Please choose a different output file name."
|
31
|
+
)
|
32
|
+
|
33
|
+
def get_file_list(input_files: Union[str, List[str]]) -> List[str]:
|
34
|
+
if isinstance(input_files, str):
|
35
|
+
if os.path.isdir(input_files):
|
36
|
+
files = [f for f in os.listdir(input_files) if f.endswith(".h5")]
|
37
|
+
files = [os.path.join(input_files, f) for f in files]
|
38
|
+
else:
|
39
|
+
raise ValueError(
|
40
|
+
"If a string is provided, it must be a directory path."
|
41
|
+
)
|
42
|
+
elif isinstance(input_files, list):
|
43
|
+
files = [f for f in input_files if os.path.isfile(f) and f.endswith(".h5")]
|
44
|
+
if len(files) != len(input_files):
|
45
|
+
raise ValueError("All input files must exist and have .h5 extension.")
|
46
|
+
else:
|
47
|
+
raise TypeError(
|
48
|
+
"input_files must be either a string (directory path) or a list of file paths."
|
49
|
+
)
|
50
|
+
|
51
|
+
return sorted(
|
52
|
+
files,
|
53
|
+
key=lambda x: (
|
54
|
+
int(re.search(r"_(\d+)\.h5$", x).group(1))
|
55
|
+
if re.search(r"_(\d+)\.h5$", x)
|
56
|
+
else float("inf")
|
57
|
+
),
|
58
|
+
)
|
59
|
+
|
60
|
+
files = get_file_list(input_files)
|
61
|
+
|
62
|
+
# First pass: calculate total size and determine vector dimension
|
63
|
+
total_rows = 0
|
64
|
+
vector_dim = None
|
65
|
+
for file_path in tqdm(files, desc="Calculating total size"):
|
66
|
+
with h5py.File(file_path, "r") as h5_in:
|
67
|
+
total_rows += h5_in["ids"].shape[0]
|
68
|
+
if vector_dim is None:
|
69
|
+
vector_dim = h5_in["vectors"].shape[1]
|
70
|
+
|
71
|
+
with h5py.File(output_file, "w") as h5_out:
|
72
|
+
# Initialize datasets in the output file
|
73
|
+
id_dataset = h5_out.create_dataset(
|
74
|
+
"ids",
|
75
|
+
shape=(total_rows,),
|
76
|
+
dtype=h5py.special_dtype(vlen=str),
|
77
|
+
chunks=True,
|
78
|
+
)
|
79
|
+
vector_dataset = h5_out.create_dataset(
|
80
|
+
"vectors",
|
81
|
+
shape=(total_rows, vector_dim),
|
82
|
+
dtype=np.float32,
|
83
|
+
chunks=True,
|
84
|
+
)
|
85
|
+
|
86
|
+
# Second pass: copy data with more granular progress updates
|
87
|
+
current_index = 0
|
88
|
+
with tqdm(total=total_rows, desc="Combining files", unit="rows") as pbar:
|
89
|
+
for file_path in files:
|
90
|
+
with h5py.File(file_path, "r") as h5_in:
|
91
|
+
file_size = h5_in["ids"].shape[0]
|
92
|
+
|
93
|
+
for i in range(0, file_size, chunk_size):
|
94
|
+
end = min(i + chunk_size, file_size)
|
95
|
+
chunk_size_actual = end - i
|
96
|
+
|
97
|
+
# Read and write IDs
|
98
|
+
ids = h5_in["ids"][i:end]
|
99
|
+
id_dataset[
|
100
|
+
current_index : current_index + chunk_size_actual
|
101
|
+
] = [id.decode("utf-8") for id in ids]
|
102
|
+
|
103
|
+
# Read and write vectors
|
104
|
+
vectors = h5_in["vectors"][i:end]
|
105
|
+
vector_dataset[
|
106
|
+
current_index : current_index + chunk_size_actual
|
107
|
+
] = vectors
|
108
|
+
|
109
|
+
current_index += chunk_size_actual
|
110
|
+
pbar.update(chunk_size_actual)
|
111
|
+
|
112
|
+
print(f"Combined {total_rows} rows into {output_file}")
|
113
|
+
|
114
|
+
|
115
|
+
def extract_h5_entries(
|
116
|
+
input_file: str, output_file: str, targets: Set[str], chunk_size: int = 10000
|
117
|
+
) -> Set[str]:
|
118
|
+
"""
|
119
|
+
Extract specific entries from a large H5 file based on target IDs and save them to a new H5 file.
|
120
|
+
|
121
|
+
This function reads 'ids' and 'vectors' datasets from the input H5 file in chunks,
|
122
|
+
finds the entries corresponding to the target IDs, and saves them
|
123
|
+
to the output H5 file. The order of entries in the output file may differ
|
124
|
+
from the input file, but the correspondence between IDs and vectors is preserved.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
input_file (str): Path to the input H5 file.
|
128
|
+
output_file (str): Path to the output H5 file.
|
129
|
+
targets (Set[str]): Set of IDs to extract.
|
130
|
+
chunk_size (int): Number of entries to process at a time. Default is 10000.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
Set[str]: A set of IDs that were not found in the H5 file.
|
134
|
+
|
135
|
+
Raises:
|
136
|
+
ValueError: If 'ids' or 'vectors' datasets are not found in the input file.
|
137
|
+
ValueError: If the number of IDs and vectors in the input file don't match.
|
138
|
+
FileExistsError: If the output file already exists.
|
139
|
+
"""
|
140
|
+
if os.path.exists(output_file):
|
141
|
+
raise FileExistsError(
|
142
|
+
f"Output file '{output_file}' already exists. Please choose a different output file name."
|
143
|
+
)
|
144
|
+
|
145
|
+
print(f"Opening input file: {input_file}")
|
146
|
+
with h5py.File(input_file, "r") as in_file, h5py.File(output_file, "w") as out_file:
|
147
|
+
if "ids" not in in_file or "vectors" not in in_file:
|
148
|
+
raise ValueError("Input file must contain 'ids' and 'vectors' datasets")
|
149
|
+
|
150
|
+
ids_dataset = in_file["ids"]
|
151
|
+
vectors_dataset = in_file["vectors"]
|
152
|
+
|
153
|
+
if len(ids_dataset) != len(vectors_dataset):
|
154
|
+
raise ValueError("Number of IDs and vectors in input file don't match")
|
155
|
+
|
156
|
+
total_entries = len(ids_dataset)
|
157
|
+
vector_shape = vectors_dataset.shape[1:]
|
158
|
+
|
159
|
+
# Create datasets in the output file
|
160
|
+
out_ids = out_file.create_dataset(
|
161
|
+
"ids", shape=(0,), maxshape=(None,), dtype=ids_dataset.dtype
|
162
|
+
)
|
163
|
+
out_vectors = out_file.create_dataset(
|
164
|
+
"vectors",
|
165
|
+
shape=(0,) + vector_shape,
|
166
|
+
maxshape=(None,) + vector_shape,
|
167
|
+
dtype=vectors_dataset.dtype,
|
168
|
+
)
|
169
|
+
|
170
|
+
found_count = 0
|
171
|
+
not_found_ids = set(targets) # Initialize with all target IDs
|
172
|
+
|
173
|
+
for start_idx in tqdm(
|
174
|
+
range(0, total_entries, chunk_size), desc="Processing chunks"
|
175
|
+
):
|
176
|
+
end_idx = min(start_idx + chunk_size, total_entries)
|
177
|
+
|
178
|
+
chunk_ids = ids_dataset[start_idx:end_idx]
|
179
|
+
chunk_vectors = vectors_dataset[start_idx:end_idx]
|
180
|
+
|
181
|
+
# Find matching IDs in the current chunk
|
182
|
+
chunk_id_set = set(id.decode() for id in chunk_ids)
|
183
|
+
matching_ids = chunk_id_set.intersection(not_found_ids)
|
184
|
+
|
185
|
+
if matching_ids:
|
186
|
+
mask = np.array([id.decode() in matching_ids for id in chunk_ids])
|
187
|
+
matching_chunk_ids = chunk_ids[mask]
|
188
|
+
matching_chunk_vectors = chunk_vectors[mask]
|
189
|
+
|
190
|
+
# Resize output datasets
|
191
|
+
current_size = out_ids.shape[0]
|
192
|
+
new_size = current_size + len(matching_chunk_ids)
|
193
|
+
out_ids.resize(new_size, axis=0)
|
194
|
+
out_vectors.resize(new_size, axis=0)
|
195
|
+
|
196
|
+
# Add matching data to output datasets
|
197
|
+
out_ids[current_size:new_size] = matching_chunk_ids
|
198
|
+
out_vectors[current_size:new_size] = matching_chunk_vectors
|
199
|
+
|
200
|
+
found_count += len(matching_chunk_ids)
|
201
|
+
not_found_ids -= matching_ids
|
202
|
+
|
203
|
+
print(f"Found {found_count} out of {len(targets)} target IDs.")
|
204
|
+
print(f"Still missing: {len(not_found_ids)}")
|
205
|
+
return not_found_ids
|
206
|
+
|
207
|
+
|
208
|
+
def extract_h5_ids(file_path: str) -> Set[str]:
|
209
|
+
"""
|
210
|
+
Extract and decode IDs from an HDF5 file, keeping track of duplicates.
|
211
|
+
This function opens an HDF5 file, reads the 'ids' dataset,
|
212
|
+
decodes each ID from bytes to UTF-8 strings, and returns
|
213
|
+
a set of unique IDs. It also prints the number of duplicate IDs found.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
file_path (str): The path to the HDF5 file.
|
217
|
+
|
218
|
+
Returns:
|
219
|
+
Set[str]: A set of unique, decoded IDs.
|
220
|
+
|
221
|
+
Raises:
|
222
|
+
IOError: If the file cannot be opened or read.
|
223
|
+
KeyError: If the 'ids' dataset is not found in the file.
|
224
|
+
"""
|
225
|
+
with h5py.File(file_path, "r") as h5_file:
|
226
|
+
ids = h5_file["ids"][:]
|
227
|
+
id_counter = Counter()
|
228
|
+
for id in tqdm(ids, desc="Extracting IDs", unit="id"):
|
229
|
+
id_counter[id.decode("utf-8")] += 1
|
230
|
+
|
231
|
+
id_set = set(id_counter.keys())
|
232
|
+
duplicate_count = sum(count - 1 for count in id_counter.values() if count > 1)
|
233
|
+
|
234
|
+
print(f"Successfully extracted {len(id_set):,} unique ids")
|
235
|
+
print(f"Also noticed {duplicate_count:,} duplicate ids")
|
236
|
+
|
237
|
+
return id_set
|
238
|
+
|
239
|
+
|
240
|
+
def deduplicate_h5_file(input_filename, output_filename, chunk_size=10000):
|
241
|
+
"""
|
242
|
+
Create a de-duplicated version of the input H5 file, optimized for very large files.
|
243
|
+
Works without knowing the total number of entries in advance.
|
244
|
+
|
245
|
+
:param input_filename: Name of the input H5 file
|
246
|
+
:param output_filename: Name of the output H5 file
|
247
|
+
:param chunk_size: Number of entries to process at a time
|
248
|
+
:return: Number of duplicates removed
|
249
|
+
:raises FileExistsError: If the output file already exists
|
250
|
+
"""
|
251
|
+
if os.path.exists(output_filename):
|
252
|
+
raise FileExistsError(
|
253
|
+
f"Output file '{output_filename}' already exists. Please choose a different output file name."
|
254
|
+
)
|
255
|
+
|
256
|
+
with (
|
257
|
+
h5py.File(input_filename, "r") as input_file,
|
258
|
+
h5py.File(output_filename, "w") as output_file,
|
259
|
+
):
|
260
|
+
# Get dataset information
|
261
|
+
ids_dataset = input_file["ids"]
|
262
|
+
vectors_dataset = input_file["vectors"]
|
263
|
+
total_entries = ids_dataset.shape[0]
|
264
|
+
|
265
|
+
# Create datasets in the output file with unlimited max shape
|
266
|
+
output_ids = output_file.create_dataset(
|
267
|
+
"ids", shape=(0,), dtype=ids_dataset.dtype, maxshape=(None,), chunks=True
|
268
|
+
)
|
269
|
+
output_vectors = output_file.create_dataset(
|
270
|
+
"vectors",
|
271
|
+
shape=(0, vectors_dataset.shape[1]),
|
272
|
+
dtype=vectors_dataset.dtype,
|
273
|
+
maxshape=(None, vectors_dataset.shape[1]),
|
274
|
+
chunks=True,
|
275
|
+
)
|
276
|
+
|
277
|
+
unique_ids = {}
|
278
|
+
|
279
|
+
# Process the file in chunks
|
280
|
+
with tqdm(total=total_entries, desc="De-duplicating", unit="entry") as pbar:
|
281
|
+
for start_idx in range(0, total_entries, chunk_size):
|
282
|
+
end_idx = min(start_idx + chunk_size, total_entries)
|
283
|
+
|
284
|
+
chunk_ids = ids_dataset[start_idx:end_idx]
|
285
|
+
chunk_vectors = vectors_dataset[start_idx:end_idx]
|
286
|
+
|
287
|
+
for i, id_value in enumerate(chunk_ids):
|
288
|
+
id_str = id_value.decode("utf-8")
|
289
|
+
if id_str not in unique_ids:
|
290
|
+
unique_ids[id_str] = len(unique_ids)
|
291
|
+
|
292
|
+
# Resize output datasets
|
293
|
+
new_size = len(unique_ids)
|
294
|
+
output_ids.resize((new_size,))
|
295
|
+
output_vectors.resize((new_size, vectors_dataset.shape[1]))
|
296
|
+
|
297
|
+
# Write new unique entries
|
298
|
+
for i, id_value in enumerate(chunk_ids):
|
299
|
+
id_str = id_value.decode("utf-8")
|
300
|
+
index = unique_ids[id_str]
|
301
|
+
output_ids[index] = id_value
|
302
|
+
output_vectors[index] = chunk_vectors[i]
|
303
|
+
|
304
|
+
pbar.update(end_idx - start_idx)
|
305
|
+
|
306
|
+
duplicates_removed = total_entries - len(unique_ids)
|
307
|
+
print(f"De-duplication complete. Unique entries: {len(unique_ids)}")
|
308
|
+
print(f"Number of duplicates removed: {duplicates_removed}")
|
309
|
+
return duplicates_removed
|
310
|
+
|
311
|
+
|
312
|
+
def create_id_mapping(correct_ids: Set[str]) -> Dict[str, str]:
|
313
|
+
"""
|
314
|
+
Create a mapping from incorrectly underscored IDs to correct IDs.
|
315
|
+
Only considers IDs that contain a period (.) in the correct version.
|
316
|
+
|
317
|
+
Args:
|
318
|
+
correct_ids (Set[str]): Set of correct IDs.
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
Dict[str, str]: A dictionary mapping incorrectly underscored IDs to correct IDs.
|
322
|
+
|
323
|
+
Raises:
|
324
|
+
ValueError: If the set of correct IDs is empty.
|
325
|
+
"""
|
326
|
+
if not correct_ids:
|
327
|
+
raise ValueError("The set of correct IDs is empty")
|
328
|
+
|
329
|
+
id_mapping = {}
|
330
|
+
for correct_id in correct_ids:
|
331
|
+
if "." in correct_id:
|
332
|
+
incorrect_id = correct_id.replace(".", "_")
|
333
|
+
id_mapping[incorrect_id] = correct_id
|
334
|
+
|
335
|
+
print(f"Created mapping for {len(id_mapping)} IDs")
|
336
|
+
return id_mapping
|
337
|
+
|
338
|
+
|
339
|
+
def fix_underscored_ids_in_h5(
|
340
|
+
input_file: str,
|
341
|
+
output_file: str,
|
342
|
+
id_mapping: Dict[str, str],
|
343
|
+
chunk_size: int = 10000,
|
344
|
+
) -> None:
|
345
|
+
"""
|
346
|
+
Fix underscored IDs in a large H5 file based on the provided mapping and write to a new file.
|
347
|
+
|
348
|
+
This function processes the input file in chunks to minimize memory usage and improve
|
349
|
+
efficiency when working with large files.
|
350
|
+
|
351
|
+
Args:
|
352
|
+
input_file (str): Path to the input H5 file to be fixed.
|
353
|
+
output_file (str): Path to the output H5 file with fixed IDs.
|
354
|
+
id_mapping (Dict[str, str]): A dictionary mapping incorrectly underscored IDs to correct IDs.
|
355
|
+
chunk_size (int): Number of rows to process at a time. Default is 10000.
|
356
|
+
|
357
|
+
Raises:
|
358
|
+
FileNotFoundError: If the specified input file does not exist.
|
359
|
+
KeyError: If the 'ids' dataset is not found in the H5 file.
|
360
|
+
FileExistsError: If the output file already exists.
|
361
|
+
"""
|
362
|
+
if os.path.exists(output_file):
|
363
|
+
raise FileExistsError(
|
364
|
+
f"Output file '{output_file}' already exists. Please choose a different output file name."
|
365
|
+
)
|
366
|
+
|
367
|
+
with h5py.File(input_file, "r") as in_file, h5py.File(output_file, "w") as out_file:
|
368
|
+
if "ids" not in in_file or "vectors" not in in_file:
|
369
|
+
raise KeyError("The 'ids' or 'vectors' dataset is not found in the H5 file")
|
370
|
+
|
371
|
+
total_rows = in_file["ids"].shape[0]
|
372
|
+
vector_dim = in_file["vectors"].shape[1]
|
373
|
+
|
374
|
+
# Create datasets in the output file
|
375
|
+
out_ids = out_file.create_dataset(
|
376
|
+
"ids", shape=(total_rows,), dtype=h5py.special_dtype(vlen=str), chunks=True
|
377
|
+
)
|
378
|
+
out_vectors = out_file.create_dataset(
|
379
|
+
"vectors",
|
380
|
+
shape=(total_rows, vector_dim),
|
381
|
+
dtype=in_file["vectors"].dtype,
|
382
|
+
chunks=True,
|
383
|
+
)
|
384
|
+
|
385
|
+
print("Processing file in chunks")
|
386
|
+
for start_idx in tqdm(
|
387
|
+
range(0, total_rows, chunk_size), desc="Fixing IDs", unit="chunk"
|
388
|
+
):
|
389
|
+
end_idx = min(start_idx + chunk_size, total_rows)
|
390
|
+
|
391
|
+
# Read chunk of IDs and vectors
|
392
|
+
chunk_ids = in_file["ids"][start_idx:end_idx]
|
393
|
+
chunk_vectors = in_file["vectors"][start_idx:end_idx]
|
394
|
+
|
395
|
+
# Fix IDs in the chunk
|
396
|
+
fixed_chunk_ids = [
|
397
|
+
id_mapping.get(id.decode(), id.decode()) for id in chunk_ids
|
398
|
+
]
|
399
|
+
|
400
|
+
# Write fixed IDs and corresponding vectors to output file
|
401
|
+
out_ids[start_idx:end_idx] = fixed_chunk_ids
|
402
|
+
out_vectors[start_idx:end_idx] = chunk_vectors
|
403
|
+
|
404
|
+
print("IDs have been successfully fixed and written to the new file")
|
405
|
+
|
406
|
+
|
407
|
+
def optimize_protein_embedding_chunks(
|
408
|
+
src_path: str | Path, dst_path: str | Path, train_batch_size: int = 16384
|
409
|
+
):
|
410
|
+
"""
|
411
|
+
Create a new HDF5 file with chunking optimized for protein embedding access during training.
|
412
|
+
|
413
|
+
This function specifically optimizes HDF5 chunking for protein embedding files that contain:
|
414
|
+
- 'ids': Protein identifiers
|
415
|
+
- 'vectors': Protein embeddings (typically 1024-dimensional ProtT5 vectors)
|
416
|
+
|
417
|
+
The chunking strategy is optimized for the training access pattern where we load
|
418
|
+
batches of protein embeddings sequentially. Poor chunk sizing can severely impact
|
419
|
+
performance because:
|
420
|
+
1. If chunks are too small (e.g., 64 proteins), reading a training batch of 16384
|
421
|
+
proteins requires reading 256 separate chunks from disk
|
422
|
+
2. If chunks are poorly shaped (e.g., 27317x4), reading a single protein's embedding
|
423
|
+
requires loading data for thousands of other proteins
|
424
|
+
|
425
|
+
Args:
|
426
|
+
src_path: Source HDF5 file path containing protein embeddings
|
427
|
+
dst_path: Destination HDF5 file path for optimized version
|
428
|
+
train_batch_size: Batch size used during training (default: 16384)
|
429
|
+
This should match the train_batch_size in your training config
|
430
|
+
|
431
|
+
Raises:
|
432
|
+
FileNotFoundError: If the source file doesn't exist
|
433
|
+
KeyError: If required datasets are missing
|
434
|
+
FileExistsError: If the destination file already exists
|
435
|
+
ValueError: If the input file contains empty datasets
|
436
|
+
"""
|
437
|
+
src_path = Path(src_path)
|
438
|
+
dst_path = Path(dst_path)
|
439
|
+
|
440
|
+
if dst_path.exists():
|
441
|
+
raise FileExistsError(f"Output file '{dst_path}' already exists")
|
442
|
+
|
443
|
+
print(f"Optimizing chunking for {src_path.name}")
|
444
|
+
print(f"Writing to {dst_path}")
|
445
|
+
|
446
|
+
with h5py.File(src_path, "r") as src, h5py.File(dst_path, "w") as dst:
|
447
|
+
# Get total size for progress bar
|
448
|
+
if "ids" not in src or "vectors" not in src:
|
449
|
+
raise KeyError("Input file must contain 'ids' and 'vectors' datasets")
|
450
|
+
|
451
|
+
total_vectors = src["vectors"].shape[0]
|
452
|
+
if total_vectors == 0:
|
453
|
+
raise ValueError("Input file contains empty datasets")
|
454
|
+
|
455
|
+
# Calculate optimal chunk size (min of dataset size and batch size)
|
456
|
+
chunk_size = min(total_vectors, train_batch_size)
|
457
|
+
|
458
|
+
# Copy ids with optimized chunking
|
459
|
+
print("Copying ids dataset...")
|
460
|
+
dst.create_dataset(
|
461
|
+
"ids",
|
462
|
+
data=src["ids"][:],
|
463
|
+
chunks=(chunk_size,),
|
464
|
+
dtype=h5py.special_dtype(vlen=str),
|
465
|
+
)
|
466
|
+
|
467
|
+
# Create vectors dataset with optimized chunking
|
468
|
+
print("Creating vectors dataset...")
|
469
|
+
vectors_shape = src["vectors"].shape
|
470
|
+
dst.create_dataset(
|
471
|
+
"vectors",
|
472
|
+
shape=vectors_shape,
|
473
|
+
chunks=(chunk_size, vectors_shape[1]),
|
474
|
+
dtype=np.float32,
|
475
|
+
)
|
476
|
+
|
477
|
+
# Copy vectors in chunks to manage memory
|
478
|
+
print("Copying vectors dataset...")
|
479
|
+
for i in tqdm(range(0, total_vectors, chunk_size)):
|
480
|
+
end_idx = min(i + chunk_size, total_vectors)
|
481
|
+
dst["vectors"][i:end_idx] = src["vectors"][i:end_idx].astype(np.float32)
|
482
|
+
|
483
|
+
# Verify the copy
|
484
|
+
print("Verifying copy...")
|
485
|
+
with h5py.File(src_path, "r") as src, h5py.File(dst_path, "r") as dst:
|
486
|
+
assert np.all(src["ids"][:5] == dst["ids"][:5]), "IDs don't match"
|
487
|
+
assert np.allclose(
|
488
|
+
src["vectors"][:5], dst["vectors"][:5], rtol=1e-6, atol=1e-6
|
489
|
+
), "Vectors don't match"
|
490
|
+
print(f"Original chunks: {src['vectors'].chunks}")
|
491
|
+
print(f"New chunks: {dst['vectors'].chunks}")
|
492
|
+
|
493
|
+
print("Done!")
|
494
|
+
|
495
|
+
|
496
|
+
def test_protein_embedding_read_speed(filepath: str, batch_size: int = 16384):
|
497
|
+
"""
|
498
|
+
Test read speed for random batches of protein embeddings from an HDF5 file.
|
499
|
+
|
500
|
+
This function performs a simple benchmark by reading 50 random batches of protein
|
501
|
+
embeddings and calculating the average read time. This can be useful for:
|
502
|
+
1. Comparing different chunking strategies
|
503
|
+
2. Validating I/O performance after file optimization
|
504
|
+
3. Debugging slow read performance
|
505
|
+
|
506
|
+
Args:
|
507
|
+
filepath: Path to the HDF5 file containing protein embeddings.
|
508
|
+
Must contain a 'vectors' dataset.
|
509
|
+
batch_size: Number of embeddings to read in each batch.
|
510
|
+
Default is 16384 to match typical training batch sizes.
|
511
|
+
|
512
|
+
Prints:
|
513
|
+
Average time taken to read a batch of embeddings.
|
514
|
+
|
515
|
+
Raises:
|
516
|
+
FileNotFoundError: If the specified file doesn't exist
|
517
|
+
KeyError: If the file doesn't contain a 'vectors' dataset
|
518
|
+
ValueError: If batch_size is larger than the total number of vectors
|
519
|
+
"""
|
520
|
+
with h5py.File(filepath, "r") as f:
|
521
|
+
vectors = f["vectors"]
|
522
|
+
total_vectors = vectors.shape[0]
|
523
|
+
|
524
|
+
if batch_size > total_vectors:
|
525
|
+
raise ValueError(
|
526
|
+
f"Batch size ({batch_size}) cannot be larger than "
|
527
|
+
f"total number of vectors ({total_vectors})"
|
528
|
+
)
|
529
|
+
|
530
|
+
read_times = []
|
531
|
+
|
532
|
+
for _ in tqdm(range(50), desc="Testing read speed", unit="batch"):
|
533
|
+
batch_start = time.time()
|
534
|
+
idx = np.random.randint(0, total_vectors - batch_size)
|
535
|
+
_ = vectors[idx : idx + batch_size]
|
536
|
+
read_times.append(time.time() - batch_start)
|
537
|
+
|
538
|
+
avg_time = sum(read_times) / len(read_times)
|
539
|
+
print(f"\nResults:")
|
540
|
+
print(f" Average time per read: {avg_time:.3f} seconds")
|
541
|
+
print(f" Min time per read: {min(read_times):.3f} seconds")
|
542
|
+
print(f" Max time per read: {max(read_times):.3f} seconds")
|
dayhoff_tools/kegg.py
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
import sqlite3
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
import pandas as pd
|
5
|
+
|
6
|
+
|
7
|
+
def get_ko2gene_df(db: str, ko: str | list[str] | None = None) -> pd.DataFrame:
|
8
|
+
"""Specialized function that extracts KO-to-gene mappings from a SQLite database,
|
9
|
+
and returns them as a dataframe.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
db: Path to an SQLite database file that contains a table called `gene_to_ko`.
|
13
|
+
ko: KO or list of KOs to query. If None, all KOs will be queried.
|
14
|
+
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
pd.DataFrame: KO to gene mappings.
|
18
|
+
"""
|
19
|
+
if type(ko) == str:
|
20
|
+
ko = [ko]
|
21
|
+
|
22
|
+
conn = sqlite3.connect(db)
|
23
|
+
|
24
|
+
if ko is not None:
|
25
|
+
query = (
|
26
|
+
f"SELECT gene,ko FROM gene_to_ko WHERE ko IN ({','.join('?' * len(ko))})"
|
27
|
+
)
|
28
|
+
result_df = pd.read_sql_query(
|
29
|
+
query, conn, params=ko # type:ignore
|
30
|
+
)
|
31
|
+
else:
|
32
|
+
query = f"SELECT gene,ko FROM gene_to_ko"
|
33
|
+
result_df = pd.read_sql_query(query, conn)
|
34
|
+
|
35
|
+
conn.close()
|
36
|
+
|
37
|
+
return result_df
|
dayhoff_tools/logs.py
ADDED
@@ -0,0 +1,27 @@
|
|
1
|
+
import logging
|
2
|
+
import sys
|
3
|
+
|
4
|
+
|
5
|
+
def configure_logs(logging_level=logging.INFO):
|
6
|
+
"""Configure logging"""
|
7
|
+
logging.getLogger().handlers.clear()
|
8
|
+
logging.basicConfig(
|
9
|
+
level=logging_level, # Level of logging you want to capture
|
10
|
+
format="%(asctime)s %(levelname)s %(name)s: %(message)s", # Format of the log message
|
11
|
+
stream=sys.stdout,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
def test_logs():
|
16
|
+
"""Ensure that logging and printing work as expected."""
|
17
|
+
# Create a logger object
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
# Log some messages
|
21
|
+
logger.debug(
|
22
|
+
"This is a debug message"
|
23
|
+
) # Will not be logged due to the INFO level setting
|
24
|
+
logger.info("This is an info message")
|
25
|
+
logger.warning("This is a warning message")
|
26
|
+
logger.error("This is an error message")
|
27
|
+
logger.critical("This is a critical message")
|