bio2zarr 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bio2zarr might be problematic. Click here for more details.

bio2zarr/vcf2zarr/vcz.py CHANGED
@@ -1,3 +1,4 @@
1
+ import contextlib
1
2
  import dataclasses
2
3
  import json
3
4
  import logging
@@ -12,6 +13,8 @@ import numcodecs
12
13
  import numpy as np
13
14
  import zarr
14
15
 
16
+ from bio2zarr.zarr_utils import ZARR_FORMAT_KWARGS, zarr_v3
17
+
15
18
  from .. import constants, core, provenance
16
19
  from . import icf
17
20
 
@@ -20,18 +23,31 @@ logger = logging.getLogger(__name__)
20
23
 
21
24
  def inspect(path):
22
25
  path = pathlib.Path(path)
23
- # TODO add support for the Zarr format also
26
+ if not path.exists():
27
+ raise ValueError(f"Path not found: {path}")
24
28
  if (path / "metadata.json").exists():
25
29
  obj = icf.IntermediateColumnarFormat(path)
30
+ # NOTE: this is too strict, we should support more general Zarrs, see #276
26
31
  elif (path / ".zmetadata").exists():
27
32
  obj = VcfZarr(path)
28
33
  else:
29
- raise ValueError("Format not recognised") # NEEDS TEST
34
+ raise ValueError(f"{path} not in ICF or VCF Zarr format")
30
35
  return obj.summary_table()
31
36
 
32
37
 
33
38
  DEFAULT_ZARR_COMPRESSOR = numcodecs.Blosc(cname="zstd", clevel=7)
34
39
 
40
+ _fixed_field_descriptions = {
41
+ "variant_contig": "An identifier from the reference genome or an angle-bracketed ID"
42
+ " string pointing to a contig in the assembly file",
43
+ "variant_position": "The reference position",
44
+ "variant_length": "The length of the variant measured in bases",
45
+ "variant_id": "List of unique identifiers where applicable",
46
+ "variant_allele": "List of the reference and alternate alleles",
47
+ "variant_quality": "Phred-scaled quality score",
48
+ "variant_filter": "Filter status of the variant",
49
+ }
50
+
35
51
 
36
52
  @dataclasses.dataclass
37
53
  class ZarrArraySpec:
@@ -46,6 +62,9 @@ class ZarrArraySpec:
46
62
  filters: list
47
63
 
48
64
  def __post_init__(self):
65
+ if self.name in _fixed_field_descriptions:
66
+ self.description = self.description or _fixed_field_descriptions[self.name]
67
+
49
68
  # Ensure these are tuples for ease of comparison and consistency
50
69
  self.shape = tuple(self.shape)
51
70
  self.chunks = tuple(self.chunks)
@@ -68,7 +87,7 @@ class ZarrArraySpec:
68
87
  num_samples,
69
88
  variants_chunk_size,
70
89
  samples_chunk_size,
71
- variable_name=None,
90
+ array_name=None,
72
91
  ):
73
92
  shape = [num_variants]
74
93
  prefix = "variant_"
@@ -79,11 +98,12 @@ class ZarrArraySpec:
79
98
  shape.append(num_samples)
80
99
  chunks.append(samples_chunk_size)
81
100
  dimensions.append("samples")
82
- if variable_name is None:
83
- variable_name = prefix + vcf_field.name
101
+ if array_name is None:
102
+ array_name = prefix + vcf_field.name
84
103
  # TODO make an option to add in the empty extra dimension
85
- if vcf_field.summary.max_number > 1:
104
+ if vcf_field.summary.max_number > 1 or vcf_field.full_name == "FORMAT/LAA":
86
105
  shape.append(vcf_field.summary.max_number)
106
+ chunks.append(vcf_field.summary.max_number)
87
107
  # TODO we should really be checking this to see if the named dimensions
88
108
  # are actually correct.
89
109
  if vcf_field.vcf_number == "R":
@@ -96,7 +116,7 @@ class ZarrArraySpec:
96
116
  dimensions.append(f"{vcf_field.category}_{vcf_field.name}_dim")
97
117
  return ZarrArraySpec.new(
98
118
  vcf_field=vcf_field.full_name,
99
- name=variable_name,
119
+ name=array_name,
100
120
  dtype=vcf_field.smallest_dtype(),
101
121
  shape=shape,
102
122
  chunks=chunks,
@@ -164,6 +184,62 @@ class ZarrArraySpec:
164
184
  ZARR_SCHEMA_FORMAT_VERSION = "0.4"
165
185
 
166
186
 
187
+ def convert_local_allele_field_types(fields):
188
+ """
189
+ Update the specified list of fields to include the LAA field, and to convert
190
+ any supported localisable fields to the L* counterpart.
191
+
192
+ Note that we currently support only two ALT alleles per sample, and so the
193
+ dimensions of these fields are fixed by that requirement. Later versions may
194
+ use summry data storted in the ICF to make different choices, if information
195
+ about subsequent alleles (not in the actual genotype calls) should also be
196
+ stored.
197
+ """
198
+ fields_by_name = {field.name: field for field in fields}
199
+ gt = fields_by_name["call_genotype"]
200
+ if gt.shape[-1] != 2:
201
+ raise ValueError("Local alleles only supported on diploid data")
202
+
203
+ # TODO check if LA is already in here
204
+
205
+ shape = gt.shape[:-1]
206
+ chunks = gt.chunks[:-1]
207
+ dimensions = gt.dimensions[:-1]
208
+
209
+ la = ZarrArraySpec.new(
210
+ vcf_field=None,
211
+ name="call_LA",
212
+ dtype="i1",
213
+ shape=gt.shape,
214
+ chunks=gt.chunks,
215
+ dimensions=(*dimensions, "local_alleles"),
216
+ description=(
217
+ "0-based indices into REF+ALT, indicating which alleles"
218
+ " are relevant (local) for the current sample"
219
+ ),
220
+ )
221
+ ad = fields_by_name.get("call_AD", None)
222
+ if ad is not None:
223
+ # TODO check if call_LAD is in the list already
224
+ ad.name = "call_LAD"
225
+ ad.vcf_field = None
226
+ ad.shape = (*shape, 2)
227
+ ad.chunks = (*chunks, 2)
228
+ ad.dimensions = (*dimensions, "local_alleles")
229
+ ad.description += " (local-alleles)"
230
+
231
+ pl = fields_by_name.get("call_PL", None)
232
+ if pl is not None:
233
+ # TODO check if call_LPL is in the list already
234
+ pl.name = "call_LPL"
235
+ pl.vcf_field = None
236
+ pl.shape = (*shape, 3)
237
+ pl.chunks = (*chunks, 3)
238
+ pl.description += " (local-alleles)"
239
+ pl.dimensions = (*dimensions, "local_" + pl.dimensions[-1])
240
+ return [*fields, la]
241
+
242
+
167
243
  @dataclasses.dataclass
168
244
  class VcfZarrSchema(core.JsonDataclass):
169
245
  format_version: str
@@ -214,30 +290,38 @@ class VcfZarrSchema(core.JsonDataclass):
214
290
  return VcfZarrSchema.fromdict(json.loads(s))
215
291
 
216
292
  @staticmethod
217
- def generate(icf, variants_chunk_size=None, samples_chunk_size=None):
293
+ def generate(
294
+ icf, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None
295
+ ):
218
296
  m = icf.num_records
219
297
  n = icf.num_samples
220
- # FIXME
221
298
  if samples_chunk_size is None:
222
- samples_chunk_size = 1000
299
+ samples_chunk_size = 10_000
223
300
  if variants_chunk_size is None:
224
- variants_chunk_size = 10_000
301
+ variants_chunk_size = 1000
302
+ if local_alleles is None:
303
+ local_alleles = False
225
304
  logger.info(
226
305
  f"Generating schema with chunks={variants_chunk_size, samples_chunk_size}"
227
306
  )
228
307
 
229
- def spec_from_field(field, variable_name=None):
308
+ def spec_from_field(field, array_name=None):
230
309
  return ZarrArraySpec.from_field(
231
310
  field,
232
311
  num_samples=n,
233
312
  num_variants=m,
234
313
  samples_chunk_size=samples_chunk_size,
235
314
  variants_chunk_size=variants_chunk_size,
236
- variable_name=variable_name,
315
+ array_name=array_name,
237
316
  )
238
317
 
239
318
  def fixed_field_spec(
240
- name, dtype, vcf_field=None, shape=(m,), dimensions=("variants",)
319
+ name,
320
+ dtype,
321
+ vcf_field=None,
322
+ shape=(m,),
323
+ dimensions=("variants",),
324
+ chunks=None,
241
325
  ):
242
326
  return ZarrArraySpec.new(
243
327
  vcf_field=vcf_field,
@@ -246,13 +330,13 @@ class VcfZarrSchema(core.JsonDataclass):
246
330
  shape=shape,
247
331
  description="",
248
332
  dimensions=dimensions,
249
- chunks=[variants_chunk_size],
333
+ chunks=chunks or [variants_chunk_size],
250
334
  )
251
335
 
252
- alt_col = icf.fields["ALT"]
253
- max_alleles = alt_col.vcf_field.summary.max_number + 1
336
+ alt_field = icf.fields["ALT"]
337
+ max_alleles = alt_field.vcf_field.summary.max_number + 1
254
338
 
255
- colspecs = [
339
+ array_specs = [
256
340
  fixed_field_spec(
257
341
  name="variant_contig",
258
342
  dtype=core.min_int_dtype(0, icf.metadata.num_contigs),
@@ -262,12 +346,14 @@ class VcfZarrSchema(core.JsonDataclass):
262
346
  dtype="bool",
263
347
  shape=(m, icf.metadata.num_filters),
264
348
  dimensions=["variants", "filters"],
349
+ chunks=(variants_chunk_size, icf.metadata.num_filters),
265
350
  ),
266
351
  fixed_field_spec(
267
352
  name="variant_allele",
268
353
  dtype="O",
269
354
  shape=(m, max_alleles),
270
355
  dimensions=["variants", "alleles"],
356
+ chunks=(variants_chunk_size, max_alleles),
271
357
  ),
272
358
  fixed_field_spec(
273
359
  name="variant_id",
@@ -280,28 +366,31 @@ class VcfZarrSchema(core.JsonDataclass):
280
366
  ]
281
367
  name_map = {field.full_name: field for field in icf.metadata.fields}
282
368
 
283
- # Only two of the fixed fields have a direct one-to-one mapping.
284
- colspecs.extend(
369
+ # Only three of the fixed fields have a direct one-to-one mapping.
370
+ array_specs.extend(
285
371
  [
286
- spec_from_field(name_map["QUAL"], variable_name="variant_quality"),
287
- spec_from_field(name_map["POS"], variable_name="variant_position"),
372
+ spec_from_field(name_map["QUAL"], array_name="variant_quality"),
373
+ spec_from_field(name_map["POS"], array_name="variant_position"),
374
+ spec_from_field(name_map["rlen"], array_name="variant_length"),
288
375
  ]
289
376
  )
290
- colspecs.extend([spec_from_field(field) for field in icf.metadata.info_fields])
377
+ array_specs.extend(
378
+ [spec_from_field(field) for field in icf.metadata.info_fields]
379
+ )
291
380
 
292
381
  gt_field = None
293
382
  for field in icf.metadata.format_fields:
294
383
  if field.name == "GT":
295
384
  gt_field = field
296
385
  continue
297
- colspecs.append(spec_from_field(field))
386
+ array_specs.append(spec_from_field(field))
298
387
 
299
- if gt_field is not None:
300
- ploidy = gt_field.summary.max_number - 1
388
+ if gt_field is not None and n > 0:
389
+ ploidy = max(gt_field.summary.max_number - 1, 1)
301
390
  shape = [m, n]
302
391
  chunks = [variants_chunk_size, samples_chunk_size]
303
392
  dimensions = ["variants", "samples"]
304
- colspecs.append(
393
+ array_specs.append(
305
394
  ZarrArraySpec.new(
306
395
  vcf_field=None,
307
396
  name="call_genotype_phased",
@@ -313,8 +402,9 @@ class VcfZarrSchema(core.JsonDataclass):
313
402
  )
314
403
  )
315
404
  shape += [ploidy]
405
+ chunks += [ploidy]
316
406
  dimensions += ["ploidy"]
317
- colspecs.append(
407
+ array_specs.append(
318
408
  ZarrArraySpec.new(
319
409
  vcf_field=None,
320
410
  name="call_genotype",
@@ -325,7 +415,7 @@ class VcfZarrSchema(core.JsonDataclass):
325
415
  description="",
326
416
  )
327
417
  )
328
- colspecs.append(
418
+ array_specs.append(
329
419
  ZarrArraySpec.new(
330
420
  vcf_field=None,
331
421
  name="call_genotype_mask",
@@ -337,11 +427,14 @@ class VcfZarrSchema(core.JsonDataclass):
337
427
  )
338
428
  )
339
429
 
430
+ if local_alleles:
431
+ array_specs = convert_local_allele_field_types(array_specs)
432
+
340
433
  return VcfZarrSchema(
341
434
  format_version=ZARR_SCHEMA_FORMAT_VERSION,
342
435
  samples_chunk_size=samples_chunk_size,
343
436
  variants_chunk_size=variants_chunk_size,
344
- fields=colspecs,
437
+ fields=array_specs,
345
438
  samples=icf.metadata.samples,
346
439
  contigs=icf.metadata.contigs,
347
440
  filters=icf.metadata.filters,
@@ -434,6 +527,84 @@ class VcfZarrWriterMetadata(core.JsonDataclass):
434
527
  return ret
435
528
 
436
529
 
530
+ def compute_la_field(genotypes):
531
+ """
532
+ Computes the value of the LA field for each sample given the genotypes
533
+ for a variant. The LA field lists the unique alleles observed for
534
+ each sample, including the REF.
535
+ """
536
+ v = 2**31 - 1
537
+ if np.any(genotypes >= v):
538
+ raise ValueError("Extreme allele value not supported")
539
+ G = genotypes.astype(np.int32)
540
+ if len(G) > 0:
541
+ # Anything < 0 gets mapped to -2 (pad) in the output, which comes last.
542
+ # So, to get this sorting correctly, we remap to the largest value for
543
+ # sorting, then map back. We promote the genotypes up to 32 bit for convenience
544
+ # here, assuming that we'll never have a allele of 2**31 - 1.
545
+ assert np.all(G != v)
546
+ G[G < 0] = v
547
+ G.sort(axis=1)
548
+ G[G[:, 0] == G[:, 1], 1] = -2
549
+ # Equal values result in padding also
550
+ G[G == v] = -2
551
+ return G.astype(genotypes.dtype)
552
+
553
+
554
+ def compute_lad_field(ad, la):
555
+ assert ad.shape[0] == la.shape[0]
556
+ assert la.shape[1] == 2
557
+ lad = np.full((ad.shape[0], 2), -2, dtype=ad.dtype)
558
+ homs = np.where((la[:, 0] != -2) & (la[:, 1] == -2))
559
+ lad[homs, 0] = ad[homs, la[homs, 0]]
560
+ hets = np.where(la[:, 1] != -2)
561
+ lad[hets, 0] = ad[hets, la[hets, 0]]
562
+ lad[hets, 1] = ad[hets, la[hets, 1]]
563
+ return lad
564
+
565
+
566
+ def pl_index(a, b):
567
+ """
568
+ Returns the PL index for alleles a and b.
569
+ """
570
+ return b * (b + 1) // 2 + a
571
+
572
+
573
+ def compute_lpl_field(pl, la):
574
+ lpl = np.full((pl.shape[0], 3), -2, dtype=pl.dtype)
575
+
576
+ homs = np.where((la[:, 0] != -2) & (la[:, 1] == -2))
577
+ a = la[homs, 0]
578
+ lpl[homs, 0] = pl[homs, pl_index(a, a)]
579
+
580
+ hets = np.where(la[:, 1] != -2)[0]
581
+ a = la[hets, 0]
582
+ b = la[hets, 1]
583
+ lpl[hets, 0] = pl[hets, pl_index(a, a)]
584
+ lpl[hets, 1] = pl[hets, pl_index(a, b)]
585
+ lpl[hets, 2] = pl[hets, pl_index(b, b)]
586
+
587
+ return lpl
588
+
589
+
590
+ @dataclasses.dataclass
591
+ class LocalisableFieldDescriptor:
592
+ array_name: str
593
+ vcf_field: str
594
+ sanitise: callable
595
+ convert: callable
596
+
597
+
598
+ localisable_fields = [
599
+ LocalisableFieldDescriptor(
600
+ "call_LAD", "FORMAT/AD", icf.sanitise_int_array, compute_lad_field
601
+ ),
602
+ LocalisableFieldDescriptor(
603
+ "call_LPL", "FORMAT/PL", icf.sanitise_int_array, compute_lpl_field
604
+ ),
605
+ ]
606
+
607
+
437
608
  @dataclasses.dataclass
438
609
  class VcfZarrWriteSummary(core.JsonDataclass):
439
610
  num_partitions: int
@@ -466,6 +637,12 @@ class VcfZarrWriter:
466
637
  return True
467
638
  return False
468
639
 
640
+ def has_local_alleles(self):
641
+ for field in self.schema.fields:
642
+ if field.name == "call_LA" and field.vcf_field is None:
643
+ return True
644
+ return False
645
+
469
646
  #######################
470
647
  # init
471
648
  #######################
@@ -505,8 +682,7 @@ class VcfZarrWriter:
505
682
  )
506
683
 
507
684
  self.path.mkdir()
508
- store = zarr.DirectoryStore(self.path)
509
- root = zarr.group(store=store)
685
+ root = zarr.open(store=self.path, mode="a", **ZARR_FORMAT_KWARGS)
510
686
  root.attrs.update(
511
687
  {
512
688
  "vcf_zarr_version": "0.2",
@@ -522,8 +698,7 @@ class VcfZarrWriter:
522
698
  self.wip_path.mkdir()
523
699
  self.arrays_path.mkdir()
524
700
  self.partitions_path.mkdir()
525
- store = zarr.DirectoryStore(self.arrays_path)
526
- root = zarr.group(store=store)
701
+ root = zarr.open(store=self.arrays_path, mode="a", **ZARR_FORMAT_KWARGS)
527
702
 
528
703
  total_chunks = 0
529
704
  for field in self.schema.fields:
@@ -547,7 +722,8 @@ class VcfZarrWriter:
547
722
  raise ValueError("Subsetting or reordering samples not supported currently")
548
723
  array = root.array(
549
724
  "sample_id",
550
- [sample.id for sample in self.schema.samples],
725
+ data=[sample.id for sample in self.schema.samples],
726
+ shape=len(self.schema.samples),
551
727
  dtype="str",
552
728
  compressor=DEFAULT_ZARR_COMPRESSOR,
553
729
  chunks=(self.schema.samples_chunk_size,),
@@ -558,7 +734,8 @@ class VcfZarrWriter:
558
734
  def encode_contig_id(self, root):
559
735
  array = root.array(
560
736
  "contig_id",
561
- [contig.id for contig in self.schema.contigs],
737
+ data=[contig.id for contig in self.schema.contigs],
738
+ shape=len(self.schema.contigs),
562
739
  dtype="str",
563
740
  compressor=DEFAULT_ZARR_COMPRESSOR,
564
741
  )
@@ -566,7 +743,8 @@ class VcfZarrWriter:
566
743
  if all(contig.length is not None for contig in self.schema.contigs):
567
744
  array = root.array(
568
745
  "contig_length",
569
- [contig.length for contig in self.schema.contigs],
746
+ data=[contig.length for contig in self.schema.contigs],
747
+ shape=len(self.schema.contigs),
570
748
  dtype=np.int64,
571
749
  compressor=DEFAULT_ZARR_COMPRESSOR,
572
750
  )
@@ -577,34 +755,42 @@ class VcfZarrWriter:
577
755
  # https://github.com/sgkit-dev/vcf-zarr-spec/issues/19
578
756
  array = root.array(
579
757
  "filter_id",
580
- [filt.id for filt in self.schema.filters],
758
+ data=[filt.id for filt in self.schema.filters],
759
+ shape=len(self.schema.filters),
581
760
  dtype="str",
582
761
  compressor=DEFAULT_ZARR_COMPRESSOR,
583
762
  )
584
763
  array.attrs["_ARRAY_DIMENSIONS"] = ["filters"]
585
764
 
586
- def init_array(self, root, variable, variants_dim_size):
587
- object_codec = None
588
- if variable.dtype == "O":
589
- object_codec = numcodecs.VLenUTF8()
590
- shape = list(variable.shape)
765
+ def init_array(self, root, array_spec, variants_dim_size):
766
+ kwargs = dict(ZARR_FORMAT_KWARGS)
767
+ filters = [numcodecs.get_codec(filt) for filt in array_spec.filters]
768
+ if array_spec.dtype == "O":
769
+ if zarr_v3():
770
+ filters = [*list(filters), numcodecs.VLenUTF8()]
771
+ else:
772
+ kwargs["object_codec"] = numcodecs.VLenUTF8()
773
+
774
+ if not zarr_v3():
775
+ kwargs["dimension_separator"] = self.metadata.dimension_separator
776
+
777
+ shape = list(array_spec.shape)
591
778
  # Truncate the variants dimension is max_variant_chunks was specified
592
779
  shape[0] = variants_dim_size
593
780
  a = root.empty(
594
- variable.name,
781
+ name=array_spec.name,
595
782
  shape=shape,
596
- chunks=variable.chunks,
597
- dtype=variable.dtype,
598
- compressor=numcodecs.get_codec(variable.compressor),
599
- filters=[numcodecs.get_codec(filt) for filt in variable.filters],
600
- object_codec=object_codec,
601
- dimension_separator=self.metadata.dimension_separator,
783
+ chunks=array_spec.chunks,
784
+ dtype=array_spec.dtype,
785
+ compressor=numcodecs.get_codec(array_spec.compressor),
786
+ filters=filters,
787
+ **kwargs,
602
788
  )
603
789
  a.attrs.update(
604
790
  {
605
- "description": variable.description,
791
+ "description": array_spec.description,
606
792
  # Dimension names are part of the spec in Zarr v3
607
- "_ARRAY_DIMENSIONS": variable.dimensions,
793
+ "_ARRAY_DIMENSIONS": array_spec.dimensions,
608
794
  }
609
795
  )
610
796
  logger.debug(f"Initialised {a}")
@@ -644,11 +830,15 @@ class VcfZarrWriter:
644
830
  self.encode_filters_partition(partition_index)
645
831
  self.encode_contig_partition(partition_index)
646
832
  self.encode_alleles_partition(partition_index)
647
- for col in self.schema.fields:
648
- if col.vcf_field is not None:
649
- self.encode_array_partition(col, partition_index)
833
+ for array_spec in self.schema.fields:
834
+ if array_spec.vcf_field is not None:
835
+ self.encode_array_partition(array_spec, partition_index)
650
836
  if self.has_genotypes():
651
837
  self.encode_genotypes_partition(partition_index)
838
+ self.encode_genotype_mask_partition(partition_index)
839
+ if self.has_local_alleles():
840
+ self.encode_local_alleles_partition(partition_index)
841
+ self.encode_local_allele_fields_partition(partition_index)
652
842
 
653
843
  final_path = self.partition_path(partition_index)
654
844
  logger.info(f"Finalising {partition_index} at {final_path}")
@@ -658,95 +848,144 @@ class VcfZarrWriter:
658
848
  os.rename(partition_path, final_path)
659
849
 
660
850
  def init_partition_array(self, partition_index, name):
851
+ field_map = self.schema.field_map()
852
+ array_spec = field_map[name]
661
853
  # Create an empty array like the definition
662
- src = self.arrays_path / name
854
+ src = self.arrays_path / array_spec.name
663
855
  # Overwrite any existing WIP files
664
- wip_path = self.wip_partition_array_path(partition_index, name)
856
+ wip_path = self.wip_partition_array_path(partition_index, array_spec.name)
665
857
  shutil.copytree(src, wip_path, dirs_exist_ok=True)
666
- store = zarr.DirectoryStore(self.wip_partition_path(partition_index))
667
- wip_root = zarr.group(store=store)
668
- array = wip_root[name]
669
- logger.debug(f"Opened empty array {array.name} <{array.dtype}> @ {wip_path}")
670
- return array
671
-
672
- def finalise_partition_array(self, partition_index, name):
673
- logger.debug(f"Encoded {name} partition {partition_index}")
858
+ array = zarr.open_array(store=wip_path, mode="a")
859
+ partition = self.metadata.partitions[partition_index]
860
+ ba = core.BufferedArray(array, partition.start, name)
861
+ logger.info(
862
+ f"Start partition {partition_index} array {name} <{array.dtype}> "
863
+ f"{array.shape} @ {wip_path}"
864
+ )
865
+ return ba
674
866
 
675
- def encode_array_partition(self, column, partition_index):
676
- array = self.init_partition_array(partition_index, column.name)
867
+ def finalise_partition_array(self, partition_index, buffered_array):
868
+ buffered_array.flush()
869
+ logger.info(
870
+ f"Completed partition {partition_index} array {buffered_array.name} "
871
+ f"max_memory={core.display_size(buffered_array.max_buff_size)}"
872
+ )
677
873
 
874
+ def encode_array_partition(self, array_spec, partition_index):
678
875
  partition = self.metadata.partitions[partition_index]
679
- ba = core.BufferedArray(array, partition.start)
680
- source_col = self.icf.fields[column.vcf_field]
681
- sanitiser = source_col.sanitiser_factory(ba.buff.shape)
876
+ ba = self.init_partition_array(partition_index, array_spec.name)
877
+ source_field = self.icf.fields[array_spec.vcf_field]
878
+ sanitiser = source_field.sanitiser_factory(ba.buff.shape)
682
879
 
683
- for value in source_col.iter_values(partition.start, partition.stop):
880
+ for value in source_field.iter_values(partition.start, partition.stop):
684
881
  # We write directly into the buffer in the sanitiser function
685
882
  # to make it easier to reason about dimension padding
686
883
  j = ba.next_buffer_row()
687
884
  sanitiser(ba.buff, j, value)
688
- ba.flush()
689
- self.finalise_partition_array(partition_index, column.name)
885
+ self.finalise_partition_array(partition_index, ba)
690
886
 
691
887
  def encode_genotypes_partition(self, partition_index):
692
- gt_array = self.init_partition_array(partition_index, "call_genotype")
693
- gt_mask_array = self.init_partition_array(partition_index, "call_genotype_mask")
694
- gt_phased_array = self.init_partition_array(
695
- partition_index, "call_genotype_phased"
696
- )
697
-
698
888
  partition = self.metadata.partitions[partition_index]
699
- gt = core.BufferedArray(gt_array, partition.start)
700
- gt_mask = core.BufferedArray(gt_mask_array, partition.start)
701
- gt_phased = core.BufferedArray(gt_phased_array, partition.start)
889
+ gt = self.init_partition_array(partition_index, "call_genotype")
890
+ gt_phased = self.init_partition_array(partition_index, "call_genotype_phased")
702
891
 
703
- source_col = self.icf.fields["FORMAT/GT"]
704
- for value in source_col.iter_values(partition.start, partition.stop):
892
+ source_field = self.icf.fields["FORMAT/GT"]
893
+ for value in source_field.iter_values(partition.start, partition.stop):
705
894
  j = gt.next_buffer_row()
706
- icf.sanitise_value_int_2d(gt.buff, j, value[:, :-1])
895
+ icf.sanitise_value_int_2d(
896
+ gt.buff, j, value[:, :-1] if value is not None else None
897
+ )
707
898
  j = gt_phased.next_buffer_row()
708
- icf.sanitise_value_int_1d(gt_phased.buff, j, value[:, -1])
899
+ icf.sanitise_value_int_1d(
900
+ gt_phased.buff, j, value[:, -1] if value is not None else None
901
+ )
902
+
903
+ self.finalise_partition_array(partition_index, gt)
904
+ self.finalise_partition_array(partition_index, gt_phased)
905
+
906
+ def encode_genotype_mask_partition(self, partition_index):
907
+ partition = self.metadata.partitions[partition_index]
908
+ gt_mask = self.init_partition_array(partition_index, "call_genotype_mask")
909
+ # Read back in the genotypes so we can compute the mask
910
+ gt_array = zarr.open_array(
911
+ store=self.wip_partition_array_path(partition_index, "call_genotype"),
912
+ mode="r",
913
+ )
914
+ for genotypes in core.first_dim_slice_iter(
915
+ gt_array, partition.start, partition.stop
916
+ ):
709
917
  # TODO check is this the correct semantics when we are padding
710
918
  # with mixed ploidies?
711
919
  j = gt_mask.next_buffer_row()
712
- gt_mask.buff[j] = gt.buff[j] < 0
713
- gt.flush()
714
- gt_phased.flush()
715
- gt_mask.flush()
920
+ gt_mask.buff[j] = genotypes < 0
921
+ self.finalise_partition_array(partition_index, gt_mask)
922
+
923
+ def encode_local_alleles_partition(self, partition_index):
924
+ partition = self.metadata.partitions[partition_index]
925
+ call_LA = self.init_partition_array(partition_index, "call_LA")
926
+
927
+ gt_array = zarr.open_array(
928
+ store=self.wip_partition_array_path(partition_index, "call_genotype"),
929
+ mode="r",
930
+ )
931
+ for genotypes in core.first_dim_slice_iter(
932
+ gt_array, partition.start, partition.stop
933
+ ):
934
+ la = compute_la_field(genotypes)
935
+ j = call_LA.next_buffer_row()
936
+ call_LA.buff[j] = la
937
+ self.finalise_partition_array(partition_index, call_LA)
716
938
 
717
- self.finalise_partition_array(partition_index, "call_genotype")
718
- self.finalise_partition_array(partition_index, "call_genotype_mask")
719
- self.finalise_partition_array(partition_index, "call_genotype_phased")
939
+ def encode_local_allele_fields_partition(self, partition_index):
940
+ partition = self.metadata.partitions[partition_index]
941
+ la_array = zarr.open_array(
942
+ store=self.wip_partition_array_path(partition_index, "call_LA"),
943
+ mode="r",
944
+ )
945
+ # We got through the localisable fields one-by-one so that we don't need to
946
+ # keep several large arrays in memory at once for each partition.
947
+ field_map = self.schema.field_map()
948
+ for descriptor in localisable_fields:
949
+ if descriptor.array_name not in field_map:
950
+ continue
951
+ assert field_map[descriptor.array_name].vcf_field is None
952
+
953
+ buff = self.init_partition_array(partition_index, descriptor.array_name)
954
+ source = self.icf.fields[descriptor.vcf_field].iter_values(
955
+ partition.start, partition.stop
956
+ )
957
+ for la in core.first_dim_slice_iter(
958
+ la_array, partition.start, partition.stop
959
+ ):
960
+ raw_value = next(source)
961
+ value = descriptor.sanitise(raw_value, 2, raw_value.dtype)
962
+ j = buff.next_buffer_row()
963
+ buff.buff[j] = descriptor.convert(value, la)
964
+ self.finalise_partition_array(partition_index, buff)
720
965
 
721
966
  def encode_alleles_partition(self, partition_index):
722
- array_name = "variant_allele"
723
- alleles_array = self.init_partition_array(partition_index, array_name)
967
+ alleles = self.init_partition_array(partition_index, "variant_allele")
724
968
  partition = self.metadata.partitions[partition_index]
725
- alleles = core.BufferedArray(alleles_array, partition.start)
726
- ref_col = self.icf.fields["REF"]
727
- alt_col = self.icf.fields["ALT"]
969
+ ref_field = self.icf.fields["REF"]
970
+ alt_field = self.icf.fields["ALT"]
728
971
 
729
972
  for ref, alt in zip(
730
- ref_col.iter_values(partition.start, partition.stop),
731
- alt_col.iter_values(partition.start, partition.stop),
973
+ ref_field.iter_values(partition.start, partition.stop),
974
+ alt_field.iter_values(partition.start, partition.stop),
732
975
  ):
733
976
  j = alleles.next_buffer_row()
734
977
  alleles.buff[j, :] = constants.STR_FILL
735
978
  alleles.buff[j, 0] = ref[0]
736
979
  alleles.buff[j, 1 : 1 + len(alt)] = alt
737
- alleles.flush()
738
-
739
- self.finalise_partition_array(partition_index, array_name)
980
+ self.finalise_partition_array(partition_index, alleles)
740
981
 
741
982
  def encode_id_partition(self, partition_index):
742
- vid_array = self.init_partition_array(partition_index, "variant_id")
743
- vid_mask_array = self.init_partition_array(partition_index, "variant_id_mask")
983
+ vid = self.init_partition_array(partition_index, "variant_id")
984
+ vid_mask = self.init_partition_array(partition_index, "variant_id_mask")
744
985
  partition = self.metadata.partitions[partition_index]
745
- vid = core.BufferedArray(vid_array, partition.start)
746
- vid_mask = core.BufferedArray(vid_mask_array, partition.start)
747
- col = self.icf.fields["ID"]
986
+ field = self.icf.fields["ID"]
748
987
 
749
- for value in col.iter_values(partition.start, partition.stop):
988
+ for value in field.iter_values(partition.start, partition.stop):
750
989
  j = vid.next_buffer_row()
751
990
  k = vid_mask.next_buffer_row()
752
991
  assert j == k
@@ -756,21 +995,17 @@ class VcfZarrWriter:
756
995
  else:
757
996
  vid.buff[j] = constants.STR_MISSING
758
997
  vid_mask.buff[j] = True
759
- vid.flush()
760
- vid_mask.flush()
761
998
 
762
- self.finalise_partition_array(partition_index, "variant_id")
763
- self.finalise_partition_array(partition_index, "variant_id_mask")
999
+ self.finalise_partition_array(partition_index, vid)
1000
+ self.finalise_partition_array(partition_index, vid_mask)
764
1001
 
765
1002
  def encode_filters_partition(self, partition_index):
766
1003
  lookup = {filt.id: index for index, filt in enumerate(self.schema.filters)}
767
- array_name = "variant_filter"
768
- array = self.init_partition_array(partition_index, array_name)
1004
+ var_filter = self.init_partition_array(partition_index, "variant_filter")
769
1005
  partition = self.metadata.partitions[partition_index]
770
- var_filter = core.BufferedArray(array, partition.start)
771
1006
 
772
- col = self.icf.fields["FILTERS"]
773
- for value in col.iter_values(partition.start, partition.stop):
1007
+ field = self.icf.fields["FILTERS"]
1008
+ for value in field.iter_values(partition.start, partition.stop):
774
1009
  j = var_filter.next_buffer_row()
775
1010
  var_filter.buff[j] = False
776
1011
  for f in value:
@@ -780,28 +1015,24 @@ class VcfZarrWriter:
780
1015
  raise ValueError(
781
1016
  f"Filter '{f}' was not defined in the header."
782
1017
  ) from None
783
- var_filter.flush()
784
1018
 
785
- self.finalise_partition_array(partition_index, array_name)
1019
+ self.finalise_partition_array(partition_index, var_filter)
786
1020
 
787
1021
  def encode_contig_partition(self, partition_index):
788
1022
  lookup = {contig.id: index for index, contig in enumerate(self.schema.contigs)}
789
- array_name = "variant_contig"
790
- array = self.init_partition_array(partition_index, array_name)
1023
+ contig = self.init_partition_array(partition_index, "variant_contig")
791
1024
  partition = self.metadata.partitions[partition_index]
792
- contig = core.BufferedArray(array, partition.start)
793
- col = self.icf.fields["CHROM"]
1025
+ field = self.icf.fields["CHROM"]
794
1026
 
795
- for value in col.iter_values(partition.start, partition.stop):
1027
+ for value in field.iter_values(partition.start, partition.stop):
796
1028
  j = contig.next_buffer_row()
797
1029
  # Note: because we are using the indexes to define the lookups
798
1030
  # and we always have an index, it seems that we the contig lookup
799
1031
  # will always succeed. However, if anyone ever does hit a KeyError
800
1032
  # here, please do open an issue with a reproducible example!
801
1033
  contig.buff[j] = lookup[value[0]]
802
- contig.flush()
803
1034
 
804
- self.finalise_partition_array(partition_index, array_name)
1035
+ self.finalise_partition_array(partition_index, contig)
805
1036
 
806
1037
  #######################
807
1038
  # finalise
@@ -871,6 +1102,68 @@ class VcfZarrWriter:
871
1102
  logger.info("Consolidating Zarr metadata")
872
1103
  zarr.consolidate_metadata(self.path)
873
1104
 
1105
+ #######################
1106
+ # index
1107
+ #######################
1108
+
1109
+ def create_index(self):
1110
+ """Create an index to support efficient region queries."""
1111
+
1112
+ root = zarr.open_group(store=self.path, mode="r+")
1113
+
1114
+ contig = root["variant_contig"]
1115
+ pos = root["variant_position"]
1116
+ length = root["variant_length"]
1117
+
1118
+ assert contig.cdata_shape == pos.cdata_shape
1119
+
1120
+ index = []
1121
+
1122
+ logger.info("Creating region index")
1123
+ for v_chunk in range(pos.cdata_shape[0]):
1124
+ c = contig.blocks[v_chunk]
1125
+ p = pos.blocks[v_chunk]
1126
+ e = p + length.blocks[v_chunk] - 1
1127
+
1128
+ # create a row for each contig in the chunk
1129
+ d = np.diff(c, append=-1)
1130
+ c_start_idx = 0
1131
+ for c_end_idx in np.nonzero(d)[0]:
1132
+ assert c[c_start_idx] == c[c_end_idx]
1133
+ index.append(
1134
+ (
1135
+ v_chunk, # chunk index
1136
+ c[c_start_idx], # contig ID
1137
+ p[c_start_idx], # start
1138
+ p[c_end_idx], # end
1139
+ np.max(e[c_start_idx : c_end_idx + 1]), # max end
1140
+ c_end_idx - c_start_idx + 1, # num records
1141
+ )
1142
+ )
1143
+ c_start_idx = c_end_idx + 1
1144
+
1145
+ index = np.array(index, dtype=pos.dtype)
1146
+ kwargs = {}
1147
+ if not zarr_v3():
1148
+ kwargs["dimension_separator"] = self.metadata.dimension_separator
1149
+ array = root.array(
1150
+ "region_index",
1151
+ data=index,
1152
+ shape=index.shape,
1153
+ chunks=index.shape,
1154
+ dtype=index.dtype,
1155
+ compressor=numcodecs.Blosc("zstd", clevel=9, shuffle=0),
1156
+ fill_value=None,
1157
+ **kwargs,
1158
+ )
1159
+ array.attrs["_ARRAY_DIMENSIONS"] = [
1160
+ "region_index_values",
1161
+ "region_index_fields",
1162
+ ]
1163
+
1164
+ logger.info("Consolidating Zarr metadata")
1165
+ zarr.consolidate_metadata(self.path)
1166
+
874
1167
  ######################
875
1168
  # encode_all_partitions
876
1169
  ######################
@@ -880,8 +1173,8 @@ class VcfZarrWriter:
880
1173
  Return the approximate maximum memory used to encode a variant chunk.
881
1174
  """
882
1175
  max_encoding_mem = 0
883
- for col in self.schema.fields:
884
- max_encoding_mem = max(max_encoding_mem, col.variant_chunk_nbytes)
1176
+ for array_spec in self.schema.fields:
1177
+ max_encoding_mem = max(max_encoding_mem, array_spec.variant_chunk_nbytes)
885
1178
  gt_mem = 0
886
1179
  if self.has_genotypes:
887
1180
  gt_mem = sum(
@@ -921,9 +1214,9 @@ class VcfZarrWriter:
921
1214
  num_workers = min(max_num_workers, worker_processes)
922
1215
 
923
1216
  total_bytes = 0
924
- for col in self.schema.fields:
1217
+ for array_spec in self.schema.fields:
925
1218
  # Open the array definition to get the total size
926
- total_bytes += zarr.open(self.arrays_path / col.name).nbytes
1219
+ total_bytes += zarr.open(self.arrays_path / array_spec.name).nbytes
927
1220
 
928
1221
  progress_config = core.ProgressConfig(
929
1222
  total=total_bytes,
@@ -936,9 +1229,21 @@ class VcfZarrWriter:
936
1229
  pwm.submit(self.encode_partition, partition_index)
937
1230
 
938
1231
 
939
- def mkschema(if_path, out):
1232
+ def mkschema(
1233
+ if_path,
1234
+ out,
1235
+ *,
1236
+ variants_chunk_size=None,
1237
+ samples_chunk_size=None,
1238
+ local_alleles=None,
1239
+ ):
940
1240
  store = icf.IntermediateColumnarFormat(if_path)
941
- spec = VcfZarrSchema.generate(store)
1241
+ spec = VcfZarrSchema.generate(
1242
+ store,
1243
+ variants_chunk_size=variants_chunk_size,
1244
+ samples_chunk_size=samples_chunk_size,
1245
+ local_alleles=local_alleles,
1246
+ )
942
1247
  out.write(spec.asjson())
943
1248
 
944
1249
 
@@ -951,6 +1256,7 @@ def encode(
951
1256
  max_variant_chunks=None,
952
1257
  dimension_separator=None,
953
1258
  max_memory=None,
1259
+ local_alleles=None,
954
1260
  worker_processes=1,
955
1261
  show_progress=False,
956
1262
  ):
@@ -963,6 +1269,7 @@ def encode(
963
1269
  schema_path=schema_path,
964
1270
  variants_chunk_size=variants_chunk_size,
965
1271
  samples_chunk_size=samples_chunk_size,
1272
+ local_alleles=local_alleles,
966
1273
  max_variant_chunks=max_variant_chunks,
967
1274
  dimension_separator=dimension_separator,
968
1275
  )
@@ -973,6 +1280,7 @@ def encode(
973
1280
  max_memory=max_memory,
974
1281
  )
975
1282
  vzw.finalise(show_progress)
1283
+ vzw.create_index()
976
1284
 
977
1285
 
978
1286
  def encode_init(
@@ -983,6 +1291,7 @@ def encode_init(
983
1291
  schema_path=None,
984
1292
  variants_chunk_size=None,
985
1293
  samples_chunk_size=None,
1294
+ local_alleles=None,
986
1295
  max_variant_chunks=None,
987
1296
  dimension_separator=None,
988
1297
  max_memory=None,
@@ -995,6 +1304,7 @@ def encode_init(
995
1304
  icf_store,
996
1305
  variants_chunk_size=variants_chunk_size,
997
1306
  samples_chunk_size=samples_chunk_size,
1307
+ local_alleles=local_alleles,
998
1308
  )
999
1309
  else:
1000
1310
  logger.info(f"Reading schema from {schema_path}")
@@ -1032,22 +1342,34 @@ def convert(
1032
1342
  variants_chunk_size=None,
1033
1343
  samples_chunk_size=None,
1034
1344
  worker_processes=1,
1345
+ local_alleles=None,
1035
1346
  show_progress=False,
1036
- # TODO add arguments to control location of tmpdir
1347
+ icf_path=None,
1037
1348
  ):
1038
- with tempfile.TemporaryDirectory(prefix="vcf2zarr") as tmp:
1039
- if_dir = pathlib.Path(tmp) / "icf"
1349
+ if icf_path is None:
1350
+ cm = temp_icf_path(prefix="vcf2zarr")
1351
+ else:
1352
+ cm = contextlib.nullcontext(icf_path)
1353
+
1354
+ with cm as icf_path:
1040
1355
  icf.explode(
1041
- if_dir,
1356
+ icf_path,
1042
1357
  vcfs,
1043
1358
  worker_processes=worker_processes,
1044
1359
  show_progress=show_progress,
1045
1360
  )
1046
1361
  encode(
1047
- if_dir,
1362
+ icf_path,
1048
1363
  out_path,
1049
1364
  variants_chunk_size=variants_chunk_size,
1050
1365
  samples_chunk_size=samples_chunk_size,
1051
1366
  worker_processes=worker_processes,
1052
1367
  show_progress=show_progress,
1368
+ local_alleles=local_alleles,
1053
1369
  )
1370
+
1371
+
1372
+ @contextlib.contextmanager
1373
+ def temp_icf_path(prefix=None):
1374
+ with tempfile.TemporaryDirectory(prefix=prefix) as tmp:
1375
+ yield pathlib.Path(tmp) / "icf"