bio2zarr 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bio2zarr might be problematic. Click here for more details.

bio2zarr/vcf_utils.py ADDED
@@ -0,0 +1,513 @@
1
+ from typing import IO, Any, Dict, Optional, Sequence, Union
2
+ import contextlib
3
+ import struct
4
+ import pathlib
5
+ import gzip
6
+ from dataclasses import dataclass
7
+ import os
8
+
9
+ import numpy as np
10
+ import cyvcf2
11
+ import humanfriendly
12
+
13
+ from bio2zarr.typing import PathType
14
+
15
+ CSI_EXTENSION = ".csi"
16
+ TABIX_EXTENSION = ".tbi"
17
+ TABIX_LINEAR_INDEX_INTERVAL_SIZE = 1 << 14 # 16kb interval size
18
+
19
+
20
+ def ceildiv(a: int, b: int) -> int:
21
+ """Safe integer ceil function"""
22
+ return -(-a // b)
23
+
24
+
25
+ def get_file_offset(vfp: int) -> int:
26
+ """Convert a block compressed virtual file pointer to a file offset."""
27
+ address_mask = 0xFFFFFFFFFFFF
28
+ return vfp >> 16 & address_mask
29
+
30
+
31
+ def read_bytes_as_value(f: IO[Any], fmt: str, nodata: Optional[Any] = None) -> Any:
32
+ """Read bytes using a `struct` format string and return the unpacked data value.
33
+
34
+ Parameters
35
+ ----------
36
+ f : IO[Any]
37
+ The IO stream to read bytes from.
38
+ fmt : str
39
+ A Python `struct` format string.
40
+ nodata : Optional[Any], optional
41
+ The value to return in case there is no further data in the stream, by default None
42
+
43
+ Returns
44
+ -------
45
+ Any
46
+ The unpacked data value read from the stream.
47
+ """
48
+ data = f.read(struct.calcsize(fmt))
49
+ if not data:
50
+ return nodata
51
+ values = struct.Struct(fmt).unpack(data)
52
+ assert len(values) == 1
53
+ return values[0]
54
+
55
+
56
+ def read_bytes_as_tuple(f: IO[Any], fmt: str) -> Sequence[Any]:
57
+ """Read bytes using a `struct` format string and return the unpacked data values.
58
+
59
+ Parameters
60
+ ----------
61
+ f : IO[Any]
62
+ The IO stream to read bytes from.
63
+ fmt : str
64
+ A Python `struct` format string.
65
+
66
+ Returns
67
+ -------
68
+ Sequence[Any]
69
+ The unpacked data values read from the stream.
70
+ """
71
+ data = f.read(struct.calcsize(fmt))
72
+ return struct.Struct(fmt).unpack(data)
73
+
74
+
75
+ @dataclass
76
+ class Region:
77
+ contig: str
78
+ start: Optional[int] = None
79
+ end: Optional[int] = None
80
+
81
+ def __post_init__(self):
82
+ if self.start is not None:
83
+ self.start = int(self.start)
84
+ assert self.start > 0
85
+ if self.end is not None:
86
+ self.end = int(self.end)
87
+ assert self.end > self.start
88
+
89
+ def __str__(self):
90
+ s = f"{self.contig}"
91
+ if self.start is not None:
92
+ s += f":{self.start}-"
93
+ if self.end is not None:
94
+ s += str(self.end)
95
+ return s
96
+
97
+ # TODO add "parse" class methoda for when we accept regions
98
+ # as input
99
+
100
+
101
+ @dataclass
102
+ class Chunk:
103
+ cnk_beg: int
104
+ cnk_end: int
105
+
106
+
107
+ @dataclass
108
+ class CSIBin:
109
+ bin: int
110
+ loffset: int
111
+ chunks: Sequence[Chunk]
112
+
113
+
114
+ @dataclass
115
+ class CSIIndex:
116
+ min_shift: int
117
+ depth: int
118
+ aux: str
119
+ bins: Sequence[Sequence[CSIBin]]
120
+ record_counts: Sequence[int]
121
+ n_no_coor: int
122
+
123
+ def parse_vcf_aux(self):
124
+ assert len(self.aux) > 0
125
+ # The first 7 values form the Tabix header or something, but I don't
126
+ # know how to interpret what's in there. The n_ref value doesn't seem
127
+ # to correspond to the number of contigs at all anyway, so just
128
+ # ignoring for now.
129
+ # values = struct.Struct("<7i").unpack(self.aux[:28])
130
+ # tabix_header = Header(*values, 0)
131
+ names = self.aux[28:]
132
+ # Convert \0-terminated names to strings
133
+ sequence_names = [str(name, "utf-8") for name in names.split(b"\x00")[:-1]]
134
+ return sequence_names
135
+
136
+ def offsets(self) -> Any:
137
+ pseudo_bin = bin_limit(self.min_shift, self.depth) + 1
138
+
139
+ file_offsets = []
140
+ contig_indexes = []
141
+ positions = []
142
+ for contig_index, bins in enumerate(self.bins):
143
+ # bins may be in any order within a contig, so sort by loffset
144
+ for bin in sorted(bins, key=lambda b: b.loffset):
145
+ if bin.bin == pseudo_bin:
146
+ continue # skip pseudo bins
147
+ file_offset = get_file_offset(bin.loffset)
148
+ position = get_first_locus_in_bin(self, bin.bin)
149
+ file_offsets.append(file_offset)
150
+ contig_indexes.append(contig_index)
151
+ positions.append(position)
152
+
153
+ return np.array(file_offsets), np.array(contig_indexes), np.array(positions)
154
+
155
+
156
+ def bin_limit(min_shift: int, depth: int) -> int:
157
+ """Defined in CSI spec"""
158
+ return ((1 << (depth + 1) * 3) - 1) // 7
159
+
160
+
161
+ def get_first_bin_in_level(level: int) -> int:
162
+ return ((1 << level * 3) - 1) // 7
163
+
164
+
165
+ def get_level_size(level: int) -> int:
166
+ return 1 << level * 3
167
+
168
+
169
+ def get_level_for_bin(csi: CSIIndex, bin: int) -> int:
170
+ for i in range(csi.depth, -1, -1):
171
+ if bin >= get_first_bin_in_level(i):
172
+ return i
173
+ raise ValueError(f"Cannot find level for bin {bin}.") # pragma: no cover
174
+
175
+
176
+ def get_first_locus_in_bin(csi: CSIIndex, bin: int) -> int:
177
+ level = get_level_for_bin(csi, bin)
178
+ first_bin_on_level = get_first_bin_in_level(level)
179
+ level_size = get_level_size(level)
180
+ max_span = 1 << (csi.min_shift + 3 * csi.depth)
181
+ return (bin - first_bin_on_level) * (max_span // level_size) + 1
182
+
183
+
184
+ def read_csi(
185
+ file: PathType, storage_options: Optional[Dict[str, str]] = None
186
+ ) -> CSIIndex:
187
+ """Parse a CSI file into a `CSIIndex` object.
188
+
189
+ Parameters
190
+ ----------
191
+ file : PathType
192
+ The path to the CSI file.
193
+
194
+ Returns
195
+ -------
196
+ CSIIndex
197
+ An object representing a CSI index.
198
+
199
+ Raises
200
+ ------
201
+ ValueError
202
+ If the file is not a CSI file.
203
+ """
204
+ with gzip.open(file) as f:
205
+ magic = read_bytes_as_value(f, "4s")
206
+ if magic != b"CSI\x01":
207
+ raise ValueError("File not in CSI format.")
208
+
209
+ min_shift, depth, l_aux = read_bytes_as_tuple(f, "<3i")
210
+ aux = read_bytes_as_value(f, f"{l_aux}s", "")
211
+ n_ref = read_bytes_as_value(f, "<i")
212
+
213
+ pseudo_bin = bin_limit(min_shift, depth) + 1
214
+
215
+ bins = []
216
+ record_counts = []
217
+
218
+ if n_ref > 0:
219
+ for _ in range(n_ref):
220
+ n_bin = read_bytes_as_value(f, "<i")
221
+ seq_bins = []
222
+ record_count = -1
223
+ for _ in range(n_bin):
224
+ bin, loffset, n_chunk = read_bytes_as_tuple(f, "<IQi")
225
+ chunks = []
226
+ for _ in range(n_chunk):
227
+ chunk = Chunk(*read_bytes_as_tuple(f, "<QQ"))
228
+ chunks.append(chunk)
229
+ seq_bins.append(CSIBin(bin, loffset, chunks))
230
+
231
+ if bin == pseudo_bin:
232
+ assert len(chunks) == 2
233
+ n_mapped, n_unmapped = chunks[1].cnk_beg, chunks[1].cnk_end
234
+ record_count = n_mapped + n_unmapped
235
+ bins.append(seq_bins)
236
+ record_counts.append(record_count)
237
+
238
+ n_no_coor = read_bytes_as_value(f, "<Q", 0)
239
+
240
+ assert len(f.read(1)) == 0
241
+
242
+ return CSIIndex(min_shift, depth, aux, bins, record_counts, n_no_coor)
243
+
244
+
245
+ @dataclass
246
+ class Header:
247
+ n_ref: int
248
+ format: int
249
+ col_seq: int
250
+ col_beg: int
251
+ col_end: int
252
+ meta: int
253
+ skip: int
254
+ l_nm: int
255
+
256
+
257
+ @dataclass
258
+ class TabixBin:
259
+ bin: int
260
+ chunks: Sequence[Chunk]
261
+
262
+
263
+ @dataclass
264
+ class TabixIndex:
265
+ header: Header
266
+ sequence_names: Sequence[str]
267
+ bins: Sequence[Sequence[TabixBin]]
268
+ linear_indexes: Sequence[Sequence[int]]
269
+ record_counts: Sequence[int]
270
+ n_no_coor: int
271
+
272
+ def offsets(self) -> Any:
273
+ # Combine the linear indexes into one stacked array
274
+ linear_indexes = self.linear_indexes
275
+ linear_index = np.hstack([np.array(li) for li in linear_indexes])
276
+
277
+ # Create file offsets for each element in the linear index
278
+ file_offsets = np.array([get_file_offset(vfp) for vfp in linear_index])
279
+
280
+ # Calculate corresponding contigs and positions or each element in the linear index
281
+ contig_indexes = np.hstack(
282
+ [np.full(len(li), i) for (i, li) in enumerate(linear_indexes)]
283
+ )
284
+ # positions are 1-based and inclusive
285
+ positions = np.hstack(
286
+ [
287
+ np.arange(len(li)) * TABIX_LINEAR_INDEX_INTERVAL_SIZE + 1
288
+ for li in linear_indexes
289
+ ]
290
+ )
291
+ assert len(file_offsets) == len(contig_indexes)
292
+ assert len(file_offsets) == len(positions)
293
+
294
+ return file_offsets, contig_indexes, positions
295
+
296
+
297
+ def read_tabix(
298
+ file: PathType, storage_options: Optional[Dict[str, str]] = None
299
+ ) -> TabixIndex:
300
+ """Parse a tabix file into a `TabixIndex` object.
301
+
302
+ Parameters
303
+ ----------
304
+ file : PathType
305
+ The path to the tabix file.
306
+
307
+ Returns
308
+ -------
309
+ TabixIndex
310
+ An object representing a tabix index.
311
+
312
+ Raises
313
+ ------
314
+ ValueError
315
+ If the file is not a tabix file.
316
+ """
317
+ with gzip.open(file) as f:
318
+ magic = read_bytes_as_value(f, "4s")
319
+ if magic != b"TBI\x01":
320
+ raise ValueError("File not in Tabix format.")
321
+
322
+ header = Header(*read_bytes_as_tuple(f, "<8i"))
323
+
324
+ sequence_names = []
325
+ bins = []
326
+ linear_indexes = []
327
+ record_counts = []
328
+
329
+ if header.l_nm > 0:
330
+ names = read_bytes_as_value(f, f"<{header.l_nm}s")
331
+ # Convert \0-terminated names to strings
332
+ sequence_names = [str(name, "utf-8") for name in names.split(b"\x00")[:-1]]
333
+
334
+ for _ in range(header.n_ref):
335
+ n_bin = read_bytes_as_value(f, "<i")
336
+ seq_bins = []
337
+ record_count = -1
338
+ for _ in range(n_bin):
339
+ bin, n_chunk = read_bytes_as_tuple(f, "<Ii")
340
+ chunks = []
341
+ for _ in range(n_chunk):
342
+ chunk = Chunk(*read_bytes_as_tuple(f, "<QQ"))
343
+ chunks.append(chunk)
344
+ seq_bins.append(TabixBin(bin, chunks))
345
+
346
+ if bin == 37450: # pseudo-bin, see section 5.2 of BAM spec
347
+ assert len(chunks) == 2
348
+ n_mapped, n_unmapped = chunks[1].cnk_beg, chunks[1].cnk_end
349
+ record_count = n_mapped + n_unmapped
350
+ n_intv = read_bytes_as_value(f, "<i")
351
+ linear_index = []
352
+ for _ in range(n_intv):
353
+ ioff = read_bytes_as_value(f, "<Q")
354
+ linear_index.append(ioff)
355
+ bins.append(seq_bins)
356
+ linear_indexes.append(linear_index)
357
+ record_counts.append(record_count)
358
+
359
+ n_no_coor = read_bytes_as_value(f, "<Q", 0)
360
+
361
+ assert len(f.read(1)) == 0
362
+
363
+ return TabixIndex(
364
+ header, sequence_names, bins, linear_indexes, record_counts, n_no_coor
365
+ )
366
+
367
+
368
+ class IndexedVcf(contextlib.AbstractContextManager):
369
+ def __init__(self, vcf_path, index_path=None):
370
+ self.vcf = None
371
+ vcf_path = pathlib.Path(vcf_path)
372
+ if not vcf_path.exists():
373
+ raise FileNotFoundError(vcf_path)
374
+ # TODO use constants here instead of strings
375
+ if index_path is None:
376
+ index_path = vcf_path.with_suffix(vcf_path.suffix + ".tbi")
377
+ if not index_path.exists():
378
+ index_path = vcf_path.with_suffix(vcf_path.suffix + ".csi")
379
+ if not index_path.exists():
380
+ raise FileNotFoundError(
381
+ "Cannot find .tbi or .csi file for {vcf_path}"
382
+ )
383
+ else:
384
+ index_path = pathlib.Path(index_path)
385
+
386
+ self.vcf_path = vcf_path
387
+ self.index_path = index_path
388
+ # TODO use Enums for these
389
+ self.file_type = None
390
+ self.index_type = None
391
+ if index_path.suffix == ".csi":
392
+ self.index_type = "csi"
393
+ elif index_path.suffix == ".tbi":
394
+ self.index_type = "tabix"
395
+ self.file_type = "vcf"
396
+ else:
397
+ raise ValueError("Only .tbi or .csi indexes are supported.")
398
+ self.vcf = cyvcf2.VCF(vcf_path)
399
+ self.vcf.set_index(str(self.index_path))
400
+ self.sequence_names = None
401
+ if self.index_type == "csi":
402
+ # Determine the file-type based on the "aux" field.
403
+ self.index = read_csi(self.index_path)
404
+ self.file_type = "bcf"
405
+ if len(self.index.aux) > 0:
406
+ self.file_type = "vcf"
407
+ self.sequence_names = self.index.parse_vcf_aux()
408
+ else:
409
+ self.sequence_names = self.vcf.seqnames
410
+ else:
411
+ self.index = read_tabix(self.index_path)
412
+ self.sequence_names = self.index.sequence_names
413
+
414
+ def __exit__(self, exc_type, exc_val, exc_tb):
415
+ if self.vcf is not None:
416
+ self.vcf.close()
417
+ self.vcf = None
418
+ return False
419
+
420
+ def contig_record_counts(self):
421
+ d = dict(zip(self.sequence_names, self.index.record_counts))
422
+ if self.file_type == "bcf":
423
+ d = {k: v for k, v in d.items() if v > 0}
424
+ return d
425
+
426
+ def count_variants(self, region):
427
+ return sum(1 for _ in self.variants(region))
428
+
429
+ def variants(self, region):
430
+ # Need to filter because of indels overlapping the region
431
+ start = 1 if region.start is None else region.start
432
+ for var in self.vcf(str(region)):
433
+ if var.POS >= start:
434
+ yield var
435
+
436
+ def partition_into_regions(
437
+ self,
438
+ num_parts: Optional[int] = None,
439
+ target_part_size: Union[None, int, str] = None,
440
+ ):
441
+ if num_parts is None and target_part_size is None:
442
+ raise ValueError("One of num_parts or target_part_size must be specified")
443
+
444
+ if num_parts is not None and target_part_size is not None:
445
+ raise ValueError(
446
+ "Only one of num_parts or target_part_size may be specified"
447
+ )
448
+
449
+ if num_parts is not None and num_parts < 1:
450
+ raise ValueError("num_parts must be positive")
451
+
452
+ if target_part_size is not None:
453
+ if isinstance(target_part_size, int):
454
+ target_part_size_bytes = target_part_size
455
+ else:
456
+ target_part_size_bytes = humanfriendly.parse_size(target_part_size)
457
+ if target_part_size_bytes < 1:
458
+ raise ValueError("target_part_size must be positive")
459
+
460
+ # Calculate the desired part file boundaries
461
+ file_length = os.stat(self.vcf_path).st_size
462
+ if num_parts is not None:
463
+ target_part_size_bytes = file_length // num_parts
464
+ elif target_part_size_bytes is not None:
465
+ num_parts = ceildiv(file_length, target_part_size_bytes)
466
+ part_lengths = np.array([i * target_part_size_bytes for i in range(num_parts)])
467
+
468
+ file_offsets, region_contig_indexes, region_positions = self.index.offsets()
469
+
470
+ # Search the file offsets to find which indexes the part lengths fall at
471
+ ind = np.searchsorted(file_offsets, part_lengths)
472
+
473
+ # Drop any parts that are greater than the file offsets
474
+ # (these will be covered by a region with no end)
475
+ ind = np.delete(ind, ind >= len(file_offsets))
476
+
477
+ # Drop any duplicates
478
+ ind = np.unique(ind)
479
+
480
+ # Calculate region contig and start for each index
481
+ region_contigs = region_contig_indexes[ind]
482
+ region_starts = region_positions[ind]
483
+
484
+ # Build region query strings
485
+ regions = []
486
+ for i in range(len(region_starts)):
487
+ contig = self.sequence_names[region_contigs[i]]
488
+ start = region_starts[i]
489
+
490
+ if i == len(region_starts) - 1: # final region
491
+ regions.append(Region(contig, start))
492
+ else:
493
+ next_contig = self.sequence_names[region_contigs[i + 1]]
494
+ next_start = region_starts[i + 1]
495
+ end = next_start - 1 # subtract one since positions are inclusive
496
+ # print("next_start", next_contig, next_start)
497
+ if next_contig == contig: # contig doesn't change
498
+ regions.append(Region(contig, start, end))
499
+ else:
500
+ # contig changes, so need two regions (or possibly more if any
501
+ # sequences were skipped)
502
+ regions.append(Region(contig, start))
503
+ for ri in range(region_contigs[i] + 1, region_contigs[i + 1]):
504
+ regions.append(Region(self.sequence_names[ri]))
505
+ if end >= 1:
506
+ regions.append(Region(next_contig, 1, end))
507
+
508
+ # Add any sequences at the end that were not skipped
509
+ for ri in range(region_contigs[-1] + 1, len(self.sequence_names)):
510
+ if self.index.record_counts[ri] > 0:
511
+ regions.append(Region(self.sequence_names[ri]))
512
+
513
+ return regions