bio2zarr 0.0.9__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bio2zarr might be problematic. Click here for more details.

bio2zarr/vcf.py DELETED
@@ -1,2445 +0,0 @@
1
- import collections
2
- import contextlib
3
- import dataclasses
4
- import json
5
- import logging
6
- import math
7
- import os
8
- import os.path
9
- import pathlib
10
- import pickle
11
- import shutil
12
- import sys
13
- import tempfile
14
- from typing import Any
15
-
16
- import cyvcf2
17
- import humanfriendly
18
- import numcodecs
19
- import numpy as np
20
- import numpy.testing as nt
21
- import tqdm
22
- import zarr
23
-
24
- from . import core, provenance, vcf_utils
25
-
26
- logger = logging.getLogger(__name__)
27
-
28
- INT_MISSING = -1
29
- INT_FILL = -2
30
- STR_MISSING = "."
31
- STR_FILL = ""
32
-
33
- FLOAT32_MISSING, FLOAT32_FILL = np.array([0x7F800001, 0x7F800002], dtype=np.int32).view(
34
- np.float32
35
- )
36
- FLOAT32_MISSING_AS_INT32, FLOAT32_FILL_AS_INT32 = np.array(
37
- [0x7F800001, 0x7F800002], dtype=np.int32
38
- )
39
-
40
-
41
- def display_number(x):
42
- ret = "n/a"
43
- if math.isfinite(x):
44
- ret = f"{x: 0.2g}"
45
- return ret
46
-
47
-
48
- def display_size(n):
49
- return humanfriendly.format_size(n, binary=True)
50
-
51
-
52
- @dataclasses.dataclass
53
- class VcfFieldSummary:
54
- num_chunks: int = 0
55
- compressed_size: int = 0
56
- uncompressed_size: int = 0
57
- max_number: int = 0 # Corresponds to VCF Number field, depends on context
58
- # Only defined for numeric fields
59
- max_value: Any = -math.inf
60
- min_value: Any = math.inf
61
-
62
- def update(self, other):
63
- self.num_chunks += other.num_chunks
64
- self.compressed_size += other.compressed_size
65
- self.uncompressed_size += other.uncompressed_size
66
- self.max_number = max(self.max_number, other.max_number)
67
- self.min_value = min(self.min_value, other.min_value)
68
- self.max_value = max(self.max_value, other.max_value)
69
-
70
- def asdict(self):
71
- return dataclasses.asdict(self)
72
-
73
- @staticmethod
74
- def fromdict(d):
75
- return VcfFieldSummary(**d)
76
-
77
-
78
- @dataclasses.dataclass
79
- class VcfField:
80
- category: str
81
- name: str
82
- vcf_number: str
83
- vcf_type: str
84
- description: str
85
- summary: VcfFieldSummary
86
-
87
- @staticmethod
88
- def from_header(definition):
89
- category = definition["HeaderType"]
90
- name = definition["ID"]
91
- vcf_number = definition["Number"]
92
- vcf_type = definition["Type"]
93
- return VcfField(
94
- category=category,
95
- name=name,
96
- vcf_number=vcf_number,
97
- vcf_type=vcf_type,
98
- description=definition["Description"].strip('"'),
99
- summary=VcfFieldSummary(),
100
- )
101
-
102
- @staticmethod
103
- def fromdict(d):
104
- f = VcfField(**d)
105
- f.summary = VcfFieldSummary(**d["summary"])
106
- return f
107
-
108
- @property
109
- def full_name(self):
110
- if self.category == "fixed":
111
- return self.name
112
- return f"{self.category}/{self.name}"
113
-
114
- def smallest_dtype(self):
115
- """
116
- Returns the smallest dtype suitable for this field based
117
- on type, and values.
118
- """
119
- s = self.summary
120
- if self.vcf_type == "Float":
121
- ret = "f4"
122
- elif self.vcf_type == "Integer":
123
- if not math.isfinite(s.max_value):
124
- # All missing values; use i1. Note we should have some API to
125
- # check more explicitly for missingness:
126
- # https://github.com/sgkit-dev/bio2zarr/issues/131
127
- ret = "i1"
128
- else:
129
- ret = core.min_int_dtype(s.min_value, s.max_value)
130
- elif self.vcf_type == "Flag":
131
- ret = "bool"
132
- elif self.vcf_type == "Character":
133
- ret = "U1"
134
- else:
135
- assert self.vcf_type == "String"
136
- ret = "O"
137
- return ret
138
-
139
-
140
- @dataclasses.dataclass
141
- class VcfPartition:
142
- vcf_path: str
143
- region: str
144
- num_records: int = -1
145
-
146
-
147
- ICF_METADATA_FORMAT_VERSION = "0.3"
148
- ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc(
149
- cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE
150
- )
151
-
152
-
153
- @dataclasses.dataclass
154
- class Contig:
155
- id: str
156
- length: int = None
157
-
158
-
159
- @dataclasses.dataclass
160
- class Sample:
161
- id: str
162
-
163
-
164
- @dataclasses.dataclass
165
- class Filter:
166
- id: str
167
- description: str = ""
168
-
169
-
170
- @dataclasses.dataclass
171
- class IcfMetadata:
172
- samples: list
173
- contigs: list
174
- filters: list
175
- fields: list
176
- partitions: list = None
177
- format_version: str = None
178
- compressor: dict = None
179
- column_chunk_size: int = None
180
- provenance: dict = None
181
- num_records: int = -1
182
-
183
- @property
184
- def info_fields(self):
185
- fields = []
186
- for field in self.fields:
187
- if field.category == "INFO":
188
- fields.append(field)
189
- return fields
190
-
191
- @property
192
- def format_fields(self):
193
- fields = []
194
- for field in self.fields:
195
- if field.category == "FORMAT":
196
- fields.append(field)
197
- return fields
198
-
199
- @property
200
- def num_contigs(self):
201
- return len(self.contigs)
202
-
203
- @property
204
- def num_filters(self):
205
- return len(self.filters)
206
-
207
- @staticmethod
208
- def fromdict(d):
209
- if d["format_version"] != ICF_METADATA_FORMAT_VERSION:
210
- raise ValueError(
211
- "Intermediate columnar metadata format version mismatch: "
212
- f"{d['format_version']} != {ICF_METADATA_FORMAT_VERSION}"
213
- )
214
- partitions = [VcfPartition(**pd) for pd in d["partitions"]]
215
- for p in partitions:
216
- p.region = vcf_utils.Region(**p.region)
217
- d = d.copy()
218
- d["partitions"] = partitions
219
- d["fields"] = [VcfField.fromdict(fd) for fd in d["fields"]]
220
- d["samples"] = [Sample(**sd) for sd in d["samples"]]
221
- d["filters"] = [Filter(**fd) for fd in d["filters"]]
222
- d["contigs"] = [Contig(**cd) for cd in d["contigs"]]
223
- return IcfMetadata(**d)
224
-
225
- def asdict(self):
226
- return dataclasses.asdict(self)
227
-
228
- def asjson(self):
229
- return json.dumps(self.asdict(), indent=4)
230
-
231
-
232
- def fixed_vcf_field_definitions():
233
- def make_field_def(name, vcf_type, vcf_number):
234
- return VcfField(
235
- category="fixed",
236
- name=name,
237
- vcf_type=vcf_type,
238
- vcf_number=vcf_number,
239
- description="",
240
- summary=VcfFieldSummary(),
241
- )
242
-
243
- fields = [
244
- make_field_def("CHROM", "String", "1"),
245
- make_field_def("POS", "Integer", "1"),
246
- make_field_def("QUAL", "Float", "1"),
247
- make_field_def("ID", "String", "."),
248
- make_field_def("FILTERS", "String", "."),
249
- make_field_def("REF", "String", "1"),
250
- make_field_def("ALT", "String", "."),
251
- ]
252
- return fields
253
-
254
-
255
- def scan_vcf(path, target_num_partitions):
256
- with vcf_utils.IndexedVcf(path) as indexed_vcf:
257
- vcf = indexed_vcf.vcf
258
- filters = []
259
- pass_index = -1
260
- for h in vcf.header_iter():
261
- if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str):
262
- try:
263
- description = h["Description"].strip('"')
264
- except KeyError:
265
- description = ""
266
- if h["ID"] == "PASS":
267
- pass_index = len(filters)
268
- filters.append(Filter(h["ID"], description))
269
-
270
- # Ensure PASS is the first filter if present
271
- if pass_index > 0:
272
- pass_filter = filters.pop(pass_index)
273
- filters.insert(0, pass_filter)
274
-
275
- fields = fixed_vcf_field_definitions()
276
- for h in vcf.header_iter():
277
- if h["HeaderType"] in ["INFO", "FORMAT"]:
278
- field = VcfField.from_header(h)
279
- if field.name == "GT":
280
- field.vcf_type = "Integer"
281
- field.vcf_number = "."
282
- fields.append(field)
283
-
284
- try:
285
- contig_lengths = vcf.seqlens
286
- except AttributeError:
287
- contig_lengths = [None for _ in vcf.seqnames]
288
-
289
- metadata = IcfMetadata(
290
- samples=[Sample(sample_id) for sample_id in vcf.samples],
291
- contigs=[
292
- Contig(contig_id, length)
293
- for contig_id, length in zip(vcf.seqnames, contig_lengths)
294
- ],
295
- filters=filters,
296
- fields=fields,
297
- partitions=[],
298
- num_records=sum(indexed_vcf.contig_record_counts().values()),
299
- )
300
-
301
- regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions)
302
- logger.info(
303
- f"Split {path} into {len(regions)} regions (target={target_num_partitions})"
304
- )
305
- for region in regions:
306
- metadata.partitions.append(
307
- VcfPartition(
308
- # TODO should this be fully resolving the path? Otherwise it's all
309
- # relative to the original WD
310
- vcf_path=str(path),
311
- region=region,
312
- )
313
- )
314
- core.update_progress(1)
315
- return metadata, vcf.raw_header
316
-
317
-
318
- def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
319
- logger.info(
320
- f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
321
- f" partitions."
322
- )
323
- # An easy mistake to make is to pass the same file twice. Check this early on.
324
- for path, count in collections.Counter(paths).items():
325
- if not path.exists(): # NEEDS TEST
326
- raise FileNotFoundError(path)
327
- if count > 1:
328
- raise ValueError(f"Duplicate path provided: {path}")
329
-
330
- progress_config = core.ProgressConfig(
331
- total=len(paths),
332
- units="files",
333
- title="Scan",
334
- show=show_progress,
335
- )
336
- with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
337
- for path in paths:
338
- pwm.submit(scan_vcf, path, max(1, target_num_partitions // len(paths)))
339
- results = list(pwm.results_as_completed())
340
-
341
- # Sort to make the ordering deterministic
342
- results.sort(key=lambda t: t[0].partitions[0].vcf_path)
343
- # We just take the first header, assuming the others
344
- # are compatible.
345
- all_partitions = []
346
- total_records = 0
347
- for metadata, _ in results:
348
- for partition in metadata.partitions:
349
- logger.debug(f"Scanned partition {partition}")
350
- all_partitions.append(partition)
351
- total_records += metadata.num_records
352
- metadata.num_records = 0
353
- metadata.partitions = []
354
-
355
- icf_metadata, header = results[0]
356
- for metadata, _ in results[1:]:
357
- if metadata != icf_metadata:
358
- raise ValueError("Incompatible VCF chunks")
359
-
360
- # Note: this will be infinity here if any of the chunks has an index
361
- # that doesn't keep track of the number of records per-contig
362
- icf_metadata.num_records = total_records
363
-
364
- # Sort by contig (in the order they appear in the header) first,
365
- # then by start coordinate
366
- contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)}
367
- all_partitions.sort(
368
- key=lambda x: (contig_index_map[x.region.contig], x.region.start)
369
- )
370
- icf_metadata.partitions = all_partitions
371
- logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.")
372
- return icf_metadata, header
373
-
374
-
375
- def sanitise_value_bool(buff, j, value):
376
- x = True
377
- if value is None:
378
- x = False
379
- buff[j] = x
380
-
381
-
382
- def sanitise_value_float_scalar(buff, j, value):
383
- x = value
384
- if value is None:
385
- x = FLOAT32_MISSING
386
- buff[j] = x
387
-
388
-
389
- def sanitise_value_int_scalar(buff, j, value):
390
- x = value
391
- if value is None:
392
- # print("MISSING", INT_MISSING, INT_FILL)
393
- x = [INT_MISSING]
394
- else:
395
- x = sanitise_int_array([value], ndmin=1, dtype=np.int32)
396
- buff[j] = x[0]
397
-
398
-
399
- def sanitise_value_string_scalar(buff, j, value):
400
- if value is None:
401
- buff[j] = "."
402
- else:
403
- buff[j] = value[0]
404
-
405
-
406
- def sanitise_value_string_1d(buff, j, value):
407
- if value is None:
408
- buff[j] = "."
409
- else:
410
- # value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
411
- # FIXME failure isn't coming from here, it seems to be from an
412
- # incorrectly detected dimension in the zarr array
413
- # The dimesions look all wrong, and the dtype should be Object
414
- # not str
415
- value = drop_empty_second_dim(value)
416
- buff[j] = ""
417
- buff[j, : value.shape[0]] = value
418
-
419
-
420
- def sanitise_value_string_2d(buff, j, value):
421
- if value is None:
422
- buff[j] = "."
423
- else:
424
- # print(buff.shape, value.dtype, value)
425
- # assert value.ndim == 2
426
- buff[j] = ""
427
- if value.ndim == 2:
428
- buff[j, :, : value.shape[1]] = value
429
- else:
430
- # TODO check if this is still necessary
431
- for k, val in enumerate(value):
432
- buff[j, k, : len(val)] = val
433
-
434
-
435
- def drop_empty_second_dim(value):
436
- assert len(value.shape) == 1 or value.shape[1] == 1
437
- if len(value.shape) == 2 and value.shape[1] == 1:
438
- value = value[..., 0]
439
- return value
440
-
441
-
442
- def sanitise_value_float_1d(buff, j, value):
443
- if value is None:
444
- buff[j] = FLOAT32_MISSING
445
- else:
446
- value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
447
- # numpy will map None values to Nan, but we need a
448
- # specific NaN
449
- value[np.isnan(value)] = FLOAT32_MISSING
450
- value = drop_empty_second_dim(value)
451
- buff[j] = FLOAT32_FILL
452
- buff[j, : value.shape[0]] = value
453
-
454
-
455
- def sanitise_value_float_2d(buff, j, value):
456
- if value is None:
457
- buff[j] = FLOAT32_MISSING
458
- else:
459
- # print("value = ", value)
460
- value = np.array(value, ndmin=2, dtype=buff.dtype, copy=False)
461
- buff[j] = FLOAT32_FILL
462
- buff[j, :, : value.shape[1]] = value
463
-
464
-
465
- def sanitise_int_array(value, ndmin, dtype):
466
- if isinstance(value, tuple):
467
- value = [VCF_INT_MISSING if x is None else x for x in value] # NEEDS TEST
468
- value = np.array(value, ndmin=ndmin, copy=False)
469
- value[value == VCF_INT_MISSING] = -1
470
- value[value == VCF_INT_FILL] = -2
471
- # TODO watch out for clipping here!
472
- return value.astype(dtype)
473
-
474
-
475
- def sanitise_value_int_1d(buff, j, value):
476
- if value is None:
477
- buff[j] = -1
478
- else:
479
- value = sanitise_int_array(value, 1, buff.dtype)
480
- value = drop_empty_second_dim(value)
481
- buff[j] = -2
482
- buff[j, : value.shape[0]] = value
483
-
484
-
485
- def sanitise_value_int_2d(buff, j, value):
486
- if value is None:
487
- buff[j] = -1
488
- else:
489
- value = sanitise_int_array(value, 2, buff.dtype)
490
- buff[j] = -2
491
- buff[j, :, : value.shape[1]] = value
492
-
493
-
494
- MIN_INT_VALUE = np.iinfo(np.int32).min + 2
495
- VCF_INT_MISSING = np.iinfo(np.int32).min
496
- VCF_INT_FILL = np.iinfo(np.int32).min + 1
497
-
498
- missing_value_map = {
499
- "Integer": -1,
500
- "Float": FLOAT32_MISSING,
501
- "String": ".",
502
- "Character": ".",
503
- "Flag": False,
504
- }
505
-
506
-
507
- class VcfValueTransformer:
508
- """
509
- Transform VCF values into the stored intermediate format used
510
- in the IntermediateColumnarFormat, and update field summaries.
511
- """
512
-
513
- def __init__(self, field, num_samples):
514
- self.field = field
515
- self.num_samples = num_samples
516
- self.dimension = 1
517
- if field.category == "FORMAT":
518
- self.dimension = 2
519
- self.missing = missing_value_map[field.vcf_type]
520
-
521
- @staticmethod
522
- def factory(field, num_samples):
523
- if field.vcf_type in ("Integer", "Flag"):
524
- return IntegerValueTransformer(field, num_samples)
525
- if field.vcf_type == "Float":
526
- return FloatValueTransformer(field, num_samples)
527
- if field.name in ["REF", "FILTERS", "ALT", "ID", "CHROM"]:
528
- return SplitStringValueTransformer(field, num_samples)
529
- return StringValueTransformer(field, num_samples)
530
-
531
- def transform(self, vcf_value):
532
- if isinstance(vcf_value, tuple):
533
- vcf_value = [self.missing if v is None else v for v in vcf_value]
534
- value = np.array(vcf_value, ndmin=self.dimension, copy=False)
535
- return value
536
-
537
- def transform_and_update_bounds(self, vcf_value):
538
- if vcf_value is None:
539
- return None
540
- value = self.transform(vcf_value)
541
- self.update_bounds(value)
542
- # print(self.field.full_name, "T", vcf_value, "->", value)
543
- return value
544
-
545
-
546
- MIN_INT_VALUE = np.iinfo(np.int32).min + 2
547
- VCF_INT_MISSING = np.iinfo(np.int32).min
548
- VCF_INT_FILL = np.iinfo(np.int32).min + 1
549
-
550
-
551
- class IntegerValueTransformer(VcfValueTransformer):
552
- def update_bounds(self, value):
553
- summary = self.field.summary
554
- # Mask out missing and fill values
555
- # print(value)
556
- a = value[value >= MIN_INT_VALUE]
557
- if a.size > 0:
558
- summary.max_value = int(max(summary.max_value, np.max(a)))
559
- summary.min_value = int(min(summary.min_value, np.min(a)))
560
- number = value.shape[-1]
561
- summary.max_number = max(summary.max_number, number)
562
-
563
-
564
- class FloatValueTransformer(VcfValueTransformer):
565
- def update_bounds(self, value):
566
- summary = self.field.summary
567
- summary.max_value = float(max(summary.max_value, np.max(value)))
568
- summary.min_value = float(min(summary.min_value, np.min(value)))
569
- number = value.shape[-1]
570
- summary.max_number = max(summary.max_number, number)
571
-
572
-
573
- class StringValueTransformer(VcfValueTransformer):
574
- def update_bounds(self, value):
575
- summary = self.field.summary
576
- number = value.shape[-1]
577
- # TODO would be nice to report string lengths, but not
578
- # really necessary.
579
- summary.max_number = max(summary.max_number, number)
580
-
581
- def transform(self, vcf_value):
582
- # print("transform", vcf_value)
583
- if self.dimension == 1:
584
- value = np.array(list(vcf_value.split(",")))
585
- else:
586
- # TODO can we make this faster??
587
- value = np.array([v.split(",") for v in vcf_value], dtype="O")
588
- # print("HERE", vcf_value, value)
589
- # for v in vcf_value:
590
- # print("\t", type(v), len(v), v.split(","))
591
- # print("S: ", self.dimension, ":", value.shape, value)
592
- return value
593
-
594
-
595
- class SplitStringValueTransformer(StringValueTransformer):
596
- def transform(self, vcf_value):
597
- if vcf_value is None:
598
- return self.missing_value # NEEDS TEST
599
- assert self.dimension == 1
600
- return np.array(vcf_value, ndmin=1, dtype="str")
601
-
602
-
603
- def get_vcf_field_path(base_path, vcf_field):
604
- if vcf_field.category == "fixed":
605
- return base_path / vcf_field.name
606
- return base_path / vcf_field.category / vcf_field.name
607
-
608
-
609
- class IntermediateColumnarFormatField:
610
- def __init__(self, icf, vcf_field):
611
- self.vcf_field = vcf_field
612
- self.path = get_vcf_field_path(icf.path, vcf_field)
613
- self.compressor = icf.compressor
614
- self.num_partitions = icf.num_partitions
615
- self.num_records = icf.num_records
616
- self.partition_record_index = icf.partition_record_index
617
- # A map of partition id to the cumulative number of records
618
- # in chunks within that partition
619
- self._chunk_record_index = {}
620
-
621
- @property
622
- def name(self):
623
- return self.vcf_field.full_name
624
-
625
- def partition_path(self, partition_id):
626
- return self.path / f"p{partition_id}"
627
-
628
- def __repr__(self):
629
- partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)]
630
- return (
631
- f"IntermediateColumnarFormatField(name={self.name}, "
632
- f"partition_chunks={partition_chunks}, "
633
- f"path={self.path})"
634
- )
635
-
636
- def num_chunks(self, partition_id):
637
- return len(self.chunk_record_index(partition_id)) - 1
638
-
639
- def chunk_record_index(self, partition_id):
640
- if partition_id not in self._chunk_record_index:
641
- index_path = self.partition_path(partition_id) / "chunk_index"
642
- with open(index_path, "rb") as f:
643
- a = pickle.load(f)
644
- assert len(a) > 1
645
- assert a[0] == 0
646
- self._chunk_record_index[partition_id] = a
647
- return self._chunk_record_index[partition_id]
648
-
649
- def read_chunk(self, path):
650
- with open(path, "rb") as f:
651
- pkl = self.compressor.decode(f.read())
652
- return pickle.loads(pkl)
653
-
654
- def chunk_num_records(self, partition_id):
655
- return np.diff(self.chunk_record_index(partition_id))
656
-
657
- def chunks(self, partition_id, start_chunk=0):
658
- partition_path = self.partition_path(partition_id)
659
- chunk_cumulative_records = self.chunk_record_index(partition_id)
660
- chunk_num_records = np.diff(chunk_cumulative_records)
661
- for count, cumulative in zip(
662
- chunk_num_records[start_chunk:], chunk_cumulative_records[start_chunk + 1 :]
663
- ):
664
- path = partition_path / f"{cumulative}"
665
- chunk = self.read_chunk(path)
666
- if len(chunk) != count:
667
- raise ValueError(f"Corruption detected in chunk: {path}")
668
- yield chunk
669
-
670
- def iter_values(self, start=None, stop=None):
671
- start = 0 if start is None else start
672
- stop = self.num_records if stop is None else stop
673
- start_partition = (
674
- np.searchsorted(self.partition_record_index, start, side="right") - 1
675
- )
676
- offset = self.partition_record_index[start_partition]
677
- assert offset <= start
678
- chunk_offset = start - offset
679
-
680
- chunk_record_index = self.chunk_record_index(start_partition)
681
- start_chunk = (
682
- np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1
683
- )
684
- record_id = offset + chunk_record_index[start_chunk]
685
- assert record_id <= start
686
- logger.debug(
687
- f"Read {self.vcf_field.full_name} slice [{start}:{stop}]:"
688
- f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}"
689
- )
690
- for chunk in self.chunks(start_partition, start_chunk):
691
- for record in chunk:
692
- if record_id == stop:
693
- return
694
- if record_id >= start:
695
- yield record
696
- record_id += 1
697
- assert record_id > start
698
- for partition_id in range(start_partition + 1, self.num_partitions):
699
- for chunk in self.chunks(partition_id):
700
- for record in chunk:
701
- if record_id == stop:
702
- return
703
- yield record
704
- record_id += 1
705
-
706
- # Note: this involves some computation so should arguably be a method,
707
- # but making a property for consistency with xarray etc
708
- @property
709
- def values(self):
710
- ret = [None] * self.num_records
711
- j = 0
712
- for partition_id in range(self.num_partitions):
713
- for chunk in self.chunks(partition_id):
714
- for record in chunk:
715
- ret[j] = record
716
- j += 1
717
- assert j == self.num_records
718
- return ret
719
-
720
- def sanitiser_factory(self, shape):
721
- """
722
- Return a function that sanitised values from this column
723
- and writes into a buffer of the specified shape.
724
- """
725
- assert len(shape) <= 3
726
- if self.vcf_field.vcf_type == "Flag":
727
- assert len(shape) == 1
728
- return sanitise_value_bool
729
- elif self.vcf_field.vcf_type == "Float":
730
- if len(shape) == 1:
731
- return sanitise_value_float_scalar
732
- elif len(shape) == 2:
733
- return sanitise_value_float_1d
734
- else:
735
- return sanitise_value_float_2d
736
- elif self.vcf_field.vcf_type == "Integer":
737
- if len(shape) == 1:
738
- return sanitise_value_int_scalar
739
- elif len(shape) == 2:
740
- return sanitise_value_int_1d
741
- else:
742
- return sanitise_value_int_2d
743
- else:
744
- assert self.vcf_field.vcf_type in ("String", "Character")
745
- if len(shape) == 1:
746
- return sanitise_value_string_scalar
747
- elif len(shape) == 2:
748
- return sanitise_value_string_1d
749
- else:
750
- return sanitise_value_string_2d
751
-
752
-
753
- @dataclasses.dataclass
754
- class IcfFieldWriter:
755
- vcf_field: VcfField
756
- path: pathlib.Path
757
- transformer: VcfValueTransformer
758
- compressor: Any
759
- max_buffered_bytes: int
760
- buff: list[Any] = dataclasses.field(default_factory=list)
761
- buffered_bytes: int = 0
762
- chunk_index: list[int] = dataclasses.field(default_factory=lambda: [0])
763
- num_records: int = 0
764
-
765
- def append(self, val):
766
- val = self.transformer.transform_and_update_bounds(val)
767
- assert val is None or isinstance(val, np.ndarray)
768
- self.buff.append(val)
769
- val_bytes = sys.getsizeof(val)
770
- self.buffered_bytes += val_bytes
771
- self.num_records += 1
772
- if self.buffered_bytes >= self.max_buffered_bytes:
773
- logger.debug(
774
- f"Flush {self.path} buffered={self.buffered_bytes} "
775
- f"max={self.max_buffered_bytes}"
776
- )
777
- self.write_chunk()
778
- self.buff.clear()
779
- self.buffered_bytes = 0
780
-
781
- def write_chunk(self):
782
- # Update index
783
- self.chunk_index.append(self.num_records)
784
- path = self.path / f"{self.num_records}"
785
- logger.debug(f"Start write: {path}")
786
- pkl = pickle.dumps(self.buff)
787
- compressed = self.compressor.encode(pkl)
788
- with open(path, "wb") as f:
789
- f.write(compressed)
790
-
791
- # Update the summary
792
- self.vcf_field.summary.num_chunks += 1
793
- self.vcf_field.summary.compressed_size += len(compressed)
794
- self.vcf_field.summary.uncompressed_size += self.buffered_bytes
795
- logger.debug(f"Finish write: {path}")
796
-
797
- def flush(self):
798
- logger.debug(
799
- f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}"
800
- )
801
- if len(self.buff) > 0:
802
- self.write_chunk()
803
- with open(self.path / "chunk_index", "wb") as f:
804
- a = np.array(self.chunk_index, dtype=int)
805
- pickle.dump(a, f)
806
-
807
-
808
- class IcfPartitionWriter(contextlib.AbstractContextManager):
809
- """
810
- Writes the data for a IntermediateColumnarFormat partition.
811
- """
812
-
813
- def __init__(
814
- self,
815
- icf_metadata,
816
- out_path,
817
- partition_index,
818
- ):
819
- self.partition_index = partition_index
820
- # chunk_size is in megabytes
821
- max_buffered_bytes = icf_metadata.column_chunk_size * 2**20
822
- assert max_buffered_bytes > 0
823
- compressor = numcodecs.get_codec(icf_metadata.compressor)
824
-
825
- self.field_writers = {}
826
- num_samples = len(icf_metadata.samples)
827
- for vcf_field in icf_metadata.fields:
828
- field_path = get_vcf_field_path(out_path, vcf_field)
829
- field_partition_path = field_path / f"p{partition_index}"
830
- # Should be robust to running explode_partition twice.
831
- field_partition_path.mkdir(exist_ok=True)
832
- transformer = VcfValueTransformer.factory(vcf_field, num_samples)
833
- self.field_writers[vcf_field.full_name] = IcfFieldWriter(
834
- vcf_field,
835
- field_partition_path,
836
- transformer,
837
- compressor,
838
- max_buffered_bytes,
839
- )
840
-
841
- @property
842
- def field_summaries(self):
843
- return {
844
- name: field.vcf_field.summary for name, field in self.field_writers.items()
845
- }
846
-
847
- def append(self, name, value):
848
- self.field_writers[name].append(value)
849
-
850
- def __exit__(self, exc_type, exc_val, exc_tb):
851
- if exc_type is None:
852
- for field in self.field_writers.values():
853
- field.flush()
854
- return False
855
-
856
-
857
- class IntermediateColumnarFormat(collections.abc.Mapping):
858
- def __init__(self, path):
859
- self.path = pathlib.Path(path)
860
- # TODO raise a more informative error here telling people this
861
- # directory is either a WIP or the wrong format.
862
- with open(self.path / "metadata.json") as f:
863
- self.metadata = IcfMetadata.fromdict(json.load(f))
864
- with open(self.path / "header.txt") as f:
865
- self.vcf_header = f.read()
866
- self.compressor = numcodecs.get_codec(self.metadata.compressor)
867
- self.fields = {}
868
- partition_num_records = [
869
- partition.num_records for partition in self.metadata.partitions
870
- ]
871
- # Allow us to find which partition a given record is in
872
- self.partition_record_index = np.cumsum([0, *partition_num_records])
873
- for field in self.metadata.fields:
874
- self.fields[field.full_name] = IntermediateColumnarFormatField(self, field)
875
- logger.info(
876
- f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, "
877
- f"records={self.num_records}, fields={self.num_fields})"
878
- )
879
-
880
- def __repr__(self):
881
- return (
882
- f"IntermediateColumnarFormat(fields={len(self)}, "
883
- f"partitions={self.num_partitions}, "
884
- f"records={self.num_records}, path={self.path})"
885
- )
886
-
887
- def __getitem__(self, key):
888
- return self.fields[key]
889
-
890
- def __iter__(self):
891
- return iter(self.fields)
892
-
893
- def __len__(self):
894
- return len(self.fields)
895
-
896
- def summary_table(self):
897
- data = []
898
- for name, col in self.fields.items():
899
- summary = col.vcf_field.summary
900
- d = {
901
- "name": name,
902
- "type": col.vcf_field.vcf_type,
903
- "chunks": summary.num_chunks,
904
- "size": display_size(summary.uncompressed_size),
905
- "compressed": display_size(summary.compressed_size),
906
- "max_n": summary.max_number,
907
- "min_val": display_number(summary.min_value),
908
- "max_val": display_number(summary.max_value),
909
- }
910
-
911
- data.append(d)
912
- return data
913
-
914
- @property
915
- def num_records(self):
916
- return self.metadata.num_records
917
-
918
- @property
919
- def num_partitions(self):
920
- return len(self.metadata.partitions)
921
-
922
- @property
923
- def num_samples(self):
924
- return len(self.metadata.samples)
925
-
926
- @property
927
- def num_fields(self):
928
- return len(self.fields)
929
-
930
-
931
- @dataclasses.dataclass
932
- class IcfPartitionMetadata:
933
- num_records: int
934
- last_position: int
935
- field_summaries: dict
936
-
937
- def asdict(self):
938
- return dataclasses.asdict(self)
939
-
940
- def asjson(self):
941
- return json.dumps(self.asdict(), indent=4)
942
-
943
- @staticmethod
944
- def fromdict(d):
945
- md = IcfPartitionMetadata(**d)
946
- for k, v in md.field_summaries.items():
947
- md.field_summaries[k] = VcfFieldSummary.fromdict(v)
948
- return md
949
-
950
-
951
- def check_overlapping_partitions(partitions):
952
- for i in range(1, len(partitions)):
953
- prev_region = partitions[i - 1].region
954
- current_region = partitions[i].region
955
- if prev_region.contig == current_region.contig:
956
- assert prev_region.end is not None
957
- # Regions are *inclusive*
958
- if prev_region.end >= current_region.start:
959
- raise ValueError(
960
- f"Overlapping VCF regions in partitions {i - 1} and {i}: "
961
- f"{prev_region} and {current_region}"
962
- )
963
-
964
-
965
- class IntermediateColumnarFormatWriter:
966
- def __init__(self, path):
967
- self.path = pathlib.Path(path)
968
- self.wip_path = self.path / "wip"
969
- self.metadata = None
970
-
971
- @property
972
- def num_partitions(self):
973
- return len(self.metadata.partitions)
974
-
975
- def init(
976
- self,
977
- vcfs,
978
- *,
979
- column_chunk_size=16,
980
- worker_processes=1,
981
- target_num_partitions=None,
982
- show_progress=False,
983
- compressor=None,
984
- ):
985
- if self.path.exists():
986
- raise ValueError("ICF path already exists")
987
- if compressor is None:
988
- compressor = ICF_DEFAULT_COMPRESSOR
989
- vcfs = [pathlib.Path(vcf) for vcf in vcfs]
990
- target_num_partitions = max(target_num_partitions, len(vcfs))
991
-
992
- # TODO move scan_vcfs into this class
993
- icf_metadata, header = scan_vcfs(
994
- vcfs,
995
- worker_processes=worker_processes,
996
- show_progress=show_progress,
997
- target_num_partitions=target_num_partitions,
998
- )
999
- self.metadata = icf_metadata
1000
- self.metadata.format_version = ICF_METADATA_FORMAT_VERSION
1001
- self.metadata.compressor = compressor.get_config()
1002
- self.metadata.column_chunk_size = column_chunk_size
1003
- # Bare minimum here for provenance - would be nice to include versions of key
1004
- # dependencies as well.
1005
- self.metadata.provenance = {"source": f"bio2zarr-{provenance.__version__}"}
1006
-
1007
- self.mkdirs()
1008
-
1009
- # Note: this is needed for the current version of the vcfzarr spec, but it's
1010
- # probably going to be dropped.
1011
- # https://github.com/pystatgen/vcf-zarr-spec/issues/15
1012
- # May be useful to keep lying around still though?
1013
- logger.info("Writing VCF header")
1014
- with open(self.path / "header.txt", "w") as f:
1015
- f.write(header)
1016
-
1017
- logger.info("Writing WIP metadata")
1018
- with open(self.wip_path / "metadata.json", "w") as f:
1019
- json.dump(self.metadata.asdict(), f, indent=4)
1020
- return self.num_partitions
1021
-
1022
- def mkdirs(self):
1023
- num_dirs = len(self.metadata.fields)
1024
- logger.info(f"Creating {num_dirs} field directories")
1025
- self.path.mkdir()
1026
- self.wip_path.mkdir()
1027
- for field in self.metadata.fields:
1028
- col_path = get_vcf_field_path(self.path, field)
1029
- col_path.mkdir(parents=True)
1030
-
1031
- def load_partition_summaries(self):
1032
- summaries = []
1033
- not_found = []
1034
- for j in range(self.num_partitions):
1035
- try:
1036
- with open(self.wip_path / f"p{j}.json") as f:
1037
- summaries.append(IcfPartitionMetadata.fromdict(json.load(f)))
1038
- except FileNotFoundError:
1039
- not_found.append(j)
1040
- if len(not_found) > 0:
1041
- raise FileNotFoundError(
1042
- f"Partition metadata not found for {len(not_found)}"
1043
- f" partitions: {not_found}"
1044
- )
1045
- return summaries
1046
-
1047
- def load_metadata(self):
1048
- if self.metadata is None:
1049
- with open(self.wip_path / "metadata.json") as f:
1050
- self.metadata = IcfMetadata.fromdict(json.load(f))
1051
-
1052
- def process_partition(self, partition_index):
1053
- self.load_metadata()
1054
- summary_path = self.wip_path / f"p{partition_index}.json"
1055
- # If someone is rewriting a summary path (for whatever reason), make sure it
1056
- # doesn't look like it's already been completed.
1057
- # NOTE to do this properly we probably need to take a lock on this file - but
1058
- # this simple approach will catch the vast majority of problems.
1059
- if summary_path.exists():
1060
- summary_path.unlink()
1061
-
1062
- partition = self.metadata.partitions[partition_index]
1063
- logger.info(
1064
- f"Start p{partition_index} {partition.vcf_path}__{partition.region}"
1065
- )
1066
- info_fields = self.metadata.info_fields
1067
- format_fields = []
1068
- has_gt = False
1069
- for field in self.metadata.format_fields:
1070
- if field.name == "GT":
1071
- has_gt = True
1072
- else:
1073
- format_fields.append(field)
1074
-
1075
- last_position = None
1076
- with IcfPartitionWriter(
1077
- self.metadata,
1078
- self.path,
1079
- partition_index,
1080
- ) as tcw:
1081
- with vcf_utils.IndexedVcf(partition.vcf_path) as ivcf:
1082
- num_records = 0
1083
- for variant in ivcf.variants(partition.region):
1084
- num_records += 1
1085
- last_position = variant.POS
1086
- tcw.append("CHROM", variant.CHROM)
1087
- tcw.append("POS", variant.POS)
1088
- tcw.append("QUAL", variant.QUAL)
1089
- tcw.append("ID", variant.ID)
1090
- tcw.append("FILTERS", variant.FILTERS)
1091
- tcw.append("REF", variant.REF)
1092
- tcw.append("ALT", variant.ALT)
1093
- for field in info_fields:
1094
- tcw.append(field.full_name, variant.INFO.get(field.name, None))
1095
- if has_gt:
1096
- tcw.append("FORMAT/GT", variant.genotype.array())
1097
- for field in format_fields:
1098
- val = variant.format(field.name)
1099
- tcw.append(field.full_name, val)
1100
- # Note: an issue with updating the progress per variant here like
1101
- # this is that we get a significant pause at the end of the counter
1102
- # while all the "small" fields get flushed. Possibly not much to be
1103
- # done about it.
1104
- core.update_progress(1)
1105
- logger.info(
1106
- f"Finished reading VCF for partition {partition_index}, "
1107
- f"flushing buffers"
1108
- )
1109
-
1110
- partition_metadata = IcfPartitionMetadata(
1111
- num_records=num_records,
1112
- last_position=last_position,
1113
- field_summaries=tcw.field_summaries,
1114
- )
1115
- with open(summary_path, "w") as f:
1116
- f.write(partition_metadata.asjson())
1117
- logger.info(
1118
- f"Finish p{partition_index} {partition.vcf_path}__{partition.region} "
1119
- f"{num_records} records last_pos={last_position}"
1120
- )
1121
-
1122
- def explode(self, *, worker_processes=1, show_progress=False):
1123
- self.load_metadata()
1124
- num_records = self.metadata.num_records
1125
- if np.isinf(num_records):
1126
- logger.warning(
1127
- "Total records unknown, cannot show progress; "
1128
- "reindex VCFs with bcftools index to fix"
1129
- )
1130
- num_records = None
1131
- num_fields = len(self.metadata.fields)
1132
- num_samples = len(self.metadata.samples)
1133
- logger.info(
1134
- f"Exploding fields={num_fields} samples={num_samples}; "
1135
- f"partitions={self.num_partitions} "
1136
- f"variants={'unknown' if num_records is None else num_records}"
1137
- )
1138
- progress_config = core.ProgressConfig(
1139
- total=num_records,
1140
- units="vars",
1141
- title="Explode",
1142
- show=show_progress,
1143
- )
1144
- with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1145
- for j in range(self.num_partitions):
1146
- pwm.submit(self.process_partition, j)
1147
-
1148
- def explode_partition(self, partition):
1149
- self.load_metadata()
1150
- if partition < 0 or partition >= self.num_partitions:
1151
- raise ValueError(
1152
- "Partition index must be in the range 0 <= index < num_partitions"
1153
- )
1154
- self.process_partition(partition)
1155
-
1156
- def finalise(self):
1157
- self.load_metadata()
1158
- partition_summaries = self.load_partition_summaries()
1159
- total_records = 0
1160
- for index, summary in enumerate(partition_summaries):
1161
- partition_records = summary.num_records
1162
- self.metadata.partitions[index].num_records = partition_records
1163
- self.metadata.partitions[index].region.end = summary.last_position
1164
- total_records += partition_records
1165
- if not np.isinf(self.metadata.num_records):
1166
- # Note: this is just telling us that there's a bug in the
1167
- # index based record counting code, but it doesn't actually
1168
- # matter much. We may want to just make this a warning if
1169
- # we hit regular problems.
1170
- assert total_records == self.metadata.num_records
1171
- self.metadata.num_records = total_records
1172
-
1173
- check_overlapping_partitions(self.metadata.partitions)
1174
-
1175
- for field in self.metadata.fields:
1176
- for summary in partition_summaries:
1177
- field.summary.update(summary.field_summaries[field.full_name])
1178
-
1179
- logger.info("Finalising metadata")
1180
- with open(self.path / "metadata.json", "w") as f:
1181
- f.write(self.metadata.asjson())
1182
-
1183
- logger.debug("Removing WIP directory")
1184
- shutil.rmtree(self.wip_path)
1185
-
1186
-
1187
- def explode(
1188
- icf_path,
1189
- vcfs,
1190
- *,
1191
- column_chunk_size=16,
1192
- worker_processes=1,
1193
- show_progress=False,
1194
- compressor=None,
1195
- ):
1196
- writer = IntermediateColumnarFormatWriter(icf_path)
1197
- writer.init(
1198
- vcfs,
1199
- # Heuristic to get reasonable worker utilisation with lumpy partition sizing
1200
- target_num_partitions=max(1, worker_processes * 4),
1201
- worker_processes=worker_processes,
1202
- show_progress=show_progress,
1203
- column_chunk_size=column_chunk_size,
1204
- compressor=compressor,
1205
- )
1206
- writer.explode(worker_processes=worker_processes, show_progress=show_progress)
1207
- writer.finalise()
1208
- return IntermediateColumnarFormat(icf_path)
1209
-
1210
-
1211
- def explode_init(
1212
- icf_path,
1213
- vcfs,
1214
- *,
1215
- column_chunk_size=16,
1216
- target_num_partitions=1,
1217
- worker_processes=1,
1218
- show_progress=False,
1219
- compressor=None,
1220
- ):
1221
- writer = IntermediateColumnarFormatWriter(icf_path)
1222
- return writer.init(
1223
- vcfs,
1224
- target_num_partitions=target_num_partitions,
1225
- worker_processes=worker_processes,
1226
- show_progress=show_progress,
1227
- column_chunk_size=column_chunk_size,
1228
- compressor=compressor,
1229
- )
1230
-
1231
-
1232
- def explode_partition(icf_path, partition):
1233
- writer = IntermediateColumnarFormatWriter(icf_path)
1234
- writer.explode_partition(partition)
1235
-
1236
-
1237
- def explode_finalise(icf_path):
1238
- writer = IntermediateColumnarFormatWriter(icf_path)
1239
- writer.finalise()
1240
-
1241
-
1242
- def inspect(path):
1243
- path = pathlib.Path(path)
1244
- # TODO add support for the Zarr format also
1245
- if (path / "metadata.json").exists():
1246
- obj = IntermediateColumnarFormat(path)
1247
- elif (path / ".zmetadata").exists():
1248
- obj = VcfZarr(path)
1249
- else:
1250
- raise ValueError("Format not recognised") # NEEDS TEST
1251
- return obj.summary_table()
1252
-
1253
-
1254
- DEFAULT_ZARR_COMPRESSOR = numcodecs.Blosc(cname="zstd", clevel=7)
1255
-
1256
-
1257
- @dataclasses.dataclass
1258
- class ZarrColumnSpec:
1259
- name: str
1260
- dtype: str
1261
- shape: tuple
1262
- chunks: tuple
1263
- dimensions: tuple
1264
- description: str
1265
- vcf_field: str
1266
- compressor: dict
1267
- filters: list
1268
-
1269
- def __post_init__(self):
1270
- # Ensure these are tuples for ease of comparison and consistency
1271
- self.shape = tuple(self.shape)
1272
- self.chunks = tuple(self.chunks)
1273
- self.dimensions = tuple(self.dimensions)
1274
-
1275
- @staticmethod
1276
- def new(**kwargs):
1277
- spec = ZarrColumnSpec(
1278
- **kwargs, compressor=DEFAULT_ZARR_COMPRESSOR.get_config(), filters=[]
1279
- )
1280
- spec._choose_compressor_settings()
1281
- return spec
1282
-
1283
- @staticmethod
1284
- def from_field(
1285
- vcf_field,
1286
- *,
1287
- num_variants,
1288
- num_samples,
1289
- variants_chunk_size,
1290
- samples_chunk_size,
1291
- variable_name=None,
1292
- ):
1293
- shape = [num_variants]
1294
- prefix = "variant_"
1295
- dimensions = ["variants"]
1296
- chunks = [variants_chunk_size]
1297
- if vcf_field.category == "FORMAT":
1298
- prefix = "call_"
1299
- shape.append(num_samples)
1300
- chunks.append(samples_chunk_size)
1301
- dimensions.append("samples")
1302
- if variable_name is None:
1303
- variable_name = prefix + vcf_field.name
1304
- # TODO make an option to add in the empty extra dimension
1305
- if vcf_field.summary.max_number > 1:
1306
- shape.append(vcf_field.summary.max_number)
1307
- # TODO we should really be checking this to see if the named dimensions
1308
- # are actually correct.
1309
- if vcf_field.vcf_number == "R":
1310
- dimensions.append("alleles")
1311
- elif vcf_field.vcf_number == "A":
1312
- dimensions.append("alt_alleles")
1313
- elif vcf_field.vcf_number == "G":
1314
- dimensions.append("genotypes")
1315
- else:
1316
- dimensions.append(f"{vcf_field.category}_{vcf_field.name}_dim")
1317
- return ZarrColumnSpec.new(
1318
- vcf_field=vcf_field.full_name,
1319
- name=variable_name,
1320
- dtype=vcf_field.smallest_dtype(),
1321
- shape=shape,
1322
- chunks=chunks,
1323
- dimensions=dimensions,
1324
- description=vcf_field.description,
1325
- )
1326
-
1327
- def _choose_compressor_settings(self):
1328
- """
1329
- Choose compressor and filter settings based on the size and
1330
- type of the array, plus some hueristics from observed properties
1331
- of VCFs.
1332
-
1333
- See https://github.com/pystatgen/bio2zarr/discussions/74
1334
- """
1335
- # Default is to not shuffle, because autoshuffle isn't recognised
1336
- # by many Zarr implementations, and shuffling can lead to worse
1337
- # performance in some cases anyway. Turning on shuffle should be a
1338
- # deliberate choice.
1339
- shuffle = numcodecs.Blosc.NOSHUFFLE
1340
- if self.name == "call_genotype" and self.dtype == "i1":
1341
- # call_genotype gets BITSHUFFLE by default as it gets
1342
- # significantly better compression (at a cost of slower
1343
- # decoding)
1344
- shuffle = numcodecs.Blosc.BITSHUFFLE
1345
- elif self.dtype == "bool":
1346
- shuffle = numcodecs.Blosc.BITSHUFFLE
1347
-
1348
- self.compressor["shuffle"] = shuffle
1349
-
1350
- @property
1351
- def variant_chunk_nbytes(self):
1352
- """
1353
- Returns the nbytes for a single variant chunk of this array.
1354
- """
1355
- chunk_items = self.chunks[0]
1356
- for size in self.shape[1:]:
1357
- chunk_items *= size
1358
- dt = np.dtype(self.dtype)
1359
- return chunk_items * dt.itemsize
1360
-
1361
-
1362
- ZARR_SCHEMA_FORMAT_VERSION = "0.3"
1363
-
1364
-
1365
- @dataclasses.dataclass
1366
- class VcfZarrSchema:
1367
- format_version: str
1368
- samples_chunk_size: int
1369
- variants_chunk_size: int
1370
- dimensions: list
1371
- samples: list
1372
- contigs: list
1373
- filters: list
1374
- fields: dict
1375
-
1376
- def asdict(self):
1377
- return dataclasses.asdict(self)
1378
-
1379
- def asjson(self):
1380
- return json.dumps(self.asdict(), indent=4)
1381
-
1382
- @staticmethod
1383
- def fromdict(d):
1384
- if d["format_version"] != ZARR_SCHEMA_FORMAT_VERSION:
1385
- raise ValueError(
1386
- "Zarr schema format version mismatch: "
1387
- f"{d['format_version']} != {ZARR_SCHEMA_FORMAT_VERSION}"
1388
- )
1389
- ret = VcfZarrSchema(**d)
1390
- ret.samples = [Sample(**sd) for sd in d["samples"]]
1391
- ret.contigs = [Contig(**sd) for sd in d["contigs"]]
1392
- ret.filters = [Filter(**sd) for sd in d["filters"]]
1393
- ret.fields = {
1394
- key: ZarrColumnSpec(**value) for key, value in d["fields"].items()
1395
- }
1396
- return ret
1397
-
1398
- @staticmethod
1399
- def fromjson(s):
1400
- return VcfZarrSchema.fromdict(json.loads(s))
1401
-
1402
- @staticmethod
1403
- def generate(icf, variants_chunk_size=None, samples_chunk_size=None):
1404
- m = icf.num_records
1405
- n = icf.num_samples
1406
- # FIXME
1407
- if samples_chunk_size is None:
1408
- samples_chunk_size = 1000
1409
- if variants_chunk_size is None:
1410
- variants_chunk_size = 10_000
1411
- logger.info(
1412
- f"Generating schema with chunks={variants_chunk_size, samples_chunk_size}"
1413
- )
1414
-
1415
- def spec_from_field(field, variable_name=None):
1416
- return ZarrColumnSpec.from_field(
1417
- field,
1418
- num_samples=n,
1419
- num_variants=m,
1420
- samples_chunk_size=samples_chunk_size,
1421
- variants_chunk_size=variants_chunk_size,
1422
- variable_name=variable_name,
1423
- )
1424
-
1425
- def fixed_field_spec(
1426
- name, dtype, vcf_field=None, shape=(m,), dimensions=("variants",)
1427
- ):
1428
- return ZarrColumnSpec.new(
1429
- vcf_field=vcf_field,
1430
- name=name,
1431
- dtype=dtype,
1432
- shape=shape,
1433
- description="",
1434
- dimensions=dimensions,
1435
- chunks=[variants_chunk_size],
1436
- )
1437
-
1438
- alt_col = icf.fields["ALT"]
1439
- max_alleles = alt_col.vcf_field.summary.max_number + 1
1440
-
1441
- colspecs = [
1442
- fixed_field_spec(
1443
- name="variant_contig",
1444
- dtype=core.min_int_dtype(0, icf.metadata.num_contigs),
1445
- ),
1446
- fixed_field_spec(
1447
- name="variant_filter",
1448
- dtype="bool",
1449
- shape=(m, icf.metadata.num_filters),
1450
- dimensions=["variants", "filters"],
1451
- ),
1452
- fixed_field_spec(
1453
- name="variant_allele",
1454
- dtype="str",
1455
- shape=(m, max_alleles),
1456
- dimensions=["variants", "alleles"],
1457
- ),
1458
- fixed_field_spec(
1459
- name="variant_id",
1460
- dtype="str",
1461
- ),
1462
- fixed_field_spec(
1463
- name="variant_id_mask",
1464
- dtype="bool",
1465
- ),
1466
- ]
1467
- name_map = {field.full_name: field for field in icf.metadata.fields}
1468
-
1469
- # Only two of the fixed fields have a direct one-to-one mapping.
1470
- colspecs.extend(
1471
- [
1472
- spec_from_field(name_map["QUAL"], variable_name="variant_quality"),
1473
- spec_from_field(name_map["POS"], variable_name="variant_position"),
1474
- ]
1475
- )
1476
- colspecs.extend([spec_from_field(field) for field in icf.metadata.info_fields])
1477
-
1478
- gt_field = None
1479
- for field in icf.metadata.format_fields:
1480
- if field.name == "GT":
1481
- gt_field = field
1482
- continue
1483
- colspecs.append(spec_from_field(field))
1484
-
1485
- if gt_field is not None:
1486
- ploidy = gt_field.summary.max_number - 1
1487
- shape = [m, n]
1488
- chunks = [variants_chunk_size, samples_chunk_size]
1489
- dimensions = ["variants", "samples"]
1490
- colspecs.append(
1491
- ZarrColumnSpec.new(
1492
- vcf_field=None,
1493
- name="call_genotype_phased",
1494
- dtype="bool",
1495
- shape=list(shape),
1496
- chunks=list(chunks),
1497
- dimensions=list(dimensions),
1498
- description="",
1499
- )
1500
- )
1501
- shape += [ploidy]
1502
- dimensions += ["ploidy"]
1503
- colspecs.append(
1504
- ZarrColumnSpec.new(
1505
- vcf_field=None,
1506
- name="call_genotype",
1507
- dtype=gt_field.smallest_dtype(),
1508
- shape=list(shape),
1509
- chunks=list(chunks),
1510
- dimensions=list(dimensions),
1511
- description="",
1512
- )
1513
- )
1514
- colspecs.append(
1515
- ZarrColumnSpec.new(
1516
- vcf_field=None,
1517
- name="call_genotype_mask",
1518
- dtype="bool",
1519
- shape=list(shape),
1520
- chunks=list(chunks),
1521
- dimensions=list(dimensions),
1522
- description="",
1523
- )
1524
- )
1525
-
1526
- return VcfZarrSchema(
1527
- format_version=ZARR_SCHEMA_FORMAT_VERSION,
1528
- samples_chunk_size=samples_chunk_size,
1529
- variants_chunk_size=variants_chunk_size,
1530
- fields={col.name: col for col in colspecs},
1531
- dimensions=["variants", "samples", "ploidy", "alleles", "filters"],
1532
- samples=icf.metadata.samples,
1533
- contigs=icf.metadata.contigs,
1534
- filters=icf.metadata.filters,
1535
- )
1536
-
1537
-
1538
- class VcfZarr:
1539
- def __init__(self, path):
1540
- if not (path / ".zmetadata").exists():
1541
- raise ValueError("Not in VcfZarr format") # NEEDS TEST
1542
- self.path = path
1543
- self.root = zarr.open(path, mode="r")
1544
-
1545
- def summary_table(self):
1546
- data = []
1547
- arrays = [(core.du(self.path / a.basename), a) for _, a in self.root.arrays()]
1548
- arrays.sort(key=lambda x: x[0])
1549
- for stored, array in reversed(arrays):
1550
- d = {
1551
- "name": array.name,
1552
- "dtype": str(array.dtype),
1553
- "stored": display_size(stored),
1554
- "size": display_size(array.nbytes),
1555
- "ratio": display_number(array.nbytes / stored),
1556
- "nchunks": str(array.nchunks),
1557
- "chunk_size": display_size(array.nbytes / array.nchunks),
1558
- "avg_chunk_stored": display_size(int(stored / array.nchunks)),
1559
- "shape": str(array.shape),
1560
- "chunk_shape": str(array.chunks),
1561
- "compressor": str(array.compressor),
1562
- "filters": str(array.filters),
1563
- }
1564
- data.append(d)
1565
- return data
1566
-
1567
-
1568
- def parse_max_memory(max_memory):
1569
- if max_memory is None:
1570
- # Effectively unbounded
1571
- return 2**63
1572
- if isinstance(max_memory, str):
1573
- max_memory = humanfriendly.parse_size(max_memory)
1574
- logger.info(f"Set memory budget to {display_size(max_memory)}")
1575
- return max_memory
1576
-
1577
-
1578
- @dataclasses.dataclass
1579
- class VcfZarrPartition:
1580
- start: int
1581
- stop: int
1582
-
1583
- @staticmethod
1584
- def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None):
1585
- num_chunks = int(np.ceil(num_records / chunk_size))
1586
- if max_chunks is not None:
1587
- num_chunks = min(num_chunks, max_chunks)
1588
- partitions = []
1589
- splits = np.array_split(np.arange(num_chunks), min(num_partitions, num_chunks))
1590
- for chunk_slice in splits:
1591
- start_chunk = int(chunk_slice[0])
1592
- stop_chunk = int(chunk_slice[-1]) + 1
1593
- start_index = start_chunk * chunk_size
1594
- stop_index = min(stop_chunk * chunk_size, num_records)
1595
- partitions.append(VcfZarrPartition(start_index, stop_index))
1596
- return partitions
1597
-
1598
-
1599
- VZW_METADATA_FORMAT_VERSION = "0.1"
1600
-
1601
-
1602
- @dataclasses.dataclass
1603
- class VcfZarrWriterMetadata:
1604
- format_version: str
1605
- icf_path: str
1606
- schema: VcfZarrSchema
1607
- dimension_separator: str
1608
- partitions: list
1609
- provenance: dict
1610
-
1611
- def asdict(self):
1612
- return dataclasses.asdict(self)
1613
-
1614
- @staticmethod
1615
- def fromdict(d):
1616
- if d["format_version"] != VZW_METADATA_FORMAT_VERSION:
1617
- raise ValueError(
1618
- "VcfZarrWriter format version mismatch: "
1619
- f"{d['format_version']} != {VZW_METADATA_FORMAT_VERSION}"
1620
- )
1621
- ret = VcfZarrWriterMetadata(**d)
1622
- ret.schema = VcfZarrSchema.fromdict(ret.schema)
1623
- ret.partitions = [VcfZarrPartition(**p) for p in ret.partitions]
1624
- return ret
1625
-
1626
-
1627
- class VcfZarrWriter:
1628
- def __init__(self, path):
1629
- self.path = pathlib.Path(path)
1630
- self.wip_path = self.path / "wip"
1631
- self.arrays_path = self.wip_path / "arrays"
1632
- self.partitions_path = self.wip_path / "partitions"
1633
- self.metadata = None
1634
- self.icf = None
1635
-
1636
- @property
1637
- def schema(self):
1638
- return self.metadata.schema
1639
-
1640
- @property
1641
- def num_partitions(self):
1642
- return len(self.metadata.partitions)
1643
-
1644
- #######################
1645
- # init
1646
- #######################
1647
-
1648
- def init(
1649
- self,
1650
- icf,
1651
- *,
1652
- target_num_partitions,
1653
- schema,
1654
- dimension_separator=None,
1655
- max_variant_chunks=None,
1656
- ):
1657
- self.icf = icf
1658
- if self.path.exists():
1659
- raise ValueError("Zarr path already exists") # NEEDS TEST
1660
- partitions = VcfZarrPartition.generate_partitions(
1661
- self.icf.num_records,
1662
- schema.variants_chunk_size,
1663
- target_num_partitions,
1664
- max_chunks=max_variant_chunks,
1665
- )
1666
- # Default to using nested directories following the Zarr v3 default.
1667
- # This seems to require version 2.17+ to work properly
1668
- dimension_separator = (
1669
- "/" if dimension_separator is None else dimension_separator
1670
- )
1671
- self.metadata = VcfZarrWriterMetadata(
1672
- format_version=VZW_METADATA_FORMAT_VERSION,
1673
- icf_path=str(self.icf.path),
1674
- schema=schema,
1675
- dimension_separator=dimension_separator,
1676
- partitions=partitions,
1677
- # Bare minimum here for provenance - see comments above
1678
- provenance={"source": f"bio2zarr-{provenance.__version__}"},
1679
- )
1680
-
1681
- self.path.mkdir()
1682
- store = zarr.DirectoryStore(self.path)
1683
- root = zarr.group(store=store)
1684
- root.attrs.update(
1685
- {
1686
- "vcf_zarr_version": "0.2",
1687
- "vcf_header": self.icf.vcf_header,
1688
- "source": f"bio2zarr-{provenance.__version__}",
1689
- }
1690
- )
1691
- # Doing this syncronously - this is fine surely
1692
- self.encode_samples(root)
1693
- self.encode_filter_id(root)
1694
- self.encode_contig_id(root)
1695
-
1696
- self.wip_path.mkdir()
1697
- self.arrays_path.mkdir()
1698
- self.partitions_path.mkdir()
1699
- store = zarr.DirectoryStore(self.arrays_path)
1700
- root = zarr.group(store=store)
1701
-
1702
- for column in self.schema.fields.values():
1703
- self.init_array(root, column, partitions[-1].stop)
1704
-
1705
- logger.info("Writing WIP metadata")
1706
- with open(self.wip_path / "metadata.json", "w") as f:
1707
- json.dump(self.metadata.asdict(), f, indent=4)
1708
- return len(partitions)
1709
-
1710
- def encode_samples(self, root):
1711
- if self.schema.samples != self.icf.metadata.samples:
1712
- raise ValueError(
1713
- "Subsetting or reordering samples not supported currently"
1714
- ) # NEEDS TEST
1715
- array = root.array(
1716
- "sample_id",
1717
- [sample.id for sample in self.schema.samples],
1718
- dtype="str",
1719
- compressor=DEFAULT_ZARR_COMPRESSOR,
1720
- chunks=(self.schema.samples_chunk_size,),
1721
- )
1722
- array.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
1723
- logger.debug("Samples done")
1724
-
1725
- def encode_contig_id(self, root):
1726
- array = root.array(
1727
- "contig_id",
1728
- [contig.id for contig in self.schema.contigs],
1729
- dtype="str",
1730
- compressor=DEFAULT_ZARR_COMPRESSOR,
1731
- )
1732
- array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1733
- if all(contig.length is not None for contig in self.schema.contigs):
1734
- array = root.array(
1735
- "contig_length",
1736
- [contig.length for contig in self.schema.contigs],
1737
- dtype=np.int64,
1738
- compressor=DEFAULT_ZARR_COMPRESSOR,
1739
- )
1740
- array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1741
-
1742
- def encode_filter_id(self, root):
1743
- # TODO need a way to store description also
1744
- # https://github.com/sgkit-dev/vcf-zarr-spec/issues/19
1745
- array = root.array(
1746
- "filter_id",
1747
- [filt.id for filt in self.schema.filters],
1748
- dtype="str",
1749
- compressor=DEFAULT_ZARR_COMPRESSOR,
1750
- )
1751
- array.attrs["_ARRAY_DIMENSIONS"] = ["filters"]
1752
-
1753
- def init_array(self, root, variable, variants_dim_size):
1754
- object_codec = None
1755
- if variable.dtype == "O":
1756
- object_codec = numcodecs.VLenUTF8()
1757
- shape = list(variable.shape)
1758
- # Truncate the variants dimension is max_variant_chunks was specified
1759
- shape[0] = variants_dim_size
1760
- a = root.empty(
1761
- variable.name,
1762
- shape=shape,
1763
- chunks=variable.chunks,
1764
- dtype=variable.dtype,
1765
- compressor=numcodecs.get_codec(variable.compressor),
1766
- filters=[numcodecs.get_codec(filt) for filt in variable.filters],
1767
- object_codec=object_codec,
1768
- dimension_separator=self.metadata.dimension_separator,
1769
- )
1770
- a.attrs.update(
1771
- {
1772
- "description": variable.description,
1773
- # Dimension names are part of the spec in Zarr v3
1774
- "_ARRAY_DIMENSIONS": variable.dimensions,
1775
- }
1776
- )
1777
- logger.debug(f"Initialised {a}")
1778
-
1779
- #######################
1780
- # encode_partition
1781
- #######################
1782
-
1783
- def load_metadata(self):
1784
- if self.metadata is None:
1785
- with open(self.wip_path / "metadata.json") as f:
1786
- self.metadata = VcfZarrWriterMetadata.fromdict(json.load(f))
1787
- self.icf = IntermediateColumnarFormat(self.metadata.icf_path)
1788
-
1789
- def partition_path(self, partition_index):
1790
- return self.partitions_path / f"p{partition_index}"
1791
-
1792
- def wip_partition_path(self, partition_index):
1793
- return self.partitions_path / f"wip_p{partition_index}"
1794
-
1795
- def wip_partition_array_path(self, partition_index, name):
1796
- return self.wip_partition_path(partition_index) / name
1797
-
1798
- def partition_array_path(self, partition_index, name):
1799
- return self.partition_path(partition_index) / name
1800
-
1801
- def encode_partition(self, partition_index):
1802
- self.load_metadata()
1803
- if partition_index < 0 or partition_index >= self.num_partitions:
1804
- raise ValueError(
1805
- "Partition index must be in the range 0 <= index < num_partitions"
1806
- )
1807
- partition_path = self.wip_partition_path(partition_index)
1808
- partition_path.mkdir(exist_ok=True)
1809
- logger.info(f"Encoding partition {partition_index} to {partition_path}")
1810
-
1811
- self.encode_id_partition(partition_index)
1812
- self.encode_filters_partition(partition_index)
1813
- self.encode_contig_partition(partition_index)
1814
- self.encode_alleles_partition(partition_index)
1815
- for col in self.schema.fields.values():
1816
- if col.vcf_field is not None:
1817
- self.encode_array_partition(col, partition_index)
1818
- if "call_genotype" in self.schema.fields:
1819
- self.encode_genotypes_partition(partition_index)
1820
-
1821
- final_path = self.partition_path(partition_index)
1822
- logger.info(f"Finalising {partition_index} at {final_path}")
1823
- if final_path.exists():
1824
- logger.warning(f"Removing existing partition at {final_path}")
1825
- shutil.rmtree(final_path)
1826
- os.rename(partition_path, final_path)
1827
-
1828
- def init_partition_array(self, partition_index, name):
1829
- wip_path = self.wip_partition_array_path(partition_index, name)
1830
- # Create an empty array like the definition
1831
- src = self.arrays_path / name
1832
- # Overwrite any existing WIP files
1833
- shutil.copytree(src, wip_path, dirs_exist_ok=True)
1834
- array = zarr.open(wip_path)
1835
- logger.debug(f"Opened empty array {array} @ {wip_path}")
1836
- return array
1837
-
1838
- def finalise_partition_array(self, partition_index, name):
1839
- logger.debug(f"Encoded {name} partition {partition_index}")
1840
-
1841
- def encode_array_partition(self, column, partition_index):
1842
- array = self.init_partition_array(partition_index, column.name)
1843
-
1844
- partition = self.metadata.partitions[partition_index]
1845
- ba = core.BufferedArray(array, partition.start)
1846
- source_col = self.icf.fields[column.vcf_field]
1847
- sanitiser = source_col.sanitiser_factory(ba.buff.shape)
1848
-
1849
- for value in source_col.iter_values(partition.start, partition.stop):
1850
- # We write directly into the buffer in the sanitiser function
1851
- # to make it easier to reason about dimension padding
1852
- j = ba.next_buffer_row()
1853
- sanitiser(ba.buff, j, value)
1854
- ba.flush()
1855
- self.finalise_partition_array(partition_index, column.name)
1856
-
1857
- def encode_genotypes_partition(self, partition_index):
1858
- gt_array = self.init_partition_array(partition_index, "call_genotype")
1859
- gt_mask_array = self.init_partition_array(partition_index, "call_genotype_mask")
1860
- gt_phased_array = self.init_partition_array(
1861
- partition_index, "call_genotype_phased"
1862
- )
1863
-
1864
- partition = self.metadata.partitions[partition_index]
1865
- gt = core.BufferedArray(gt_array, partition.start)
1866
- gt_mask = core.BufferedArray(gt_mask_array, partition.start)
1867
- gt_phased = core.BufferedArray(gt_phased_array, partition.start)
1868
-
1869
- source_col = self.icf.fields["FORMAT/GT"]
1870
- for value in source_col.iter_values(partition.start, partition.stop):
1871
- j = gt.next_buffer_row()
1872
- sanitise_value_int_2d(gt.buff, j, value[:, :-1])
1873
- j = gt_phased.next_buffer_row()
1874
- sanitise_value_int_1d(gt_phased.buff, j, value[:, -1])
1875
- # TODO check is this the correct semantics when we are padding
1876
- # with mixed ploidies?
1877
- j = gt_mask.next_buffer_row()
1878
- gt_mask.buff[j] = gt.buff[j] < 0
1879
- gt.flush()
1880
- gt_phased.flush()
1881
- gt_mask.flush()
1882
-
1883
- self.finalise_partition_array(partition_index, "call_genotype")
1884
- self.finalise_partition_array(partition_index, "call_genotype_mask")
1885
- self.finalise_partition_array(partition_index, "call_genotype_phased")
1886
-
1887
- def encode_alleles_partition(self, partition_index):
1888
- array_name = "variant_allele"
1889
- alleles_array = self.init_partition_array(partition_index, array_name)
1890
- partition = self.metadata.partitions[partition_index]
1891
- alleles = core.BufferedArray(alleles_array, partition.start)
1892
- ref_col = self.icf.fields["REF"]
1893
- alt_col = self.icf.fields["ALT"]
1894
-
1895
- for ref, alt in zip(
1896
- ref_col.iter_values(partition.start, partition.stop),
1897
- alt_col.iter_values(partition.start, partition.stop),
1898
- ):
1899
- j = alleles.next_buffer_row()
1900
- alleles.buff[j, :] = STR_FILL
1901
- alleles.buff[j, 0] = ref[0]
1902
- alleles.buff[j, 1 : 1 + len(alt)] = alt
1903
- alleles.flush()
1904
-
1905
- self.finalise_partition_array(partition_index, array_name)
1906
-
1907
- def encode_id_partition(self, partition_index):
1908
- vid_array = self.init_partition_array(partition_index, "variant_id")
1909
- vid_mask_array = self.init_partition_array(partition_index, "variant_id_mask")
1910
- partition = self.metadata.partitions[partition_index]
1911
- vid = core.BufferedArray(vid_array, partition.start)
1912
- vid_mask = core.BufferedArray(vid_mask_array, partition.start)
1913
- col = self.icf.fields["ID"]
1914
-
1915
- for value in col.iter_values(partition.start, partition.stop):
1916
- j = vid.next_buffer_row()
1917
- k = vid_mask.next_buffer_row()
1918
- assert j == k
1919
- if value is not None:
1920
- vid.buff[j] = value[0]
1921
- vid_mask.buff[j] = False
1922
- else:
1923
- vid.buff[j] = STR_MISSING
1924
- vid_mask.buff[j] = True
1925
- vid.flush()
1926
- vid_mask.flush()
1927
-
1928
- self.finalise_partition_array(partition_index, "variant_id")
1929
- self.finalise_partition_array(partition_index, "variant_id_mask")
1930
-
1931
- def encode_filters_partition(self, partition_index):
1932
- lookup = {filt.id: index for index, filt in enumerate(self.schema.filters)}
1933
- array_name = "variant_filter"
1934
- array = self.init_partition_array(partition_index, array_name)
1935
- partition = self.metadata.partitions[partition_index]
1936
- var_filter = core.BufferedArray(array, partition.start)
1937
-
1938
- col = self.icf.fields["FILTERS"]
1939
- for value in col.iter_values(partition.start, partition.stop):
1940
- j = var_filter.next_buffer_row()
1941
- var_filter.buff[j] = False
1942
- for f in value:
1943
- try:
1944
- var_filter.buff[j, lookup[f]] = True
1945
- except KeyError:
1946
- raise ValueError(
1947
- f"Filter '{f}' was not defined in the header."
1948
- ) from None
1949
- var_filter.flush()
1950
-
1951
- self.finalise_partition_array(partition_index, array_name)
1952
-
1953
- def encode_contig_partition(self, partition_index):
1954
- lookup = {contig.id: index for index, contig in enumerate(self.schema.contigs)}
1955
- array_name = "variant_contig"
1956
- array = self.init_partition_array(partition_index, array_name)
1957
- partition = self.metadata.partitions[partition_index]
1958
- contig = core.BufferedArray(array, partition.start)
1959
- col = self.icf.fields["CHROM"]
1960
-
1961
- for value in col.iter_values(partition.start, partition.stop):
1962
- j = contig.next_buffer_row()
1963
- # Note: because we are using the indexes to define the lookups
1964
- # and we always have an index, it seems that we the contig lookup
1965
- # will always succeed. However, if anyone ever does hit a KeyError
1966
- # here, please do open an issue with a reproducible example!
1967
- contig.buff[j] = lookup[value[0]]
1968
- contig.flush()
1969
-
1970
- self.finalise_partition_array(partition_index, array_name)
1971
-
1972
- #######################
1973
- # finalise
1974
- #######################
1975
-
1976
- def finalise_array(self, name):
1977
- logger.info(f"Finalising {name}")
1978
- final_path = self.path / name
1979
- if final_path.exists():
1980
- # NEEDS TEST
1981
- raise ValueError(f"Array {name} already exists")
1982
- for partition in range(self.num_partitions):
1983
- # Move all the files in partition dir to dest dir
1984
- src = self.partition_array_path(partition, name)
1985
- if not src.exists():
1986
- # Needs test
1987
- raise ValueError(f"Partition {partition} of {name} does not exist")
1988
- dest = self.arrays_path / name
1989
- # This is Zarr v2 specific. Chunks in v3 with start with "c" prefix.
1990
- chunk_files = [
1991
- path for path in src.iterdir() if not path.name.startswith(".")
1992
- ]
1993
- # TODO check for a count of then number of files. If we require a
1994
- # dimension_separator of "/" then we could make stronger assertions
1995
- # here, as we'd always have num_variant_chunks
1996
- logger.debug(
1997
- f"Moving {len(chunk_files)} chunks for {name} partition {partition}"
1998
- )
1999
- for chunk_file in chunk_files:
2000
- os.rename(chunk_file, dest / chunk_file.name)
2001
- # Finally, once all the chunks have moved into the arrays dir,
2002
- # we move it out of wip
2003
- os.rename(self.arrays_path / name, self.path / name)
2004
- core.update_progress(1)
2005
-
2006
- def finalise(self, show_progress=False):
2007
- self.load_metadata()
2008
-
2009
- logger.info("Scanning {self.num_partitions} partitions")
2010
- missing = []
2011
- # TODO may need a progress bar here
2012
- for partition_id in range(self.num_partitions):
2013
- if not self.partition_path(partition_id).exists():
2014
- missing.append(partition_id)
2015
- if len(missing) > 0:
2016
- raise FileNotFoundError(f"Partitions not encoded: {missing}")
2017
-
2018
- progress_config = core.ProgressConfig(
2019
- total=len(self.schema.fields),
2020
- title="Finalise",
2021
- units="array",
2022
- show=show_progress,
2023
- )
2024
- # NOTE: it's not clear that adding more workers will make this quicker,
2025
- # as it's just going to be causing contention on the file system.
2026
- # Something to check empirically in some deployments.
2027
- # FIXME we're just using worker_processes=0 here to hook into the
2028
- # SynchronousExecutor which is intended for testing purposes so
2029
- # that we get test coverage. Should fix this either by allowing
2030
- # for multiple workers, or making a standard wrapper for tqdm
2031
- # that allows us to have a consistent look and feel.
2032
- with core.ParallelWorkManager(0, progress_config) as pwm:
2033
- for name in self.schema.fields:
2034
- pwm.submit(self.finalise_array, name)
2035
- logger.debug(f"Removing {self.wip_path}")
2036
- shutil.rmtree(self.wip_path)
2037
- logger.info("Consolidating Zarr metadata")
2038
- zarr.consolidate_metadata(self.path)
2039
-
2040
- ######################
2041
- # encode_all_partitions
2042
- ######################
2043
-
2044
- def get_max_encoding_memory(self):
2045
- """
2046
- Return the approximate maximum memory used to encode a variant chunk.
2047
- """
2048
- max_encoding_mem = max(
2049
- col.variant_chunk_nbytes for col in self.schema.fields.values()
2050
- )
2051
- gt_mem = 0
2052
- if "call_genotype" in self.schema.fields:
2053
- encoded_together = [
2054
- "call_genotype",
2055
- "call_genotype_phased",
2056
- "call_genotype_mask",
2057
- ]
2058
- gt_mem = sum(
2059
- self.schema.fields[col].variant_chunk_nbytes for col in encoded_together
2060
- )
2061
- return max(max_encoding_mem, gt_mem)
2062
-
2063
- def encode_all_partitions(
2064
- self, *, worker_processes=1, show_progress=False, max_memory=None
2065
- ):
2066
- max_memory = parse_max_memory(max_memory)
2067
- self.load_metadata()
2068
- num_partitions = self.num_partitions
2069
- per_worker_memory = self.get_max_encoding_memory()
2070
- logger.info(
2071
- f"Encoding Zarr over {num_partitions} partitions with "
2072
- f"{worker_processes} workers and {display_size(per_worker_memory)} "
2073
- "per worker"
2074
- )
2075
- # Each partition requires per_worker_memory bytes, so to prevent more that
2076
- # max_memory being used, we clamp the number of workers
2077
- max_num_workers = max_memory // per_worker_memory
2078
- if max_num_workers < worker_processes:
2079
- logger.warning(
2080
- f"Limiting number of workers to {max_num_workers} to "
2081
- f"keep within specified memory budget of {display_size(max_memory)}"
2082
- )
2083
- if max_num_workers <= 0:
2084
- raise ValueError(
2085
- f"Insufficient memory to encode a partition:"
2086
- f"{display_size(per_worker_memory)} > {display_size(max_memory)}"
2087
- )
2088
- num_workers = min(max_num_workers, worker_processes)
2089
-
2090
- total_bytes = 0
2091
- for col in self.schema.fields.values():
2092
- # Open the array definition to get the total size
2093
- total_bytes += zarr.open(self.arrays_path / col.name).nbytes
2094
-
2095
- progress_config = core.ProgressConfig(
2096
- total=total_bytes,
2097
- title="Encode",
2098
- units="B",
2099
- show=show_progress,
2100
- )
2101
- with core.ParallelWorkManager(num_workers, progress_config) as pwm:
2102
- for partition_index in range(num_partitions):
2103
- pwm.submit(self.encode_partition, partition_index)
2104
-
2105
-
2106
- def mkschema(if_path, out):
2107
- icf = IntermediateColumnarFormat(if_path)
2108
- spec = VcfZarrSchema.generate(icf)
2109
- out.write(spec.asjson())
2110
-
2111
-
2112
- def encode(
2113
- if_path,
2114
- zarr_path,
2115
- schema_path=None,
2116
- variants_chunk_size=None,
2117
- samples_chunk_size=None,
2118
- max_variant_chunks=None,
2119
- dimension_separator=None,
2120
- max_memory=None,
2121
- worker_processes=1,
2122
- show_progress=False,
2123
- ):
2124
- # Rough heuristic to split work up enough to keep utilisation high
2125
- target_num_partitions = max(1, worker_processes * 4)
2126
- encode_init(
2127
- if_path,
2128
- zarr_path,
2129
- target_num_partitions,
2130
- schema_path=schema_path,
2131
- variants_chunk_size=variants_chunk_size,
2132
- samples_chunk_size=samples_chunk_size,
2133
- max_variant_chunks=max_variant_chunks,
2134
- dimension_separator=dimension_separator,
2135
- )
2136
- vzw = VcfZarrWriter(zarr_path)
2137
- vzw.encode_all_partitions(
2138
- worker_processes=worker_processes,
2139
- show_progress=show_progress,
2140
- max_memory=max_memory,
2141
- )
2142
- vzw.finalise(show_progress)
2143
-
2144
-
2145
- def encode_init(
2146
- icf_path,
2147
- zarr_path,
2148
- target_num_partitions,
2149
- *,
2150
- schema_path=None,
2151
- variants_chunk_size=None,
2152
- samples_chunk_size=None,
2153
- max_variant_chunks=None,
2154
- dimension_separator=None,
2155
- max_memory=None,
2156
- worker_processes=1,
2157
- show_progress=False,
2158
- ):
2159
- icf = IntermediateColumnarFormat(icf_path)
2160
- if schema_path is None:
2161
- schema = VcfZarrSchema.generate(
2162
- icf,
2163
- variants_chunk_size=variants_chunk_size,
2164
- samples_chunk_size=samples_chunk_size,
2165
- )
2166
- else:
2167
- logger.info(f"Reading schema from {schema_path}")
2168
- if variants_chunk_size is not None or samples_chunk_size is not None:
2169
- raise ValueError(
2170
- "Cannot specify schema along with chunk sizes"
2171
- ) # NEEDS TEST
2172
- with open(schema_path) as f:
2173
- schema = VcfZarrSchema.fromjson(f.read())
2174
- zarr_path = pathlib.Path(zarr_path)
2175
- vzw = VcfZarrWriter(zarr_path)
2176
- vzw.init(
2177
- icf,
2178
- target_num_partitions=target_num_partitions,
2179
- schema=schema,
2180
- dimension_separator=dimension_separator,
2181
- max_variant_chunks=max_variant_chunks,
2182
- )
2183
- return vzw.num_partitions, vzw.get_max_encoding_memory()
2184
-
2185
-
2186
- def encode_partition(zarr_path, partition):
2187
- writer = VcfZarrWriter(zarr_path)
2188
- writer.encode_partition(partition)
2189
-
2190
-
2191
- def encode_finalise(zarr_path, show_progress=False):
2192
- writer = VcfZarrWriter(zarr_path)
2193
- writer.finalise(show_progress=show_progress)
2194
-
2195
-
2196
- def convert(
2197
- vcfs,
2198
- out_path,
2199
- *,
2200
- variants_chunk_size=None,
2201
- samples_chunk_size=None,
2202
- worker_processes=1,
2203
- show_progress=False,
2204
- # TODO add arguments to control location of tmpdir
2205
- ):
2206
- with tempfile.TemporaryDirectory(prefix="vcf2zarr") as tmp:
2207
- if_dir = pathlib.Path(tmp) / "if"
2208
- explode(
2209
- if_dir,
2210
- vcfs,
2211
- worker_processes=worker_processes,
2212
- show_progress=show_progress,
2213
- )
2214
- encode(
2215
- if_dir,
2216
- out_path,
2217
- variants_chunk_size=variants_chunk_size,
2218
- samples_chunk_size=samples_chunk_size,
2219
- worker_processes=worker_processes,
2220
- show_progress=show_progress,
2221
- )
2222
-
2223
-
2224
- def assert_all_missing_float(a):
2225
- v = np.array(a, dtype=np.float32).view(np.int32)
2226
- nt.assert_equal(v, FLOAT32_MISSING_AS_INT32)
2227
-
2228
-
2229
- def assert_all_fill_float(a):
2230
- v = np.array(a, dtype=np.float32).view(np.int32)
2231
- nt.assert_equal(v, FLOAT32_FILL_AS_INT32)
2232
-
2233
-
2234
- def assert_all_missing_int(a):
2235
- v = np.array(a, dtype=int)
2236
- nt.assert_equal(v, -1)
2237
-
2238
-
2239
- def assert_all_fill_int(a):
2240
- v = np.array(a, dtype=int)
2241
- nt.assert_equal(v, -2)
2242
-
2243
-
2244
- def assert_all_missing_string(a):
2245
- nt.assert_equal(a, ".")
2246
-
2247
-
2248
- def assert_all_fill_string(a):
2249
- nt.assert_equal(a, "")
2250
-
2251
-
2252
- def assert_all_fill(zarr_val, vcf_type):
2253
- if vcf_type == "Integer":
2254
- assert_all_fill_int(zarr_val)
2255
- elif vcf_type in ("String", "Character"):
2256
- assert_all_fill_string(zarr_val)
2257
- elif vcf_type == "Float":
2258
- assert_all_fill_float(zarr_val)
2259
- else: # pragma: no cover
2260
- assert False # noqa PT015
2261
-
2262
-
2263
- def assert_all_missing(zarr_val, vcf_type):
2264
- if vcf_type == "Integer":
2265
- assert_all_missing_int(zarr_val)
2266
- elif vcf_type in ("String", "Character"):
2267
- assert_all_missing_string(zarr_val)
2268
- elif vcf_type == "Flag":
2269
- assert zarr_val == False # noqa 712
2270
- elif vcf_type == "Float":
2271
- assert_all_missing_float(zarr_val)
2272
- else: # pragma: no cover
2273
- assert False # noqa PT015
2274
-
2275
-
2276
- def assert_info_val_missing(zarr_val, vcf_type):
2277
- assert_all_missing(zarr_val, vcf_type)
2278
-
2279
-
2280
- def assert_format_val_missing(zarr_val, vcf_type):
2281
- assert_info_val_missing(zarr_val, vcf_type)
2282
-
2283
-
2284
- # Note: checking exact equality may prove problematic here
2285
- # but we should be deterministically storing what cyvcf2
2286
- # provides, which should compare equal.
2287
-
2288
-
2289
- def assert_info_val_equal(vcf_val, zarr_val, vcf_type):
2290
- assert vcf_val is not None
2291
- if vcf_type in ("String", "Character"):
2292
- split = list(vcf_val.split(","))
2293
- k = len(split)
2294
- if isinstance(zarr_val, str):
2295
- assert k == 1
2296
- # Scalar
2297
- assert vcf_val == zarr_val
2298
- else:
2299
- nt.assert_equal(split, zarr_val[:k])
2300
- assert_all_fill(zarr_val[k:], vcf_type)
2301
-
2302
- elif isinstance(vcf_val, tuple):
2303
- vcf_missing_value_map = {
2304
- "Integer": -1,
2305
- "Float": FLOAT32_MISSING,
2306
- }
2307
- v = [vcf_missing_value_map[vcf_type] if x is None else x for x in vcf_val]
2308
- missing = np.array([j for j, x in enumerate(vcf_val) if x is None], dtype=int)
2309
- a = np.array(v)
2310
- k = len(a)
2311
- # We are checking for int missing twice here, but it's necessary to have
2312
- # a separate check for floats because different NaNs compare equal
2313
- nt.assert_equal(a, zarr_val[:k])
2314
- assert_all_missing(zarr_val[missing], vcf_type)
2315
- if k < len(zarr_val):
2316
- assert_all_fill(zarr_val[k:], vcf_type)
2317
- else:
2318
- # Scalar
2319
- zarr_val = np.array(zarr_val, ndmin=1)
2320
- assert len(zarr_val.shape) == 1
2321
- assert vcf_val == zarr_val[0]
2322
- if len(zarr_val) > 1:
2323
- assert_all_fill(zarr_val[1:], vcf_type)
2324
-
2325
-
2326
- def assert_format_val_equal(vcf_val, zarr_val, vcf_type):
2327
- assert vcf_val is not None
2328
- assert isinstance(vcf_val, np.ndarray)
2329
- if vcf_type in ("String", "Character"):
2330
- assert len(vcf_val) == len(zarr_val)
2331
- for v, z in zip(vcf_val, zarr_val):
2332
- split = list(v.split(","))
2333
- # Note: deliberately duplicating logic here between this and the
2334
- # INFO col above to make sure all combinations are covered by tests
2335
- k = len(split)
2336
- if k == 1:
2337
- assert v == z
2338
- else:
2339
- nt.assert_equal(split, z[:k])
2340
- assert_all_fill(z[k:], vcf_type)
2341
- else:
2342
- assert vcf_val.shape[0] == zarr_val.shape[0]
2343
- if len(vcf_val.shape) == len(zarr_val.shape) + 1:
2344
- assert vcf_val.shape[-1] == 1
2345
- vcf_val = vcf_val[..., 0]
2346
- assert len(vcf_val.shape) <= 2
2347
- assert len(vcf_val.shape) == len(zarr_val.shape)
2348
- if len(vcf_val.shape) == 2:
2349
- k = vcf_val.shape[1]
2350
- if zarr_val.shape[1] != k:
2351
- assert_all_fill(zarr_val[:, k:], vcf_type)
2352
- zarr_val = zarr_val[:, :k]
2353
- assert vcf_val.shape == zarr_val.shape
2354
- if vcf_type == "Integer":
2355
- vcf_val[vcf_val == VCF_INT_MISSING] = INT_MISSING
2356
- vcf_val[vcf_val == VCF_INT_FILL] = INT_FILL
2357
- elif vcf_type == "Float":
2358
- nt.assert_equal(vcf_val.view(np.int32), zarr_val.view(np.int32))
2359
-
2360
- nt.assert_equal(vcf_val, zarr_val)
2361
-
2362
-
2363
- # TODO rename to "verify"
2364
- def validate(vcf_path, zarr_path, show_progress=False):
2365
- store = zarr.DirectoryStore(zarr_path)
2366
-
2367
- root = zarr.group(store=store)
2368
- pos = root["variant_position"][:]
2369
- allele = root["variant_allele"][:]
2370
- chrom = root["contig_id"][:][root["variant_contig"][:]]
2371
- vid = root["variant_id"][:]
2372
- call_genotype = None
2373
- if "call_genotype" in root:
2374
- call_genotype = iter(root["call_genotype"])
2375
-
2376
- vcf = cyvcf2.VCF(vcf_path)
2377
- format_headers = {}
2378
- info_headers = {}
2379
- for h in vcf.header_iter():
2380
- if h["HeaderType"] == "FORMAT":
2381
- format_headers[h["ID"]] = h
2382
- if h["HeaderType"] == "INFO":
2383
- info_headers[h["ID"]] = h
2384
-
2385
- format_fields = {}
2386
- info_fields = {}
2387
- for colname in root.keys():
2388
- if colname.startswith("call") and not colname.startswith("call_genotype"):
2389
- vcf_name = colname.split("_", 1)[1]
2390
- vcf_type = format_headers[vcf_name]["Type"]
2391
- format_fields[vcf_name] = vcf_type, iter(root[colname])
2392
- if colname.startswith("variant"):
2393
- name = colname.split("_", 1)[1]
2394
- if name.isupper():
2395
- vcf_type = info_headers[name]["Type"]
2396
- info_fields[name] = vcf_type, iter(root[colname])
2397
-
2398
- first_pos = next(vcf).POS
2399
- start_index = np.searchsorted(pos, first_pos)
2400
- assert pos[start_index] == first_pos
2401
- vcf = cyvcf2.VCF(vcf_path)
2402
- if show_progress:
2403
- iterator = tqdm.tqdm(vcf, desc=" Verify", total=vcf.num_records) # NEEDS TEST
2404
- else:
2405
- iterator = vcf
2406
- for j, row in enumerate(iterator, start_index):
2407
- assert chrom[j] == row.CHROM
2408
- assert pos[j] == row.POS
2409
- assert vid[j] == ("." if row.ID is None else row.ID)
2410
- assert allele[j, 0] == row.REF
2411
- k = len(row.ALT)
2412
- nt.assert_array_equal(allele[j, 1 : k + 1], row.ALT)
2413
- assert np.all(allele[j, k + 1 :] == "")
2414
- # TODO FILTERS
2415
-
2416
- if call_genotype is None:
2417
- val = None
2418
- try:
2419
- val = row.format("GT")
2420
- except KeyError:
2421
- pass
2422
- assert val is None
2423
- else:
2424
- gt = row.genotype.array()
2425
- gt_zarr = next(call_genotype)
2426
- gt_vcf = gt[:, :-1]
2427
- # NOTE cyvcf2 remaps genotypes automatically
2428
- # into the same missing/pad encoding that sgkit uses.
2429
- nt.assert_array_equal(gt_zarr, gt_vcf)
2430
-
2431
- for name, (vcf_type, zarr_iter) in info_fields.items():
2432
- vcf_val = row.INFO.get(name, None)
2433
- zarr_val = next(zarr_iter)
2434
- if vcf_val is None:
2435
- assert_info_val_missing(zarr_val, vcf_type)
2436
- else:
2437
- assert_info_val_equal(vcf_val, zarr_val, vcf_type)
2438
-
2439
- for name, (vcf_type, zarr_iter) in format_fields.items():
2440
- vcf_val = row.format(name)
2441
- zarr_val = next(zarr_iter)
2442
- if vcf_val is None:
2443
- assert_format_val_missing(zarr_val, vcf_type)
2444
- else:
2445
- assert_format_val_equal(vcf_val, zarr_val, vcf_type)