bio2zarr 0.0.9__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bio2zarr might be problematic. Click here for more details.

@@ -0,0 +1,1221 @@
1
+ import collections
2
+ import contextlib
3
+ import dataclasses
4
+ import json
5
+ import logging
6
+ import math
7
+ import pathlib
8
+ import pickle
9
+ import shutil
10
+ import sys
11
+ from typing import Any
12
+
13
+ import numcodecs
14
+ import numpy as np
15
+
16
+ from .. import constants, core, provenance, vcf_utils
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class VcfFieldSummary(core.JsonDataclass):
23
+ num_chunks: int = 0
24
+ compressed_size: int = 0
25
+ uncompressed_size: int = 0
26
+ max_number: int = 0 # Corresponds to VCF Number field, depends on context
27
+ # Only defined for numeric fields
28
+ max_value: Any = -math.inf
29
+ min_value: Any = math.inf
30
+
31
+ def update(self, other):
32
+ self.num_chunks += other.num_chunks
33
+ self.compressed_size += other.compressed_size
34
+ self.uncompressed_size += other.uncompressed_size
35
+ self.max_number = max(self.max_number, other.max_number)
36
+ self.min_value = min(self.min_value, other.min_value)
37
+ self.max_value = max(self.max_value, other.max_value)
38
+
39
+ @staticmethod
40
+ def fromdict(d):
41
+ return VcfFieldSummary(**d)
42
+
43
+
44
+ @dataclasses.dataclass
45
+ class VcfField:
46
+ category: str
47
+ name: str
48
+ vcf_number: str
49
+ vcf_type: str
50
+ description: str
51
+ summary: VcfFieldSummary
52
+
53
+ @staticmethod
54
+ def from_header(definition):
55
+ category = definition["HeaderType"]
56
+ name = definition["ID"]
57
+ vcf_number = definition["Number"]
58
+ vcf_type = definition["Type"]
59
+ return VcfField(
60
+ category=category,
61
+ name=name,
62
+ vcf_number=vcf_number,
63
+ vcf_type=vcf_type,
64
+ description=definition["Description"].strip('"'),
65
+ summary=VcfFieldSummary(),
66
+ )
67
+
68
+ @staticmethod
69
+ def fromdict(d):
70
+ f = VcfField(**d)
71
+ f.summary = VcfFieldSummary(**d["summary"])
72
+ return f
73
+
74
+ @property
75
+ def full_name(self):
76
+ if self.category == "fixed":
77
+ return self.name
78
+ return f"{self.category}/{self.name}"
79
+
80
+ def smallest_dtype(self):
81
+ """
82
+ Returns the smallest dtype suitable for this field based
83
+ on type, and values.
84
+ """
85
+ s = self.summary
86
+ if self.vcf_type == "Float":
87
+ ret = "f4"
88
+ elif self.vcf_type == "Integer":
89
+ if not math.isfinite(s.max_value):
90
+ # All missing values; use i1. Note we should have some API to
91
+ # check more explicitly for missingness:
92
+ # https://github.com/sgkit-dev/bio2zarr/issues/131
93
+ ret = "i1"
94
+ else:
95
+ ret = core.min_int_dtype(s.min_value, s.max_value)
96
+ elif self.vcf_type == "Flag":
97
+ ret = "bool"
98
+ elif self.vcf_type == "Character":
99
+ ret = "U1"
100
+ else:
101
+ assert self.vcf_type == "String"
102
+ ret = "O"
103
+ return ret
104
+
105
+
106
+ @dataclasses.dataclass
107
+ class VcfPartition:
108
+ vcf_path: str
109
+ region: str
110
+ num_records: int = -1
111
+
112
+
113
+ ICF_METADATA_FORMAT_VERSION = "0.3"
114
+ ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc(
115
+ cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE
116
+ )
117
+
118
+
119
+ @dataclasses.dataclass
120
+ class Contig:
121
+ id: str
122
+ length: int = None
123
+
124
+
125
+ @dataclasses.dataclass
126
+ class Sample:
127
+ id: str
128
+
129
+
130
+ @dataclasses.dataclass
131
+ class Filter:
132
+ id: str
133
+ description: str = ""
134
+
135
+
136
+ @dataclasses.dataclass
137
+ class IcfMetadata(core.JsonDataclass):
138
+ samples: list
139
+ contigs: list
140
+ filters: list
141
+ fields: list
142
+ partitions: list = None
143
+ format_version: str = None
144
+ compressor: dict = None
145
+ column_chunk_size: int = None
146
+ provenance: dict = None
147
+ num_records: int = -1
148
+
149
+ @property
150
+ def info_fields(self):
151
+ fields = []
152
+ for field in self.fields:
153
+ if field.category == "INFO":
154
+ fields.append(field)
155
+ return fields
156
+
157
+ @property
158
+ def format_fields(self):
159
+ fields = []
160
+ for field in self.fields:
161
+ if field.category == "FORMAT":
162
+ fields.append(field)
163
+ return fields
164
+
165
+ @property
166
+ def num_contigs(self):
167
+ return len(self.contigs)
168
+
169
+ @property
170
+ def num_filters(self):
171
+ return len(self.filters)
172
+
173
+ @property
174
+ def num_samples(self):
175
+ return len(self.samples)
176
+
177
+ @staticmethod
178
+ def fromdict(d):
179
+ if d["format_version"] != ICF_METADATA_FORMAT_VERSION:
180
+ raise ValueError(
181
+ "Intermediate columnar metadata format version mismatch: "
182
+ f"{d['format_version']} != {ICF_METADATA_FORMAT_VERSION}"
183
+ )
184
+ partitions = [VcfPartition(**pd) for pd in d["partitions"]]
185
+ for p in partitions:
186
+ p.region = vcf_utils.Region(**p.region)
187
+ d = d.copy()
188
+ d["partitions"] = partitions
189
+ d["fields"] = [VcfField.fromdict(fd) for fd in d["fields"]]
190
+ d["samples"] = [Sample(**sd) for sd in d["samples"]]
191
+ d["filters"] = [Filter(**fd) for fd in d["filters"]]
192
+ d["contigs"] = [Contig(**cd) for cd in d["contigs"]]
193
+ return IcfMetadata(**d)
194
+
195
+
196
+ def fixed_vcf_field_definitions():
197
+ def make_field_def(name, vcf_type, vcf_number):
198
+ return VcfField(
199
+ category="fixed",
200
+ name=name,
201
+ vcf_type=vcf_type,
202
+ vcf_number=vcf_number,
203
+ description="",
204
+ summary=VcfFieldSummary(),
205
+ )
206
+
207
+ fields = [
208
+ make_field_def("CHROM", "String", "1"),
209
+ make_field_def("POS", "Integer", "1"),
210
+ make_field_def("QUAL", "Float", "1"),
211
+ make_field_def("ID", "String", "."),
212
+ make_field_def("FILTERS", "String", "."),
213
+ make_field_def("REF", "String", "1"),
214
+ make_field_def("ALT", "String", "."),
215
+ ]
216
+ return fields
217
+
218
+
219
+ def scan_vcf(path, target_num_partitions):
220
+ with vcf_utils.IndexedVcf(path) as indexed_vcf:
221
+ vcf = indexed_vcf.vcf
222
+ filters = []
223
+ pass_index = -1
224
+ for h in vcf.header_iter():
225
+ if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str):
226
+ try:
227
+ description = h["Description"].strip('"')
228
+ except KeyError:
229
+ description = ""
230
+ if h["ID"] == "PASS":
231
+ pass_index = len(filters)
232
+ filters.append(Filter(h["ID"], description))
233
+
234
+ # Ensure PASS is the first filter if present
235
+ if pass_index > 0:
236
+ pass_filter = filters.pop(pass_index)
237
+ filters.insert(0, pass_filter)
238
+
239
+ fields = fixed_vcf_field_definitions()
240
+ for h in vcf.header_iter():
241
+ if h["HeaderType"] in ["INFO", "FORMAT"]:
242
+ field = VcfField.from_header(h)
243
+ if field.name == "GT":
244
+ field.vcf_type = "Integer"
245
+ field.vcf_number = "."
246
+ fields.append(field)
247
+
248
+ try:
249
+ contig_lengths = vcf.seqlens
250
+ except AttributeError:
251
+ contig_lengths = [None for _ in vcf.seqnames]
252
+
253
+ metadata = IcfMetadata(
254
+ samples=[Sample(sample_id) for sample_id in vcf.samples],
255
+ contigs=[
256
+ Contig(contig_id, length)
257
+ for contig_id, length in zip(vcf.seqnames, contig_lengths)
258
+ ],
259
+ filters=filters,
260
+ fields=fields,
261
+ partitions=[],
262
+ num_records=sum(indexed_vcf.contig_record_counts().values()),
263
+ )
264
+
265
+ regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions)
266
+ for region in regions:
267
+ metadata.partitions.append(
268
+ VcfPartition(
269
+ # TODO should this be fully resolving the path? Otherwise it's all
270
+ # relative to the original WD
271
+ vcf_path=str(path),
272
+ region=region,
273
+ )
274
+ )
275
+ logger.info(
276
+ f"Split {path} into {len(metadata.partitions)} "
277
+ f"partitions target={target_num_partitions})"
278
+ )
279
+ core.update_progress(1)
280
+ return metadata, vcf.raw_header
281
+
282
+
283
+ def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
284
+ logger.info(
285
+ f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
286
+ f" partitions."
287
+ )
288
+ # An easy mistake to make is to pass the same file twice. Check this early on.
289
+ for path, count in collections.Counter(paths).items():
290
+ if not path.exists(): # NEEDS TEST
291
+ raise FileNotFoundError(path)
292
+ if count > 1:
293
+ raise ValueError(f"Duplicate path provided: {path}")
294
+
295
+ progress_config = core.ProgressConfig(
296
+ total=len(paths),
297
+ units="files",
298
+ title="Scan",
299
+ show=show_progress,
300
+ )
301
+ with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
302
+ for path in paths:
303
+ pwm.submit(scan_vcf, path, max(1, target_num_partitions // len(paths)))
304
+ results = list(pwm.results_as_completed())
305
+
306
+ # Sort to make the ordering deterministic
307
+ results.sort(key=lambda t: t[0].partitions[0].vcf_path)
308
+ # We just take the first header, assuming the others
309
+ # are compatible.
310
+ all_partitions = []
311
+ total_records = 0
312
+ for metadata, _ in results:
313
+ for partition in metadata.partitions:
314
+ logger.debug(f"Scanned partition {partition}")
315
+ all_partitions.append(partition)
316
+ total_records += metadata.num_records
317
+ metadata.num_records = 0
318
+ metadata.partitions = []
319
+
320
+ icf_metadata, header = results[0]
321
+ for metadata, _ in results[1:]:
322
+ if metadata != icf_metadata:
323
+ raise ValueError("Incompatible VCF chunks")
324
+
325
+ # Note: this will be infinity here if any of the chunks has an index
326
+ # that doesn't keep track of the number of records per-contig
327
+ icf_metadata.num_records = total_records
328
+
329
+ # Sort by contig (in the order they appear in the header) first,
330
+ # then by start coordinate
331
+ contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)}
332
+ all_partitions.sort(
333
+ key=lambda x: (contig_index_map[x.region.contig], x.region.start)
334
+ )
335
+ icf_metadata.partitions = all_partitions
336
+ logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.")
337
+ return icf_metadata, header
338
+
339
+
340
+ def sanitise_value_bool(buff, j, value):
341
+ x = True
342
+ if value is None:
343
+ x = False
344
+ buff[j] = x
345
+
346
+
347
+ def sanitise_value_float_scalar(buff, j, value):
348
+ x = value
349
+ if value is None:
350
+ x = [constants.FLOAT32_MISSING]
351
+ buff[j] = x[0]
352
+
353
+
354
+ def sanitise_value_int_scalar(buff, j, value):
355
+ x = value
356
+ if value is None:
357
+ # print("MISSING", INT_MISSING, INT_FILL)
358
+ x = [constants.INT_MISSING]
359
+ else:
360
+ x = sanitise_int_array(value, ndmin=1, dtype=np.int32)
361
+ buff[j] = x[0]
362
+
363
+
364
+ def sanitise_value_string_scalar(buff, j, value):
365
+ if value is None:
366
+ buff[j] = "."
367
+ else:
368
+ buff[j] = value[0]
369
+
370
+
371
+ def sanitise_value_string_1d(buff, j, value):
372
+ if value is None:
373
+ buff[j] = "."
374
+ else:
375
+ # value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
376
+ # FIXME failure isn't coming from here, it seems to be from an
377
+ # incorrectly detected dimension in the zarr array
378
+ # The dimesions look all wrong, and the dtype should be Object
379
+ # not str
380
+ value = drop_empty_second_dim(value)
381
+ buff[j] = ""
382
+ buff[j, : value.shape[0]] = value
383
+
384
+
385
+ def sanitise_value_string_2d(buff, j, value):
386
+ if value is None:
387
+ buff[j] = "."
388
+ else:
389
+ # print(buff.shape, value.dtype, value)
390
+ # assert value.ndim == 2
391
+ buff[j] = ""
392
+ if value.ndim == 2:
393
+ buff[j, :, : value.shape[1]] = value
394
+ else:
395
+ # TODO check if this is still necessary
396
+ for k, val in enumerate(value):
397
+ buff[j, k, : len(val)] = val
398
+
399
+
400
+ def drop_empty_second_dim(value):
401
+ assert len(value.shape) == 1 or value.shape[1] == 1
402
+ if len(value.shape) == 2 and value.shape[1] == 1:
403
+ value = value[..., 0]
404
+ return value
405
+
406
+
407
+ def sanitise_value_float_1d(buff, j, value):
408
+ if value is None:
409
+ buff[j] = constants.FLOAT32_MISSING
410
+ else:
411
+ value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
412
+ # numpy will map None values to Nan, but we need a
413
+ # specific NaN
414
+ value[np.isnan(value)] = constants.FLOAT32_MISSING
415
+ value = drop_empty_second_dim(value)
416
+ buff[j] = constants.FLOAT32_FILL
417
+ buff[j, : value.shape[0]] = value
418
+
419
+
420
+ def sanitise_value_float_2d(buff, j, value):
421
+ if value is None:
422
+ buff[j] = constants.FLOAT32_MISSING
423
+ else:
424
+ # print("value = ", value)
425
+ value = np.array(value, ndmin=2, dtype=buff.dtype, copy=False)
426
+ buff[j] = constants.FLOAT32_FILL
427
+ buff[j, :, : value.shape[1]] = value
428
+
429
+
430
+ def sanitise_int_array(value, ndmin, dtype):
431
+ if isinstance(value, tuple):
432
+ value = [
433
+ constants.VCF_INT_MISSING if x is None else x for x in value
434
+ ] # NEEDS TEST
435
+ value = np.array(value, ndmin=ndmin, copy=False)
436
+ value[value == constants.VCF_INT_MISSING] = -1
437
+ value[value == constants.VCF_INT_FILL] = -2
438
+ # TODO watch out for clipping here!
439
+ return value.astype(dtype)
440
+
441
+
442
+ def sanitise_value_int_1d(buff, j, value):
443
+ if value is None:
444
+ buff[j] = -1
445
+ else:
446
+ value = sanitise_int_array(value, 1, buff.dtype)
447
+ value = drop_empty_second_dim(value)
448
+ buff[j] = -2
449
+ buff[j, : value.shape[0]] = value
450
+
451
+
452
+ def sanitise_value_int_2d(buff, j, value):
453
+ if value is None:
454
+ buff[j] = -1
455
+ else:
456
+ value = sanitise_int_array(value, 2, buff.dtype)
457
+ buff[j] = -2
458
+ buff[j, :, : value.shape[1]] = value
459
+
460
+
461
+ missing_value_map = {
462
+ "Integer": constants.INT_MISSING,
463
+ "Float": constants.FLOAT32_MISSING,
464
+ "String": constants.STR_MISSING,
465
+ "Character": constants.STR_MISSING,
466
+ "Flag": False,
467
+ }
468
+
469
+
470
+ class VcfValueTransformer:
471
+ """
472
+ Transform VCF values into the stored intermediate format used
473
+ in the IntermediateColumnarFormat, and update field summaries.
474
+ """
475
+
476
+ def __init__(self, field, num_samples):
477
+ self.field = field
478
+ self.num_samples = num_samples
479
+ self.dimension = 1
480
+ if field.category == "FORMAT":
481
+ self.dimension = 2
482
+ self.missing = missing_value_map[field.vcf_type]
483
+
484
+ @staticmethod
485
+ def factory(field, num_samples):
486
+ if field.vcf_type in ("Integer", "Flag"):
487
+ return IntegerValueTransformer(field, num_samples)
488
+ if field.vcf_type == "Float":
489
+ return FloatValueTransformer(field, num_samples)
490
+ if field.name in ["REF", "FILTERS", "ALT", "ID", "CHROM"]:
491
+ return SplitStringValueTransformer(field, num_samples)
492
+ return StringValueTransformer(field, num_samples)
493
+
494
+ def transform(self, vcf_value):
495
+ if isinstance(vcf_value, tuple):
496
+ vcf_value = [self.missing if v is None else v for v in vcf_value]
497
+ value = np.array(vcf_value, ndmin=self.dimension, copy=False)
498
+ return value
499
+
500
+ def transform_and_update_bounds(self, vcf_value):
501
+ if vcf_value is None:
502
+ return None
503
+ value = self.transform(vcf_value)
504
+ self.update_bounds(value)
505
+ # print(self.field.full_name, "T", vcf_value, "->", value)
506
+ return value
507
+
508
+
509
+ class IntegerValueTransformer(VcfValueTransformer):
510
+ def update_bounds(self, value):
511
+ summary = self.field.summary
512
+ # Mask out missing and fill values
513
+ # print(value)
514
+ a = value[value >= constants.MIN_INT_VALUE]
515
+ if a.size > 0:
516
+ summary.max_value = int(max(summary.max_value, np.max(a)))
517
+ summary.min_value = int(min(summary.min_value, np.min(a)))
518
+ number = value.shape[-1]
519
+ summary.max_number = max(summary.max_number, number)
520
+
521
+
522
+ class FloatValueTransformer(VcfValueTransformer):
523
+ def update_bounds(self, value):
524
+ summary = self.field.summary
525
+ summary.max_value = float(max(summary.max_value, np.max(value)))
526
+ summary.min_value = float(min(summary.min_value, np.min(value)))
527
+ number = value.shape[-1]
528
+ summary.max_number = max(summary.max_number, number)
529
+
530
+
531
+ class StringValueTransformer(VcfValueTransformer):
532
+ def update_bounds(self, value):
533
+ summary = self.field.summary
534
+ number = value.shape[-1]
535
+ # TODO would be nice to report string lengths, but not
536
+ # really necessary.
537
+ summary.max_number = max(summary.max_number, number)
538
+
539
+ def transform(self, vcf_value):
540
+ # print("transform", vcf_value)
541
+ if self.dimension == 1:
542
+ value = np.array(list(vcf_value.split(",")))
543
+ else:
544
+ # TODO can we make this faster??
545
+ value = np.array([v.split(",") for v in vcf_value], dtype="O")
546
+ # print("HERE", vcf_value, value)
547
+ # for v in vcf_value:
548
+ # print("\t", type(v), len(v), v.split(","))
549
+ # print("S: ", self.dimension, ":", value.shape, value)
550
+ return value
551
+
552
+
553
+ class SplitStringValueTransformer(StringValueTransformer):
554
+ def transform(self, vcf_value):
555
+ if vcf_value is None:
556
+ return self.missing_value # NEEDS TEST
557
+ assert self.dimension == 1
558
+ return np.array(vcf_value, ndmin=1, dtype="str")
559
+
560
+
561
+ def get_vcf_field_path(base_path, vcf_field):
562
+ if vcf_field.category == "fixed":
563
+ return base_path / vcf_field.name
564
+ return base_path / vcf_field.category / vcf_field.name
565
+
566
+
567
+ class IntermediateColumnarFormatField:
568
+ def __init__(self, icf, vcf_field):
569
+ self.vcf_field = vcf_field
570
+ self.path = get_vcf_field_path(icf.path, vcf_field)
571
+ self.compressor = icf.compressor
572
+ self.num_partitions = icf.num_partitions
573
+ self.num_records = icf.num_records
574
+ self.partition_record_index = icf.partition_record_index
575
+ # A map of partition id to the cumulative number of records
576
+ # in chunks within that partition
577
+ self._chunk_record_index = {}
578
+
579
+ @property
580
+ def name(self):
581
+ return self.vcf_field.full_name
582
+
583
+ def partition_path(self, partition_id):
584
+ return self.path / f"p{partition_id}"
585
+
586
+ def __repr__(self):
587
+ partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)]
588
+ return (
589
+ f"IntermediateColumnarFormatField(name={self.name}, "
590
+ f"partition_chunks={partition_chunks}, "
591
+ f"path={self.path})"
592
+ )
593
+
594
+ def num_chunks(self, partition_id):
595
+ return len(self.chunk_record_index(partition_id)) - 1
596
+
597
+ def chunk_record_index(self, partition_id):
598
+ if partition_id not in self._chunk_record_index:
599
+ index_path = self.partition_path(partition_id) / "chunk_index"
600
+ with open(index_path, "rb") as f:
601
+ a = pickle.load(f)
602
+ assert len(a) > 1
603
+ assert a[0] == 0
604
+ self._chunk_record_index[partition_id] = a
605
+ return self._chunk_record_index[partition_id]
606
+
607
+ def read_chunk(self, path):
608
+ with open(path, "rb") as f:
609
+ pkl = self.compressor.decode(f.read())
610
+ return pickle.loads(pkl)
611
+
612
+ def chunk_num_records(self, partition_id):
613
+ return np.diff(self.chunk_record_index(partition_id))
614
+
615
+ def chunks(self, partition_id, start_chunk=0):
616
+ partition_path = self.partition_path(partition_id)
617
+ chunk_cumulative_records = self.chunk_record_index(partition_id)
618
+ chunk_num_records = np.diff(chunk_cumulative_records)
619
+ for count, cumulative in zip(
620
+ chunk_num_records[start_chunk:], chunk_cumulative_records[start_chunk + 1 :]
621
+ ):
622
+ path = partition_path / f"{cumulative}"
623
+ chunk = self.read_chunk(path)
624
+ if len(chunk) != count:
625
+ raise ValueError(f"Corruption detected in chunk: {path}")
626
+ yield chunk
627
+
628
+ def iter_values(self, start=None, stop=None):
629
+ start = 0 if start is None else start
630
+ stop = self.num_records if stop is None else stop
631
+ start_partition = (
632
+ np.searchsorted(self.partition_record_index, start, side="right") - 1
633
+ )
634
+ offset = self.partition_record_index[start_partition]
635
+ assert offset <= start
636
+ chunk_offset = start - offset
637
+
638
+ chunk_record_index = self.chunk_record_index(start_partition)
639
+ start_chunk = (
640
+ np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1
641
+ )
642
+ record_id = offset + chunk_record_index[start_chunk]
643
+ assert record_id <= start
644
+ logger.debug(
645
+ f"Read {self.vcf_field.full_name} slice [{start}:{stop}]:"
646
+ f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}"
647
+ )
648
+ for chunk in self.chunks(start_partition, start_chunk):
649
+ for record in chunk:
650
+ if record_id == stop:
651
+ return
652
+ if record_id >= start:
653
+ yield record
654
+ record_id += 1
655
+ assert record_id > start
656
+ for partition_id in range(start_partition + 1, self.num_partitions):
657
+ for chunk in self.chunks(partition_id):
658
+ for record in chunk:
659
+ if record_id == stop:
660
+ return
661
+ yield record
662
+ record_id += 1
663
+
664
+ # Note: this involves some computation so should arguably be a method,
665
+ # but making a property for consistency with xarray etc
666
+ @property
667
+ def values(self):
668
+ ret = [None] * self.num_records
669
+ j = 0
670
+ for partition_id in range(self.num_partitions):
671
+ for chunk in self.chunks(partition_id):
672
+ for record in chunk:
673
+ ret[j] = record
674
+ j += 1
675
+ assert j == self.num_records
676
+ return ret
677
+
678
+ def sanitiser_factory(self, shape):
679
+ """
680
+ Return a function that sanitised values from this column
681
+ and writes into a buffer of the specified shape.
682
+ """
683
+ assert len(shape) <= 3
684
+ if self.vcf_field.vcf_type == "Flag":
685
+ assert len(shape) == 1
686
+ return sanitise_value_bool
687
+ elif self.vcf_field.vcf_type == "Float":
688
+ if len(shape) == 1:
689
+ return sanitise_value_float_scalar
690
+ elif len(shape) == 2:
691
+ return sanitise_value_float_1d
692
+ else:
693
+ return sanitise_value_float_2d
694
+ elif self.vcf_field.vcf_type == "Integer":
695
+ if len(shape) == 1:
696
+ return sanitise_value_int_scalar
697
+ elif len(shape) == 2:
698
+ return sanitise_value_int_1d
699
+ else:
700
+ return sanitise_value_int_2d
701
+ else:
702
+ assert self.vcf_field.vcf_type in ("String", "Character")
703
+ if len(shape) == 1:
704
+ return sanitise_value_string_scalar
705
+ elif len(shape) == 2:
706
+ return sanitise_value_string_1d
707
+ else:
708
+ return sanitise_value_string_2d
709
+
710
+
711
+ @dataclasses.dataclass
712
+ class IcfFieldWriter:
713
+ vcf_field: VcfField
714
+ path: pathlib.Path
715
+ transformer: VcfValueTransformer
716
+ compressor: Any
717
+ max_buffered_bytes: int
718
+ buff: list[Any] = dataclasses.field(default_factory=list)
719
+ buffered_bytes: int = 0
720
+ chunk_index: list[int] = dataclasses.field(default_factory=lambda: [0])
721
+ num_records: int = 0
722
+
723
+ def append(self, val):
724
+ val = self.transformer.transform_and_update_bounds(val)
725
+ assert val is None or isinstance(val, np.ndarray)
726
+ self.buff.append(val)
727
+ val_bytes = sys.getsizeof(val)
728
+ self.buffered_bytes += val_bytes
729
+ self.num_records += 1
730
+ if self.buffered_bytes >= self.max_buffered_bytes:
731
+ logger.debug(
732
+ f"Flush {self.path} buffered={self.buffered_bytes} "
733
+ f"max={self.max_buffered_bytes}"
734
+ )
735
+ self.write_chunk()
736
+ self.buff.clear()
737
+ self.buffered_bytes = 0
738
+
739
+ def write_chunk(self):
740
+ # Update index
741
+ self.chunk_index.append(self.num_records)
742
+ path = self.path / f"{self.num_records}"
743
+ logger.debug(f"Start write: {path}")
744
+ pkl = pickle.dumps(self.buff)
745
+ compressed = self.compressor.encode(pkl)
746
+ with open(path, "wb") as f:
747
+ f.write(compressed)
748
+
749
+ # Update the summary
750
+ self.vcf_field.summary.num_chunks += 1
751
+ self.vcf_field.summary.compressed_size += len(compressed)
752
+ self.vcf_field.summary.uncompressed_size += self.buffered_bytes
753
+ logger.debug(f"Finish write: {path}")
754
+
755
+ def flush(self):
756
+ logger.debug(
757
+ f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}"
758
+ )
759
+ if len(self.buff) > 0:
760
+ self.write_chunk()
761
+ with open(self.path / "chunk_index", "wb") as f:
762
+ a = np.array(self.chunk_index, dtype=int)
763
+ pickle.dump(a, f)
764
+
765
+
766
+ class IcfPartitionWriter(contextlib.AbstractContextManager):
767
+ """
768
+ Writes the data for a IntermediateColumnarFormat partition.
769
+ """
770
+
771
+ def __init__(
772
+ self,
773
+ icf_metadata,
774
+ out_path,
775
+ partition_index,
776
+ ):
777
+ self.partition_index = partition_index
778
+ # chunk_size is in megabytes
779
+ max_buffered_bytes = icf_metadata.column_chunk_size * 2**20
780
+ assert max_buffered_bytes > 0
781
+ compressor = numcodecs.get_codec(icf_metadata.compressor)
782
+
783
+ self.field_writers = {}
784
+ num_samples = len(icf_metadata.samples)
785
+ for vcf_field in icf_metadata.fields:
786
+ field_path = get_vcf_field_path(out_path, vcf_field)
787
+ field_partition_path = field_path / f"p{partition_index}"
788
+ # Should be robust to running explode_partition twice.
789
+ field_partition_path.mkdir(exist_ok=True)
790
+ transformer = VcfValueTransformer.factory(vcf_field, num_samples)
791
+ self.field_writers[vcf_field.full_name] = IcfFieldWriter(
792
+ vcf_field,
793
+ field_partition_path,
794
+ transformer,
795
+ compressor,
796
+ max_buffered_bytes,
797
+ )
798
+
799
+ @property
800
+ def field_summaries(self):
801
+ return {
802
+ name: field.vcf_field.summary for name, field in self.field_writers.items()
803
+ }
804
+
805
+ def append(self, name, value):
806
+ self.field_writers[name].append(value)
807
+
808
+ def __exit__(self, exc_type, exc_val, exc_tb):
809
+ if exc_type is None:
810
+ for field in self.field_writers.values():
811
+ field.flush()
812
+ return False
813
+
814
+
815
+ class IntermediateColumnarFormat(collections.abc.Mapping):
816
+ def __init__(self, path):
817
+ self.path = pathlib.Path(path)
818
+ # TODO raise a more informative error here telling people this
819
+ # directory is either a WIP or the wrong format.
820
+ with open(self.path / "metadata.json") as f:
821
+ self.metadata = IcfMetadata.fromdict(json.load(f))
822
+ with open(self.path / "header.txt") as f:
823
+ self.vcf_header = f.read()
824
+ self.compressor = numcodecs.get_codec(self.metadata.compressor)
825
+ self.fields = {}
826
+ partition_num_records = [
827
+ partition.num_records for partition in self.metadata.partitions
828
+ ]
829
+ # Allow us to find which partition a given record is in
830
+ self.partition_record_index = np.cumsum([0, *partition_num_records])
831
+ for field in self.metadata.fields:
832
+ self.fields[field.full_name] = IntermediateColumnarFormatField(self, field)
833
+ logger.info(
834
+ f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, "
835
+ f"records={self.num_records}, fields={self.num_fields})"
836
+ )
837
+
838
+ def __repr__(self):
839
+ return (
840
+ f"IntermediateColumnarFormat(fields={len(self)}, "
841
+ f"partitions={self.num_partitions}, "
842
+ f"records={self.num_records}, path={self.path})"
843
+ )
844
+
845
+ def __getitem__(self, key):
846
+ return self.fields[key]
847
+
848
+ def __iter__(self):
849
+ return iter(self.fields)
850
+
851
+ def __len__(self):
852
+ return len(self.fields)
853
+
854
+ def summary_table(self):
855
+ data = []
856
+ for name, col in self.fields.items():
857
+ summary = col.vcf_field.summary
858
+ d = {
859
+ "name": name,
860
+ "type": col.vcf_field.vcf_type,
861
+ "chunks": summary.num_chunks,
862
+ "size": core.display_size(summary.uncompressed_size),
863
+ "compressed": core.display_size(summary.compressed_size),
864
+ "max_n": summary.max_number,
865
+ "min_val": core.display_number(summary.min_value),
866
+ "max_val": core.display_number(summary.max_value),
867
+ }
868
+
869
+ data.append(d)
870
+ return data
871
+
872
+ @property
873
+ def num_records(self):
874
+ return self.metadata.num_records
875
+
876
+ @property
877
+ def num_partitions(self):
878
+ return len(self.metadata.partitions)
879
+
880
+ @property
881
+ def num_samples(self):
882
+ return len(self.metadata.samples)
883
+
884
+ @property
885
+ def num_fields(self):
886
+ return len(self.fields)
887
+
888
+
889
+ @dataclasses.dataclass
890
+ class IcfPartitionMetadata(core.JsonDataclass):
891
+ num_records: int
892
+ last_position: int
893
+ field_summaries: dict
894
+
895
+ @staticmethod
896
+ def fromdict(d):
897
+ md = IcfPartitionMetadata(**d)
898
+ for k, v in md.field_summaries.items():
899
+ md.field_summaries[k] = VcfFieldSummary.fromdict(v)
900
+ return md
901
+
902
+
903
+ def check_overlapping_partitions(partitions):
904
+ for i in range(1, len(partitions)):
905
+ prev_region = partitions[i - 1].region
906
+ current_region = partitions[i].region
907
+ if prev_region.contig == current_region.contig:
908
+ assert prev_region.end is not None
909
+ # Regions are *inclusive*
910
+ if prev_region.end >= current_region.start:
911
+ raise ValueError(
912
+ f"Overlapping VCF regions in partitions {i - 1} and {i}: "
913
+ f"{prev_region} and {current_region}"
914
+ )
915
+
916
+
917
+ def check_field_clobbering(icf_metadata):
918
+ info_field_names = set(field.name for field in icf_metadata.info_fields)
919
+ fixed_variant_fields = set(
920
+ ["contig", "id", "id_mask", "position", "allele", "filter", "quality"]
921
+ )
922
+ intersection = info_field_names & fixed_variant_fields
923
+ if len(intersection) > 0:
924
+ raise ValueError(
925
+ f"INFO field name(s) clashing with VCF Zarr spec: {intersection}"
926
+ )
927
+
928
+ format_field_names = set(field.name for field in icf_metadata.format_fields)
929
+ fixed_variant_fields = set(["genotype", "genotype_phased", "genotype_mask"])
930
+ intersection = format_field_names & fixed_variant_fields
931
+ if len(intersection) > 0:
932
+ raise ValueError(
933
+ f"FORMAT field name(s) clashing with VCF Zarr spec: {intersection}"
934
+ )
935
+
936
+
937
+ @dataclasses.dataclass
938
+ class IcfWriteSummary(core.JsonDataclass):
939
+ num_partitions: int
940
+ num_samples: int
941
+ num_variants: int
942
+
943
+
944
+ class IntermediateColumnarFormatWriter:
945
+ def __init__(self, path):
946
+ self.path = pathlib.Path(path)
947
+ self.wip_path = self.path / "wip"
948
+ self.metadata = None
949
+
950
+ @property
951
+ def num_partitions(self):
952
+ return len(self.metadata.partitions)
953
+
954
+ def init(
955
+ self,
956
+ vcfs,
957
+ *,
958
+ column_chunk_size=16,
959
+ worker_processes=1,
960
+ target_num_partitions=None,
961
+ show_progress=False,
962
+ compressor=None,
963
+ ):
964
+ if self.path.exists():
965
+ raise ValueError("ICF path already exists")
966
+ if compressor is None:
967
+ compressor = ICF_DEFAULT_COMPRESSOR
968
+ vcfs = [pathlib.Path(vcf) for vcf in vcfs]
969
+ target_num_partitions = max(target_num_partitions, len(vcfs))
970
+
971
+ # TODO move scan_vcfs into this class
972
+ icf_metadata, header = scan_vcfs(
973
+ vcfs,
974
+ worker_processes=worker_processes,
975
+ show_progress=show_progress,
976
+ target_num_partitions=target_num_partitions,
977
+ )
978
+ check_field_clobbering(icf_metadata)
979
+ self.metadata = icf_metadata
980
+ self.metadata.format_version = ICF_METADATA_FORMAT_VERSION
981
+ self.metadata.compressor = compressor.get_config()
982
+ self.metadata.column_chunk_size = column_chunk_size
983
+ # Bare minimum here for provenance - would be nice to include versions of key
984
+ # dependencies as well.
985
+ self.metadata.provenance = {"source": f"bio2zarr-{provenance.__version__}"}
986
+
987
+ self.mkdirs()
988
+
989
+ # Note: this is needed for the current version of the vcfzarr spec, but it's
990
+ # probably going to be dropped.
991
+ # https://github.com/pystatgen/vcf-zarr-spec/issues/15
992
+ # May be useful to keep lying around still though?
993
+ logger.info("Writing VCF header")
994
+ with open(self.path / "header.txt", "w") as f:
995
+ f.write(header)
996
+
997
+ logger.info("Writing WIP metadata")
998
+ with open(self.wip_path / "metadata.json", "w") as f:
999
+ json.dump(self.metadata.asdict(), f, indent=4)
1000
+ return IcfWriteSummary(
1001
+ num_partitions=self.num_partitions,
1002
+ num_variants=icf_metadata.num_records,
1003
+ num_samples=icf_metadata.num_samples,
1004
+ )
1005
+
1006
+ def mkdirs(self):
1007
+ num_dirs = len(self.metadata.fields)
1008
+ logger.info(f"Creating {num_dirs} field directories")
1009
+ self.path.mkdir()
1010
+ self.wip_path.mkdir()
1011
+ for field in self.metadata.fields:
1012
+ col_path = get_vcf_field_path(self.path, field)
1013
+ col_path.mkdir(parents=True)
1014
+
1015
+ def load_partition_summaries(self):
1016
+ summaries = []
1017
+ not_found = []
1018
+ for j in range(self.num_partitions):
1019
+ try:
1020
+ with open(self.wip_path / f"p{j}.json") as f:
1021
+ summaries.append(IcfPartitionMetadata.fromdict(json.load(f)))
1022
+ except FileNotFoundError:
1023
+ not_found.append(j)
1024
+ if len(not_found) > 0:
1025
+ raise FileNotFoundError(
1026
+ f"Partition metadata not found for {len(not_found)}"
1027
+ f" partitions: {not_found}"
1028
+ )
1029
+ return summaries
1030
+
1031
+ def load_metadata(self):
1032
+ if self.metadata is None:
1033
+ with open(self.wip_path / "metadata.json") as f:
1034
+ self.metadata = IcfMetadata.fromdict(json.load(f))
1035
+
1036
+ def process_partition(self, partition_index):
1037
+ self.load_metadata()
1038
+ summary_path = self.wip_path / f"p{partition_index}.json"
1039
+ # If someone is rewriting a summary path (for whatever reason), make sure it
1040
+ # doesn't look like it's already been completed.
1041
+ # NOTE to do this properly we probably need to take a lock on this file - but
1042
+ # this simple approach will catch the vast majority of problems.
1043
+ if summary_path.exists():
1044
+ summary_path.unlink()
1045
+
1046
+ partition = self.metadata.partitions[partition_index]
1047
+ logger.info(
1048
+ f"Start p{partition_index} {partition.vcf_path}__{partition.region}"
1049
+ )
1050
+ info_fields = self.metadata.info_fields
1051
+ format_fields = []
1052
+ has_gt = False
1053
+ for field in self.metadata.format_fields:
1054
+ if field.name == "GT":
1055
+ has_gt = True
1056
+ else:
1057
+ format_fields.append(field)
1058
+
1059
+ last_position = None
1060
+ with IcfPartitionWriter(
1061
+ self.metadata,
1062
+ self.path,
1063
+ partition_index,
1064
+ ) as tcw:
1065
+ with vcf_utils.IndexedVcf(partition.vcf_path) as ivcf:
1066
+ num_records = 0
1067
+ for variant in ivcf.variants(partition.region):
1068
+ num_records += 1
1069
+ last_position = variant.POS
1070
+ tcw.append("CHROM", variant.CHROM)
1071
+ tcw.append("POS", variant.POS)
1072
+ tcw.append("QUAL", variant.QUAL)
1073
+ tcw.append("ID", variant.ID)
1074
+ tcw.append("FILTERS", variant.FILTERS)
1075
+ tcw.append("REF", variant.REF)
1076
+ tcw.append("ALT", variant.ALT)
1077
+ for field in info_fields:
1078
+ tcw.append(field.full_name, variant.INFO.get(field.name, None))
1079
+ if has_gt:
1080
+ tcw.append("FORMAT/GT", variant.genotype.array())
1081
+ for field in format_fields:
1082
+ val = variant.format(field.name)
1083
+ tcw.append(field.full_name, val)
1084
+ # Note: an issue with updating the progress per variant here like
1085
+ # this is that we get a significant pause at the end of the counter
1086
+ # while all the "small" fields get flushed. Possibly not much to be
1087
+ # done about it.
1088
+ core.update_progress(1)
1089
+ logger.info(
1090
+ f"Finished reading VCF for partition {partition_index}, "
1091
+ f"flushing buffers"
1092
+ )
1093
+
1094
+ partition_metadata = IcfPartitionMetadata(
1095
+ num_records=num_records,
1096
+ last_position=last_position,
1097
+ field_summaries=tcw.field_summaries,
1098
+ )
1099
+ with open(summary_path, "w") as f:
1100
+ f.write(partition_metadata.asjson())
1101
+ logger.info(
1102
+ f"Finish p{partition_index} {partition.vcf_path}__{partition.region} "
1103
+ f"{num_records} records last_pos={last_position}"
1104
+ )
1105
+
1106
+ def explode(self, *, worker_processes=1, show_progress=False):
1107
+ self.load_metadata()
1108
+ num_records = self.metadata.num_records
1109
+ if np.isinf(num_records):
1110
+ logger.warning(
1111
+ "Total records unknown, cannot show progress; "
1112
+ "reindex VCFs with bcftools index to fix"
1113
+ )
1114
+ num_records = None
1115
+ num_fields = len(self.metadata.fields)
1116
+ num_samples = len(self.metadata.samples)
1117
+ logger.info(
1118
+ f"Exploding fields={num_fields} samples={num_samples}; "
1119
+ f"partitions={self.num_partitions} "
1120
+ f"variants={'unknown' if num_records is None else num_records}"
1121
+ )
1122
+ progress_config = core.ProgressConfig(
1123
+ total=num_records,
1124
+ units="vars",
1125
+ title="Explode",
1126
+ show=show_progress,
1127
+ )
1128
+ with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1129
+ for j in range(self.num_partitions):
1130
+ pwm.submit(self.process_partition, j)
1131
+
1132
+ def explode_partition(self, partition):
1133
+ self.load_metadata()
1134
+ if partition < 0 or partition >= self.num_partitions:
1135
+ raise ValueError("Partition index not in the valid range")
1136
+ self.process_partition(partition)
1137
+
1138
+ def finalise(self):
1139
+ self.load_metadata()
1140
+ partition_summaries = self.load_partition_summaries()
1141
+ total_records = 0
1142
+ for index, summary in enumerate(partition_summaries):
1143
+ partition_records = summary.num_records
1144
+ self.metadata.partitions[index].num_records = partition_records
1145
+ self.metadata.partitions[index].region.end = summary.last_position
1146
+ total_records += partition_records
1147
+ if not np.isinf(self.metadata.num_records):
1148
+ # Note: this is just telling us that there's a bug in the
1149
+ # index based record counting code, but it doesn't actually
1150
+ # matter much. We may want to just make this a warning if
1151
+ # we hit regular problems.
1152
+ assert total_records == self.metadata.num_records
1153
+ self.metadata.num_records = total_records
1154
+
1155
+ check_overlapping_partitions(self.metadata.partitions)
1156
+
1157
+ for field in self.metadata.fields:
1158
+ for summary in partition_summaries:
1159
+ field.summary.update(summary.field_summaries[field.full_name])
1160
+
1161
+ logger.info("Finalising metadata")
1162
+ with open(self.path / "metadata.json", "w") as f:
1163
+ f.write(self.metadata.asjson())
1164
+
1165
+ logger.debug("Removing WIP directory")
1166
+ shutil.rmtree(self.wip_path)
1167
+
1168
+
1169
+ def explode(
1170
+ icf_path,
1171
+ vcfs,
1172
+ *,
1173
+ column_chunk_size=16,
1174
+ worker_processes=1,
1175
+ show_progress=False,
1176
+ compressor=None,
1177
+ ):
1178
+ writer = IntermediateColumnarFormatWriter(icf_path)
1179
+ writer.init(
1180
+ vcfs,
1181
+ # Heuristic to get reasonable worker utilisation with lumpy partition sizing
1182
+ target_num_partitions=max(1, worker_processes * 4),
1183
+ worker_processes=worker_processes,
1184
+ show_progress=show_progress,
1185
+ column_chunk_size=column_chunk_size,
1186
+ compressor=compressor,
1187
+ )
1188
+ writer.explode(worker_processes=worker_processes, show_progress=show_progress)
1189
+ writer.finalise()
1190
+ return IntermediateColumnarFormat(icf_path)
1191
+
1192
+
1193
+ def explode_init(
1194
+ icf_path,
1195
+ vcfs,
1196
+ *,
1197
+ column_chunk_size=16,
1198
+ target_num_partitions=1,
1199
+ worker_processes=1,
1200
+ show_progress=False,
1201
+ compressor=None,
1202
+ ):
1203
+ writer = IntermediateColumnarFormatWriter(icf_path)
1204
+ return writer.init(
1205
+ vcfs,
1206
+ target_num_partitions=target_num_partitions,
1207
+ worker_processes=worker_processes,
1208
+ show_progress=show_progress,
1209
+ column_chunk_size=column_chunk_size,
1210
+ compressor=compressor,
1211
+ )
1212
+
1213
+
1214
+ def explode_partition(icf_path, partition):
1215
+ writer = IntermediateColumnarFormatWriter(icf_path)
1216
+ writer.explode_partition(partition)
1217
+
1218
+
1219
+ def explode_finalise(icf_path):
1220
+ writer = IntermediateColumnarFormatWriter(icf_path)
1221
+ writer.finalise()