bio2zarr 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bio2zarr might be problematic. Click here for more details.

bio2zarr/vcf.py CHANGED
@@ -1,29 +1,27 @@
1
1
  import collections
2
+ import contextlib
2
3
  import dataclasses
3
4
  import functools
5
+ import json
4
6
  import logging
7
+ import math
5
8
  import os
6
9
  import pathlib
7
10
  import pickle
8
- import sys
9
11
  import shutil
10
- import json
11
- import math
12
+ import sys
12
13
  import tempfile
13
- import contextlib
14
14
  from typing import Any, List
15
15
 
16
- import humanfriendly
17
16
  import cyvcf2
17
+ import humanfriendly
18
18
  import numcodecs
19
19
  import numpy as np
20
20
  import numpy.testing as nt
21
21
  import tqdm
22
22
  import zarr
23
23
 
24
- from . import core
25
- from . import provenance
26
- from . import vcf_utils
24
+ from . import core, provenance, vcf_utils
27
25
 
28
26
  logger = logging.getLogger(__name__)
29
27
 
@@ -113,9 +111,6 @@ class VcfField:
113
111
  return self.name
114
112
  return f"{self.category}/{self.name}"
115
113
 
116
- # TODO add method here to choose a good set compressor and
117
- # filters default here for this field.
118
-
119
114
  def smallest_dtype(self):
120
115
  """
121
116
  Returns the smallest dtype suitable for this field based
@@ -125,13 +120,13 @@ class VcfField:
125
120
  if self.vcf_type == "Float":
126
121
  ret = "f4"
127
122
  elif self.vcf_type == "Integer":
128
- dtype = "i4"
129
- for a_dtype in ["i1", "i2"]:
130
- info = np.iinfo(a_dtype)
131
- if info.min <= s.min_value and s.max_value <= info.max:
132
- dtype = a_dtype
133
- break
134
- ret = dtype
123
+ if not math.isfinite(s.max_value):
124
+ # All missing values; use i1. Note we should have some API to
125
+ # check more explicitly for missingness:
126
+ # https://github.com/sgkit-dev/bio2zarr/issues/131
127
+ ret = "i1"
128
+ else:
129
+ ret = core.min_int_dtype(s.min_value, s.max_value)
135
130
  elif self.vcf_type == "Flag":
136
131
  ret = "bool"
137
132
  elif self.vcf_type == "Character":
@@ -154,6 +149,10 @@ ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc(
154
149
  cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE
155
150
  )
156
151
 
152
+ # TODO refactor this to have embedded Contig dataclass, Filters
153
+ # and Samples dataclasses to allow for more information to be
154
+ # retained and forward compatibility.
155
+
157
156
 
158
157
  @dataclasses.dataclass
159
158
  class IcfMetadata:
@@ -185,6 +184,14 @@ class IcfMetadata:
185
184
  fields.append(field)
186
185
  return fields
187
186
 
187
+ @property
188
+ def num_contigs(self):
189
+ return len(self.contig_names)
190
+
191
+ @property
192
+ def num_filters(self):
193
+ return len(self.filters)
194
+
188
195
  @property
189
196
  def num_records(self):
190
197
  return sum(self.contig_record_counts.values())
@@ -284,9 +291,25 @@ def scan_vcf(path, target_num_partitions):
284
291
  return metadata, vcf.raw_header
285
292
 
286
293
 
294
+ def check_overlap(partitions):
295
+ for i in range(1, len(partitions)):
296
+ prev_partition = partitions[i - 1]
297
+ current_partition = partitions[i]
298
+ if (
299
+ prev_partition.region.contig == current_partition.region.contig
300
+ and prev_partition.region.end > current_partition.region.start
301
+ ):
302
+ raise ValueError(
303
+ f"Multiple VCFs have the region "
304
+ f"{prev_partition.region.contig}:{prev_partition.region.start}-"
305
+ f"{current_partition.region.end}"
306
+ )
307
+
308
+
287
309
  def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
288
310
  logger.info(
289
- f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions} partitions."
311
+ f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
312
+ f" partitions."
290
313
  )
291
314
  # An easy mistake to make is to pass the same file twice. Check this early on.
292
315
  for path, count in collections.Counter(paths).items():
@@ -331,6 +354,7 @@ def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
331
354
  all_partitions.sort(
332
355
  key=lambda x: (contig_index_map[x.region.contig], x.region.start)
333
356
  )
357
+ check_overlap(all_partitions)
334
358
  icf_metadata.partitions = all_partitions
335
359
  logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.")
336
360
  return icf_metadata, header
@@ -791,6 +815,8 @@ class IcfPartitionWriter(contextlib.AbstractContextManager):
791
815
  for vcf_field in icf_metadata.fields:
792
816
  field_path = get_vcf_field_path(out_path, vcf_field)
793
817
  field_partition_path = field_path / f"p{partition_index}"
818
+ # Should be robust to running explode_partition twice.
819
+ field_partition_path.mkdir(exist_ok=True)
794
820
  transformer = VcfValueTransformer.factory(vcf_field, num_samples)
795
821
  self.field_writers[vcf_field.full_name] = IcfFieldWriter(
796
822
  vcf_field,
@@ -832,7 +858,7 @@ class IntermediateColumnarFormat(collections.abc.Mapping):
832
858
  partition.num_records for partition in self.metadata.partitions
833
859
  ]
834
860
  # Allow us to find which partition a given record is in
835
- self.partition_record_index = np.cumsum([0] + partition_num_records)
861
+ self.partition_record_index = np.cumsum([0, *partition_num_records])
836
862
  for field in self.metadata.fields:
837
863
  self.columns[field.full_name] = IntermediateColumnarFormatField(self, field)
838
864
  logger.info(
@@ -842,7 +868,8 @@ class IntermediateColumnarFormat(collections.abc.Mapping):
842
868
 
843
869
  def __repr__(self):
844
870
  return (
845
- f"IntermediateColumnarFormat(fields={len(self)}, partitions={self.num_partitions}, "
871
+ f"IntermediateColumnarFormat(fields={len(self)}, "
872
+ f"partitions={self.num_partitions}, "
846
873
  f"records={self.num_records}, path={self.path})"
847
874
  )
848
875
 
@@ -890,15 +917,6 @@ class IntermediateColumnarFormat(collections.abc.Mapping):
890
917
  return len(self.columns)
891
918
 
892
919
 
893
-
894
- def mkdir_with_progress(path):
895
- logger.debug(f"mkdir f{path}")
896
- # NOTE we may have race-conditions here, I'm not sure. Hopefully allowing
897
- # parents=True will take care of it.
898
- path.mkdir(parents=True)
899
- core.update_progress(1)
900
-
901
-
902
920
  class IntermediateColumnarFormatWriter:
903
921
  def __init__(self, path):
904
922
  self.path = pathlib.Path(path)
@@ -941,45 +959,29 @@ class IntermediateColumnarFormatWriter:
941
959
  # dependencies as well.
942
960
  self.metadata.provenance = {"source": f"bio2zarr-{provenance.__version__}"}
943
961
 
944
- self.mkdirs(worker_processes, show_progress=show_progress)
962
+ self.mkdirs()
945
963
 
946
964
  # Note: this is needed for the current version of the vcfzarr spec, but it's
947
965
  # probably going to be dropped.
948
966
  # https://github.com/pystatgen/vcf-zarr-spec/issues/15
949
967
  # May be useful to keep lying around still though?
950
- logger.info(f"Writing VCF header")
968
+ logger.info("Writing VCF header")
951
969
  with open(self.path / "header.txt", "w") as f:
952
970
  f.write(header)
953
971
 
954
- logger.info(f"Writing WIP metadata")
972
+ logger.info("Writing WIP metadata")
955
973
  with open(self.wip_path / "metadata.json", "w") as f:
956
974
  json.dump(self.metadata.asdict(), f, indent=4)
957
975
  return self.num_partitions
958
976
 
959
- def mkdirs(self, worker_processes=1, show_progress=False):
960
- num_dirs = len(self.metadata.fields) * self.num_partitions
961
- logger.info(f"Creating {num_dirs} directories")
977
+ def mkdirs(self):
978
+ num_dirs = len(self.metadata.fields)
979
+ logger.info(f"Creating {num_dirs} field directories")
962
980
  self.path.mkdir()
963
981
  self.wip_path.mkdir()
964
- # Due to high latency batch system filesystems, we create all the directories in
965
- # parallel
966
- progress_config = core.ProgressConfig(
967
- total=num_dirs,
968
- units="dirs",
969
- title="Mkdirs",
970
- show=show_progress,
971
- )
972
- with core.ParallelWorkManager(
973
- worker_processes=worker_processes, progress_config=progress_config
974
- ) as manager:
975
- for field in self.metadata.fields:
976
- col_path = get_vcf_field_path(self.path, field)
977
- # Don't bother trying to count the intermediate directories towards
978
- # progress
979
- manager.submit(col_path.mkdir, parents=True)
980
- for j in range(self.num_partitions):
981
- part_path = col_path / f"p{j}"
982
- manager.submit(mkdir_with_progress, part_path)
982
+ for field in self.metadata.fields:
983
+ col_path = get_vcf_field_path(self.path, field)
984
+ col_path.mkdir(parents=True)
983
985
 
984
986
  def load_partition_summaries(self):
985
987
  summaries = []
@@ -995,13 +997,14 @@ class IntermediateColumnarFormatWriter:
995
997
  not_found.append(j)
996
998
  if len(not_found) > 0:
997
999
  raise FileNotFoundError(
998
- f"Partition metadata not found for {len(not_found)} partitions: {not_found}"
1000
+ f"Partition metadata not found for {len(not_found)}"
1001
+ f" partitions: {not_found}"
999
1002
  )
1000
1003
  return summaries
1001
1004
 
1002
1005
  def load_metadata(self):
1003
1006
  if self.metadata is None:
1004
- with open(self.wip_path / f"metadata.json") as f:
1007
+ with open(self.wip_path / "metadata.json") as f:
1005
1008
  self.metadata = IcfMetadata.fromdict(json.load(f))
1006
1009
 
1007
1010
  def process_partition(self, partition_index):
@@ -1050,12 +1053,14 @@ class IntermediateColumnarFormatWriter:
1050
1053
  for field in format_fields:
1051
1054
  val = variant.format(field.name)
1052
1055
  tcw.append(field.full_name, val)
1053
- # Note: an issue with updating the progress per variant here like this
1054
- # is that we get a significant pause at the end of the counter while
1055
- # all the "small" fields get flushed. Possibly not much to be done about it.
1056
+ # Note: an issue with updating the progress per variant here like
1057
+ # this is that we get a significant pause at the end of the counter
1058
+ # while all the "small" fields get flushed. Possibly not much to be
1059
+ # done about it.
1056
1060
  core.update_progress(1)
1057
1061
  logger.info(
1058
- f"Finished reading VCF for partition {partition_index}, flushing buffers"
1062
+ f"Finished reading VCF for partition {partition_index}, "
1063
+ f"flushing buffers"
1059
1064
  )
1060
1065
 
1061
1066
  partition_metadata = {
@@ -1137,11 +1142,11 @@ class IntermediateColumnarFormatWriter:
1137
1142
  for summary in partition_summaries:
1138
1143
  field.summary.update(summary["field_summaries"][field.full_name])
1139
1144
 
1140
- logger.info(f"Finalising metadata")
1145
+ logger.info("Finalising metadata")
1141
1146
  with open(self.path / "metadata.json", "w") as f:
1142
1147
  json.dump(self.metadata.asdict(), f, indent=4)
1143
1148
 
1144
- logger.debug(f"Removing WIP directory")
1149
+ logger.debug("Removing WIP directory")
1145
1150
  shutil.rmtree(self.wip_path)
1146
1151
 
1147
1152
 
@@ -1155,7 +1160,7 @@ def explode(
1155
1160
  compressor=None,
1156
1161
  ):
1157
1162
  writer = IntermediateColumnarFormatWriter(icf_path)
1158
- num_partitions = writer.init(
1163
+ writer.init(
1159
1164
  vcfs,
1160
1165
  # Heuristic to get reasonable worker utilisation with lumpy partition sizing
1161
1166
  target_num_partitions=max(1, worker_processes * 4),
@@ -1226,20 +1231,69 @@ class ZarrColumnSpec:
1226
1231
  dtype: str
1227
1232
  shape: tuple
1228
1233
  chunks: tuple
1229
- dimensions: list
1234
+ dimensions: tuple
1230
1235
  description: str
1231
1236
  vcf_field: str
1232
- compressor: dict = None
1233
- filters: list = None
1234
- # TODO add filters
1237
+ compressor: dict
1238
+ filters: list
1235
1239
 
1236
1240
  def __post_init__(self):
1241
+ # Ensure these are tuples for ease of comparison and consistency
1237
1242
  self.shape = tuple(self.shape)
1238
1243
  self.chunks = tuple(self.chunks)
1239
1244
  self.dimensions = tuple(self.dimensions)
1240
- self.compressor = DEFAULT_ZARR_COMPRESSOR.get_config()
1241
- self.filters = []
1242
- self._choose_compressor_settings()
1245
+
1246
+ @staticmethod
1247
+ def new(**kwargs):
1248
+ spec = ZarrColumnSpec(
1249
+ **kwargs, compressor=DEFAULT_ZARR_COMPRESSOR.get_config(), filters=[]
1250
+ )
1251
+ spec._choose_compressor_settings()
1252
+ return spec
1253
+
1254
+ @staticmethod
1255
+ def from_field(
1256
+ vcf_field,
1257
+ *,
1258
+ num_variants,
1259
+ num_samples,
1260
+ variants_chunk_size,
1261
+ samples_chunk_size,
1262
+ variable_name=None,
1263
+ ):
1264
+ shape = [num_variants]
1265
+ prefix = "variant_"
1266
+ dimensions = ["variants"]
1267
+ chunks = [variants_chunk_size]
1268
+ if vcf_field.category == "FORMAT":
1269
+ prefix = "call_"
1270
+ shape.append(num_samples)
1271
+ chunks.append(samples_chunk_size)
1272
+ dimensions.append("samples")
1273
+ if variable_name is None:
1274
+ variable_name = prefix + vcf_field.name
1275
+ # TODO make an option to add in the empty extra dimension
1276
+ if vcf_field.summary.max_number > 1:
1277
+ shape.append(vcf_field.summary.max_number)
1278
+ # TODO we should really be checking this to see if the named dimensions
1279
+ # are actually correct.
1280
+ if vcf_field.vcf_number == "R":
1281
+ dimensions.append("alleles")
1282
+ elif vcf_field.vcf_number == "A":
1283
+ dimensions.append("alt_alleles")
1284
+ elif vcf_field.vcf_number == "G":
1285
+ dimensions.append("genotypes")
1286
+ else:
1287
+ dimensions.append(f"{vcf_field.category}_{vcf_field.name}_dim")
1288
+ return ZarrColumnSpec.new(
1289
+ vcf_field=vcf_field.full_name,
1290
+ name=variable_name,
1291
+ dtype=vcf_field.smallest_dtype(),
1292
+ shape=shape,
1293
+ chunks=chunks,
1294
+ dimensions=dimensions,
1295
+ description=vcf_field.description,
1296
+ )
1243
1297
 
1244
1298
  def _choose_compressor_settings(self):
1245
1299
  """
@@ -1249,17 +1303,32 @@ class ZarrColumnSpec:
1249
1303
 
1250
1304
  See https://github.com/pystatgen/bio2zarr/discussions/74
1251
1305
  """
1252
- dt = np.dtype(self.dtype)
1253
1306
  # Default is to not shuffle, because autoshuffle isn't recognised
1254
1307
  # by many Zarr implementations, and shuffling can lead to worse
1255
1308
  # performance in some cases anyway. Turning on shuffle should be a
1256
1309
  # deliberate choice.
1257
1310
  shuffle = numcodecs.Blosc.NOSHUFFLE
1258
- if dt.itemsize == 1:
1259
- # Any 1 byte field gets BITSHUFFLE by default
1311
+ if self.name == "call_genotype" and self.dtype == "i1":
1312
+ # call_genotype gets BITSHUFFLE by default as it gets
1313
+ # significantly better compression (at a cost of slower
1314
+ # decoding)
1260
1315
  shuffle = numcodecs.Blosc.BITSHUFFLE
1316
+ elif self.dtype == "bool":
1317
+ shuffle = numcodecs.Blosc.BITSHUFFLE
1318
+
1261
1319
  self.compressor["shuffle"] = shuffle
1262
1320
 
1321
+ @property
1322
+ def variant_chunk_nbytes(self):
1323
+ """
1324
+ Returns the nbytes for a single variant chunk of this array.
1325
+ """
1326
+ chunk_items = self.chunks[0]
1327
+ for size in self.shape[1:]:
1328
+ chunk_items *= size
1329
+ dt = np.dtype(self.dtype)
1330
+ return chunk_items * dt.itemsize
1331
+
1263
1332
 
1264
1333
  ZARR_SCHEMA_FORMAT_VERSION = "0.2"
1265
1334
 
@@ -1312,10 +1381,20 @@ class VcfZarrSchema:
1312
1381
  f"Generating schema with chunks={variants_chunk_size, samples_chunk_size}"
1313
1382
  )
1314
1383
 
1384
+ def spec_from_field(field, variable_name=None):
1385
+ return ZarrColumnSpec.from_field(
1386
+ field,
1387
+ num_samples=n,
1388
+ num_variants=m,
1389
+ samples_chunk_size=samples_chunk_size,
1390
+ variants_chunk_size=variants_chunk_size,
1391
+ variable_name=variable_name,
1392
+ )
1393
+
1315
1394
  def fixed_field_spec(
1316
1395
  name, dtype, vcf_field=None, shape=(m,), dimensions=("variants",)
1317
1396
  ):
1318
- return ZarrColumnSpec(
1397
+ return ZarrColumnSpec.new(
1319
1398
  vcf_field=vcf_field,
1320
1399
  name=name,
1321
1400
  dtype=dtype,
@@ -1327,88 +1406,58 @@ class VcfZarrSchema:
1327
1406
 
1328
1407
  alt_col = icf.columns["ALT"]
1329
1408
  max_alleles = alt_col.vcf_field.summary.max_number + 1
1330
- num_filters = len(icf.metadata.filters)
1331
1409
 
1332
- # # FIXME get dtype from lookup table
1333
1410
  colspecs = [
1334
1411
  fixed_field_spec(
1335
1412
  name="variant_contig",
1336
- dtype="i2", # FIXME
1413
+ dtype=core.min_int_dtype(0, icf.metadata.num_contigs),
1337
1414
  ),
1338
1415
  fixed_field_spec(
1339
1416
  name="variant_filter",
1340
1417
  dtype="bool",
1341
- shape=(m, num_filters),
1418
+ shape=(m, icf.metadata.num_filters),
1342
1419
  dimensions=["variants", "filters"],
1343
1420
  ),
1344
1421
  fixed_field_spec(
1345
1422
  name="variant_allele",
1346
1423
  dtype="str",
1347
- shape=[m, max_alleles],
1424
+ shape=(m, max_alleles),
1348
1425
  dimensions=["variants", "alleles"],
1349
1426
  ),
1350
1427
  fixed_field_spec(
1351
- vcf_field="POS",
1352
- name="variant_position",
1353
- dtype="i4",
1354
- ),
1355
- fixed_field_spec(
1356
- vcf_field=None,
1357
1428
  name="variant_id",
1358
1429
  dtype="str",
1359
1430
  ),
1360
1431
  fixed_field_spec(
1361
- vcf_field=None,
1362
1432
  name="variant_id_mask",
1363
1433
  dtype="bool",
1364
1434
  ),
1365
- fixed_field_spec(
1366
- vcf_field="QUAL",
1367
- name="variant_quality",
1368
- dtype="f4",
1369
- ),
1370
1435
  ]
1436
+ name_map = {field.full_name: field for field in icf.metadata.fields}
1437
+
1438
+ # Only two of the fixed fields have a direct one-to-one mapping.
1439
+ colspecs.extend(
1440
+ [
1441
+ spec_from_field(name_map["QUAL"], variable_name="variant_quality"),
1442
+ spec_from_field(name_map["POS"], variable_name="variant_position"),
1443
+ ]
1444
+ )
1445
+ colspecs.extend([spec_from_field(field) for field in icf.metadata.info_fields])
1371
1446
 
1372
1447
  gt_field = None
1373
- for field in icf.metadata.fields:
1374
- if field.category == "fixed":
1375
- continue
1448
+ for field in icf.metadata.format_fields:
1376
1449
  if field.name == "GT":
1377
1450
  gt_field = field
1378
1451
  continue
1379
- shape = [m]
1380
- prefix = "variant_"
1381
- dimensions = ["variants"]
1382
- chunks = [variants_chunk_size]
1383
- if field.category == "FORMAT":
1384
- prefix = "call_"
1385
- shape.append(n)
1386
- chunks.append(samples_chunk_size),
1387
- dimensions.append("samples")
1388
- # TODO make an option to add in the empty extra dimension
1389
- if field.summary.max_number > 1:
1390
- shape.append(field.summary.max_number)
1391
- dimensions.append(field.name)
1392
- variable_name = prefix + field.name
1393
- colspec = ZarrColumnSpec(
1394
- vcf_field=field.full_name,
1395
- name=variable_name,
1396
- dtype=field.smallest_dtype(),
1397
- shape=shape,
1398
- chunks=chunks,
1399
- dimensions=dimensions,
1400
- description=field.description,
1401
- )
1402
- colspecs.append(colspec)
1452
+ colspecs.append(spec_from_field(field))
1403
1453
 
1404
1454
  if gt_field is not None:
1405
1455
  ploidy = gt_field.summary.max_number - 1
1406
1456
  shape = [m, n]
1407
1457
  chunks = [variants_chunk_size, samples_chunk_size]
1408
1458
  dimensions = ["variants", "samples"]
1409
-
1410
1459
  colspecs.append(
1411
- ZarrColumnSpec(
1460
+ ZarrColumnSpec.new(
1412
1461
  vcf_field=None,
1413
1462
  name="call_genotype_phased",
1414
1463
  dtype="bool",
@@ -1421,7 +1470,7 @@ class VcfZarrSchema:
1421
1470
  shape += [ploidy]
1422
1471
  dimensions += ["ploidy"]
1423
1472
  colspecs.append(
1424
- ZarrColumnSpec(
1473
+ ZarrColumnSpec.new(
1425
1474
  vcf_field=None,
1426
1475
  name="call_genotype",
1427
1476
  dtype=gt_field.smallest_dtype(),
@@ -1432,7 +1481,7 @@ class VcfZarrSchema:
1432
1481
  )
1433
1482
  )
1434
1483
  colspecs.append(
1435
- ZarrColumnSpec(
1484
+ ZarrColumnSpec.new(
1436
1485
  vcf_field=None,
1437
1486
  name="call_genotype_mask",
1438
1487
  dtype="bool",
@@ -1488,15 +1537,6 @@ class VcfZarr:
1488
1537
  return data
1489
1538
 
1490
1539
 
1491
- @dataclasses.dataclass
1492
- class EncodingWork:
1493
- func: callable = dataclasses.field(repr=False)
1494
- start: int
1495
- stop: int
1496
- columns: list[str]
1497
- memory: int = 0
1498
-
1499
-
1500
1540
  def parse_max_memory(max_memory):
1501
1541
  if max_memory is None:
1502
1542
  # Effectively unbounded
@@ -1507,65 +1547,299 @@ def parse_max_memory(max_memory):
1507
1547
  return max_memory
1508
1548
 
1509
1549
 
1550
+ @dataclasses.dataclass
1551
+ class VcfZarrPartition:
1552
+ start_index: int
1553
+ stop_index: int
1554
+ start_chunk: int
1555
+ stop_chunk: int
1556
+
1557
+ @staticmethod
1558
+ def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None):
1559
+ num_chunks = int(np.ceil(num_records / chunk_size))
1560
+ if max_chunks is not None:
1561
+ num_chunks = min(num_chunks, max_chunks)
1562
+ partitions = []
1563
+ splits = np.array_split(np.arange(num_chunks), min(num_partitions, num_chunks))
1564
+ for chunk_slice in splits:
1565
+ start_chunk = int(chunk_slice[0])
1566
+ stop_chunk = int(chunk_slice[-1]) + 1
1567
+ start_index = start_chunk * chunk_size
1568
+ stop_index = min(stop_chunk * chunk_size, num_records)
1569
+ partitions.append(
1570
+ VcfZarrPartition(start_index, stop_index, start_chunk, stop_chunk)
1571
+ )
1572
+ return partitions
1573
+
1574
+
1575
+ VZW_METADATA_FORMAT_VERSION = "0.1"
1576
+
1577
+
1578
+ @dataclasses.dataclass
1579
+ class VcfZarrWriterMetadata:
1580
+ format_version: str
1581
+ icf_path: str
1582
+ schema: VcfZarrSchema
1583
+ dimension_separator: str
1584
+ partitions: list
1585
+ provenance: dict
1586
+
1587
+ def asdict(self):
1588
+ return dataclasses.asdict(self)
1589
+
1590
+ @staticmethod
1591
+ def fromdict(d):
1592
+ if d["format_version"] != VZW_METADATA_FORMAT_VERSION:
1593
+ raise ValueError(
1594
+ "VcfZarrWriter format version mismatch: "
1595
+ f"{d['format_version']} != {VZW_METADATA_FORMAT_VERSION}"
1596
+ )
1597
+ ret = VcfZarrWriterMetadata(**d)
1598
+ ret.schema = VcfZarrSchema.fromdict(ret.schema)
1599
+ ret.partitions = [VcfZarrPartition(**p) for p in ret.partitions]
1600
+ return ret
1601
+
1602
+
1510
1603
  class VcfZarrWriter:
1511
- def __init__(self, path, icf, schema, dimension_separator=None):
1604
+ def __init__(self, path):
1512
1605
  self.path = pathlib.Path(path)
1606
+ self.wip_path = self.path / "wip"
1607
+ self.arrays_path = self.wip_path / "arrays"
1608
+ self.partitions_path = self.wip_path / "partitions"
1609
+ self.metadata = None
1610
+ self.icf = None
1611
+
1612
+ @property
1613
+ def schema(self):
1614
+ return self.metadata.schema
1615
+
1616
+ @property
1617
+ def num_partitions(self):
1618
+ return len(self.metadata.partitions)
1619
+
1620
+ #######################
1621
+ # init
1622
+ #######################
1623
+
1624
+ def init(
1625
+ self,
1626
+ icf,
1627
+ *,
1628
+ target_num_partitions,
1629
+ schema,
1630
+ dimension_separator=None,
1631
+ max_variant_chunks=None,
1632
+ ):
1513
1633
  self.icf = icf
1514
- self.schema = schema
1634
+ if self.path.exists():
1635
+ raise ValueError("Zarr path already exists") # NEEDS TEST
1636
+ partitions = VcfZarrPartition.generate_partitions(
1637
+ self.icf.num_records,
1638
+ schema.variants_chunk_size,
1639
+ target_num_partitions,
1640
+ max_chunks=max_variant_chunks,
1641
+ )
1515
1642
  # Default to using nested directories following the Zarr v3 default.
1516
1643
  # This seems to require version 2.17+ to work properly
1517
- self.dimension_separator = "/" if dimension_separator is None else dimension_separator
1644
+ dimension_separator = (
1645
+ "/" if dimension_separator is None else dimension_separator
1646
+ )
1647
+ self.metadata = VcfZarrWriterMetadata(
1648
+ format_version=VZW_METADATA_FORMAT_VERSION,
1649
+ icf_path=str(self.icf.path),
1650
+ schema=schema,
1651
+ dimension_separator=dimension_separator,
1652
+ partitions=partitions,
1653
+ # Bare minimum here for provenance - see comments above
1654
+ provenance={"source": f"bio2zarr-{provenance.__version__}"},
1655
+ )
1656
+
1657
+ self.path.mkdir()
1518
1658
  store = zarr.DirectoryStore(self.path)
1519
- self.root = zarr.group(store=store)
1659
+ root = zarr.group(store=store)
1660
+ root.attrs.update(
1661
+ {
1662
+ "vcf_zarr_version": "0.2",
1663
+ "vcf_header": self.icf.vcf_header,
1664
+ "source": f"bio2zarr-{provenance.__version__}",
1665
+ }
1666
+ )
1667
+ # Doing this syncronously - this is fine surely
1668
+ self.encode_samples(root)
1669
+ self.encode_filter_id(root)
1670
+ self.encode_contig_id(root)
1671
+
1672
+ self.wip_path.mkdir()
1673
+ self.arrays_path.mkdir()
1674
+ self.partitions_path.mkdir()
1675
+ store = zarr.DirectoryStore(self.arrays_path)
1676
+ root = zarr.group(store=store)
1677
+
1678
+ for column in self.schema.columns.values():
1679
+ self.init_array(root, column, partitions[-1].stop_index)
1680
+
1681
+ logger.info("Writing WIP metadata")
1682
+ with open(self.wip_path / "metadata.json", "w") as f:
1683
+ json.dump(self.metadata.asdict(), f, indent=4)
1684
+ return len(partitions)
1685
+
1686
+ def encode_samples(self, root):
1687
+ if not np.array_equal(self.schema.sample_id, self.icf.metadata.samples):
1688
+ raise ValueError(
1689
+ "Subsetting or reordering samples not supported currently"
1690
+ ) # NEEDS TEST
1691
+ array = root.array(
1692
+ "sample_id",
1693
+ self.schema.sample_id,
1694
+ dtype="str",
1695
+ compressor=DEFAULT_ZARR_COMPRESSOR,
1696
+ chunks=(self.schema.samples_chunk_size,),
1697
+ )
1698
+ array.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
1699
+ logger.debug("Samples done")
1520
1700
 
1521
- def init_array(self, variable):
1701
+ def encode_contig_id(self, root):
1702
+ array = root.array(
1703
+ "contig_id",
1704
+ self.schema.contig_id,
1705
+ dtype="str",
1706
+ compressor=DEFAULT_ZARR_COMPRESSOR,
1707
+ )
1708
+ array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1709
+ if self.schema.contig_length is not None:
1710
+ array = root.array(
1711
+ "contig_length",
1712
+ self.schema.contig_length,
1713
+ dtype=np.int64,
1714
+ compressor=DEFAULT_ZARR_COMPRESSOR,
1715
+ )
1716
+ array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1717
+
1718
+ def encode_filter_id(self, root):
1719
+ array = root.array(
1720
+ "filter_id",
1721
+ self.schema.filter_id,
1722
+ dtype="str",
1723
+ compressor=DEFAULT_ZARR_COMPRESSOR,
1724
+ )
1725
+ array.attrs["_ARRAY_DIMENSIONS"] = ["filters"]
1726
+
1727
+ def init_array(self, root, variable, variants_dim_size):
1522
1728
  object_codec = None
1523
1729
  if variable.dtype == "O":
1524
1730
  object_codec = numcodecs.VLenUTF8()
1525
- a = self.root.empty(
1526
- "wip_" + variable.name,
1527
- shape=variable.shape,
1731
+ shape = list(variable.shape)
1732
+ # Truncate the variants dimension is max_variant_chunks was specified
1733
+ shape[0] = variants_dim_size
1734
+ a = root.empty(
1735
+ variable.name,
1736
+ shape=shape,
1528
1737
  chunks=variable.chunks,
1529
1738
  dtype=variable.dtype,
1530
1739
  compressor=numcodecs.get_codec(variable.compressor),
1531
1740
  filters=[numcodecs.get_codec(filt) for filt in variable.filters],
1532
1741
  object_codec=object_codec,
1533
- dimension_separator=self.dimension_separator,
1742
+ dimension_separator=self.metadata.dimension_separator,
1743
+ )
1744
+ a.attrs.update(
1745
+ {
1746
+ "description": variable.description,
1747
+ # Dimension names are part of the spec in Zarr v3
1748
+ "_ARRAY_DIMENSIONS": variable.dimensions,
1749
+ }
1534
1750
  )
1535
- # Dimension names are part of the spec in Zarr v3
1536
- a.attrs["_ARRAY_DIMENSIONS"] = variable.dimensions
1751
+ logger.debug(f"Initialised {a}")
1752
+
1753
+ #######################
1754
+ # encode_partition
1755
+ #######################
1756
+
1757
+ def load_metadata(self):
1758
+ if self.metadata is None:
1759
+ with open(self.wip_path / "metadata.json") as f:
1760
+ self.metadata = VcfZarrWriterMetadata.fromdict(json.load(f))
1761
+ self.icf = IntermediateColumnarFormat(self.metadata.icf_path)
1537
1762
 
1538
- def get_array(self, name):
1539
- return self.root["wip_" + name]
1763
+ def partition_path(self, partition_index):
1764
+ return self.partitions_path / f"p{partition_index}"
1540
1765
 
1541
- def finalise_array(self, variable_name):
1542
- source = self.path / ("wip_" + variable_name)
1543
- dest = self.path / variable_name
1766
+ def wip_partition_array_path(self, partition_index, name):
1767
+ return self.partition_path(partition_index) / f"wip_{name}"
1768
+
1769
+ def partition_array_path(self, partition_index, name):
1770
+ return self.partition_path(partition_index) / name
1771
+
1772
+ def encode_partition(self, partition_index):
1773
+ self.load_metadata()
1774
+ partition_path = self.partition_path(partition_index)
1775
+ partition_path.mkdir(exist_ok=True)
1776
+ logger.info(f"Encoding partition {partition_index} to {partition_path}")
1777
+
1778
+ self.encode_alleles_partition(partition_index)
1779
+ self.encode_id_partition(partition_index)
1780
+ self.encode_filters_partition(partition_index)
1781
+ self.encode_contig_partition(partition_index)
1782
+ for col in self.schema.columns.values():
1783
+ if col.vcf_field is not None:
1784
+ self.encode_array_partition(col, partition_index)
1785
+ if "call_genotype" in self.schema.columns:
1786
+ self.encode_genotypes_partition(partition_index)
1787
+
1788
+ def init_partition_array(self, partition_index, name):
1789
+ wip_path = self.wip_partition_array_path(partition_index, name)
1790
+ # Create an empty array like the definition
1791
+ src = self.arrays_path / name
1792
+ # Overwrite any existing WIP files
1793
+ shutil.copytree(src, wip_path, dirs_exist_ok=True)
1794
+ array = zarr.open(wip_path)
1795
+ logger.debug(f"Opened empty array {array} @ {wip_path}")
1796
+ return array
1797
+
1798
+ def finalise_partition_array(self, partition_index, name):
1799
+ wip_path = self.wip_partition_array_path(partition_index, name)
1800
+ final_path = self.partition_array_path(partition_index, name)
1801
+ if final_path.exists():
1802
+ # NEEDS TEST
1803
+ logger.warning(f"Removing existing {final_path}")
1804
+ shutil.rmtree(final_path)
1544
1805
  # Atomic swap
1545
- os.rename(source, dest)
1546
- logger.info(f"Finalised {variable_name}")
1806
+ os.rename(wip_path, final_path)
1807
+ logger.debug(f"Encoded {name} partition {partition_index}")
1808
+
1809
+ def encode_array_partition(self, column, partition_index):
1810
+ array = self.init_partition_array(partition_index, column.name)
1547
1811
 
1548
- def encode_array_slice(self, column, start, stop):
1812
+ partition = self.metadata.partitions[partition_index]
1813
+ ba = core.BufferedArray(array, partition.start_index)
1549
1814
  source_col = self.icf.columns[column.vcf_field]
1550
- array = self.get_array(column.name)
1551
- ba = core.BufferedArray(array, start)
1552
1815
  sanitiser = source_col.sanitiser_factory(ba.buff.shape)
1553
1816
 
1554
- for value in source_col.iter_values(start, stop):
1817
+ for value in source_col.iter_values(
1818
+ partition.start_index, partition.stop_index
1819
+ ):
1555
1820
  # We write directly into the buffer in the sanitiser function
1556
1821
  # to make it easier to reason about dimension padding
1557
1822
  j = ba.next_buffer_row()
1558
1823
  sanitiser(ba.buff, j, value)
1559
1824
  ba.flush()
1560
- logger.debug(f"Encoded {column.name} slice {start}:{stop}")
1825
+ self.finalise_partition_array(partition_index, column.name)
1561
1826
 
1562
- def encode_genotypes_slice(self, start, stop):
1563
- source_col = self.icf.columns["FORMAT/GT"]
1564
- gt = core.BufferedArray(self.get_array("call_genotype"), start)
1565
- gt_mask = core.BufferedArray(self.get_array("call_genotype_mask"), start)
1566
- gt_phased = core.BufferedArray(self.get_array("call_genotype_phased"), start)
1827
+ def encode_genotypes_partition(self, partition_index):
1828
+ gt_array = self.init_partition_array(partition_index, "call_genotype")
1829
+ gt_mask_array = self.init_partition_array(partition_index, "call_genotype_mask")
1830
+ gt_phased_array = self.init_partition_array(
1831
+ partition_index, "call_genotype_phased"
1832
+ )
1567
1833
 
1568
- for value in source_col.iter_values(start, stop):
1834
+ partition = self.metadata.partitions[partition_index]
1835
+ gt = core.BufferedArray(gt_array, partition.start_index)
1836
+ gt_mask = core.BufferedArray(gt_mask_array, partition.start_index)
1837
+ gt_phased = core.BufferedArray(gt_phased_array, partition.start_index)
1838
+
1839
+ source_col = self.icf.columns["FORMAT/GT"]
1840
+ for value in source_col.iter_values(
1841
+ partition.start_index, partition.stop_index
1842
+ ):
1569
1843
  j = gt.next_buffer_row()
1570
1844
  sanitise_value_int_2d(gt.buff, j, value[:, :-1])
1571
1845
  j = gt_phased.next_buffer_row()
@@ -1577,29 +1851,40 @@ class VcfZarrWriter:
1577
1851
  gt.flush()
1578
1852
  gt_phased.flush()
1579
1853
  gt_mask.flush()
1580
- logger.debug(f"Encoded GT slice {start}:{stop}")
1581
1854
 
1582
- def encode_alleles_slice(self, start, stop):
1855
+ self.finalise_partition_array(partition_index, "call_genotype")
1856
+ self.finalise_partition_array(partition_index, "call_genotype_mask")
1857
+ self.finalise_partition_array(partition_index, "call_genotype_phased")
1858
+
1859
+ def encode_alleles_partition(self, partition_index):
1860
+ array_name = "variant_allele"
1861
+ alleles_array = self.init_partition_array(partition_index, array_name)
1862
+ partition = self.metadata.partitions[partition_index]
1863
+ alleles = core.BufferedArray(alleles_array, partition.start_index)
1583
1864
  ref_col = self.icf.columns["REF"]
1584
1865
  alt_col = self.icf.columns["ALT"]
1585
- alleles = core.BufferedArray(self.get_array("variant_allele"), start)
1586
1866
 
1587
1867
  for ref, alt in zip(
1588
- ref_col.iter_values(start, stop), alt_col.iter_values(start, stop)
1868
+ ref_col.iter_values(partition.start_index, partition.stop_index),
1869
+ alt_col.iter_values(partition.start_index, partition.stop_index),
1589
1870
  ):
1590
1871
  j = alleles.next_buffer_row()
1591
1872
  alleles.buff[j, :] = STR_FILL
1592
1873
  alleles.buff[j, 0] = ref[0]
1593
1874
  alleles.buff[j, 1 : 1 + len(alt)] = alt
1594
1875
  alleles.flush()
1595
- logger.debug(f"Encoded alleles slice {start}:{stop}")
1596
1876
 
1597
- def encode_id_slice(self, start, stop):
1877
+ self.finalise_partition_array(partition_index, array_name)
1878
+
1879
+ def encode_id_partition(self, partition_index):
1880
+ vid_array = self.init_partition_array(partition_index, "variant_id")
1881
+ vid_mask_array = self.init_partition_array(partition_index, "variant_id_mask")
1882
+ partition = self.metadata.partitions[partition_index]
1883
+ vid = core.BufferedArray(vid_array, partition.start_index)
1884
+ vid_mask = core.BufferedArray(vid_mask_array, partition.start_index)
1598
1885
  col = self.icf.columns["ID"]
1599
- vid = core.BufferedArray(self.get_array("variant_id"), start)
1600
- vid_mask = core.BufferedArray(self.get_array("variant_id_mask"), start)
1601
1886
 
1602
- for value in col.iter_values(start, stop):
1887
+ for value in col.iter_values(partition.start_index, partition.stop_index):
1603
1888
  j = vid.next_buffer_row()
1604
1889
  k = vid_mask.next_buffer_row()
1605
1890
  assert j == k
@@ -1611,28 +1896,41 @@ class VcfZarrWriter:
1611
1896
  vid_mask.buff[j] = True
1612
1897
  vid.flush()
1613
1898
  vid_mask.flush()
1614
- logger.debug(f"Encoded ID slice {start}:{stop}")
1615
1899
 
1616
- def encode_filters_slice(self, lookup, start, stop):
1617
- col = self.icf.columns["FILTERS"]
1618
- var_filter = core.BufferedArray(self.get_array("variant_filter"), start)
1900
+ self.finalise_partition_array(partition_index, "variant_id")
1901
+ self.finalise_partition_array(partition_index, "variant_id_mask")
1902
+
1903
+ def encode_filters_partition(self, partition_index):
1904
+ lookup = {filt: index for index, filt in enumerate(self.schema.filter_id)}
1905
+ array_name = "variant_filter"
1906
+ array = self.init_partition_array(partition_index, array_name)
1907
+ partition = self.metadata.partitions[partition_index]
1908
+ var_filter = core.BufferedArray(array, partition.start_index)
1619
1909
 
1620
- for value in col.iter_values(start, stop):
1910
+ col = self.icf.columns["FILTERS"]
1911
+ for value in col.iter_values(partition.start_index, partition.stop_index):
1621
1912
  j = var_filter.next_buffer_row()
1622
1913
  var_filter.buff[j] = False
1623
1914
  for f in value:
1624
1915
  try:
1625
1916
  var_filter.buff[j, lookup[f]] = True
1626
1917
  except KeyError:
1627
- raise ValueError(f"Filter '{f}' was not defined in the header.")
1918
+ raise ValueError(
1919
+ f"Filter '{f}' was not defined in the header."
1920
+ ) from None
1628
1921
  var_filter.flush()
1629
- logger.debug(f"Encoded FILTERS slice {start}:{stop}")
1630
1922
 
1631
- def encode_contig_slice(self, lookup, start, stop):
1923
+ self.finalise_partition_array(partition_index, array_name)
1924
+
1925
+ def encode_contig_partition(self, partition_index):
1926
+ lookup = {contig: index for index, contig in enumerate(self.schema.contig_id)}
1927
+ array_name = "variant_contig"
1928
+ array = self.init_partition_array(partition_index, array_name)
1929
+ partition = self.metadata.partitions[partition_index]
1930
+ contig = core.BufferedArray(array, partition.start_index)
1632
1931
  col = self.icf.columns["CHROM"]
1633
- contig = core.BufferedArray(self.get_array("variant_contig"), start)
1634
1932
 
1635
- for value in col.iter_values(start, stop):
1933
+ for value in col.iter_values(partition.start_index, partition.stop_index):
1636
1934
  j = contig.next_buffer_row()
1637
1935
  # Note: because we are using the indexes to define the lookups
1638
1936
  # and we always have an index, it seems that we the contig lookup
@@ -1640,160 +1938,120 @@ class VcfZarrWriter:
1640
1938
  # here, please do open an issue with a reproducible example!
1641
1939
  contig.buff[j] = lookup[value[0]]
1642
1940
  contig.flush()
1643
- logger.debug(f"Encoded CHROM slice {start}:{stop}")
1644
-
1645
- def encode_samples(self):
1646
- if not np.array_equal(self.schema.sample_id, self.icf.metadata.samples):
1647
- raise ValueError(
1648
- "Subsetting or reordering samples not supported currently"
1649
- ) # NEEDS TEST
1650
- array = self.root.array(
1651
- "sample_id",
1652
- self.schema.sample_id,
1653
- dtype="str",
1654
- compressor=DEFAULT_ZARR_COMPRESSOR,
1655
- chunks=(self.schema.samples_chunk_size,),
1656
- )
1657
- array.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
1658
- logger.debug("Samples done")
1659
1941
 
1660
- def encode_contig_id(self):
1661
- array = self.root.array(
1662
- "contig_id",
1663
- self.schema.contig_id,
1664
- dtype="str",
1665
- compressor=DEFAULT_ZARR_COMPRESSOR,
1666
- )
1667
- array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1668
- if self.schema.contig_length is not None:
1669
- array = self.root.array(
1670
- "contig_length",
1671
- self.schema.contig_length,
1672
- dtype=np.int64,
1673
- compressor=DEFAULT_ZARR_COMPRESSOR,
1942
+ self.finalise_partition_array(partition_index, array_name)
1943
+
1944
+ #######################
1945
+ # finalise
1946
+ #######################
1947
+
1948
+ def finalise_array(self, name):
1949
+ logger.info(f"Finalising {name}")
1950
+ final_path = self.path / name
1951
+ if final_path.exists():
1952
+ # NEEDS TEST
1953
+ raise ValueError(f"Array {name} already exists")
1954
+ for partition in range(len(self.metadata.partitions)):
1955
+ # Move all the files in partition dir to dest dir
1956
+ src = self.partition_array_path(partition, name)
1957
+ if not src.exists():
1958
+ # Needs test
1959
+ raise ValueError(f"Partition {partition} of {name} does not exist")
1960
+ dest = self.arrays_path / name
1961
+ # This is Zarr v2 specific. Chunks in v3 with start with "c" prefix.
1962
+ chunk_files = [
1963
+ path for path in src.iterdir() if not path.name.startswith(".")
1964
+ ]
1965
+ # TODO check for a count of then number of files. If we require a
1966
+ # dimension_separator of "/" then we could make stronger assertions
1967
+ # here, as we'd always have num_variant_chunks
1968
+ logger.debug(
1969
+ f"Moving {len(chunk_files)} chunks for {name} partition {partition}"
1674
1970
  )
1675
- array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1676
- return {v: j for j, v in enumerate(self.schema.contig_id)}
1971
+ for chunk_file in chunk_files:
1972
+ os.rename(chunk_file, dest / chunk_file.name)
1973
+ # Finally, once all the chunks have moved into the arrays dir,
1974
+ # we move it out of wip
1975
+ os.rename(self.arrays_path / name, self.path / name)
1976
+ core.update_progress(1)
1677
1977
 
1678
- def encode_filter_id(self):
1679
- array = self.root.array(
1680
- "filter_id",
1681
- self.schema.filter_id,
1682
- dtype="str",
1683
- compressor=DEFAULT_ZARR_COMPRESSOR,
1978
+ def finalise(self, show_progress=False):
1979
+ self.load_metadata()
1980
+
1981
+ progress_config = core.ProgressConfig(
1982
+ total=len(self.schema.columns),
1983
+ title="Finalise",
1984
+ units="array",
1985
+ show=show_progress,
1684
1986
  )
1685
- array.attrs["_ARRAY_DIMENSIONS"] = ["filters"]
1686
- return {v: j for j, v in enumerate(self.schema.filter_id)}
1987
+ # NOTE: it's not clear that adding more workers will make this quicker,
1988
+ # as it's just going to be causing contention on the file system.
1989
+ # Something to check empirically in some deployments.
1990
+ # FIXME we're just using worker_processes=0 here to hook into the
1991
+ # SynchronousExecutor which is intended for testing purposes so
1992
+ # that we get test coverage. Should fix this either by allowing
1993
+ # for multiple workers, or making a standard wrapper for tqdm
1994
+ # that allows us to have a consistent look and feel.
1995
+ with core.ParallelWorkManager(0, progress_config) as pwm:
1996
+ for name in self.schema.columns:
1997
+ pwm.submit(self.finalise_array, name)
1998
+ zarr.consolidate_metadata(self.path)
1687
1999
 
1688
- def init(self):
1689
- self.root.attrs["vcf_zarr_version"] = "0.2"
1690
- self.root.attrs["vcf_header"] = self.icf.vcf_header
1691
- self.root.attrs["source"] = f"bio2zarr-{provenance.__version__}"
1692
- for column in self.schema.columns.values():
1693
- self.init_array(column)
2000
+ ######################
2001
+ # encode_all_partitions
2002
+ ######################
1694
2003
 
1695
- def finalise(self):
1696
- zarr.consolidate_metadata(self.path)
2004
+ def get_max_encoding_memory(self):
2005
+ """
2006
+ Return the approximate maximum memory used to encode a variant chunk.
2007
+ """
2008
+ max_encoding_mem = max(
2009
+ col.variant_chunk_nbytes for col in self.schema.columns.values()
2010
+ )
2011
+ gt_mem = 0
2012
+ if "call_genotype" in self.schema.columns:
2013
+ encoded_together = [
2014
+ "call_genotype",
2015
+ "call_genotype_phased",
2016
+ "call_genotype_mask",
2017
+ ]
2018
+ gt_mem = sum(
2019
+ self.schema.columns[col].variant_chunk_nbytes
2020
+ for col in encoded_together
2021
+ )
2022
+ return max(max_encoding_mem, gt_mem)
1697
2023
 
1698
- def encode(
1699
- self,
1700
- worker_processes=1,
1701
- max_v_chunks=None,
1702
- show_progress=False,
1703
- max_memory=None,
2024
+ def encode_all_partitions(
2025
+ self, *, worker_processes=1, show_progress=False, max_memory=None
1704
2026
  ):
1705
2027
  max_memory = parse_max_memory(max_memory)
1706
-
1707
- # TODO this will move into the setup logic later when we're making it possible
1708
- # to split the work by slice
1709
- num_slices = max(1, worker_processes * 4)
1710
- # Using POS arbitrarily to get the array slices
1711
- slices = core.chunk_aligned_slices(
1712
- self.get_array("variant_position"), num_slices, max_chunks=max_v_chunks
2028
+ self.load_metadata()
2029
+ num_partitions = self.num_partitions
2030
+ per_worker_memory = self.get_max_encoding_memory()
2031
+ logger.info(
2032
+ f"Encoding Zarr over {num_partitions} partitions with "
2033
+ f"{worker_processes} workers and {display_size(per_worker_memory)} "
2034
+ "per worker"
1713
2035
  )
1714
- truncated = slices[-1][-1]
1715
- for array in self.root.values():
1716
- if array.attrs["_ARRAY_DIMENSIONS"][0] == "variants":
1717
- shape = list(array.shape)
1718
- shape[0] = truncated
1719
- array.resize(shape)
1720
-
1721
- total_bytes = 0
1722
- encoding_memory_requirements = {}
1723
- for col in self.schema.columns.values():
1724
- array = self.get_array(col.name)
1725
- # NOTE!! this is bad, we're potentially creating quite a large
1726
- # numpy array for basically nothing. We can compute this.
1727
- variant_chunk_size = array.blocks[0].nbytes
1728
- encoding_memory_requirements[col.name] = variant_chunk_size
1729
- logger.debug(
1730
- f"{col.name} requires at least {display_size(variant_chunk_size)} per worker"
2036
+ # Each partition requires per_worker_memory bytes, so to prevent more that
2037
+ # max_memory being used, we clamp the number of workers
2038
+ max_num_workers = max_memory // per_worker_memory
2039
+ if max_num_workers < worker_processes:
2040
+ logger.warning(
2041
+ f"Limiting number of workers to {max_num_workers} to "
2042
+ f"keep within specified memory budget of {display_size(max_memory)}"
1731
2043
  )
1732
- total_bytes += array.nbytes
1733
-
1734
- filter_id_map = self.encode_filter_id()
1735
- contig_id_map = self.encode_contig_id()
1736
-
1737
- work = []
1738
- for start, stop in slices:
1739
- for col in self.schema.columns.values():
1740
- if col.vcf_field is not None:
1741
- f = functools.partial(self.encode_array_slice, col)
1742
- work.append(
1743
- EncodingWork(
1744
- f,
1745
- start,
1746
- stop,
1747
- [col.name],
1748
- encoding_memory_requirements[col.name],
1749
- )
1750
- )
1751
- work.append(
1752
- EncodingWork(self.encode_alleles_slice, start, stop, ["variant_allele"])
1753
- )
1754
- work.append(
1755
- EncodingWork(
1756
- self.encode_id_slice, start, stop, ["variant_id", "variant_id_mask"]
1757
- )
1758
- )
1759
- work.append(
1760
- EncodingWork(
1761
- functools.partial(self.encode_filters_slice, filter_id_map),
1762
- start,
1763
- stop,
1764
- ["variant_filter"],
1765
- )
1766
- )
1767
- work.append(
1768
- EncodingWork(
1769
- functools.partial(self.encode_contig_slice, contig_id_map),
1770
- start,
1771
- stop,
1772
- ["variant_contig"],
1773
- )
2044
+ if max_num_workers <= 0:
2045
+ raise ValueError(
2046
+ f"Insufficient memory to encode a partition:"
2047
+ f"{display_size(per_worker_memory)} > {display_size(max_memory)}"
1774
2048
  )
1775
- if "call_genotype" in self.schema.columns:
1776
- variables = [
1777
- "call_genotype",
1778
- "call_genotype_phased",
1779
- "call_genotype_mask",
1780
- ]
1781
- gt_memory = sum(
1782
- encoding_memory_requirements[name] for name in variables
1783
- )
1784
- work.append(
1785
- EncodingWork(
1786
- self.encode_genotypes_slice, start, stop, variables, gt_memory
1787
- )
1788
- )
2049
+ num_workers = min(max_num_workers, worker_processes)
1789
2050
 
1790
- # Fail early if we can't fit a particular column into memory
1791
- for wp in work:
1792
- if wp.memory > max_memory:
1793
- raise ValueError(
1794
- f"Insufficient memory for {wp.columns}: "
1795
- f"{display_size(wp.memory)} > {display_size(max_memory)}"
1796
- )
2051
+ total_bytes = 0
2052
+ for col in self.schema.columns.values():
2053
+ # Open the array definition to get the total size
2054
+ total_bytes += zarr.open(self.arrays_path / col.name).nbytes
1797
2055
 
1798
2056
  progress_config = core.ProgressConfig(
1799
2057
  total=total_bytes,
@@ -1801,53 +2059,9 @@ class VcfZarrWriter:
1801
2059
  units="B",
1802
2060
  show=show_progress,
1803
2061
  )
1804
-
1805
- used_memory = 0
1806
- # We need to keep some bounds on the queue size or the memory bounds algorithm
1807
- # below doesn't really work.
1808
- max_queued = 4 * max(1, worker_processes)
1809
- encoded_slices = collections.Counter()
1810
-
1811
- with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1812
- future = pwm.submit(self.encode_samples)
1813
- future_to_work = {future: EncodingWork(None, 0, 0, [])}
1814
-
1815
- def service_completed_futures():
1816
- nonlocal used_memory
1817
-
1818
- completed = pwm.wait_for_completed()
1819
- for future in completed:
1820
- wp_done = future_to_work.pop(future)
1821
- used_memory -= wp_done.memory
1822
- logger.debug(
1823
- f"Complete {wp_done}: used mem={display_size(used_memory)}"
1824
- )
1825
- for column in wp_done.columns:
1826
- encoded_slices[column] += 1
1827
- if encoded_slices[column] == len(slices):
1828
- # Do this syncronously for simplicity. Should be
1829
- # fine as the workers will probably be busy with
1830
- # large encode tasks most of the time.
1831
- self.finalise_array(column)
1832
-
1833
- for wp in work:
1834
- while (
1835
- used_memory + wp.memory > max_memory
1836
- or len(future_to_work) > max_queued
1837
- ):
1838
- logger.debug(
1839
- f"Wait: mem_required={used_memory + wp.memory} max_mem={max_memory} "
1840
- f"queued={len(future_to_work)} max_queued={max_queued}"
1841
- )
1842
- service_completed_futures()
1843
- future = pwm.submit(wp.func, wp.start, wp.stop)
1844
- used_memory += wp.memory
1845
- logger.debug(f"Submit {wp}: used mem={display_size(used_memory)}")
1846
- future_to_work[future] = wp
1847
-
1848
- logger.debug("All work submitted")
1849
- while len(future_to_work) > 0:
1850
- service_completed_futures()
2062
+ with core.ParallelWorkManager(num_workers, progress_config) as pwm:
2063
+ for partition_index in range(num_partitions):
2064
+ pwm.submit(self.encode_partition, partition_index)
1851
2065
 
1852
2066
 
1853
2067
  def mkschema(if_path, out):
@@ -1862,13 +2076,48 @@ def encode(
1862
2076
  schema_path=None,
1863
2077
  variants_chunk_size=None,
1864
2078
  samples_chunk_size=None,
1865
- max_v_chunks=None,
2079
+ max_variant_chunks=None,
1866
2080
  dimension_separator=None,
1867
2081
  max_memory=None,
1868
2082
  worker_processes=1,
1869
2083
  show_progress=False,
1870
2084
  ):
1871
- icf = IntermediateColumnarFormat(if_path)
2085
+ # Rough heuristic to split work up enough to keep utilisation high
2086
+ target_num_partitions = max(1, worker_processes * 4)
2087
+ encode_init(
2088
+ if_path,
2089
+ zarr_path,
2090
+ target_num_partitions,
2091
+ schema_path=schema_path,
2092
+ variants_chunk_size=variants_chunk_size,
2093
+ samples_chunk_size=samples_chunk_size,
2094
+ max_variant_chunks=max_variant_chunks,
2095
+ dimension_separator=dimension_separator,
2096
+ )
2097
+ vzw = VcfZarrWriter(zarr_path)
2098
+ vzw.encode_all_partitions(
2099
+ worker_processes=worker_processes,
2100
+ show_progress=show_progress,
2101
+ max_memory=max_memory,
2102
+ )
2103
+ vzw.finalise(show_progress)
2104
+
2105
+
2106
+ def encode_init(
2107
+ icf_path,
2108
+ zarr_path,
2109
+ target_num_partitions,
2110
+ *,
2111
+ schema_path=None,
2112
+ variants_chunk_size=None,
2113
+ samples_chunk_size=None,
2114
+ max_variant_chunks=None,
2115
+ dimension_separator=None,
2116
+ max_memory=None,
2117
+ worker_processes=1,
2118
+ show_progress=False,
2119
+ ):
2120
+ icf = IntermediateColumnarFormat(icf_path)
1872
2121
  if schema_path is None:
1873
2122
  schema = VcfZarrSchema.generate(
1874
2123
  icf,
@@ -1881,21 +2130,28 @@ def encode(
1881
2130
  raise ValueError(
1882
2131
  "Cannot specify schema along with chunk sizes"
1883
2132
  ) # NEEDS TEST
1884
- with open(schema_path, "r") as f:
2133
+ with open(schema_path) as f:
1885
2134
  schema = VcfZarrSchema.fromjson(f.read())
1886
2135
  zarr_path = pathlib.Path(zarr_path)
1887
- if zarr_path.exists():
1888
- logger.warning(f"Deleting existing {zarr_path}")
1889
- shutil.rmtree(zarr_path)
1890
- vzw = VcfZarrWriter(zarr_path, icf, schema, dimension_separator=dimension_separator)
1891
- vzw.init()
1892
- vzw.encode(
1893
- max_v_chunks=max_v_chunks,
1894
- worker_processes=worker_processes,
1895
- max_memory=max_memory,
1896
- show_progress=show_progress,
2136
+ vzw = VcfZarrWriter(zarr_path)
2137
+ vzw.init(
2138
+ icf,
2139
+ target_num_partitions=target_num_partitions,
2140
+ schema=schema,
2141
+ dimension_separator=dimension_separator,
2142
+ max_variant_chunks=max_variant_chunks,
1897
2143
  )
1898
- vzw.finalise()
2144
+ return vzw.num_partitions, vzw.get_max_encoding_memory()
2145
+
2146
+
2147
+ def encode_partition(zarr_path, partition):
2148
+ writer = VcfZarrWriter(zarr_path)
2149
+ writer.encode_partition(partition)
2150
+
2151
+
2152
+ def encode_finalise(zarr_path, show_progress=False):
2153
+ writer = VcfZarrWriter(zarr_path)
2154
+ writer.finalise(show_progress=show_progress)
1899
2155
 
1900
2156
 
1901
2157
  def convert(
@@ -1962,7 +2218,7 @@ def assert_all_fill(zarr_val, vcf_type):
1962
2218
  elif vcf_type == "Float":
1963
2219
  assert_all_fill_float(zarr_val)
1964
2220
  else: # pragma: no cover
1965
- assert False
2221
+ assert False # noqa PT015
1966
2222
 
1967
2223
 
1968
2224
  def assert_all_missing(zarr_val, vcf_type):
@@ -1975,7 +2231,7 @@ def assert_all_missing(zarr_val, vcf_type):
1975
2231
  elif vcf_type == "Float":
1976
2232
  assert_all_missing_float(zarr_val)
1977
2233
  else: # pragma: no cover
1978
- assert False
2234
+ assert False # noqa PT015
1979
2235
 
1980
2236
 
1981
2237
  def assert_info_val_missing(zarr_val, vcf_type):
@@ -2105,7 +2361,7 @@ def validate(vcf_path, zarr_path, show_progress=False):
2105
2361
  assert pos[start_index] == first_pos
2106
2362
  vcf = cyvcf2.VCF(vcf_path)
2107
2363
  if show_progress:
2108
- iterator = tqdm.tqdm(vcf, desc=" Verify", total=vcf.num_records) # NEEDS TEST
2364
+ iterator = tqdm.tqdm(vcf, desc=" Verify", total=vcf.num_records) # NEEDS TEST
2109
2365
  else:
2110
2366
  iterator = vcf
2111
2367
  for j, row in enumerate(iterator, start_index):
@@ -2114,7 +2370,7 @@ def validate(vcf_path, zarr_path, show_progress=False):
2114
2370
  assert vid[j] == ("." if row.ID is None else row.ID)
2115
2371
  assert allele[j, 0] == row.REF
2116
2372
  k = len(row.ALT)
2117
- nt.assert_array_equal(allele[j, 1 : k + 1], row.ALT),
2373
+ nt.assert_array_equal(allele[j, 1 : k + 1], row.ALT)
2118
2374
  assert np.all(allele[j, k + 1 :] == "")
2119
2375
  # TODO FILTERS
2120
2376