bio2zarr 0.0.6__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bio2zarr/vcf.py CHANGED
@@ -1,17 +1,17 @@
1
1
  import collections
2
2
  import contextlib
3
3
  import dataclasses
4
- import functools
5
4
  import json
6
5
  import logging
7
6
  import math
8
7
  import os
8
+ import os.path
9
9
  import pathlib
10
10
  import pickle
11
11
  import shutil
12
12
  import sys
13
13
  import tempfile
14
- from typing import Any, List
14
+ from typing import Any
15
15
 
16
16
  import cyvcf2
17
17
  import humanfriendly
@@ -144,29 +144,41 @@ class VcfPartition:
144
144
  num_records: int = -1
145
145
 
146
146
 
147
- ICF_METADATA_FORMAT_VERSION = "0.2"
147
+ ICF_METADATA_FORMAT_VERSION = "0.3"
148
148
  ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc(
149
149
  cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE
150
150
  )
151
151
 
152
- # TODO refactor this to have embedded Contig dataclass, Filters
153
- # and Samples dataclasses to allow for more information to be
154
- # retained and forward compatibility.
152
+
153
+ @dataclasses.dataclass
154
+ class Contig:
155
+ id: str
156
+ length: int = None
157
+
158
+
159
+ @dataclasses.dataclass
160
+ class Sample:
161
+ id: str
162
+
163
+
164
+ @dataclasses.dataclass
165
+ class Filter:
166
+ id: str
167
+ description: str = ""
155
168
 
156
169
 
157
170
  @dataclasses.dataclass
158
171
  class IcfMetadata:
159
172
  samples: list
160
- contig_names: list
161
- contig_record_counts: dict
173
+ contigs: list
162
174
  filters: list
163
175
  fields: list
164
176
  partitions: list = None
165
- contig_lengths: list = None
166
177
  format_version: str = None
167
178
  compressor: dict = None
168
179
  column_chunk_size: int = None
169
180
  provenance: dict = None
181
+ num_records: int = -1
170
182
 
171
183
  @property
172
184
  def info_fields(self):
@@ -186,16 +198,12 @@ class IcfMetadata:
186
198
 
187
199
  @property
188
200
  def num_contigs(self):
189
- return len(self.contig_names)
201
+ return len(self.contigs)
190
202
 
191
203
  @property
192
204
  def num_filters(self):
193
205
  return len(self.filters)
194
206
 
195
- @property
196
- def num_records(self):
197
- return sum(self.contig_record_counts.values())
198
-
199
207
  @staticmethod
200
208
  def fromdict(d):
201
209
  if d["format_version"] != ICF_METADATA_FORMAT_VERSION:
@@ -203,18 +211,23 @@ class IcfMetadata:
203
211
  "Intermediate columnar metadata format version mismatch: "
204
212
  f"{d['format_version']} != {ICF_METADATA_FORMAT_VERSION}"
205
213
  )
206
- fields = [VcfField.fromdict(fd) for fd in d["fields"]]
207
214
  partitions = [VcfPartition(**pd) for pd in d["partitions"]]
208
215
  for p in partitions:
209
216
  p.region = vcf_utils.Region(**p.region)
210
217
  d = d.copy()
211
- d["fields"] = fields
212
218
  d["partitions"] = partitions
219
+ d["fields"] = [VcfField.fromdict(fd) for fd in d["fields"]]
220
+ d["samples"] = [Sample(**sd) for sd in d["samples"]]
221
+ d["filters"] = [Filter(**fd) for fd in d["filters"]]
222
+ d["contigs"] = [Contig(**cd) for cd in d["contigs"]]
213
223
  return IcfMetadata(**d)
214
224
 
215
225
  def asdict(self):
216
226
  return dataclasses.asdict(self)
217
227
 
228
+ def asjson(self):
229
+ return json.dumps(self.asdict(), indent=4)
230
+
218
231
 
219
232
  def fixed_vcf_field_definitions():
220
233
  def make_field_def(name, vcf_type, vcf_number):
@@ -242,15 +255,22 @@ def fixed_vcf_field_definitions():
242
255
  def scan_vcf(path, target_num_partitions):
243
256
  with vcf_utils.IndexedVcf(path) as indexed_vcf:
244
257
  vcf = indexed_vcf.vcf
245
- filters = [
246
- h["ID"]
247
- for h in vcf.header_iter()
248
- if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str)
249
- ]
258
+ filters = []
259
+ pass_index = -1
260
+ for h in vcf.header_iter():
261
+ if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str):
262
+ try:
263
+ description = h["Description"].strip('"')
264
+ except KeyError:
265
+ description = ""
266
+ if h["ID"] == "PASS":
267
+ pass_index = len(filters)
268
+ filters.append(Filter(h["ID"], description))
269
+
250
270
  # Ensure PASS is the first filter if present
251
- if "PASS" in filters:
252
- filters.remove("PASS")
253
- filters.insert(0, "PASS")
271
+ if pass_index > 0:
272
+ pass_filter = filters.pop(pass_index)
273
+ filters.insert(0, pass_filter)
254
274
 
255
275
  fields = fixed_vcf_field_definitions()
256
276
  for h in vcf.header_iter():
@@ -261,18 +281,22 @@ def scan_vcf(path, target_num_partitions):
261
281
  field.vcf_number = "."
262
282
  fields.append(field)
263
283
 
284
+ try:
285
+ contig_lengths = vcf.seqlens
286
+ except AttributeError:
287
+ contig_lengths = [None for _ in vcf.seqnames]
288
+
264
289
  metadata = IcfMetadata(
265
- samples=vcf.samples,
266
- contig_names=vcf.seqnames,
267
- contig_record_counts=indexed_vcf.contig_record_counts(),
290
+ samples=[Sample(sample_id) for sample_id in vcf.samples],
291
+ contigs=[
292
+ Contig(contig_id, length)
293
+ for contig_id, length in zip(vcf.seqnames, contig_lengths)
294
+ ],
268
295
  filters=filters,
269
296
  fields=fields,
270
297
  partitions=[],
298
+ num_records=sum(indexed_vcf.contig_record_counts().values()),
271
299
  )
272
- try:
273
- metadata.contig_lengths = vcf.seqlens
274
- except AttributeError:
275
- pass
276
300
 
277
301
  regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions)
278
302
  logger.info(
@@ -291,21 +315,6 @@ def scan_vcf(path, target_num_partitions):
291
315
  return metadata, vcf.raw_header
292
316
 
293
317
 
294
- def check_overlap(partitions):
295
- for i in range(1, len(partitions)):
296
- prev_partition = partitions[i - 1]
297
- current_partition = partitions[i]
298
- if (
299
- prev_partition.region.contig == current_partition.region.contig
300
- and prev_partition.region.end > current_partition.region.start
301
- ):
302
- raise ValueError(
303
- f"Multiple VCFs have the region "
304
- f"{prev_partition.region.contig}:{prev_partition.region.start}-"
305
- f"{current_partition.region.end}"
306
- )
307
-
308
-
309
318
  def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
310
319
  logger.info(
311
320
  f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
@@ -334,27 +343,30 @@ def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
334
343
  # We just take the first header, assuming the others
335
344
  # are compatible.
336
345
  all_partitions = []
337
- contig_record_counts = collections.Counter()
346
+ total_records = 0
338
347
  for metadata, _ in results:
339
- all_partitions.extend(metadata.partitions)
340
- metadata.partitions.clear()
341
- contig_record_counts += metadata.contig_record_counts
342
- metadata.contig_record_counts.clear()
348
+ for partition in metadata.partitions:
349
+ logger.debug(f"Scanned partition {partition}")
350
+ all_partitions.append(partition)
351
+ total_records += metadata.num_records
352
+ metadata.num_records = 0
353
+ metadata.partitions = []
343
354
 
344
355
  icf_metadata, header = results[0]
345
356
  for metadata, _ in results[1:]:
346
357
  if metadata != icf_metadata:
347
358
  raise ValueError("Incompatible VCF chunks")
348
359
 
349
- icf_metadata.contig_record_counts = dict(contig_record_counts)
360
+ # Note: this will be infinity here if any of the chunks has an index
361
+ # that doesn't keep track of the number of records per-contig
362
+ icf_metadata.num_records = total_records
350
363
 
351
364
  # Sort by contig (in the order they appear in the header) first,
352
365
  # then by start coordinate
353
- contig_index_map = {contig: j for j, contig in enumerate(metadata.contig_names)}
366
+ contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)}
354
367
  all_partitions.sort(
355
368
  key=lambda x: (contig_index_map[x.region.contig], x.region.start)
356
369
  )
357
- check_overlap(all_partitions)
358
370
  icf_metadata.partitions = all_partitions
359
371
  logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.")
360
372
  return icf_metadata, header
@@ -452,7 +464,7 @@ def sanitise_value_float_2d(buff, j, value):
452
464
 
453
465
  def sanitise_int_array(value, ndmin, dtype):
454
466
  if isinstance(value, tuple):
455
- value = [VCF_INT_MISSING if x is None else x for x in value] # NEEDS TEST
467
+ value = [VCF_INT_MISSING if x is None else x for x in value] # NEEDS TEST
456
468
  value = np.array(value, ndmin=ndmin, copy=False)
457
469
  value[value == VCF_INT_MISSING] = -1
458
470
  value[value == VCF_INT_FILL] = -2
@@ -745,9 +757,9 @@ class IcfFieldWriter:
745
757
  transformer: VcfValueTransformer
746
758
  compressor: Any
747
759
  max_buffered_bytes: int
748
- buff: List[Any] = dataclasses.field(default_factory=list)
760
+ buff: list[Any] = dataclasses.field(default_factory=list)
749
761
  buffered_bytes: int = 0
750
- chunk_index: List[int] = dataclasses.field(default_factory=lambda: [0])
762
+ chunk_index: list[int] = dataclasses.field(default_factory=lambda: [0])
751
763
  num_records: int = 0
752
764
 
753
765
  def append(self, val):
@@ -851,19 +863,18 @@ class IntermediateColumnarFormat(collections.abc.Mapping):
851
863
  self.metadata = IcfMetadata.fromdict(json.load(f))
852
864
  with open(self.path / "header.txt") as f:
853
865
  self.vcf_header = f.read()
854
-
855
866
  self.compressor = numcodecs.get_codec(self.metadata.compressor)
856
- self.columns = {}
867
+ self.fields = {}
857
868
  partition_num_records = [
858
869
  partition.num_records for partition in self.metadata.partitions
859
870
  ]
860
871
  # Allow us to find which partition a given record is in
861
872
  self.partition_record_index = np.cumsum([0, *partition_num_records])
862
873
  for field in self.metadata.fields:
863
- self.columns[field.full_name] = IntermediateColumnarFormatField(self, field)
874
+ self.fields[field.full_name] = IntermediateColumnarFormatField(self, field)
864
875
  logger.info(
865
876
  f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, "
866
- f"records={self.num_records}, columns={self.num_columns})"
877
+ f"records={self.num_records}, fields={self.num_fields})"
867
878
  )
868
879
 
869
880
  def __repr__(self):
@@ -874,17 +885,17 @@ class IntermediateColumnarFormat(collections.abc.Mapping):
874
885
  )
875
886
 
876
887
  def __getitem__(self, key):
877
- return self.columns[key]
888
+ return self.fields[key]
878
889
 
879
890
  def __iter__(self):
880
- return iter(self.columns)
891
+ return iter(self.fields)
881
892
 
882
893
  def __len__(self):
883
- return len(self.columns)
894
+ return len(self.fields)
884
895
 
885
896
  def summary_table(self):
886
897
  data = []
887
- for name, col in self.columns.items():
898
+ for name, col in self.fields.items():
888
899
  summary = col.vcf_field.summary
889
900
  d = {
890
901
  "name": name,
@@ -900,9 +911,9 @@ class IntermediateColumnarFormat(collections.abc.Mapping):
900
911
  data.append(d)
901
912
  return data
902
913
 
903
- @functools.cached_property
914
+ @property
904
915
  def num_records(self):
905
- return sum(self.metadata.contig_record_counts.values())
916
+ return self.metadata.num_records
906
917
 
907
918
  @property
908
919
  def num_partitions(self):
@@ -913,8 +924,42 @@ class IntermediateColumnarFormat(collections.abc.Mapping):
913
924
  return len(self.metadata.samples)
914
925
 
915
926
  @property
916
- def num_columns(self):
917
- return len(self.columns)
927
+ def num_fields(self):
928
+ return len(self.fields)
929
+
930
+
931
+ @dataclasses.dataclass
932
+ class IcfPartitionMetadata:
933
+ num_records: int
934
+ last_position: int
935
+ field_summaries: dict
936
+
937
+ def asdict(self):
938
+ return dataclasses.asdict(self)
939
+
940
+ def asjson(self):
941
+ return json.dumps(self.asdict(), indent=4)
942
+
943
+ @staticmethod
944
+ def fromdict(d):
945
+ md = IcfPartitionMetadata(**d)
946
+ for k, v in md.field_summaries.items():
947
+ md.field_summaries[k] = VcfFieldSummary.fromdict(v)
948
+ return md
949
+
950
+
951
+ def check_overlapping_partitions(partitions):
952
+ for i in range(1, len(partitions)):
953
+ prev_region = partitions[i - 1].region
954
+ current_region = partitions[i].region
955
+ if prev_region.contig == current_region.contig:
956
+ assert prev_region.end is not None
957
+ # Regions are *inclusive*
958
+ if prev_region.end >= current_region.start:
959
+ raise ValueError(
960
+ f"Overlapping VCF regions in partitions {i - 1} and {i}: "
961
+ f"{prev_region} and {current_region}"
962
+ )
918
963
 
919
964
 
920
965
  class IntermediateColumnarFormatWriter:
@@ -988,11 +1033,8 @@ class IntermediateColumnarFormatWriter:
988
1033
  not_found = []
989
1034
  for j in range(self.num_partitions):
990
1035
  try:
991
- with open(self.wip_path / f"p{j}_summary.json") as f:
992
- summary = json.load(f)
993
- for k, v in summary["field_summaries"].items():
994
- summary["field_summaries"][k] = VcfFieldSummary.fromdict(v)
995
- summaries.append(summary)
1036
+ with open(self.wip_path / f"p{j}.json") as f:
1037
+ summaries.append(IcfPartitionMetadata.fromdict(json.load(f)))
996
1038
  except FileNotFoundError:
997
1039
  not_found.append(j)
998
1040
  if len(not_found) > 0:
@@ -1009,7 +1051,7 @@ class IntermediateColumnarFormatWriter:
1009
1051
 
1010
1052
  def process_partition(self, partition_index):
1011
1053
  self.load_metadata()
1012
- summary_path = self.wip_path / f"p{partition_index}_summary.json"
1054
+ summary_path = self.wip_path / f"p{partition_index}.json"
1013
1055
  # If someone is rewriting a summary path (for whatever reason), make sure it
1014
1056
  # doesn't look like it's already been completed.
1015
1057
  # NOTE to do this properly we probably need to take a lock on this file - but
@@ -1030,6 +1072,7 @@ class IntermediateColumnarFormatWriter:
1030
1072
  else:
1031
1073
  format_fields.append(field)
1032
1074
 
1075
+ last_position = None
1033
1076
  with IcfPartitionWriter(
1034
1077
  self.metadata,
1035
1078
  self.path,
@@ -1039,6 +1082,7 @@ class IntermediateColumnarFormatWriter:
1039
1082
  num_records = 0
1040
1083
  for variant in ivcf.variants(partition.region):
1041
1084
  num_records += 1
1085
+ last_position = variant.POS
1042
1086
  tcw.append("CHROM", variant.CHROM)
1043
1087
  tcw.append("POS", variant.POS)
1044
1088
  tcw.append("QUAL", variant.QUAL)
@@ -1063,37 +1107,32 @@ class IntermediateColumnarFormatWriter:
1063
1107
  f"flushing buffers"
1064
1108
  )
1065
1109
 
1066
- partition_metadata = {
1067
- "num_records": num_records,
1068
- "field_summaries": {k: v.asdict() for k, v in tcw.field_summaries.items()},
1069
- }
1110
+ partition_metadata = IcfPartitionMetadata(
1111
+ num_records=num_records,
1112
+ last_position=last_position,
1113
+ field_summaries=tcw.field_summaries,
1114
+ )
1070
1115
  with open(summary_path, "w") as f:
1071
- json.dump(partition_metadata, f, indent=4)
1116
+ f.write(partition_metadata.asjson())
1072
1117
  logger.info(
1073
- f"Finish p{partition_index} {partition.vcf_path}__{partition.region}="
1074
- f"{num_records} records"
1118
+ f"Finish p{partition_index} {partition.vcf_path}__{partition.region} "
1119
+ f"{num_records} records last_pos={last_position}"
1075
1120
  )
1076
1121
 
1077
- def process_partition_slice(
1078
- self,
1079
- start,
1080
- stop,
1081
- *,
1082
- worker_processes=1,
1083
- show_progress=False,
1084
- ):
1122
+ def explode(self, *, worker_processes=1, show_progress=False):
1085
1123
  self.load_metadata()
1086
- if start == 0 and stop == self.num_partitions:
1087
- num_records = self.metadata.num_records
1088
- else:
1089
- # We only know the number of records if all partitions are done at once,
1090
- # and we signal this to tqdm by passing None as the total.
1124
+ num_records = self.metadata.num_records
1125
+ if np.isinf(num_records):
1126
+ logger.warning(
1127
+ "Total records unknown, cannot show progress; "
1128
+ "reindex VCFs with bcftools index to fix"
1129
+ )
1091
1130
  num_records = None
1092
- num_columns = len(self.metadata.fields)
1131
+ num_fields = len(self.metadata.fields)
1093
1132
  num_samples = len(self.metadata.samples)
1094
1133
  logger.info(
1095
- f"Exploding columns={num_columns} samples={num_samples}; "
1096
- f"partitions={stop - start} "
1134
+ f"Exploding fields={num_fields} samples={num_samples}; "
1135
+ f"partitions={self.num_partitions} "
1097
1136
  f"variants={'unknown' if num_records is None else num_records}"
1098
1137
  )
1099
1138
  progress_config = core.ProgressConfig(
@@ -1103,48 +1142,43 @@ class IntermediateColumnarFormatWriter:
1103
1142
  show=show_progress,
1104
1143
  )
1105
1144
  with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1106
- for j in range(start, stop):
1145
+ for j in range(self.num_partitions):
1107
1146
  pwm.submit(self.process_partition, j)
1108
1147
 
1109
- def explode(self, *, worker_processes=1, show_progress=False):
1110
- self.load_metadata()
1111
- return self.process_partition_slice(
1112
- 0,
1113
- self.num_partitions,
1114
- worker_processes=worker_processes,
1115
- show_progress=show_progress,
1116
- )
1117
-
1118
- def explode_partition(self, partition, *, show_progress=False, worker_processes=1):
1148
+ def explode_partition(self, partition):
1119
1149
  self.load_metadata()
1120
1150
  if partition < 0 or partition >= self.num_partitions:
1121
1151
  raise ValueError(
1122
1152
  "Partition index must be in the range 0 <= index < num_partitions"
1123
1153
  )
1124
- return self.process_partition_slice(
1125
- partition,
1126
- partition + 1,
1127
- worker_processes=worker_processes,
1128
- show_progress=show_progress,
1129
- )
1154
+ self.process_partition(partition)
1130
1155
 
1131
1156
  def finalise(self):
1132
1157
  self.load_metadata()
1133
1158
  partition_summaries = self.load_partition_summaries()
1134
1159
  total_records = 0
1135
1160
  for index, summary in enumerate(partition_summaries):
1136
- partition_records = summary["num_records"]
1161
+ partition_records = summary.num_records
1137
1162
  self.metadata.partitions[index].num_records = partition_records
1163
+ self.metadata.partitions[index].region.end = summary.last_position
1138
1164
  total_records += partition_records
1139
- assert total_records == self.metadata.num_records
1165
+ if not np.isinf(self.metadata.num_records):
1166
+ # Note: this is just telling us that there's a bug in the
1167
+ # index based record counting code, but it doesn't actually
1168
+ # matter much. We may want to just make this a warning if
1169
+ # we hit regular problems.
1170
+ assert total_records == self.metadata.num_records
1171
+ self.metadata.num_records = total_records
1172
+
1173
+ check_overlapping_partitions(self.metadata.partitions)
1140
1174
 
1141
1175
  for field in self.metadata.fields:
1142
1176
  for summary in partition_summaries:
1143
- field.summary.update(summary["field_summaries"][field.full_name])
1177
+ field.summary.update(summary.field_summaries[field.full_name])
1144
1178
 
1145
1179
  logger.info("Finalising metadata")
1146
1180
  with open(self.path / "metadata.json", "w") as f:
1147
- json.dump(self.metadata.asdict(), f, indent=4)
1181
+ f.write(self.metadata.asjson())
1148
1182
 
1149
1183
  logger.debug("Removing WIP directory")
1150
1184
  shutil.rmtree(self.wip_path)
@@ -1195,14 +1229,9 @@ def explode_init(
1195
1229
  )
1196
1230
 
1197
1231
 
1198
- # NOTE only including worker_processes here so we can use the 0 option to get the
1199
- # work done syncronously and so we can get test coverage on it. Should find a
1200
- # better way to do this.
1201
- def explode_partition(icf_path, partition, *, show_progress=False, worker_processes=1):
1232
+ def explode_partition(icf_path, partition):
1202
1233
  writer = IntermediateColumnarFormatWriter(icf_path)
1203
- writer.explode_partition(
1204
- partition, show_progress=show_progress, worker_processes=worker_processes
1205
- )
1234
+ writer.explode_partition(partition)
1206
1235
 
1207
1236
 
1208
1237
  def explode_finalise(icf_path):
@@ -1330,7 +1359,7 @@ class ZarrColumnSpec:
1330
1359
  return chunk_items * dt.itemsize
1331
1360
 
1332
1361
 
1333
- ZARR_SCHEMA_FORMAT_VERSION = "0.2"
1362
+ ZARR_SCHEMA_FORMAT_VERSION = "0.3"
1334
1363
 
1335
1364
 
1336
1365
  @dataclasses.dataclass
@@ -1339,11 +1368,10 @@ class VcfZarrSchema:
1339
1368
  samples_chunk_size: int
1340
1369
  variants_chunk_size: int
1341
1370
  dimensions: list
1342
- sample_id: list
1343
- contig_id: list
1344
- contig_length: list
1345
- filter_id: list
1346
- columns: dict
1371
+ samples: list
1372
+ contigs: list
1373
+ filters: list
1374
+ fields: dict
1347
1375
 
1348
1376
  def asdict(self):
1349
1377
  return dataclasses.asdict(self)
@@ -1359,8 +1387,11 @@ class VcfZarrSchema:
1359
1387
  f"{d['format_version']} != {ZARR_SCHEMA_FORMAT_VERSION}"
1360
1388
  )
1361
1389
  ret = VcfZarrSchema(**d)
1362
- ret.columns = {
1363
- key: ZarrColumnSpec(**value) for key, value in d["columns"].items()
1390
+ ret.samples = [Sample(**sd) for sd in d["samples"]]
1391
+ ret.contigs = [Contig(**sd) for sd in d["contigs"]]
1392
+ ret.filters = [Filter(**sd) for sd in d["filters"]]
1393
+ ret.fields = {
1394
+ key: ZarrColumnSpec(**value) for key, value in d["fields"].items()
1364
1395
  }
1365
1396
  return ret
1366
1397
 
@@ -1404,7 +1435,7 @@ class VcfZarrSchema:
1404
1435
  chunks=[variants_chunk_size],
1405
1436
  )
1406
1437
 
1407
- alt_col = icf.columns["ALT"]
1438
+ alt_col = icf.fields["ALT"]
1408
1439
  max_alleles = alt_col.vcf_field.summary.max_number + 1
1409
1440
 
1410
1441
  colspecs = [
@@ -1496,12 +1527,11 @@ class VcfZarrSchema:
1496
1527
  format_version=ZARR_SCHEMA_FORMAT_VERSION,
1497
1528
  samples_chunk_size=samples_chunk_size,
1498
1529
  variants_chunk_size=variants_chunk_size,
1499
- columns={col.name: col for col in colspecs},
1530
+ fields={col.name: col for col in colspecs},
1500
1531
  dimensions=["variants", "samples", "ploidy", "alleles", "filters"],
1501
- sample_id=icf.metadata.samples,
1502
- contig_id=icf.metadata.contig_names,
1503
- contig_length=icf.metadata.contig_lengths,
1504
- filter_id=icf.metadata.filters,
1532
+ samples=icf.metadata.samples,
1533
+ contigs=icf.metadata.contigs,
1534
+ filters=icf.metadata.filters,
1505
1535
  )
1506
1536
 
1507
1537
 
@@ -1509,14 +1539,12 @@ class VcfZarr:
1509
1539
  def __init__(self, path):
1510
1540
  if not (path / ".zmetadata").exists():
1511
1541
  raise ValueError("Not in VcfZarr format") # NEEDS TEST
1542
+ self.path = path
1512
1543
  self.root = zarr.open(path, mode="r")
1513
1544
 
1514
- def __repr__(self):
1515
- return repr(self.root) # NEEDS TEST
1516
-
1517
1545
  def summary_table(self):
1518
1546
  data = []
1519
- arrays = [(a.nbytes_stored, a) for _, a in self.root.arrays()]
1547
+ arrays = [(core.du(self.path / a.basename), a) for _, a in self.root.arrays()]
1520
1548
  arrays.sort(key=lambda x: x[0])
1521
1549
  for stored, array in reversed(arrays):
1522
1550
  d = {
@@ -1549,10 +1577,8 @@ def parse_max_memory(max_memory):
1549
1577
 
1550
1578
  @dataclasses.dataclass
1551
1579
  class VcfZarrPartition:
1552
- start_index: int
1553
- stop_index: int
1554
- start_chunk: int
1555
- stop_chunk: int
1580
+ start: int
1581
+ stop: int
1556
1582
 
1557
1583
  @staticmethod
1558
1584
  def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None):
@@ -1566,9 +1592,7 @@ class VcfZarrPartition:
1566
1592
  stop_chunk = int(chunk_slice[-1]) + 1
1567
1593
  start_index = start_chunk * chunk_size
1568
1594
  stop_index = min(stop_chunk * chunk_size, num_records)
1569
- partitions.append(
1570
- VcfZarrPartition(start_index, stop_index, start_chunk, stop_chunk)
1571
- )
1595
+ partitions.append(VcfZarrPartition(start_index, stop_index))
1572
1596
  return partitions
1573
1597
 
1574
1598
 
@@ -1591,7 +1615,7 @@ class VcfZarrWriterMetadata:
1591
1615
  def fromdict(d):
1592
1616
  if d["format_version"] != VZW_METADATA_FORMAT_VERSION:
1593
1617
  raise ValueError(
1594
- "VcfZarrWriter format version mismatch: "
1618
+ "VcfZarrWriter format version mismatch: "
1595
1619
  f"{d['format_version']} != {VZW_METADATA_FORMAT_VERSION}"
1596
1620
  )
1597
1621
  ret = VcfZarrWriterMetadata(**d)
@@ -1675,8 +1699,8 @@ class VcfZarrWriter:
1675
1699
  store = zarr.DirectoryStore(self.arrays_path)
1676
1700
  root = zarr.group(store=store)
1677
1701
 
1678
- for column in self.schema.columns.values():
1679
- self.init_array(root, column, partitions[-1].stop_index)
1702
+ for column in self.schema.fields.values():
1703
+ self.init_array(root, column, partitions[-1].stop)
1680
1704
 
1681
1705
  logger.info("Writing WIP metadata")
1682
1706
  with open(self.wip_path / "metadata.json", "w") as f:
@@ -1684,13 +1708,13 @@ class VcfZarrWriter:
1684
1708
  return len(partitions)
1685
1709
 
1686
1710
  def encode_samples(self, root):
1687
- if not np.array_equal(self.schema.sample_id, self.icf.metadata.samples):
1711
+ if self.schema.samples != self.icf.metadata.samples:
1688
1712
  raise ValueError(
1689
1713
  "Subsetting or reordering samples not supported currently"
1690
1714
  ) # NEEDS TEST
1691
1715
  array = root.array(
1692
1716
  "sample_id",
1693
- self.schema.sample_id,
1717
+ [sample.id for sample in self.schema.samples],
1694
1718
  dtype="str",
1695
1719
  compressor=DEFAULT_ZARR_COMPRESSOR,
1696
1720
  chunks=(self.schema.samples_chunk_size,),
@@ -1701,24 +1725,26 @@ class VcfZarrWriter:
1701
1725
  def encode_contig_id(self, root):
1702
1726
  array = root.array(
1703
1727
  "contig_id",
1704
- self.schema.contig_id,
1728
+ [contig.id for contig in self.schema.contigs],
1705
1729
  dtype="str",
1706
1730
  compressor=DEFAULT_ZARR_COMPRESSOR,
1707
1731
  )
1708
1732
  array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1709
- if self.schema.contig_length is not None:
1733
+ if all(contig.length is not None for contig in self.schema.contigs):
1710
1734
  array = root.array(
1711
1735
  "contig_length",
1712
- self.schema.contig_length,
1736
+ [contig.length for contig in self.schema.contigs],
1713
1737
  dtype=np.int64,
1714
1738
  compressor=DEFAULT_ZARR_COMPRESSOR,
1715
1739
  )
1716
1740
  array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1717
1741
 
1718
1742
  def encode_filter_id(self, root):
1743
+ # TODO need a way to store description also
1744
+ # https://github.com/sgkit-dev/vcf-zarr-spec/issues/19
1719
1745
  array = root.array(
1720
1746
  "filter_id",
1721
- self.schema.filter_id,
1747
+ [filt.id for filt in self.schema.filters],
1722
1748
  dtype="str",
1723
1749
  compressor=DEFAULT_ZARR_COMPRESSOR,
1724
1750
  )
@@ -1763,28 +1789,42 @@ class VcfZarrWriter:
1763
1789
  def partition_path(self, partition_index):
1764
1790
  return self.partitions_path / f"p{partition_index}"
1765
1791
 
1792
+ def wip_partition_path(self, partition_index):
1793
+ return self.partitions_path / f"wip_p{partition_index}"
1794
+
1766
1795
  def wip_partition_array_path(self, partition_index, name):
1767
- return self.partition_path(partition_index) / f"wip_{name}"
1796
+ return self.wip_partition_path(partition_index) / name
1768
1797
 
1769
1798
  def partition_array_path(self, partition_index, name):
1770
1799
  return self.partition_path(partition_index) / name
1771
1800
 
1772
1801
  def encode_partition(self, partition_index):
1773
1802
  self.load_metadata()
1774
- partition_path = self.partition_path(partition_index)
1803
+ if partition_index < 0 or partition_index >= self.num_partitions:
1804
+ raise ValueError(
1805
+ "Partition index must be in the range 0 <= index < num_partitions"
1806
+ )
1807
+ partition_path = self.wip_partition_path(partition_index)
1775
1808
  partition_path.mkdir(exist_ok=True)
1776
1809
  logger.info(f"Encoding partition {partition_index} to {partition_path}")
1777
1810
 
1778
- self.encode_alleles_partition(partition_index)
1779
1811
  self.encode_id_partition(partition_index)
1780
1812
  self.encode_filters_partition(partition_index)
1781
1813
  self.encode_contig_partition(partition_index)
1782
- for col in self.schema.columns.values():
1814
+ self.encode_alleles_partition(partition_index)
1815
+ for col in self.schema.fields.values():
1783
1816
  if col.vcf_field is not None:
1784
1817
  self.encode_array_partition(col, partition_index)
1785
- if "call_genotype" in self.schema.columns:
1818
+ if "call_genotype" in self.schema.fields:
1786
1819
  self.encode_genotypes_partition(partition_index)
1787
1820
 
1821
+ final_path = self.partition_path(partition_index)
1822
+ logger.info(f"Finalising {partition_index} at {final_path}")
1823
+ if final_path.exists():
1824
+ logger.warning(f"Removing existing partition at {final_path}")
1825
+ shutil.rmtree(final_path)
1826
+ os.rename(partition_path, final_path)
1827
+
1788
1828
  def init_partition_array(self, partition_index, name):
1789
1829
  wip_path = self.wip_partition_array_path(partition_index, name)
1790
1830
  # Create an empty array like the definition
@@ -1796,27 +1836,17 @@ class VcfZarrWriter:
1796
1836
  return array
1797
1837
 
1798
1838
  def finalise_partition_array(self, partition_index, name):
1799
- wip_path = self.wip_partition_array_path(partition_index, name)
1800
- final_path = self.partition_array_path(partition_index, name)
1801
- if final_path.exists():
1802
- # NEEDS TEST
1803
- logger.warning(f"Removing existing {final_path}")
1804
- shutil.rmtree(final_path)
1805
- # Atomic swap
1806
- os.rename(wip_path, final_path)
1807
1839
  logger.debug(f"Encoded {name} partition {partition_index}")
1808
1840
 
1809
1841
  def encode_array_partition(self, column, partition_index):
1810
1842
  array = self.init_partition_array(partition_index, column.name)
1811
1843
 
1812
1844
  partition = self.metadata.partitions[partition_index]
1813
- ba = core.BufferedArray(array, partition.start_index)
1814
- source_col = self.icf.columns[column.vcf_field]
1845
+ ba = core.BufferedArray(array, partition.start)
1846
+ source_col = self.icf.fields[column.vcf_field]
1815
1847
  sanitiser = source_col.sanitiser_factory(ba.buff.shape)
1816
1848
 
1817
- for value in source_col.iter_values(
1818
- partition.start_index, partition.stop_index
1819
- ):
1849
+ for value in source_col.iter_values(partition.start, partition.stop):
1820
1850
  # We write directly into the buffer in the sanitiser function
1821
1851
  # to make it easier to reason about dimension padding
1822
1852
  j = ba.next_buffer_row()
@@ -1832,14 +1862,12 @@ class VcfZarrWriter:
1832
1862
  )
1833
1863
 
1834
1864
  partition = self.metadata.partitions[partition_index]
1835
- gt = core.BufferedArray(gt_array, partition.start_index)
1836
- gt_mask = core.BufferedArray(gt_mask_array, partition.start_index)
1837
- gt_phased = core.BufferedArray(gt_phased_array, partition.start_index)
1865
+ gt = core.BufferedArray(gt_array, partition.start)
1866
+ gt_mask = core.BufferedArray(gt_mask_array, partition.start)
1867
+ gt_phased = core.BufferedArray(gt_phased_array, partition.start)
1838
1868
 
1839
- source_col = self.icf.columns["FORMAT/GT"]
1840
- for value in source_col.iter_values(
1841
- partition.start_index, partition.stop_index
1842
- ):
1869
+ source_col = self.icf.fields["FORMAT/GT"]
1870
+ for value in source_col.iter_values(partition.start, partition.stop):
1843
1871
  j = gt.next_buffer_row()
1844
1872
  sanitise_value_int_2d(gt.buff, j, value[:, :-1])
1845
1873
  j = gt_phased.next_buffer_row()
@@ -1860,13 +1888,13 @@ class VcfZarrWriter:
1860
1888
  array_name = "variant_allele"
1861
1889
  alleles_array = self.init_partition_array(partition_index, array_name)
1862
1890
  partition = self.metadata.partitions[partition_index]
1863
- alleles = core.BufferedArray(alleles_array, partition.start_index)
1864
- ref_col = self.icf.columns["REF"]
1865
- alt_col = self.icf.columns["ALT"]
1891
+ alleles = core.BufferedArray(alleles_array, partition.start)
1892
+ ref_col = self.icf.fields["REF"]
1893
+ alt_col = self.icf.fields["ALT"]
1866
1894
 
1867
1895
  for ref, alt in zip(
1868
- ref_col.iter_values(partition.start_index, partition.stop_index),
1869
- alt_col.iter_values(partition.start_index, partition.stop_index),
1896
+ ref_col.iter_values(partition.start, partition.stop),
1897
+ alt_col.iter_values(partition.start, partition.stop),
1870
1898
  ):
1871
1899
  j = alleles.next_buffer_row()
1872
1900
  alleles.buff[j, :] = STR_FILL
@@ -1880,11 +1908,11 @@ class VcfZarrWriter:
1880
1908
  vid_array = self.init_partition_array(partition_index, "variant_id")
1881
1909
  vid_mask_array = self.init_partition_array(partition_index, "variant_id_mask")
1882
1910
  partition = self.metadata.partitions[partition_index]
1883
- vid = core.BufferedArray(vid_array, partition.start_index)
1884
- vid_mask = core.BufferedArray(vid_mask_array, partition.start_index)
1885
- col = self.icf.columns["ID"]
1911
+ vid = core.BufferedArray(vid_array, partition.start)
1912
+ vid_mask = core.BufferedArray(vid_mask_array, partition.start)
1913
+ col = self.icf.fields["ID"]
1886
1914
 
1887
- for value in col.iter_values(partition.start_index, partition.stop_index):
1915
+ for value in col.iter_values(partition.start, partition.stop):
1888
1916
  j = vid.next_buffer_row()
1889
1917
  k = vid_mask.next_buffer_row()
1890
1918
  assert j == k
@@ -1901,14 +1929,14 @@ class VcfZarrWriter:
1901
1929
  self.finalise_partition_array(partition_index, "variant_id_mask")
1902
1930
 
1903
1931
  def encode_filters_partition(self, partition_index):
1904
- lookup = {filt: index for index, filt in enumerate(self.schema.filter_id)}
1932
+ lookup = {filt.id: index for index, filt in enumerate(self.schema.filters)}
1905
1933
  array_name = "variant_filter"
1906
1934
  array = self.init_partition_array(partition_index, array_name)
1907
1935
  partition = self.metadata.partitions[partition_index]
1908
- var_filter = core.BufferedArray(array, partition.start_index)
1936
+ var_filter = core.BufferedArray(array, partition.start)
1909
1937
 
1910
- col = self.icf.columns["FILTERS"]
1911
- for value in col.iter_values(partition.start_index, partition.stop_index):
1938
+ col = self.icf.fields["FILTERS"]
1939
+ for value in col.iter_values(partition.start, partition.stop):
1912
1940
  j = var_filter.next_buffer_row()
1913
1941
  var_filter.buff[j] = False
1914
1942
  for f in value:
@@ -1923,14 +1951,14 @@ class VcfZarrWriter:
1923
1951
  self.finalise_partition_array(partition_index, array_name)
1924
1952
 
1925
1953
  def encode_contig_partition(self, partition_index):
1926
- lookup = {contig: index for index, contig in enumerate(self.schema.contig_id)}
1954
+ lookup = {contig.id: index for index, contig in enumerate(self.schema.contigs)}
1927
1955
  array_name = "variant_contig"
1928
1956
  array = self.init_partition_array(partition_index, array_name)
1929
1957
  partition = self.metadata.partitions[partition_index]
1930
- contig = core.BufferedArray(array, partition.start_index)
1931
- col = self.icf.columns["CHROM"]
1958
+ contig = core.BufferedArray(array, partition.start)
1959
+ col = self.icf.fields["CHROM"]
1932
1960
 
1933
- for value in col.iter_values(partition.start_index, partition.stop_index):
1961
+ for value in col.iter_values(partition.start, partition.stop):
1934
1962
  j = contig.next_buffer_row()
1935
1963
  # Note: because we are using the indexes to define the lookups
1936
1964
  # and we always have an index, it seems that we the contig lookup
@@ -1951,7 +1979,7 @@ class VcfZarrWriter:
1951
1979
  if final_path.exists():
1952
1980
  # NEEDS TEST
1953
1981
  raise ValueError(f"Array {name} already exists")
1954
- for partition in range(len(self.metadata.partitions)):
1982
+ for partition in range(self.num_partitions):
1955
1983
  # Move all the files in partition dir to dest dir
1956
1984
  src = self.partition_array_path(partition, name)
1957
1985
  if not src.exists():
@@ -1978,8 +2006,17 @@ class VcfZarrWriter:
1978
2006
  def finalise(self, show_progress=False):
1979
2007
  self.load_metadata()
1980
2008
 
2009
+ logger.info("Scanning {self.num_partitions} partitions")
2010
+ missing = []
2011
+ # TODO may need a progress bar here
2012
+ for partition_id in range(self.num_partitions):
2013
+ if not self.partition_path(partition_id).exists():
2014
+ missing.append(partition_id)
2015
+ if len(missing) > 0:
2016
+ raise FileNotFoundError(f"Partitions not encoded: {missing}")
2017
+
1981
2018
  progress_config = core.ProgressConfig(
1982
- total=len(self.schema.columns),
2019
+ total=len(self.schema.fields),
1983
2020
  title="Finalise",
1984
2021
  units="array",
1985
2022
  show=show_progress,
@@ -1993,8 +2030,11 @@ class VcfZarrWriter:
1993
2030
  # for multiple workers, or making a standard wrapper for tqdm
1994
2031
  # that allows us to have a consistent look and feel.
1995
2032
  with core.ParallelWorkManager(0, progress_config) as pwm:
1996
- for name in self.schema.columns:
2033
+ for name in self.schema.fields:
1997
2034
  pwm.submit(self.finalise_array, name)
2035
+ logger.debug(f"Removing {self.wip_path}")
2036
+ shutil.rmtree(self.wip_path)
2037
+ logger.info("Consolidating Zarr metadata")
1998
2038
  zarr.consolidate_metadata(self.path)
1999
2039
 
2000
2040
  ######################
@@ -2006,18 +2046,17 @@ class VcfZarrWriter:
2006
2046
  Return the approximate maximum memory used to encode a variant chunk.
2007
2047
  """
2008
2048
  max_encoding_mem = max(
2009
- col.variant_chunk_nbytes for col in self.schema.columns.values()
2049
+ col.variant_chunk_nbytes for col in self.schema.fields.values()
2010
2050
  )
2011
2051
  gt_mem = 0
2012
- if "call_genotype" in self.schema.columns:
2052
+ if "call_genotype" in self.schema.fields:
2013
2053
  encoded_together = [
2014
2054
  "call_genotype",
2015
2055
  "call_genotype_phased",
2016
2056
  "call_genotype_mask",
2017
2057
  ]
2018
2058
  gt_mem = sum(
2019
- self.schema.columns[col].variant_chunk_nbytes
2020
- for col in encoded_together
2059
+ self.schema.fields[col].variant_chunk_nbytes for col in encoded_together
2021
2060
  )
2022
2061
  return max(max_encoding_mem, gt_mem)
2023
2062
 
@@ -2049,7 +2088,7 @@ class VcfZarrWriter:
2049
2088
  num_workers = min(max_num_workers, worker_processes)
2050
2089
 
2051
2090
  total_bytes = 0
2052
- for col in self.schema.columns.values():
2091
+ for col in self.schema.fields.values():
2053
2092
  # Open the array definition to get the total size
2054
2093
  total_bytes += zarr.open(self.arrays_path / col.name).nbytes
2055
2094