bio2zarr 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bio2zarr might be problematic. Click here for more details.

bio2zarr/vcf.py ADDED
@@ -0,0 +1,1802 @@
1
+ import collections
2
+ import dataclasses
3
+ import functools
4
+ import logging
5
+ import os
6
+ import pathlib
7
+ import pickle
8
+ import sys
9
+ import shutil
10
+ import json
11
+ import math
12
+ import tempfile
13
+ import contextlib
14
+ from typing import Any, List
15
+
16
+ import humanfriendly
17
+ import cyvcf2
18
+ import numcodecs
19
+ import numpy as np
20
+ import numpy.testing as nt
21
+ import tqdm
22
+ import zarr
23
+
24
+ from . import core
25
+ from . import provenance
26
+ from . import vcf_utils
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ INT_MISSING = -1
31
+ INT_FILL = -2
32
+ STR_MISSING = "."
33
+ STR_FILL = ""
34
+
35
+ FLOAT32_MISSING, FLOAT32_FILL = np.array([0x7F800001, 0x7F800002], dtype=np.int32).view(
36
+ np.float32
37
+ )
38
+ FLOAT32_MISSING_AS_INT32, FLOAT32_FILL_AS_INT32 = np.array(
39
+ [0x7F800001, 0x7F800002], dtype=np.int32
40
+ )
41
+
42
+
43
+ @dataclasses.dataclass
44
+ class VcfFieldSummary:
45
+ num_chunks: int = 0
46
+ compressed_size: int = 0
47
+ uncompressed_size: int = 0
48
+ max_number: int = 0 # Corresponds to VCF Number field, depends on context
49
+ # Only defined for numeric fields
50
+ max_value: Any = -math.inf
51
+ min_value: Any = math.inf
52
+
53
+ def update(self, other):
54
+ self.num_chunks += other.num_chunks
55
+ self.compressed_size += other.compressed_size
56
+ self.uncompressed_size = other.uncompressed_size
57
+ self.max_number = max(self.max_number, other.max_number)
58
+ self.min_value = min(self.min_value, other.min_value)
59
+ self.max_value = max(self.max_value, other.max_value)
60
+
61
+
62
+ @dataclasses.dataclass
63
+ class VcfField:
64
+ category: str
65
+ name: str
66
+ vcf_number: str
67
+ vcf_type: str
68
+ description: str
69
+ summary: VcfFieldSummary
70
+
71
+ @staticmethod
72
+ def from_header(definition):
73
+ category = definition["HeaderType"]
74
+ name = definition["ID"]
75
+ vcf_number = definition["Number"]
76
+ vcf_type = definition["Type"]
77
+ return VcfField(
78
+ category=category,
79
+ name=name,
80
+ vcf_number=vcf_number,
81
+ vcf_type=vcf_type,
82
+ description=definition["Description"].strip('"'),
83
+ summary=VcfFieldSummary(),
84
+ )
85
+
86
+ @staticmethod
87
+ def fromdict(d):
88
+ f = VcfField(**d)
89
+ f.summary = VcfFieldSummary(**d["summary"])
90
+ return f
91
+
92
+ @property
93
+ def full_name(self):
94
+ if self.category == "fixed":
95
+ return self.name
96
+ return f"{self.category}/{self.name}"
97
+
98
+ # TODO add method here to choose a good set compressor and
99
+ # filters default here for this field.
100
+
101
+ def smallest_dtype(self):
102
+ """
103
+ Returns the smallest dtype suitable for this field based
104
+ on type, and values.
105
+ """
106
+ s = self.summary
107
+ if self.vcf_type == "Float":
108
+ ret = "f4"
109
+ elif self.vcf_type == "Integer":
110
+ dtype = "i4"
111
+ for a_dtype in ["i1", "i2"]:
112
+ info = np.iinfo(a_dtype)
113
+ if info.min <= s.min_value and s.max_value <= info.max:
114
+ dtype = a_dtype
115
+ break
116
+ ret = dtype
117
+ elif self.vcf_type == "Flag":
118
+ ret = "bool"
119
+ elif self.vcf_type == "Character":
120
+ ret = "U1"
121
+ else:
122
+ assert self.vcf_type == "String"
123
+ ret = "O"
124
+ return ret
125
+
126
+
127
+ @dataclasses.dataclass
128
+ class VcfPartition:
129
+ vcf_path: str
130
+ region: str
131
+ num_records: int = -1
132
+
133
+
134
+ @dataclasses.dataclass
135
+ class VcfMetadata:
136
+ format_version: str
137
+ samples: list
138
+ contig_names: list
139
+ contig_record_counts: dict
140
+ filters: list
141
+ fields: list
142
+ partitions: list = None
143
+ contig_lengths: list = None
144
+
145
+ @property
146
+ def info_fields(self):
147
+ fields = []
148
+ for field in self.fields:
149
+ if field.category == "INFO":
150
+ fields.append(field)
151
+ return fields
152
+
153
+ @property
154
+ def format_fields(self):
155
+ fields = []
156
+ for field in self.fields:
157
+ if field.category == "FORMAT":
158
+ fields.append(field)
159
+ return fields
160
+
161
+ @property
162
+ def num_records(self):
163
+ return sum(self.contig_record_counts.values())
164
+
165
+ @staticmethod
166
+ def fromdict(d):
167
+ fields = [VcfField.fromdict(fd) for fd in d["fields"]]
168
+ partitions = [VcfPartition(**pd) for pd in d["partitions"]]
169
+ d = d.copy()
170
+ d["fields"] = fields
171
+ d["partitions"] = partitions
172
+ return VcfMetadata(**d)
173
+
174
+ def asdict(self):
175
+ return dataclasses.asdict(self)
176
+
177
+
178
+ def fixed_vcf_field_definitions():
179
+ def make_field_def(name, vcf_type, vcf_number):
180
+ return VcfField(
181
+ category="fixed",
182
+ name=name,
183
+ vcf_type=vcf_type,
184
+ vcf_number=vcf_number,
185
+ description="",
186
+ summary=VcfFieldSummary(),
187
+ )
188
+
189
+ fields = [
190
+ make_field_def("CHROM", "String", "1"),
191
+ make_field_def("POS", "Integer", "1"),
192
+ make_field_def("QUAL", "Float", "1"),
193
+ make_field_def("ID", "String", "."),
194
+ make_field_def("FILTERS", "String", "."),
195
+ make_field_def("REF", "String", "1"),
196
+ make_field_def("ALT", "String", "."),
197
+ ]
198
+ return fields
199
+
200
+
201
+ def scan_vcf(path, target_num_partitions):
202
+ with vcf_utils.IndexedVcf(path) as indexed_vcf:
203
+ vcf = indexed_vcf.vcf
204
+ filters = [
205
+ h["ID"]
206
+ for h in vcf.header_iter()
207
+ if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str)
208
+ ]
209
+ # Ensure PASS is the first filter if present
210
+ if "PASS" in filters:
211
+ filters.remove("PASS")
212
+ filters.insert(0, "PASS")
213
+
214
+ fields = fixed_vcf_field_definitions()
215
+ for h in vcf.header_iter():
216
+ if h["HeaderType"] in ["INFO", "FORMAT"]:
217
+ field = VcfField.from_header(h)
218
+ if field.name == "GT":
219
+ field.vcf_type = "Integer"
220
+ field.vcf_number = "."
221
+ fields.append(field)
222
+
223
+ metadata = VcfMetadata(
224
+ samples=vcf.samples,
225
+ contig_names=vcf.seqnames,
226
+ contig_record_counts=indexed_vcf.contig_record_counts(),
227
+ filters=filters,
228
+ # TODO use the mapping dictionary
229
+ fields=fields,
230
+ partitions=[],
231
+ # FIXME do something systematic with this
232
+ format_version="0.1",
233
+ )
234
+ try:
235
+ metadata.contig_lengths = vcf.seqlens
236
+ except AttributeError:
237
+ pass
238
+
239
+ regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions)
240
+ logger.info(
241
+ f"Split {path} into {len(regions)} regions (target={target_num_partitions})"
242
+ )
243
+ for region in regions:
244
+ metadata.partitions.append(
245
+ VcfPartition(
246
+ vcf_path=str(path),
247
+ region=region,
248
+ )
249
+ )
250
+ core.update_progress(1)
251
+ return metadata, vcf.raw_header
252
+
253
+
254
+ def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
255
+ logger.info(f"Scanning {len(paths)} VCFs")
256
+ progress_config = core.ProgressConfig(
257
+ total=len(paths),
258
+ units="files",
259
+ title="Scan",
260
+ show=show_progress,
261
+ )
262
+ with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
263
+ for path in paths:
264
+ pwm.submit(scan_vcf, path, target_num_partitions)
265
+ results = list(pwm.results_as_completed())
266
+
267
+ # Sort to make the ordering deterministic
268
+ results.sort(key=lambda t: t[0].partitions[0].vcf_path)
269
+ # We just take the first header, assuming the others
270
+ # are compatible.
271
+ all_partitions = []
272
+ contig_record_counts = collections.Counter()
273
+ for metadata, _ in results:
274
+ all_partitions.extend(metadata.partitions)
275
+ metadata.partitions.clear()
276
+ contig_record_counts += metadata.contig_record_counts
277
+ metadata.contig_record_counts.clear()
278
+
279
+ vcf_metadata, header = results[0]
280
+ for metadata, _ in results[1:]:
281
+ if metadata != vcf_metadata:
282
+ raise ValueError("Incompatible VCF chunks")
283
+
284
+ vcf_metadata.contig_record_counts = dict(contig_record_counts)
285
+
286
+ # Sort by contig (in the order they appear in the header) first,
287
+ # then by start coordinate
288
+ contig_index_map = {contig: j for j, contig in enumerate(metadata.contig_names)}
289
+ all_partitions.sort(
290
+ key=lambda x: (contig_index_map[x.region.contig], x.region.start)
291
+ )
292
+ vcf_metadata.partitions = all_partitions
293
+ return vcf_metadata, header
294
+
295
+
296
+ def sanitise_value_bool(buff, j, value):
297
+ x = True
298
+ if value is None:
299
+ x = False
300
+ buff[j] = x
301
+
302
+
303
+ def sanitise_value_float_scalar(buff, j, value):
304
+ x = value
305
+ if value is None:
306
+ x = FLOAT32_MISSING
307
+ buff[j] = x
308
+
309
+
310
+ def sanitise_value_int_scalar(buff, j, value):
311
+ x = value
312
+ if value is None:
313
+ # print("MISSING", INT_MISSING, INT_FILL)
314
+ x = [INT_MISSING]
315
+ else:
316
+ x = sanitise_int_array([value], ndmin=1, dtype=np.int32)
317
+ buff[j] = x[0]
318
+
319
+
320
+ def sanitise_value_string_scalar(buff, j, value):
321
+ if value is None:
322
+ buff[j] = "."
323
+ else:
324
+ buff[j] = value[0]
325
+
326
+
327
+ def sanitise_value_string_1d(buff, j, value):
328
+ if value is None:
329
+ buff[j] = "."
330
+ else:
331
+ # value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
332
+ # FIXME failure isn't coming from here, it seems to be from an
333
+ # incorrectly detected dimension in the zarr array
334
+ # The dimesions look all wrong, and the dtype should be Object
335
+ # not str
336
+ value = drop_empty_second_dim(value)
337
+ buff[j] = ""
338
+ buff[j, : value.shape[0]] = value
339
+
340
+
341
+ def sanitise_value_string_2d(buff, j, value):
342
+ if value is None:
343
+ buff[j] = "."
344
+ else:
345
+ # print(buff.shape, value.dtype, value)
346
+ # assert value.ndim == 2
347
+ buff[j] = ""
348
+ if value.ndim == 2:
349
+ buff[j, :, : value.shape[1]] = value
350
+ else:
351
+ # TODO check if this is still necessary
352
+ for k, val in enumerate(value):
353
+ buff[j, k, : len(val)] = val
354
+
355
+
356
+ def drop_empty_second_dim(value):
357
+ assert len(value.shape) == 1 or value.shape[1] == 1
358
+ if len(value.shape) == 2 and value.shape[1] == 1:
359
+ value = value[..., 0]
360
+ return value
361
+
362
+
363
+ def sanitise_value_float_1d(buff, j, value):
364
+ if value is None:
365
+ buff[j] = FLOAT32_MISSING
366
+ else:
367
+ value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
368
+ # numpy will map None values to Nan, but we need a
369
+ # specific NaN
370
+ value[np.isnan(value)] = FLOAT32_MISSING
371
+ value = drop_empty_second_dim(value)
372
+ buff[j] = FLOAT32_FILL
373
+ buff[j, : value.shape[0]] = value
374
+
375
+
376
+ def sanitise_value_float_2d(buff, j, value):
377
+ if value is None:
378
+ buff[j] = FLOAT32_MISSING
379
+ else:
380
+ # print("value = ", value)
381
+ value = np.array(value, ndmin=2, dtype=buff.dtype, copy=False)
382
+ buff[j] = FLOAT32_FILL
383
+ buff[j, :, : value.shape[1]] = value
384
+
385
+
386
+ def sanitise_int_array(value, ndmin, dtype):
387
+ if isinstance(value, tuple):
388
+ value = [VCF_INT_MISSING if x is None else x for x in value]
389
+ value = np.array(value, ndmin=ndmin, copy=False)
390
+ value[value == VCF_INT_MISSING] = -1
391
+ value[value == VCF_INT_FILL] = -2
392
+ # TODO watch out for clipping here!
393
+ return value.astype(dtype)
394
+
395
+
396
+ def sanitise_value_int_1d(buff, j, value):
397
+ if value is None:
398
+ buff[j] = -1
399
+ else:
400
+ value = sanitise_int_array(value, 1, buff.dtype)
401
+ value = drop_empty_second_dim(value)
402
+ buff[j] = -2
403
+ buff[j, : value.shape[0]] = value
404
+
405
+
406
+ def sanitise_value_int_2d(buff, j, value):
407
+ if value is None:
408
+ buff[j] = -1
409
+ else:
410
+ value = sanitise_int_array(value, 2, buff.dtype)
411
+ buff[j] = -2
412
+ buff[j, :, : value.shape[1]] = value
413
+
414
+
415
+ MIN_INT_VALUE = np.iinfo(np.int32).min + 2
416
+ VCF_INT_MISSING = np.iinfo(np.int32).min
417
+ VCF_INT_FILL = np.iinfo(np.int32).min + 1
418
+
419
+ missing_value_map = {
420
+ "Integer": -1,
421
+ "Float": FLOAT32_MISSING,
422
+ "String": ".",
423
+ "Character": ".",
424
+ "Flag": False,
425
+ }
426
+
427
+
428
+ class VcfValueTransformer:
429
+ """
430
+ Transform VCF values into the stored intermediate format used
431
+ in the PickleChunkedVcf, and update field summaries.
432
+ """
433
+
434
+ def __init__(self, field, num_samples):
435
+ self.field = field
436
+ self.num_samples = num_samples
437
+ self.dimension = 1
438
+ if field.category == "FORMAT":
439
+ self.dimension = 2
440
+ self.missing = missing_value_map[field.vcf_type]
441
+
442
+ @staticmethod
443
+ def factory(field, num_samples):
444
+ if field.vcf_type in ("Integer", "Flag"):
445
+ return IntegerValueTransformer(field, num_samples)
446
+ if field.vcf_type == "Float":
447
+ return FloatValueTransformer(field, num_samples)
448
+ if field.name in ["REF", "FILTERS", "ALT", "ID", "CHROM"]:
449
+ return SplitStringValueTransformer(field, num_samples)
450
+ return StringValueTransformer(field, num_samples)
451
+
452
+ def transform(self, vcf_value):
453
+ if isinstance(vcf_value, tuple):
454
+ vcf_value = [self.missing if v is None else v for v in vcf_value]
455
+ value = np.array(vcf_value, ndmin=self.dimension, copy=False)
456
+ return value
457
+
458
+ def transform_and_update_bounds(self, vcf_value):
459
+ if vcf_value is None:
460
+ return None
461
+ value = self.transform(vcf_value)
462
+ self.update_bounds(value)
463
+ # print(self.field.full_name, "T", vcf_value, "->", value)
464
+ return value
465
+
466
+
467
+ MIN_INT_VALUE = np.iinfo(np.int32).min + 2
468
+ VCF_INT_MISSING = np.iinfo(np.int32).min
469
+ VCF_INT_FILL = np.iinfo(np.int32).min + 1
470
+
471
+
472
+ class IntegerValueTransformer(VcfValueTransformer):
473
+ def update_bounds(self, value):
474
+ summary = self.field.summary
475
+ # Mask out missing and fill values
476
+ # print(value)
477
+ a = value[value >= MIN_INT_VALUE]
478
+ if a.size > 0:
479
+ summary.max_value = int(max(summary.max_value, np.max(a)))
480
+ summary.min_value = int(min(summary.min_value, np.min(a)))
481
+ number = value.shape[-1]
482
+ summary.max_number = max(summary.max_number, number)
483
+
484
+
485
+ class FloatValueTransformer(VcfValueTransformer):
486
+ def update_bounds(self, value):
487
+ summary = self.field.summary
488
+ summary.max_value = float(max(summary.max_value, np.max(value)))
489
+ summary.min_value = float(min(summary.min_value, np.min(value)))
490
+ number = value.shape[-1]
491
+ summary.max_number = max(summary.max_number, number)
492
+
493
+
494
+ class StringValueTransformer(VcfValueTransformer):
495
+ def update_bounds(self, value):
496
+ summary = self.field.summary
497
+ number = value.shape[-1]
498
+ # TODO would be nice to report string lengths, but not
499
+ # really necessary.
500
+ summary.max_number = max(summary.max_number, number)
501
+
502
+ def transform(self, vcf_value):
503
+ # print("transform", vcf_value)
504
+ if self.dimension == 1:
505
+ value = np.array(list(vcf_value.split(",")))
506
+ else:
507
+ # TODO can we make this faster??
508
+ value = np.array([v.split(",") for v in vcf_value], dtype="O")
509
+ # print("HERE", vcf_value, value)
510
+ # for v in vcf_value:
511
+ # print("\t", type(v), len(v), v.split(","))
512
+ # print("S: ", self.dimension, ":", value.shape, value)
513
+ return value
514
+
515
+
516
+ class SplitStringValueTransformer(StringValueTransformer):
517
+ def transform(self, vcf_value):
518
+ if vcf_value is None:
519
+ return self.missing_value
520
+ assert self.dimension == 1
521
+ return np.array(vcf_value, ndmin=1, dtype="str")
522
+
523
+
524
+ class PickleChunkedVcfField:
525
+ def __init__(self, pcvcf, vcf_field):
526
+ self.vcf_field = vcf_field
527
+ self.path = self.get_path(pcvcf.path, vcf_field)
528
+ self.compressor = pcvcf.compressor
529
+ self.num_partitions = pcvcf.num_partitions
530
+ self.num_records = pcvcf.num_records
531
+ self.partition_record_index = pcvcf.partition_record_index
532
+ # A map of partition id to the cumulative number of records
533
+ # in chunks within that partition
534
+ self._chunk_record_index = {}
535
+
536
+ @staticmethod
537
+ def get_path(base_path, vcf_field):
538
+ if vcf_field.category == "fixed":
539
+ return base_path / vcf_field.name
540
+ return base_path / vcf_field.category / vcf_field.name
541
+
542
+ @property
543
+ def name(self):
544
+ return self.vcf_field.full_name
545
+
546
+ def partition_path(self, partition_id):
547
+ return self.path / f"p{partition_id}"
548
+
549
+ def __repr__(self):
550
+ partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)]
551
+ return (
552
+ f"PickleChunkedVcfField(name={self.name}, "
553
+ f"partition_chunks={partition_chunks}, "
554
+ f"path={self.path})"
555
+ )
556
+
557
+ def num_chunks(self, partition_id):
558
+ return len(self.chunk_cumulative_records(partition_id))
559
+
560
+ def chunk_record_index(self, partition_id):
561
+ if partition_id not in self._chunk_record_index:
562
+ index_path = self.partition_path(partition_id) / "chunk_index.pkl"
563
+ with open(index_path, "rb") as f:
564
+ a = pickle.load(f)
565
+ assert len(a) > 1
566
+ assert a[0] == 0
567
+ self._chunk_record_index[partition_id] = a
568
+ return self._chunk_record_index[partition_id]
569
+
570
+ def chunk_cumulative_records(self, partition_id):
571
+ return self.chunk_record_index(partition_id)[1:]
572
+
573
+ def chunk_num_records(self, partition_id):
574
+ return np.diff(self.chunk_cumulative_records(partition_id))
575
+
576
+ def chunk_files(self, partition_id, start=0):
577
+ partition_path = self.partition_path(partition_id)
578
+ for n in self.chunk_cumulative_records(partition_id)[start:]:
579
+ yield partition_path / f"{n}.pkl"
580
+
581
+ def read_chunk(self, path):
582
+ with open(path, "rb") as f:
583
+ pkl = self.compressor.decode(f.read())
584
+ return pickle.loads(pkl)
585
+
586
+ def iter_values(self, start=None, stop=None):
587
+ start = 0 if start is None else start
588
+ stop = self.num_records if stop is None else stop
589
+ start_partition = (
590
+ np.searchsorted(self.partition_record_index, start, side="right") - 1
591
+ )
592
+ offset = self.partition_record_index[start_partition]
593
+ assert offset <= start
594
+ chunk_offset = start - offset
595
+
596
+ chunk_record_index = self.chunk_record_index(start_partition)
597
+ start_chunk = (
598
+ np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1
599
+ )
600
+ record_id = offset + chunk_record_index[start_chunk]
601
+ assert record_id <= start
602
+ logger.debug(
603
+ f"Read {self.vcf_field.full_name} slice [{start}:{stop}]:"
604
+ f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}"
605
+ )
606
+
607
+ for chunk_path in self.chunk_files(start_partition, start_chunk):
608
+ chunk = self.read_chunk(chunk_path)
609
+ for record in chunk:
610
+ if record_id == stop:
611
+ return
612
+ if record_id >= start:
613
+ yield record
614
+ record_id += 1
615
+ assert record_id > start
616
+ for partition_id in range(start_partition + 1, self.num_partitions):
617
+ for chunk_path in self.chunk_files(partition_id):
618
+ chunk = self.read_chunk(chunk_path)
619
+ for record in chunk:
620
+ if record_id == stop:
621
+ return
622
+ yield record
623
+ record_id += 1
624
+
625
+ # Note: this involves some computation so should arguably be a method,
626
+ # but making a property for consistency with xarray etc
627
+ @property
628
+ def values(self):
629
+ ret = [None] * self.num_records
630
+ j = 0
631
+ for partition_id in range(self.num_partitions):
632
+ for chunk_path in self.chunk_files(partition_id):
633
+ chunk = self.read_chunk(chunk_path)
634
+ for record in chunk:
635
+ ret[j] = record
636
+ j += 1
637
+ if j != self.num_records:
638
+ raise ValueError(
639
+ f"Corruption detected: incorrect number of records in {str(self.path)}."
640
+ )
641
+ return ret
642
+
643
+ def sanitiser_factory(self, shape):
644
+ """
645
+ Return a function that sanitised values from this column
646
+ and writes into a buffer of the specified shape.
647
+ """
648
+ assert len(shape) <= 3
649
+ if self.vcf_field.vcf_type == "Flag":
650
+ assert len(shape) == 1
651
+ return sanitise_value_bool
652
+ elif self.vcf_field.vcf_type == "Float":
653
+ if len(shape) == 1:
654
+ return sanitise_value_float_scalar
655
+ elif len(shape) == 2:
656
+ return sanitise_value_float_1d
657
+ else:
658
+ return sanitise_value_float_2d
659
+ elif self.vcf_field.vcf_type == "Integer":
660
+ if len(shape) == 1:
661
+ return sanitise_value_int_scalar
662
+ elif len(shape) == 2:
663
+ return sanitise_value_int_1d
664
+ else:
665
+ return sanitise_value_int_2d
666
+ else:
667
+ assert self.vcf_field.vcf_type in ("String", "Character")
668
+ if len(shape) == 1:
669
+ return sanitise_value_string_scalar
670
+ elif len(shape) == 2:
671
+ return sanitise_value_string_1d
672
+ else:
673
+ return sanitise_value_string_2d
674
+
675
+
676
+ @dataclasses.dataclass
677
+ class PcvcfFieldWriter:
678
+ vcf_field: VcfField
679
+ path: pathlib.Path
680
+ transformer: VcfValueTransformer
681
+ compressor: Any
682
+ max_buffered_bytes: int
683
+ buff: List[Any] = dataclasses.field(default_factory=list)
684
+ buffered_bytes: int = 0
685
+ chunk_index: List[int] = dataclasses.field(default_factory=lambda: [0])
686
+ num_records: int = 0
687
+
688
+ def append(self, val):
689
+ val = self.transformer.transform_and_update_bounds(val)
690
+ assert val is None or isinstance(val, np.ndarray)
691
+ self.buff.append(val)
692
+ val_bytes = sys.getsizeof(val)
693
+ self.buffered_bytes += val_bytes
694
+ self.num_records += 1
695
+ if self.buffered_bytes >= self.max_buffered_bytes:
696
+ logger.debug(
697
+ f"Flush {self.path} buffered={self.buffered_bytes} "
698
+ f"max={self.max_buffered_bytes}"
699
+ )
700
+ self.write_chunk()
701
+ self.buff.clear()
702
+ self.buffered_bytes = 0
703
+
704
+ def write_chunk(self):
705
+ # Update index
706
+ self.chunk_index.append(self.num_records)
707
+ path = self.path / f"{self.num_records}.pkl"
708
+ logger.debug(f"Start write: {path}")
709
+ pkl = pickle.dumps(self.buff)
710
+ compressed = self.compressor.encode(pkl)
711
+ with open(path, "wb") as f:
712
+ f.write(compressed)
713
+
714
+ # Update the summary
715
+ self.vcf_field.summary.num_chunks += 1
716
+ self.vcf_field.summary.compressed_size += len(compressed)
717
+ self.vcf_field.summary.uncompressed_size += self.buffered_bytes
718
+ logger.debug(f"Finish write: {path}")
719
+
720
+ def flush(self):
721
+ logger.debug(
722
+ f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}"
723
+ )
724
+ if len(self.buff) > 0:
725
+ self.write_chunk()
726
+ with open(self.path / "chunk_index.pkl", "wb") as f:
727
+ a = np.array(self.chunk_index, dtype=int)
728
+ pickle.dump(a, f)
729
+
730
+
731
+ class PcvcfPartitionWriter(contextlib.AbstractContextManager):
732
+ """
733
+ Writes the data for a PickleChunkedVcf partition.
734
+ """
735
+
736
+ def __init__(
737
+ self,
738
+ vcf_metadata,
739
+ out_path,
740
+ partition_index,
741
+ compressor,
742
+ *,
743
+ chunk_size=1,
744
+ ):
745
+ self.partition_index = partition_index
746
+ # chunk_size is in megabytes
747
+ max_buffered_bytes = chunk_size * 2**20
748
+ assert max_buffered_bytes > 0
749
+
750
+ self.field_writers = {}
751
+ num_samples = len(vcf_metadata.samples)
752
+ for vcf_field in vcf_metadata.fields:
753
+ field_path = PickleChunkedVcfField.get_path(out_path, vcf_field)
754
+ field_partition_path = field_path / f"p{partition_index}"
755
+ transformer = VcfValueTransformer.factory(vcf_field, num_samples)
756
+ self.field_writers[vcf_field.full_name] = PcvcfFieldWriter(
757
+ vcf_field,
758
+ field_partition_path,
759
+ transformer,
760
+ compressor,
761
+ max_buffered_bytes,
762
+ )
763
+
764
+ @property
765
+ def field_summaries(self):
766
+ return {
767
+ name: field.vcf_field.summary for name, field in self.field_writers.items()
768
+ }
769
+
770
+ def append(self, name, value):
771
+ self.field_writers[name].append(value)
772
+
773
+ def __exit__(self, exc_type, exc_val, exc_tb):
774
+ if exc_type is None:
775
+ for field in self.field_writers.values():
776
+ field.flush()
777
+ return False
778
+
779
+
780
+ class PickleChunkedVcf(collections.abc.Mapping):
781
+ # TODO Check if other compressors would give reasonable compression
782
+ # with significantly faster times
783
+ DEFAULT_COMPRESSOR = numcodecs.Blosc(cname="zstd", clevel=7)
784
+
785
+ def __init__(self, path, metadata, vcf_header):
786
+ self.path = path
787
+ self.metadata = metadata
788
+ self.vcf_header = vcf_header
789
+ self.compressor = self.DEFAULT_COMPRESSOR
790
+ self.columns = {}
791
+ partition_num_records = [
792
+ partition.num_records for partition in self.metadata.partitions
793
+ ]
794
+ # Allow us to find which partition a given record is in
795
+ self.partition_record_index = np.cumsum([0] + partition_num_records)
796
+ for field in self.metadata.fields:
797
+ self.columns[field.full_name] = PickleChunkedVcfField(self, field)
798
+
799
+ def __repr__(self):
800
+ return (
801
+ f"PickleChunkedVcf(fields={len(self)}, partitions={self.num_partitions}, "
802
+ f"records={self.num_records}, path={self.path})"
803
+ )
804
+
805
+ def __getitem__(self, key):
806
+ return self.columns[key]
807
+
808
+ def __iter__(self):
809
+ return iter(self.columns)
810
+
811
+ def __len__(self):
812
+ return len(self.columns)
813
+
814
+ def summary_table(self):
815
+ def display_number(x):
816
+ ret = "n/a"
817
+ if math.isfinite(x):
818
+ ret = f"{x: 0.2g}"
819
+ return ret
820
+
821
+ def display_size(n):
822
+ return humanfriendly.format_size(n)
823
+
824
+ data = []
825
+ for name, col in self.columns.items():
826
+ summary = col.vcf_field.summary
827
+ d = {
828
+ "name": name,
829
+ "type": col.vcf_field.vcf_type,
830
+ "chunks": summary.num_chunks,
831
+ "size": display_size(summary.uncompressed_size),
832
+ "compressed": display_size(summary.compressed_size),
833
+ "max_n": summary.max_number,
834
+ "min_val": display_number(summary.min_value),
835
+ "max_val": display_number(summary.max_value),
836
+ }
837
+
838
+ data.append(d)
839
+ return data
840
+
841
+ @functools.cached_property
842
+ def total_uncompressed_bytes(self):
843
+ total = 0
844
+ for col in self.columns.values():
845
+ summary = col.vcf_field.summary
846
+ total += summary.uncompressed_size
847
+ return total
848
+
849
+ @functools.cached_property
850
+ def num_records(self):
851
+ return sum(self.metadata.contig_record_counts.values())
852
+
853
+ @property
854
+ def num_partitions(self):
855
+ return len(self.metadata.partitions)
856
+
857
+ @property
858
+ def num_samples(self):
859
+ return len(self.metadata.samples)
860
+
861
+ @property
862
+ def num_columns(self):
863
+ return len(self.columns)
864
+
865
+ def mkdirs(self):
866
+ self.path.mkdir()
867
+ for col in self.columns.values():
868
+ col.path.mkdir(parents=True)
869
+ for j in range(self.num_partitions):
870
+ part_path = col.path / f"p{j}"
871
+ part_path.mkdir()
872
+
873
+ @staticmethod
874
+ def load(path):
875
+ path = pathlib.Path(path)
876
+ with open(path / "metadata.json") as f:
877
+ metadata = VcfMetadata.fromdict(json.load(f))
878
+ with open(path / "header.txt") as f:
879
+ header = f.read()
880
+ pcvcf = PickleChunkedVcf(path, metadata, header)
881
+ logger.info(
882
+ f"Loaded PickleChunkedVcf(partitions={pcvcf.num_partitions}, "
883
+ f"records={pcvcf.num_records}, columns={pcvcf.num_columns})"
884
+ )
885
+ return pcvcf
886
+
887
+ @staticmethod
888
+ def convert_partition(
889
+ vcf_metadata,
890
+ partition_index,
891
+ out_path,
892
+ *,
893
+ column_chunk_size=16,
894
+ ):
895
+ partition = vcf_metadata.partitions[partition_index]
896
+ logger.info(
897
+ f"Start p{partition_index} {partition.vcf_path}__{partition.region}"
898
+ )
899
+ info_fields = vcf_metadata.info_fields
900
+ format_fields = []
901
+ has_gt = False
902
+ for field in vcf_metadata.format_fields:
903
+ if field.name == "GT":
904
+ has_gt = True
905
+ else:
906
+ format_fields.append(field)
907
+
908
+ compressor = PickleChunkedVcf.DEFAULT_COMPRESSOR
909
+
910
+ with PcvcfPartitionWriter(
911
+ vcf_metadata,
912
+ out_path,
913
+ partition_index,
914
+ compressor,
915
+ chunk_size=column_chunk_size,
916
+ ) as tcw:
917
+ with vcf_utils.IndexedVcf(partition.vcf_path) as ivcf:
918
+ num_records = 0
919
+ for variant in ivcf.variants(partition.region):
920
+ num_records += 1
921
+ tcw.append("CHROM", variant.CHROM)
922
+ tcw.append("POS", variant.POS)
923
+ tcw.append("QUAL", variant.QUAL)
924
+ tcw.append("ID", variant.ID)
925
+ tcw.append("FILTERS", variant.FILTERS)
926
+ tcw.append("REF", variant.REF)
927
+ tcw.append("ALT", variant.ALT)
928
+ for field in info_fields:
929
+ tcw.append(field.full_name, variant.INFO.get(field.name, None))
930
+ if has_gt:
931
+ tcw.append("FORMAT/GT", variant.genotype.array())
932
+ for field in format_fields:
933
+ val = None
934
+ try:
935
+ val = variant.format(field.name)
936
+ except KeyError:
937
+ pass
938
+ tcw.append(field.full_name, val)
939
+ # Note: an issue with updating the progress per variant here like this
940
+ # is that we get a significant pause at the end of the counter while
941
+ # all the "small" fields get flushed. Possibly not much to be done about it.
942
+ core.update_progress(1)
943
+
944
+ logger.info(
945
+ f"Finish p{partition_index} {partition.vcf_path}__{partition.region}="
946
+ f"{num_records} records"
947
+ )
948
+ return partition_index, tcw.field_summaries, num_records
949
+
950
+ @staticmethod
951
+ def convert(
952
+ vcfs, out_path, *, column_chunk_size=16, worker_processes=1, show_progress=False
953
+ ):
954
+ out_path = pathlib.Path(out_path)
955
+ # TODO make scan work in parallel using general progress code too
956
+ target_num_partitions = max(1, worker_processes * 4)
957
+ vcf_metadata, header = scan_vcfs(
958
+ vcfs,
959
+ worker_processes=worker_processes,
960
+ show_progress=show_progress,
961
+ target_num_partitions=target_num_partitions,
962
+ )
963
+ pcvcf = PickleChunkedVcf(out_path, vcf_metadata, header)
964
+ pcvcf.mkdirs()
965
+
966
+ logger.info(
967
+ f"Exploding {pcvcf.num_columns} columns {vcf_metadata.num_records} variants "
968
+ f"{pcvcf.num_samples} samples"
969
+ )
970
+ progress_config = core.ProgressConfig(
971
+ total=vcf_metadata.num_records,
972
+ units="vars",
973
+ title="Explode",
974
+ show=show_progress,
975
+ )
976
+ with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
977
+ for j, partition in enumerate(vcf_metadata.partitions):
978
+ pwm.submit(
979
+ PickleChunkedVcf.convert_partition,
980
+ vcf_metadata,
981
+ j,
982
+ out_path,
983
+ column_chunk_size=column_chunk_size,
984
+ )
985
+ num_records = 0
986
+ partition_summaries = []
987
+ for index, summary, num_records in pwm.results_as_completed():
988
+ partition_summaries.append(summary)
989
+ vcf_metadata.partitions[index].num_records = num_records
990
+
991
+ total_records = sum(
992
+ partition.num_records for partition in vcf_metadata.partitions
993
+ )
994
+ assert total_records == pcvcf.num_records
995
+
996
+ for field in vcf_metadata.fields:
997
+ # Clear the summary to avoid problems when running in debug
998
+ # syncronous mode
999
+ field.summary = VcfFieldSummary()
1000
+ for summary in partition_summaries:
1001
+ field.summary.update(summary[field.full_name])
1002
+
1003
+ with open(out_path / "metadata.json", "w") as f:
1004
+ json.dump(vcf_metadata.asdict(), f, indent=4)
1005
+ with open(out_path / "header.txt", "w") as f:
1006
+ f.write(header)
1007
+
1008
+
1009
+ def explode(
1010
+ vcfs,
1011
+ out_path,
1012
+ *,
1013
+ column_chunk_size=16,
1014
+ worker_processes=1,
1015
+ show_progress=False,
1016
+ ):
1017
+ out_path = pathlib.Path(out_path)
1018
+ if out_path.exists():
1019
+ shutil.rmtree(out_path)
1020
+
1021
+ PickleChunkedVcf.convert(
1022
+ vcfs,
1023
+ out_path,
1024
+ column_chunk_size=column_chunk_size,
1025
+ worker_processes=worker_processes,
1026
+ show_progress=show_progress,
1027
+ )
1028
+ return PickleChunkedVcf.load(out_path)
1029
+
1030
+
1031
+ def inspect(if_path):
1032
+ # TODO add support for the Zarr format also
1033
+ pcvcf = PickleChunkedVcf.load(if_path)
1034
+ return pcvcf.summary_table()
1035
+
1036
+
1037
+ @dataclasses.dataclass
1038
+ class ZarrColumnSpec:
1039
+ name: str
1040
+ dtype: str
1041
+ shape: tuple
1042
+ chunks: tuple
1043
+ dimensions: list
1044
+ description: str
1045
+ vcf_field: str
1046
+ compressor: dict
1047
+ # TODO add filters
1048
+
1049
+ def __post_init__(self):
1050
+ self.shape = tuple(self.shape)
1051
+ self.chunks = tuple(self.chunks)
1052
+ self.dimensions = tuple(self.dimensions)
1053
+
1054
+
1055
+ @dataclasses.dataclass
1056
+ class ZarrConversionSpec:
1057
+ format_version: str
1058
+ chunk_width: int
1059
+ chunk_length: int
1060
+ dimensions: list
1061
+ sample_id: list
1062
+ contig_id: list
1063
+ contig_length: list
1064
+ filter_id: list
1065
+ columns: dict
1066
+
1067
+ def asdict(self):
1068
+ return dataclasses.asdict(self)
1069
+
1070
+ def asjson(self):
1071
+ return json.dumps(self.asdict(), indent=4)
1072
+
1073
+ @staticmethod
1074
+ def fromdict(d):
1075
+ ret = ZarrConversionSpec(**d)
1076
+ ret.columns = {
1077
+ key: ZarrColumnSpec(**value) for key, value in d["columns"].items()
1078
+ }
1079
+ return ret
1080
+
1081
+ @staticmethod
1082
+ def fromjson(s):
1083
+ return ZarrConversionSpec.fromdict(json.loads(s))
1084
+
1085
+ @staticmethod
1086
+ def generate(pcvcf, chunk_length=None, chunk_width=None):
1087
+ m = pcvcf.num_records
1088
+ n = pcvcf.num_samples
1089
+ # FIXME
1090
+ if chunk_width is None:
1091
+ chunk_width = 1000
1092
+ if chunk_length is None:
1093
+ chunk_length = 10_000
1094
+ logger.info(f"Generating schema with chunks={chunk_length, chunk_width}")
1095
+ compressor = core.default_compressor.get_config()
1096
+
1097
+ def fixed_field_spec(
1098
+ name, dtype, vcf_field=None, shape=(m,), dimensions=("variants",)
1099
+ ):
1100
+ return ZarrColumnSpec(
1101
+ vcf_field=vcf_field,
1102
+ name=name,
1103
+ dtype=dtype,
1104
+ shape=shape,
1105
+ description="",
1106
+ dimensions=dimensions,
1107
+ chunks=[chunk_length],
1108
+ compressor=compressor,
1109
+ )
1110
+
1111
+ alt_col = pcvcf.columns["ALT"]
1112
+ max_alleles = alt_col.vcf_field.summary.max_number + 1
1113
+ num_filters = len(pcvcf.metadata.filters)
1114
+
1115
+ # # FIXME get dtype from lookup table
1116
+ colspecs = [
1117
+ fixed_field_spec(
1118
+ name="variant_contig",
1119
+ dtype="i2", # FIXME
1120
+ ),
1121
+ fixed_field_spec(
1122
+ name="variant_filter",
1123
+ dtype="bool",
1124
+ shape=(m, num_filters),
1125
+ dimensions=["variants", "filters"],
1126
+ ),
1127
+ fixed_field_spec(
1128
+ name="variant_allele",
1129
+ dtype="str",
1130
+ shape=[m, max_alleles],
1131
+ dimensions=["variants", "alleles"],
1132
+ ),
1133
+ fixed_field_spec(
1134
+ vcf_field="POS",
1135
+ name="variant_position",
1136
+ dtype="i4",
1137
+ ),
1138
+ fixed_field_spec(
1139
+ vcf_field=None,
1140
+ name="variant_id",
1141
+ dtype="str",
1142
+ ),
1143
+ fixed_field_spec(
1144
+ vcf_field=None,
1145
+ name="variant_id_mask",
1146
+ dtype="bool",
1147
+ ),
1148
+ fixed_field_spec(
1149
+ vcf_field="QUAL",
1150
+ name="variant_quality",
1151
+ dtype="f4",
1152
+ ),
1153
+ ]
1154
+
1155
+ gt_field = None
1156
+ for field in pcvcf.metadata.fields:
1157
+ if field.category == "fixed":
1158
+ continue
1159
+ if field.name == "GT":
1160
+ gt_field = field
1161
+ continue
1162
+ shape = [m]
1163
+ prefix = "variant_"
1164
+ dimensions = ["variants"]
1165
+ chunks = [chunk_length]
1166
+ if field.category == "FORMAT":
1167
+ prefix = "call_"
1168
+ shape.append(n)
1169
+ chunks.append(chunk_width),
1170
+ dimensions.append("samples")
1171
+ # TODO make an option to add in the empty extra dimension
1172
+ if field.summary.max_number > 1:
1173
+ shape.append(field.summary.max_number)
1174
+ dimensions.append(field.name)
1175
+ variable_name = prefix + field.name
1176
+ colspec = ZarrColumnSpec(
1177
+ vcf_field=field.full_name,
1178
+ name=variable_name,
1179
+ dtype=field.smallest_dtype(),
1180
+ shape=shape,
1181
+ chunks=chunks,
1182
+ dimensions=dimensions,
1183
+ description=field.description,
1184
+ compressor=compressor,
1185
+ )
1186
+ colspecs.append(colspec)
1187
+
1188
+ if gt_field is not None:
1189
+ ploidy = gt_field.summary.max_number - 1
1190
+ shape = [m, n]
1191
+ chunks = [chunk_length, chunk_width]
1192
+ dimensions = ["variants", "samples"]
1193
+
1194
+ colspecs.append(
1195
+ ZarrColumnSpec(
1196
+ vcf_field=None,
1197
+ name="call_genotype_phased",
1198
+ dtype="bool",
1199
+ shape=list(shape),
1200
+ chunks=list(chunks),
1201
+ dimensions=list(dimensions),
1202
+ description="",
1203
+ compressor=compressor,
1204
+ )
1205
+ )
1206
+ shape += [ploidy]
1207
+ dimensions += ["ploidy"]
1208
+ colspecs.append(
1209
+ ZarrColumnSpec(
1210
+ vcf_field=None,
1211
+ name="call_genotype",
1212
+ dtype=gt_field.smallest_dtype(),
1213
+ shape=list(shape),
1214
+ chunks=list(chunks),
1215
+ dimensions=list(dimensions),
1216
+ description="",
1217
+ compressor=compressor,
1218
+ )
1219
+ )
1220
+ colspecs.append(
1221
+ ZarrColumnSpec(
1222
+ vcf_field=None,
1223
+ name="call_genotype_mask",
1224
+ dtype="bool",
1225
+ shape=list(shape),
1226
+ chunks=list(chunks),
1227
+ dimensions=list(dimensions),
1228
+ description="",
1229
+ compressor=compressor,
1230
+ )
1231
+ )
1232
+
1233
+ return ZarrConversionSpec(
1234
+ # TODO do something systematic
1235
+ format_version="0.1",
1236
+ chunk_width=chunk_width,
1237
+ chunk_length=chunk_length,
1238
+ columns={col.name: col for col in colspecs},
1239
+ dimensions=["variants", "samples", "ploidy", "alleles", "filters"],
1240
+ sample_id=pcvcf.metadata.samples,
1241
+ contig_id=pcvcf.metadata.contig_names,
1242
+ contig_length=pcvcf.metadata.contig_lengths,
1243
+ filter_id=pcvcf.metadata.filters,
1244
+ )
1245
+
1246
+
1247
+ class SgvcfZarr:
1248
+ def __init__(self, path):
1249
+ self.path = pathlib.Path(path)
1250
+ self.root = None
1251
+
1252
+ def create_array(self, variable):
1253
+ # print("CREATE", variable)
1254
+ object_codec = None
1255
+ if variable.dtype == "O":
1256
+ object_codec = numcodecs.VLenUTF8()
1257
+ a = self.root.empty(
1258
+ variable.name,
1259
+ shape=variable.shape,
1260
+ chunks=variable.chunks,
1261
+ dtype=variable.dtype,
1262
+ compressor=numcodecs.get_codec(variable.compressor),
1263
+ object_codec=object_codec,
1264
+ )
1265
+ a.attrs["_ARRAY_DIMENSIONS"] = variable.dimensions
1266
+
1267
+ def encode_column_slice(self, pcvcf, column, start, stop):
1268
+ source_col = pcvcf.columns[column.vcf_field]
1269
+ array = self.root[column.name]
1270
+ ba = core.BufferedArray(array, start)
1271
+ sanitiser = source_col.sanitiser_factory(ba.buff.shape)
1272
+
1273
+ for value in source_col.iter_values(start, stop):
1274
+ # We write directly into the buffer in the sanitiser function
1275
+ # to make it easier to reason about dimension padding
1276
+ j = ba.next_buffer_row()
1277
+ sanitiser(ba.buff, j, value)
1278
+ ba.flush()
1279
+ logger.debug(f"Encoded {column.name} slice {start}:{stop}")
1280
+
1281
+ def encode_genotypes_slice(self, pcvcf, start, stop):
1282
+ source_col = pcvcf.columns["FORMAT/GT"]
1283
+ gt = core.BufferedArray(self.root["call_genotype"], start)
1284
+ gt_mask = core.BufferedArray(self.root["call_genotype_mask"], start)
1285
+ gt_phased = core.BufferedArray(self.root["call_genotype_phased"], start)
1286
+
1287
+ for value in source_col.iter_values(start, stop):
1288
+ j = gt.next_buffer_row()
1289
+ sanitise_value_int_2d(gt.buff, j, value[:, :-1])
1290
+ j = gt_phased.next_buffer_row()
1291
+ sanitise_value_int_1d(gt_phased.buff, j, value[:, -1])
1292
+ # TODO check is this the correct semantics when we are padding
1293
+ # with mixed ploidies?
1294
+ j = gt_mask.next_buffer_row()
1295
+ gt_mask.buff[j] = gt.buff[j] < 0
1296
+ gt.flush()
1297
+ gt_phased.flush()
1298
+ gt_mask.flush()
1299
+ logger.debug(f"Encoded GT slice {start}:{stop}")
1300
+
1301
+ def encode_alleles_slice(self, pcvcf, start, stop):
1302
+ ref_col = pcvcf.columns["REF"]
1303
+ alt_col = pcvcf.columns["ALT"]
1304
+ alleles = core.BufferedArray(self.root["variant_allele"], start)
1305
+
1306
+ for ref, alt in zip(
1307
+ ref_col.iter_values(start, stop), alt_col.iter_values(start, stop)
1308
+ ):
1309
+ j = alleles.next_buffer_row()
1310
+ alleles.buff[j, :] = STR_FILL
1311
+ alleles.buff[j, 0] = ref[0]
1312
+ alleles.buff[j, 1 : 1 + len(alt)] = alt
1313
+ alleles.flush()
1314
+ logger.debug(f"Encoded alleles slice {start}:{stop}")
1315
+
1316
+ def encode_id_slice(self, pcvcf, start, stop):
1317
+ col = pcvcf.columns["ID"]
1318
+ vid = core.BufferedArray(self.root["variant_id"], start)
1319
+ vid_mask = core.BufferedArray(self.root["variant_id_mask"], start)
1320
+
1321
+ for value in col.iter_values(start, stop):
1322
+ j = vid.next_buffer_row()
1323
+ k = vid_mask.next_buffer_row()
1324
+ assert j == k
1325
+ if value is not None:
1326
+ vid.buff[j] = value[0]
1327
+ vid_mask.buff[j] = False
1328
+ else:
1329
+ vid.buff[j] = STR_MISSING
1330
+ vid_mask.buff[j] = True
1331
+ vid.flush()
1332
+ vid_mask.flush()
1333
+ logger.debug(f"Encoded ID slice {start}:{stop}")
1334
+
1335
+ def encode_filters_slice(self, pcvcf, lookup, start, stop):
1336
+ col = pcvcf.columns["FILTERS"]
1337
+ var_filter = core.BufferedArray(self.root["variant_filter"], start)
1338
+
1339
+ for value in col.iter_values(start, stop):
1340
+ j = var_filter.next_buffer_row()
1341
+ var_filter.buff[j] = False
1342
+ try:
1343
+ for f in value:
1344
+ var_filter.buff[j, lookup[f]] = True
1345
+ except IndexError:
1346
+ raise ValueError(f"Filter '{f}' was not defined in the header.")
1347
+ var_filter.flush()
1348
+ logger.debug(f"Encoded FILTERS slice {start}:{stop}")
1349
+
1350
+ def encode_contig_slice(self, pcvcf, lookup, start, stop):
1351
+ col = pcvcf.columns["CHROM"]
1352
+ contig = core.BufferedArray(self.root["variant_contig"], start)
1353
+
1354
+ for value in col.iter_values(start, stop):
1355
+ j = contig.next_buffer_row()
1356
+ try:
1357
+ contig.buff[j] = lookup[value[0]]
1358
+ except KeyError:
1359
+ # TODO add advice about adding it to the spec
1360
+ raise ValueError(f"Contig '{contig}' was not defined in the header.")
1361
+ contig.flush()
1362
+ logger.debug(f"Encoded CHROM slice {start}:{stop}")
1363
+
1364
+ def encode_samples(self, pcvcf, sample_id, chunk_width):
1365
+ if not np.array_equal(sample_id, pcvcf.metadata.samples):
1366
+ raise ValueError("Subsetting or reordering samples not supported currently")
1367
+ array = self.root.array(
1368
+ "sample_id",
1369
+ sample_id,
1370
+ dtype="str",
1371
+ compressor=core.default_compressor,
1372
+ chunks=(chunk_width,),
1373
+ )
1374
+ array.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
1375
+ logger.debug("Samples done")
1376
+
1377
+ def encode_contig_id(self, pcvcf, contig_names, contig_lengths):
1378
+ array = self.root.array(
1379
+ "contig_id",
1380
+ contig_names,
1381
+ dtype="str",
1382
+ compressor=core.default_compressor,
1383
+ )
1384
+ array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1385
+ if contig_lengths is not None:
1386
+ array = self.root.array(
1387
+ "contig_length",
1388
+ contig_lengths,
1389
+ dtype=np.int64,
1390
+ )
1391
+ array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1392
+ return {v: j for j, v in enumerate(contig_names)}
1393
+
1394
+ def encode_filter_id(self, pcvcf, filter_names):
1395
+ array = self.root.array(
1396
+ "filter_id",
1397
+ filter_names,
1398
+ dtype="str",
1399
+ compressor=core.default_compressor,
1400
+ )
1401
+ array.attrs["_ARRAY_DIMENSIONS"] = ["filters"]
1402
+ return {v: j for j, v in enumerate(filter_names)}
1403
+
1404
+ @staticmethod
1405
+ def encode(
1406
+ pcvcf,
1407
+ path,
1408
+ conversion_spec,
1409
+ *,
1410
+ worker_processes=1,
1411
+ max_v_chunks=None,
1412
+ show_progress=False,
1413
+ ):
1414
+ path = pathlib.Path(path)
1415
+ # TODO: we should do this as a future to avoid blocking
1416
+ if path.exists():
1417
+ logger.warning(f"Deleting existing {path}")
1418
+ shutil.rmtree(path)
1419
+ write_path = path.with_suffix(path.suffix + f".{os.getpid()}.build")
1420
+ store = zarr.DirectoryStore(write_path)
1421
+ # FIXME, duplicating logic about the store
1422
+ logger.info(f"Create zarr at {write_path}")
1423
+ sgvcf = SgvcfZarr(write_path)
1424
+ sgvcf.root = zarr.group(store=store, overwrite=True)
1425
+ for column in conversion_spec.columns.values():
1426
+ sgvcf.create_array(column)
1427
+
1428
+ sgvcf.root.attrs["vcf_zarr_version"] = "0.2"
1429
+ sgvcf.root.attrs["vcf_header"] = pcvcf.vcf_header
1430
+ sgvcf.root.attrs["source"] = f"bio2zarr-{provenance.__version__}"
1431
+
1432
+ num_slices = max(1, worker_processes * 4)
1433
+ # Using POS arbitrarily to get the array slices
1434
+ slices = core.chunk_aligned_slices(
1435
+ sgvcf.root["variant_position"], num_slices, max_chunks=max_v_chunks
1436
+ )
1437
+ truncated = slices[-1][-1]
1438
+ for array in sgvcf.root.values():
1439
+ if array.attrs["_ARRAY_DIMENSIONS"][0] == "variants":
1440
+ shape = list(array.shape)
1441
+ shape[0] = truncated
1442
+ array.resize(shape)
1443
+
1444
+ chunked_1d = [
1445
+ col for col in conversion_spec.columns.values() if len(col.chunks) <= 1
1446
+ ]
1447
+ progress_config = core.ProgressConfig(
1448
+ total=sum(sgvcf.root[col.name].nchunks for col in chunked_1d),
1449
+ title="Encode 1D",
1450
+ units="chunks",
1451
+ show=show_progress,
1452
+ )
1453
+
1454
+ # Do these syncronously for simplicity so we have the mapping
1455
+ filter_id_map = sgvcf.encode_filter_id(pcvcf, conversion_spec.filter_id)
1456
+ contig_id_map = sgvcf.encode_contig_id(
1457
+ pcvcf, conversion_spec.contig_id, conversion_spec.contig_length
1458
+ )
1459
+
1460
+ with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1461
+ pwm.submit(
1462
+ sgvcf.encode_samples,
1463
+ pcvcf,
1464
+ conversion_spec.sample_id,
1465
+ conversion_spec.chunk_width,
1466
+ )
1467
+ for start, stop in slices:
1468
+ pwm.submit(sgvcf.encode_alleles_slice, pcvcf, start, stop)
1469
+ pwm.submit(sgvcf.encode_id_slice, pcvcf, start, stop)
1470
+ pwm.submit(
1471
+ sgvcf.encode_filters_slice, pcvcf, filter_id_map, start, stop
1472
+ )
1473
+ pwm.submit(sgvcf.encode_contig_slice, pcvcf, contig_id_map, start, stop)
1474
+ for col in chunked_1d:
1475
+ if col.vcf_field is not None:
1476
+ pwm.submit(sgvcf.encode_column_slice, pcvcf, col, start, stop)
1477
+
1478
+ chunked_2d = [
1479
+ col for col in conversion_spec.columns.values() if len(col.chunks) >= 2
1480
+ ]
1481
+ if len(chunked_2d) > 0:
1482
+ progress_config = core.ProgressConfig(
1483
+ total=sum(sgvcf.root[col.name].nchunks for col in chunked_2d),
1484
+ title="Encode 2D",
1485
+ units="chunks",
1486
+ show=show_progress,
1487
+ )
1488
+ with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1489
+ if "call_genotype" in conversion_spec.columns:
1490
+ logger.info(f"Submit encode call_genotypes in {len(slices)} slices")
1491
+ for start, stop in slices:
1492
+ pwm.submit(sgvcf.encode_genotypes_slice, pcvcf, start, stop)
1493
+
1494
+ for col in chunked_2d:
1495
+ if col.vcf_field is not None:
1496
+ logger.info(f"Submit encode {col.name} in {len(slices)} slices")
1497
+ for start, stop in slices:
1498
+ pwm.submit(
1499
+ sgvcf.encode_column_slice, pcvcf, col, start, stop
1500
+ )
1501
+
1502
+ zarr.consolidate_metadata(write_path)
1503
+ # Atomic swap, now we've completely finished.
1504
+ logger.info(f"Moving to final path {path}")
1505
+ os.rename(write_path, path)
1506
+
1507
+
1508
+ def mkschema(if_path, out):
1509
+ pcvcf = PickleChunkedVcf.load(if_path)
1510
+ spec = ZarrConversionSpec.generate(pcvcf)
1511
+ out.write(spec.asjson())
1512
+
1513
+
1514
+ def encode(
1515
+ if_path,
1516
+ zarr_path,
1517
+ schema_path=None,
1518
+ chunk_length=None,
1519
+ chunk_width=None,
1520
+ max_v_chunks=None,
1521
+ worker_processes=1,
1522
+ show_progress=False,
1523
+ ):
1524
+ pcvcf = PickleChunkedVcf.load(if_path)
1525
+ if schema_path is None:
1526
+ schema = ZarrConversionSpec.generate(
1527
+ pcvcf,
1528
+ chunk_length=chunk_length,
1529
+ chunk_width=chunk_width,
1530
+ )
1531
+ else:
1532
+ logger.info(f"Reading schema from {schema_path}")
1533
+ if chunk_length is not None or chunk_width is not None:
1534
+ raise ValueError("Cannot specify schema along with chunk sizes")
1535
+ with open(schema_path, "r") as f:
1536
+ schema = ZarrConversionSpec.fromjson(f.read())
1537
+
1538
+ SgvcfZarr.encode(
1539
+ pcvcf,
1540
+ zarr_path,
1541
+ conversion_spec=schema,
1542
+ max_v_chunks=max_v_chunks,
1543
+ worker_processes=worker_processes,
1544
+ show_progress=show_progress,
1545
+ )
1546
+
1547
+
1548
+ def convert(
1549
+ vcfs,
1550
+ out_path,
1551
+ *,
1552
+ chunk_length=None,
1553
+ chunk_width=None,
1554
+ worker_processes=1,
1555
+ show_progress=False,
1556
+ # TODO add arguments to control location of tmpdir
1557
+ ):
1558
+ with tempfile.TemporaryDirectory(prefix="vcf2zarr_if_") as if_dir:
1559
+ explode(
1560
+ vcfs,
1561
+ if_dir,
1562
+ worker_processes=worker_processes,
1563
+ show_progress=show_progress,
1564
+ )
1565
+ encode(
1566
+ if_dir,
1567
+ out_path,
1568
+ chunk_length=chunk_length,
1569
+ chunk_width=chunk_width,
1570
+ worker_processes=worker_processes,
1571
+ show_progress=show_progress,
1572
+ )
1573
+
1574
+
1575
+ def assert_all_missing_float(a):
1576
+ v = np.array(a, dtype=np.float32).view(np.int32)
1577
+ nt.assert_equal(v, FLOAT32_MISSING_AS_INT32)
1578
+
1579
+
1580
+ def assert_all_fill_float(a):
1581
+ v = np.array(a, dtype=np.float32).view(np.int32)
1582
+ nt.assert_equal(v, FLOAT32_FILL_AS_INT32)
1583
+
1584
+
1585
+ def assert_all_missing_int(a):
1586
+ v = np.array(a, dtype=int)
1587
+ nt.assert_equal(v, -1)
1588
+
1589
+
1590
+ def assert_all_fill_int(a):
1591
+ v = np.array(a, dtype=int)
1592
+ nt.assert_equal(v, -2)
1593
+
1594
+
1595
+ def assert_all_missing_string(a):
1596
+ nt.assert_equal(a, ".")
1597
+
1598
+
1599
+ def assert_all_fill_string(a):
1600
+ nt.assert_equal(a, "")
1601
+
1602
+
1603
+ def assert_all_fill(zarr_val, vcf_type):
1604
+ if vcf_type == "Integer":
1605
+ assert_all_fill_int(zarr_val)
1606
+ elif vcf_type in ("String", "Character"):
1607
+ assert_all_fill_string(zarr_val)
1608
+ elif vcf_type == "Float":
1609
+ assert_all_fill_float(zarr_val)
1610
+ else: # pragma: no cover
1611
+ assert False
1612
+
1613
+
1614
+ def assert_all_missing(zarr_val, vcf_type):
1615
+ if vcf_type == "Integer":
1616
+ assert_all_missing_int(zarr_val)
1617
+ elif vcf_type in ("String", "Character"):
1618
+ assert_all_missing_string(zarr_val)
1619
+ elif vcf_type == "Flag":
1620
+ assert zarr_val == False # noqa 712
1621
+ elif vcf_type == "Float":
1622
+ assert_all_missing_float(zarr_val)
1623
+ else: # pragma: no cover
1624
+ assert False
1625
+
1626
+
1627
+ def assert_info_val_missing(zarr_val, vcf_type):
1628
+ assert_all_missing(zarr_val, vcf_type)
1629
+
1630
+
1631
+ def assert_format_val_missing(zarr_val, vcf_type):
1632
+ assert_info_val_missing(zarr_val, vcf_type)
1633
+
1634
+
1635
+ # Note: checking exact equality may prove problematic here
1636
+ # but we should be deterministically storing what cyvcf2
1637
+ # provides, which should compare equal.
1638
+
1639
+
1640
+ def assert_info_val_equal(vcf_val, zarr_val, vcf_type):
1641
+ assert vcf_val is not None
1642
+ if vcf_type in ("String", "Character"):
1643
+ split = list(vcf_val.split(","))
1644
+ k = len(split)
1645
+ if isinstance(zarr_val, str):
1646
+ assert k == 1
1647
+ # Scalar
1648
+ assert vcf_val == zarr_val
1649
+ else:
1650
+ nt.assert_equal(split, zarr_val[:k])
1651
+ assert_all_fill(zarr_val[k:], vcf_type)
1652
+
1653
+ elif isinstance(vcf_val, tuple):
1654
+ vcf_missing_value_map = {
1655
+ "Integer": -1,
1656
+ "Float": FLOAT32_MISSING,
1657
+ }
1658
+ v = [vcf_missing_value_map[vcf_type] if x is None else x for x in vcf_val]
1659
+ missing = np.array([j for j, x in enumerate(vcf_val) if x is None], dtype=int)
1660
+ a = np.array(v)
1661
+ k = len(a)
1662
+ # We are checking for int missing twice here, but it's necessary to have
1663
+ # a separate check for floats because different NaNs compare equal
1664
+ nt.assert_equal(a, zarr_val[:k])
1665
+ assert_all_missing(zarr_val[missing], vcf_type)
1666
+ if k < len(zarr_val):
1667
+ assert_all_fill(zarr_val[k:], vcf_type)
1668
+ else:
1669
+ # Scalar
1670
+ zarr_val = np.array(zarr_val, ndmin=1)
1671
+ assert len(zarr_val.shape) == 1
1672
+ assert vcf_val == zarr_val[0]
1673
+ if len(zarr_val) > 1:
1674
+ assert_all_fill(zarr_val[1:], vcf_type)
1675
+
1676
+
1677
+ def assert_format_val_equal(vcf_val, zarr_val, vcf_type):
1678
+ assert vcf_val is not None
1679
+ assert isinstance(vcf_val, np.ndarray)
1680
+ if vcf_type in ("String", "Character"):
1681
+ assert len(vcf_val) == len(zarr_val)
1682
+ for v, z in zip(vcf_val, zarr_val):
1683
+ split = list(v.split(","))
1684
+ # Note: deliberately duplicating logic here between this and the
1685
+ # INFO col above to make sure all combinations are covered by tests
1686
+ k = len(split)
1687
+ if k == 1:
1688
+ assert v == z
1689
+ else:
1690
+ nt.assert_equal(split, z[:k])
1691
+ assert_all_fill(z[k:], vcf_type)
1692
+ else:
1693
+ assert vcf_val.shape[0] == zarr_val.shape[0]
1694
+ if len(vcf_val.shape) == len(zarr_val.shape) + 1:
1695
+ assert vcf_val.shape[-1] == 1
1696
+ vcf_val = vcf_val[..., 0]
1697
+ assert len(vcf_val.shape) <= 2
1698
+ assert len(vcf_val.shape) == len(zarr_val.shape)
1699
+ if len(vcf_val.shape) == 2:
1700
+ k = vcf_val.shape[1]
1701
+ if zarr_val.shape[1] != k:
1702
+ assert_all_fill(zarr_val[:, k:], vcf_type)
1703
+ zarr_val = zarr_val[:, :k]
1704
+ assert vcf_val.shape == zarr_val.shape
1705
+ if vcf_type == "Integer":
1706
+ vcf_val[vcf_val == VCF_INT_MISSING] = INT_MISSING
1707
+ vcf_val[vcf_val == VCF_INT_FILL] = INT_FILL
1708
+ elif vcf_type == "Float":
1709
+ nt.assert_equal(vcf_val.view(np.int32), zarr_val.view(np.int32))
1710
+
1711
+ nt.assert_equal(vcf_val, zarr_val)
1712
+
1713
+
1714
+ # TODO rename to "verify"
1715
+ def validate(vcf_path, zarr_path, show_progress=False):
1716
+ store = zarr.DirectoryStore(zarr_path)
1717
+
1718
+ root = zarr.group(store=store)
1719
+ pos = root["variant_position"][:]
1720
+ allele = root["variant_allele"][:]
1721
+ chrom = root["contig_id"][:][root["variant_contig"][:]]
1722
+ vid = root["variant_id"][:]
1723
+ call_genotype = None
1724
+ if "call_genotype" in root:
1725
+ call_genotype = iter(root["call_genotype"])
1726
+
1727
+ vcf = cyvcf2.VCF(vcf_path)
1728
+ format_headers = {}
1729
+ info_headers = {}
1730
+ for h in vcf.header_iter():
1731
+ if h["HeaderType"] == "FORMAT":
1732
+ format_headers[h["ID"]] = h
1733
+ if h["HeaderType"] == "INFO":
1734
+ info_headers[h["ID"]] = h
1735
+
1736
+ format_fields = {}
1737
+ info_fields = {}
1738
+ for colname in root.keys():
1739
+ if colname.startswith("call") and not colname.startswith("call_genotype"):
1740
+ vcf_name = colname.split("_", 1)[1]
1741
+ vcf_type = format_headers[vcf_name]["Type"]
1742
+ format_fields[vcf_name] = vcf_type, iter(root[colname])
1743
+ if colname.startswith("variant"):
1744
+ name = colname.split("_", 1)[1]
1745
+ if name.isupper():
1746
+ vcf_type = info_headers[name]["Type"]
1747
+ # print(root[colname])
1748
+ info_fields[name] = vcf_type, iter(root[colname])
1749
+ # print(info_fields)
1750
+
1751
+ first_pos = next(vcf).POS
1752
+ start_index = np.searchsorted(pos, first_pos)
1753
+ assert pos[start_index] == first_pos
1754
+ vcf = cyvcf2.VCF(vcf_path)
1755
+ if show_progress:
1756
+ iterator = tqdm.tqdm(vcf, desc=" Verify", total=vcf.num_records)
1757
+ else:
1758
+ iterator = vcf
1759
+ for j, row in enumerate(iterator, start_index):
1760
+ assert chrom[j] == row.CHROM
1761
+ assert pos[j] == row.POS
1762
+ assert vid[j] == ("." if row.ID is None else row.ID)
1763
+ assert allele[j, 0] == row.REF
1764
+ k = len(row.ALT)
1765
+ nt.assert_array_equal(allele[j, 1 : k + 1], row.ALT),
1766
+ assert np.all(allele[j, k + 1 :] == "")
1767
+ # TODO FILTERS
1768
+
1769
+ if call_genotype is None:
1770
+ val = None
1771
+ try:
1772
+ val = row.format("GT")
1773
+ except KeyError:
1774
+ pass
1775
+ assert val is None
1776
+ else:
1777
+ gt = row.genotype.array()
1778
+ gt_zarr = next(call_genotype)
1779
+ gt_vcf = gt[:, :-1]
1780
+ # NOTE cyvcf2 remaps genotypes automatically
1781
+ # into the same missing/pad encoding that sgkit uses.
1782
+ nt.assert_array_equal(gt_zarr, gt_vcf)
1783
+
1784
+ for name, (vcf_type, zarr_iter) in info_fields.items():
1785
+ vcf_val = row.INFO.get(name, None)
1786
+ zarr_val = next(zarr_iter)
1787
+ if vcf_val is None:
1788
+ assert_info_val_missing(zarr_val, vcf_type)
1789
+ else:
1790
+ assert_info_val_equal(vcf_val, zarr_val, vcf_type)
1791
+
1792
+ for name, (vcf_type, zarr_iter) in format_fields.items():
1793
+ vcf_val = None
1794
+ try:
1795
+ vcf_val = row.format(name)
1796
+ except KeyError:
1797
+ pass
1798
+ zarr_val = next(zarr_iter)
1799
+ if vcf_val is None:
1800
+ assert_format_val_missing(zarr_val, vcf_type)
1801
+ else:
1802
+ assert_format_val_equal(vcf_val, zarr_val, vcf_type)