tskit 1.0.0b1__cp310-cp310-win_amd64.whl → 1.0.0b2__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
tskit/_version.py CHANGED
@@ -1,4 +1,4 @@
1
1
  # Definitive location for the version number.
2
2
  # During development, should be x.y.z.devN
3
3
  # For beta should be x.y.zbN
4
- tskit_version = "1.0.0b1"
4
+ tskit_version = "1.0.0b2"
tskit/jit/__init__.py ADDED
File without changes
tskit/jit/numba.py ADDED
@@ -0,0 +1,674 @@
1
+ import functools
2
+ import os
3
+
4
+ import numpy as np
5
+
6
+ import tskit
7
+
8
+ try:
9
+ import numba
10
+
11
+ except ImportError:
12
+ raise ImportError(
13
+ "Numba is not installed. Please install it with `pip install numba` "
14
+ "or `conda install numba` to use the tskit.jit.numba module."
15
+ )
16
+
17
+
18
+ FORWARD = 1 #: Direction constant for forward tree traversal
19
+ REVERSE = -1 #: Direction constant for reverse tree traversal
20
+
21
+ # Retrieve these here to avoid lookups in tight loops
22
+ NODE_IS_SAMPLE = tskit.NODE_IS_SAMPLE
23
+ NULL = tskit.NULL
24
+
25
+ edge_range_spec = [
26
+ ("start", numba.int32),
27
+ ("stop", numba.int32),
28
+ ("order", numba.int32[:]),
29
+ ]
30
+
31
+ parent_index_spec = [
32
+ ("edge_index", numba.int32[:]),
33
+ ("index_range", numba.int32[:, :]),
34
+ ]
35
+
36
+
37
+ @numba.experimental.jitclass(edge_range_spec)
38
+ class EdgeRange:
39
+ """
40
+ Represents a range of edges during tree traversal.
41
+
42
+ This class encapsulates information about a contiguous range of edges
43
+ that are either being removed or added to step from one tree to another.
44
+ The ``start`` and ``stop`` indices, when applied to the order array,
45
+ define the ids of edges to process.
46
+
47
+ Attributes
48
+ ----------
49
+ start : int
50
+ Starting index of the edge range (inclusive).
51
+ stop : int
52
+ Stopping index of the edge range (exclusive).
53
+ order : numpy.ndarray
54
+ Array (dtype=np.int32) containing edge IDs in the order they should be processed.
55
+ The edge ids in this range are order[start:stop].
56
+ """
57
+
58
+ def __init__(self, start, stop, order):
59
+ self.start = start
60
+ self.stop = stop
61
+ self.order = order
62
+
63
+
64
+ @numba.experimental.jitclass(parent_index_spec)
65
+ class ParentIndex:
66
+ """
67
+ Simple data container for parent index information.
68
+
69
+ This class provides access to all edges where a given node is the child.
70
+ Since edges are not sorted by child in the tskit edge table, a custom index
71
+ (edge_index) is built that sorts edge IDs by child node. `index_range`
72
+ then contains the [start, stop) range of edges for each child node in `edge_index`.
73
+
74
+ Attributes
75
+ ----------
76
+ edge_index : numpy.ndarray
77
+ Array (dtype=np.int32) of edge IDs sorted by child node and left coordinate.
78
+ index_range : numpy.ndarray
79
+ Array (dtype=np.int32, shape=(num_nodes, 2)) where each row contains the
80
+ [start, stop) range in edge_index where this node is the child.
81
+ """
82
+
83
+ def __init__(self, edge_index, index_range):
84
+ self.edge_index = edge_index
85
+ self.index_range = index_range
86
+
87
+
88
+ class TreeIndex:
89
+ """
90
+ Traverse trees in a numba compatible tree sequence.
91
+
92
+ This class provides efficient forward and backward iteration through
93
+ the trees in a tree sequence. It provides the tree interval,
94
+ edge changes to create the current tree, along with its sites and mutations.
95
+ A full pass over the trees using repeated `next` or `prev` requires O(E + M + S) time
96
+ complexity.
97
+
98
+ It should not be instantiated directly, but is returned by the `tree_index` method
99
+ of `NumbaTreeSequence`.
100
+
101
+
102
+ Attributes
103
+ ----------
104
+ ts : NumbaTreeSequence
105
+ Reference to the tree sequence being traversed.
106
+ index : int
107
+ Current tree index. -1 indicates no current tree (null state).
108
+ direction : int
109
+ Traversal direction: tskit.FORWARD or tskit.REVERSE. tskit.NULL
110
+ if uninitialised.
111
+ interval : tuple
112
+ Genomic interval (left, right) covered by the current tree.
113
+ in_range : EdgeRange
114
+ Edges being added to form this current tree, relative to the last state
115
+ out_range : EdgeRange
116
+ Edges being removed to form this current tree, relative to the last state
117
+ site_range : tuple
118
+ Range of sites in the current tree (start, stop).
119
+ mutation_range : tuple
120
+ Range of mutations in the current tree (start, stop).
121
+
122
+ Example
123
+ --------
124
+ >>> tree_index = numba_ts.tree_index()
125
+ >>> num_edges = 0
126
+ >>> while tree_index.next():
127
+ num_edges += (tree_index.in_range.stop - tree_index.in_range.start)
128
+ num_edges -= (tree_index.out_range.stop - tree_index.out_range.start)
129
+ print(f"Tree {tree_index.index}: {num_edges} edges")
130
+ """
131
+
132
+ def __init__(self, ts):
133
+ self.ts = ts
134
+ self.index = -1
135
+ self.direction = NULL
136
+ self.interval = (0, 0)
137
+ self.in_range = EdgeRange(0, 0, np.zeros(0, dtype=np.int32))
138
+ self.out_range = EdgeRange(0, 0, np.zeros(0, dtype=np.int32))
139
+ self.site_range = (0, 0)
140
+ self.mutation_range = (0, 0)
141
+
142
+ def set_null(self):
143
+ """
144
+ Reset the tree index to null state.
145
+ """
146
+ self.index = -1
147
+ self.interval = (0, 0)
148
+ self.site_range = (0, 0)
149
+ self.mutation_range = (0, 0)
150
+
151
+ def next(self): # noqa: A003
152
+ """
153
+ Move to the next tree in forward direction.
154
+
155
+ Updates the tree index to the next tree in the sequence,
156
+ computing the edges that need to be added and removed to
157
+ transform from the previous tree to the current tree.
158
+ On the first call, this initializes the iterator and moves to tree 0.
159
+
160
+ :return: True if successfully moved to next tree, False if the end
161
+ of the tree sequence is reached. When False is returned, the iterator
162
+ is in null state (index=-1).
163
+ :rtype: bool
164
+ """
165
+ M = self.ts.num_edges
166
+ NS = self.ts.num_sites
167
+ NM = self.ts.num_mutations
168
+ breakpoints = self.ts.breakpoints
169
+ left_coords = self.ts.edges_left
170
+ left_order = self.ts.indexes_edge_insertion_order
171
+ right_coords = self.ts.edges_right
172
+ right_order = self.ts.indexes_edge_removal_order
173
+ sites_position = self.ts.sites_position
174
+ mutations_site = self.ts.mutations_site
175
+
176
+ if self.index == -1:
177
+ self.interval = (self.interval[0], 0)
178
+ self.out_range.stop = 0
179
+ self.in_range.stop = 0
180
+ self.direction = FORWARD
181
+ self.site_range = (0, 0)
182
+ self.mutation_range = (0, 0)
183
+
184
+ if self.direction == FORWARD:
185
+ left_current_index = self.in_range.stop
186
+ right_current_index = self.out_range.stop
187
+ else:
188
+ left_current_index = self.out_range.stop + 1
189
+ right_current_index = self.in_range.stop + 1
190
+
191
+ left = self.interval[1]
192
+
193
+ j = right_current_index
194
+ self.out_range.start = j
195
+ while j < M and right_coords[right_order[j]] == left:
196
+ j += 1
197
+ self.out_range.stop = j
198
+ self.out_range.order = right_order
199
+
200
+ j = left_current_index
201
+ self.in_range.start = j
202
+ while j < M and left_coords[left_order[j]] == left:
203
+ j += 1
204
+ self.in_range.stop = j
205
+ self.in_range.order = left_order
206
+
207
+ self.direction = FORWARD
208
+ self.index += 1
209
+ if self.index == self.ts.num_trees:
210
+ self.set_null()
211
+ else:
212
+ right = breakpoints[self.index + 1]
213
+ self.interval = (left, right)
214
+
215
+ # Find sites in current tree interval [left, right)
216
+ old_site_left, old_site_right = self.site_range
217
+ j = old_site_right
218
+ while j < NS and sites_position[j] < right:
219
+ j += 1
220
+ self.site_range = (old_site_right, j)
221
+
222
+ # Find mutations for sites in this interval
223
+ old_mutation_left, old_mutation_right = self.mutation_range
224
+ k = old_mutation_right
225
+ while k < NM and mutations_site[k] < j:
226
+ k += 1
227
+ self.mutation_range = (old_mutation_right, k)
228
+
229
+ return self.index != -1
230
+
231
+ def prev(self):
232
+ """
233
+ Move to the previous tree in reverse direction.
234
+
235
+ Updates the tree index to the previous tree in the sequence,
236
+ computing the edges that need to be added and removed to
237
+ transform from the next tree to the current tree.
238
+ On the first call, this initializes the iterator and moves to the most
239
+ rightward tree.
240
+
241
+ :return: True if successfully moved to previous tree, False if the beginning
242
+ of the tree sequence is reached. When False is returned, the iterator
243
+ is in null state (index=-1).
244
+ :rtype: bool
245
+ """
246
+ M = self.ts.num_edges
247
+ NS = self.ts.num_sites
248
+ NM = self.ts.num_mutations
249
+ breakpoints = self.ts.breakpoints
250
+ right_coords = self.ts.edges_right
251
+ right_order = self.ts.indexes_edge_removal_order
252
+ left_coords = self.ts.edges_left
253
+ left_order = self.ts.indexes_edge_insertion_order
254
+ sites_position = self.ts.sites_position
255
+ mutations_site = self.ts.mutations_site
256
+
257
+ if self.index == -1:
258
+ self.index = self.ts.num_trees
259
+ self.interval = (self.ts.sequence_length, self.interval[1])
260
+ self.in_range.stop = M - 1
261
+ self.out_range.stop = M - 1
262
+ self.direction = REVERSE
263
+ self.site_range = (NS, NS)
264
+ self.mutation_range = (NM, NM)
265
+
266
+ if self.direction == REVERSE:
267
+ left_current_index = self.out_range.stop
268
+ right_current_index = self.in_range.stop
269
+ else:
270
+ left_current_index = self.in_range.stop - 1
271
+ right_current_index = self.out_range.stop - 1
272
+
273
+ right = self.interval[0]
274
+
275
+ j = left_current_index
276
+ self.out_range.start = j
277
+ while j >= 0 and left_coords[left_order[j]] == right:
278
+ j -= 1
279
+ self.out_range.stop = j
280
+ self.out_range.order = left_order
281
+
282
+ j = right_current_index
283
+ self.in_range.start = j
284
+ while j >= 0 and right_coords[right_order[j]] == right:
285
+ j -= 1
286
+ self.in_range.stop = j
287
+ self.in_range.order = right_order
288
+
289
+ self.direction = REVERSE
290
+ self.index -= 1
291
+ if self.index == -1:
292
+ self.set_null()
293
+ else:
294
+ left = breakpoints[self.index]
295
+ self.interval = (left, right)
296
+
297
+ # Find sites in current tree interval [left, right) going backward
298
+ old_site_left, old_site_right = self.site_range
299
+ j = old_site_left - 1
300
+ while j >= 0 and sites_position[j] >= left:
301
+ j -= 1
302
+ self.site_range = (j + 1, old_site_left)
303
+
304
+ # Find mutations for sites in this interval going backward
305
+ old_mutation_left, old_mutation_right = self.mutation_range
306
+ k = old_mutation_left - 1
307
+ while k >= 0 and mutations_site[k] >= self.site_range[0]:
308
+ k -= 1
309
+ self.mutation_range = (k + 1, old_mutation_left)
310
+
311
+ return self.index != -1
312
+
313
+
314
+ class NumbaTreeSequence:
315
+ """
316
+ A Numba-compatible representation of a tree sequence.
317
+
318
+ This class provides access a tree sequence class that can be used
319
+ from within Numba "njit" compiled functions. :meth:`jitwrap` should
320
+ be used to JIT compile this class from a :class:`tskit.TreeSequence` object,
321
+ before it is passed to a Numba function.
322
+
323
+ Attributes
324
+ ----------
325
+ num_trees : int
326
+ Number of trees in the tree sequence.
327
+ num_nodes : int
328
+ Number of nodes in the tree sequence.
329
+ num_samples : int
330
+ Number of samples in the tree sequence.
331
+ num_edges : int
332
+ Number of edges in the tree sequence.
333
+ num_sites : int
334
+ Number of sites in the tree sequence.
335
+ num_mutations : int
336
+ Number of mutations in the tree sequence.
337
+ sequence_length : float
338
+ Total sequence length of the tree sequence.
339
+ edges_left : numpy.ndarray
340
+ Array (dtype=np.float64) of left coordinates of edges.
341
+ edges_right : numpy.ndarray
342
+ Array (dtype=np.float64) of right coordinates of edges.
343
+ edges_parent : numpy.ndarray
344
+ Array (dtype=np.int32) of parent node IDs for each edge.
345
+ edges_child : numpy.ndarray
346
+ Array (dtype=np.int32) of child node IDs for each edge.
347
+ nodes_time : numpy.ndarray
348
+ Array (dtype=np.float64) of time values for each node.
349
+ nodes_flags : numpy.ndarray
350
+ Array (dtype=np.uint32) of flag values for each node.
351
+ nodes_population : numpy.ndarray
352
+ Array (dtype=np.int32) of population IDs for each node.
353
+ nodes_individual : numpy.ndarray
354
+ Array (dtype=np.int32) of individual IDs for each node.
355
+ individuals_flags : numpy.ndarray
356
+ Array (dtype=np.uint32) of flag values for each individual.
357
+ sites_position : numpy.ndarray
358
+ Array (dtype=np.float64) of positions of sites along the sequence.
359
+ mutations_site : numpy.ndarray
360
+ Array (dtype=np.int32) of site IDs for each mutation.
361
+ mutations_node : numpy.ndarray
362
+ Array (dtype=np.int32) of node IDs for each mutation.
363
+ mutations_parent : numpy.ndarray
364
+ Array (dtype=np.int32) of parent mutation IDs.
365
+ mutations_time : numpy.ndarray
366
+ Array (dtype=np.float64) of time values for each mutation.
367
+ breakpoints : numpy.ndarray
368
+ Array (dtype=np.float64) of genomic positions where trees change.
369
+ indexes_edge_insertion_order : numpy.ndarray
370
+ Array (dtype=np.int32) specifying the order in which edges are inserted
371
+ during tree building.
372
+ indexes_edge_removal_order : numpy.ndarray
373
+ Array (dtype=np.int32) specifying the order in which edges are removed
374
+ during tree building.
375
+
376
+ """
377
+
378
+ def __init__(
379
+ self,
380
+ num_trees,
381
+ num_nodes,
382
+ num_samples,
383
+ num_edges,
384
+ num_sites,
385
+ num_mutations,
386
+ sequence_length,
387
+ edges_left,
388
+ edges_right,
389
+ indexes_edge_insertion_order,
390
+ indexes_edge_removal_order,
391
+ individuals_flags,
392
+ nodes_time,
393
+ nodes_flags,
394
+ nodes_population,
395
+ nodes_individual,
396
+ edges_parent,
397
+ edges_child,
398
+ sites_position,
399
+ sites_ancestral_state,
400
+ mutations_site,
401
+ mutations_node,
402
+ mutations_parent,
403
+ mutations_time,
404
+ mutations_derived_state,
405
+ mutations_inherited_state,
406
+ breakpoints,
407
+ max_ancestral_length,
408
+ max_derived_length,
409
+ max_inherited_length,
410
+ ):
411
+ self.num_trees = num_trees
412
+ self.num_nodes = num_nodes
413
+ self.num_samples = num_samples
414
+ self.num_edges = num_edges
415
+ self.num_sites = num_sites
416
+ self.num_mutations = num_mutations
417
+ self.sequence_length = sequence_length
418
+ self.edges_left = edges_left
419
+ self.edges_right = edges_right
420
+ self.indexes_edge_insertion_order = indexes_edge_insertion_order
421
+ self.indexes_edge_removal_order = indexes_edge_removal_order
422
+ self.individuals_flags = individuals_flags
423
+ self.nodes_time = nodes_time
424
+ self.nodes_flags = nodes_flags
425
+ self.nodes_population = nodes_population
426
+ self.nodes_individual = nodes_individual
427
+ self.edges_parent = edges_parent
428
+ self.edges_child = edges_child
429
+ self.sites_position = sites_position
430
+ self.sites_ancestral_state = sites_ancestral_state
431
+ self.mutations_site = mutations_site
432
+ self.mutations_node = mutations_node
433
+ self.mutations_parent = mutations_parent
434
+ self.mutations_time = mutations_time
435
+ self.mutations_derived_state = mutations_derived_state
436
+ self.mutations_inherited_state = mutations_inherited_state
437
+ self.breakpoints = breakpoints
438
+ self.max_ancestral_length = max_ancestral_length
439
+ self.max_derived_length = max_derived_length
440
+ self.max_inherited_length = max_inherited_length
441
+
442
+ def tree_index(self):
443
+ """
444
+ Create a :class:`TreeIndex` for traversing this tree sequence.
445
+
446
+ :return: A new tree index initialized to the null tree.
447
+ Use next() or prev() to move to an actual tree.
448
+ :rtype: TreeIndex
449
+ """
450
+ # This method will be overriden when the concrete JIT class TreeIndex
451
+ # is defined in `jitwrap`.
452
+ return TreeIndex(self) # pragma: no cover
453
+
454
+ def child_index(self):
455
+ """
456
+ Create child index array for finding child edges of nodes. This operation
457
+ requires a linear pass over the edge table and therefore has a time
458
+ complexity of O(E).
459
+
460
+ :return: A numpy array (dtype=np.int32, shape=(num_nodes, 2)) where each row
461
+ contains the [start, stop) range of edges where this node is the parent.
462
+ :rtype: numpy.ndarray
463
+ """
464
+ child_range = np.full((self.num_nodes, 2), -1, dtype=np.int32)
465
+ edges_parent = self.edges_parent
466
+ if self.num_edges == 0:
467
+ return child_range
468
+
469
+ # Find ranges in tskit edge ordering
470
+ last_parent = -1
471
+ for edge_id in range(self.num_edges):
472
+ parent = edges_parent[edge_id]
473
+ if parent != last_parent:
474
+ child_range[parent, 0] = edge_id
475
+ if last_parent != -1:
476
+ child_range[last_parent, 1] = edge_id
477
+ last_parent = parent
478
+
479
+ if last_parent != -1:
480
+ child_range[last_parent, 1] = self.num_edges
481
+
482
+ return child_range
483
+
484
+ def parent_index(self):
485
+ """
486
+ Create a :class:`ParentIndex` for finding parent edges of nodes.
487
+
488
+ Edges within each child's group are not guaranteed to be in any
489
+ specific order. This operation uses a two-pass algorithm with
490
+ O(N + E) time complexity and O(N) auxiliary space.
491
+
492
+ :return: A new parent index container that can be used to
493
+ efficiently find all edges where a given node is the child.
494
+ :rtype: ParentIndex
495
+ """
496
+ num_nodes = self.num_nodes
497
+ num_edges = self.num_edges
498
+ edges_child = self.edges_child
499
+
500
+ child_counts = np.zeros(num_nodes, dtype=np.int32)
501
+ edge_index = np.zeros(num_edges, dtype=np.int32)
502
+ index_range = np.zeros((num_nodes, 2), dtype=np.int32)
503
+
504
+ if num_edges == 0:
505
+ return ParentIndex(edge_index, index_range)
506
+
507
+ # Count how many children each node has
508
+ for child_node in edges_child:
509
+ child_counts[child_node] += 1
510
+
511
+ # From the counts build the index ranges, we set both the start and the
512
+ # end index to the start - this lets us use the end index as a tracker
513
+ # for where we should insert the next edge for that node - when all
514
+ # edges are done these values will be the correct end values!
515
+ current_start = 0
516
+ for i in range(num_nodes):
517
+ index_range[i, :] = current_start
518
+ current_start += child_counts[i]
519
+
520
+ # Now go over the edges, inserting them at the index pointed to
521
+ # by the node's current end value, then increment.
522
+ for edge_id in range(num_edges):
523
+ child = edges_child[edge_id]
524
+ pos = index_range[child, 1]
525
+ edge_index[pos] = edge_id
526
+ index_range[child, 1] += 1
527
+
528
+ return ParentIndex(edge_index, index_range)
529
+
530
+
531
+ # We cache these classes to avoid repeated JIT compilation
532
+ @functools.lru_cache(None)
533
+ def _jitwrap(max_ancestral_length, max_derived_length, max_inherited_length):
534
+ # We have a circular dependency in JIT compilation between NumbaTreeSequence
535
+ # and NumbaTreeIndex so we used a deferred type to break it
536
+ tree_sequence_type = numba.deferred_type()
537
+
538
+ # We run this code on CI with this env var set so we can get coverage
539
+ # of the jitted functions. EdgeRange doesn't have a class_type
540
+ # in this case, so we skip the spec entirely.
541
+ if os.environ.get("NUMBA_DISABLE_JIT") == "1":
542
+ tree_index_spec = []
543
+ else:
544
+ tree_index_spec = [
545
+ ("ts", tree_sequence_type),
546
+ ("index", numba.int32),
547
+ ("direction", numba.int32),
548
+ ("interval", numba.types.UniTuple(numba.float64, 2)),
549
+ ("in_range", EdgeRange.class_type.instance_type),
550
+ ("out_range", EdgeRange.class_type.instance_type),
551
+ ("site_range", numba.types.UniTuple(numba.int32, 2)),
552
+ ("mutation_range", numba.types.UniTuple(numba.int32, 2)),
553
+ ]
554
+
555
+ JittedTreeIndex = numba.experimental.jitclass(tree_index_spec)(TreeIndex)
556
+
557
+ tree_sequence_spec = [
558
+ ("num_trees", numba.int32),
559
+ ("num_nodes", numba.int32),
560
+ ("num_samples", numba.int32),
561
+ ("num_edges", numba.int32),
562
+ ("num_sites", numba.int32),
563
+ ("num_mutations", numba.int32),
564
+ ("sequence_length", numba.float64),
565
+ ("edges_left", numba.float64[:]),
566
+ ("edges_right", numba.float64[:]),
567
+ ("indexes_edge_insertion_order", numba.int32[:]),
568
+ ("indexes_edge_removal_order", numba.int32[:]),
569
+ ("individuals_flags", numba.uint32[:]),
570
+ ("nodes_time", numba.float64[:]),
571
+ ("nodes_flags", numba.uint32[:]),
572
+ ("nodes_population", numba.int32[:]),
573
+ ("nodes_individual", numba.int32[:]),
574
+ ("edges_parent", numba.int32[:]),
575
+ ("edges_child", numba.int32[:]),
576
+ ("sites_position", numba.float64[:]),
577
+ ("sites_ancestral_state", numba.types.UnicodeCharSeq(max_ancestral_length)[:]),
578
+ ("mutations_site", numba.int32[:]),
579
+ ("mutations_node", numba.int32[:]),
580
+ ("mutations_parent", numba.int32[:]),
581
+ ("mutations_time", numba.float64[:]),
582
+ ("mutations_derived_state", numba.types.UnicodeCharSeq(max_derived_length)[:]),
583
+ (
584
+ "mutations_inherited_state",
585
+ numba.types.UnicodeCharSeq(max_inherited_length)[:],
586
+ ),
587
+ ("breakpoints", numba.float64[:]),
588
+ ("max_ancestral_length", numba.int32),
589
+ ("max_derived_length", numba.int32),
590
+ ("max_inherited_length", numba.int32),
591
+ ]
592
+
593
+ # The `tree_index` method on NumbaTreeSequence uses NumbaTreeIndex
594
+ # which is the uncompiled version of the class. The compiled version isn't
595
+ # known till now, so replace the method with this definition.
596
+
597
+ class _NumbaTreeSequence(NumbaTreeSequence):
598
+ def tree_index(self):
599
+ return JittedTreeIndex(self)
600
+
601
+ JittedTreeSequence = numba.experimental.jitclass(tree_sequence_spec)(
602
+ _NumbaTreeSequence
603
+ )
604
+
605
+ # Now both classes are setup we can resolve the deferred type
606
+ if os.environ.get("NUMBA_DISABLE_JIT") != "1":
607
+ tree_sequence_type.define(JittedTreeSequence.class_type.instance_type)
608
+
609
+ return JittedTreeSequence
610
+
611
+
612
+ def jitwrap(ts):
613
+ """
614
+ Convert a TreeSequence to a Numba-compatible format.
615
+
616
+ Creates a NumbaTreeSequence object that can be used within
617
+ Numba-compiled functions.
618
+
619
+ :param tskit.TreeSequence ts: The tree sequence to convert.
620
+ :return: A Numba-compatible representation of the input tree sequence.
621
+ Contains all necessary data arrays and metadata for tree traversal.
622
+ :rtype: NumbaTreeSequence
623
+ """
624
+ max_ancestral_length = max(1, max(map(len, ts.sites_ancestral_state), default=1))
625
+ max_derived_length = max(1, max(map(len, ts.mutations_derived_state), default=1))
626
+ max_inherited_length = max(
627
+ 1, max(map(len, ts.mutations_inherited_state), default=1)
628
+ )
629
+
630
+ JittedTreeSequence = _jitwrap(
631
+ max_ancestral_length, max_derived_length, max_inherited_length
632
+ )
633
+
634
+ # Create the tree sequence instance
635
+ numba_ts = JittedTreeSequence(
636
+ num_trees=ts.num_trees,
637
+ num_nodes=ts.num_nodes,
638
+ num_samples=ts.num_samples,
639
+ num_edges=ts.num_edges,
640
+ num_sites=ts.num_sites,
641
+ num_mutations=ts.num_mutations,
642
+ sequence_length=ts.sequence_length,
643
+ edges_left=ts.edges_left,
644
+ edges_right=ts.edges_right,
645
+ indexes_edge_insertion_order=ts.indexes_edge_insertion_order,
646
+ indexes_edge_removal_order=ts.indexes_edge_removal_order,
647
+ individuals_flags=ts.individuals_flags,
648
+ nodes_time=ts.nodes_time,
649
+ nodes_flags=ts.nodes_flags,
650
+ nodes_population=ts.nodes_population,
651
+ nodes_individual=ts.nodes_individual,
652
+ edges_parent=ts.edges_parent,
653
+ edges_child=ts.edges_child,
654
+ sites_position=ts.sites_position,
655
+ sites_ancestral_state=ts.sites_ancestral_state.astype(
656
+ f"U{max_ancestral_length}"
657
+ ),
658
+ mutations_site=ts.mutations_site,
659
+ mutations_node=ts.mutations_node,
660
+ mutations_parent=ts.mutations_parent,
661
+ mutations_time=ts.mutations_time,
662
+ mutations_derived_state=ts.mutations_derived_state.astype(
663
+ f"U{max_derived_length}"
664
+ ),
665
+ mutations_inherited_state=ts.mutations_inherited_state.astype(
666
+ f"U{max_inherited_length}"
667
+ ),
668
+ breakpoints=ts.breakpoints(as_array=True),
669
+ max_ancestral_length=max_ancestral_length,
670
+ max_derived_length=max_derived_length,
671
+ max_inherited_length=max_inherited_length,
672
+ )
673
+
674
+ return numba_ts
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tskit
3
- Version: 1.0.0b1
3
+ Version: 1.0.0b2
4
4
  Summary: The tree sequence toolkit.
5
5
  Author-email: Tskit Developers <admin@tskit.dev>
6
6
  License: MIT
@@ -1,7 +1,7 @@
1
- _tskit.cp310-win_amd64.pyd,sha256=gNYFMvO4zT6NYwS_S_pa_-aYsoWvHzkLX0haHGsDgew,570880
1
+ _tskit.cp310-win_amd64.pyd,sha256=TJBfgWYrXdEbt7cGW9hYvS2YDgazUeO1U2Jt4kIXbr4,570880
2
2
  tskit/__init__.py,sha256=iw2dYlGVXcJUCDUpIuzj-x6h9SvF-j2wvbWkx5LUsSI,3312
3
3
  tskit/__main__.py,sha256=Dnv25YfzrMDwYSW3pHqOdoa7EZI3oUEw81vU6x3qY_c,71
4
- tskit/_version.py,sha256=uhZFyLeOUoG8Pt-iWUoxiDdmzQWOiFgkUjSCJhSWF38,148
4
+ tskit/_version.py,sha256=TAaXO78UDuZn2Q6kyb5R2HZANmJaZnESfTFS6_Psu70,148
5
5
  tskit/cli.py,sha256=kcWbw-lujEAYT5GQXU2gzcZOW3RHhcegH1D7S1-RNFE,9029
6
6
  tskit/combinatorics.py,sha256=q0Wf95djp5dYv5NqMzx6hAiKfmNtauiFFAqFim5v9LQ,57227
7
7
  tskit/drawing.py,sha256=SVG2r0IOne57oEKwLb7rPOhoXpz8v9r6lJmKj7rfeXs,114112
@@ -17,9 +17,11 @@ tskit/text_formats.py,sha256=ZDGEhAhar9Z8Zjjg3SZwm5-AQxHVSTjZy4kifZqTNtQ,15262
17
17
  tskit/trees.py,sha256=VNYpyGQk3ahF2950x6eDmBN35utM_SQUbdH2EZZ3SWY,502418
18
18
  tskit/util.py,sha256=OYlDi5phgNiirgiG62_MMzyiMC1CSKI4IGrY4gbj4i4,36448
19
19
  tskit/vcf.py,sha256=U6IqXuw8mWO7K5DgjmKH9s7GjSJHrnCDK53sMoq9UNQ,8814
20
- tskit-1.0.0b1.dist-info/licenses/LICENSE,sha256=h92w-U3mbrRNLntc2_9qHbCB2MmJsvDSBia6BTUCwLE,1099
21
- tskit-1.0.0b1.dist-info/METADATA,sha256=y9r31ja-D18FCOwrxKrJCXOJs6VcXVNwslX1ZU9wCIY,4645
22
- tskit-1.0.0b1.dist-info/WHEEL,sha256=KUuBC6lxAbHCKilKua8R9W_TM71_-9Sg5uEP3uDWcoU,101
23
- tskit-1.0.0b1.dist-info/entry_points.txt,sha256=3Zik1X8C9Io1WvmTRBao5yEG5Kwy_xhFdM-ABC9TkWQ,47
24
- tskit-1.0.0b1.dist-info/top_level.txt,sha256=6GsXJYqSCR5Uhb4Js0BBzC0EFXE0FA5ywslsixSbwGM,13
25
- tskit-1.0.0b1.dist-info/RECORD,,
20
+ tskit/jit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
+ tskit/jit/numba.py,sha256=vyH3Hpa1_a3PJMGLli9Zw8_t8nPJx7pO46z5PFd7xP4,25453
22
+ tskit-1.0.0b2.dist-info/licenses/LICENSE,sha256=h92w-U3mbrRNLntc2_9qHbCB2MmJsvDSBia6BTUCwLE,1099
23
+ tskit-1.0.0b2.dist-info/METADATA,sha256=TAY19u828s_IqIjncv0hn5wqH0ev2vsIsdsaWT6PZQU,4645
24
+ tskit-1.0.0b2.dist-info/WHEEL,sha256=KUuBC6lxAbHCKilKua8R9W_TM71_-9Sg5uEP3uDWcoU,101
25
+ tskit-1.0.0b2.dist-info/entry_points.txt,sha256=3Zik1X8C9Io1WvmTRBao5yEG5Kwy_xhFdM-ABC9TkWQ,47
26
+ tskit-1.0.0b2.dist-info/top_level.txt,sha256=6GsXJYqSCR5Uhb4Js0BBzC0EFXE0FA5ywslsixSbwGM,13
27
+ tskit-1.0.0b2.dist-info/RECORD,,