bio2zarr 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bio2zarr might be problematic. Click here for more details.

bio2zarr/tskit.py ADDED
@@ -0,0 +1,301 @@
1
+ import logging
2
+ import pathlib
3
+
4
+ import numpy as np
5
+
6
+ from bio2zarr import constants, core, vcz
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class TskitFormat(vcz.Source):
12
+ @core.requires_optional_dependency("tskit", "tskit")
13
+ def __init__(
14
+ self,
15
+ ts,
16
+ *,
17
+ model_mapping=None,
18
+ contig_id=None,
19
+ isolated_as_missing=False,
20
+ ):
21
+ import tskit
22
+
23
+ self._path = None
24
+ # Future versions here will need to deal with the complexities of
25
+ # having lists of tree sequences for multiple chromosomes.
26
+ if isinstance(ts, tskit.TreeSequence):
27
+ self.ts = ts
28
+ else:
29
+ # input 'ts' is a path.
30
+ self._path = ts
31
+ logger.info(f"Loading from {ts}")
32
+ self.ts = tskit.load(ts)
33
+ logger.info(
34
+ f"Input has {self.ts.num_individuals} individuals and "
35
+ f"{self.ts.num_sites} sites"
36
+ )
37
+
38
+ self.contig_id = contig_id if contig_id is not None else "1"
39
+ self.isolated_as_missing = isolated_as_missing
40
+
41
+ self.positions = self.ts.sites_position
42
+
43
+ if model_mapping is None:
44
+ model_mapping = self.ts.map_to_vcf_model()
45
+
46
+ individuals_nodes = model_mapping.individuals_nodes
47
+ sample_ids = model_mapping.individuals_name
48
+
49
+ self._num_samples = individuals_nodes.shape[0]
50
+ logger.info(f"Converting for {self._num_samples} samples")
51
+ if self._num_samples < 1:
52
+ raise ValueError("individuals_nodes must have at least one sample")
53
+ self.max_ploidy = individuals_nodes.shape[1]
54
+ if len(sample_ids) != self._num_samples:
55
+ raise ValueError(
56
+ f"Length of sample_ids ({len(sample_ids)}) does not match "
57
+ f"number of samples ({self._num_samples})"
58
+ )
59
+
60
+ self._samples = [vcz.Sample(id=sample_id) for sample_id in sample_ids]
61
+
62
+ self.tskit_samples = np.unique(individuals_nodes[individuals_nodes >= 0])
63
+ if len(self.tskit_samples) < 1:
64
+ raise ValueError("individuals_nodes must have at least one valid sample")
65
+ node_id_to_index = {node_id: i for i, node_id in enumerate(self.tskit_samples)}
66
+ valid_mask = individuals_nodes >= 0
67
+ self.sample_indices, self.ploidy_indices = np.where(valid_mask)
68
+ self.genotype_indices = np.array(
69
+ [node_id_to_index[node_id] for node_id in individuals_nodes[valid_mask]]
70
+ )
71
+
72
+ @property
73
+ def path(self):
74
+ return self._path
75
+
76
+ @property
77
+ def num_records(self):
78
+ return self.ts.num_sites
79
+
80
+ @property
81
+ def num_samples(self):
82
+ return self._num_samples
83
+
84
+ @property
85
+ def samples(self):
86
+ return self._samples
87
+
88
+ @property
89
+ def root_attrs(self):
90
+ return {}
91
+
92
+ @property
93
+ def contigs(self):
94
+ return [vcz.Contig(id=self.contig_id)]
95
+
96
+ def iter_contig(self, start, stop):
97
+ yield from (0 for _ in range(start, stop))
98
+
99
+ def iter_field(self, field_name, shape, start, stop):
100
+ if field_name == "position":
101
+ for pos in self.ts.sites_position[start:stop]:
102
+ yield int(pos)
103
+ else:
104
+ raise ValueError(f"Unknown field {field_name}")
105
+
106
+ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
107
+ # All genotypes in tskit are considered phased
108
+ phased = np.ones(shape[:-1], dtype=bool)
109
+ logger.debug(f"Getting genotpes start={start} stop={stop}")
110
+
111
+ for variant in self.ts.variants(
112
+ isolated_as_missing=self.isolated_as_missing,
113
+ left=self.positions[start],
114
+ right=self.positions[stop] if stop < self.num_records else None,
115
+ samples=self.tskit_samples,
116
+ copy=False,
117
+ ):
118
+ gt = np.full(shape, constants.INT_FILL, dtype=np.int8)
119
+ alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
120
+ # length is the length of the REF allele unless other fields
121
+ # are included.
122
+ variant_length = len(variant.alleles[0])
123
+ for i, allele in enumerate(variant.alleles):
124
+ # None is returned by tskit in the case of a missing allele
125
+ if allele is None:
126
+ continue
127
+ assert i < num_alleles
128
+ alleles[i] = allele
129
+ gt[self.sample_indices, self.ploidy_indices] = variant.genotypes[
130
+ self.genotype_indices
131
+ ]
132
+
133
+ yield vcz.VariantData(variant_length, alleles, gt, phased)
134
+
135
+ def generate_schema(
136
+ self,
137
+ variants_chunk_size=None,
138
+ samples_chunk_size=None,
139
+ ):
140
+ n = self.num_samples
141
+ m = self.ts.num_sites
142
+
143
+ # Determine max number of alleles
144
+ max_alleles = 0
145
+ for site in self.ts.sites():
146
+ states = {site.ancestral_state}
147
+ for mut in site.mutations:
148
+ states.add(mut.derived_state)
149
+ max_alleles = max(len(states), max_alleles)
150
+
151
+ logging.info(f"Scanned tskit with {n} samples and {m} variants")
152
+ logging.info(
153
+ f"Maximum ploidy: {self.max_ploidy}, maximum alleles: {max_alleles}"
154
+ )
155
+ dimensions = vcz.standard_dimensions(
156
+ variants_size=m,
157
+ variants_chunk_size=variants_chunk_size,
158
+ samples_size=n,
159
+ samples_chunk_size=samples_chunk_size,
160
+ ploidy_size=self.max_ploidy,
161
+ alleles_size=max_alleles,
162
+ )
163
+ schema_instance = vcz.VcfZarrSchema(
164
+ format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
165
+ dimensions=dimensions,
166
+ fields=[],
167
+ )
168
+
169
+ logger.info(
170
+ "Generating schema with chunks="
171
+ f"{schema_instance.dimensions['variants'].chunk_size}, "
172
+ f"{schema_instance.dimensions['samples'].chunk_size}"
173
+ )
174
+
175
+ # Check if positions will fit in i4 (max ~2.1 billion)
176
+ min_position = 0
177
+ max_position = 0
178
+ if self.ts.num_sites > 0:
179
+ min_position = np.min(self.ts.sites_position)
180
+ max_position = np.max(self.ts.sites_position)
181
+
182
+ tables = self.ts.tables
183
+ ancestral_state_offsets = tables.sites.ancestral_state_offset
184
+ derived_state_offsets = tables.mutations.derived_state_offset
185
+ ancestral_lengths = ancestral_state_offsets[1:] - ancestral_state_offsets[:-1]
186
+ derived_lengths = derived_state_offsets[1:] - derived_state_offsets[:-1]
187
+ max_variant_length = max(
188
+ np.max(ancestral_lengths) if len(ancestral_lengths) > 0 else 0,
189
+ np.max(derived_lengths) if len(derived_lengths) > 0 else 0,
190
+ )
191
+
192
+ array_specs = [
193
+ vcz.ZarrArraySpec(
194
+ source="position",
195
+ name="variant_position",
196
+ dtype=core.min_int_dtype(min_position, max_position),
197
+ dimensions=["variants"],
198
+ description="Position of each variant",
199
+ ),
200
+ vcz.ZarrArraySpec(
201
+ source=None,
202
+ name="variant_allele",
203
+ dtype="O",
204
+ dimensions=["variants", "alleles"],
205
+ description="Alleles for each variant",
206
+ ),
207
+ vcz.ZarrArraySpec(
208
+ source=None,
209
+ name="variant_length",
210
+ dtype=core.min_int_dtype(0, max_variant_length),
211
+ dimensions=["variants"],
212
+ description="Length of each variant",
213
+ ),
214
+ vcz.ZarrArraySpec(
215
+ source=None,
216
+ name="variant_contig",
217
+ dtype=core.min_int_dtype(0, len(self.contigs)),
218
+ dimensions=["variants"],
219
+ description="Contig/chromosome index for each variant",
220
+ ),
221
+ vcz.ZarrArraySpec(
222
+ source=None,
223
+ name="call_genotype_phased",
224
+ dtype="bool",
225
+ dimensions=["variants", "samples"],
226
+ description="Whether the genotype is phased",
227
+ compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
228
+ ),
229
+ vcz.ZarrArraySpec(
230
+ source=None,
231
+ name="call_genotype",
232
+ dtype=core.min_int_dtype(constants.INT_FILL, max_alleles - 1),
233
+ dimensions=["variants", "samples", "ploidy"],
234
+ description="Genotype for each variant and sample",
235
+ compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(),
236
+ ),
237
+ vcz.ZarrArraySpec(
238
+ source=None,
239
+ name="call_genotype_mask",
240
+ dtype="bool",
241
+ dimensions=["variants", "samples", "ploidy"],
242
+ description="Mask for each genotype call",
243
+ compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
244
+ ),
245
+ ]
246
+ schema_instance.fields = array_specs
247
+ return schema_instance
248
+
249
+
250
+ def convert(
251
+ ts_or_path,
252
+ vcz_path,
253
+ *,
254
+ model_mapping=None,
255
+ contig_id=None,
256
+ isolated_as_missing=False,
257
+ variants_chunk_size=None,
258
+ samples_chunk_size=None,
259
+ worker_processes=core.DEFAULT_WORKER_PROCESSES,
260
+ show_progress=False,
261
+ ):
262
+ """
263
+ Convert a :class:`tskit.TreeSequence` (or path to a tree sequence
264
+ file) to VCF Zarr format stored at the specified path.
265
+
266
+ .. todo:: Document parameters
267
+ """
268
+ # FIXME there's some tricky details here in how we're handling
269
+ # parallelism that we'll need to tackle properly, and maybe
270
+ # review the current structures a bit. Basically, it looks like
271
+ # we're pickling/unpickling the format object when we have
272
+ # multiple workers, and this results in several copies of the
273
+ # tree sequence object being pass around. This is fine most
274
+ # of the time, but results in lots of memory being used when
275
+ # we're dealing with really massive files.
276
+ # See https://github.com/sgkit-dev/bio2zarr/issues/403
277
+ tskit_format = TskitFormat(
278
+ ts_or_path,
279
+ model_mapping=model_mapping,
280
+ contig_id=contig_id,
281
+ isolated_as_missing=isolated_as_missing,
282
+ )
283
+ schema_instance = tskit_format.generate_schema(
284
+ variants_chunk_size=variants_chunk_size,
285
+ samples_chunk_size=samples_chunk_size,
286
+ )
287
+ zarr_path = pathlib.Path(vcz_path)
288
+ vzw = vcz.VcfZarrWriter(TskitFormat, zarr_path)
289
+ # Rough heuristic to split work up enough to keep utilisation high
290
+ target_num_partitions = max(1, worker_processes * 4)
291
+ vzw.init(
292
+ tskit_format,
293
+ target_num_partitions=target_num_partitions,
294
+ schema=schema_instance,
295
+ )
296
+ vzw.encode_all_partitions(
297
+ worker_processes=worker_processes,
298
+ show_progress=show_progress,
299
+ )
300
+ vzw.finalise(show_progress)
301
+ vzw.create_index()
bio2zarr/typing.py CHANGED
@@ -1,4 +1,3 @@
1
1
  from pathlib import Path
2
- from typing import Union
3
2
 
4
- PathType = Union[str, Path]
3
+ PathType = str | Path