bio2zarr 0.0.6__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bio2zarr might be problematic. Click here for more details.

@@ -0,0 +1,1220 @@
1
+ import collections
2
+ import contextlib
3
+ import dataclasses
4
+ import json
5
+ import logging
6
+ import math
7
+ import pathlib
8
+ import pickle
9
+ import shutil
10
+ import sys
11
+ from typing import Any
12
+
13
+ import numcodecs
14
+ import numpy as np
15
+
16
+ from .. import constants, core, provenance, vcf_utils
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class VcfFieldSummary(core.JsonDataclass):
23
+ num_chunks: int = 0
24
+ compressed_size: int = 0
25
+ uncompressed_size: int = 0
26
+ max_number: int = 0 # Corresponds to VCF Number field, depends on context
27
+ # Only defined for numeric fields
28
+ max_value: Any = -math.inf
29
+ min_value: Any = math.inf
30
+
31
+ def update(self, other):
32
+ self.num_chunks += other.num_chunks
33
+ self.compressed_size += other.compressed_size
34
+ self.uncompressed_size += other.uncompressed_size
35
+ self.max_number = max(self.max_number, other.max_number)
36
+ self.min_value = min(self.min_value, other.min_value)
37
+ self.max_value = max(self.max_value, other.max_value)
38
+
39
+ @staticmethod
40
+ def fromdict(d):
41
+ return VcfFieldSummary(**d)
42
+
43
+
44
+ @dataclasses.dataclass
45
+ class VcfField:
46
+ category: str
47
+ name: str
48
+ vcf_number: str
49
+ vcf_type: str
50
+ description: str
51
+ summary: VcfFieldSummary
52
+
53
+ @staticmethod
54
+ def from_header(definition):
55
+ category = definition["HeaderType"]
56
+ name = definition["ID"]
57
+ vcf_number = definition["Number"]
58
+ vcf_type = definition["Type"]
59
+ return VcfField(
60
+ category=category,
61
+ name=name,
62
+ vcf_number=vcf_number,
63
+ vcf_type=vcf_type,
64
+ description=definition["Description"].strip('"'),
65
+ summary=VcfFieldSummary(),
66
+ )
67
+
68
+ @staticmethod
69
+ def fromdict(d):
70
+ f = VcfField(**d)
71
+ f.summary = VcfFieldSummary(**d["summary"])
72
+ return f
73
+
74
+ @property
75
+ def full_name(self):
76
+ if self.category == "fixed":
77
+ return self.name
78
+ return f"{self.category}/{self.name}"
79
+
80
+ def smallest_dtype(self):
81
+ """
82
+ Returns the smallest dtype suitable for this field based
83
+ on type, and values.
84
+ """
85
+ s = self.summary
86
+ if self.vcf_type == "Float":
87
+ ret = "f4"
88
+ elif self.vcf_type == "Integer":
89
+ if not math.isfinite(s.max_value):
90
+ # All missing values; use i1. Note we should have some API to
91
+ # check more explicitly for missingness:
92
+ # https://github.com/sgkit-dev/bio2zarr/issues/131
93
+ ret = "i1"
94
+ else:
95
+ ret = core.min_int_dtype(s.min_value, s.max_value)
96
+ elif self.vcf_type == "Flag":
97
+ ret = "bool"
98
+ elif self.vcf_type == "Character":
99
+ ret = "U1"
100
+ else:
101
+ assert self.vcf_type == "String"
102
+ ret = "O"
103
+ return ret
104
+
105
+
106
+ @dataclasses.dataclass
107
+ class VcfPartition:
108
+ vcf_path: str
109
+ region: str
110
+ num_records: int = -1
111
+
112
+
113
+ ICF_METADATA_FORMAT_VERSION = "0.3"
114
+ ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc(
115
+ cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE
116
+ )
117
+
118
+
119
+ @dataclasses.dataclass
120
+ class Contig:
121
+ id: str
122
+ length: int = None
123
+
124
+
125
+ @dataclasses.dataclass
126
+ class Sample:
127
+ id: str
128
+
129
+
130
+ @dataclasses.dataclass
131
+ class Filter:
132
+ id: str
133
+ description: str = ""
134
+
135
+
136
+ @dataclasses.dataclass
137
+ class IcfMetadata(core.JsonDataclass):
138
+ samples: list
139
+ contigs: list
140
+ filters: list
141
+ fields: list
142
+ partitions: list = None
143
+ format_version: str = None
144
+ compressor: dict = None
145
+ column_chunk_size: int = None
146
+ provenance: dict = None
147
+ num_records: int = -1
148
+
149
+ @property
150
+ def info_fields(self):
151
+ fields = []
152
+ for field in self.fields:
153
+ if field.category == "INFO":
154
+ fields.append(field)
155
+ return fields
156
+
157
+ @property
158
+ def format_fields(self):
159
+ fields = []
160
+ for field in self.fields:
161
+ if field.category == "FORMAT":
162
+ fields.append(field)
163
+ return fields
164
+
165
+ @property
166
+ def num_contigs(self):
167
+ return len(self.contigs)
168
+
169
+ @property
170
+ def num_filters(self):
171
+ return len(self.filters)
172
+
173
+ @property
174
+ def num_samples(self):
175
+ return len(self.samples)
176
+
177
+ @staticmethod
178
+ def fromdict(d):
179
+ if d["format_version"] != ICF_METADATA_FORMAT_VERSION:
180
+ raise ValueError(
181
+ "Intermediate columnar metadata format version mismatch: "
182
+ f"{d['format_version']} != {ICF_METADATA_FORMAT_VERSION}"
183
+ )
184
+ partitions = [VcfPartition(**pd) for pd in d["partitions"]]
185
+ for p in partitions:
186
+ p.region = vcf_utils.Region(**p.region)
187
+ d = d.copy()
188
+ d["partitions"] = partitions
189
+ d["fields"] = [VcfField.fromdict(fd) for fd in d["fields"]]
190
+ d["samples"] = [Sample(**sd) for sd in d["samples"]]
191
+ d["filters"] = [Filter(**fd) for fd in d["filters"]]
192
+ d["contigs"] = [Contig(**cd) for cd in d["contigs"]]
193
+ return IcfMetadata(**d)
194
+
195
+
196
+ def fixed_vcf_field_definitions():
197
+ def make_field_def(name, vcf_type, vcf_number):
198
+ return VcfField(
199
+ category="fixed",
200
+ name=name,
201
+ vcf_type=vcf_type,
202
+ vcf_number=vcf_number,
203
+ description="",
204
+ summary=VcfFieldSummary(),
205
+ )
206
+
207
+ fields = [
208
+ make_field_def("CHROM", "String", "1"),
209
+ make_field_def("POS", "Integer", "1"),
210
+ make_field_def("QUAL", "Float", "1"),
211
+ make_field_def("ID", "String", "."),
212
+ make_field_def("FILTERS", "String", "."),
213
+ make_field_def("REF", "String", "1"),
214
+ make_field_def("ALT", "String", "."),
215
+ ]
216
+ return fields
217
+
218
+
219
+ def scan_vcf(path, target_num_partitions):
220
+ with vcf_utils.IndexedVcf(path) as indexed_vcf:
221
+ vcf = indexed_vcf.vcf
222
+ filters = []
223
+ pass_index = -1
224
+ for h in vcf.header_iter():
225
+ if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str):
226
+ try:
227
+ description = h["Description"].strip('"')
228
+ except KeyError:
229
+ description = ""
230
+ if h["ID"] == "PASS":
231
+ pass_index = len(filters)
232
+ filters.append(Filter(h["ID"], description))
233
+
234
+ # Ensure PASS is the first filter if present
235
+ if pass_index > 0:
236
+ pass_filter = filters.pop(pass_index)
237
+ filters.insert(0, pass_filter)
238
+
239
+ fields = fixed_vcf_field_definitions()
240
+ for h in vcf.header_iter():
241
+ if h["HeaderType"] in ["INFO", "FORMAT"]:
242
+ field = VcfField.from_header(h)
243
+ if field.name == "GT":
244
+ field.vcf_type = "Integer"
245
+ field.vcf_number = "."
246
+ fields.append(field)
247
+
248
+ try:
249
+ contig_lengths = vcf.seqlens
250
+ except AttributeError:
251
+ contig_lengths = [None for _ in vcf.seqnames]
252
+
253
+ metadata = IcfMetadata(
254
+ samples=[Sample(sample_id) for sample_id in vcf.samples],
255
+ contigs=[
256
+ Contig(contig_id, length)
257
+ for contig_id, length in zip(vcf.seqnames, contig_lengths)
258
+ ],
259
+ filters=filters,
260
+ fields=fields,
261
+ partitions=[],
262
+ num_records=sum(indexed_vcf.contig_record_counts().values()),
263
+ )
264
+
265
+ regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions)
266
+ logger.info(
267
+ f"Split {path} into {len(regions)} regions (target={target_num_partitions})"
268
+ )
269
+ for region in regions:
270
+ metadata.partitions.append(
271
+ VcfPartition(
272
+ # TODO should this be fully resolving the path? Otherwise it's all
273
+ # relative to the original WD
274
+ vcf_path=str(path),
275
+ region=region,
276
+ )
277
+ )
278
+ core.update_progress(1)
279
+ return metadata, vcf.raw_header
280
+
281
+
282
+ def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
283
+ logger.info(
284
+ f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
285
+ f" partitions."
286
+ )
287
+ # An easy mistake to make is to pass the same file twice. Check this early on.
288
+ for path, count in collections.Counter(paths).items():
289
+ if not path.exists(): # NEEDS TEST
290
+ raise FileNotFoundError(path)
291
+ if count > 1:
292
+ raise ValueError(f"Duplicate path provided: {path}")
293
+
294
+ progress_config = core.ProgressConfig(
295
+ total=len(paths),
296
+ units="files",
297
+ title="Scan",
298
+ show=show_progress,
299
+ )
300
+ with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
301
+ for path in paths:
302
+ pwm.submit(scan_vcf, path, max(1, target_num_partitions // len(paths)))
303
+ results = list(pwm.results_as_completed())
304
+
305
+ # Sort to make the ordering deterministic
306
+ results.sort(key=lambda t: t[0].partitions[0].vcf_path)
307
+ # We just take the first header, assuming the others
308
+ # are compatible.
309
+ all_partitions = []
310
+ total_records = 0
311
+ for metadata, _ in results:
312
+ for partition in metadata.partitions:
313
+ logger.debug(f"Scanned partition {partition}")
314
+ all_partitions.append(partition)
315
+ total_records += metadata.num_records
316
+ metadata.num_records = 0
317
+ metadata.partitions = []
318
+
319
+ icf_metadata, header = results[0]
320
+ for metadata, _ in results[1:]:
321
+ if metadata != icf_metadata:
322
+ raise ValueError("Incompatible VCF chunks")
323
+
324
+ # Note: this will be infinity here if any of the chunks has an index
325
+ # that doesn't keep track of the number of records per-contig
326
+ icf_metadata.num_records = total_records
327
+
328
+ # Sort by contig (in the order they appear in the header) first,
329
+ # then by start coordinate
330
+ contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)}
331
+ all_partitions.sort(
332
+ key=lambda x: (contig_index_map[x.region.contig], x.region.start)
333
+ )
334
+ icf_metadata.partitions = all_partitions
335
+ logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.")
336
+ return icf_metadata, header
337
+
338
+
339
+ def sanitise_value_bool(buff, j, value):
340
+ x = True
341
+ if value is None:
342
+ x = False
343
+ buff[j] = x
344
+
345
+
346
+ def sanitise_value_float_scalar(buff, j, value):
347
+ x = value
348
+ if value is None:
349
+ x = [constants.FLOAT32_MISSING]
350
+ buff[j] = x[0]
351
+
352
+
353
+ def sanitise_value_int_scalar(buff, j, value):
354
+ x = value
355
+ if value is None:
356
+ # print("MISSING", INT_MISSING, INT_FILL)
357
+ x = [constants.INT_MISSING]
358
+ else:
359
+ x = sanitise_int_array(value, ndmin=1, dtype=np.int32)
360
+ buff[j] = x[0]
361
+
362
+
363
+ def sanitise_value_string_scalar(buff, j, value):
364
+ if value is None:
365
+ buff[j] = "."
366
+ else:
367
+ buff[j] = value[0]
368
+
369
+
370
+ def sanitise_value_string_1d(buff, j, value):
371
+ if value is None:
372
+ buff[j] = "."
373
+ else:
374
+ # value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
375
+ # FIXME failure isn't coming from here, it seems to be from an
376
+ # incorrectly detected dimension in the zarr array
377
+ # The dimesions look all wrong, and the dtype should be Object
378
+ # not str
379
+ value = drop_empty_second_dim(value)
380
+ buff[j] = ""
381
+ buff[j, : value.shape[0]] = value
382
+
383
+
384
+ def sanitise_value_string_2d(buff, j, value):
385
+ if value is None:
386
+ buff[j] = "."
387
+ else:
388
+ # print(buff.shape, value.dtype, value)
389
+ # assert value.ndim == 2
390
+ buff[j] = ""
391
+ if value.ndim == 2:
392
+ buff[j, :, : value.shape[1]] = value
393
+ else:
394
+ # TODO check if this is still necessary
395
+ for k, val in enumerate(value):
396
+ buff[j, k, : len(val)] = val
397
+
398
+
399
+ def drop_empty_second_dim(value):
400
+ assert len(value.shape) == 1 or value.shape[1] == 1
401
+ if len(value.shape) == 2 and value.shape[1] == 1:
402
+ value = value[..., 0]
403
+ return value
404
+
405
+
406
+ def sanitise_value_float_1d(buff, j, value):
407
+ if value is None:
408
+ buff[j] = constants.FLOAT32_MISSING
409
+ else:
410
+ value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
411
+ # numpy will map None values to Nan, but we need a
412
+ # specific NaN
413
+ value[np.isnan(value)] = constants.FLOAT32_MISSING
414
+ value = drop_empty_second_dim(value)
415
+ buff[j] = constants.FLOAT32_FILL
416
+ buff[j, : value.shape[0]] = value
417
+
418
+
419
+ def sanitise_value_float_2d(buff, j, value):
420
+ if value is None:
421
+ buff[j] = constants.FLOAT32_MISSING
422
+ else:
423
+ # print("value = ", value)
424
+ value = np.array(value, ndmin=2, dtype=buff.dtype, copy=False)
425
+ buff[j] = constants.FLOAT32_FILL
426
+ buff[j, :, : value.shape[1]] = value
427
+
428
+
429
+ def sanitise_int_array(value, ndmin, dtype):
430
+ if isinstance(value, tuple):
431
+ value = [
432
+ constants.VCF_INT_MISSING if x is None else x for x in value
433
+ ] # NEEDS TEST
434
+ value = np.array(value, ndmin=ndmin, copy=False)
435
+ value[value == constants.VCF_INT_MISSING] = -1
436
+ value[value == constants.VCF_INT_FILL] = -2
437
+ # TODO watch out for clipping here!
438
+ return value.astype(dtype)
439
+
440
+
441
+ def sanitise_value_int_1d(buff, j, value):
442
+ if value is None:
443
+ buff[j] = -1
444
+ else:
445
+ value = sanitise_int_array(value, 1, buff.dtype)
446
+ value = drop_empty_second_dim(value)
447
+ buff[j] = -2
448
+ buff[j, : value.shape[0]] = value
449
+
450
+
451
+ def sanitise_value_int_2d(buff, j, value):
452
+ if value is None:
453
+ buff[j] = -1
454
+ else:
455
+ value = sanitise_int_array(value, 2, buff.dtype)
456
+ buff[j] = -2
457
+ buff[j, :, : value.shape[1]] = value
458
+
459
+
460
+ missing_value_map = {
461
+ "Integer": constants.INT_MISSING,
462
+ "Float": constants.FLOAT32_MISSING,
463
+ "String": constants.STR_MISSING,
464
+ "Character": constants.STR_MISSING,
465
+ "Flag": False,
466
+ }
467
+
468
+
469
+ class VcfValueTransformer:
470
+ """
471
+ Transform VCF values into the stored intermediate format used
472
+ in the IntermediateColumnarFormat, and update field summaries.
473
+ """
474
+
475
+ def __init__(self, field, num_samples):
476
+ self.field = field
477
+ self.num_samples = num_samples
478
+ self.dimension = 1
479
+ if field.category == "FORMAT":
480
+ self.dimension = 2
481
+ self.missing = missing_value_map[field.vcf_type]
482
+
483
+ @staticmethod
484
+ def factory(field, num_samples):
485
+ if field.vcf_type in ("Integer", "Flag"):
486
+ return IntegerValueTransformer(field, num_samples)
487
+ if field.vcf_type == "Float":
488
+ return FloatValueTransformer(field, num_samples)
489
+ if field.name in ["REF", "FILTERS", "ALT", "ID", "CHROM"]:
490
+ return SplitStringValueTransformer(field, num_samples)
491
+ return StringValueTransformer(field, num_samples)
492
+
493
+ def transform(self, vcf_value):
494
+ if isinstance(vcf_value, tuple):
495
+ vcf_value = [self.missing if v is None else v for v in vcf_value]
496
+ value = np.array(vcf_value, ndmin=self.dimension, copy=False)
497
+ return value
498
+
499
+ def transform_and_update_bounds(self, vcf_value):
500
+ if vcf_value is None:
501
+ return None
502
+ value = self.transform(vcf_value)
503
+ self.update_bounds(value)
504
+ # print(self.field.full_name, "T", vcf_value, "->", value)
505
+ return value
506
+
507
+
508
+ class IntegerValueTransformer(VcfValueTransformer):
509
+ def update_bounds(self, value):
510
+ summary = self.field.summary
511
+ # Mask out missing and fill values
512
+ # print(value)
513
+ a = value[value >= constants.MIN_INT_VALUE]
514
+ if a.size > 0:
515
+ summary.max_value = int(max(summary.max_value, np.max(a)))
516
+ summary.min_value = int(min(summary.min_value, np.min(a)))
517
+ number = value.shape[-1]
518
+ summary.max_number = max(summary.max_number, number)
519
+
520
+
521
+ class FloatValueTransformer(VcfValueTransformer):
522
+ def update_bounds(self, value):
523
+ summary = self.field.summary
524
+ summary.max_value = float(max(summary.max_value, np.max(value)))
525
+ summary.min_value = float(min(summary.min_value, np.min(value)))
526
+ number = value.shape[-1]
527
+ summary.max_number = max(summary.max_number, number)
528
+
529
+
530
+ class StringValueTransformer(VcfValueTransformer):
531
+ def update_bounds(self, value):
532
+ summary = self.field.summary
533
+ number = value.shape[-1]
534
+ # TODO would be nice to report string lengths, but not
535
+ # really necessary.
536
+ summary.max_number = max(summary.max_number, number)
537
+
538
+ def transform(self, vcf_value):
539
+ # print("transform", vcf_value)
540
+ if self.dimension == 1:
541
+ value = np.array(list(vcf_value.split(",")))
542
+ else:
543
+ # TODO can we make this faster??
544
+ value = np.array([v.split(",") for v in vcf_value], dtype="O")
545
+ # print("HERE", vcf_value, value)
546
+ # for v in vcf_value:
547
+ # print("\t", type(v), len(v), v.split(","))
548
+ # print("S: ", self.dimension, ":", value.shape, value)
549
+ return value
550
+
551
+
552
+ class SplitStringValueTransformer(StringValueTransformer):
553
+ def transform(self, vcf_value):
554
+ if vcf_value is None:
555
+ return self.missing_value # NEEDS TEST
556
+ assert self.dimension == 1
557
+ return np.array(vcf_value, ndmin=1, dtype="str")
558
+
559
+
560
+ def get_vcf_field_path(base_path, vcf_field):
561
+ if vcf_field.category == "fixed":
562
+ return base_path / vcf_field.name
563
+ return base_path / vcf_field.category / vcf_field.name
564
+
565
+
566
+ class IntermediateColumnarFormatField:
567
+ def __init__(self, icf, vcf_field):
568
+ self.vcf_field = vcf_field
569
+ self.path = get_vcf_field_path(icf.path, vcf_field)
570
+ self.compressor = icf.compressor
571
+ self.num_partitions = icf.num_partitions
572
+ self.num_records = icf.num_records
573
+ self.partition_record_index = icf.partition_record_index
574
+ # A map of partition id to the cumulative number of records
575
+ # in chunks within that partition
576
+ self._chunk_record_index = {}
577
+
578
+ @property
579
+ def name(self):
580
+ return self.vcf_field.full_name
581
+
582
+ def partition_path(self, partition_id):
583
+ return self.path / f"p{partition_id}"
584
+
585
+ def __repr__(self):
586
+ partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)]
587
+ return (
588
+ f"IntermediateColumnarFormatField(name={self.name}, "
589
+ f"partition_chunks={partition_chunks}, "
590
+ f"path={self.path})"
591
+ )
592
+
593
+ def num_chunks(self, partition_id):
594
+ return len(self.chunk_record_index(partition_id)) - 1
595
+
596
+ def chunk_record_index(self, partition_id):
597
+ if partition_id not in self._chunk_record_index:
598
+ index_path = self.partition_path(partition_id) / "chunk_index"
599
+ with open(index_path, "rb") as f:
600
+ a = pickle.load(f)
601
+ assert len(a) > 1
602
+ assert a[0] == 0
603
+ self._chunk_record_index[partition_id] = a
604
+ return self._chunk_record_index[partition_id]
605
+
606
+ def read_chunk(self, path):
607
+ with open(path, "rb") as f:
608
+ pkl = self.compressor.decode(f.read())
609
+ return pickle.loads(pkl)
610
+
611
+ def chunk_num_records(self, partition_id):
612
+ return np.diff(self.chunk_record_index(partition_id))
613
+
614
+ def chunks(self, partition_id, start_chunk=0):
615
+ partition_path = self.partition_path(partition_id)
616
+ chunk_cumulative_records = self.chunk_record_index(partition_id)
617
+ chunk_num_records = np.diff(chunk_cumulative_records)
618
+ for count, cumulative in zip(
619
+ chunk_num_records[start_chunk:], chunk_cumulative_records[start_chunk + 1 :]
620
+ ):
621
+ path = partition_path / f"{cumulative}"
622
+ chunk = self.read_chunk(path)
623
+ if len(chunk) != count:
624
+ raise ValueError(f"Corruption detected in chunk: {path}")
625
+ yield chunk
626
+
627
+ def iter_values(self, start=None, stop=None):
628
+ start = 0 if start is None else start
629
+ stop = self.num_records if stop is None else stop
630
+ start_partition = (
631
+ np.searchsorted(self.partition_record_index, start, side="right") - 1
632
+ )
633
+ offset = self.partition_record_index[start_partition]
634
+ assert offset <= start
635
+ chunk_offset = start - offset
636
+
637
+ chunk_record_index = self.chunk_record_index(start_partition)
638
+ start_chunk = (
639
+ np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1
640
+ )
641
+ record_id = offset + chunk_record_index[start_chunk]
642
+ assert record_id <= start
643
+ logger.debug(
644
+ f"Read {self.vcf_field.full_name} slice [{start}:{stop}]:"
645
+ f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}"
646
+ )
647
+ for chunk in self.chunks(start_partition, start_chunk):
648
+ for record in chunk:
649
+ if record_id == stop:
650
+ return
651
+ if record_id >= start:
652
+ yield record
653
+ record_id += 1
654
+ assert record_id > start
655
+ for partition_id in range(start_partition + 1, self.num_partitions):
656
+ for chunk in self.chunks(partition_id):
657
+ for record in chunk:
658
+ if record_id == stop:
659
+ return
660
+ yield record
661
+ record_id += 1
662
+
663
+ # Note: this involves some computation so should arguably be a method,
664
+ # but making a property for consistency with xarray etc
665
+ @property
666
+ def values(self):
667
+ ret = [None] * self.num_records
668
+ j = 0
669
+ for partition_id in range(self.num_partitions):
670
+ for chunk in self.chunks(partition_id):
671
+ for record in chunk:
672
+ ret[j] = record
673
+ j += 1
674
+ assert j == self.num_records
675
+ return ret
676
+
677
+ def sanitiser_factory(self, shape):
678
+ """
679
+ Return a function that sanitised values from this column
680
+ and writes into a buffer of the specified shape.
681
+ """
682
+ assert len(shape) <= 3
683
+ if self.vcf_field.vcf_type == "Flag":
684
+ assert len(shape) == 1
685
+ return sanitise_value_bool
686
+ elif self.vcf_field.vcf_type == "Float":
687
+ if len(shape) == 1:
688
+ return sanitise_value_float_scalar
689
+ elif len(shape) == 2:
690
+ return sanitise_value_float_1d
691
+ else:
692
+ return sanitise_value_float_2d
693
+ elif self.vcf_field.vcf_type == "Integer":
694
+ if len(shape) == 1:
695
+ return sanitise_value_int_scalar
696
+ elif len(shape) == 2:
697
+ return sanitise_value_int_1d
698
+ else:
699
+ return sanitise_value_int_2d
700
+ else:
701
+ assert self.vcf_field.vcf_type in ("String", "Character")
702
+ if len(shape) == 1:
703
+ return sanitise_value_string_scalar
704
+ elif len(shape) == 2:
705
+ return sanitise_value_string_1d
706
+ else:
707
+ return sanitise_value_string_2d
708
+
709
+
710
+ @dataclasses.dataclass
711
+ class IcfFieldWriter:
712
+ vcf_field: VcfField
713
+ path: pathlib.Path
714
+ transformer: VcfValueTransformer
715
+ compressor: Any
716
+ max_buffered_bytes: int
717
+ buff: list[Any] = dataclasses.field(default_factory=list)
718
+ buffered_bytes: int = 0
719
+ chunk_index: list[int] = dataclasses.field(default_factory=lambda: [0])
720
+ num_records: int = 0
721
+
722
+ def append(self, val):
723
+ val = self.transformer.transform_and_update_bounds(val)
724
+ assert val is None or isinstance(val, np.ndarray)
725
+ self.buff.append(val)
726
+ val_bytes = sys.getsizeof(val)
727
+ self.buffered_bytes += val_bytes
728
+ self.num_records += 1
729
+ if self.buffered_bytes >= self.max_buffered_bytes:
730
+ logger.debug(
731
+ f"Flush {self.path} buffered={self.buffered_bytes} "
732
+ f"max={self.max_buffered_bytes}"
733
+ )
734
+ self.write_chunk()
735
+ self.buff.clear()
736
+ self.buffered_bytes = 0
737
+
738
+ def write_chunk(self):
739
+ # Update index
740
+ self.chunk_index.append(self.num_records)
741
+ path = self.path / f"{self.num_records}"
742
+ logger.debug(f"Start write: {path}")
743
+ pkl = pickle.dumps(self.buff)
744
+ compressed = self.compressor.encode(pkl)
745
+ with open(path, "wb") as f:
746
+ f.write(compressed)
747
+
748
+ # Update the summary
749
+ self.vcf_field.summary.num_chunks += 1
750
+ self.vcf_field.summary.compressed_size += len(compressed)
751
+ self.vcf_field.summary.uncompressed_size += self.buffered_bytes
752
+ logger.debug(f"Finish write: {path}")
753
+
754
+ def flush(self):
755
+ logger.debug(
756
+ f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}"
757
+ )
758
+ if len(self.buff) > 0:
759
+ self.write_chunk()
760
+ with open(self.path / "chunk_index", "wb") as f:
761
+ a = np.array(self.chunk_index, dtype=int)
762
+ pickle.dump(a, f)
763
+
764
+
765
+ class IcfPartitionWriter(contextlib.AbstractContextManager):
766
+ """
767
+ Writes the data for a IntermediateColumnarFormat partition.
768
+ """
769
+
770
+ def __init__(
771
+ self,
772
+ icf_metadata,
773
+ out_path,
774
+ partition_index,
775
+ ):
776
+ self.partition_index = partition_index
777
+ # chunk_size is in megabytes
778
+ max_buffered_bytes = icf_metadata.column_chunk_size * 2**20
779
+ assert max_buffered_bytes > 0
780
+ compressor = numcodecs.get_codec(icf_metadata.compressor)
781
+
782
+ self.field_writers = {}
783
+ num_samples = len(icf_metadata.samples)
784
+ for vcf_field in icf_metadata.fields:
785
+ field_path = get_vcf_field_path(out_path, vcf_field)
786
+ field_partition_path = field_path / f"p{partition_index}"
787
+ # Should be robust to running explode_partition twice.
788
+ field_partition_path.mkdir(exist_ok=True)
789
+ transformer = VcfValueTransformer.factory(vcf_field, num_samples)
790
+ self.field_writers[vcf_field.full_name] = IcfFieldWriter(
791
+ vcf_field,
792
+ field_partition_path,
793
+ transformer,
794
+ compressor,
795
+ max_buffered_bytes,
796
+ )
797
+
798
+ @property
799
+ def field_summaries(self):
800
+ return {
801
+ name: field.vcf_field.summary for name, field in self.field_writers.items()
802
+ }
803
+
804
+ def append(self, name, value):
805
+ self.field_writers[name].append(value)
806
+
807
+ def __exit__(self, exc_type, exc_val, exc_tb):
808
+ if exc_type is None:
809
+ for field in self.field_writers.values():
810
+ field.flush()
811
+ return False
812
+
813
+
814
+ class IntermediateColumnarFormat(collections.abc.Mapping):
815
+ def __init__(self, path):
816
+ self.path = pathlib.Path(path)
817
+ # TODO raise a more informative error here telling people this
818
+ # directory is either a WIP or the wrong format.
819
+ with open(self.path / "metadata.json") as f:
820
+ self.metadata = IcfMetadata.fromdict(json.load(f))
821
+ with open(self.path / "header.txt") as f:
822
+ self.vcf_header = f.read()
823
+ self.compressor = numcodecs.get_codec(self.metadata.compressor)
824
+ self.fields = {}
825
+ partition_num_records = [
826
+ partition.num_records for partition in self.metadata.partitions
827
+ ]
828
+ # Allow us to find which partition a given record is in
829
+ self.partition_record_index = np.cumsum([0, *partition_num_records])
830
+ for field in self.metadata.fields:
831
+ self.fields[field.full_name] = IntermediateColumnarFormatField(self, field)
832
+ logger.info(
833
+ f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, "
834
+ f"records={self.num_records}, fields={self.num_fields})"
835
+ )
836
+
837
+ def __repr__(self):
838
+ return (
839
+ f"IntermediateColumnarFormat(fields={len(self)}, "
840
+ f"partitions={self.num_partitions}, "
841
+ f"records={self.num_records}, path={self.path})"
842
+ )
843
+
844
+ def __getitem__(self, key):
845
+ return self.fields[key]
846
+
847
+ def __iter__(self):
848
+ return iter(self.fields)
849
+
850
+ def __len__(self):
851
+ return len(self.fields)
852
+
853
+ def summary_table(self):
854
+ data = []
855
+ for name, col in self.fields.items():
856
+ summary = col.vcf_field.summary
857
+ d = {
858
+ "name": name,
859
+ "type": col.vcf_field.vcf_type,
860
+ "chunks": summary.num_chunks,
861
+ "size": core.display_size(summary.uncompressed_size),
862
+ "compressed": core.display_size(summary.compressed_size),
863
+ "max_n": summary.max_number,
864
+ "min_val": core.display_number(summary.min_value),
865
+ "max_val": core.display_number(summary.max_value),
866
+ }
867
+
868
+ data.append(d)
869
+ return data
870
+
871
+ @property
872
+ def num_records(self):
873
+ return self.metadata.num_records
874
+
875
+ @property
876
+ def num_partitions(self):
877
+ return len(self.metadata.partitions)
878
+
879
+ @property
880
+ def num_samples(self):
881
+ return len(self.metadata.samples)
882
+
883
+ @property
884
+ def num_fields(self):
885
+ return len(self.fields)
886
+
887
+
888
+ @dataclasses.dataclass
889
+ class IcfPartitionMetadata(core.JsonDataclass):
890
+ num_records: int
891
+ last_position: int
892
+ field_summaries: dict
893
+
894
+ @staticmethod
895
+ def fromdict(d):
896
+ md = IcfPartitionMetadata(**d)
897
+ for k, v in md.field_summaries.items():
898
+ md.field_summaries[k] = VcfFieldSummary.fromdict(v)
899
+ return md
900
+
901
+
902
+ def check_overlapping_partitions(partitions):
903
+ for i in range(1, len(partitions)):
904
+ prev_region = partitions[i - 1].region
905
+ current_region = partitions[i].region
906
+ if prev_region.contig == current_region.contig:
907
+ assert prev_region.end is not None
908
+ # Regions are *inclusive*
909
+ if prev_region.end >= current_region.start:
910
+ raise ValueError(
911
+ f"Overlapping VCF regions in partitions {i - 1} and {i}: "
912
+ f"{prev_region} and {current_region}"
913
+ )
914
+
915
+
916
+ def check_field_clobbering(icf_metadata):
917
+ info_field_names = set(field.name for field in icf_metadata.info_fields)
918
+ fixed_variant_fields = set(
919
+ ["contig", "id", "id_mask", "position", "allele", "filter", "quality"]
920
+ )
921
+ intersection = info_field_names & fixed_variant_fields
922
+ if len(intersection) > 0:
923
+ raise ValueError(
924
+ f"INFO field name(s) clashing with VCF Zarr spec: {intersection}"
925
+ )
926
+
927
+ format_field_names = set(field.name for field in icf_metadata.format_fields)
928
+ fixed_variant_fields = set(["genotype", "genotype_phased", "genotype_mask"])
929
+ intersection = format_field_names & fixed_variant_fields
930
+ if len(intersection) > 0:
931
+ raise ValueError(
932
+ f"FORMAT field name(s) clashing with VCF Zarr spec: {intersection}"
933
+ )
934
+
935
+
936
+ @dataclasses.dataclass
937
+ class IcfWriteSummary(core.JsonDataclass):
938
+ num_partitions: int
939
+ num_samples: int
940
+ num_variants: int
941
+
942
+
943
+ class IntermediateColumnarFormatWriter:
944
+ def __init__(self, path):
945
+ self.path = pathlib.Path(path)
946
+ self.wip_path = self.path / "wip"
947
+ self.metadata = None
948
+
949
+ @property
950
+ def num_partitions(self):
951
+ return len(self.metadata.partitions)
952
+
953
+ def init(
954
+ self,
955
+ vcfs,
956
+ *,
957
+ column_chunk_size=16,
958
+ worker_processes=1,
959
+ target_num_partitions=None,
960
+ show_progress=False,
961
+ compressor=None,
962
+ ):
963
+ if self.path.exists():
964
+ raise ValueError("ICF path already exists")
965
+ if compressor is None:
966
+ compressor = ICF_DEFAULT_COMPRESSOR
967
+ vcfs = [pathlib.Path(vcf) for vcf in vcfs]
968
+ target_num_partitions = max(target_num_partitions, len(vcfs))
969
+
970
+ # TODO move scan_vcfs into this class
971
+ icf_metadata, header = scan_vcfs(
972
+ vcfs,
973
+ worker_processes=worker_processes,
974
+ show_progress=show_progress,
975
+ target_num_partitions=target_num_partitions,
976
+ )
977
+ check_field_clobbering(icf_metadata)
978
+ self.metadata = icf_metadata
979
+ self.metadata.format_version = ICF_METADATA_FORMAT_VERSION
980
+ self.metadata.compressor = compressor.get_config()
981
+ self.metadata.column_chunk_size = column_chunk_size
982
+ # Bare minimum here for provenance - would be nice to include versions of key
983
+ # dependencies as well.
984
+ self.metadata.provenance = {"source": f"bio2zarr-{provenance.__version__}"}
985
+
986
+ self.mkdirs()
987
+
988
+ # Note: this is needed for the current version of the vcfzarr spec, but it's
989
+ # probably going to be dropped.
990
+ # https://github.com/pystatgen/vcf-zarr-spec/issues/15
991
+ # May be useful to keep lying around still though?
992
+ logger.info("Writing VCF header")
993
+ with open(self.path / "header.txt", "w") as f:
994
+ f.write(header)
995
+
996
+ logger.info("Writing WIP metadata")
997
+ with open(self.wip_path / "metadata.json", "w") as f:
998
+ json.dump(self.metadata.asdict(), f, indent=4)
999
+ return IcfWriteSummary(
1000
+ num_partitions=self.num_partitions,
1001
+ num_variants=icf_metadata.num_records,
1002
+ num_samples=icf_metadata.num_samples,
1003
+ )
1004
+
1005
+ def mkdirs(self):
1006
+ num_dirs = len(self.metadata.fields)
1007
+ logger.info(f"Creating {num_dirs} field directories")
1008
+ self.path.mkdir()
1009
+ self.wip_path.mkdir()
1010
+ for field in self.metadata.fields:
1011
+ col_path = get_vcf_field_path(self.path, field)
1012
+ col_path.mkdir(parents=True)
1013
+
1014
+ def load_partition_summaries(self):
1015
+ summaries = []
1016
+ not_found = []
1017
+ for j in range(self.num_partitions):
1018
+ try:
1019
+ with open(self.wip_path / f"p{j}.json") as f:
1020
+ summaries.append(IcfPartitionMetadata.fromdict(json.load(f)))
1021
+ except FileNotFoundError:
1022
+ not_found.append(j)
1023
+ if len(not_found) > 0:
1024
+ raise FileNotFoundError(
1025
+ f"Partition metadata not found for {len(not_found)}"
1026
+ f" partitions: {not_found}"
1027
+ )
1028
+ return summaries
1029
+
1030
+ def load_metadata(self):
1031
+ if self.metadata is None:
1032
+ with open(self.wip_path / "metadata.json") as f:
1033
+ self.metadata = IcfMetadata.fromdict(json.load(f))
1034
+
1035
+ def process_partition(self, partition_index):
1036
+ self.load_metadata()
1037
+ summary_path = self.wip_path / f"p{partition_index}.json"
1038
+ # If someone is rewriting a summary path (for whatever reason), make sure it
1039
+ # doesn't look like it's already been completed.
1040
+ # NOTE to do this properly we probably need to take a lock on this file - but
1041
+ # this simple approach will catch the vast majority of problems.
1042
+ if summary_path.exists():
1043
+ summary_path.unlink()
1044
+
1045
+ partition = self.metadata.partitions[partition_index]
1046
+ logger.info(
1047
+ f"Start p{partition_index} {partition.vcf_path}__{partition.region}"
1048
+ )
1049
+ info_fields = self.metadata.info_fields
1050
+ format_fields = []
1051
+ has_gt = False
1052
+ for field in self.metadata.format_fields:
1053
+ if field.name == "GT":
1054
+ has_gt = True
1055
+ else:
1056
+ format_fields.append(field)
1057
+
1058
+ last_position = None
1059
+ with IcfPartitionWriter(
1060
+ self.metadata,
1061
+ self.path,
1062
+ partition_index,
1063
+ ) as tcw:
1064
+ with vcf_utils.IndexedVcf(partition.vcf_path) as ivcf:
1065
+ num_records = 0
1066
+ for variant in ivcf.variants(partition.region):
1067
+ num_records += 1
1068
+ last_position = variant.POS
1069
+ tcw.append("CHROM", variant.CHROM)
1070
+ tcw.append("POS", variant.POS)
1071
+ tcw.append("QUAL", variant.QUAL)
1072
+ tcw.append("ID", variant.ID)
1073
+ tcw.append("FILTERS", variant.FILTERS)
1074
+ tcw.append("REF", variant.REF)
1075
+ tcw.append("ALT", variant.ALT)
1076
+ for field in info_fields:
1077
+ tcw.append(field.full_name, variant.INFO.get(field.name, None))
1078
+ if has_gt:
1079
+ tcw.append("FORMAT/GT", variant.genotype.array())
1080
+ for field in format_fields:
1081
+ val = variant.format(field.name)
1082
+ tcw.append(field.full_name, val)
1083
+ # Note: an issue with updating the progress per variant here like
1084
+ # this is that we get a significant pause at the end of the counter
1085
+ # while all the "small" fields get flushed. Possibly not much to be
1086
+ # done about it.
1087
+ core.update_progress(1)
1088
+ logger.info(
1089
+ f"Finished reading VCF for partition {partition_index}, "
1090
+ f"flushing buffers"
1091
+ )
1092
+
1093
+ partition_metadata = IcfPartitionMetadata(
1094
+ num_records=num_records,
1095
+ last_position=last_position,
1096
+ field_summaries=tcw.field_summaries,
1097
+ )
1098
+ with open(summary_path, "w") as f:
1099
+ f.write(partition_metadata.asjson())
1100
+ logger.info(
1101
+ f"Finish p{partition_index} {partition.vcf_path}__{partition.region} "
1102
+ f"{num_records} records last_pos={last_position}"
1103
+ )
1104
+
1105
+ def explode(self, *, worker_processes=1, show_progress=False):
1106
+ self.load_metadata()
1107
+ num_records = self.metadata.num_records
1108
+ if np.isinf(num_records):
1109
+ logger.warning(
1110
+ "Total records unknown, cannot show progress; "
1111
+ "reindex VCFs with bcftools index to fix"
1112
+ )
1113
+ num_records = None
1114
+ num_fields = len(self.metadata.fields)
1115
+ num_samples = len(self.metadata.samples)
1116
+ logger.info(
1117
+ f"Exploding fields={num_fields} samples={num_samples}; "
1118
+ f"partitions={self.num_partitions} "
1119
+ f"variants={'unknown' if num_records is None else num_records}"
1120
+ )
1121
+ progress_config = core.ProgressConfig(
1122
+ total=num_records,
1123
+ units="vars",
1124
+ title="Explode",
1125
+ show=show_progress,
1126
+ )
1127
+ with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1128
+ for j in range(self.num_partitions):
1129
+ pwm.submit(self.process_partition, j)
1130
+
1131
+ def explode_partition(self, partition):
1132
+ self.load_metadata()
1133
+ if partition < 0 or partition >= self.num_partitions:
1134
+ raise ValueError("Partition index not in the valid range")
1135
+ self.process_partition(partition)
1136
+
1137
+ def finalise(self):
1138
+ self.load_metadata()
1139
+ partition_summaries = self.load_partition_summaries()
1140
+ total_records = 0
1141
+ for index, summary in enumerate(partition_summaries):
1142
+ partition_records = summary.num_records
1143
+ self.metadata.partitions[index].num_records = partition_records
1144
+ self.metadata.partitions[index].region.end = summary.last_position
1145
+ total_records += partition_records
1146
+ if not np.isinf(self.metadata.num_records):
1147
+ # Note: this is just telling us that there's a bug in the
1148
+ # index based record counting code, but it doesn't actually
1149
+ # matter much. We may want to just make this a warning if
1150
+ # we hit regular problems.
1151
+ assert total_records == self.metadata.num_records
1152
+ self.metadata.num_records = total_records
1153
+
1154
+ check_overlapping_partitions(self.metadata.partitions)
1155
+
1156
+ for field in self.metadata.fields:
1157
+ for summary in partition_summaries:
1158
+ field.summary.update(summary.field_summaries[field.full_name])
1159
+
1160
+ logger.info("Finalising metadata")
1161
+ with open(self.path / "metadata.json", "w") as f:
1162
+ f.write(self.metadata.asjson())
1163
+
1164
+ logger.debug("Removing WIP directory")
1165
+ shutil.rmtree(self.wip_path)
1166
+
1167
+
1168
+ def explode(
1169
+ icf_path,
1170
+ vcfs,
1171
+ *,
1172
+ column_chunk_size=16,
1173
+ worker_processes=1,
1174
+ show_progress=False,
1175
+ compressor=None,
1176
+ ):
1177
+ writer = IntermediateColumnarFormatWriter(icf_path)
1178
+ writer.init(
1179
+ vcfs,
1180
+ # Heuristic to get reasonable worker utilisation with lumpy partition sizing
1181
+ target_num_partitions=max(1, worker_processes * 4),
1182
+ worker_processes=worker_processes,
1183
+ show_progress=show_progress,
1184
+ column_chunk_size=column_chunk_size,
1185
+ compressor=compressor,
1186
+ )
1187
+ writer.explode(worker_processes=worker_processes, show_progress=show_progress)
1188
+ writer.finalise()
1189
+ return IntermediateColumnarFormat(icf_path)
1190
+
1191
+
1192
+ def explode_init(
1193
+ icf_path,
1194
+ vcfs,
1195
+ *,
1196
+ column_chunk_size=16,
1197
+ target_num_partitions=1,
1198
+ worker_processes=1,
1199
+ show_progress=False,
1200
+ compressor=None,
1201
+ ):
1202
+ writer = IntermediateColumnarFormatWriter(icf_path)
1203
+ return writer.init(
1204
+ vcfs,
1205
+ target_num_partitions=target_num_partitions,
1206
+ worker_processes=worker_processes,
1207
+ show_progress=show_progress,
1208
+ column_chunk_size=column_chunk_size,
1209
+ compressor=compressor,
1210
+ )
1211
+
1212
+
1213
+ def explode_partition(icf_path, partition):
1214
+ writer = IntermediateColumnarFormatWriter(icf_path)
1215
+ writer.explode_partition(partition)
1216
+
1217
+
1218
+ def explode_finalise(icf_path):
1219
+ writer = IntermediateColumnarFormatWriter(icf_path)
1220
+ writer.finalise()