dayhoff-tools 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,751 @@
1
+ import gzip
2
+ import io
3
+ import multiprocessing
4
+ import os
5
+ import re
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+ from typing import Iterator
9
+
10
+ import h5py
11
+ import numpy as np
12
+ from Bio import PDB, SeqIO
13
+ from tqdm import tqdm
14
+
15
+
16
+ @dataclass
17
+ class PDBData:
18
+ """Stores parsed PDB file data."""
19
+
20
+ atom_coords: np.ndarray
21
+ uncertainty: np.ndarray
22
+ id: str
23
+ aa_sequence: str
24
+
25
+ @property
26
+ def atom_count(self) -> int:
27
+ return len(self.atom_coords)
28
+
29
+
30
+ class PDBParser:
31
+ """Parses PDB files and extracts relevant information."""
32
+
33
+ def __init__(self):
34
+ self.parser = PDB.PDBParser(QUIET=True) # type: ignore
35
+
36
+ def parse(self, pdb_file: str, backbone_only: bool = False) -> PDBData:
37
+ """
38
+ Parse a PDB file and extract data for a single-chain protein.
39
+
40
+ Args:
41
+ pdb_file: Path to the PDB file.
42
+ backbone_only: If True, only extract coordinates for alpha carbons (CA).
43
+
44
+ Returns:
45
+ PDBData containing parsed PDB data.
46
+
47
+ Raises:
48
+ ValueError: If the PDB file is invalid or doesn't meet criteria.
49
+ """
50
+ with self._open_pdb_file(pdb_file) as file:
51
+ try:
52
+ structure = self.parser.get_structure("protein", file)
53
+ model = next(structure.get_models()) # type: ignore
54
+ except StopIteration:
55
+ raise ValueError("Invalid PDB file: No models found.")
56
+
57
+ chains = list(model.get_chains())
58
+
59
+ if len(chains) != 1:
60
+ raise ValueError(
61
+ f"Expected a single chain, but found {len(chains)} chains."
62
+ )
63
+
64
+ chain = chains[0]
65
+
66
+ atom_coords, uncertainty = self._extract_coords_and_uncertainty(
67
+ chain, backbone_only
68
+ )
69
+
70
+ id = self._extract_id(pdb_file)
71
+ aa_sequence = self._extract_aa_sequence(pdb_file)
72
+
73
+ return PDBData(
74
+ atom_coords=atom_coords,
75
+ uncertainty=uncertainty,
76
+ id=id,
77
+ aa_sequence=aa_sequence,
78
+ )
79
+
80
+ def _extract_coords_and_uncertainty(
81
+ self, chain: PDB.Chain.Chain, backbone_only: bool = False
82
+ ) -> tuple[np.ndarray, np.ndarray]:
83
+ """
84
+ Extract atom coordinates and uncertainty values from a chain.
85
+
86
+ Args:
87
+ chain: A Bio.PDB.Chain object.
88
+ backbone_only: If True, only extract coordinates for alpha carbons (CA).
89
+
90
+ Returns:
91
+ A tuple containing two NumPy arrays:
92
+ - atom_coords: Array of atom coordinates with shape (n_atoms, 3).
93
+ - uncertainty: Array of uncertainty values (B-factors) with shape (n_atoms,).
94
+ """
95
+ atoms = self._get_atoms(chain, backbone_only)
96
+
97
+ atom_coords = []
98
+ uncertainty = []
99
+ for atom in atoms:
100
+ atom_coords.append(atom.coord)
101
+ uncertainty.append(atom.bfactor if atom.bfactor is not None else 0.0)
102
+
103
+ return np.array(atom_coords), np.array(uncertainty)
104
+
105
+ @staticmethod
106
+ def _get_atoms(
107
+ chain: PDB.Chain.Chain, backbone_only: bool
108
+ ) -> Iterator[PDB.Atom.Atom]:
109
+ """Get atoms from the chain based on the backbone_only flag."""
110
+ if backbone_only:
111
+ return (residue["CA"] for residue in chain if "CA" in residue)
112
+ return (atom for residue in chain for atom in residue)
113
+
114
+ def _extract_id(self, file_path: str) -> str:
115
+ """
116
+ Extract UniProt ID from PDB file TITLE lines.
117
+
118
+ Args:
119
+ file_path: Path to the PDB file.
120
+
121
+ Returns:
122
+ UniProt ID extracted from the PDB file.
123
+
124
+ Raises:
125
+ ValueError: If UniProt ID is not found.
126
+ """
127
+ with self._open_pdb_file(file_path) as file:
128
+ title = ""
129
+ for line in file:
130
+ if line.startswith("TITLE"): # type: ignore
131
+ title += line[10:].strip() # type: ignore
132
+ elif not line.startswith("TITLE") and title: # type: ignore
133
+ break
134
+
135
+ match = re.search(r"\(([A-Z0-9]+)\)$", title)
136
+ if match:
137
+ return match.group(1)
138
+ raise ValueError("UniProt ID not found in the PDB file.")
139
+
140
+ def _extract_aa_sequence(self, file_path: str) -> str:
141
+ """
142
+ Extract amino acid sequence from PDB file SEQRES records.
143
+
144
+ Args:
145
+ file_path: Path to the PDB file.
146
+
147
+ Returns:
148
+ Amino acid sequence extracted from the PDB file.
149
+
150
+ Raises:
151
+ ValueError: If sequence is not found.
152
+ """
153
+ with self._open_pdb_file(file_path) as file:
154
+ for record in SeqIO.parse(file, "pdb-seqres"):
155
+ return str(record.seq)
156
+ raise ValueError("Amino acid sequence not found in the PDB file.")
157
+
158
+ @staticmethod
159
+ def _open_pdb_file(file_path: str) -> io.TextIOWrapper | gzip.GzipFile:
160
+ """
161
+ Open a PDB file, handling both .pdb and .pdb.gz formats.
162
+
163
+ Args:
164
+ file_path: Path to the PDB file.
165
+
166
+ Returns:
167
+ File-like object containing the PDB data.
168
+ """
169
+ if file_path.endswith(".gz"):
170
+ return gzip.open(file_path, "rt") # type: ignore
171
+ return open(file_path, "r")
172
+
173
+
174
+ class HDF5Writer:
175
+ """Writes protein data to an HDF5 file."""
176
+
177
+ def __init__(self, output_file: str, total_proteins: int):
178
+ """
179
+ Initialize the HDF5Writer.
180
+
181
+ Args:
182
+ output_file: Path to the output HDF5 file.
183
+ total_proteins: Total number of proteins to be processed.
184
+ """
185
+ self.output_file = output_file
186
+ self.total_proteins = total_proteins
187
+ self.file = None
188
+
189
+ def create_datasets(self):
190
+ """
191
+ Create resizable datasets in the HDF5 file.
192
+
193
+ Creates the following datasets:
194
+ - ids: UniProt IDs of the proteins (string)
195
+ - aa_sequences: Amino acid sequences of the proteins (string)
196
+ - atom_counts: Number of atoms in each protein (integer)
197
+ - prot_start_idx: Starting index of each protein's atoms in the atom_coords and uncertainty datasets (integer)
198
+ - atom_coords: 3D coordinates of atoms for all proteins (float)
199
+ - uncertainty: B-factors or uncertainty values for each atom (float)
200
+ """
201
+ compression = "gzip"
202
+ compression_opts = 4 # Compression level (1-9)
203
+ chunk_size = min(1000, self.total_proteins)
204
+
205
+ self._create_string_dataset("ids", chunk_size, compression, compression_opts)
206
+ self._create_string_dataset(
207
+ "aa_sequences", chunk_size, compression, compression_opts
208
+ )
209
+ self._create_int_dataset(
210
+ "atom_counts", chunk_size, compression, compression_opts
211
+ )
212
+ self._create_int_dataset(
213
+ "prot_start_idx", chunk_size, compression, compression_opts
214
+ )
215
+ self._create_float_dataset(
216
+ "atom_coords", (1000, 3), compression, compression_opts
217
+ )
218
+ self._create_float_dataset(
219
+ "uncertainty", (1000,), compression, compression_opts
220
+ )
221
+
222
+ def _create_string_dataset(
223
+ self, name: str, chunk_size: int, compression: str, compression_opts: int
224
+ ):
225
+ self.file.create_dataset(
226
+ name,
227
+ (0,),
228
+ maxshape=(None,),
229
+ dtype=h5py.string_dtype(encoding="utf-8"),
230
+ chunks=(chunk_size,),
231
+ compression=compression,
232
+ compression_opts=compression_opts,
233
+ )
234
+
235
+ def _create_int_dataset(
236
+ self, name: str, chunk_size: int, compression: str, compression_opts: int
237
+ ):
238
+ self.file.create_dataset(
239
+ name,
240
+ (0,),
241
+ maxshape=(None,),
242
+ dtype=int,
243
+ chunks=(chunk_size,),
244
+ compression=compression,
245
+ compression_opts=compression_opts,
246
+ )
247
+
248
+ def _create_float_dataset(
249
+ self, name: str, chunk: tuple, compression: str, compression_opts: int
250
+ ):
251
+ self.file.create_dataset(
252
+ name,
253
+ (0, *chunk[1:]),
254
+ maxshape=(None, *chunk[1:]),
255
+ dtype=float,
256
+ chunks=chunk,
257
+ compression=compression,
258
+ compression_opts=compression_opts,
259
+ )
260
+
261
+ def update_datasets(self, start_idx: int, end_idx: int, data: list[PDBData]):
262
+ """
263
+ Update datasets in the HDF5 file with new data, resizing as necessary.
264
+
265
+ Args:
266
+ start_idx: Starting index for updating the datasets.
267
+ end_idx: Ending index for updating the datasets.
268
+ data: List of PDBData objects containing the new data to be added.
269
+
270
+ This method updates all datasets, including prot_start_idx, which stores
271
+ the starting index of each protein's atoms in the atom_coords and uncertainty datasets.
272
+ """
273
+ if not data:
274
+ raise ValueError("No data provided to update_datasets")
275
+
276
+ if any(pdb is None for pdb in data):
277
+ raise ValueError("Invalid data: None values are not allowed")
278
+
279
+ current_size = self.file["ids"].shape[0]
280
+ new_size = max(current_size, end_idx)
281
+
282
+ # Resize datasets if necessary
283
+ if new_size > current_size:
284
+ self.file["ids"].resize((new_size,))
285
+ self.file["aa_sequences"].resize((new_size,))
286
+ self.file["atom_counts"].resize((new_size,))
287
+ self.file["prot_start_idx"].resize((new_size,))
288
+
289
+ # Update datasets
290
+ self.file["ids"][start_idx:end_idx] = [pdb.id for pdb in data]
291
+ self.file["aa_sequences"][start_idx:end_idx] = [pdb.aa_sequence for pdb in data]
292
+
293
+ atom_counts = [pdb.atom_count for pdb in data]
294
+ self.file["atom_counts"][start_idx:end_idx] = atom_counts
295
+
296
+ # Update prot_start_idx
297
+ if start_idx == 0:
298
+ self.file["prot_start_idx"][0] = 0
299
+ else:
300
+ previous_start = self.file["prot_start_idx"][start_idx - 1]
301
+ previous_count = self.file["atom_counts"][start_idx - 1]
302
+ self.file["prot_start_idx"][start_idx] = previous_start + previous_count
303
+
304
+ cumulative_counts = np.cumsum([0] + atom_counts[:-1])
305
+ self.file["prot_start_idx"][start_idx + 1 : end_idx] = (
306
+ self.file["prot_start_idx"][start_idx] + cumulative_counts[1:]
307
+ )
308
+
309
+ # Calculate total atoms for the current chunk
310
+ total_atoms = sum(atom_counts)
311
+
312
+ # Resize atom_coords and uncertainty datasets
313
+ current_atoms = self.file["atom_coords"].shape[0]
314
+ new_atoms = current_atoms + total_atoms
315
+ self.file["atom_coords"].resize((new_atoms, 3))
316
+ self.file["uncertainty"].resize((new_atoms,))
317
+
318
+ # Update atom_coords and uncertainty datasets
319
+ atom_index = current_atoms
320
+ for pdb in data:
321
+ self.file["atom_coords"][
322
+ atom_index : atom_index + pdb.atom_count
323
+ ] = pdb.atom_coords
324
+ self.file["uncertainty"][
325
+ atom_index : atom_index + pdb.atom_count
326
+ ] = pdb.uncertainty
327
+ atom_index += pdb.atom_count
328
+
329
+ def __enter__(self):
330
+ self.file = h5py.File(self.output_file, "w")
331
+ self.create_datasets()
332
+ return self
333
+
334
+ def __exit__(self, exc_type, exc_val, exc_tb):
335
+ if self.file:
336
+ self.file.close()
337
+
338
+
339
+ class PDBFolderProcessor:
340
+ """Processes multiple PDB files and writes data to an HDF5 file."""
341
+
342
+ def __init__(
343
+ self,
344
+ pdb_dir: str,
345
+ output_file: str,
346
+ chunk_size: int = 1000,
347
+ id_set: set[str] | None = None,
348
+ backbone_only: bool = False,
349
+ ):
350
+ """
351
+ Initialize the PDBFolderProcessor.
352
+
353
+ Args:
354
+ pdb_dir: Path to the directory containing PDB files.
355
+ output_file: Path to the output HDF5 file.
356
+ chunk_size: Number of PDB files to process in each chunk.
357
+ id_set: Optional set of IDs to filter PDB files.
358
+ backbone_only: If True, only extract coordinates for alpha carbons (CA).
359
+ """
360
+ self.pdb_dir = pdb_dir
361
+ self.output_file = output_file
362
+ self.chunk_size = chunk_size
363
+ self.parser = PDBParser()
364
+ self.id_set = id_set
365
+ self.backbone_only = backbone_only
366
+
367
+ def process(self):
368
+ """
369
+ Process PDB files and write data to HDF5 file.
370
+ """
371
+ print(f"Starting to process PDB files from {self.pdb_dir}")
372
+ pdb_files = self._get_pdb_files()
373
+ total_proteins = len(pdb_files)
374
+ print(f"Found {total_proteins} PDB files to process")
375
+
376
+ with HDF5Writer(self.output_file, total_proteins) as writer:
377
+ with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
378
+ process_single_pdb_partial = partial(self._process_single_pdb)
379
+ processed_proteins = 0
380
+ for start_idx in tqdm(
381
+ range(0, total_proteins, self.chunk_size),
382
+ desc="Processing PDB files",
383
+ unit="chunk",
384
+ ):
385
+ end_idx = min(start_idx + self.chunk_size, total_proteins)
386
+ chunk = pdb_files[start_idx:end_idx]
387
+
388
+ print(f"\nProcessing chunk {start_idx // self.chunk_size + 1}")
389
+ data = pool.map(process_single_pdb_partial, chunk)
390
+ valid_data = [item for item in data if item is not None]
391
+
392
+ if valid_data:
393
+ writer.update_datasets(
394
+ processed_proteins,
395
+ processed_proteins + len(valid_data),
396
+ valid_data,
397
+ )
398
+ processed_proteins += len(valid_data)
399
+
400
+ print(
401
+ f"Processed {processed_proteins} valid proteins out of {end_idx} total files"
402
+ )
403
+
404
+ print(
405
+ f"\nFinished processing all PDB files. Output saved to {self.output_file}"
406
+ )
407
+ print(f"Total valid proteins processed: {processed_proteins}")
408
+
409
+ def _get_pdb_files(self) -> list[str]:
410
+ """
411
+ Get a list of PDB files in the specified directory, optionally filtered by ID set.
412
+ Files are sorted by creation time to ensure consistent processing order.
413
+
414
+ Returns:
415
+ List of PDB file names sorted by creation time.
416
+ """
417
+ print("Scanning directory for PDB files...")
418
+ pdb_files = [
419
+ f for f in os.listdir(self.pdb_dir) if f.endswith((".pdb", ".pdb.gz"))
420
+ ]
421
+
422
+ if self.id_set:
423
+ pdb_files = [
424
+ f for f in pdb_files if self._extract_id_from_filename(f) in self.id_set
425
+ ]
426
+
427
+ # Sort files by creation time
428
+ pdb_files.sort(key=lambda f: os.path.getctime(os.path.join(self.pdb_dir, f)))
429
+
430
+ print(f"Found {len(pdb_files)} PDB files")
431
+ return pdb_files
432
+
433
+ @staticmethod
434
+ def _extract_id_from_filename(filename: str) -> str:
435
+ """
436
+ Extract the ID from a PDB filename.
437
+
438
+ Args:
439
+ filename: The filename of the PDB file (e.g., "AF-A3DBM5-F1-model_v4.pdb.gz").
440
+
441
+ Returns:
442
+ The extracted ID (e.g., "A3DBM5").
443
+
444
+ Raises:
445
+ ValueError: If the filename doesn't match the expected format.
446
+ """
447
+ match = re.match(r"AF-([A-Z0-9]+)-F", filename)
448
+ if match:
449
+ return match.group(1)
450
+ raise ValueError(f"Invalid filename format: {filename}")
451
+
452
+ def _process_single_pdb(self, pdb_file: str) -> PDBData | None:
453
+ """
454
+ Process a single PDB file.
455
+
456
+ Args:
457
+ pdb_file: PDB file name to process.
458
+
459
+ Returns:
460
+ PDBData object containing parsed PDB data, or None if processing fails.
461
+ """
462
+ try:
463
+ file_path = os.path.join(self.pdb_dir, pdb_file)
464
+ parsed_data = self.parser.parse(file_path, backbone_only=self.backbone_only)
465
+ return parsed_data
466
+ except Exception as e:
467
+ return None
468
+
469
+
470
+ def parse_pdb_folder_to_h5(
471
+ pdb_dir: str,
472
+ output_file: str,
473
+ chunk_size: int = 1000,
474
+ id_set: set[str] | None = None,
475
+ backbone_only: bool = False,
476
+ ):
477
+ """
478
+ Create an HDF5 file containing data from multiple PDB files.
479
+
480
+ This function processes a folder of PDB files and stores their structural
481
+ information in an efficient HDF5 format. The resulting file is optimized
482
+ for fast access by machine learning dataloaders.
483
+
484
+ Args:
485
+ pdb_dir: Path to the directory containing PDB files.
486
+ output_file: Path to the output HDF5 file.
487
+ chunk_size: Number of PDB files to process in each chunk.
488
+ id_set: Optional set of IDs to filter PDB files.
489
+ backbone_only: If True, only extract coordinates for alpha carbons (CA).
490
+
491
+ H5 File Structure:
492
+ The output HDF5 file contains the following datasets:
493
+ - ids: UniProt IDs of the proteins (string) [n_proteins]
494
+ - aa_sequences: Amino acid sequences of the proteins (string) [n_proteins]
495
+ - atom_counts: Number of atoms in each protein (integer) [n_proteins]
496
+ - prot_start_idx: Starting index of each protein's atoms (integer) [n_proteins]
497
+ - atom_coords: 3D coordinates of atoms for all proteins (float) [total_atoms, 3]
498
+ - uncertainty: B-factors or uncertainty values for each atom (float) [total_atoms]
499
+
500
+ Benefits for ML Dataloaders:
501
+ 1. Efficient Storage: Atom coordinates are stored sequentially in a single
502
+ contiguous array, allowing for efficient disk I/O and memory usage.
503
+ 2. Fast Retrieval: Using atom_counts and prot_start_idx, dataloaders
504
+ can quickly locate and extract coordinates for specific proteins without
505
+ loading the entire dataset.
506
+ 3. Vectorized Operations: The sequential storage of atom coordinates enables
507
+ efficient vectorized operations on entire proteins or batches of proteins.
508
+ 4. Memory Mapping: The contiguous storage allows for easy memory mapping,
509
+ enabling access to large datasets without loading them entirely into RAM.
510
+ 5. Batch Processing: Dataloaders can efficiently create batches of proteins
511
+ with varying numbers of atoms, using the atom count information to slice
512
+ the coordinate array.
513
+
514
+ Retrieving Atom Coordinates:
515
+ To get the atom coordinates for the i-th protein:
516
+ 1. start_idx = prot_start_idx[i]
517
+ 2. end_idx = start_idx + atom_counts[i]
518
+ 3. coords = atom_coords[start_idx:end_idx]
519
+
520
+ This approach allows for quick access to protein data without loading or
521
+ processing unnecessary information, making it ideal for ML tasks involving
522
+ protein structural data, such as structure prediction or function analysis.
523
+ """
524
+ print(f"Starting PDB folder processing")
525
+ print(f"Input directory: {pdb_dir}")
526
+ print(f"Output file: {output_file}")
527
+ print(f"Chunk size: {chunk_size}")
528
+ print(f"Backbone only: {backbone_only}")
529
+ if id_set:
530
+ print(f"Filtering PDB files using {len(id_set)} provided IDs")
531
+
532
+ processor = PDBFolderProcessor(
533
+ pdb_dir, output_file, chunk_size, id_set, backbone_only
534
+ )
535
+ processor.process()
536
+
537
+ print("PDB folder processing completed successfully")
538
+
539
+
540
+ def filter_structural_h5(
541
+ input_file: str, output_file: str, id_set: set[str], chunk_size: int = 1000000
542
+ ) -> set[str]:
543
+ """
544
+ WARNING: This function takes forever on files of just 3gb or so. Problem unclear.
545
+ I'll leave it here because it works on small data, so it's a good starting point for future work.
546
+
547
+ Filter an H5 file generated by parse_pdb_folder_to_h5 to include only specified IDs.
548
+
549
+ Args:
550
+ input_file: Path to the input H5 file.
551
+ output_file: Path to the output H5 file.
552
+ id_set: Set of IDs to include in the output file.
553
+ chunk_size: Number of atoms to process in each chunk.
554
+
555
+ Returns:
556
+ Set of IDs that were not found in the input file.
557
+
558
+ Raises:
559
+ FileNotFoundError: If the input file does not exist.
560
+ FileExistsError: If the output file already exists.
561
+ ValueError: If the input file is missing required datasets or contains unexpected datasets.
562
+ OSError: If there's an error reading from or writing to the H5 files.
563
+
564
+ Note:
565
+ The function expects the input H5 file to have the following datasets:
566
+ - ids
567
+ - aa_sequences
568
+ - atom_counts
569
+ - prot_start_idx
570
+ - atom_coords
571
+ - uncertainty
572
+
573
+ If any of these datasets are missing or if there are additional unexpected
574
+ datasets, a ValueError will be raised.
575
+
576
+ The output file will have the same structure as the input file, but only
577
+ containing data for the specified IDs. The prot_start_idx dataset will be
578
+ recalculated to reflect the new positions of proteins in the filtered file.
579
+ """
580
+ print(f"\nStarting H5 filtering process:")
581
+ print(f"Input file: {input_file}")
582
+ print(f"Output file: {output_file}")
583
+ print(f"Number of IDs to filter: {len(id_set)}")
584
+ print(f"Chunk size: {chunk_size} atoms")
585
+
586
+ if not os.path.exists(input_file):
587
+ raise FileNotFoundError(f"Input file not found: {input_file}")
588
+
589
+ if os.path.exists(output_file):
590
+ raise FileExistsError(f"Output file already exists: {output_file}")
591
+
592
+ expected_datasets = [
593
+ "ids",
594
+ "aa_sequences",
595
+ "atom_counts",
596
+ "prot_start_idx",
597
+ "atom_coords",
598
+ "uncertainty",
599
+ ]
600
+
601
+ with h5py.File(input_file, "r") as input_h5:
602
+ # Validate datasets
603
+ missing_datasets = set(expected_datasets) - set(input_h5.keys())
604
+ if missing_datasets:
605
+ raise ValueError(
606
+ f"Missing required datasets: {', '.join(missing_datasets)}"
607
+ )
608
+
609
+ unexpected_datasets = set(input_h5.keys()) - set(expected_datasets)
610
+ if unexpected_datasets:
611
+ raise ValueError(
612
+ f"Unexpected datasets found: {', '.join(unexpected_datasets)}"
613
+ )
614
+
615
+ # Find matching indices
616
+ ids = input_h5["ids"][:]
617
+ id_list = [id.decode() for id in ids]
618
+ indices = [i for i, id in enumerate(id_list) if id in id_set]
619
+
620
+ if not indices:
621
+ return id_set
622
+
623
+ print(f"Found {len(indices)} matching IDs")
624
+
625
+ with h5py.File(output_file, "w") as output_h5:
626
+ # Copy basic datasets
627
+ output_h5.create_dataset("ids", data=input_h5["ids"][indices])
628
+ output_h5.create_dataset(
629
+ "aa_sequences", data=input_h5["aa_sequences"][indices]
630
+ )
631
+ output_h5.create_dataset(
632
+ "atom_counts", data=input_h5["atom_counts"][indices]
633
+ )
634
+
635
+ # Calculate new prot_start_idx
636
+ atom_counts = input_h5["atom_counts"][indices]
637
+ prot_start_idx = np.zeros(len(indices), dtype=int)
638
+ np.cumsum(atom_counts[:-1], out=prot_start_idx[1:])
639
+ output_h5.create_dataset("prot_start_idx", data=prot_start_idx)
640
+
641
+ # Create output datasets for coordinates and uncertainty
642
+ total_atoms = prot_start_idx[-1] + atom_counts[-1]
643
+ output_h5.create_dataset("atom_coords", shape=(total_atoms, 3), dtype=float)
644
+ output_h5.create_dataset("uncertainty", shape=(total_atoms,), dtype=float)
645
+
646
+ # Copy atom data in chunks
647
+ output_idx = 0
648
+ for i, protein_idx in enumerate(tqdm(indices, desc="Filtering proteins")):
649
+ start_idx = input_h5["prot_start_idx"][protein_idx]
650
+ n_atoms = atom_counts[i]
651
+ end_idx = start_idx + n_atoms
652
+
653
+ # Process this protein's atoms in chunks if needed
654
+ for chunk_start in range(start_idx, end_idx, chunk_size):
655
+ chunk_end = min(chunk_start + chunk_size, end_idx)
656
+ chunk_size_actual = chunk_end - chunk_start
657
+
658
+ output_h5["atom_coords"][
659
+ output_idx : output_idx + chunk_size_actual
660
+ ] = input_h5["atom_coords"][chunk_start:chunk_end]
661
+ output_h5["uncertainty"][
662
+ output_idx : output_idx + chunk_size_actual
663
+ ] = input_h5["uncertainty"][chunk_start:chunk_end]
664
+
665
+ output_idx += chunk_size_actual
666
+
667
+ not_found = id_set - set(id_list)
668
+ return not_found
669
+
670
+
671
+ def update_cumulative_to_start_idx(input_file: str, output_file: str):
672
+ """
673
+ Update an H5 file that uses cumulative_atom_counts to use prot_start_idx instead.
674
+
675
+ This function reads the existing cumulative_atom_counts dataset, calculates the
676
+ corresponding prot_start_idx, and writes a new H5 file with the updated structure.
677
+
678
+ Args:
679
+ input_file: Path to the input H5 file with cumulative_atom_counts.
680
+ output_file: Path to the output H5 file that will use prot_start_idx.
681
+
682
+ Raises:
683
+ FileNotFoundError: If the input file does not exist.
684
+ FileExistsError: If the output file already exists.
685
+ ValueError: If the input file is missing required datasets or contains unexpected datasets.
686
+ OSError: If there's an error reading from or writing to the H5 files.
687
+ """
688
+ if not os.path.exists(input_file):
689
+ raise FileNotFoundError(f"Input file not found: {input_file}")
690
+
691
+ if os.path.exists(output_file):
692
+ raise FileExistsError(f"Output file already exists: {output_file}")
693
+
694
+ expected_datasets = [
695
+ "ids",
696
+ "aa_sequences",
697
+ "atom_counts",
698
+ "cumulative_atom_counts",
699
+ "atom_coords",
700
+ "uncertainty",
701
+ ]
702
+
703
+ with (
704
+ h5py.File(input_file, "r") as input_h5,
705
+ h5py.File(output_file, "w") as output_h5,
706
+ ):
707
+ # Check if the input file has the expected structure
708
+ missing_datasets = set(expected_datasets) - set(input_h5.keys())
709
+ if missing_datasets:
710
+ raise ValueError(
711
+ f"Missing required datasets: {', '.join(missing_datasets)}"
712
+ )
713
+
714
+ unexpected_datasets = set(input_h5.keys()) - set(expected_datasets)
715
+ if unexpected_datasets:
716
+ raise ValueError(
717
+ f"Unexpected datasets found: {', '.join(unexpected_datasets)}"
718
+ )
719
+
720
+ # Copy datasets with original compression and chunking if available
721
+ for dataset in [
722
+ "ids",
723
+ "aa_sequences",
724
+ "atom_counts",
725
+ "atom_coords",
726
+ "uncertainty",
727
+ ]:
728
+ input_dataset = input_h5[dataset]
729
+ kwargs = {}
730
+ if input_dataset.compression is not None:
731
+ kwargs["compression"] = input_dataset.compression
732
+ kwargs["compression_opts"] = input_dataset.compression_opts
733
+ if input_dataset.chunks is not None:
734
+ kwargs["chunks"] = input_dataset.chunks
735
+ output_h5.create_dataset(dataset, data=input_dataset, **kwargs)
736
+
737
+ # Calculate prot_start_idx from cumulative_atom_counts
738
+ cumulative_atom_counts = input_h5["cumulative_atom_counts"][:]
739
+ prot_start_idx = cumulative_atom_counts[:-1]
740
+
741
+ # Create prot_start_idx dataset with compression and chunking similar to cumulative_atom_counts
742
+ cumulative_dataset = input_h5["cumulative_atom_counts"]
743
+ kwargs = {}
744
+ if cumulative_dataset.compression is not None:
745
+ kwargs["compression"] = cumulative_dataset.compression
746
+ kwargs["compression_opts"] = cumulative_dataset.compression_opts
747
+ if cumulative_dataset.chunks is not None:
748
+ kwargs["chunks"] = cumulative_dataset.chunks
749
+ output_h5.create_dataset("prot_start_idx", data=prot_start_idx, **kwargs)
750
+
751
+ print(f"Updated H5 file saved to {output_file}")