bio2zarr 0.0.6__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bio2zarr might be problematic. Click here for more details.

bio2zarr/vcf.py DELETED
@@ -1,2406 +0,0 @@
1
- import collections
2
- import contextlib
3
- import dataclasses
4
- import functools
5
- import json
6
- import logging
7
- import math
8
- import os
9
- import pathlib
10
- import pickle
11
- import shutil
12
- import sys
13
- import tempfile
14
- from typing import Any, List
15
-
16
- import cyvcf2
17
- import humanfriendly
18
- import numcodecs
19
- import numpy as np
20
- import numpy.testing as nt
21
- import tqdm
22
- import zarr
23
-
24
- from . import core, provenance, vcf_utils
25
-
26
- logger = logging.getLogger(__name__)
27
-
28
- INT_MISSING = -1
29
- INT_FILL = -2
30
- STR_MISSING = "."
31
- STR_FILL = ""
32
-
33
- FLOAT32_MISSING, FLOAT32_FILL = np.array([0x7F800001, 0x7F800002], dtype=np.int32).view(
34
- np.float32
35
- )
36
- FLOAT32_MISSING_AS_INT32, FLOAT32_FILL_AS_INT32 = np.array(
37
- [0x7F800001, 0x7F800002], dtype=np.int32
38
- )
39
-
40
-
41
- def display_number(x):
42
- ret = "n/a"
43
- if math.isfinite(x):
44
- ret = f"{x: 0.2g}"
45
- return ret
46
-
47
-
48
- def display_size(n):
49
- return humanfriendly.format_size(n, binary=True)
50
-
51
-
52
- @dataclasses.dataclass
53
- class VcfFieldSummary:
54
- num_chunks: int = 0
55
- compressed_size: int = 0
56
- uncompressed_size: int = 0
57
- max_number: int = 0 # Corresponds to VCF Number field, depends on context
58
- # Only defined for numeric fields
59
- max_value: Any = -math.inf
60
- min_value: Any = math.inf
61
-
62
- def update(self, other):
63
- self.num_chunks += other.num_chunks
64
- self.compressed_size += other.compressed_size
65
- self.uncompressed_size += other.uncompressed_size
66
- self.max_number = max(self.max_number, other.max_number)
67
- self.min_value = min(self.min_value, other.min_value)
68
- self.max_value = max(self.max_value, other.max_value)
69
-
70
- def asdict(self):
71
- return dataclasses.asdict(self)
72
-
73
- @staticmethod
74
- def fromdict(d):
75
- return VcfFieldSummary(**d)
76
-
77
-
78
- @dataclasses.dataclass
79
- class VcfField:
80
- category: str
81
- name: str
82
- vcf_number: str
83
- vcf_type: str
84
- description: str
85
- summary: VcfFieldSummary
86
-
87
- @staticmethod
88
- def from_header(definition):
89
- category = definition["HeaderType"]
90
- name = definition["ID"]
91
- vcf_number = definition["Number"]
92
- vcf_type = definition["Type"]
93
- return VcfField(
94
- category=category,
95
- name=name,
96
- vcf_number=vcf_number,
97
- vcf_type=vcf_type,
98
- description=definition["Description"].strip('"'),
99
- summary=VcfFieldSummary(),
100
- )
101
-
102
- @staticmethod
103
- def fromdict(d):
104
- f = VcfField(**d)
105
- f.summary = VcfFieldSummary(**d["summary"])
106
- return f
107
-
108
- @property
109
- def full_name(self):
110
- if self.category == "fixed":
111
- return self.name
112
- return f"{self.category}/{self.name}"
113
-
114
- def smallest_dtype(self):
115
- """
116
- Returns the smallest dtype suitable for this field based
117
- on type, and values.
118
- """
119
- s = self.summary
120
- if self.vcf_type == "Float":
121
- ret = "f4"
122
- elif self.vcf_type == "Integer":
123
- if not math.isfinite(s.max_value):
124
- # All missing values; use i1. Note we should have some API to
125
- # check more explicitly for missingness:
126
- # https://github.com/sgkit-dev/bio2zarr/issues/131
127
- ret = "i1"
128
- else:
129
- ret = core.min_int_dtype(s.min_value, s.max_value)
130
- elif self.vcf_type == "Flag":
131
- ret = "bool"
132
- elif self.vcf_type == "Character":
133
- ret = "U1"
134
- else:
135
- assert self.vcf_type == "String"
136
- ret = "O"
137
- return ret
138
-
139
-
140
- @dataclasses.dataclass
141
- class VcfPartition:
142
- vcf_path: str
143
- region: str
144
- num_records: int = -1
145
-
146
-
147
- ICF_METADATA_FORMAT_VERSION = "0.2"
148
- ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc(
149
- cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE
150
- )
151
-
152
- # TODO refactor this to have embedded Contig dataclass, Filters
153
- # and Samples dataclasses to allow for more information to be
154
- # retained and forward compatibility.
155
-
156
-
157
- @dataclasses.dataclass
158
- class IcfMetadata:
159
- samples: list
160
- contig_names: list
161
- contig_record_counts: dict
162
- filters: list
163
- fields: list
164
- partitions: list = None
165
- contig_lengths: list = None
166
- format_version: str = None
167
- compressor: dict = None
168
- column_chunk_size: int = None
169
- provenance: dict = None
170
-
171
- @property
172
- def info_fields(self):
173
- fields = []
174
- for field in self.fields:
175
- if field.category == "INFO":
176
- fields.append(field)
177
- return fields
178
-
179
- @property
180
- def format_fields(self):
181
- fields = []
182
- for field in self.fields:
183
- if field.category == "FORMAT":
184
- fields.append(field)
185
- return fields
186
-
187
- @property
188
- def num_contigs(self):
189
- return len(self.contig_names)
190
-
191
- @property
192
- def num_filters(self):
193
- return len(self.filters)
194
-
195
- @property
196
- def num_records(self):
197
- return sum(self.contig_record_counts.values())
198
-
199
- @staticmethod
200
- def fromdict(d):
201
- if d["format_version"] != ICF_METADATA_FORMAT_VERSION:
202
- raise ValueError(
203
- "Intermediate columnar metadata format version mismatch: "
204
- f"{d['format_version']} != {ICF_METADATA_FORMAT_VERSION}"
205
- )
206
- fields = [VcfField.fromdict(fd) for fd in d["fields"]]
207
- partitions = [VcfPartition(**pd) for pd in d["partitions"]]
208
- for p in partitions:
209
- p.region = vcf_utils.Region(**p.region)
210
- d = d.copy()
211
- d["fields"] = fields
212
- d["partitions"] = partitions
213
- return IcfMetadata(**d)
214
-
215
- def asdict(self):
216
- return dataclasses.asdict(self)
217
-
218
-
219
- def fixed_vcf_field_definitions():
220
- def make_field_def(name, vcf_type, vcf_number):
221
- return VcfField(
222
- category="fixed",
223
- name=name,
224
- vcf_type=vcf_type,
225
- vcf_number=vcf_number,
226
- description="",
227
- summary=VcfFieldSummary(),
228
- )
229
-
230
- fields = [
231
- make_field_def("CHROM", "String", "1"),
232
- make_field_def("POS", "Integer", "1"),
233
- make_field_def("QUAL", "Float", "1"),
234
- make_field_def("ID", "String", "."),
235
- make_field_def("FILTERS", "String", "."),
236
- make_field_def("REF", "String", "1"),
237
- make_field_def("ALT", "String", "."),
238
- ]
239
- return fields
240
-
241
-
242
- def scan_vcf(path, target_num_partitions):
243
- with vcf_utils.IndexedVcf(path) as indexed_vcf:
244
- vcf = indexed_vcf.vcf
245
- filters = [
246
- h["ID"]
247
- for h in vcf.header_iter()
248
- if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str)
249
- ]
250
- # Ensure PASS is the first filter if present
251
- if "PASS" in filters:
252
- filters.remove("PASS")
253
- filters.insert(0, "PASS")
254
-
255
- fields = fixed_vcf_field_definitions()
256
- for h in vcf.header_iter():
257
- if h["HeaderType"] in ["INFO", "FORMAT"]:
258
- field = VcfField.from_header(h)
259
- if field.name == "GT":
260
- field.vcf_type = "Integer"
261
- field.vcf_number = "."
262
- fields.append(field)
263
-
264
- metadata = IcfMetadata(
265
- samples=vcf.samples,
266
- contig_names=vcf.seqnames,
267
- contig_record_counts=indexed_vcf.contig_record_counts(),
268
- filters=filters,
269
- fields=fields,
270
- partitions=[],
271
- )
272
- try:
273
- metadata.contig_lengths = vcf.seqlens
274
- except AttributeError:
275
- pass
276
-
277
- regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions)
278
- logger.info(
279
- f"Split {path} into {len(regions)} regions (target={target_num_partitions})"
280
- )
281
- for region in regions:
282
- metadata.partitions.append(
283
- VcfPartition(
284
- # TODO should this be fully resolving the path? Otherwise it's all
285
- # relative to the original WD
286
- vcf_path=str(path),
287
- region=region,
288
- )
289
- )
290
- core.update_progress(1)
291
- return metadata, vcf.raw_header
292
-
293
-
294
- def check_overlap(partitions):
295
- for i in range(1, len(partitions)):
296
- prev_partition = partitions[i - 1]
297
- current_partition = partitions[i]
298
- if (
299
- prev_partition.region.contig == current_partition.region.contig
300
- and prev_partition.region.end > current_partition.region.start
301
- ):
302
- raise ValueError(
303
- f"Multiple VCFs have the region "
304
- f"{prev_partition.region.contig}:{prev_partition.region.start}-"
305
- f"{current_partition.region.end}"
306
- )
307
-
308
-
309
- def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
310
- logger.info(
311
- f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
312
- f" partitions."
313
- )
314
- # An easy mistake to make is to pass the same file twice. Check this early on.
315
- for path, count in collections.Counter(paths).items():
316
- if not path.exists(): # NEEDS TEST
317
- raise FileNotFoundError(path)
318
- if count > 1:
319
- raise ValueError(f"Duplicate path provided: {path}")
320
-
321
- progress_config = core.ProgressConfig(
322
- total=len(paths),
323
- units="files",
324
- title="Scan",
325
- show=show_progress,
326
- )
327
- with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
328
- for path in paths:
329
- pwm.submit(scan_vcf, path, max(1, target_num_partitions // len(paths)))
330
- results = list(pwm.results_as_completed())
331
-
332
- # Sort to make the ordering deterministic
333
- results.sort(key=lambda t: t[0].partitions[0].vcf_path)
334
- # We just take the first header, assuming the others
335
- # are compatible.
336
- all_partitions = []
337
- contig_record_counts = collections.Counter()
338
- for metadata, _ in results:
339
- all_partitions.extend(metadata.partitions)
340
- metadata.partitions.clear()
341
- contig_record_counts += metadata.contig_record_counts
342
- metadata.contig_record_counts.clear()
343
-
344
- icf_metadata, header = results[0]
345
- for metadata, _ in results[1:]:
346
- if metadata != icf_metadata:
347
- raise ValueError("Incompatible VCF chunks")
348
-
349
- icf_metadata.contig_record_counts = dict(contig_record_counts)
350
-
351
- # Sort by contig (in the order they appear in the header) first,
352
- # then by start coordinate
353
- contig_index_map = {contig: j for j, contig in enumerate(metadata.contig_names)}
354
- all_partitions.sort(
355
- key=lambda x: (contig_index_map[x.region.contig], x.region.start)
356
- )
357
- check_overlap(all_partitions)
358
- icf_metadata.partitions = all_partitions
359
- logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.")
360
- return icf_metadata, header
361
-
362
-
363
- def sanitise_value_bool(buff, j, value):
364
- x = True
365
- if value is None:
366
- x = False
367
- buff[j] = x
368
-
369
-
370
- def sanitise_value_float_scalar(buff, j, value):
371
- x = value
372
- if value is None:
373
- x = FLOAT32_MISSING
374
- buff[j] = x
375
-
376
-
377
- def sanitise_value_int_scalar(buff, j, value):
378
- x = value
379
- if value is None:
380
- # print("MISSING", INT_MISSING, INT_FILL)
381
- x = [INT_MISSING]
382
- else:
383
- x = sanitise_int_array([value], ndmin=1, dtype=np.int32)
384
- buff[j] = x[0]
385
-
386
-
387
- def sanitise_value_string_scalar(buff, j, value):
388
- if value is None:
389
- buff[j] = "."
390
- else:
391
- buff[j] = value[0]
392
-
393
-
394
- def sanitise_value_string_1d(buff, j, value):
395
- if value is None:
396
- buff[j] = "."
397
- else:
398
- # value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
399
- # FIXME failure isn't coming from here, it seems to be from an
400
- # incorrectly detected dimension in the zarr array
401
- # The dimesions look all wrong, and the dtype should be Object
402
- # not str
403
- value = drop_empty_second_dim(value)
404
- buff[j] = ""
405
- buff[j, : value.shape[0]] = value
406
-
407
-
408
- def sanitise_value_string_2d(buff, j, value):
409
- if value is None:
410
- buff[j] = "."
411
- else:
412
- # print(buff.shape, value.dtype, value)
413
- # assert value.ndim == 2
414
- buff[j] = ""
415
- if value.ndim == 2:
416
- buff[j, :, : value.shape[1]] = value
417
- else:
418
- # TODO check if this is still necessary
419
- for k, val in enumerate(value):
420
- buff[j, k, : len(val)] = val
421
-
422
-
423
- def drop_empty_second_dim(value):
424
- assert len(value.shape) == 1 or value.shape[1] == 1
425
- if len(value.shape) == 2 and value.shape[1] == 1:
426
- value = value[..., 0]
427
- return value
428
-
429
-
430
- def sanitise_value_float_1d(buff, j, value):
431
- if value is None:
432
- buff[j] = FLOAT32_MISSING
433
- else:
434
- value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
435
- # numpy will map None values to Nan, but we need a
436
- # specific NaN
437
- value[np.isnan(value)] = FLOAT32_MISSING
438
- value = drop_empty_second_dim(value)
439
- buff[j] = FLOAT32_FILL
440
- buff[j, : value.shape[0]] = value
441
-
442
-
443
- def sanitise_value_float_2d(buff, j, value):
444
- if value is None:
445
- buff[j] = FLOAT32_MISSING
446
- else:
447
- # print("value = ", value)
448
- value = np.array(value, ndmin=2, dtype=buff.dtype, copy=False)
449
- buff[j] = FLOAT32_FILL
450
- buff[j, :, : value.shape[1]] = value
451
-
452
-
453
- def sanitise_int_array(value, ndmin, dtype):
454
- if isinstance(value, tuple):
455
- value = [VCF_INT_MISSING if x is None else x for x in value] # NEEDS TEST
456
- value = np.array(value, ndmin=ndmin, copy=False)
457
- value[value == VCF_INT_MISSING] = -1
458
- value[value == VCF_INT_FILL] = -2
459
- # TODO watch out for clipping here!
460
- return value.astype(dtype)
461
-
462
-
463
- def sanitise_value_int_1d(buff, j, value):
464
- if value is None:
465
- buff[j] = -1
466
- else:
467
- value = sanitise_int_array(value, 1, buff.dtype)
468
- value = drop_empty_second_dim(value)
469
- buff[j] = -2
470
- buff[j, : value.shape[0]] = value
471
-
472
-
473
- def sanitise_value_int_2d(buff, j, value):
474
- if value is None:
475
- buff[j] = -1
476
- else:
477
- value = sanitise_int_array(value, 2, buff.dtype)
478
- buff[j] = -2
479
- buff[j, :, : value.shape[1]] = value
480
-
481
-
482
- MIN_INT_VALUE = np.iinfo(np.int32).min + 2
483
- VCF_INT_MISSING = np.iinfo(np.int32).min
484
- VCF_INT_FILL = np.iinfo(np.int32).min + 1
485
-
486
- missing_value_map = {
487
- "Integer": -1,
488
- "Float": FLOAT32_MISSING,
489
- "String": ".",
490
- "Character": ".",
491
- "Flag": False,
492
- }
493
-
494
-
495
- class VcfValueTransformer:
496
- """
497
- Transform VCF values into the stored intermediate format used
498
- in the IntermediateColumnarFormat, and update field summaries.
499
- """
500
-
501
- def __init__(self, field, num_samples):
502
- self.field = field
503
- self.num_samples = num_samples
504
- self.dimension = 1
505
- if field.category == "FORMAT":
506
- self.dimension = 2
507
- self.missing = missing_value_map[field.vcf_type]
508
-
509
- @staticmethod
510
- def factory(field, num_samples):
511
- if field.vcf_type in ("Integer", "Flag"):
512
- return IntegerValueTransformer(field, num_samples)
513
- if field.vcf_type == "Float":
514
- return FloatValueTransformer(field, num_samples)
515
- if field.name in ["REF", "FILTERS", "ALT", "ID", "CHROM"]:
516
- return SplitStringValueTransformer(field, num_samples)
517
- return StringValueTransformer(field, num_samples)
518
-
519
- def transform(self, vcf_value):
520
- if isinstance(vcf_value, tuple):
521
- vcf_value = [self.missing if v is None else v for v in vcf_value]
522
- value = np.array(vcf_value, ndmin=self.dimension, copy=False)
523
- return value
524
-
525
- def transform_and_update_bounds(self, vcf_value):
526
- if vcf_value is None:
527
- return None
528
- value = self.transform(vcf_value)
529
- self.update_bounds(value)
530
- # print(self.field.full_name, "T", vcf_value, "->", value)
531
- return value
532
-
533
-
534
- MIN_INT_VALUE = np.iinfo(np.int32).min + 2
535
- VCF_INT_MISSING = np.iinfo(np.int32).min
536
- VCF_INT_FILL = np.iinfo(np.int32).min + 1
537
-
538
-
539
- class IntegerValueTransformer(VcfValueTransformer):
540
- def update_bounds(self, value):
541
- summary = self.field.summary
542
- # Mask out missing and fill values
543
- # print(value)
544
- a = value[value >= MIN_INT_VALUE]
545
- if a.size > 0:
546
- summary.max_value = int(max(summary.max_value, np.max(a)))
547
- summary.min_value = int(min(summary.min_value, np.min(a)))
548
- number = value.shape[-1]
549
- summary.max_number = max(summary.max_number, number)
550
-
551
-
552
- class FloatValueTransformer(VcfValueTransformer):
553
- def update_bounds(self, value):
554
- summary = self.field.summary
555
- summary.max_value = float(max(summary.max_value, np.max(value)))
556
- summary.min_value = float(min(summary.min_value, np.min(value)))
557
- number = value.shape[-1]
558
- summary.max_number = max(summary.max_number, number)
559
-
560
-
561
- class StringValueTransformer(VcfValueTransformer):
562
- def update_bounds(self, value):
563
- summary = self.field.summary
564
- number = value.shape[-1]
565
- # TODO would be nice to report string lengths, but not
566
- # really necessary.
567
- summary.max_number = max(summary.max_number, number)
568
-
569
- def transform(self, vcf_value):
570
- # print("transform", vcf_value)
571
- if self.dimension == 1:
572
- value = np.array(list(vcf_value.split(",")))
573
- else:
574
- # TODO can we make this faster??
575
- value = np.array([v.split(",") for v in vcf_value], dtype="O")
576
- # print("HERE", vcf_value, value)
577
- # for v in vcf_value:
578
- # print("\t", type(v), len(v), v.split(","))
579
- # print("S: ", self.dimension, ":", value.shape, value)
580
- return value
581
-
582
-
583
- class SplitStringValueTransformer(StringValueTransformer):
584
- def transform(self, vcf_value):
585
- if vcf_value is None:
586
- return self.missing_value # NEEDS TEST
587
- assert self.dimension == 1
588
- return np.array(vcf_value, ndmin=1, dtype="str")
589
-
590
-
591
- def get_vcf_field_path(base_path, vcf_field):
592
- if vcf_field.category == "fixed":
593
- return base_path / vcf_field.name
594
- return base_path / vcf_field.category / vcf_field.name
595
-
596
-
597
- class IntermediateColumnarFormatField:
598
- def __init__(self, icf, vcf_field):
599
- self.vcf_field = vcf_field
600
- self.path = get_vcf_field_path(icf.path, vcf_field)
601
- self.compressor = icf.compressor
602
- self.num_partitions = icf.num_partitions
603
- self.num_records = icf.num_records
604
- self.partition_record_index = icf.partition_record_index
605
- # A map of partition id to the cumulative number of records
606
- # in chunks within that partition
607
- self._chunk_record_index = {}
608
-
609
- @property
610
- def name(self):
611
- return self.vcf_field.full_name
612
-
613
- def partition_path(self, partition_id):
614
- return self.path / f"p{partition_id}"
615
-
616
- def __repr__(self):
617
- partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)]
618
- return (
619
- f"IntermediateColumnarFormatField(name={self.name}, "
620
- f"partition_chunks={partition_chunks}, "
621
- f"path={self.path})"
622
- )
623
-
624
- def num_chunks(self, partition_id):
625
- return len(self.chunk_record_index(partition_id)) - 1
626
-
627
- def chunk_record_index(self, partition_id):
628
- if partition_id not in self._chunk_record_index:
629
- index_path = self.partition_path(partition_id) / "chunk_index"
630
- with open(index_path, "rb") as f:
631
- a = pickle.load(f)
632
- assert len(a) > 1
633
- assert a[0] == 0
634
- self._chunk_record_index[partition_id] = a
635
- return self._chunk_record_index[partition_id]
636
-
637
- def read_chunk(self, path):
638
- with open(path, "rb") as f:
639
- pkl = self.compressor.decode(f.read())
640
- return pickle.loads(pkl)
641
-
642
- def chunk_num_records(self, partition_id):
643
- return np.diff(self.chunk_record_index(partition_id))
644
-
645
- def chunks(self, partition_id, start_chunk=0):
646
- partition_path = self.partition_path(partition_id)
647
- chunk_cumulative_records = self.chunk_record_index(partition_id)
648
- chunk_num_records = np.diff(chunk_cumulative_records)
649
- for count, cumulative in zip(
650
- chunk_num_records[start_chunk:], chunk_cumulative_records[start_chunk + 1 :]
651
- ):
652
- path = partition_path / f"{cumulative}"
653
- chunk = self.read_chunk(path)
654
- if len(chunk) != count:
655
- raise ValueError(f"Corruption detected in chunk: {path}")
656
- yield chunk
657
-
658
- def iter_values(self, start=None, stop=None):
659
- start = 0 if start is None else start
660
- stop = self.num_records if stop is None else stop
661
- start_partition = (
662
- np.searchsorted(self.partition_record_index, start, side="right") - 1
663
- )
664
- offset = self.partition_record_index[start_partition]
665
- assert offset <= start
666
- chunk_offset = start - offset
667
-
668
- chunk_record_index = self.chunk_record_index(start_partition)
669
- start_chunk = (
670
- np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1
671
- )
672
- record_id = offset + chunk_record_index[start_chunk]
673
- assert record_id <= start
674
- logger.debug(
675
- f"Read {self.vcf_field.full_name} slice [{start}:{stop}]:"
676
- f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}"
677
- )
678
- for chunk in self.chunks(start_partition, start_chunk):
679
- for record in chunk:
680
- if record_id == stop:
681
- return
682
- if record_id >= start:
683
- yield record
684
- record_id += 1
685
- assert record_id > start
686
- for partition_id in range(start_partition + 1, self.num_partitions):
687
- for chunk in self.chunks(partition_id):
688
- for record in chunk:
689
- if record_id == stop:
690
- return
691
- yield record
692
- record_id += 1
693
-
694
- # Note: this involves some computation so should arguably be a method,
695
- # but making a property for consistency with xarray etc
696
- @property
697
- def values(self):
698
- ret = [None] * self.num_records
699
- j = 0
700
- for partition_id in range(self.num_partitions):
701
- for chunk in self.chunks(partition_id):
702
- for record in chunk:
703
- ret[j] = record
704
- j += 1
705
- assert j == self.num_records
706
- return ret
707
-
708
- def sanitiser_factory(self, shape):
709
- """
710
- Return a function that sanitised values from this column
711
- and writes into a buffer of the specified shape.
712
- """
713
- assert len(shape) <= 3
714
- if self.vcf_field.vcf_type == "Flag":
715
- assert len(shape) == 1
716
- return sanitise_value_bool
717
- elif self.vcf_field.vcf_type == "Float":
718
- if len(shape) == 1:
719
- return sanitise_value_float_scalar
720
- elif len(shape) == 2:
721
- return sanitise_value_float_1d
722
- else:
723
- return sanitise_value_float_2d
724
- elif self.vcf_field.vcf_type == "Integer":
725
- if len(shape) == 1:
726
- return sanitise_value_int_scalar
727
- elif len(shape) == 2:
728
- return sanitise_value_int_1d
729
- else:
730
- return sanitise_value_int_2d
731
- else:
732
- assert self.vcf_field.vcf_type in ("String", "Character")
733
- if len(shape) == 1:
734
- return sanitise_value_string_scalar
735
- elif len(shape) == 2:
736
- return sanitise_value_string_1d
737
- else:
738
- return sanitise_value_string_2d
739
-
740
-
741
- @dataclasses.dataclass
742
- class IcfFieldWriter:
743
- vcf_field: VcfField
744
- path: pathlib.Path
745
- transformer: VcfValueTransformer
746
- compressor: Any
747
- max_buffered_bytes: int
748
- buff: List[Any] = dataclasses.field(default_factory=list)
749
- buffered_bytes: int = 0
750
- chunk_index: List[int] = dataclasses.field(default_factory=lambda: [0])
751
- num_records: int = 0
752
-
753
- def append(self, val):
754
- val = self.transformer.transform_and_update_bounds(val)
755
- assert val is None or isinstance(val, np.ndarray)
756
- self.buff.append(val)
757
- val_bytes = sys.getsizeof(val)
758
- self.buffered_bytes += val_bytes
759
- self.num_records += 1
760
- if self.buffered_bytes >= self.max_buffered_bytes:
761
- logger.debug(
762
- f"Flush {self.path} buffered={self.buffered_bytes} "
763
- f"max={self.max_buffered_bytes}"
764
- )
765
- self.write_chunk()
766
- self.buff.clear()
767
- self.buffered_bytes = 0
768
-
769
- def write_chunk(self):
770
- # Update index
771
- self.chunk_index.append(self.num_records)
772
- path = self.path / f"{self.num_records}"
773
- logger.debug(f"Start write: {path}")
774
- pkl = pickle.dumps(self.buff)
775
- compressed = self.compressor.encode(pkl)
776
- with open(path, "wb") as f:
777
- f.write(compressed)
778
-
779
- # Update the summary
780
- self.vcf_field.summary.num_chunks += 1
781
- self.vcf_field.summary.compressed_size += len(compressed)
782
- self.vcf_field.summary.uncompressed_size += self.buffered_bytes
783
- logger.debug(f"Finish write: {path}")
784
-
785
- def flush(self):
786
- logger.debug(
787
- f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}"
788
- )
789
- if len(self.buff) > 0:
790
- self.write_chunk()
791
- with open(self.path / "chunk_index", "wb") as f:
792
- a = np.array(self.chunk_index, dtype=int)
793
- pickle.dump(a, f)
794
-
795
-
796
- class IcfPartitionWriter(contextlib.AbstractContextManager):
797
- """
798
- Writes the data for a IntermediateColumnarFormat partition.
799
- """
800
-
801
- def __init__(
802
- self,
803
- icf_metadata,
804
- out_path,
805
- partition_index,
806
- ):
807
- self.partition_index = partition_index
808
- # chunk_size is in megabytes
809
- max_buffered_bytes = icf_metadata.column_chunk_size * 2**20
810
- assert max_buffered_bytes > 0
811
- compressor = numcodecs.get_codec(icf_metadata.compressor)
812
-
813
- self.field_writers = {}
814
- num_samples = len(icf_metadata.samples)
815
- for vcf_field in icf_metadata.fields:
816
- field_path = get_vcf_field_path(out_path, vcf_field)
817
- field_partition_path = field_path / f"p{partition_index}"
818
- # Should be robust to running explode_partition twice.
819
- field_partition_path.mkdir(exist_ok=True)
820
- transformer = VcfValueTransformer.factory(vcf_field, num_samples)
821
- self.field_writers[vcf_field.full_name] = IcfFieldWriter(
822
- vcf_field,
823
- field_partition_path,
824
- transformer,
825
- compressor,
826
- max_buffered_bytes,
827
- )
828
-
829
- @property
830
- def field_summaries(self):
831
- return {
832
- name: field.vcf_field.summary for name, field in self.field_writers.items()
833
- }
834
-
835
- def append(self, name, value):
836
- self.field_writers[name].append(value)
837
-
838
- def __exit__(self, exc_type, exc_val, exc_tb):
839
- if exc_type is None:
840
- for field in self.field_writers.values():
841
- field.flush()
842
- return False
843
-
844
-
845
- class IntermediateColumnarFormat(collections.abc.Mapping):
846
- def __init__(self, path):
847
- self.path = pathlib.Path(path)
848
- # TODO raise a more informative error here telling people this
849
- # directory is either a WIP or the wrong format.
850
- with open(self.path / "metadata.json") as f:
851
- self.metadata = IcfMetadata.fromdict(json.load(f))
852
- with open(self.path / "header.txt") as f:
853
- self.vcf_header = f.read()
854
-
855
- self.compressor = numcodecs.get_codec(self.metadata.compressor)
856
- self.columns = {}
857
- partition_num_records = [
858
- partition.num_records for partition in self.metadata.partitions
859
- ]
860
- # Allow us to find which partition a given record is in
861
- self.partition_record_index = np.cumsum([0, *partition_num_records])
862
- for field in self.metadata.fields:
863
- self.columns[field.full_name] = IntermediateColumnarFormatField(self, field)
864
- logger.info(
865
- f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, "
866
- f"records={self.num_records}, columns={self.num_columns})"
867
- )
868
-
869
- def __repr__(self):
870
- return (
871
- f"IntermediateColumnarFormat(fields={len(self)}, "
872
- f"partitions={self.num_partitions}, "
873
- f"records={self.num_records}, path={self.path})"
874
- )
875
-
876
- def __getitem__(self, key):
877
- return self.columns[key]
878
-
879
- def __iter__(self):
880
- return iter(self.columns)
881
-
882
- def __len__(self):
883
- return len(self.columns)
884
-
885
- def summary_table(self):
886
- data = []
887
- for name, col in self.columns.items():
888
- summary = col.vcf_field.summary
889
- d = {
890
- "name": name,
891
- "type": col.vcf_field.vcf_type,
892
- "chunks": summary.num_chunks,
893
- "size": display_size(summary.uncompressed_size),
894
- "compressed": display_size(summary.compressed_size),
895
- "max_n": summary.max_number,
896
- "min_val": display_number(summary.min_value),
897
- "max_val": display_number(summary.max_value),
898
- }
899
-
900
- data.append(d)
901
- return data
902
-
903
- @functools.cached_property
904
- def num_records(self):
905
- return sum(self.metadata.contig_record_counts.values())
906
-
907
- @property
908
- def num_partitions(self):
909
- return len(self.metadata.partitions)
910
-
911
- @property
912
- def num_samples(self):
913
- return len(self.metadata.samples)
914
-
915
- @property
916
- def num_columns(self):
917
- return len(self.columns)
918
-
919
-
920
- class IntermediateColumnarFormatWriter:
921
- def __init__(self, path):
922
- self.path = pathlib.Path(path)
923
- self.wip_path = self.path / "wip"
924
- self.metadata = None
925
-
926
- @property
927
- def num_partitions(self):
928
- return len(self.metadata.partitions)
929
-
930
- def init(
931
- self,
932
- vcfs,
933
- *,
934
- column_chunk_size=16,
935
- worker_processes=1,
936
- target_num_partitions=None,
937
- show_progress=False,
938
- compressor=None,
939
- ):
940
- if self.path.exists():
941
- raise ValueError("ICF path already exists")
942
- if compressor is None:
943
- compressor = ICF_DEFAULT_COMPRESSOR
944
- vcfs = [pathlib.Path(vcf) for vcf in vcfs]
945
- target_num_partitions = max(target_num_partitions, len(vcfs))
946
-
947
- # TODO move scan_vcfs into this class
948
- icf_metadata, header = scan_vcfs(
949
- vcfs,
950
- worker_processes=worker_processes,
951
- show_progress=show_progress,
952
- target_num_partitions=target_num_partitions,
953
- )
954
- self.metadata = icf_metadata
955
- self.metadata.format_version = ICF_METADATA_FORMAT_VERSION
956
- self.metadata.compressor = compressor.get_config()
957
- self.metadata.column_chunk_size = column_chunk_size
958
- # Bare minimum here for provenance - would be nice to include versions of key
959
- # dependencies as well.
960
- self.metadata.provenance = {"source": f"bio2zarr-{provenance.__version__}"}
961
-
962
- self.mkdirs()
963
-
964
- # Note: this is needed for the current version of the vcfzarr spec, but it's
965
- # probably going to be dropped.
966
- # https://github.com/pystatgen/vcf-zarr-spec/issues/15
967
- # May be useful to keep lying around still though?
968
- logger.info("Writing VCF header")
969
- with open(self.path / "header.txt", "w") as f:
970
- f.write(header)
971
-
972
- logger.info("Writing WIP metadata")
973
- with open(self.wip_path / "metadata.json", "w") as f:
974
- json.dump(self.metadata.asdict(), f, indent=4)
975
- return self.num_partitions
976
-
977
- def mkdirs(self):
978
- num_dirs = len(self.metadata.fields)
979
- logger.info(f"Creating {num_dirs} field directories")
980
- self.path.mkdir()
981
- self.wip_path.mkdir()
982
- for field in self.metadata.fields:
983
- col_path = get_vcf_field_path(self.path, field)
984
- col_path.mkdir(parents=True)
985
-
986
- def load_partition_summaries(self):
987
- summaries = []
988
- not_found = []
989
- for j in range(self.num_partitions):
990
- try:
991
- with open(self.wip_path / f"p{j}_summary.json") as f:
992
- summary = json.load(f)
993
- for k, v in summary["field_summaries"].items():
994
- summary["field_summaries"][k] = VcfFieldSummary.fromdict(v)
995
- summaries.append(summary)
996
- except FileNotFoundError:
997
- not_found.append(j)
998
- if len(not_found) > 0:
999
- raise FileNotFoundError(
1000
- f"Partition metadata not found for {len(not_found)}"
1001
- f" partitions: {not_found}"
1002
- )
1003
- return summaries
1004
-
1005
- def load_metadata(self):
1006
- if self.metadata is None:
1007
- with open(self.wip_path / "metadata.json") as f:
1008
- self.metadata = IcfMetadata.fromdict(json.load(f))
1009
-
1010
- def process_partition(self, partition_index):
1011
- self.load_metadata()
1012
- summary_path = self.wip_path / f"p{partition_index}_summary.json"
1013
- # If someone is rewriting a summary path (for whatever reason), make sure it
1014
- # doesn't look like it's already been completed.
1015
- # NOTE to do this properly we probably need to take a lock on this file - but
1016
- # this simple approach will catch the vast majority of problems.
1017
- if summary_path.exists():
1018
- summary_path.unlink()
1019
-
1020
- partition = self.metadata.partitions[partition_index]
1021
- logger.info(
1022
- f"Start p{partition_index} {partition.vcf_path}__{partition.region}"
1023
- )
1024
- info_fields = self.metadata.info_fields
1025
- format_fields = []
1026
- has_gt = False
1027
- for field in self.metadata.format_fields:
1028
- if field.name == "GT":
1029
- has_gt = True
1030
- else:
1031
- format_fields.append(field)
1032
-
1033
- with IcfPartitionWriter(
1034
- self.metadata,
1035
- self.path,
1036
- partition_index,
1037
- ) as tcw:
1038
- with vcf_utils.IndexedVcf(partition.vcf_path) as ivcf:
1039
- num_records = 0
1040
- for variant in ivcf.variants(partition.region):
1041
- num_records += 1
1042
- tcw.append("CHROM", variant.CHROM)
1043
- tcw.append("POS", variant.POS)
1044
- tcw.append("QUAL", variant.QUAL)
1045
- tcw.append("ID", variant.ID)
1046
- tcw.append("FILTERS", variant.FILTERS)
1047
- tcw.append("REF", variant.REF)
1048
- tcw.append("ALT", variant.ALT)
1049
- for field in info_fields:
1050
- tcw.append(field.full_name, variant.INFO.get(field.name, None))
1051
- if has_gt:
1052
- tcw.append("FORMAT/GT", variant.genotype.array())
1053
- for field in format_fields:
1054
- val = variant.format(field.name)
1055
- tcw.append(field.full_name, val)
1056
- # Note: an issue with updating the progress per variant here like
1057
- # this is that we get a significant pause at the end of the counter
1058
- # while all the "small" fields get flushed. Possibly not much to be
1059
- # done about it.
1060
- core.update_progress(1)
1061
- logger.info(
1062
- f"Finished reading VCF for partition {partition_index}, "
1063
- f"flushing buffers"
1064
- )
1065
-
1066
- partition_metadata = {
1067
- "num_records": num_records,
1068
- "field_summaries": {k: v.asdict() for k, v in tcw.field_summaries.items()},
1069
- }
1070
- with open(summary_path, "w") as f:
1071
- json.dump(partition_metadata, f, indent=4)
1072
- logger.info(
1073
- f"Finish p{partition_index} {partition.vcf_path}__{partition.region}="
1074
- f"{num_records} records"
1075
- )
1076
-
1077
- def process_partition_slice(
1078
- self,
1079
- start,
1080
- stop,
1081
- *,
1082
- worker_processes=1,
1083
- show_progress=False,
1084
- ):
1085
- self.load_metadata()
1086
- if start == 0 and stop == self.num_partitions:
1087
- num_records = self.metadata.num_records
1088
- else:
1089
- # We only know the number of records if all partitions are done at once,
1090
- # and we signal this to tqdm by passing None as the total.
1091
- num_records = None
1092
- num_columns = len(self.metadata.fields)
1093
- num_samples = len(self.metadata.samples)
1094
- logger.info(
1095
- f"Exploding columns={num_columns} samples={num_samples}; "
1096
- f"partitions={stop - start} "
1097
- f"variants={'unknown' if num_records is None else num_records}"
1098
- )
1099
- progress_config = core.ProgressConfig(
1100
- total=num_records,
1101
- units="vars",
1102
- title="Explode",
1103
- show=show_progress,
1104
- )
1105
- with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1106
- for j in range(start, stop):
1107
- pwm.submit(self.process_partition, j)
1108
-
1109
- def explode(self, *, worker_processes=1, show_progress=False):
1110
- self.load_metadata()
1111
- return self.process_partition_slice(
1112
- 0,
1113
- self.num_partitions,
1114
- worker_processes=worker_processes,
1115
- show_progress=show_progress,
1116
- )
1117
-
1118
- def explode_partition(self, partition, *, show_progress=False, worker_processes=1):
1119
- self.load_metadata()
1120
- if partition < 0 or partition >= self.num_partitions:
1121
- raise ValueError(
1122
- "Partition index must be in the range 0 <= index < num_partitions"
1123
- )
1124
- return self.process_partition_slice(
1125
- partition,
1126
- partition + 1,
1127
- worker_processes=worker_processes,
1128
- show_progress=show_progress,
1129
- )
1130
-
1131
- def finalise(self):
1132
- self.load_metadata()
1133
- partition_summaries = self.load_partition_summaries()
1134
- total_records = 0
1135
- for index, summary in enumerate(partition_summaries):
1136
- partition_records = summary["num_records"]
1137
- self.metadata.partitions[index].num_records = partition_records
1138
- total_records += partition_records
1139
- assert total_records == self.metadata.num_records
1140
-
1141
- for field in self.metadata.fields:
1142
- for summary in partition_summaries:
1143
- field.summary.update(summary["field_summaries"][field.full_name])
1144
-
1145
- logger.info("Finalising metadata")
1146
- with open(self.path / "metadata.json", "w") as f:
1147
- json.dump(self.metadata.asdict(), f, indent=4)
1148
-
1149
- logger.debug("Removing WIP directory")
1150
- shutil.rmtree(self.wip_path)
1151
-
1152
-
1153
- def explode(
1154
- icf_path,
1155
- vcfs,
1156
- *,
1157
- column_chunk_size=16,
1158
- worker_processes=1,
1159
- show_progress=False,
1160
- compressor=None,
1161
- ):
1162
- writer = IntermediateColumnarFormatWriter(icf_path)
1163
- writer.init(
1164
- vcfs,
1165
- # Heuristic to get reasonable worker utilisation with lumpy partition sizing
1166
- target_num_partitions=max(1, worker_processes * 4),
1167
- worker_processes=worker_processes,
1168
- show_progress=show_progress,
1169
- column_chunk_size=column_chunk_size,
1170
- compressor=compressor,
1171
- )
1172
- writer.explode(worker_processes=worker_processes, show_progress=show_progress)
1173
- writer.finalise()
1174
- return IntermediateColumnarFormat(icf_path)
1175
-
1176
-
1177
- def explode_init(
1178
- icf_path,
1179
- vcfs,
1180
- *,
1181
- column_chunk_size=16,
1182
- target_num_partitions=1,
1183
- worker_processes=1,
1184
- show_progress=False,
1185
- compressor=None,
1186
- ):
1187
- writer = IntermediateColumnarFormatWriter(icf_path)
1188
- return writer.init(
1189
- vcfs,
1190
- target_num_partitions=target_num_partitions,
1191
- worker_processes=worker_processes,
1192
- show_progress=show_progress,
1193
- column_chunk_size=column_chunk_size,
1194
- compressor=compressor,
1195
- )
1196
-
1197
-
1198
- # NOTE only including worker_processes here so we can use the 0 option to get the
1199
- # work done syncronously and so we can get test coverage on it. Should find a
1200
- # better way to do this.
1201
- def explode_partition(icf_path, partition, *, show_progress=False, worker_processes=1):
1202
- writer = IntermediateColumnarFormatWriter(icf_path)
1203
- writer.explode_partition(
1204
- partition, show_progress=show_progress, worker_processes=worker_processes
1205
- )
1206
-
1207
-
1208
- def explode_finalise(icf_path):
1209
- writer = IntermediateColumnarFormatWriter(icf_path)
1210
- writer.finalise()
1211
-
1212
-
1213
- def inspect(path):
1214
- path = pathlib.Path(path)
1215
- # TODO add support for the Zarr format also
1216
- if (path / "metadata.json").exists():
1217
- obj = IntermediateColumnarFormat(path)
1218
- elif (path / ".zmetadata").exists():
1219
- obj = VcfZarr(path)
1220
- else:
1221
- raise ValueError("Format not recognised") # NEEDS TEST
1222
- return obj.summary_table()
1223
-
1224
-
1225
- DEFAULT_ZARR_COMPRESSOR = numcodecs.Blosc(cname="zstd", clevel=7)
1226
-
1227
-
1228
- @dataclasses.dataclass
1229
- class ZarrColumnSpec:
1230
- name: str
1231
- dtype: str
1232
- shape: tuple
1233
- chunks: tuple
1234
- dimensions: tuple
1235
- description: str
1236
- vcf_field: str
1237
- compressor: dict
1238
- filters: list
1239
-
1240
- def __post_init__(self):
1241
- # Ensure these are tuples for ease of comparison and consistency
1242
- self.shape = tuple(self.shape)
1243
- self.chunks = tuple(self.chunks)
1244
- self.dimensions = tuple(self.dimensions)
1245
-
1246
- @staticmethod
1247
- def new(**kwargs):
1248
- spec = ZarrColumnSpec(
1249
- **kwargs, compressor=DEFAULT_ZARR_COMPRESSOR.get_config(), filters=[]
1250
- )
1251
- spec._choose_compressor_settings()
1252
- return spec
1253
-
1254
- @staticmethod
1255
- def from_field(
1256
- vcf_field,
1257
- *,
1258
- num_variants,
1259
- num_samples,
1260
- variants_chunk_size,
1261
- samples_chunk_size,
1262
- variable_name=None,
1263
- ):
1264
- shape = [num_variants]
1265
- prefix = "variant_"
1266
- dimensions = ["variants"]
1267
- chunks = [variants_chunk_size]
1268
- if vcf_field.category == "FORMAT":
1269
- prefix = "call_"
1270
- shape.append(num_samples)
1271
- chunks.append(samples_chunk_size)
1272
- dimensions.append("samples")
1273
- if variable_name is None:
1274
- variable_name = prefix + vcf_field.name
1275
- # TODO make an option to add in the empty extra dimension
1276
- if vcf_field.summary.max_number > 1:
1277
- shape.append(vcf_field.summary.max_number)
1278
- # TODO we should really be checking this to see if the named dimensions
1279
- # are actually correct.
1280
- if vcf_field.vcf_number == "R":
1281
- dimensions.append("alleles")
1282
- elif vcf_field.vcf_number == "A":
1283
- dimensions.append("alt_alleles")
1284
- elif vcf_field.vcf_number == "G":
1285
- dimensions.append("genotypes")
1286
- else:
1287
- dimensions.append(f"{vcf_field.category}_{vcf_field.name}_dim")
1288
- return ZarrColumnSpec.new(
1289
- vcf_field=vcf_field.full_name,
1290
- name=variable_name,
1291
- dtype=vcf_field.smallest_dtype(),
1292
- shape=shape,
1293
- chunks=chunks,
1294
- dimensions=dimensions,
1295
- description=vcf_field.description,
1296
- )
1297
-
1298
- def _choose_compressor_settings(self):
1299
- """
1300
- Choose compressor and filter settings based on the size and
1301
- type of the array, plus some hueristics from observed properties
1302
- of VCFs.
1303
-
1304
- See https://github.com/pystatgen/bio2zarr/discussions/74
1305
- """
1306
- # Default is to not shuffle, because autoshuffle isn't recognised
1307
- # by many Zarr implementations, and shuffling can lead to worse
1308
- # performance in some cases anyway. Turning on shuffle should be a
1309
- # deliberate choice.
1310
- shuffle = numcodecs.Blosc.NOSHUFFLE
1311
- if self.name == "call_genotype" and self.dtype == "i1":
1312
- # call_genotype gets BITSHUFFLE by default as it gets
1313
- # significantly better compression (at a cost of slower
1314
- # decoding)
1315
- shuffle = numcodecs.Blosc.BITSHUFFLE
1316
- elif self.dtype == "bool":
1317
- shuffle = numcodecs.Blosc.BITSHUFFLE
1318
-
1319
- self.compressor["shuffle"] = shuffle
1320
-
1321
- @property
1322
- def variant_chunk_nbytes(self):
1323
- """
1324
- Returns the nbytes for a single variant chunk of this array.
1325
- """
1326
- chunk_items = self.chunks[0]
1327
- for size in self.shape[1:]:
1328
- chunk_items *= size
1329
- dt = np.dtype(self.dtype)
1330
- return chunk_items * dt.itemsize
1331
-
1332
-
1333
- ZARR_SCHEMA_FORMAT_VERSION = "0.2"
1334
-
1335
-
1336
- @dataclasses.dataclass
1337
- class VcfZarrSchema:
1338
- format_version: str
1339
- samples_chunk_size: int
1340
- variants_chunk_size: int
1341
- dimensions: list
1342
- sample_id: list
1343
- contig_id: list
1344
- contig_length: list
1345
- filter_id: list
1346
- columns: dict
1347
-
1348
- def asdict(self):
1349
- return dataclasses.asdict(self)
1350
-
1351
- def asjson(self):
1352
- return json.dumps(self.asdict(), indent=4)
1353
-
1354
- @staticmethod
1355
- def fromdict(d):
1356
- if d["format_version"] != ZARR_SCHEMA_FORMAT_VERSION:
1357
- raise ValueError(
1358
- "Zarr schema format version mismatch: "
1359
- f"{d['format_version']} != {ZARR_SCHEMA_FORMAT_VERSION}"
1360
- )
1361
- ret = VcfZarrSchema(**d)
1362
- ret.columns = {
1363
- key: ZarrColumnSpec(**value) for key, value in d["columns"].items()
1364
- }
1365
- return ret
1366
-
1367
- @staticmethod
1368
- def fromjson(s):
1369
- return VcfZarrSchema.fromdict(json.loads(s))
1370
-
1371
- @staticmethod
1372
- def generate(icf, variants_chunk_size=None, samples_chunk_size=None):
1373
- m = icf.num_records
1374
- n = icf.num_samples
1375
- # FIXME
1376
- if samples_chunk_size is None:
1377
- samples_chunk_size = 1000
1378
- if variants_chunk_size is None:
1379
- variants_chunk_size = 10_000
1380
- logger.info(
1381
- f"Generating schema with chunks={variants_chunk_size, samples_chunk_size}"
1382
- )
1383
-
1384
- def spec_from_field(field, variable_name=None):
1385
- return ZarrColumnSpec.from_field(
1386
- field,
1387
- num_samples=n,
1388
- num_variants=m,
1389
- samples_chunk_size=samples_chunk_size,
1390
- variants_chunk_size=variants_chunk_size,
1391
- variable_name=variable_name,
1392
- )
1393
-
1394
- def fixed_field_spec(
1395
- name, dtype, vcf_field=None, shape=(m,), dimensions=("variants",)
1396
- ):
1397
- return ZarrColumnSpec.new(
1398
- vcf_field=vcf_field,
1399
- name=name,
1400
- dtype=dtype,
1401
- shape=shape,
1402
- description="",
1403
- dimensions=dimensions,
1404
- chunks=[variants_chunk_size],
1405
- )
1406
-
1407
- alt_col = icf.columns["ALT"]
1408
- max_alleles = alt_col.vcf_field.summary.max_number + 1
1409
-
1410
- colspecs = [
1411
- fixed_field_spec(
1412
- name="variant_contig",
1413
- dtype=core.min_int_dtype(0, icf.metadata.num_contigs),
1414
- ),
1415
- fixed_field_spec(
1416
- name="variant_filter",
1417
- dtype="bool",
1418
- shape=(m, icf.metadata.num_filters),
1419
- dimensions=["variants", "filters"],
1420
- ),
1421
- fixed_field_spec(
1422
- name="variant_allele",
1423
- dtype="str",
1424
- shape=(m, max_alleles),
1425
- dimensions=["variants", "alleles"],
1426
- ),
1427
- fixed_field_spec(
1428
- name="variant_id",
1429
- dtype="str",
1430
- ),
1431
- fixed_field_spec(
1432
- name="variant_id_mask",
1433
- dtype="bool",
1434
- ),
1435
- ]
1436
- name_map = {field.full_name: field for field in icf.metadata.fields}
1437
-
1438
- # Only two of the fixed fields have a direct one-to-one mapping.
1439
- colspecs.extend(
1440
- [
1441
- spec_from_field(name_map["QUAL"], variable_name="variant_quality"),
1442
- spec_from_field(name_map["POS"], variable_name="variant_position"),
1443
- ]
1444
- )
1445
- colspecs.extend([spec_from_field(field) for field in icf.metadata.info_fields])
1446
-
1447
- gt_field = None
1448
- for field in icf.metadata.format_fields:
1449
- if field.name == "GT":
1450
- gt_field = field
1451
- continue
1452
- colspecs.append(spec_from_field(field))
1453
-
1454
- if gt_field is not None:
1455
- ploidy = gt_field.summary.max_number - 1
1456
- shape = [m, n]
1457
- chunks = [variants_chunk_size, samples_chunk_size]
1458
- dimensions = ["variants", "samples"]
1459
- colspecs.append(
1460
- ZarrColumnSpec.new(
1461
- vcf_field=None,
1462
- name="call_genotype_phased",
1463
- dtype="bool",
1464
- shape=list(shape),
1465
- chunks=list(chunks),
1466
- dimensions=list(dimensions),
1467
- description="",
1468
- )
1469
- )
1470
- shape += [ploidy]
1471
- dimensions += ["ploidy"]
1472
- colspecs.append(
1473
- ZarrColumnSpec.new(
1474
- vcf_field=None,
1475
- name="call_genotype",
1476
- dtype=gt_field.smallest_dtype(),
1477
- shape=list(shape),
1478
- chunks=list(chunks),
1479
- dimensions=list(dimensions),
1480
- description="",
1481
- )
1482
- )
1483
- colspecs.append(
1484
- ZarrColumnSpec.new(
1485
- vcf_field=None,
1486
- name="call_genotype_mask",
1487
- dtype="bool",
1488
- shape=list(shape),
1489
- chunks=list(chunks),
1490
- dimensions=list(dimensions),
1491
- description="",
1492
- )
1493
- )
1494
-
1495
- return VcfZarrSchema(
1496
- format_version=ZARR_SCHEMA_FORMAT_VERSION,
1497
- samples_chunk_size=samples_chunk_size,
1498
- variants_chunk_size=variants_chunk_size,
1499
- columns={col.name: col for col in colspecs},
1500
- dimensions=["variants", "samples", "ploidy", "alleles", "filters"],
1501
- sample_id=icf.metadata.samples,
1502
- contig_id=icf.metadata.contig_names,
1503
- contig_length=icf.metadata.contig_lengths,
1504
- filter_id=icf.metadata.filters,
1505
- )
1506
-
1507
-
1508
- class VcfZarr:
1509
- def __init__(self, path):
1510
- if not (path / ".zmetadata").exists():
1511
- raise ValueError("Not in VcfZarr format") # NEEDS TEST
1512
- self.root = zarr.open(path, mode="r")
1513
-
1514
- def __repr__(self):
1515
- return repr(self.root) # NEEDS TEST
1516
-
1517
- def summary_table(self):
1518
- data = []
1519
- arrays = [(a.nbytes_stored, a) for _, a in self.root.arrays()]
1520
- arrays.sort(key=lambda x: x[0])
1521
- for stored, array in reversed(arrays):
1522
- d = {
1523
- "name": array.name,
1524
- "dtype": str(array.dtype),
1525
- "stored": display_size(stored),
1526
- "size": display_size(array.nbytes),
1527
- "ratio": display_number(array.nbytes / stored),
1528
- "nchunks": str(array.nchunks),
1529
- "chunk_size": display_size(array.nbytes / array.nchunks),
1530
- "avg_chunk_stored": display_size(int(stored / array.nchunks)),
1531
- "shape": str(array.shape),
1532
- "chunk_shape": str(array.chunks),
1533
- "compressor": str(array.compressor),
1534
- "filters": str(array.filters),
1535
- }
1536
- data.append(d)
1537
- return data
1538
-
1539
-
1540
- def parse_max_memory(max_memory):
1541
- if max_memory is None:
1542
- # Effectively unbounded
1543
- return 2**63
1544
- if isinstance(max_memory, str):
1545
- max_memory = humanfriendly.parse_size(max_memory)
1546
- logger.info(f"Set memory budget to {display_size(max_memory)}")
1547
- return max_memory
1548
-
1549
-
1550
- @dataclasses.dataclass
1551
- class VcfZarrPartition:
1552
- start_index: int
1553
- stop_index: int
1554
- start_chunk: int
1555
- stop_chunk: int
1556
-
1557
- @staticmethod
1558
- def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None):
1559
- num_chunks = int(np.ceil(num_records / chunk_size))
1560
- if max_chunks is not None:
1561
- num_chunks = min(num_chunks, max_chunks)
1562
- partitions = []
1563
- splits = np.array_split(np.arange(num_chunks), min(num_partitions, num_chunks))
1564
- for chunk_slice in splits:
1565
- start_chunk = int(chunk_slice[0])
1566
- stop_chunk = int(chunk_slice[-1]) + 1
1567
- start_index = start_chunk * chunk_size
1568
- stop_index = min(stop_chunk * chunk_size, num_records)
1569
- partitions.append(
1570
- VcfZarrPartition(start_index, stop_index, start_chunk, stop_chunk)
1571
- )
1572
- return partitions
1573
-
1574
-
1575
- VZW_METADATA_FORMAT_VERSION = "0.1"
1576
-
1577
-
1578
- @dataclasses.dataclass
1579
- class VcfZarrWriterMetadata:
1580
- format_version: str
1581
- icf_path: str
1582
- schema: VcfZarrSchema
1583
- dimension_separator: str
1584
- partitions: list
1585
- provenance: dict
1586
-
1587
- def asdict(self):
1588
- return dataclasses.asdict(self)
1589
-
1590
- @staticmethod
1591
- def fromdict(d):
1592
- if d["format_version"] != VZW_METADATA_FORMAT_VERSION:
1593
- raise ValueError(
1594
- "VcfZarrWriter format version mismatch: "
1595
- f"{d['format_version']} != {VZW_METADATA_FORMAT_VERSION}"
1596
- )
1597
- ret = VcfZarrWriterMetadata(**d)
1598
- ret.schema = VcfZarrSchema.fromdict(ret.schema)
1599
- ret.partitions = [VcfZarrPartition(**p) for p in ret.partitions]
1600
- return ret
1601
-
1602
-
1603
- class VcfZarrWriter:
1604
- def __init__(self, path):
1605
- self.path = pathlib.Path(path)
1606
- self.wip_path = self.path / "wip"
1607
- self.arrays_path = self.wip_path / "arrays"
1608
- self.partitions_path = self.wip_path / "partitions"
1609
- self.metadata = None
1610
- self.icf = None
1611
-
1612
- @property
1613
- def schema(self):
1614
- return self.metadata.schema
1615
-
1616
- @property
1617
- def num_partitions(self):
1618
- return len(self.metadata.partitions)
1619
-
1620
- #######################
1621
- # init
1622
- #######################
1623
-
1624
- def init(
1625
- self,
1626
- icf,
1627
- *,
1628
- target_num_partitions,
1629
- schema,
1630
- dimension_separator=None,
1631
- max_variant_chunks=None,
1632
- ):
1633
- self.icf = icf
1634
- if self.path.exists():
1635
- raise ValueError("Zarr path already exists") # NEEDS TEST
1636
- partitions = VcfZarrPartition.generate_partitions(
1637
- self.icf.num_records,
1638
- schema.variants_chunk_size,
1639
- target_num_partitions,
1640
- max_chunks=max_variant_chunks,
1641
- )
1642
- # Default to using nested directories following the Zarr v3 default.
1643
- # This seems to require version 2.17+ to work properly
1644
- dimension_separator = (
1645
- "/" if dimension_separator is None else dimension_separator
1646
- )
1647
- self.metadata = VcfZarrWriterMetadata(
1648
- format_version=VZW_METADATA_FORMAT_VERSION,
1649
- icf_path=str(self.icf.path),
1650
- schema=schema,
1651
- dimension_separator=dimension_separator,
1652
- partitions=partitions,
1653
- # Bare minimum here for provenance - see comments above
1654
- provenance={"source": f"bio2zarr-{provenance.__version__}"},
1655
- )
1656
-
1657
- self.path.mkdir()
1658
- store = zarr.DirectoryStore(self.path)
1659
- root = zarr.group(store=store)
1660
- root.attrs.update(
1661
- {
1662
- "vcf_zarr_version": "0.2",
1663
- "vcf_header": self.icf.vcf_header,
1664
- "source": f"bio2zarr-{provenance.__version__}",
1665
- }
1666
- )
1667
- # Doing this syncronously - this is fine surely
1668
- self.encode_samples(root)
1669
- self.encode_filter_id(root)
1670
- self.encode_contig_id(root)
1671
-
1672
- self.wip_path.mkdir()
1673
- self.arrays_path.mkdir()
1674
- self.partitions_path.mkdir()
1675
- store = zarr.DirectoryStore(self.arrays_path)
1676
- root = zarr.group(store=store)
1677
-
1678
- for column in self.schema.columns.values():
1679
- self.init_array(root, column, partitions[-1].stop_index)
1680
-
1681
- logger.info("Writing WIP metadata")
1682
- with open(self.wip_path / "metadata.json", "w") as f:
1683
- json.dump(self.metadata.asdict(), f, indent=4)
1684
- return len(partitions)
1685
-
1686
- def encode_samples(self, root):
1687
- if not np.array_equal(self.schema.sample_id, self.icf.metadata.samples):
1688
- raise ValueError(
1689
- "Subsetting or reordering samples not supported currently"
1690
- ) # NEEDS TEST
1691
- array = root.array(
1692
- "sample_id",
1693
- self.schema.sample_id,
1694
- dtype="str",
1695
- compressor=DEFAULT_ZARR_COMPRESSOR,
1696
- chunks=(self.schema.samples_chunk_size,),
1697
- )
1698
- array.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
1699
- logger.debug("Samples done")
1700
-
1701
- def encode_contig_id(self, root):
1702
- array = root.array(
1703
- "contig_id",
1704
- self.schema.contig_id,
1705
- dtype="str",
1706
- compressor=DEFAULT_ZARR_COMPRESSOR,
1707
- )
1708
- array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1709
- if self.schema.contig_length is not None:
1710
- array = root.array(
1711
- "contig_length",
1712
- self.schema.contig_length,
1713
- dtype=np.int64,
1714
- compressor=DEFAULT_ZARR_COMPRESSOR,
1715
- )
1716
- array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1717
-
1718
- def encode_filter_id(self, root):
1719
- array = root.array(
1720
- "filter_id",
1721
- self.schema.filter_id,
1722
- dtype="str",
1723
- compressor=DEFAULT_ZARR_COMPRESSOR,
1724
- )
1725
- array.attrs["_ARRAY_DIMENSIONS"] = ["filters"]
1726
-
1727
- def init_array(self, root, variable, variants_dim_size):
1728
- object_codec = None
1729
- if variable.dtype == "O":
1730
- object_codec = numcodecs.VLenUTF8()
1731
- shape = list(variable.shape)
1732
- # Truncate the variants dimension is max_variant_chunks was specified
1733
- shape[0] = variants_dim_size
1734
- a = root.empty(
1735
- variable.name,
1736
- shape=shape,
1737
- chunks=variable.chunks,
1738
- dtype=variable.dtype,
1739
- compressor=numcodecs.get_codec(variable.compressor),
1740
- filters=[numcodecs.get_codec(filt) for filt in variable.filters],
1741
- object_codec=object_codec,
1742
- dimension_separator=self.metadata.dimension_separator,
1743
- )
1744
- a.attrs.update(
1745
- {
1746
- "description": variable.description,
1747
- # Dimension names are part of the spec in Zarr v3
1748
- "_ARRAY_DIMENSIONS": variable.dimensions,
1749
- }
1750
- )
1751
- logger.debug(f"Initialised {a}")
1752
-
1753
- #######################
1754
- # encode_partition
1755
- #######################
1756
-
1757
- def load_metadata(self):
1758
- if self.metadata is None:
1759
- with open(self.wip_path / "metadata.json") as f:
1760
- self.metadata = VcfZarrWriterMetadata.fromdict(json.load(f))
1761
- self.icf = IntermediateColumnarFormat(self.metadata.icf_path)
1762
-
1763
- def partition_path(self, partition_index):
1764
- return self.partitions_path / f"p{partition_index}"
1765
-
1766
- def wip_partition_array_path(self, partition_index, name):
1767
- return self.partition_path(partition_index) / f"wip_{name}"
1768
-
1769
- def partition_array_path(self, partition_index, name):
1770
- return self.partition_path(partition_index) / name
1771
-
1772
- def encode_partition(self, partition_index):
1773
- self.load_metadata()
1774
- partition_path = self.partition_path(partition_index)
1775
- partition_path.mkdir(exist_ok=True)
1776
- logger.info(f"Encoding partition {partition_index} to {partition_path}")
1777
-
1778
- self.encode_alleles_partition(partition_index)
1779
- self.encode_id_partition(partition_index)
1780
- self.encode_filters_partition(partition_index)
1781
- self.encode_contig_partition(partition_index)
1782
- for col in self.schema.columns.values():
1783
- if col.vcf_field is not None:
1784
- self.encode_array_partition(col, partition_index)
1785
- if "call_genotype" in self.schema.columns:
1786
- self.encode_genotypes_partition(partition_index)
1787
-
1788
- def init_partition_array(self, partition_index, name):
1789
- wip_path = self.wip_partition_array_path(partition_index, name)
1790
- # Create an empty array like the definition
1791
- src = self.arrays_path / name
1792
- # Overwrite any existing WIP files
1793
- shutil.copytree(src, wip_path, dirs_exist_ok=True)
1794
- array = zarr.open(wip_path)
1795
- logger.debug(f"Opened empty array {array} @ {wip_path}")
1796
- return array
1797
-
1798
- def finalise_partition_array(self, partition_index, name):
1799
- wip_path = self.wip_partition_array_path(partition_index, name)
1800
- final_path = self.partition_array_path(partition_index, name)
1801
- if final_path.exists():
1802
- # NEEDS TEST
1803
- logger.warning(f"Removing existing {final_path}")
1804
- shutil.rmtree(final_path)
1805
- # Atomic swap
1806
- os.rename(wip_path, final_path)
1807
- logger.debug(f"Encoded {name} partition {partition_index}")
1808
-
1809
- def encode_array_partition(self, column, partition_index):
1810
- array = self.init_partition_array(partition_index, column.name)
1811
-
1812
- partition = self.metadata.partitions[partition_index]
1813
- ba = core.BufferedArray(array, partition.start_index)
1814
- source_col = self.icf.columns[column.vcf_field]
1815
- sanitiser = source_col.sanitiser_factory(ba.buff.shape)
1816
-
1817
- for value in source_col.iter_values(
1818
- partition.start_index, partition.stop_index
1819
- ):
1820
- # We write directly into the buffer in the sanitiser function
1821
- # to make it easier to reason about dimension padding
1822
- j = ba.next_buffer_row()
1823
- sanitiser(ba.buff, j, value)
1824
- ba.flush()
1825
- self.finalise_partition_array(partition_index, column.name)
1826
-
1827
- def encode_genotypes_partition(self, partition_index):
1828
- gt_array = self.init_partition_array(partition_index, "call_genotype")
1829
- gt_mask_array = self.init_partition_array(partition_index, "call_genotype_mask")
1830
- gt_phased_array = self.init_partition_array(
1831
- partition_index, "call_genotype_phased"
1832
- )
1833
-
1834
- partition = self.metadata.partitions[partition_index]
1835
- gt = core.BufferedArray(gt_array, partition.start_index)
1836
- gt_mask = core.BufferedArray(gt_mask_array, partition.start_index)
1837
- gt_phased = core.BufferedArray(gt_phased_array, partition.start_index)
1838
-
1839
- source_col = self.icf.columns["FORMAT/GT"]
1840
- for value in source_col.iter_values(
1841
- partition.start_index, partition.stop_index
1842
- ):
1843
- j = gt.next_buffer_row()
1844
- sanitise_value_int_2d(gt.buff, j, value[:, :-1])
1845
- j = gt_phased.next_buffer_row()
1846
- sanitise_value_int_1d(gt_phased.buff, j, value[:, -1])
1847
- # TODO check is this the correct semantics when we are padding
1848
- # with mixed ploidies?
1849
- j = gt_mask.next_buffer_row()
1850
- gt_mask.buff[j] = gt.buff[j] < 0
1851
- gt.flush()
1852
- gt_phased.flush()
1853
- gt_mask.flush()
1854
-
1855
- self.finalise_partition_array(partition_index, "call_genotype")
1856
- self.finalise_partition_array(partition_index, "call_genotype_mask")
1857
- self.finalise_partition_array(partition_index, "call_genotype_phased")
1858
-
1859
- def encode_alleles_partition(self, partition_index):
1860
- array_name = "variant_allele"
1861
- alleles_array = self.init_partition_array(partition_index, array_name)
1862
- partition = self.metadata.partitions[partition_index]
1863
- alleles = core.BufferedArray(alleles_array, partition.start_index)
1864
- ref_col = self.icf.columns["REF"]
1865
- alt_col = self.icf.columns["ALT"]
1866
-
1867
- for ref, alt in zip(
1868
- ref_col.iter_values(partition.start_index, partition.stop_index),
1869
- alt_col.iter_values(partition.start_index, partition.stop_index),
1870
- ):
1871
- j = alleles.next_buffer_row()
1872
- alleles.buff[j, :] = STR_FILL
1873
- alleles.buff[j, 0] = ref[0]
1874
- alleles.buff[j, 1 : 1 + len(alt)] = alt
1875
- alleles.flush()
1876
-
1877
- self.finalise_partition_array(partition_index, array_name)
1878
-
1879
- def encode_id_partition(self, partition_index):
1880
- vid_array = self.init_partition_array(partition_index, "variant_id")
1881
- vid_mask_array = self.init_partition_array(partition_index, "variant_id_mask")
1882
- partition = self.metadata.partitions[partition_index]
1883
- vid = core.BufferedArray(vid_array, partition.start_index)
1884
- vid_mask = core.BufferedArray(vid_mask_array, partition.start_index)
1885
- col = self.icf.columns["ID"]
1886
-
1887
- for value in col.iter_values(partition.start_index, partition.stop_index):
1888
- j = vid.next_buffer_row()
1889
- k = vid_mask.next_buffer_row()
1890
- assert j == k
1891
- if value is not None:
1892
- vid.buff[j] = value[0]
1893
- vid_mask.buff[j] = False
1894
- else:
1895
- vid.buff[j] = STR_MISSING
1896
- vid_mask.buff[j] = True
1897
- vid.flush()
1898
- vid_mask.flush()
1899
-
1900
- self.finalise_partition_array(partition_index, "variant_id")
1901
- self.finalise_partition_array(partition_index, "variant_id_mask")
1902
-
1903
- def encode_filters_partition(self, partition_index):
1904
- lookup = {filt: index for index, filt in enumerate(self.schema.filter_id)}
1905
- array_name = "variant_filter"
1906
- array = self.init_partition_array(partition_index, array_name)
1907
- partition = self.metadata.partitions[partition_index]
1908
- var_filter = core.BufferedArray(array, partition.start_index)
1909
-
1910
- col = self.icf.columns["FILTERS"]
1911
- for value in col.iter_values(partition.start_index, partition.stop_index):
1912
- j = var_filter.next_buffer_row()
1913
- var_filter.buff[j] = False
1914
- for f in value:
1915
- try:
1916
- var_filter.buff[j, lookup[f]] = True
1917
- except KeyError:
1918
- raise ValueError(
1919
- f"Filter '{f}' was not defined in the header."
1920
- ) from None
1921
- var_filter.flush()
1922
-
1923
- self.finalise_partition_array(partition_index, array_name)
1924
-
1925
- def encode_contig_partition(self, partition_index):
1926
- lookup = {contig: index for index, contig in enumerate(self.schema.contig_id)}
1927
- array_name = "variant_contig"
1928
- array = self.init_partition_array(partition_index, array_name)
1929
- partition = self.metadata.partitions[partition_index]
1930
- contig = core.BufferedArray(array, partition.start_index)
1931
- col = self.icf.columns["CHROM"]
1932
-
1933
- for value in col.iter_values(partition.start_index, partition.stop_index):
1934
- j = contig.next_buffer_row()
1935
- # Note: because we are using the indexes to define the lookups
1936
- # and we always have an index, it seems that we the contig lookup
1937
- # will always succeed. However, if anyone ever does hit a KeyError
1938
- # here, please do open an issue with a reproducible example!
1939
- contig.buff[j] = lookup[value[0]]
1940
- contig.flush()
1941
-
1942
- self.finalise_partition_array(partition_index, array_name)
1943
-
1944
- #######################
1945
- # finalise
1946
- #######################
1947
-
1948
- def finalise_array(self, name):
1949
- logger.info(f"Finalising {name}")
1950
- final_path = self.path / name
1951
- if final_path.exists():
1952
- # NEEDS TEST
1953
- raise ValueError(f"Array {name} already exists")
1954
- for partition in range(len(self.metadata.partitions)):
1955
- # Move all the files in partition dir to dest dir
1956
- src = self.partition_array_path(partition, name)
1957
- if not src.exists():
1958
- # Needs test
1959
- raise ValueError(f"Partition {partition} of {name} does not exist")
1960
- dest = self.arrays_path / name
1961
- # This is Zarr v2 specific. Chunks in v3 with start with "c" prefix.
1962
- chunk_files = [
1963
- path for path in src.iterdir() if not path.name.startswith(".")
1964
- ]
1965
- # TODO check for a count of then number of files. If we require a
1966
- # dimension_separator of "/" then we could make stronger assertions
1967
- # here, as we'd always have num_variant_chunks
1968
- logger.debug(
1969
- f"Moving {len(chunk_files)} chunks for {name} partition {partition}"
1970
- )
1971
- for chunk_file in chunk_files:
1972
- os.rename(chunk_file, dest / chunk_file.name)
1973
- # Finally, once all the chunks have moved into the arrays dir,
1974
- # we move it out of wip
1975
- os.rename(self.arrays_path / name, self.path / name)
1976
- core.update_progress(1)
1977
-
1978
- def finalise(self, show_progress=False):
1979
- self.load_metadata()
1980
-
1981
- progress_config = core.ProgressConfig(
1982
- total=len(self.schema.columns),
1983
- title="Finalise",
1984
- units="array",
1985
- show=show_progress,
1986
- )
1987
- # NOTE: it's not clear that adding more workers will make this quicker,
1988
- # as it's just going to be causing contention on the file system.
1989
- # Something to check empirically in some deployments.
1990
- # FIXME we're just using worker_processes=0 here to hook into the
1991
- # SynchronousExecutor which is intended for testing purposes so
1992
- # that we get test coverage. Should fix this either by allowing
1993
- # for multiple workers, or making a standard wrapper for tqdm
1994
- # that allows us to have a consistent look and feel.
1995
- with core.ParallelWorkManager(0, progress_config) as pwm:
1996
- for name in self.schema.columns:
1997
- pwm.submit(self.finalise_array, name)
1998
- zarr.consolidate_metadata(self.path)
1999
-
2000
- ######################
2001
- # encode_all_partitions
2002
- ######################
2003
-
2004
- def get_max_encoding_memory(self):
2005
- """
2006
- Return the approximate maximum memory used to encode a variant chunk.
2007
- """
2008
- max_encoding_mem = max(
2009
- col.variant_chunk_nbytes for col in self.schema.columns.values()
2010
- )
2011
- gt_mem = 0
2012
- if "call_genotype" in self.schema.columns:
2013
- encoded_together = [
2014
- "call_genotype",
2015
- "call_genotype_phased",
2016
- "call_genotype_mask",
2017
- ]
2018
- gt_mem = sum(
2019
- self.schema.columns[col].variant_chunk_nbytes
2020
- for col in encoded_together
2021
- )
2022
- return max(max_encoding_mem, gt_mem)
2023
-
2024
- def encode_all_partitions(
2025
- self, *, worker_processes=1, show_progress=False, max_memory=None
2026
- ):
2027
- max_memory = parse_max_memory(max_memory)
2028
- self.load_metadata()
2029
- num_partitions = self.num_partitions
2030
- per_worker_memory = self.get_max_encoding_memory()
2031
- logger.info(
2032
- f"Encoding Zarr over {num_partitions} partitions with "
2033
- f"{worker_processes} workers and {display_size(per_worker_memory)} "
2034
- "per worker"
2035
- )
2036
- # Each partition requires per_worker_memory bytes, so to prevent more that
2037
- # max_memory being used, we clamp the number of workers
2038
- max_num_workers = max_memory // per_worker_memory
2039
- if max_num_workers < worker_processes:
2040
- logger.warning(
2041
- f"Limiting number of workers to {max_num_workers} to "
2042
- f"keep within specified memory budget of {display_size(max_memory)}"
2043
- )
2044
- if max_num_workers <= 0:
2045
- raise ValueError(
2046
- f"Insufficient memory to encode a partition:"
2047
- f"{display_size(per_worker_memory)} > {display_size(max_memory)}"
2048
- )
2049
- num_workers = min(max_num_workers, worker_processes)
2050
-
2051
- total_bytes = 0
2052
- for col in self.schema.columns.values():
2053
- # Open the array definition to get the total size
2054
- total_bytes += zarr.open(self.arrays_path / col.name).nbytes
2055
-
2056
- progress_config = core.ProgressConfig(
2057
- total=total_bytes,
2058
- title="Encode",
2059
- units="B",
2060
- show=show_progress,
2061
- )
2062
- with core.ParallelWorkManager(num_workers, progress_config) as pwm:
2063
- for partition_index in range(num_partitions):
2064
- pwm.submit(self.encode_partition, partition_index)
2065
-
2066
-
2067
- def mkschema(if_path, out):
2068
- icf = IntermediateColumnarFormat(if_path)
2069
- spec = VcfZarrSchema.generate(icf)
2070
- out.write(spec.asjson())
2071
-
2072
-
2073
- def encode(
2074
- if_path,
2075
- zarr_path,
2076
- schema_path=None,
2077
- variants_chunk_size=None,
2078
- samples_chunk_size=None,
2079
- max_variant_chunks=None,
2080
- dimension_separator=None,
2081
- max_memory=None,
2082
- worker_processes=1,
2083
- show_progress=False,
2084
- ):
2085
- # Rough heuristic to split work up enough to keep utilisation high
2086
- target_num_partitions = max(1, worker_processes * 4)
2087
- encode_init(
2088
- if_path,
2089
- zarr_path,
2090
- target_num_partitions,
2091
- schema_path=schema_path,
2092
- variants_chunk_size=variants_chunk_size,
2093
- samples_chunk_size=samples_chunk_size,
2094
- max_variant_chunks=max_variant_chunks,
2095
- dimension_separator=dimension_separator,
2096
- )
2097
- vzw = VcfZarrWriter(zarr_path)
2098
- vzw.encode_all_partitions(
2099
- worker_processes=worker_processes,
2100
- show_progress=show_progress,
2101
- max_memory=max_memory,
2102
- )
2103
- vzw.finalise(show_progress)
2104
-
2105
-
2106
- def encode_init(
2107
- icf_path,
2108
- zarr_path,
2109
- target_num_partitions,
2110
- *,
2111
- schema_path=None,
2112
- variants_chunk_size=None,
2113
- samples_chunk_size=None,
2114
- max_variant_chunks=None,
2115
- dimension_separator=None,
2116
- max_memory=None,
2117
- worker_processes=1,
2118
- show_progress=False,
2119
- ):
2120
- icf = IntermediateColumnarFormat(icf_path)
2121
- if schema_path is None:
2122
- schema = VcfZarrSchema.generate(
2123
- icf,
2124
- variants_chunk_size=variants_chunk_size,
2125
- samples_chunk_size=samples_chunk_size,
2126
- )
2127
- else:
2128
- logger.info(f"Reading schema from {schema_path}")
2129
- if variants_chunk_size is not None or samples_chunk_size is not None:
2130
- raise ValueError(
2131
- "Cannot specify schema along with chunk sizes"
2132
- ) # NEEDS TEST
2133
- with open(schema_path) as f:
2134
- schema = VcfZarrSchema.fromjson(f.read())
2135
- zarr_path = pathlib.Path(zarr_path)
2136
- vzw = VcfZarrWriter(zarr_path)
2137
- vzw.init(
2138
- icf,
2139
- target_num_partitions=target_num_partitions,
2140
- schema=schema,
2141
- dimension_separator=dimension_separator,
2142
- max_variant_chunks=max_variant_chunks,
2143
- )
2144
- return vzw.num_partitions, vzw.get_max_encoding_memory()
2145
-
2146
-
2147
- def encode_partition(zarr_path, partition):
2148
- writer = VcfZarrWriter(zarr_path)
2149
- writer.encode_partition(partition)
2150
-
2151
-
2152
- def encode_finalise(zarr_path, show_progress=False):
2153
- writer = VcfZarrWriter(zarr_path)
2154
- writer.finalise(show_progress=show_progress)
2155
-
2156
-
2157
- def convert(
2158
- vcfs,
2159
- out_path,
2160
- *,
2161
- variants_chunk_size=None,
2162
- samples_chunk_size=None,
2163
- worker_processes=1,
2164
- show_progress=False,
2165
- # TODO add arguments to control location of tmpdir
2166
- ):
2167
- with tempfile.TemporaryDirectory(prefix="vcf2zarr") as tmp:
2168
- if_dir = pathlib.Path(tmp) / "if"
2169
- explode(
2170
- if_dir,
2171
- vcfs,
2172
- worker_processes=worker_processes,
2173
- show_progress=show_progress,
2174
- )
2175
- encode(
2176
- if_dir,
2177
- out_path,
2178
- variants_chunk_size=variants_chunk_size,
2179
- samples_chunk_size=samples_chunk_size,
2180
- worker_processes=worker_processes,
2181
- show_progress=show_progress,
2182
- )
2183
-
2184
-
2185
- def assert_all_missing_float(a):
2186
- v = np.array(a, dtype=np.float32).view(np.int32)
2187
- nt.assert_equal(v, FLOAT32_MISSING_AS_INT32)
2188
-
2189
-
2190
- def assert_all_fill_float(a):
2191
- v = np.array(a, dtype=np.float32).view(np.int32)
2192
- nt.assert_equal(v, FLOAT32_FILL_AS_INT32)
2193
-
2194
-
2195
- def assert_all_missing_int(a):
2196
- v = np.array(a, dtype=int)
2197
- nt.assert_equal(v, -1)
2198
-
2199
-
2200
- def assert_all_fill_int(a):
2201
- v = np.array(a, dtype=int)
2202
- nt.assert_equal(v, -2)
2203
-
2204
-
2205
- def assert_all_missing_string(a):
2206
- nt.assert_equal(a, ".")
2207
-
2208
-
2209
- def assert_all_fill_string(a):
2210
- nt.assert_equal(a, "")
2211
-
2212
-
2213
- def assert_all_fill(zarr_val, vcf_type):
2214
- if vcf_type == "Integer":
2215
- assert_all_fill_int(zarr_val)
2216
- elif vcf_type in ("String", "Character"):
2217
- assert_all_fill_string(zarr_val)
2218
- elif vcf_type == "Float":
2219
- assert_all_fill_float(zarr_val)
2220
- else: # pragma: no cover
2221
- assert False # noqa PT015
2222
-
2223
-
2224
- def assert_all_missing(zarr_val, vcf_type):
2225
- if vcf_type == "Integer":
2226
- assert_all_missing_int(zarr_val)
2227
- elif vcf_type in ("String", "Character"):
2228
- assert_all_missing_string(zarr_val)
2229
- elif vcf_type == "Flag":
2230
- assert zarr_val == False # noqa 712
2231
- elif vcf_type == "Float":
2232
- assert_all_missing_float(zarr_val)
2233
- else: # pragma: no cover
2234
- assert False # noqa PT015
2235
-
2236
-
2237
- def assert_info_val_missing(zarr_val, vcf_type):
2238
- assert_all_missing(zarr_val, vcf_type)
2239
-
2240
-
2241
- def assert_format_val_missing(zarr_val, vcf_type):
2242
- assert_info_val_missing(zarr_val, vcf_type)
2243
-
2244
-
2245
- # Note: checking exact equality may prove problematic here
2246
- # but we should be deterministically storing what cyvcf2
2247
- # provides, which should compare equal.
2248
-
2249
-
2250
- def assert_info_val_equal(vcf_val, zarr_val, vcf_type):
2251
- assert vcf_val is not None
2252
- if vcf_type in ("String", "Character"):
2253
- split = list(vcf_val.split(","))
2254
- k = len(split)
2255
- if isinstance(zarr_val, str):
2256
- assert k == 1
2257
- # Scalar
2258
- assert vcf_val == zarr_val
2259
- else:
2260
- nt.assert_equal(split, zarr_val[:k])
2261
- assert_all_fill(zarr_val[k:], vcf_type)
2262
-
2263
- elif isinstance(vcf_val, tuple):
2264
- vcf_missing_value_map = {
2265
- "Integer": -1,
2266
- "Float": FLOAT32_MISSING,
2267
- }
2268
- v = [vcf_missing_value_map[vcf_type] if x is None else x for x in vcf_val]
2269
- missing = np.array([j for j, x in enumerate(vcf_val) if x is None], dtype=int)
2270
- a = np.array(v)
2271
- k = len(a)
2272
- # We are checking for int missing twice here, but it's necessary to have
2273
- # a separate check for floats because different NaNs compare equal
2274
- nt.assert_equal(a, zarr_val[:k])
2275
- assert_all_missing(zarr_val[missing], vcf_type)
2276
- if k < len(zarr_val):
2277
- assert_all_fill(zarr_val[k:], vcf_type)
2278
- else:
2279
- # Scalar
2280
- zarr_val = np.array(zarr_val, ndmin=1)
2281
- assert len(zarr_val.shape) == 1
2282
- assert vcf_val == zarr_val[0]
2283
- if len(zarr_val) > 1:
2284
- assert_all_fill(zarr_val[1:], vcf_type)
2285
-
2286
-
2287
- def assert_format_val_equal(vcf_val, zarr_val, vcf_type):
2288
- assert vcf_val is not None
2289
- assert isinstance(vcf_val, np.ndarray)
2290
- if vcf_type in ("String", "Character"):
2291
- assert len(vcf_val) == len(zarr_val)
2292
- for v, z in zip(vcf_val, zarr_val):
2293
- split = list(v.split(","))
2294
- # Note: deliberately duplicating logic here between this and the
2295
- # INFO col above to make sure all combinations are covered by tests
2296
- k = len(split)
2297
- if k == 1:
2298
- assert v == z
2299
- else:
2300
- nt.assert_equal(split, z[:k])
2301
- assert_all_fill(z[k:], vcf_type)
2302
- else:
2303
- assert vcf_val.shape[0] == zarr_val.shape[0]
2304
- if len(vcf_val.shape) == len(zarr_val.shape) + 1:
2305
- assert vcf_val.shape[-1] == 1
2306
- vcf_val = vcf_val[..., 0]
2307
- assert len(vcf_val.shape) <= 2
2308
- assert len(vcf_val.shape) == len(zarr_val.shape)
2309
- if len(vcf_val.shape) == 2:
2310
- k = vcf_val.shape[1]
2311
- if zarr_val.shape[1] != k:
2312
- assert_all_fill(zarr_val[:, k:], vcf_type)
2313
- zarr_val = zarr_val[:, :k]
2314
- assert vcf_val.shape == zarr_val.shape
2315
- if vcf_type == "Integer":
2316
- vcf_val[vcf_val == VCF_INT_MISSING] = INT_MISSING
2317
- vcf_val[vcf_val == VCF_INT_FILL] = INT_FILL
2318
- elif vcf_type == "Float":
2319
- nt.assert_equal(vcf_val.view(np.int32), zarr_val.view(np.int32))
2320
-
2321
- nt.assert_equal(vcf_val, zarr_val)
2322
-
2323
-
2324
- # TODO rename to "verify"
2325
- def validate(vcf_path, zarr_path, show_progress=False):
2326
- store = zarr.DirectoryStore(zarr_path)
2327
-
2328
- root = zarr.group(store=store)
2329
- pos = root["variant_position"][:]
2330
- allele = root["variant_allele"][:]
2331
- chrom = root["contig_id"][:][root["variant_contig"][:]]
2332
- vid = root["variant_id"][:]
2333
- call_genotype = None
2334
- if "call_genotype" in root:
2335
- call_genotype = iter(root["call_genotype"])
2336
-
2337
- vcf = cyvcf2.VCF(vcf_path)
2338
- format_headers = {}
2339
- info_headers = {}
2340
- for h in vcf.header_iter():
2341
- if h["HeaderType"] == "FORMAT":
2342
- format_headers[h["ID"]] = h
2343
- if h["HeaderType"] == "INFO":
2344
- info_headers[h["ID"]] = h
2345
-
2346
- format_fields = {}
2347
- info_fields = {}
2348
- for colname in root.keys():
2349
- if colname.startswith("call") and not colname.startswith("call_genotype"):
2350
- vcf_name = colname.split("_", 1)[1]
2351
- vcf_type = format_headers[vcf_name]["Type"]
2352
- format_fields[vcf_name] = vcf_type, iter(root[colname])
2353
- if colname.startswith("variant"):
2354
- name = colname.split("_", 1)[1]
2355
- if name.isupper():
2356
- vcf_type = info_headers[name]["Type"]
2357
- info_fields[name] = vcf_type, iter(root[colname])
2358
-
2359
- first_pos = next(vcf).POS
2360
- start_index = np.searchsorted(pos, first_pos)
2361
- assert pos[start_index] == first_pos
2362
- vcf = cyvcf2.VCF(vcf_path)
2363
- if show_progress:
2364
- iterator = tqdm.tqdm(vcf, desc=" Verify", total=vcf.num_records) # NEEDS TEST
2365
- else:
2366
- iterator = vcf
2367
- for j, row in enumerate(iterator, start_index):
2368
- assert chrom[j] == row.CHROM
2369
- assert pos[j] == row.POS
2370
- assert vid[j] == ("." if row.ID is None else row.ID)
2371
- assert allele[j, 0] == row.REF
2372
- k = len(row.ALT)
2373
- nt.assert_array_equal(allele[j, 1 : k + 1], row.ALT)
2374
- assert np.all(allele[j, k + 1 :] == "")
2375
- # TODO FILTERS
2376
-
2377
- if call_genotype is None:
2378
- val = None
2379
- try:
2380
- val = row.format("GT")
2381
- except KeyError:
2382
- pass
2383
- assert val is None
2384
- else:
2385
- gt = row.genotype.array()
2386
- gt_zarr = next(call_genotype)
2387
- gt_vcf = gt[:, :-1]
2388
- # NOTE cyvcf2 remaps genotypes automatically
2389
- # into the same missing/pad encoding that sgkit uses.
2390
- nt.assert_array_equal(gt_zarr, gt_vcf)
2391
-
2392
- for name, (vcf_type, zarr_iter) in info_fields.items():
2393
- vcf_val = row.INFO.get(name, None)
2394
- zarr_val = next(zarr_iter)
2395
- if vcf_val is None:
2396
- assert_info_val_missing(zarr_val, vcf_type)
2397
- else:
2398
- assert_info_val_equal(vcf_val, zarr_val, vcf_type)
2399
-
2400
- for name, (vcf_type, zarr_iter) in format_fields.items():
2401
- vcf_val = row.format(name)
2402
- zarr_val = next(zarr_iter)
2403
- if vcf_val is None:
2404
- assert_format_val_missing(zarr_val, vcf_type)
2405
- else:
2406
- assert_format_val_equal(vcf_val, zarr_val, vcf_type)