bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +4 -2
- bartz/{BART.py → _interface.py} +256 -132
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +269 -314
- bartz/grove.py +124 -68
- bartz/jaxext/__init__.py +101 -27
- bartz/jaxext/_autobatch.py +257 -51
- bartz/jaxext/scipy/__init__.py +1 -1
- bartz/jaxext/scipy/special.py +3 -4
- bartz/jaxext/scipy/stats.py +1 -1
- bartz/mcmcloop.py +399 -208
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +1 -1
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.7.0.dist-info → bartz-0.8.0.dist-info}/METADATA +17 -11
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.7.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/mcmcstep.py +0 -2616
- bartz-0.7.0.dist-info/RECORD +0 -17
bartz/grove.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/grove.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024-
|
|
3
|
+
# Copyright (c) 2024-2026, The Bartz Contributors
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -28,10 +28,14 @@ import math
|
|
|
28
28
|
from functools import partial
|
|
29
29
|
from typing import Protocol
|
|
30
30
|
|
|
31
|
-
import
|
|
32
|
-
from jax import jit, lax
|
|
31
|
+
from jax import jit, lax, vmap
|
|
33
32
|
from jax import numpy as jnp
|
|
34
|
-
from jaxtyping import Array, Bool, DTypeLike, Float32, Int32,
|
|
33
|
+
from jaxtyping import Array, Bool, DTypeLike, Float32, Int32, Shaped, UInt
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
from numpy.lib.array_utils import normalize_axis_tuple # numpy 2
|
|
37
|
+
except ImportError:
|
|
38
|
+
from numpy.core.numeric import normalize_axis_tuple # numpy 1
|
|
35
39
|
|
|
36
40
|
from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc
|
|
37
41
|
|
|
@@ -44,32 +48,33 @@ class TreeHeaps(Protocol):
|
|
|
44
48
|
(left child) and :math:`2i + 1` (right child). The array element at index 0
|
|
45
49
|
is unused.
|
|
46
50
|
|
|
47
|
-
Parameters
|
|
48
|
-
----------
|
|
49
|
-
leaf_tree
|
|
50
|
-
The values in the leaves of the trees. This array can be dirty, i.e.,
|
|
51
|
-
unused nodes can have whatever value.
|
|
52
|
-
var_tree
|
|
53
|
-
The axes along which the decision nodes operate. This array can be
|
|
54
|
-
dirty but for the always unused node at index 0 which must be set to 0.
|
|
55
|
-
split_tree
|
|
56
|
-
The decision boundaries of the trees. The boundaries are open on the
|
|
57
|
-
right, i.e., a point belongs to the left child iff x < split. Whether a
|
|
58
|
-
node is a leaf is indicated by the corresponding 'split' element being
|
|
59
|
-
0. Unused nodes also have split set to 0. This array can't be dirty.
|
|
60
|
-
|
|
61
|
-
Notes
|
|
62
|
-
-----
|
|
63
51
|
Since the nodes at the bottom can only be leaves and not decision nodes,
|
|
64
52
|
`var_tree` and `split_tree` are half as long as `leaf_tree`.
|
|
53
|
+
|
|
54
|
+
Arrays may have additional initial axes to represent multiple trees.
|
|
65
55
|
"""
|
|
66
56
|
|
|
67
|
-
leaf_tree:
|
|
68
|
-
|
|
69
|
-
|
|
57
|
+
leaf_tree: (
|
|
58
|
+
Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d']
|
|
59
|
+
)
|
|
60
|
+
"""The values in the leaves of the trees. This array can be dirty, i.e.,
|
|
61
|
+
unused nodes can have whatever value. It may have an additional axis
|
|
62
|
+
for multivariate leaves."""
|
|
63
|
+
|
|
64
|
+
var_tree: UInt[Array, '*batch_shape 2**(d-1)']
|
|
65
|
+
"""The axes along which the decision nodes operate. This array can be
|
|
66
|
+
dirty but for the always unused node at index 0 which must be set to 0."""
|
|
67
|
+
|
|
68
|
+
split_tree: UInt[Array, '*batch_shape 2**(d-1)']
|
|
69
|
+
"""The decision boundaries of the trees. The boundaries are open on the
|
|
70
|
+
right, i.e., a point belongs to the left child iff x < split. Whether a
|
|
71
|
+
node is a leaf is indicated by the corresponding 'split' element being
|
|
72
|
+
0. Unused nodes also have split set to 0. This array can't be dirty."""
|
|
70
73
|
|
|
71
74
|
|
|
72
|
-
def make_tree(
|
|
75
|
+
def make_tree(
|
|
76
|
+
depth: int, dtype: DTypeLike, batch_shape: tuple[int, ...] = ()
|
|
77
|
+
) -> Shaped[Array, '*batch_shape 2**{depth}']:
|
|
73
78
|
"""
|
|
74
79
|
Make an array to represent a binary tree.
|
|
75
80
|
|
|
@@ -80,15 +85,19 @@ def make_tree(depth: int, dtype: DTypeLike) -> Shaped[Array, ' 2**{depth}']:
|
|
|
80
85
|
node.
|
|
81
86
|
dtype
|
|
82
87
|
The dtype of the array.
|
|
88
|
+
batch_shape
|
|
89
|
+
The leading shape of the array, to represent multiple trees and/or
|
|
90
|
+
multivariate trees.
|
|
83
91
|
|
|
84
92
|
Returns
|
|
85
93
|
-------
|
|
86
94
|
An array of zeroes with the appropriate shape.
|
|
87
95
|
"""
|
|
88
|
-
|
|
96
|
+
shape = (*batch_shape, 2**depth)
|
|
97
|
+
return jnp.zeros(shape, dtype)
|
|
89
98
|
|
|
90
99
|
|
|
91
|
-
def tree_depth(tree: Shaped[Array, '* 2**d']) -> int:
|
|
100
|
+
def tree_depth(tree: Shaped[Array, '*batch_shape 2**d']) -> int:
|
|
92
101
|
"""
|
|
93
102
|
Return the maximum depth of a tree.
|
|
94
103
|
|
|
@@ -106,10 +115,10 @@ def tree_depth(tree: Shaped[Array, '* 2**d']) -> int:
|
|
|
106
115
|
|
|
107
116
|
|
|
108
117
|
def traverse_tree(
|
|
109
|
-
x:
|
|
118
|
+
x: UInt[Array, ' p'],
|
|
110
119
|
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
111
120
|
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
112
|
-
) ->
|
|
121
|
+
) -> UInt[Array, '']:
|
|
113
122
|
"""
|
|
114
123
|
Find the leaf where a point falls into.
|
|
115
124
|
|
|
@@ -148,15 +157,16 @@ def traverse_tree(
|
|
|
148
157
|
return index
|
|
149
158
|
|
|
150
159
|
|
|
151
|
-
@
|
|
160
|
+
@jit
|
|
161
|
+
@partial(jnp.vectorize, excluded=(0,), signature='(hts),(hts)->(n)')
|
|
152
162
|
@partial(vmap_nodoc, in_axes=(1, None, None))
|
|
153
163
|
def traverse_forest(
|
|
154
|
-
X:
|
|
155
|
-
var_trees: UInt[Array, '
|
|
156
|
-
split_trees: UInt[Array, '
|
|
157
|
-
) ->
|
|
164
|
+
X: UInt[Array, 'p n'],
|
|
165
|
+
var_trees: UInt[Array, '*forest_shape 2**(d-1)'],
|
|
166
|
+
split_trees: UInt[Array, '*forest_shape 2**(d-1)'],
|
|
167
|
+
) -> UInt[Array, '*forest_shape n']:
|
|
158
168
|
"""
|
|
159
|
-
Find the leaves where points
|
|
169
|
+
Find the leaves where points falls into for each tree in a set.
|
|
160
170
|
|
|
161
171
|
Parameters
|
|
162
172
|
----------
|
|
@@ -174,35 +184,59 @@ def traverse_forest(
|
|
|
174
184
|
return traverse_tree(X, var_trees, split_trees)
|
|
175
185
|
|
|
176
186
|
|
|
187
|
+
@partial(jit, static_argnames=('sum_batch_axis',))
|
|
177
188
|
def evaluate_forest(
|
|
178
|
-
X: UInt[Array, 'p n'],
|
|
179
|
-
|
|
189
|
+
X: UInt[Array, 'p n'],
|
|
190
|
+
trees: TreeHeaps,
|
|
191
|
+
*,
|
|
192
|
+
sum_batch_axis: int | tuple[int, ...] = (),
|
|
193
|
+
) -> (
|
|
194
|
+
Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n']
|
|
195
|
+
):
|
|
180
196
|
"""
|
|
181
|
-
Evaluate
|
|
197
|
+
Evaluate an ensemble of trees at an array of points.
|
|
182
198
|
|
|
183
199
|
Parameters
|
|
184
200
|
----------
|
|
185
201
|
X
|
|
186
202
|
The coordinates to evaluate the trees at.
|
|
187
203
|
trees
|
|
188
|
-
The
|
|
189
|
-
|
|
190
|
-
|
|
204
|
+
The trees.
|
|
205
|
+
sum_batch_axis
|
|
206
|
+
The batch axes to sum over. By default, no summation is performed.
|
|
207
|
+
Note that negative indices count from the end of the batch dimensions,
|
|
208
|
+
the core dimensions n and k can't be summed over by this function.
|
|
191
209
|
|
|
192
210
|
Returns
|
|
193
211
|
-------
|
|
194
212
|
The (sum of) the values of the trees at the points in `X`.
|
|
195
213
|
"""
|
|
214
|
+
indices: UInt[Array, '*forest_shape n']
|
|
196
215
|
indices = traverse_forest(X, trees.var_tree, trees.split_tree)
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
216
|
+
|
|
217
|
+
is_mv = trees.leaf_tree.ndim != trees.var_tree.ndim
|
|
218
|
+
|
|
219
|
+
bc_indices: UInt[Array, '*forest_shape n 1'] | UInt[Array, '*forest_shape 1 n 1']
|
|
220
|
+
bc_indices = indices[..., None, :, None] if is_mv else indices[..., None]
|
|
221
|
+
|
|
222
|
+
bc_leaf_tree: (
|
|
223
|
+
Float32[Array, '*forest_shape 1 tree_size']
|
|
224
|
+
| Float32[Array, '*forest_shape k 1 tree_size']
|
|
225
|
+
)
|
|
226
|
+
bc_leaf_tree = (
|
|
227
|
+
trees.leaf_tree[..., :, None, :] if is_mv else trees.leaf_tree[..., None, :]
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
bc_leaves: (
|
|
231
|
+
Float32[Array, '*forest_shape n 1'] | Float32[Array, '*forest_shape k n 1']
|
|
232
|
+
)
|
|
233
|
+
bc_leaves = jnp.take_along_axis(bc_leaf_tree, bc_indices, -1)
|
|
234
|
+
|
|
235
|
+
leaves: Float32[Array, '*forest_shape n'] | Float32[Array, '*forest_shape k n']
|
|
236
|
+
leaves = jnp.squeeze(bc_leaves, -1)
|
|
237
|
+
|
|
238
|
+
axis = normalize_axis_tuple(sum_batch_axis, trees.var_tree.ndim - 1)
|
|
239
|
+
return jnp.sum(leaves, axis=axis)
|
|
206
240
|
|
|
207
241
|
|
|
208
242
|
def is_actual_leaf(
|
|
@@ -259,13 +293,13 @@ def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(
|
|
|
259
293
|
# the 0-th item has split == 0, so it's not counted
|
|
260
294
|
|
|
261
295
|
|
|
262
|
-
def tree_depths(
|
|
296
|
+
def tree_depths(tree_size: int) -> Int32[Array, ' {tree_size}']:
|
|
263
297
|
"""
|
|
264
298
|
Return the depth of each node in a binary tree.
|
|
265
299
|
|
|
266
300
|
Parameters
|
|
267
301
|
----------
|
|
268
|
-
|
|
302
|
+
tree_size
|
|
269
303
|
The length of the tree array, i.e., 2 ** d.
|
|
270
304
|
|
|
271
305
|
Returns
|
|
@@ -280,7 +314,7 @@ def tree_depths(tree_length: int) -> Int32[Array, ' {tree_length}']:
|
|
|
280
314
|
"""
|
|
281
315
|
depths = []
|
|
282
316
|
depth = 0
|
|
283
|
-
for i in range(
|
|
317
|
+
for i in range(tree_size):
|
|
284
318
|
if i == 2**depth:
|
|
285
319
|
depth += 1
|
|
286
320
|
depths.append(depth - 1)
|
|
@@ -288,7 +322,10 @@ def tree_depths(tree_length: int) -> Int32[Array, ' {tree_length}']:
|
|
|
288
322
|
return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
|
|
289
323
|
|
|
290
324
|
|
|
291
|
-
|
|
325
|
+
@partial(jnp.vectorize, signature='(half_tree_size)->(tree_size)')
|
|
326
|
+
def is_used(
|
|
327
|
+
split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
328
|
+
) -> Bool[Array, '*batch_shape 2**d']:
|
|
292
329
|
"""
|
|
293
330
|
Return a mask indicating the used nodes in a tree.
|
|
294
331
|
|
|
@@ -308,7 +345,7 @@ def is_used(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**d']:
|
|
|
308
345
|
|
|
309
346
|
|
|
310
347
|
@jit
|
|
311
|
-
def forest_fill(split_tree: UInt[Array, '
|
|
348
|
+
def forest_fill(split_tree: UInt[Array, '*batch_shape 2**(d-1)']) -> Float32[Array, '']:
|
|
312
349
|
"""
|
|
313
350
|
Return the fraction of used nodes in a set of trees.
|
|
314
351
|
|
|
@@ -321,36 +358,55 @@ def forest_fill(split_tree: UInt[Array, 'num_trees 2**(d-1)']) -> Float32[Array,
|
|
|
321
358
|
-------
|
|
322
359
|
Number of tree nodes over the maximum number that could be stored.
|
|
323
360
|
"""
|
|
324
|
-
|
|
325
|
-
used = jax.vmap(is_used)(split_tree)
|
|
361
|
+
used = is_used(split_tree)
|
|
326
362
|
count = jnp.count_nonzero(used)
|
|
327
|
-
|
|
363
|
+
batch_size = split_tree.size // split_tree.shape[-1]
|
|
364
|
+
return count / (used.size - batch_size)
|
|
328
365
|
|
|
329
366
|
|
|
367
|
+
@partial(jit, static_argnames=('p', 'sum_batch_axis'))
|
|
330
368
|
def var_histogram(
|
|
331
|
-
p: int,
|
|
332
|
-
|
|
369
|
+
p: int,
|
|
370
|
+
var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
371
|
+
split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
372
|
+
*,
|
|
373
|
+
sum_batch_axis: int | tuple[int, ...] = (),
|
|
374
|
+
) -> Int32[Array, '*reduced_batch_shape {p}']:
|
|
333
375
|
"""
|
|
334
376
|
Count how many times each variable appears in a tree.
|
|
335
377
|
|
|
336
378
|
Parameters
|
|
337
379
|
----------
|
|
338
380
|
p
|
|
339
|
-
The number of variables (the maximum value that can occur in
|
|
340
|
-
|
|
381
|
+
The number of variables (the maximum value that can occur in `var_tree`
|
|
382
|
+
is ``p - 1``).
|
|
341
383
|
var_tree
|
|
342
384
|
The decision axes of the tree.
|
|
343
385
|
split_tree
|
|
344
386
|
The decision boundaries of the tree.
|
|
387
|
+
sum_batch_axis
|
|
388
|
+
The batch axes to sum over. By default, no summation is performed. Note
|
|
389
|
+
that negative indices count from the end of the batch dimensions, the
|
|
390
|
+
core dimension p can't be summed over by this function.
|
|
345
391
|
|
|
346
392
|
Returns
|
|
347
393
|
-------
|
|
348
|
-
The histogram of the variables used in the tree.
|
|
349
|
-
|
|
350
|
-
Notes
|
|
351
|
-
-----
|
|
352
|
-
If there are leading axes in the tree arrays (i.e., multiple trees), the
|
|
353
|
-
returned counts are cumulative over trees.
|
|
394
|
+
The histogram(s) of the variables used in the tree.
|
|
354
395
|
"""
|
|
355
396
|
is_internal = split_tree.astype(bool)
|
|
356
|
-
|
|
397
|
+
|
|
398
|
+
def scatter_add(
|
|
399
|
+
var_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
|
|
400
|
+
is_internal: Bool[Array, '*summed_batch_axes half_tree_size'],
|
|
401
|
+
) -> Int32[Array, ' p']:
|
|
402
|
+
return jnp.zeros(p, int).at[var_tree].add(is_internal)
|
|
403
|
+
|
|
404
|
+
# vmap scatter_add over non-batched dims
|
|
405
|
+
batch_ndim = var_tree.ndim - 1
|
|
406
|
+
axes = normalize_axis_tuple(sum_batch_axis, batch_ndim)
|
|
407
|
+
for i in reversed(range(batch_ndim)):
|
|
408
|
+
neg_i = i - var_tree.ndim
|
|
409
|
+
if i not in axes:
|
|
410
|
+
scatter_add = vmap(scatter_add, in_axes=neg_i)
|
|
411
|
+
|
|
412
|
+
return scatter_add(var_tree, is_internal)
|
bartz/jaxext/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/jaxext/__init__.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024-
|
|
3
|
+
# Copyright (c) 2024-2026, The Bartz Contributors
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -24,13 +24,23 @@
|
|
|
24
24
|
|
|
25
25
|
"""Additions to jax."""
|
|
26
26
|
|
|
27
|
-
import functools
|
|
28
27
|
import math
|
|
29
28
|
from collections.abc import Sequence
|
|
29
|
+
from contextlib import nullcontext
|
|
30
|
+
from functools import partial
|
|
30
31
|
|
|
31
32
|
import jax
|
|
33
|
+
from jax import (
|
|
34
|
+
Device,
|
|
35
|
+
debug_key_reuse,
|
|
36
|
+
device_count,
|
|
37
|
+
ensure_compile_time_eval,
|
|
38
|
+
jit,
|
|
39
|
+
random,
|
|
40
|
+
vmap,
|
|
41
|
+
)
|
|
32
42
|
from jax import numpy as jnp
|
|
33
|
-
from jax import
|
|
43
|
+
from jax.dtypes import prng_key
|
|
34
44
|
from jax.lax import scan
|
|
35
45
|
from jax.scipy.special import ndtr
|
|
36
46
|
from jaxtyping import Array, Bool, Float32, Key, Scalar, Shaped
|
|
@@ -63,7 +73,7 @@ def minimal_unsigned_dtype(value):
|
|
|
63
73
|
return jnp.uint64
|
|
64
74
|
|
|
65
75
|
|
|
66
|
-
@
|
|
76
|
+
@partial(jax.jit, static_argnums=(1,))
|
|
67
77
|
def unique(
|
|
68
78
|
x: Shaped[Array, ' _'], size: int, fill_value: Scalar
|
|
69
79
|
) -> tuple[Shaped[Array, ' {size}'], int]:
|
|
@@ -114,24 +124,42 @@ class split:
|
|
|
114
124
|
The key to split.
|
|
115
125
|
num
|
|
116
126
|
The number of keys to split into.
|
|
127
|
+
|
|
128
|
+
Notes
|
|
129
|
+
-----
|
|
130
|
+
Unlike `jax.random.split`, this class supports a vector of keys as input. In
|
|
131
|
+
this case, it behaves as if everything had been vmapped over, so `keys.pop`
|
|
132
|
+
has an additional initial output dimension equal to the number of input
|
|
133
|
+
keys, and the deterministic dependency respects this axis.
|
|
117
134
|
"""
|
|
118
135
|
|
|
119
|
-
|
|
120
|
-
|
|
136
|
+
_keys: tuple[Key[Array, '*batch'], ...]
|
|
137
|
+
_num_used: int
|
|
138
|
+
|
|
139
|
+
def __init__(self, key: Key[Array, '*batch'], num: int = 2):
|
|
140
|
+
if key.ndim:
|
|
141
|
+
context = debug_key_reuse(False)
|
|
142
|
+
else:
|
|
143
|
+
context = nullcontext()
|
|
144
|
+
with context:
|
|
145
|
+
# jitted-vmapped key split seems to be triggering a false positive
|
|
146
|
+
# with key reuse checks
|
|
147
|
+
self._keys = _split_unpack(key, num)
|
|
148
|
+
self._num_used = 0
|
|
121
149
|
|
|
122
150
|
def __len__(self):
|
|
123
|
-
return self._keys.
|
|
151
|
+
return len(self._keys) - self._num_used
|
|
124
152
|
|
|
125
|
-
def pop(self, shape: int | tuple[int, ...]
|
|
153
|
+
def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, '*batch {shape}']:
|
|
126
154
|
"""
|
|
127
155
|
Pop one or more keys from the list.
|
|
128
156
|
|
|
129
157
|
Parameters
|
|
130
158
|
----------
|
|
131
159
|
shape
|
|
132
|
-
The shape of the keys to pop. If
|
|
133
|
-
|
|
134
|
-
reshaped to
|
|
160
|
+
The shape of the keys to pop. If empty (default), a single key is
|
|
161
|
+
popped and returned. If not empty, the popped key is split and
|
|
162
|
+
reshaped to the target shape.
|
|
135
163
|
|
|
136
164
|
Returns
|
|
137
165
|
-------
|
|
@@ -140,24 +168,41 @@ class split:
|
|
|
140
168
|
Raises
|
|
141
169
|
------
|
|
142
170
|
IndexError
|
|
143
|
-
If
|
|
144
|
-
|
|
145
|
-
Notes
|
|
146
|
-
-----
|
|
147
|
-
The keys are popped from the beginning of the list, so for example
|
|
148
|
-
``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``.
|
|
171
|
+
If the list is empty.
|
|
149
172
|
"""
|
|
150
|
-
if
|
|
151
|
-
|
|
152
|
-
elif not isinstance(shape, tuple):
|
|
153
|
-
shape = (shape,)
|
|
154
|
-
size_to_pop = math.prod(shape)
|
|
155
|
-
if size_to_pop > self._keys.size:
|
|
156
|
-
msg = f'Cannot pop {size_to_pop} keys from {self._keys.size} keys'
|
|
173
|
+
if len(self) == 0:
|
|
174
|
+
msg = 'No keys left to pop'
|
|
157
175
|
raise IndexError(msg)
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
176
|
+
if not isinstance(shape, tuple):
|
|
177
|
+
shape = (shape,)
|
|
178
|
+
key = self._keys[self._num_used]
|
|
179
|
+
self._num_used += 1
|
|
180
|
+
if shape:
|
|
181
|
+
key = _split_shaped(key, shape)
|
|
182
|
+
return key
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@partial(jit, static_argnums=(1,))
|
|
186
|
+
def _split_unpack(
|
|
187
|
+
key: Key[Array, '*batch'], num: int
|
|
188
|
+
) -> tuple[Key[Array, '*batch'], ...]:
|
|
189
|
+
if key.ndim == 0:
|
|
190
|
+
keys = random.split(key, num)
|
|
191
|
+
elif key.ndim == 1:
|
|
192
|
+
keys = vmap(random.split, in_axes=(0, None), out_axes=1)(key, num)
|
|
193
|
+
return tuple(keys)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@partial(jit, static_argnums=(1,))
|
|
197
|
+
def _split_shaped(
|
|
198
|
+
key: Key[Array, '*batch'], shape: tuple[int, ...]
|
|
199
|
+
) -> Key[Array, '*batch {shape}']:
|
|
200
|
+
num = math.prod(shape)
|
|
201
|
+
if key.ndim == 0:
|
|
202
|
+
keys = random.split(key, num)
|
|
203
|
+
elif key.ndim == 1:
|
|
204
|
+
keys = vmap(random.split, in_axes=(0, None))(key, num)
|
|
205
|
+
return keys.reshape(*key.shape, *shape)
|
|
161
206
|
|
|
162
207
|
|
|
163
208
|
def truncated_normal_onesided(
|
|
@@ -165,6 +210,8 @@ def truncated_normal_onesided(
|
|
|
165
210
|
shape: Sequence[int],
|
|
166
211
|
upper: Bool[Array, '*'],
|
|
167
212
|
bound: Float32[Array, '*'],
|
|
213
|
+
*,
|
|
214
|
+
clip: bool = True,
|
|
168
215
|
) -> Float32[Array, '*']:
|
|
169
216
|
"""
|
|
170
217
|
Sample from a one-sided truncated standard normal distribution.
|
|
@@ -179,6 +226,9 @@ def truncated_normal_onesided(
|
|
|
179
226
|
True for (-∞, bound], False for [bound, ∞).
|
|
180
227
|
bound
|
|
181
228
|
The truncation boundary.
|
|
229
|
+
clip
|
|
230
|
+
Whether to clip the truncated uniform samples to (0, 1) before
|
|
231
|
+
transforming them to truncated normal. Intended for debugging purposes.
|
|
182
232
|
|
|
183
233
|
Returns
|
|
184
234
|
-------
|
|
@@ -209,5 +259,29 @@ def truncated_normal_onesided(
|
|
|
209
259
|
left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)]
|
|
210
260
|
right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1)
|
|
211
261
|
truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u)
|
|
262
|
+
if clip:
|
|
263
|
+
# on gpu the accuracy is lower and sometimes u can reach the boundaries
|
|
264
|
+
zero = jnp.zeros((), truncated_u.dtype)
|
|
265
|
+
one = jnp.ones((), truncated_u.dtype)
|
|
266
|
+
truncated_u = jnp.clip(
|
|
267
|
+
truncated_u, jnp.nextafter(zero, one), jnp.nextafter(one, zero)
|
|
268
|
+
)
|
|
212
269
|
truncated_norm = ndtri(truncated_u)
|
|
213
270
|
return jnp.where(bound_pos, -truncated_norm, truncated_norm)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def get_default_device() -> Device:
|
|
274
|
+
"""Get the current default JAX device."""
|
|
275
|
+
with ensure_compile_time_eval():
|
|
276
|
+
return jnp.zeros(()).device
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def get_device_count() -> int:
|
|
280
|
+
"""Get the number of available devices on the default platform."""
|
|
281
|
+
device = get_default_device()
|
|
282
|
+
return device_count(device.platform)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def is_key(x: object) -> bool:
|
|
286
|
+
"""Determine if `x` is a jax random key."""
|
|
287
|
+
return isinstance(x, Array) and jnp.issubdtype(x.dtype, prng_key)
|